Skip to content

Commit

Permalink
Support wrapping functions.
Browse files Browse the repository at this point in the history
Add tests.
  • Loading branch information
eskildsf committed Dec 22, 2024
1 parent 5b2b4e3 commit e477a9e
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 13 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@ def function():
pass
```

It is also possible to mark a function for reload after defining it:
```python
from reloading import reloading

def function():
# This function will be reloaded before each function call
pass

function = reloading(function)
```

## Additional Options

To iterate forever in a `for` loop you can omit the argument:
Expand Down
39 changes: 36 additions & 3 deletions reloading/reloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,10 @@ def strip_reloading_decorator(function_with_decorator: ast.FunctionDef):
decorator_names = [get_decorator_name_or_none(decorator)
for decorator
in fwod.decorator_list]
# Find index of "reloading" decorator
reloading_index = decorator_names.index("reloading")
fwod.decorator_list = fwod.decorator_list[reloading_index + 1:]
if "reloading" in decorator_names:
# Find index of "reloading" decorator
reloading_index = decorator_names.index("reloading")
fwod.decorator_list = fwod.decorator_list[reloading_index + 1:]
function_without_decorator = fwod
return function_without_decorator

Expand All @@ -511,6 +512,7 @@ def isolate_function_def(function_frame_info: inspect.FrameInfo,
class_name = qualname.split(".")[length - 2] if length > 1 else None

candidates = []
weak_candidates = []
for node in ast.walk(reloaded_file_ast):
if isinstance(node, ast.ClassDef) and node.name == class_name:
for subnode in node.body:
Expand All @@ -527,6 +529,24 @@ def isolate_function_def(function_frame_info: inspect.FrameInfo,
for decorator in node.decorator_list
]:
candidates.append(node)
else:
# The function was not decorated... Hmm
# This could be because the function was wrapped. Example:
# def f(x):
# return x
# f = reloading(f)
# If we only find ONE function which matches by name,
# then we return it. If we're confused though, due to
# multiple definitions of function with the same
# name then we raise and exception. Example:
# def f(x):
# return x
# def g(x):
# def f(x):
# return x
# return f(x)
# f = reloading(f)
weak_candidates.append(node)
# Select the candidate node which is closest to function_frame_info
if len(candidates):
def sorting_function(candidate):
Expand All @@ -535,6 +555,19 @@ def sorting_function(candidate):
function_node = strip_reloading_decorator(candidate)
function_node_ast = ast.Module([function_node], type_ignores=[])
return function_node_ast
elif len(weak_candidates) == 1:
candidate = weak_candidates[0]
function_node = strip_reloading_decorator(candidate)
function_node_ast = ast.Module([function_node], type_ignores=[])
return function_node_ast
elif len(weak_candidates) > 1:
raise ReloadingException(
f'The file "{function_frame_info.filename}" contains '
f'{len(weak_candidates)} definitions of functions with the name '
f'"{function_name}" so it is not possible to figure out which '
'one to reload. This can be resolved by decorating the function '
'instead of wrapping it.'
)
else:
raise ReloadingException(
f'Unable to reload function "{function_name}" '
Expand Down
79 changes: 69 additions & 10 deletions reloading/test_reloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ def test_empty_function_definition(self):
def function():
pass

def test_empty_function_wrapped(self):
def function():
pass

function = reloading(function)

def test_empty_function_run(self):
@reloading
def function():
Expand All @@ -210,6 +216,13 @@ def function():

self.assertEqual(function(), "result")

def test_function_return_value_wrapped(self):
def function():
return "result"

function = reloading(function)
self.assertEqual(function(), "result")

def test_nested_function(self):
def outer():
@reloading
Expand Down Expand Up @@ -397,7 +410,7 @@ def f():
print(f()+'g')
sleep(0.2)
"""
stdout, stderr = run_and_update_source(
stdout, _ = run_and_update_source(
init_src=code,
updated_src=code.replace("'f'", "'F'").
replace("'g'", "'G'"),
Expand Down Expand Up @@ -441,7 +454,7 @@ def f():
sleep(0.2)
i += 1
"""
stdout, stderr = run_and_update_source(
stdout, _ = run_and_update_source(
init_src=code,
updated_src=code.replace("'f'", "'F'").
replace("'g'", "'G'"),
Expand Down Expand Up @@ -546,22 +559,25 @@ def g():
self.assertIn("fg", stdout)
self.assertIn("FG", stdout)


class TestReloadingMixedWithChanges(unittest.TestCase):
def test_function_for_loop(self):
def test_multiple_functions_not_decorated(self):
code = """
from reloading import reloading
from time import sleep
@reloading
def f():
return 'f'
for i in reloading(range(10)):
print(f()+'g')
def g():
return 'g'
f = reloading(f)
g = reloading(g)
for i in range(10):
print(f()+g())
sleep(0.2)
"""
stdout, stderr = run_and_update_source(
stdout, _ = run_and_update_source(
init_src=code,
updated_src=code.replace("'f'", "'F'").
replace("'g'", "'G'"),
Expand All @@ -570,6 +586,49 @@ def f():
self.assertIn("fg", stdout)
self.assertIn("FG", stdout)

def test_class_decorates_methods(self):
code = """
from reloading import reloading
from time import sleep
def get_subclass_methods(cls):
methods = set(dir(cls(_get_subclass_methods=True)))
unique_methods = methods.difference(
*(dir(base()) for base in cls.__bases__)
)
return list(unique_methods)
class ClassWhichMarksSubclassMethodsForReload:
def __init__(self, *args, **kwargs):
if (self.__class__.__name__ != super().__thisclass__.__name__
and not '_get_subclass_methods' in kwargs):
methods_of_subclass = get_subclass_methods(self.__class__)
for method in methods_of_subclass:
setattr(self.__class__, method,
reloading(getattr(self.__class__, method)))
def f(self):
return 'f'
class Subclass(ClassWhichMarksSubclassMethodsForReload):
def g(self):
return 'g'
obj = Subclass()
for i in range(10):
print(obj.f()+obj.g())
sleep(0.2)
"""
stdout, _ = run_and_update_source(
init_src=code,
updated_src=code.replace("'f'", "'F'").
replace("'g'", "'G'"),
)

self.assertIn("fg", stdout)
self.assertIn("fG", stdout)
self.assertNotIn("FG", stdout)

def test_function_while_loop(self):
code = """
from reloading import reloading
Expand All @@ -585,7 +644,7 @@ def f():
sleep(0.2)
i += 1
"""
stdout, stderr = run_and_update_source(
stdout, _ = run_and_update_source(
init_src=code,
updated_src=code.replace("'f'", "'F'").
replace("'g'", "'G'"),
Expand Down

0 comments on commit e477a9e

Please sign in to comment.