Skip to content

Commit

Permalink
handled super methods properly
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed May 31, 2024
1 parent 07700f2 commit 985ab46
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 26 deletions.
44 changes: 28 additions & 16 deletions nipype2pydra/interface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,15 +1094,11 @@ def _get_referenced(
re.findall(return_value + r"\[(?:'|\")(\w+)(?:'|\")\] *=", method_body)
)
for match in re.findall(r"super\([^\)]*\)\.(\w+)\(", method_body):
super_method = None
for base in self.nipype_interface.__mro__[1:]:
if match in base.__dict__: # Found the match
super_method = getattr(base, match)
break
assert super_method is not None, (
f"Could not find super of '{match}' method in base classes of "
f"{self.nipype_interface}"
)
super_method, base = find_super_method(super_base, match)
if any(
base.__module__.startswith(m) for m in UsedSymbols.ALWAYS_OMIT_MODULES
):
continue
func_name = self._common_parent_pkg_prefix(base) + match
if func_name not in referenced_supers:
referenced_supers[func_name] = (super_method, base)
Expand Down Expand Up @@ -1289,12 +1285,16 @@ def process_method_body(
def replace_supers(self, method_body, super_base=None):
if super_base is None:
super_base = self.nipype_interface
super_name_map = self.method_supers[super_base]
return re.sub(
r"super\([^\)]*\)\.(\w+)\(",
lambda m: super_name_map[m.group(1)] + "(",
method_body,
)
name_map = self.method_supers[super_base]

def replace_super(match):
super_method = find_super_method(super_base, match.group(1))[0]
try:
return self.SPECIAL_SUPER_MAPPINGS[super_method]
except KeyError:
return name_map[match.group(1)] + "(" + match.group(2) + ")"

return re.sub(r"super\([^\)]*\)\.(\w+)\(([^\)]*)\)", replace_super, method_body)

def unwrap_nested_methods(self, method_body, additional_args=()):
"""
Expand Down Expand Up @@ -1354,7 +1354,7 @@ def unwrap_nested_methods(self, method_body, additional_args=()):
)
return cleanup_function_body(method_body)

SUPER_MAPPINGS = {CommandLine: {"_list_outputs": "{}"}}
SPECIAL_SUPER_MAPPINGS = {CommandLine._list_outputs: "{}"}

INPUT_KEYS = [
"allowed_values",
Expand Down Expand Up @@ -1407,3 +1407,15 @@ def pytest_configure(config):
else:
CATCH_CLI_EXCEPTIONS = True
"""


def find_super_method(
super_base: type, method_name: str
) -> ty.Tuple[ty.Callable, type]:
for base in super_base.__mro__[1:]:
if method_name in base.__dict__: # Found the match
return getattr(base, method_name), base
raise RuntimeError(
f"Could not find super of '{method_name}' method in base classes of "
f"{super_base}"
)
21 changes: 11 additions & 10 deletions nipype2pydra/interface/shell_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import cached_property
from copy import copy
from operator import attrgetter, itemgetter
from importlib import import_module
from nipype.interfaces.base import BaseInterface, TraitedSpec
from .base import BaseInterfaceConverter
from ..utils import (
Expand Down Expand Up @@ -162,14 +163,6 @@ def types_to_names(spec_fields):
new_name=new_name,
)

imports = self.construct_imports(
nonstd_types,
spec_str,
include_task=False,
base=base_imports,
)
# spec_str = "\n".join(str(i) for i in imports) + "\n\n" + spec_str

used = UsedSymbols.find(
self.nipype_module,
[
Expand All @@ -188,7 +181,7 @@ def types_to_names(spec_fields):
)
for super_method, base in self.referenced_supers.values():
super_used = UsedSymbols.find(
base,
import_module(base.__module__),
[super_method],
omit_classes=self.package.omit_classes + [BaseInterface, TraitedSpec],
omit_modules=self.package.omit_modules,
Expand All @@ -197,10 +190,18 @@ def types_to_names(spec_fields):
always_include=self.package.all_explicit,
translations=self.package.all_import_translations,
absolute_imports=True,
collapse_intra_pkg=True,
)
used.update(super_used)

used.imports.update(imports)
used.imports.update(
self.construct_imports(
nonstd_types,
spec_str,
include_task=False,
base=base_imports,
)
)

return spec_str, used

Expand Down

0 comments on commit 985ab46

Please sign in to comment.