diff --git a/pyiron_contrib/workflow/output_parser.py b/pyiron_contrib/workflow/output_parser.py new file mode 100644 index 000000000..e5405c239 --- /dev/null +++ b/pyiron_contrib/workflow/output_parser.py @@ -0,0 +1,65 @@ +""" +Inspects code to automatically parse return values as strings +""" + +import ast +import inspect +import re + + +def _remove_spaces_until_character(string): + pattern = r'\s+(?=\s)' + modified_string = re.sub(pattern, '', string) + return modified_string + + +class ParseOutput: + def __init__(self, function): + self._func = function + self._source = None + + @property + def func(self): + return self._func + + @property + def node_return(self): + tree = ast.parse(inspect.getsource(self.func)) + for node in ast.walk(tree): + if isinstance(node, ast.Return): + return node + + @property + def source(self): + if self._source is None: + self._source = [ + line.rsplit("\n", 1)[0] for line in inspect.getsourcelines(self.func)[0] + ] + return self._source + + def get_string(self, node): + string = "" + for ll in range(node.lineno - 1, node.end_lineno): + if ll == node.lineno - 1 == node.end_lineno - 1: + string += _remove_spaces_until_character( + self.source[ll][node.col_offset:node.end_col_offset] + ) + elif ll == node.lineno - 1: + string += _remove_spaces_until_character( + self.source[ll][node.col_offset:] + ) + elif ll == node.end_lineno - 1: + string += _remove_spaces_until_character( + self.source[ll][:node.end_col_offset] + ) + else: + string += _remove_spaces_until_character(self.source[ll]) + return string + + @property + def output(self): + if self.node_return is None: + return + if isinstance(self.node_return.value, ast.Tuple): + return [self.get_string(s) for s in self.node_return.value.dims] + return [self.get_string(self.node_return.value)] \ No newline at end of file diff --git a/tests/unit/workflow/test_output_parser.py b/tests/unit/workflow/test_output_parser.py new file mode 100644 index 000000000..87c44be45 --- /dev/null +++ b/tests/unit/workflow/test_output_parser.py @@ -0,0 +1,56 @@ +from sys import version_info +import unittest + +import numpy as np + +from pyiron_contrib.workflow.output_parser import ParseOutput + + +@unittest.skipUnless( + version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+" +) +class TestParseOutput(unittest.TestCase): + def test_parsing(self): + with self.subTest("Single return"): + def identity(x): + return x + self.assertListEqual(ParseOutput(identity).output, ["x"]) + + with self.subTest("Expression return"): + def add(x, y): + return x + y + self.assertListEqual(ParseOutput(add).output, ["x + y"]) + + with self.subTest("Weird whitespace"): + def add(x, y): + return x + y + self.assertListEqual(ParseOutput(add).output, ["x + y"]) + + with self.subTest("Multiple expressions"): + def add_and_subtract(x, y): + return x + y, x - y + self.assertListEqual(ParseOutput(add).output, ["x + y", "x - y"]) + + with self.subTest("Best-practice (well-named return vars)"): + def md(job): + temperature = job.output.temperature + energy = job.output.energy + return temperature, energy + self.assertListEqual(ParseOutput(md).output, ["temperature", "energy"]) + + with self.subTest("Function call returns"): + def function_return(i, j): + return ( + np.arange( + i, dtype=int + ), + np.shape(i, j) + ) + self.assertListEqual( + ParseOutput(function_return).output, + ["np.arange( i, dtype=int )", "np.shape(i, j)"] + ) + + +if __name__ == '__main__': + unittest.main()