Skip to content

Commit

Permalink
Copy in Sam's output parser from pyiron_contrib issue #717
Browse files Browse the repository at this point in the history
And use his examples as tests
  • Loading branch information
samwaseda authored and liamhuber committed Jul 13, 2023
1 parent e342302 commit 0887a9b
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
65 changes: 65 additions & 0 deletions pyiron_contrib/workflow/output_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Inspects code to automatically parse return values as strings
"""

import ast
import inspect
import re


def _remove_spaces_until_character(string):
pattern = r'\s+(?=\s)'
modified_string = re.sub(pattern, '', string)
return modified_string


class ParseOutput:
def __init__(self, function):
self._func = function
self._source = None

@property
def func(self):
return self._func

@property
def node_return(self):
tree = ast.parse(inspect.getsource(self.func))
for node in ast.walk(tree):
if isinstance(node, ast.Return):
return node

@property
def source(self):
if self._source is None:
self._source = [
line.rsplit("\n", 1)[0] for line in inspect.getsourcelines(self.func)[0]
]
return self._source

def get_string(self, node):
string = ""
for ll in range(node.lineno - 1, node.end_lineno):
if ll == node.lineno - 1 == node.end_lineno - 1:
string += _remove_spaces_until_character(
self.source[ll][node.col_offset:node.end_col_offset]
)
elif ll == node.lineno - 1:
string += _remove_spaces_until_character(
self.source[ll][node.col_offset:]
)
elif ll == node.end_lineno - 1:
string += _remove_spaces_until_character(
self.source[ll][:node.end_col_offset]
)
else:
string += _remove_spaces_until_character(self.source[ll])
return string

@property
def output(self):
if self.node_return is None:
return
if isinstance(self.node_return.value, ast.Tuple):
return [self.get_string(s) for s in self.node_return.value.dims]
return [self.get_string(self.node_return.value)]
56 changes: 56 additions & 0 deletions tests/unit/workflow/test_output_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from sys import version_info
import unittest

import numpy as np

from pyiron_contrib.workflow.output_parser import ParseOutput


@unittest.skipUnless(
version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+"
)
class TestParseOutput(unittest.TestCase):
def test_parsing(self):
with self.subTest("Single return"):
def identity(x):
return x
self.assertListEqual(ParseOutput(identity).output, ["x"])

with self.subTest("Expression return"):
def add(x, y):
return x + y
self.assertListEqual(ParseOutput(add).output, ["x + y"])

with self.subTest("Weird whitespace"):
def add(x, y):
return x + y
self.assertListEqual(ParseOutput(add).output, ["x + y"])

with self.subTest("Multiple expressions"):
def add_and_subtract(x, y):
return x + y, x - y
self.assertListEqual(ParseOutput(add).output, ["x + y", "x - y"])

with self.subTest("Best-practice (well-named return vars)"):
def md(job):
temperature = job.output.temperature
energy = job.output.energy
return temperature, energy
self.assertListEqual(ParseOutput(md).output, ["temperature", "energy"])

with self.subTest("Function call returns"):
def function_return(i, j):
return (
np.arange(
i, dtype=int
),
np.shape(i, j)
)
self.assertListEqual(
ParseOutput(function_return).output,
["np.arange( i, dtype=int )", "np.shape(i, j)"]
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 0887a9b

Please sign in to comment.