From 49d1a1f013c64e2580794baa04540978852469ad Mon Sep 17 00:00:00 2001 From: Romain Date: Mon, 4 Dec 2023 12:57:42 -0800 Subject: [PATCH] Fix an issue with aliases in environment escape server mappings. (#1649) We were not properly creating all the aliased classes in the conda environment (only creating the "canonical" one (which is the first one listed in the server_mappings.py file). Similarly, the hierarchy of exceptions was not properly created in all cases. --- metaflow/plugins/env_escape/client.py | 14 ++------ metaflow/plugins/env_escape/client_modules.py | 32 +++++++++++-------- metaflow/plugins/env_escape/server.py | 29 ++++++++++------- metaflow/plugins/env_escape/utils.py | 12 +++++++ 4 files changed, 50 insertions(+), 37 deletions(-) diff --git a/metaflow/plugins/env_escape/client.py b/metaflow/plugins/env_escape/client.py index 0c72bf10533..60ffdf09b68 100644 --- a/metaflow/plugins/env_escape/client.py +++ b/metaflow/plugins/env_escape/client.py @@ -37,6 +37,7 @@ from .exception_transferer import load_exception from .override_decorators import LocalAttrOverride, LocalException, LocalOverride from .stub import create_class +from .utils import get_canonical_name BIND_TIMEOUT = 0.1 BIND_RETRY = 0 @@ -336,7 +337,7 @@ def decode(self, json_obj): def get_local_class(self, name, obj_id=None): # Gets (and creates if needed), the class mapping to the remote # class of name 'name'. - name = self._get_canonical_name(name) + name = get_canonical_name(name, self._aliases) if name == "function": # Special handling of pickled functions. We create a new class that # simply has a __call__ method that will forward things back to @@ -398,17 +399,6 @@ def unpickle_object(self, obj): local_instance = local_class(self, remote_class_name, obj_id) return local_instance - def _get_canonical_name(self, name): - # We look at the aliases looking for the most specific match first - base_name = self._aliases.get(name) - if base_name is not None: - return base_name - for idx in reversed([pos for pos, char in enumerate(name) if char == "."]): - base_name = self._aliases.get(name[:idx]) - if base_name is not None: - return ".".join([base_name, name[idx + 1 :]]) - return name - def _communicate(self, msg): if os.getpid() != self._active_pid: raise RuntimeError( diff --git a/metaflow/plugins/env_escape/client_modules.py b/metaflow/plugins/env_escape/client_modules.py index 9ad3516084b..2896f2c5156 100644 --- a/metaflow/plugins/env_escape/client_modules.py +++ b/metaflow/plugins/env_escape/client_modules.py @@ -8,6 +8,7 @@ from .consts import OP_CALLFUNC, OP_GETVAL, OP_SETVAL from .client import Client from .override_decorators import LocalException +from .utils import get_canonical_name def _clean_client(client): @@ -23,6 +24,7 @@ def __init__(self, loader, prefix, exports, exception_classes, client): r"^%s\.([a-zA-Z_][a-zA-Z0-9_]*)$" % prefix.replace(".", r"\.") # noqa W605 ) self._exports = {} + self._aliases = exports["aliases"] for k in ("classes", "functions", "values"): result = [] for item in exports[k]: @@ -43,6 +45,11 @@ def __getattr__(self, name): return self._prefix if name in ("__file__", "__path__"): return self._client.name + + # Make the name canonical because the prefix is also canonical. + name = get_canonical_name(self._prefix + "." + name, self._aliases)[ + len(self._prefix) + 1 : + ] if name in self._exports["classes"]: # We load classes lazily return self._client.get_local_class("%s.%s" % (self._prefix, name)) @@ -87,6 +94,7 @@ def __setattr__(self, name, value): "_client", "_exports", "_exception_classes", + "_aliases", ): object.__setattr__(self, name, value) return @@ -95,6 +103,11 @@ def __setattr__(self, name, value): # module when loading object.__setattr__(self, name, value) return + + # Make the name canonical because the prefix is also canonical. + name = get_canonical_name(self._prefix + "." + name, self._aliases)[ + len(self._prefix) + 1 : + ] if name in self._exports["values"]: self._client.stub_request( None, OP_SETVAL, "%s.%s" % (self._prefix, name), value @@ -126,7 +139,7 @@ def __init__( def find_module(self, fullname, path=None): if self._handled_modules is not None: - if fullname in self._handled_modules: + if get_canonical_name(fullname, self._aliases) in self._handled_modules: return self return None if any([fullname.startswith(prefix) for prefix in self._module_prefixes]): @@ -224,24 +237,15 @@ def load_module(self, fullname): self._handled_modules[prefix] = _WrappedModule( self, prefix, exports, formed_exception_classes, self._client ) - fullname = self._get_canonical_name(fullname) - module = self._handled_modules.get(fullname) + canonical_fullname = get_canonical_name(fullname, self._aliases) + # Modules are created canonically but we need to return something for any + # of the aliases. + module = self._handled_modules.get(canonical_fullname) if module is None: raise ImportError sys.modules[fullname] = module return module - def _get_canonical_name(self, name): - # We look at the aliases looking for the most specific match first - base_name = self._aliases.get(name) - if base_name is not None: - return base_name - for idx in reversed([pos for pos, char in enumerate(name) if char == "."]): - base_name = self._aliases.get(name[:idx]) - if base_name is not None: - return ".".join([base_name, name[idx + 1 :]]) - return name - def create_modules(python_executable, pythonpath, max_pickle_version, path, prefixes): # This is an extra verification to make sure we are not trying to use the diff --git a/metaflow/plugins/env_escape/server.py b/metaflow/plugins/env_escape/server.py index 11ef2cf4d65..a78e870c883 100644 --- a/metaflow/plugins/env_escape/server.py +++ b/metaflow/plugins/env_escape/server.py @@ -53,7 +53,7 @@ RemoteExceptionSerializer, ) from .exception_transferer import dump_exception -from .utils import get_methods +from .utils import get_methods, get_canonical_name BIND_TIMEOUT = 0.1 BIND_RETRY = 1 @@ -61,7 +61,6 @@ class Server(object): def __init__(self, config_dir, max_pickle_version): - self._max_pickle_version = data_transferer.defaultProtocol = max_pickle_version try: mappings = importlib.import_module(".server_mappings", package=config_dir) @@ -108,6 +107,11 @@ def __init__(self, config_dir, max_pickle_version): for alias in aliases: a = self._aliases.setdefault(alias, base_name) if a != base_name: + # Technically we could have a that aliases b and b that aliases c + # and then a that aliases c. This would error out in that case + # even though it is valid. It is easy for the user to get around + # this by listing aliases in the same order so we don't support + # it for now. raise ValueError( "%s is an alias to both %s and %s" % (alias, base_name, a) ) @@ -155,12 +159,13 @@ def __init__(self, config_dir, max_pickle_version): parent_to_child = {} for ex_name, ex_cls in self._known_exceptions.items(): + ex_name_canonical = get_canonical_name(ex_name, self._aliases) parents = [] for base in ex_cls.__mro__[1:]: if base is object: raise ValueError( - "Exported exceptions not rooted in a builtin exception are not supported: %s" - % ex_name + "Exported exceptions not rooted in a builtin exception " + "are not supported: %s." % ex_name ) if base.__module__ == "builtins": # We found our base exception @@ -168,17 +173,19 @@ def __init__(self, config_dir, max_pickle_version): break else: fqn = ".".join([base.__module__, base.__name__]) - if fqn in self._known_exceptions: - parents.append(fqn) - children = parent_to_child.setdefault(fqn, []) - children.append(ex_name) + canonical_fqn = get_canonical_name(fqn, self._aliases) + if canonical_fqn in self._known_exceptions: + parents.append(canonical_fqn) + children = parent_to_child.setdefault(canonical_fqn, []) + children.append(ex_name_canonical) else: raise ValueError( "Exported exception %s has non exported and non builtin parent " - "exception: %s" % (ex_name, fqn) + "exception: %s. Known exceptions: %s" + % (ex_name, fqn, str(self._known_exceptions)) ) - name_to_parent_count[ex_name] = len(parents) - 1 - name_to_parents[ex_name] = parents + name_to_parent_count[ex_name_canonical] = len(parents) - 1 + name_to_parents[ex_name_canonical] = parents # We now form the exceptions and put them in self._known_exceptions in # the proper order (topologically) diff --git a/metaflow/plugins/env_escape/utils.py b/metaflow/plugins/env_escape/utils.py index 3267cd23aed..f9fe1aacf0f 100644 --- a/metaflow/plugins/env_escape/utils.py +++ b/metaflow/plugins/env_escape/utils.py @@ -20,3 +20,15 @@ def get_methods(class_object): elif isinstance(attribute, classmethod): all_methods["___c___%s" % name] = inspect.getdoc(attribute) return all_methods + + +def get_canonical_name(name, aliases): + # We look at the aliases looking for the most specific match first + base_name = aliases.get(name) + if base_name is not None: + return base_name + for idx in reversed([pos for pos, char in enumerate(name) if char == "."]): + base_name = aliases.get(name[:idx]) + if base_name is not None: + return ".".join([base_name, name[idx + 1 :]]) + return name