Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(framework) Improve import management in load_app function in object_ref.py #4452

Merged
merged 6 commits into from
Nov 7, 2024
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 40 additions & 33 deletions src/py/flwr/common/object_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def validate(
specified attribute within it.
project_dir : Optional[Union[str, Path]] (default: None)
The directory containing the module. If None, the current working directory
is used. If `check_module` is True, the `project_dir` will be inserted into
the system path, and the previously inserted `project_dir` will be removed.
is used. If `check_module` is True, the `project_dir` will be temporarily
inserted into the system path and then removed after the validation is complete.

Returns
-------
Expand All @@ -66,8 +66,8 @@ def validate(

Note
----
This function will modify `sys.path` by inserting the provided `project_dir`
and removing the previously inserted `project_dir`.
This function will temporarily modify `sys.path` by inserting the provided
`project_dir`, which will be removed after the validation is complete.
"""
module_str, _, attributes_str = module_attribute_str.partition(":")
if not module_str:
Expand All @@ -82,11 +82,19 @@ def validate(
)

if check_module:
if project_dir is None:
project_dir = Path.cwd()
project_dir = Path(project_dir).absolute()
# Set the system path
_set_sys_path(project_dir)
sys.path.insert(0, str(project_dir))

# Load module
module = find_spec(module_str)

# Unset the system path
sys.path.remove(str(project_dir))

# Check if the module and the attribute exist
if module and module.origin:
if not _find_attribute_in_module(module.origin, attributes_str):
return (
Expand Down Expand Up @@ -133,8 +141,10 @@ def load_app( # pylint: disable= too-many-branches

Note
----
This function will modify `sys.path` by inserting the provided `project_dir`
and removing the previously inserted `project_dir`.
- This function will unload all modules in the previously provided `project_dir`,
if it is invoked again.
- This function will modify `sys.path` by inserting the provided `project_dir`
and removing the previously inserted `project_dir`.
"""
valid, error_msg = validate(module_attribute_str, check_module=False)
if not valid and error_msg:
Expand All @@ -143,33 +153,21 @@ def load_app( # pylint: disable= too-many-branches
module_str, _, attributes_str = module_attribute_str.partition(":")

try:
if _current_sys_path:
# Hack: `tabnet` does not work with reloading
if "tabnet" in sys.modules:
log(
WARN,
"Cannot reload module `%s` from disk due to compatibility issues "
"with the `tabnet` library. The module will be loaded from the "
"cache instead. If you experience issues, consider restarting "
"the application.",
module_str,
)
else:
_unload_modules(Path(_current_sys_path))
_set_sys_path(project_dir)

if module_str not in sys.modules:
module = importlib.import_module(module_str)
# Hack: `tabnet` does not work with `importlib.reload`
elif "tabnet" in sys.modules:
log(
WARN,
"Cannot reload module `%s` from disk due to compatibility issues "
"with the `tabnet` library. The module will be loaded from the "
"cache instead. If you experience issues, consider restarting "
"the application.",
module_str,
)
module = sys.modules[module_str]
else:
module = sys.modules[module_str]

if project_dir is None:
project_dir = Path.cwd()

# Reload cached modules in the project directory
for m in list(sys.modules.values()):
path: Optional[str] = getattr(m, "__file__", None)
if path is not None and path.startswith(str(project_dir)):
importlib.reload(m)

module = importlib.import_module(module_str)
except ModuleNotFoundError as err:
raise error_type(
f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
Expand All @@ -189,6 +187,15 @@ def load_app( # pylint: disable= too-many-branches
return attribute


def _unload_modules(project_dir: Path) -> None:
"""Unload modules from the project directory."""
dir_str = str(project_dir.absolute())
for name, m in list(sys.modules.items()):
path: Optional[str] = getattr(m, "__file__", None)
if path is not None and path.startswith(dir_str):
del sys.modules[name]


def _set_sys_path(directory: Optional[Union[str, Path]]) -> None:
"""Set the system path."""
if directory is None:
Expand Down
Loading