Skip to content

Commit

Permalink
Fix an issue with aliases in environment escape server mappings.
Browse files Browse the repository at this point in the history
We were not properly creating all the aliased classes in the conda
environment (only creating the "canonical" one (which is the first
one listed in the server_mappings.py file).

Similarly, the hierarchy of exceptions was not properly created
in all cases.
  • Loading branch information
romain-intel committed Dec 1, 2023
1 parent 181aeca commit 2e1fc04
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 37 deletions.
14 changes: 2 additions & 12 deletions metaflow/plugins/env_escape/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .exception_transferer import load_exception
from .override_decorators import LocalAttrOverride, LocalException, LocalOverride
from .stub import create_class
from .utils import get_canonical_name

BIND_TIMEOUT = 0.1
BIND_RETRY = 0
Expand Down Expand Up @@ -336,7 +337,7 @@ def decode(self, json_obj):
def get_local_class(self, name, obj_id=None):
# Gets (and creates if needed), the class mapping to the remote
# class of name 'name'.
name = self._get_canonical_name(name)
name = get_canonical_name(name, self._aliases)
if name == "function":
# Special handling of pickled functions. We create a new class that
# simply has a __call__ method that will forward things back to
Expand Down Expand Up @@ -398,17 +399,6 @@ def unpickle_object(self, obj):
local_instance = local_class(self, remote_class_name, obj_id)
return local_instance

def _get_canonical_name(self, name):
# We look at the aliases looking for the most specific match first
base_name = self._aliases.get(name)
if base_name is not None:
return base_name
for idx in reversed([pos for pos, char in enumerate(name) if char == "."]):
base_name = self._aliases.get(name[:idx])
if base_name is not None:
return ".".join([base_name, name[idx + 1 :]])
return name

def _communicate(self, msg):
if os.getpid() != self._active_pid:
raise RuntimeError(
Expand Down
32 changes: 18 additions & 14 deletions metaflow/plugins/env_escape/client_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .consts import OP_CALLFUNC, OP_GETVAL, OP_SETVAL
from .client import Client
from .override_decorators import LocalException
from .utils import get_canonical_name


def _clean_client(client):
Expand All @@ -23,6 +24,7 @@ def __init__(self, loader, prefix, exports, exception_classes, client):
r"^%s\.([a-zA-Z_][a-zA-Z0-9_]*)$" % prefix.replace(".", r"\.") # noqa W605
)
self._exports = {}
self._aliases = exports["aliases"]
for k in ("classes", "functions", "values"):
result = []
for item in exports[k]:
Expand All @@ -43,6 +45,11 @@ def __getattr__(self, name):
return self._prefix
if name in ("__file__", "__path__"):
return self._client.name

# Make the name canonical because the prefix is also canonical.
name = get_canonical_name(self._prefix + "." + name, self._aliases)[
len(self._prefix) + 1 :
]
if name in self._exports["classes"]:
# We load classes lazily
return self._client.get_local_class("%s.%s" % (self._prefix, name))
Expand Down Expand Up @@ -87,6 +94,7 @@ def __setattr__(self, name, value):
"_client",
"_exports",
"_exception_classes",
"_aliases",
):
object.__setattr__(self, name, value)
return
Expand All @@ -95,6 +103,11 @@ def __setattr__(self, name, value):
# module when loading
object.__setattr__(self, name, value)
return

