From d0f42fe4c124ba1b123bb686b311bd2c22611cd6 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 9 Feb 2023 18:13:07 -0600 Subject: [PATCH] Add support for default updates in OpFromGraph --- aesara/compile/builders.py | 199 +++++++++++++++++++++++++-------- aesara/scan/op.py | 4 +- tests/compile/test_builders.py | 124 +++++++++++++++++++- tests/test_printing.py | 21 ++-- 4 files changed, 291 insertions(+), 57 deletions(-) diff --git a/aesara/compile/builders.py b/aesara/compile/builders.py index cca88ffbc7..031e619456 100644 --- a/aesara/compile/builders.py +++ b/aesara/compile/builders.py @@ -2,12 +2,13 @@ from collections import OrderedDict from copy import copy from functools import partial -from typing import Dict, List, Optional, Sequence, Tuple, cast +from typing import List, Optional, Sequence, Tuple, cast import aesara.tensor as at -from aesara import function from aesara.compile.function.pfunc import rebuild_collect_shared +from aesara.compile.io import In, Out from aesara.compile.mode import optdb +from aesara.compile.ops import update_placeholder from aesara.compile.sharedvalue import SharedVariable from aesara.configdefaults import config from aesara.gradient import DisconnectedType, Rop, grad @@ -83,13 +84,26 @@ def local_traverse(out): def construct_nominal_fgraph( inputs: Sequence[Variable], outputs: Sequence[Variable] -) -> Tuple[ - FunctionGraph, - Sequence[Variable], - Dict[Variable, Variable], - Dict[Variable, Variable], -]: - """Construct an inner-`FunctionGraph` with ordered nominal inputs.""" +) -> Tuple[FunctionGraph, Sequence[Variable],]: + r"""Construct an inner-`FunctionGraph` with ordered nominal inputs. + + .. note:: + + Updates (e.g. from `SharedVariable.default_update`) are appended to the resulting + `FunctionGraph`'s outputs. + + Parameters + ========== + inputs + A list of inputs. + outputs + A list of outputs. + + Returns + ======= + The `FunctionGraph` and a list of shared inputs. + + """ dummy_inputs = [] for n, inp in enumerate(inputs): if ( @@ -105,6 +119,7 @@ def construct_nominal_fgraph( dummy_shared_inputs = [] shared_inputs = [] + default_updates = {} for var in graph_inputs(outputs, inputs): if isinstance(var, SharedVariable): # To correctly support shared variables the inner-graph should @@ -113,14 +128,18 @@ def construct_nominal_fgraph( # That's why we collect the shared variables and replace them # with dummies. shared_inputs.append(var) - dummy_shared_inputs.append(var.type()) + dummy_var = var.type() + dummy_shared_inputs.append(dummy_var) + + if var.default_update: + default_updates[dummy_var] = var.default_update elif var not in inputs and not isinstance(var, Constant): raise MissingInputError(f"OpFromGraph is missing an input: {var}") replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs)) new = rebuild_collect_shared( - cast(Sequence[Variable], outputs), + outputs=cast(Sequence[Variable], outputs + list(default_updates.values())), inputs=inputs + shared_inputs, replace=replacements, copy_inputs_over=False, @@ -131,13 +150,23 @@ def construct_nominal_fgraph( (clone_d, update_d, update_expr, new_shared_inputs), ) = new + local_default_updates = local_outputs[len(outputs) :] + update_d.update( + {clone_d[k]: v for k, v in zip(default_updates.keys(), local_default_updates)} + ) + update_expr.extend(local_default_updates) + assert len(local_inputs) == len(inputs) + len(shared_inputs) - assert len(local_outputs) == len(outputs) - assert not update_d - assert not update_expr + assert len(local_outputs) == len(outputs) + len(default_updates) assert not new_shared_inputs - fgraph = FunctionGraph(local_inputs, local_outputs, clone=False) + update_mapping = { + local_outputs.index(v): local_inputs.index(k) for k, v in update_d.items() + } + + fgraph = FunctionGraph( + local_inputs, local_outputs, clone=False, update_mapping=update_mapping + ) # The inputs need to be `NominalVariable`s so that we can merge # inner-graphs @@ -153,7 +182,7 @@ def construct_nominal_fgraph( fgraph.clients.pop(inp, None) fgraph.add_input(nom_inp) - return fgraph, shared_inputs, update_d, update_expr + return fgraph, shared_inputs class OpFromGraph(Op, HasInnerGraph): @@ -316,7 +345,13 @@ def __init__( name: Optional[str] = None, **kwargs, ): - """ + r"""Construct an `OpFromGraph` instance. + + .. note:: + + `SharedVariable`\s in `outputs` will have their `SharedVariable.default_update` values + altered in order to support in-lining in the presence of updates. + Parameters ---------- inputs @@ -324,29 +359,30 @@ def __init__( outputs The outputs to the graph. inline - Defaults to ``False`` - ``True`` : Cause the :class:`Op`'s original graph being used during compilation, the :class:`Op` will not be visible in the compiled graph but rather its internal graph. ``False`` : will use a pre-compiled function inside. + + Defaults to ``False``. + grad_overrides - Defaults to ``'default'``. This argument is mutually exclusive with ``lop_overrides``. - ``'default'`` : Do not override, use default grad() result + ``'default'`` : Do not override, use default :meth:`Op.grad` result `OpFromGraph`: Override with another `OpFromGraph`, should accept inputs as the same order and types of ``inputs`` and ``output_grads`` - arguments as one would specify in :meth:`Op.grad`() method. + arguments as one would specify in :meth:`Op.grad` method. `callable`: Should take two args: ``inputs`` and ``output_grads``. Each argument is expected to be a list of :class:`Variable `. Must return list of :class:`Variable `. - lop_overrides + Defaults to ``'default'``. + lop_overrides This argument is mutually exclusive with ``grad_overrides``. These options are similar to the ``grad_overrides`` above, but for @@ -355,7 +391,7 @@ def __init__( ``'default'``: Do not override, use the default :meth:`Op.L_op` result `OpFromGraph`: Override with another `OpFromGraph`, should - accept inputs as the same order and types of ``inputs``, + accept inputs in the same order and types as `inputs`, ``outputs`` and ``output_grads`` arguments as one would specify in :meth:`Op.grad` method. @@ -371,11 +407,11 @@ def __init__( :class:`Variable`. Each list element corresponds to gradient of a specific input, length of list must be equal to number of inputs. + Defaults to ``'default'``. + rop_overrides One of ``{'default', OpFromGraph, callable, Variable}``. - Defaults to ``'default'``. - ``'default'``: Do not override, use the default :meth:`Op.R_op` result `OpFromGraph`: Override with another `OpFromGraph`, should @@ -397,11 +433,13 @@ def __init__( must be equal to number of outputs. connection_pattern If not ``None``, this will be used as the connection_pattern for this :class:`Op`. + + Defaults to ``'default'``. + name A name for debugging purposes. kwargs - Check :func:`aesara.function` for more arguments, only works when not - inline. + See :func:`aesara.function`. """ if not (isinstance(inputs, list) and isinstance(outputs, list)): @@ -418,9 +456,24 @@ def __init__( self.is_inline = inline - self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph( - inputs, outputs - ) + # These `shared_inputs` are the original variables in `outputs` + # (i.e. not clones). + self.fgraph, shared_inputs = construct_nominal_fgraph(inputs, outputs) + + # We need to hold on to the original variables so that gradients can be + # taken wrt. them. Ideally, we wouldn't hold on to specific `Variable` + # references like this outside of graph, but we're maintaining support + # for old functionality right now. + self.shared_inputs = [] + for v in shared_inputs: + # This is needed so that `aesara.function` will create an update + # output placeholder in the `FunctionGraph` it compiles. We need + # placeholders like this in order to properly inline `OpFromGraph`s + # containing updates. + # FYI: When the corresponding updates aren't used, they should be + # removed at the `aesara.function` level. + v.default_update = update_placeholder(v) + self.shared_inputs.append(v) self.kwargs = kwargs self.input_types = [inp.type for inp in inputs] @@ -933,7 +986,36 @@ def fn(self): if getattr(self, "_fn", None) is not None: return self._fn - self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs) + from aesara.compile.function.pfunc import pfunc + + # We don't want calls/evaluations of this `Op` to change + # the inner-graph, so we need to clone it + fgraph, _ = self.fgraph.clone_get_equiv(copy_inputs=False, copy_orphans=False) + + wrapped_inputs = [In(x, borrow=False) for x in fgraph.inputs] + wrapped_outputs = [Out(x, borrow=True) for x in fgraph.outputs] + + n_inputs = len(fgraph.inputs) + + for out_idx, in_idx in fgraph.update_mapping.items(): + shared_input = self.shared_inputs[in_idx - n_inputs] + in_var = fgraph.inputs[in_idx] + updated_wrapped_input = In( + variable=in_var, + value=shared_input.container, + update=fgraph.outputs[out_idx], + implicit=True, + shared=True, + ) + wrapped_inputs[in_idx] = updated_wrapped_input + + self._fn = pfunc( + wrapped_inputs, + wrapped_outputs, + fgraph=fgraph, + no_default_updates=True, + **self.kwargs, + ) self._fn.trust_input = True return self._fn @@ -944,6 +1026,11 @@ def inner_inputs(self): @property def inner_outputs(self): + """Return all the outputs except those used for updates.""" + n_updates = len(self.fgraph.update_mapping) + if n_updates > 0: + return self.fgraph.outputs[:-n_updates] + return self.fgraph.outputs def clone(self): @@ -952,28 +1039,52 @@ def clone(self): return res def perform(self, node, inputs, outputs): - variables = self.fn(*inputs) - assert len(variables) == len(outputs) - for output, variable in zip(outputs, variables): - output[0] = variable + results = self.fn(*inputs) + for output, res in zip(outputs, results): + output[0] = res @node_rewriter([OpFromGraph]) def inline_ofg_expansion(fgraph, node): - """ - This optimization expands internal graph of OpFromGraph. - Only performed if node.op.is_inline == True - Doing so can improve optimization at the cost of compilation speed. + """Expand the internal graph of an `OpFromGraph`. + + Only performed if ``node.op.is_inline == True``. + """ op = node.op - if not isinstance(op, OpFromGraph): - return False + if not op.is_inline: return False - return clone_replace( - op.inner_outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)} + + outputs = clone_replace( + op.fgraph.outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)} ) + replacements = { + old_var: new_var + for old_var, new_var in zip(node.outputs, outputs[: len(op.inner_outputs)]) + } + + # Add the updates from `OpFromGraph` into the outer-graph + for out_idx, in_idx in op.fgraph.update_mapping.items(): + shared_input = node.inputs[in_idx] + assert isinstance(shared_input, SharedVariable) + + outer_in_idx = fgraph.inputs.index(shared_input) + + # There should be a placeholder output in `fgraph.outputs` that we can + # use. If there isn't, then someone forgot/removed the + # `SharedVariable.default_update`s on the inputs to the `OpFromGraph` + # (i.e. at the user-level/graph construction-time). + outer_out_idx = fgraph.inv_update_mapping[outer_in_idx] + update_var = fgraph.outputs[outer_out_idx] + + assert update_var is not shared_input + + replacements[update_var] = outputs[out_idx] + + return replacements + # We want to run this before the first merge optimizer # and before the first scan optimizer. diff --git a/aesara/scan/op.py b/aesara/scan/op.py index 60efdf9f13..0c68c9be3e 100644 --- a/aesara/scan/op.py +++ b/aesara/scan/op.py @@ -750,7 +750,9 @@ def __init__( If ``True``, all the shared variables used in the inner-graph must be provided. """ - self.fgraph, shared_inputs, _, _ = construct_nominal_fgraph(inputs, outputs) + self.fgraph, shared_inputs = construct_nominal_fgraph(inputs, outputs) + + assert not self.fgraph.update_mapping # The shared variables should have been removed, so, if there are # any, it's because the user didn't specify an input. diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index 0ca4cabf53..c67480cef2 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -7,9 +7,12 @@ from aesara.compile import shared from aesara.compile.builders import OpFromGraph from aesara.compile.function import function +from aesara.compile.mode import get_mode +from aesara.compile.ops import update_placeholder +from aesara.compile.sharedvalue import SharedVariable from aesara.configdefaults import config from aesara.gradient import DisconnectedType, Rop, disconnected_type, grad -from aesara.graph.basic import equal_computations +from aesara.graph.basic import Constant, equal_computations from aesara.graph.fg import FunctionGraph from aesara.graph.null_type import NullType from aesara.graph.rewriting.utils import rewrite_graph @@ -424,8 +427,8 @@ def test_connection_pattern(self, cls_ofg): assert results == expect_result def test_infer_shape(self): - # test infer shape does not need to against inline case - # since the Op is remove during optimization phase + # N.B. this test does not need to be run against the inline case, + # because the `Op` is supposed to be removed during optimization phase. x = matrix("x") y = matrix("y") o1 = x + y @@ -560,6 +563,121 @@ def test_outputs_consistency(self): # The original `op.fgraph` outputs should stay the same, though assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x]) + def test_default_updates(self): + """Make sure that default updates on shared variables raise errors.""" + + srng = at.random.RandomStream(1290984) + x = srng.gamma(0.5, 0.5, name="x") + + srng = at.random.RandomStream(239) + y = srng.normal(x, name="y") + + y_shared_variable = y.owner.inputs[0] + assert y_shared_variable.default_update is not None + + # We don't want to in-place the RNG updates; that would avoid the + # update logic we're testing + mode = get_mode("FAST_RUN").excluding("inplace") + + ofg = OpFromGraph([x], [y], inline=False, mode=mode) + + x_val = at.as_tensor(0.0, dtype=x.type.dtype) + z = ofg(x_val) + + z_fn = function([], z, mode=mode) + + # The `y`'s mean should no longer be `x` + assert isinstance(z_fn.maker.fgraph.outputs[0].owner.inputs[0], Constant) + + # There should be placeholder default updates in the resulting graph + (shared_variable,) = [ + v for v in z_fn.maker.fgraph.variables if isinstance(v, SharedVariable) + ] + + placeholder_var = shared_variable.default_update + assert placeholder_var.owner.op == update_placeholder + assert placeholder_var.owner.inputs[0] is shared_variable + + # Since this `OpFromGraph` wasn't inlined, we should've removed the + # placeholder update. + assert not z_fn.maker.fgraph.update_mapping + + # The compiled inner-graph should be performing the RNG updates + ig_ofg = z_fn.maker.fgraph.outputs[0].owner.op + # `OpFromGraph`s are cloned + # TODO: It's always worth revisiting whether or not we should be doing + # this (and if/when we can go without it). + # assert ig_ofg is ofg + assert ig_ofg.fgraph.update_mapping == {1: 1} + + srng_exp = at.random.RandomStream(239) + exp_fn = function([], srng_exp.normal(x_val, name="y"), mode=mode) + + (exp_shared_variable,) = [ + v for v in exp_fn.maker.fgraph.variables if isinstance(v, SharedVariable) + ] + assert exp_shared_variable.default_update is not None + + z_res = z_fn() + exp_res = exp_fn() + + assert np.array_equal(z_res, exp_res) + + # Execute again to make sure that the RNG is updated correctly + z_res_next = z_fn() + exp_res_next = exp_fn() + + # Make sure nothing weird is going on here + assert not np.array_equal(exp_res, exp_res_next) + + assert np.array_equal(z_res_next, exp_res_next) + + def test_default_updates_inlined(self): + + srng = at.random.RandomStream(1290984) + x = srng.gamma(0.5, 0.5, name="x") + + srng = at.random.RandomStream(239) + y = srng.normal(x, name="y") + + y_shared_variable = y.owner.inputs[0] + assert y_shared_variable.default_update is not None + + mode = get_mode("FAST_RUN").excluding("inplace") + + ofg = OpFromGraph([x], [y], inline=True, mode=mode) + + x_val = at.as_tensor(0.0, dtype=x.type.dtype) + z = ofg(x_val) + + z_fn = function([], z, mode=mode) + + assert not any( + isinstance(node.op, OpFromGraph) for node in z_fn.maker.fgraph.apply_nodes + ) + + srng_exp = at.random.RandomStream(239) + exp_fn = function([], srng_exp.normal(x_val, name="y"), mode=mode) + + (exp_shared_variable,) = [ + v for v in exp_fn.maker.fgraph.variables if isinstance(v, SharedVariable) + ] + assert exp_shared_variable.default_update is not None + + z_res = z_fn() + exp_res = exp_fn() + + assert np.array_equal(z_res, exp_res) + + # Execute again to make sure that the RNG is updated correctly + z_res_next = z_fn() + exp_res_next = exp_fn() + + # Make sure nothing weird is going on here + assert not np.array_equal(exp_res, exp_res_next) + + assert np.array_equal(z_res_next, exp_res_next) + @config.change_flags(floatX="float64") def test_debugprint(): diff --git a/tests/test_printing.py b/tests/test_printing.py index 7251072304..56d8f362ca 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -510,10 +510,12 @@ def test_debugprint_inner_graph_default_updates(): from aesara.compile.builders import OpFromGraph - igo = OpFromGraph([igo_in_1], [igo_out_1]) + igo = OpFromGraph([r1, r2, igo_in_1], [igo_out_1]) r3 = MyVariable("3") - out = igo(r3) + r4 = MyVariable("4") + r5 = MyVariable("5") + out = igo(r3, r4, r5) s = StringIO() debugprint(out, file=s, print_default_updates=True) @@ -523,20 +525,21 @@ def test_debugprint_inner_graph_default_updates(): r""" OpFromGraph{inline=False} [id A] |3 [id B] - |s [id C] <- [id D] + |4 [id C] + |5 [id D] + |s [id E] <- [id F] Inner graphs: OpFromGraph{inline=False} [id A] - >op2 [id E] 'igo1' - > |*0- [id F] - > |*1- [id G] + >op2 [id G] 'igo1' + > |*2- [id H] + > |*3- [id I] Default updates: - op1 [id D] 'o1' - |1 [id H] - |2 [id I] + UpdatePlaceholder [id F] + |s [id E] <- [id F] """ ).lstrip()