Skip to content

Commit

Permalink
Add Constant.align method (#558)
Browse files Browse the repository at this point in the history
* add Constant.align method

* test_align
  • Loading branch information
ordabayevy authored Sep 30, 2021
1 parent 86d4a22 commit ca1557b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/source/funsors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,11 @@ Integrate
:undoc-members:
:show-inheritance:
:member-order: bysource

Constant
--------
.. automodule:: funsor.constant
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
12 changes: 12 additions & 0 deletions funsor/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,18 @@ def eager_reduce(self, op, reduced_vars):
assert reduced_vars.issubset(self.arg.inputs)
return Constant(self.const_inputs, self.arg.reduce(op, reduced_vars))

def align(self, names):
assert isinstance(names, tuple)
assert all(name in self.inputs for name in names)
if not names or names == tuple(self.inputs):
return self

const_names = names[: len(self.const_inputs)]
arg_names = names[len(self.const_inputs) :]
assert frozenset(self.const_inputs) == frozenset(const_names)
const_inputs = OrderedDict((name, self.inputs[name]) for name in const_names)
return Constant(const_inputs, self.arg.align(arg_names))


@eager.register(Reduce, ops.AddOp, Constant, frozenset)
@eager.register(Reduce, ops.MulOp, Constant, frozenset)
Expand Down
12 changes: 12 additions & 0 deletions test/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ def test_add_constant_delta():
assert d2 + c is d2 + c(y=point2)


def test_align():
data = Tensor(randn((2, 3)), OrderedDict(i=Bint[2], j=Bint[3]))
x = Constant(OrderedDict(a=Real, b=Bint[4]), data)
y = x.align(("b", "a", "j", "i"))
assert isinstance(y, Constant)
assert tuple(y.inputs) == ("b", "a", "j", "i")
for b in range(4):
for i in range(2):
for j in range(3):
assert x(a=0, b=b, i=i, j=j) == y(a=0, b=b, i=i, j=j)


@pytest.mark.skipif(
get_backend() != "torch",
reason="numpy/jax backend has not implemented ProvenanceTensor",
Expand Down

0 comments on commit ca1557b

Please sign in to comment.