Skip to content

Commit

Permalink
Clone inner-graph before compiling in OpFromGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jun 4, 2022
1 parent 88f0299 commit c09d92b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
6 changes: 3 additions & 3 deletions aesara/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import List, Optional, Sequence, cast

import aesara.tensor as at
from aesara import function
from aesara.compile.function.pfunc import rebuild_collect_shared
from aesara.compile.function.types import orig_function
from aesara.compile.mode import optdb
from aesara.compile.sharedvalue import SharedVariable
from aesara.configdefaults import config
Expand Down Expand Up @@ -326,7 +326,7 @@ def __init__(
name
A name for debugging purposes.
kwargs
Check :func:`orig_function` for more arguments, only works when not
Check :func:`aesara.function` for more arguments, only works when not
inline.
"""

Expand Down Expand Up @@ -903,7 +903,7 @@ def fn(self):
if getattr(self, "_fn", None) is not None:
return self._fn

self._fn = orig_function(self.inner_inputs, self.inner_outputs, **self.kwargs)
self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
self._fn.trust_input = True

return self._fn
Expand Down
19 changes: 19 additions & 0 deletions tests/compile/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,25 @@ def test_shared_to_nonshared_input(self):

assert np.array_equal(res_2, 1.0)

def test_outputs_consistency(self):
"""Make sure that `OpFromGraph.fn` doesn't change the value of `OpFromGraph.inner_outputs`."""

x = scalar("x")
op = OpFromGraph([x], [x**2 / x], mode="FAST_RUN")

# Confirm that the inner-graph is as expected
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])

# These outputs of the compiled `op.fgraph` should differ from the
# original, uncompiled `op.fgraph` outputs
fn = op.fn
new_inputs = fn.maker.fgraph.inputs
new_outputs = fn.maker.fgraph.outputs
assert not equal_computations(new_outputs, [x**2 / x], new_inputs, [x])

# The original `op.fgraph` outputs should stay the same, though
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])


@config.change_flags(floatX="float64")
def test_debugprint():
Expand Down

0 comments on commit c09d92b

Please sign in to comment.