Skip to content

Commit

Permalink
added support for _format_arg and _parse_inputs methods
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed May 31, 2024
1 parent 5e8da87 commit 843a83a
Show file tree
Hide file tree
Showing 6 changed files with 488 additions and 279 deletions.
4 changes: 3 additions & 1 deletion nipype2pydra/cli/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def convert(
# Clean previous version of output dir
package_dir = converter.package_dir(package_root)
if converter.interface_only:
shutil.rmtree(package_dir / "auto")
auto_dir = package_dir / "auto"
if auto_dir.exists():
shutil.rmtree(auto_dir)
else:
for fspath in package_dir.iterdir():
if fspath.parent == package_dir and fspath.name in (
Expand Down
299 changes: 299 additions & 0 deletions nipype2pydra/interface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
types_converter,
from_dict_converter,
unwrap_nested_type,
get_local_functions,
get_local_constants,
get_return_line,
cleanup_function_body,
insert_args_in_signature,
extract_args,
)
from ..statements import (
ImportStatement,
Expand Down Expand Up @@ -397,6 +403,10 @@ def nipype_output_spec(self) -> nipype.interfaces.base.BaseTraitedSpec:
def input_fields(self):
return self._convert_input_fields[0]

@property
def input_names(self):
return [f[0] for f in self.input_fields]

@cached_property
def input_templates(self):
return self._convert_input_fields[1]
Expand Down Expand Up @@ -440,6 +450,65 @@ def _converted(self):
self.input_fields, self.nonstd_types, self.output_fields
)

@property
def referenced_local_functions(self):
return self._referenced_funcs_and_methods[0]

@property
def referenced_methods(self):
return self._referenced_funcs_and_methods[1]

@property
def method_args(self):
return self._referenced_funcs_and_methods[2]

@property
def method_returns(self):
return self._referenced_funcs_and_methods[3]

@cached_property
def _referenced_funcs_and_methods(self):
referenced_funcs = set()
referenced_methods = set()
method_args = {}
method_returns = {}
already_processed = set(
getattr(self.nipype_interface, m) for m in self.INCLUDED_METHODS
)
for method_name in self.INCLUDED_METHODS:
if method_name not in self.nipype_interface.__dict__:
continue # Don't include base methods
self._get_referenced(
getattr(self.nipype_interface, method_name),
referenced_funcs,
referenced_methods,
method_args,
method_returns,
already_processed=already_processed,
)
return referenced_funcs, referenced_methods, method_args, method_returns

@cached_property
def source_code(self):
with open(inspect.getsourcefile(self.nipype_interface)) as f:
return f.read()

@cached_property
def methods(self):
"""Get the methods defined in the interface"""
methods = []
for attr_name in dir(self.nipype_interface):
if attr_name.startswith("__"):
continue
attr = getattr(self.nipype_interface, attr_name)
if inspect.isfunction(attr):
methods.append(attr)
return methods

@cached_property
def local_function_names(self):
return [f.__name__ for f in self.local_functions]

def write(
self,
package_root: Path,
Expand Down Expand Up @@ -653,6 +722,8 @@ def function_callables(self):
for fun_nm in fun_names:
fun = getattr(self.callables_module, fun_nm)
fun_str += inspect.getsource(fun) + "\n"
list_outputs = getattr(self.callables_module, "_list_outputs")
fun_str += inspect.getsource(list_outputs) + "\n"
return fun_str

def pydra_type_converter(self, field, spec_type, name):
Expand Down Expand Up @@ -904,6 +975,234 @@ def create_doctests(self, input_fields, nonstd_types):

return " Examples\n -------\n\n" + doctest_str

def _get_referenced(
self,
method: ty.Callable,
referenced_funcs: ty.Set[ty.Callable],
referenced_methods: ty.Set[ty.Callable] = None,
method_args: ty.Dict[str, ty.List[str]] = None,
method_returns: ty.Dict[str, ty.List[str]] = None,
already_processed: ty.Set[ty.Callable] = None,
) -> ty.Tuple[ty.Set, ty.Set]:
"""Get the local functions referenced in the source code
Parameters
----------
src: str
the source of the file to extract the import statements from
referenced_funcs: set[function]
the set of local functions that have been referenced so far
referenced_methods: set[function]
the set of methods that have been referenced so far
method_args: dict[str, list[str]]
a dictionary to hold additional arguments that need to be added to each method,
where the dictionary key is the names of the methods
method_returns: dict[str, list[str]]
a dictionary to hold the return values of each method,
where the dictionary key is the names of the methods
Returns
-------
referenced_inputs: set[str]
inputs that have been referenced
referenced_outputs: set[str]
outputs that have been referenced
"""
if already_processed:
already_processed.add(method)
else:
already_processed = {method}
method_body = inspect.getsource(method)
method_body = re.sub(r"\s*#.*", "", method_body) # Strip out comments
return_value = get_return_line(method_body)
ref_local_func_names = re.findall(r"(?<!self\.)(\w+)\(", method_body)
ref_local_funcs = set(
f
for f in self.local_functions
if f.__name__ in ref_local_func_names and f not in referenced_funcs
)

ref_method_names = re.findall(r"(?<=self\.)(\w+)\(", method_body)
ref_methods = set(m for m in self.methods if m.__name__ in ref_method_names)

referenced_funcs.update(ref_local_funcs)
referenced_methods.update(ref_methods)

referenced_inputs = set(re.findall(r"(?<=self\.inputs\.)(\w+)", method_body))
referenced_outputs = set(re.findall(r"self\.(\w+) *=", method_body))
if return_value and return_value.startswith("self."):
referenced_outputs.update(
re.findall(return_value + r"\[(?:'|\")(\w+)(?:'|\")\] *=", method_body)
)
for func in ref_local_funcs:
if func in already_processed:
continue
rf_inputs, rf_outputs = self._get_referenced(
func,
referenced_funcs,
referenced_methods,
already_processed=already_processed,
)
referenced_inputs.update(rf_inputs)
referenced_outputs.update(rf_outputs)
for meth in ref_methods:
if meth in already_processed:
continue
ref_inputs, ref_outputs = self._get_referenced(
meth,
referenced_funcs,
referenced_methods,
method_args=method_args,
method_returns=method_returns,
already_processed=already_processed,
)
method_args[meth.__name__] = ref_inputs
method_returns[meth.__name__] = ref_outputs
referenced_inputs.update(ref_inputs)
referenced_outputs.update(ref_outputs)
return referenced_inputs, sorted(referenced_outputs)

@cached_property
def local_functions(self):
"""Get the functions defined in the same file as the interface"""
return get_local_functions(self.nipype_module)

@cached_property
def local_constants(self):
return get_local_constants(self.nipype_module)

def process_method(
self,
method: str,
input_names: ty.List[str],
output_names: ty.List[str],
method_args: ty.Dict[str, ty.List[str]] = None,
method_returns: ty.Dict[str, ty.List[str]] = None,
):
src = inspect.getsource(method)
pre, args, post = extract_args(src)
try:
args.remove("self")
except ValueError:
pass
if "runtime" in args:
args.remove("runtime")
if method.__name__ in self.method_args:
args += [f"{a}=None" for a in self.method_args[method.__name__]]
# Insert method args in signature if present
return_types, method_body = post.split(":", maxsplit=1)
method_body = method_body.split("\n", maxsplit=1)[1]
method_body = self.process_method_body(method_body, input_names, output_names)
if self.method_returns.get(method.__name__):
return_args = self.method_returns[method.__name__]
method_body = (
" " + " = ".join(return_args) + " = attrs.NOTHING\n" + method_body
)
method_lines = method_body.rstrip().splitlines()
method_body = "\n".join(method_lines[:-1])
last_line = method_lines[-1]
if "return" in last_line:
method_body += "\n" + last_line + "," + ",".join(return_args)
else:
method_body += (
"\n" + last_line + "\n return " + ",".join(return_args)
)
pre = re.sub(r"^\s*", "", pre, flags=re.MULTILINE)
pre = pre.replace("@staticmethod\n", "")
return f"{pre}{', '.join(args)}{return_types}:\n{method_body}"

def process_method_body(
self, method_body: str, input_names: ty.List[str], output_names: ty.List[str]
) -> str:
return_value = get_return_line(method_body)
method_body = method_body.replace("if self.output_spec:", "if True:")
# Replace self.inputs.<name> with <name> in the function body
input_re = re.compile(r"self\.inputs\.(\w+)\b(?!\()")
unrecognised_inputs = set(
m for m in input_re.findall(method_body) if m not in input_names
)
if unrecognised_inputs:
logger.warning(
"Found the following unrecognised (potentially dynamic) inputs %s in "
"'%s' task",
unrecognised_inputs,
self.task_name,
)
method_body = input_re.sub(r"\1", method_body)

if return_value:
output_re = re.compile(return_value + r"\[(?:'|\")(\w+)(?:'|\")\]")
unrecognised_outputs = set(
m for m in output_re.findall(method_body) if m not in output_names
)
if unrecognised_outputs:
logger.warning(
"Found the following unrecognised (potentially dynamic) outputs %s in "
"'%s' task",
unrecognised_outputs,
self.task_name,
)
method_body = output_re.sub(r"\1", method_body)
# Strip initialisation of outputs
method_body = re.sub(
r"outputs = self.output_spec().*", r"outputs = {}", method_body
)
return self.unwrap_nested_methods(method_body)

def unwrap_nested_methods(self, method_body):
"""
Converts nested method calls into function calls
"""
# Add args to the function signature of method calls
method_re = re.compile(r"self\.(\w+)(?=\()", flags=re.MULTILINE | re.DOTALL)
method_names = [m.__name__ for m in self.referenced_methods]
unrecognised_methods = set(
m for m in method_re.findall(method_body) if m not in method_names
)
assert (
not unrecognised_methods
), f"Found the following unrecognised methods {unrecognised_methods}"
splits = method_re.split(method_body)
new_body = splits[0]
for name, args in zip(splits[1::2], splits[2::2]):
# Assign additional return values (which were previously saved to member
# attributes) to new variables from the method call
if self.method_returns[name]:
last_line = new_body.splitlines()[-1]
match = re.match(r" *([a-zA-Z0-9\,\.\_ ]+ *=)? *$", last_line)
if match:
if match.group(1):
new_body_lines = new_body.splitlines()
new_body = "\n".join(new_body_lines[:-1])
last_line = new_body_lines[-1]
new_body += "\n" + re.sub(
r"^( *)([a-zA-Z0-9\,\.\_ ]+) *= *$",
r"\1\2, " + ",".join(self.method_returns[name]) + " = ",
last_line,
flags=re.MULTILINE,
)
else:
new_body += ",".join(self.method_returns[name]) + " = "
else:
raise NotImplementedError(
"Could not augment the return value of the method converted from "
"a function with the previously assigned attributes as it is used "
"directly. Need to replace the method call with a variable and "
"assign the return value to it on a previous line"
)
# Insert additional arguments to the method call (which were previously
# accessed via member attributes)
new_body += name + insert_args_in_signature(
args, [f"{a}={a}" for a in self.method_args[name]]
)
method_body = new_body
# Convert assignment to self attributes into method-scoped variables (hopefully
# there aren't any name clashes)
method_body = re.sub(
r"self\.(\w+ *)(?==)", r"\1", method_body, flags=re.MULTILINE | re.DOTALL
)
return cleanup_function_body(method_body)

INPUT_KEYS = [
"allowed_values",
"argstr",
Expand Down
Loading

0 comments on commit 843a83a

Please sign in to comment.