-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Copy in Sam's output parser from pyiron_contrib issue #717
And use his examples as tests
- Loading branch information
Showing
2 changed files
with
121 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
""" | ||
Inspects code to automatically parse return values as strings | ||
""" | ||
|
||
import ast | ||
import inspect | ||
import re | ||
|
||
|
||
def _remove_spaces_until_character(string): | ||
pattern = r'\s+(?=\s)' | ||
modified_string = re.sub(pattern, '', string) | ||
return modified_string | ||
|
||
|
||
class ParseOutput: | ||
def __init__(self, function): | ||
self._func = function | ||
self._source = None | ||
|
||
@property | ||
def func(self): | ||
return self._func | ||
|
||
@property | ||
def node_return(self): | ||
tree = ast.parse(inspect.getsource(self.func)) | ||
for node in ast.walk(tree): | ||
if isinstance(node, ast.Return): | ||
return node | ||
|
||
@property | ||
def source(self): | ||
if self._source is None: | ||
self._source = [ | ||
line.rsplit("\n", 1)[0] for line in inspect.getsourcelines(self.func)[0] | ||
] | ||
return self._source | ||
|
||
def get_string(self, node): | ||
string = "" | ||
for ll in range(node.lineno - 1, node.end_lineno): | ||
if ll == node.lineno - 1 == node.end_lineno - 1: | ||
string += _remove_spaces_until_character( | ||
self.source[ll][node.col_offset:node.end_col_offset] | ||
) | ||
elif ll == node.lineno - 1: | ||
string += _remove_spaces_until_character( | ||
self.source[ll][node.col_offset:] | ||
) | ||
elif ll == node.end_lineno - 1: | ||
string += _remove_spaces_until_character( | ||
self.source[ll][:node.end_col_offset] | ||
) | ||
else: | ||
string += _remove_spaces_until_character(self.source[ll]) | ||
return string | ||
|
||
@property | ||
def output(self): | ||
if self.node_return is None: | ||
return | ||
if isinstance(self.node_return.value, ast.Tuple): | ||
return [self.get_string(s) for s in self.node_return.value.dims] | ||
return [self.get_string(self.node_return.value)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from sys import version_info | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
from pyiron_contrib.workflow.output_parser import ParseOutput | ||
|
||
|
||
@unittest.skipUnless( | ||
version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+" | ||
) | ||
class TestParseOutput(unittest.TestCase): | ||
def test_parsing(self): | ||
with self.subTest("Single return"): | ||
def identity(x): | ||
return x | ||
self.assertListEqual(ParseOutput(identity).output, ["x"]) | ||
|
||
with self.subTest("Expression return"): | ||
def add(x, y): | ||
return x + y | ||
self.assertListEqual(ParseOutput(add).output, ["x + y"]) | ||
|
||
with self.subTest("Weird whitespace"): | ||
def add(x, y): | ||
return x + y | ||
self.assertListEqual(ParseOutput(add).output, ["x + y"]) | ||
|
||
with self.subTest("Multiple expressions"): | ||
def add_and_subtract(x, y): | ||
return x + y, x - y | ||
self.assertListEqual(ParseOutput(add).output, ["x + y", "x - y"]) | ||
|
||
with self.subTest("Best-practice (well-named return vars)"): | ||
def md(job): | ||
temperature = job.output.temperature | ||
energy = job.output.energy | ||
return temperature, energy | ||
self.assertListEqual(ParseOutput(md).output, ["temperature", "energy"]) | ||
|
||
with self.subTest("Function call returns"): | ||
def function_return(i, j): | ||
return ( | ||
np.arange( | ||
i, dtype=int | ||
), | ||
np.shape(i, j) | ||
) | ||
self.assertListEqual( | ||
ParseOutput(function_return).output, | ||
["np.arange( i, dtype=int )", "np.shape(i, j)"] | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |