From c09d92b13e8534ebd89147a9db1fa6a84c4408ce Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Fri, 3 Jun 2022 15:16:02 -0500 Subject: [PATCH] Clone inner-graph before compiling in OpFromGraph --- aesara/compile/builders.py | 6 +++--- tests/compile/test_builders.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/aesara/compile/builders.py b/aesara/compile/builders.py index 8d0cb223f4..738d724ac6 100644 --- a/aesara/compile/builders.py +++ b/aesara/compile/builders.py @@ -5,8 +5,8 @@ from typing import List, Optional, Sequence, cast import aesara.tensor as at +from aesara import function from aesara.compile.function.pfunc import rebuild_collect_shared -from aesara.compile.function.types import orig_function from aesara.compile.mode import optdb from aesara.compile.sharedvalue import SharedVariable from aesara.configdefaults import config @@ -326,7 +326,7 @@ def __init__( name A name for debugging purposes. kwargs - Check :func:`orig_function` for more arguments, only works when not + Check :func:`aesara.function` for more arguments, only works when not inline. """ @@ -903,7 +903,7 @@ def fn(self): if getattr(self, "_fn", None) is not None: return self._fn - self._fn = orig_function(self.inner_inputs, self.inner_outputs, **self.kwargs) + self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs) self._fn.trust_input = True return self._fn diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index a4ba20b53b..2038b93666 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -541,6 +541,25 @@ def test_shared_to_nonshared_input(self): assert np.array_equal(res_2, 1.0) + def test_outputs_consistency(self): + """Make sure that `OpFromGraph.fn` doesn't change the value of `OpFromGraph.inner_outputs`.""" + + x = scalar("x") + op = OpFromGraph([x], [x**2 / x], mode="FAST_RUN") + + # Confirm that the inner-graph is as expected + assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x]) + + # These outputs of the compiled `op.fgraph` should differ from the + # original, uncompiled `op.fgraph` outputs + fn = op.fn + new_inputs = fn.maker.fgraph.inputs + new_outputs = fn.maker.fgraph.outputs + assert not equal_computations(new_outputs, [x**2 / x], new_inputs, [x]) + + # The original `op.fgraph` outputs should stay the same, though + assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x]) + @config.change_flags(floatX="float64") def test_debugprint():