Skip to content

Commit

Permalink
Add support for default updates in OpFromGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Feb 19, 2023
1 parent 48409dd commit d0f42fe
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 57 deletions.
199 changes: 155 additions & 44 deletions aesara/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from collections import OrderedDict
from copy import copy
from functools import partial
from typing import Dict, List, Optional, Sequence, Tuple, cast
from typing import List, Optional, Sequence, Tuple, cast

import aesara.tensor as at
from aesara import function
from aesara.compile.function.pfunc import rebuild_collect_shared
from aesara.compile.io import In, Out
from aesara.compile.mode import optdb
from aesara.compile.ops import update_placeholder
from aesara.compile.sharedvalue import SharedVariable
from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, Rop, grad
Expand Down Expand Up @@ -83,13 +84,26 @@ def local_traverse(out):

def construct_nominal_fgraph(
inputs: Sequence[Variable], outputs: Sequence[Variable]
) -> Tuple[
FunctionGraph,
Sequence[Variable],
Dict[Variable, Variable],
Dict[Variable, Variable],
]:
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
) -> Tuple[FunctionGraph, Sequence[Variable],]:
r"""Construct an inner-`FunctionGraph` with ordered nominal inputs.
.. note::
Updates (e.g. from `SharedVariable.default_update`) are appended to the resulting
`FunctionGraph`'s outputs.
Parameters
==========
inputs
A list of inputs.
outputs
A list of outputs.
Returns
=======
The `FunctionGraph` and a list of shared inputs.
"""
dummy_inputs = []
for n, inp in enumerate(inputs):
if (
Expand All @@ -105,6 +119,7 @@ def construct_nominal_fgraph(

dummy_shared_inputs = []
shared_inputs = []
default_updates = {}
for var in graph_inputs(outputs, inputs):
if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
Expand All @@ -113,14 +128,18 @@ def construct_nominal_fgraph(
# That's why we collect the shared variables and replace them
# with dummies.
shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
dummy_var = var.type()
dummy_shared_inputs.append(dummy_var)

if var.default_update:
default_updates[dummy_var] = var.default_update
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")

replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))

new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
outputs=cast(Sequence[Variable], outputs + list(default_updates.values())),
inputs=inputs + shared_inputs,
replace=replacements,
copy_inputs_over=False,
Expand All @@ -131,13 +150,23 @@ def construct_nominal_fgraph(
(clone_d, update_d, update_expr, new_shared_inputs),
) = new

local_default_updates = local_outputs[len(outputs) :]
update_d.update(
{clone_d[k]: v for k, v in zip(default_updates.keys(), local_default_updates)}
)
update_expr.extend(local_default_updates)

assert len(local_inputs) == len(inputs) + len(shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
assert len(local_outputs) == len(outputs) + len(default_updates)
assert not new_shared_inputs

fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
update_mapping = {
local_outputs.index(v): local_inputs.index(k) for k, v in update_d.items()
}

fgraph = FunctionGraph(
local_inputs, local_outputs, clone=False, update_mapping=update_mapping
)

# The inputs need to be `NominalVariable`s so that we can merge
# inner-graphs
Expand All @@ -153,7 +182,7 @@ def construct_nominal_fgraph(
fgraph.clients.pop(inp, None)
fgraph.add_input(nom_inp)

return fgraph, shared_inputs, update_d, update_expr
return fgraph, shared_inputs


class OpFromGraph(Op, HasInnerGraph):
Expand Down Expand Up @@ -316,37 +345,44 @@ def __init__(
name: Optional[str] = None,
**kwargs,
):
"""
r"""Construct an `OpFromGraph` instance.
.. note::
`SharedVariable`\s in `outputs` will have their `SharedVariable.default_update` values
altered in order to support in-lining in the presence of updates.
Parameters
----------
inputs
The inputs to the graph.
outputs
The outputs to the graph.
inline
Defaults to ``False``
``True`` : Cause the :class:`Op`'s original graph being used during
compilation, the :class:`Op` will not be visible in the compiled
graph but rather its internal graph.
``False`` : will use a pre-compiled function inside.
Defaults to ``False``.
grad_overrides
Defaults to ``'default'``.
This argument is mutually exclusive with ``lop_overrides``.
``'default'`` : Do not override, use default grad() result
``'default'`` : Do not override, use default :meth:`Op.grad` result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs`` and ``output_grads``
arguments as one would specify in :meth:`Op.grad`() method.
arguments as one would specify in :meth:`Op.grad` method.
`callable`: Should take two args: ``inputs`` and ``output_grads``.
Each argument is expected to be a list of :class:`Variable `.
Must return list of :class:`Variable `.
lop_overrides
Defaults to ``'default'``.
lop_overrides
This argument is mutually exclusive with ``grad_overrides``.
These options are similar to the ``grad_overrides`` above, but for
Expand All @@ -355,7 +391,7 @@ def __init__(
``'default'``: Do not override, use the default :meth:`Op.L_op` result
`OpFromGraph`: Override with another `OpFromGraph`, should
accept inputs as the same order and types of ``inputs``,
accept inputs in the same order and types as `inputs`,
``outputs`` and ``output_grads`` arguments as one would specify in
:meth:`Op.grad` method.
Expand All @@ -371,11 +407,11 @@ def __init__(
:class:`Variable`. Each list element corresponds to gradient of
a specific input, length of list must be equal to number of inputs.
Defaults to ``'default'``.
rop_overrides
One of ``{'default', OpFromGraph, callable, Variable}``.
Defaults to ``'default'``.
``'default'``: Do not override, use the default :meth:`Op.R_op` result
`OpFromGraph`: Override with another `OpFromGraph`, should
Expand All @@ -397,11 +433,13 @@ def __init__(
must be equal to number of outputs. connection_pattern If not
``None``, this will be used as the connection_pattern for this
:class:`Op`.
Defaults to ``'default'``.
name
A name for debugging purposes.
kwargs
Check :func:`aesara.function` for more arguments, only works when not
inline.
See :func:`aesara.function`.
"""

if not (isinstance(inputs, list) and isinstance(outputs, list)):
Expand All @@ -418,9 +456,24 @@ def __init__(

self.is_inline = inline

self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
inputs, outputs
)
# These `shared_inputs` are the original variables in `outputs`
# (i.e. not clones).
self.fgraph, shared_inputs = construct_nominal_fgraph(inputs, outputs)

# We need to hold on to the original variables so that gradients can be
# taken wrt. them. Ideally, we wouldn't hold on to specific `Variable`
# references like this outside of graph, but we're maintaining support
# for old functionality right now.
self.shared_inputs = []
for v in shared_inputs:
# This is needed so that `aesara.function` will create an update
# output placeholder in the `FunctionGraph` it compiles. We need
# placeholders like this in order to properly inline `OpFromGraph`s
# containing updates.
# FYI: When the corresponding updates aren't used, they should be
# removed at the `aesara.function` level.
v.default_update = update_placeholder(v)
self.shared_inputs.append(v)

self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
Expand Down Expand Up @@ -933,7 +986,36 @@ def fn(self):
if getattr(self, "_fn", None) is not None:
return self._fn

self._fn = function(self.inner_inputs, self.inner_outputs, **self.kwargs)
from aesara.compile.function.pfunc import pfunc

# We don't want calls/evaluations of this `Op` to change
# the inner-graph, so we need to clone it
fgraph, _ = self.fgraph.clone_get_equiv(copy_inputs=False, copy_orphans=False)

wrapped_inputs = [In(x, borrow=False) for x in fgraph.inputs]
wrapped_outputs = [Out(x, borrow=True) for x in fgraph.outputs]

n_inputs = len(fgraph.inputs)

for out_idx, in_idx in fgraph.update_mapping.items():
shared_input = self.shared_inputs[in_idx - n_inputs]
in_var = fgraph.inputs[in_idx]
updated_wrapped_input = In(
variable=in_var,
value=shared_input.container,
update=fgraph.outputs[out_idx],
implicit=True,
shared=True,
)
wrapped_inputs[in_idx] = updated_wrapped_input

self._fn = pfunc(
wrapped_inputs,
wrapped_outputs,
fgraph=fgraph,
no_default_updates=True,
**self.kwargs,
)
self._fn.trust_input = True

return self._fn
Expand All @@ -944,6 +1026,11 @@ def inner_inputs(self):

@property
def inner_outputs(self):
"""Return all the outputs except those used for updates."""
n_updates = len(self.fgraph.update_mapping)
if n_updates > 0:
return self.fgraph.outputs[:-n_updates]

return self.fgraph.outputs

def clone(self):
Expand All @@ -952,28 +1039,52 @@ def clone(self):
return res

def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables):
output[0] = variable
results = self.fn(*inputs)
for output, res in zip(outputs, results):
output[0] = res


@node_rewriter([OpFromGraph])
def inline_ofg_expansion(fgraph, node):
"""
This optimization expands internal graph of OpFromGraph.
Only performed if node.op.is_inline == True
Doing so can improve optimization at the cost of compilation speed.
"""Expand the internal graph of an `OpFromGraph`.
Only performed if ``node.op.is_inline == True``.
"""
op = node.op
if not isinstance(op, OpFromGraph):
return False

if not op.is_inline:
return False
return clone_replace(
op.inner_outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)}

outputs = clone_replace(
op.fgraph.outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)}
)

replacements = {
old_var: new_var
for old_var, new_var in zip(node.outputs, outputs[: len(op.inner_outputs)])
}

# Add the updates from `OpFromGraph` into the outer-graph
for out_idx, in_idx in op.fgraph.update_mapping.items():
shared_input = node.inputs[in_idx]
assert isinstance(shared_input, SharedVariable)

outer_in_idx = fgraph.inputs.index(shared_input)

# There should be a placeholder output in `fgraph.outputs` that we can
# use. If there isn't, then someone forgot/removed the
# `SharedVariable.default_update`s on the inputs to the `OpFromGraph`
# (i.e. at the user-level/graph construction-time).
outer_out_idx = fgraph.inv_update_mapping[outer_in_idx]
update_var = fgraph.outputs[outer_out_idx]

assert update_var is not shared_input

replacements[update_var] = outputs[out_idx]

return replacements


# We want to run this before the first merge optimizer
# and before the first scan optimizer.
Expand Down
4 changes: 3 additions & 1 deletion aesara/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,9 @@ def __init__(
If ``True``, all the shared variables used in the inner-graph must be provided.
"""
self.fgraph, shared_inputs, _, _ = construct_nominal_fgraph(inputs, outputs)
self.fgraph, shared_inputs = construct_nominal_fgraph(inputs, outputs)

assert not self.fgraph.update_mapping

# The shared variables should have been removed, so, if there are
# any, it's because the user didn't specify an input.
Expand Down
Loading

0 comments on commit d0f42fe

Please sign in to comment.