Skip to content

Commit

Permalink
Materialize diagonal name substitutions in Tensor.eager_subs (#382)
Browse files Browse the repository at this point in the history
* Materialize diagonal name substitutions in Tensor.eager_subs

* nit
  • Loading branch information
eb8680 authored Oct 13, 2020
1 parent 330a07d commit 97c1a43
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
7 changes: 6 additions & 1 deletion funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import itertools
import typing
import warnings
from collections import OrderedDict
from collections import Counter, OrderedDict
from contextlib import contextmanager
from functools import reduce

Expand Down Expand Up @@ -187,6 +187,11 @@ def eager_subs(self, subs):
if not subs:
return self

# Handle diagonal variable substitution
var_counts = Counter(v for v in subs.values() if isinstance(v, Variable))
subs = OrderedDict((k, self.materialize(v) if var_counts[v] > 1 else v)
for k, v in subs.items())

# Handle renaming to enable cons hashing, and
# handle slicing to avoid copying data.
if any(isinstance(v, (Variable, Slice)) for v in subs.values()):
Expand Down
9 changes: 9 additions & 0 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,3 +1047,12 @@ def test_detach():
assert_close(x, y)
assert x.data.requires_grad
assert not y.data.requires_grad


def test_diagonal_rename():
x = Tensor(randn(2, 2, 3), OrderedDict(a=funsor.Bint[2], b=funsor.Bint[2], c=funsor.Bint[3]), 'real')
d = Variable("d", funsor.Bint[2])
dt = x.materialize(d)
yt = x(a=dt, b=dt)
y = x(a=d, b=d)
assert_close(y, yt)

0 comments on commit 97c1a43

Please sign in to comment.