Skip to content

Commit

Permalink
Allow newlines in __str__ output printed by fgraph_to_python
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jun 3, 2022
1 parent c028e38 commit 45642af
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
6 changes: 3 additions & 3 deletions aesara/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,9 +760,9 @@ def fgraph_to_python(

node_output_names = [unique_name(v) for v in node.outputs]

body_assigns.append(
f"# {node}\n{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
)
assign_comment_str = f"{indent(str(node), '# ')}"
assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
body_assigns.append(f"{assign_comment_str}\n{assign_str}")

fgraph_input_names = [unique_name(v) for v in fgraph.inputs]
fgraph_output_names = [unique_name(v) for v in fgraph.outputs]
Expand Down
51 changes: 51 additions & 0 deletions tests/link/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from functools import singledispatch

import numpy as np
Expand Down Expand Up @@ -108,6 +109,56 @@ def func(*args, op=op):
assert op2.called == 2


def test_fgraph_to_python_multiline_str():
"""Make sure that multiline `__str__` values are supported by `fgraph_to_python`."""

x = vector("x")
y = vector("y")

class TestOp(Op):
def __init__(self):
super().__init__()

def make_node(self, *args):
return Apply(self, list(args), [x.type() for x in args])

def perform(self, inputs, outputs):
for i, inp in enumerate(inputs):
outputs[i][0] = inp[0]

def __str__(self):
return "Test\nOp()"

@to_python.register(TestOp)
def to_python_TestOp(op, **kwargs):
def func(*args, op=op):
return list(args)

return func

op1 = TestOp()
op2 = TestOp()

q, r = op1(x, y)
outs = op2(q + r, q + r)

out_fg = FunctionGraph([x, y], outs, clone=False)
assert len(out_fg.outputs) == 2

out_py = fgraph_to_python(out_fg, to_python)

out_py_src = inspect.getsource(out_py)

assert (
"""
# Elemwise{add,no_inplace}(Test
# Op().0, Test
# Op().1)
"""
in out_py_src
)


def test_unique_name_generator():

unique_names = unique_name_generator(["blah"], suffix_sep="_")
Expand Down

0 comments on commit 45642af

Please sign in to comment.