Skip to content

Commit

Permalink
fix: DAG rendering from code package (#290)
Browse files Browse the repository at this point in the history
* inherit DAGNode from metaflow client instead of reimplementing it.

* inherit most of FlowGraph functionality from metaflow client as well, overwriting only when necessary.

* draft unit test for legacy dag parsing

* first passing unit test for custom flowgraph

* add more unit tests to cover custom flowgraph behaviour

* revert client FlowGraph inheritance

* codestyles

* fix compatibility with imported DAGNode

Co-authored-by: Brendan Gibson <[email protected]>
  • Loading branch information
saikonen and obgibson authored Jan 27, 2022
1 parent 3953b86 commit 3a33276
Show file tree
Hide file tree
Showing 2 changed files with 291 additions and 118 deletions.
134 changes: 16 additions & 118 deletions services/ui_backend_service/data/cache/custom_flowgraph.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,9 @@
import ast
from metaflow.graph import deindent_docstring, DAGNode


class DAGNode(object):
def __init__(self, func_ast, decos, doc):
self.name = func_ast.name
self.func_lineno = func_ast.lineno
self.decorators = decos
self.doc = doc.rstrip()

# these attributes are populated by _parse
self.tail_next_lineno = 0
self.type = None
self.out_funcs = []
self.has_tail_next = False
self.invalid_tail_next = False
self.num_args = 0
self.condition = None
self.foreach_param = None
self._parse(func_ast)

# these attributes are populated by _traverse_graph
self.in_funcs = set()
self.split_parents = []
self.matching_join = None

# these attributes are populated by _postprocess
self.is_inside_foreach = False

def _expr_str(self, expr):
return '%s.%s' % (expr.value.id, expr.attr)

def _parse(self, func_ast):

self.num_args = len(func_ast.args.args)
tail = func_ast.body[-1]

# end doesn't need a transition
if self.name == 'end':
# TYPE: end
self.type = 'end'

# ensure that the tail an expression
if not isinstance(tail, ast.Expr):
return

# determine the type of self.next transition
try:
if not self._expr_str(tail.value.func) == 'self.next':
return

self.has_tail_next = True
self.invalid_tail_next = True
self.tail_next_lineno = tail.lineno
self.out_funcs = [e.attr for e in tail.value.args]
keywords = dict((k.arg, k.value.s) for k in tail.value.keywords)

if len(keywords) == 1:
if 'foreach' in keywords:
# TYPE: foreach
self.type = 'foreach'
if len(self.out_funcs) == 1:
self.foreach_param = keywords['foreach']
self.invalid_tail_next = False
elif 'condition' in keywords:
# TYPE: split-or
self.type = 'split-or'
if len(self.out_funcs) == 2:
self.condition = keywords['condition']
self.invalid_tail_next = False
elif len(keywords) == 0:
if len(self.out_funcs) > 1:
# TYPE: split-and
self.type = 'split-and'
self.invalid_tail_next = False
elif len(self.out_funcs) == 1:
# TYPE: linear
if self.num_args > 1:
self.type = 'join'
else:
self.type = 'linear'
self.invalid_tail_next = False

except AttributeError:
return

def __str__(self):
return """
*[{0.name} {0.type} (line {0.func_lineno})]*
in_funcs={in_funcs}
split_parents={parents}
matching_join={matching_join}
is_inside_foreach={is_inside_foreach}
decorators={decos}
num_args={0.num_args}
has_tail_next={0.has_tail_next} (line {0.tail_next_lineno})
invalid_tail_next={0.invalid_tail_next}
condition={0.condition}
foreach_param={0.foreach_param}
-> {out}"""\
.format(self,
matching_join=self.matching_join and '[%s]' % self.matching_join,
is_inside_foreach=self.is_inside_foreach,
in_funcs=', '.join('[%s]' % x for x in self.in_funcs),
parents=', '.join('[%s]' % x for x in self.split_parents),
decos=' | '.join(map(str, self.decorators)),
out=', '.join('[%s]' % x for x in self.out_funcs))
# NOTE: This is a custom implementation of the FlowGraph class from the Metaflow client
# which can parse a graph out of a flow_name and a source code string, instead of relying on
# importing the source code as a module.


class StepVisitor(ast.NodeVisitor):
Expand Down Expand Up @@ -141,7 +40,7 @@ def _flow(n):
[root] = list(filter(_flow, ast.parse(source).body))
self.name = root.name
doc = ast.get_docstring(root)
self.doc = doc if doc else ''
self.doc = deindent_docstring(doc) if doc else ''
nodes = {}
StepVisitor(nodes).visit(root)
return nodes
Expand All @@ -151,20 +50,18 @@ def _postprocess(self):
# has is_inside_foreach=True *unless* all of those foreaches
# are joined by the node
for node in self.nodes.values():
foreaches = [p for p in node.split_parents
if self.nodes[p].type == 'foreach']
if [f for f in foreaches
if self.nodes[f].matching_join != node.name]:
foreaches = [
p for p in node.split_parents if self.nodes[p].type == "foreach"
]
if [f for f in foreaches if self.nodes[f].matching_join != node.name]:
node.is_inside_foreach = True

def _traverse_graph(self):

def traverse(node, seen, split_parents):

if node.type in ('split-or', 'split-and', 'foreach'):
if node.type in ("split", "foreach"):
node.split_parents = split_parents
split_parents = split_parents + [node.name]
elif node.type == 'join':
elif node.type == "join":
# ignore joins without splits
if split_parents:
self[split_parents[-1]].matching_join = node.name
Expand All @@ -182,8 +79,8 @@ def traverse(node, seen, split_parents):
child.in_funcs.add(node.name)
traverse(child, seen + [n], split_parents)

if 'start' in self:
traverse(self['start'], [], [])
if "start" in self:
traverse(self["start"], [], [])

# fix the order of in_funcs
for node in self.nodes.values():
Expand All @@ -199,8 +96,9 @@ def __iter__(self):
return iter(self.nodes.values())

def __str__(self):
return '\n'.join(str(n) for _, n in sorted((n.func_lineno, n)
for n in self.nodes.values()))
return "\n".join(
str(n) for _, n in sorted((n.func_lineno, n) for n in self.nodes.values())
)

def output_steps(self):

Expand Down
Loading

0 comments on commit 3a33276

Please sign in to comment.