Skip to content

Commit

Permalink
Fix possible race condition.
Browse files Browse the repository at this point in the history
Harmonize naming.
  • Loading branch information
eskildsf committed Dec 19, 2024
1 parent 6b64d94 commit a9fd696
Showing 1 changed file with 33 additions and 30 deletions.
63 changes: 33 additions & 30 deletions reloading/reloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,31 +139,31 @@ def format_iteration_variables(ast_node: Union[ast.Name,
return ", ".join(names)


def load_file(filepath: str) -> str:
def load_file(filename: str) -> str:
"""
Read contents of file containing reloading code.
Handle case of file appearing empty on read.
"""
while True:
with open(filepath, "r") as f:
with open(filename, "r") as f:
src = f.read()
if len(src):
return src + "\n"


def parse_file_until_successful(filepath: str) -> ast.Module:
def parse_file_until_successful(filename: str) -> ast.Module:
"""
Parse source code of file containing reloading code.
File may appear incomplete as as it is read so retry until successful.
"""
source = load_file(filepath)
source = load_file(filename)
while True:
try:
tree = ast.parse(source)
return tree
except SyntaxError:
handle_exception(filepath)
source = load_file(filepath)
handle_exception(filename)
source = load_file(filename)


break_ast = ast.parse('raise Exception("break")').body
Expand Down Expand Up @@ -283,28 +283,28 @@ def get_loop_id(ast_node: Union[ast.For, ast.While]) -> str:

def get_loop_code(loop_frame_info: inspect.FrameInfo,
loop_id: Union[None, str]) -> Union[WhileLoop, ForLoop]:
filepath: str = loop_frame_info.filename
filename: str = loop_frame_info.filename
while True:
reloaded_file_ast: ast.Module = parse_file_until_successful(filepath)
reloaded_file_ast: ast.Module = parse_file_until_successful(filename)
try:
return get_loop_object(
loop_frame_info, reloaded_file_ast, loop_id=loop_id
)
except (LookupError, ReloadingException):
handle_exception(filepath)
handle_exception(filename)


def handle_exception(filepath: str):
def handle_exception(filename: str):
"""
Output helpful error message to user regarding exception in reloaded code.
"""
exception = traceback.format_exc()
exception = exception.replace('File "<string>"', f'File "{filepath}"')
exception = exception.replace('File "<string>"', f'File "{filename}"')
sys.stderr.write(exception + "\n")

if sys.stdin.isatty():
print(
f"An error occurred. Please edit the file '{filepath}' to fix "
f"An error occurred. Please edit the file '{filename}' to fix "
"the issue and press return to continue or Ctrl+C to exit."
)
try:
Expand All @@ -318,12 +318,12 @@ def handle_exception(filepath: str):
print(line_number)
raise Exception(
'An error occurred. Please fix the issue in the file'
f'"{filepath}" and run the script again.'
f'"{filename}" and run the script again.'
)


def execute_for_loop(seq: Iterable, loop_frame_info: inspect.FrameInfo):
filepath = loop_frame_info.filename
filename = loop_frame_info.filename
caller_globals: Dict[str, Any] = loop_frame_info.frame.f_globals
caller_locals: Dict[str, Any] = loop_frame_info.frame.f_locals

Expand All @@ -337,15 +337,16 @@ def execute_for_loop(seq: Iterable, loop_frame_info: inspect.FrameInfo):

for i, iteration_variable_values in enumerate(seq):
# Reload code if possibly modified
if file_stat != os.stat(filepath).st_mtime_ns:
file_stat_: int = os.stat(filename).st_mtime_ns
if file_stat != file_stat_:
if i > 0:
log.info(f'For loop at line {loop_frame_info.lineno} of file '
f'"{filepath}" has been reloaded.')
f'"{filename}" has been reloaded.')
for_loop = get_loop_code(
loop_frame_info, loop_id=for_loop.id
)
assert isinstance(for_loop, ForLoop)
file_stat = os.stat(filepath).st_mtime_ns
file_stat = file_stat_
# Make up a name for a variable which is not already present in
# the global or local namespace.
vacant_variable_name = unique_name(
Expand Down Expand Up @@ -374,15 +375,15 @@ def execute_for_loop(seq: Iterable, loop_frame_info: inspect.FrameInfo):
if exception.args == ("continue",):
continue
else:
handle_exception(filepath)
handle_exception(filename)


def execute_while_loop(loop_frame_info: inspect.FrameInfo):
filepath = loop_frame_info.filename
filename = loop_frame_info.filename
caller_globals: Dict[str, Any] = loop_frame_info.frame.f_globals
caller_locals: Dict[str, Any] = loop_frame_info.frame.f_locals

file_stat: int = os.stat(filepath).st_mtime_ns
file_stat: int = os.stat(filename).st_mtime_ns
while_loop = get_loop_code(
loop_frame_info, loop_id=None
)
Expand All @@ -394,13 +395,14 @@ def condition(while_loop):
while condition(while_loop):
i += 1
# Reload code if possibly modified
if file_stat != os.stat(filepath).st_mtime_ns:
file_stat_: int = os.stat(filename).st_mtime_ns
if file_stat != file_stat_:
log.info(f'While loop at line {loop_frame_info.lineno} of file '
f'"{filepath}" has been reloaded.')
f'"{filename}" has been reloaded.')
while_loop = get_loop_code(
loop_frame_info, loop_id=while_loop.id
)
file_stat = os.stat(filepath).st_mtime_ns
file_stat = file_stat_
try:
exec(while_loop.compiled_body, caller_globals, caller_locals)
except Exception as exception:
Expand All @@ -412,7 +414,7 @@ def condition(while_loop):
if exception.args == ("continue",):
continue
else:
handle_exception(filepath)
handle_exception(filename)


def _reloading_loop(seq: Union[Iterable, bool]) -> Iterable:
Expand Down Expand Up @@ -556,12 +558,12 @@ def _reloading_function(function: Callable) -> Callable:
assert stack[1].function == "reloading"
# The third element is the loop which called reloading.
function_frame_info: inspect.FrameInfo = stack[2]
filepath: str = function_frame_info.filename
filename: str = function_frame_info.filename

caller_globals = function_frame_info.frame.f_globals
caller_locals = function_frame_info.frame.f_locals

file_stat: int = os.stat(filepath).st_mtime_ns
file_stat: int = os.stat(filename).st_mtime_ns
rfunction = get_reloaded_function(caller_globals,
caller_locals,
function_frame_info,
Expand All @@ -571,22 +573,23 @@ def _reloading_function(function: Callable) -> Callable:
def wrapped(*args, **kwargs):
nonlocal file_stat, function, rfunction, i
# Reload code if possibly modified
if file_stat != os.stat(filepath).st_mtime_ns:
file_stat_: int = os.stat(filename).st_mtime_ns
if file_stat != file_stat_:
log.info(f'Function "{function.__name__}" at line '
f'{function_frame_info.lineno} '
f'of file "{filepath}" has been reloaded.')
f'of file "{filename}" has been reloaded.')
rfunction = get_reloaded_function(caller_globals,
caller_locals,
function_frame_info,
function)
file_stat = os.stat(filepath).st_mtime_ns
file_stat = file_stat_
i += 1
while True:
try:
result = rfunction(*args, **kwargs)
return result
except Exception:
handle_exception(filepath)
handle_exception(filename)

wrapped.__signature__ = inspect.signature(function) # type: ignore
caller_locals[function.__name__] = wrapped
Expand Down

0 comments on commit a9fd696

Please sign in to comment.