diff --git a/nvflare/fuel/utils/fobs/fobs.py b/nvflare/fuel/utils/fobs/fobs.py index a1450b5b95..7a69ea17e7 100644 --- a/nvflare/fuel/utils/fobs/fobs.py +++ b/nvflare/fuel/utils/fobs/fobs.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import builtins import importlib import inspect import logging @@ -66,15 +67,12 @@ def _get_type_name(cls: Type) -> str: def _load_class(type_name: str): try: - parts = type_name.split(".") - if len(parts) == 1: - parts = ["builtins", type_name] - - mod = __import__(parts[0]) - for comp in parts[1:]: - mod = getattr(mod, comp) - - return mod + if "." in type_name: + module_name, class_name = type_name.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + else: + return getattr(builtins, type_name) except Exception as ex: raise TypeError(f"Can't load class {type_name}: {ex}") @@ -243,7 +241,7 @@ def register_folder(folder: str, package: str): # classes who are abstract or take extra args in __init__ can't be auto-registered if issubclass(cls_obj, Decomposer) and not inspect.isabstract(cls_obj) and len(spec.args) == 1: register(cls_obj) - except (ModuleNotFoundError, RuntimeError) as e: + except (ModuleNotFoundError, RuntimeError, ValueError) as e: log.debug( f"Try to import module {decomposers}, but failed: {secure_format_exception(e)}. " f"Can't use name in config to refer to classes in module: {decomposers}." @@ -275,7 +273,7 @@ def register_custom_folder(folder: str): log.warning( f"Invalid Decomposer from {module}: can't have argument in Decomposer's constructor" ) - except (ModuleNotFoundError, RuntimeError): + except (ModuleNotFoundError, RuntimeError, ValueError): pass diff --git a/tests/unit_test/fuel/utils/fobs/fobs_test.py b/tests/unit_test/fuel/utils/fobs/fobs_test.py index 360fe415b4..4e800c491c 100644 --- a/tests/unit_test/fuel/utils/fobs/fobs_test.py +++ b/tests/unit_test/fuel/utils/fobs/fobs_test.py @@ -28,6 +28,7 @@ class TestFobs: NUMBER = 123456 FLOAT = 123.456 NAME = "FOBS Test" + SET = {4, 5, 6} NOW = datetime.now() test_data = { @@ -35,7 +36,7 @@ class TestFobs: "number": NUMBER, "float": FLOAT, "list": [7, 8, 9], - "set": {4, 5, 6}, + "set": SET, "tuple": ("abc", "xyz"), "time": NOW, } @@ -44,11 +45,7 @@ def test_builtin(self): buf = fobs.dumps(TestFobs.test_data) data = fobs.loads(buf) assert data["number"] == TestFobs.NUMBER - - def test_aliases(self): - buf = fobs.dumps(TestFobs.test_data) - data = fobs.loads(buf) - assert data["number"] == TestFobs.NUMBER + assert data["set"] == TestFobs.SET def test_unsupported_classes(self): with pytest.raises(TypeError):