# Make the name canonical because the prefix is also canonical.
name = get_canonical_name(self._prefix + "." + name, self._aliases)[
len(self._prefix) + 1 :
]
if name in self._exports["values"]:
self._client.stub_request(
None, OP_SETVAL, "%s.%s" % (self._prefix, name), value
Expand Down Expand Up @@ -126,7 +139,7 @@ def __init__(

def find_module(self, fullname, path=None):
if self._handled_modules is not None:
if fullname in self._handled_modules:
if get_canonical_name(fullname, self._aliases) in self._handled_modules:
return self
return None
if any([fullname.startswith(prefix) for prefix in self._module_prefixes]):
Expand Down Expand Up @@ -224,24 +237,15 @@ def load_module(self, fullname):
self._handled_modules[prefix] = _WrappedModule(
self, prefix, exports, formed_exception_classes, self._client
)
fullname = self._get_canonical_name(fullname)
module = self._handled_modules.get(fullname)
canonical_fullname = get_canonical_name(fullname, self._aliases)
# Modules are created canonically but we need to return something for any
# of the aliases.
module = self._handled_modules.get(canonical_fullname)
if module is None:
raise ImportError
sys.modules[fullname] = module
return module

def _get_canonical_name(self, name):
# We look at the aliases looking for the most specific match first
base_name = self._aliases.get(name)
if base_name is not None:
return base_name
for idx in reversed([pos for pos, char in enumerate(name) if char == "."]):
base_name = self._aliases.get(name[:idx])
if base_name is not None:
return ".".join([base_name, name[idx + 1 :]])
return name


def create_modules(python_executable, pythonpath, max_pickle_version, path, prefixes):
# This is an extra verification to make sure we are not trying to use the
Expand Down
29 changes: 18 additions & 11 deletions metaflow/plugins/env_escape/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,14 @@
RemoteExceptionSerializer,
)
from .exception_transferer import dump_exception
from .utils import get_methods
from .utils import get_methods, get_canonical_name

BIND_TIMEOUT = 0.1
BIND_RETRY = 1


class Server(object):
def __init__(self, config_dir, max_pickle_version):

self._max_pickle_version = data_transferer.defaultProtocol = max_pickle_version
try:
mappings = importlib.import_module(".server_mappings", package=config_dir)
Expand Down Expand Up @@ -108,6 +107,11 @@ def __init__(self, config_dir, max_pickle_version):
for alias in aliases:
a = self._aliases.setdefault(alias, base_name)
if a != base_name:
# Technically we could have a that aliases b and b that aliases c
# and then a that aliases c. This would error out in that case
# even though it is valid. It is easy for the user to get around
# this by listing aliases in the same order so we don't support
# it for now.
raise ValueError(
"%s is an alias to both %s and %s" % (alias, base_name, a)
)
Expand Down Expand Up @@ -155,30 +159,33 @@ def __init__(self, config_dir, max_pickle_version):
parent_to_child = {}

for ex_name, ex_cls in self._known_exceptions.items():
ex_name_canonical = get_canonical_name(ex_name, self._aliases)
parents = []
for base in ex_cls.__mro__[1:]:
if base is object:
raise ValueError(
"Exported exceptions not rooted in a builtin exception are not supported: %s"
% ex_name
"Exported exceptions not rooted in a builtin exception "
"are not supported: %s." % ex_name
)
if base.__module__ == "builtins":
# We found our base exception
parents.append("builtins." + base.__name__)
break
else:
fqn = ".".join([base.__module__, base.__name__])
if fqn in self._known_exceptions:
parents.append(fqn)
children = parent_to_child.setdefault(fqn, [])
children.append(ex_name)
canonical_fqn = get_canonical_name(fqn, self._aliases)
if canonical_fqn in self._known_exceptions:
parents.append(canonical_fqn)
children = parent_to_child.setdefault(canonical_fqn, [])
children.append(ex_name_canonical)
else:
raise ValueError(
"Exported exception %s has non exported and non builtin parent "
"exception: %s" % (ex_name, fqn)
"exception: %s. Known exceptions: %s"
% (ex_name, fqn, str(self._known_exceptions))
)
name_to_parent_count[ex_name] = len(parents) - 1
name_to_parents[ex_name] = parents
name_to_parent_count[ex_name_canonical] = len(parents) - 1
name_to_parents[ex_name_canonical] = parents

# We now form the exceptions and put them in self._known_exceptions in
# the proper order (topologically)
Expand Down
12 changes: 12 additions & 0 deletions metaflow/plugins/env_escape/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,15 @@ def get_methods(class_object):
elif isinstance(attribute, classmethod):
all_methods["___c___%s" % name] = inspect.getdoc(attribute)
return all_methods


def get_canonical_name(name, aliases):
# We look at the aliases looking for the most specific match first
base_name = aliases.get(name)
if base_name is not None:
return base_name
for idx in reversed([pos for pos, char in enumerate(name) if char == "."]):
base_name = aliases.get(name[:idx])
if base_name is not None:
return ".".join([base_name, name[idx + 1 :]])
return name

0 comments on commit 2e1fc04

Please sign in to comment.