From bc4e010753dbbccf263ea286e83103b8e34c5879 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Fri, 20 Dec 2024 20:37:09 -0500 Subject: [PATCH 1/3] Fixed the FOBS lazy loading issue --- nvflare/fuel/utils/fobs/fobs.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/nvflare/fuel/utils/fobs/fobs.py b/nvflare/fuel/utils/fobs/fobs.py index a1450b5b95..a907ba0f5c 100644 --- a/nvflare/fuel/utils/fobs/fobs.py +++ b/nvflare/fuel/utils/fobs/fobs.py @@ -66,15 +66,10 @@ def _get_type_name(cls: Type) -> str: def _load_class(type_name: str): try: - parts = type_name.split(".") - if len(parts) == 1: - parts = ["builtins", type_name] + module_name, class_name = type_name.rsplit('.', 1) + module = importlib.import_module(module_name) - mod = __import__(parts[0]) - for comp in parts[1:]: - mod = getattr(mod, comp) - - return mod + return getattr(module, class_name) except Exception as ex: raise TypeError(f"Can't load class {type_name}: {ex}") @@ -243,7 +238,7 @@ def register_folder(folder: str, package: str): # classes who are abstract or take extra args in __init__ can't be auto-registered if issubclass(cls_obj, Decomposer) and not inspect.isabstract(cls_obj) and len(spec.args) == 1: register(cls_obj) - except (ModuleNotFoundError, RuntimeError) as e: + except (ModuleNotFoundError, RuntimeError, ValueError) as e: log.debug( f"Try to import module {decomposers}, but failed: {secure_format_exception(e)}. " f"Can't use name in config to refer to classes in module: {decomposers}." @@ -275,7 +270,7 @@ def register_custom_folder(folder: str): log.warning( f"Invalid Decomposer from {module}: can't have argument in Decomposer's constructor" ) - except (ModuleNotFoundError, RuntimeError): + except (ModuleNotFoundError, RuntimeError, ValueError): pass From 69a84a2212fd0b69bec7b1a0e2b00bfd9bf99b82 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Fri, 20 Dec 2024 21:52:56 -0500 Subject: [PATCH 2/3] Fixed format --- nvflare/fuel/utils/fobs/fobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nvflare/fuel/utils/fobs/fobs.py b/nvflare/fuel/utils/fobs/fobs.py index a907ba0f5c..a3a49d8a7b 100644 --- a/nvflare/fuel/utils/fobs/fobs.py +++ b/nvflare/fuel/utils/fobs/fobs.py @@ -66,7 +66,7 @@ def _get_type_name(cls: Type) -> str: def _load_class(type_name: str): try: - module_name, class_name = type_name.rsplit('.', 1) + module_name, class_name = type_name.rsplit(".", 1) module = importlib.import_module(module_name) return getattr(module, class_name) From f6c4f84d539b35dba581ac35f1b95e82e724ea6f Mon Sep 17 00:00:00 2001 From: Zhihong Zhang Date: Fri, 3 Jan 2025 19:59:22 -0500 Subject: [PATCH 3/3] Added check for builtin classes --- nvflare/fuel/utils/fobs/fobs.py | 11 +++++++---- tests/unit_test/fuel/utils/fobs/fobs_test.py | 9 +++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/nvflare/fuel/utils/fobs/fobs.py b/nvflare/fuel/utils/fobs/fobs.py index a3a49d8a7b..7a69ea17e7 100644 --- a/nvflare/fuel/utils/fobs/fobs.py +++ b/nvflare/fuel/utils/fobs/fobs.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import builtins import importlib import inspect import logging @@ -66,10 +67,12 @@ def _get_type_name(cls: Type) -> str: def _load_class(type_name: str): try: - module_name, class_name = type_name.rsplit(".", 1) - module = importlib.import_module(module_name) - - return getattr(module, class_name) + if "." in type_name: + module_name, class_name = type_name.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + else: + return getattr(builtins, type_name) except Exception as ex: raise TypeError(f"Can't load class {type_name}: {ex}") diff --git a/tests/unit_test/fuel/utils/fobs/fobs_test.py b/tests/unit_test/fuel/utils/fobs/fobs_test.py index 360fe415b4..4e800c491c 100644 --- a/tests/unit_test/fuel/utils/fobs/fobs_test.py +++ b/tests/unit_test/fuel/utils/fobs/fobs_test.py @@ -28,6 +28,7 @@ class TestFobs: NUMBER = 123456 FLOAT = 123.456 NAME = "FOBS Test" + SET = {4, 5, 6} NOW = datetime.now() test_data = { @@ -35,7 +36,7 @@ class TestFobs: "number": NUMBER, "float": FLOAT, "list": [7, 8, 9], - "set": {4, 5, 6}, + "set": SET, "tuple": ("abc", "xyz"), "time": NOW, } @@ -44,11 +45,7 @@ def test_builtin(self): buf = fobs.dumps(TestFobs.test_data) data = fobs.loads(buf) assert data["number"] == TestFobs.NUMBER - - def test_aliases(self): - buf = fobs.dumps(TestFobs.test_data) - data = fobs.loads(buf) - assert data["number"] == TestFobs.NUMBER + assert data["set"] == TestFobs.SET def test_unsupported_classes(self): with pytest.raises(TypeError):