Skip to content

Commit

Permalink
Enable non-scalar Bint domains by replacing some uses of funsor.Reals…
Browse files Browse the repository at this point in the history
… with funsor.Array (#417)
  • Loading branch information
eb8680 authored Jan 7, 2021
1 parent 27cc0e3 commit a6fa083
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 21 deletions.
14 changes: 8 additions & 6 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import funsor.ops as ops
from funsor.affine import is_affine
from funsor.cnf import Contraction, GaussianMixture
from funsor.domains import Array, Real, Reals, RealsType
from funsor.domains import Array, Real, Reals
from funsor.gaussian import Gaussian
from funsor.interpreter import gensym
from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype,
Expand Down Expand Up @@ -72,16 +72,18 @@ def __call__(cls, *args, **kwargs):

# broadcast individual param domains with Funsor inputs
# this avoids .expand-ing underlying parameter tensors
if isinstance(v, Funsor) and isinstance(v.output, RealsType):
domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)]
dtype = domains[k].dtype
if isinstance(v, Funsor):
domains[k] = Array[dtype, broadcast_shape(v.shape, domains[k].shape)]
elif ops.is_numeric_array(v):
domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)]
domains[k] = Array[dtype, broadcast_shape(v.shape, domains[k].shape)]

# now use the broadcasted parameter shapes to infer the event_shape
domains["value"] = cls._infer_value_domain(**domains)
if isinstance(kwargs["value"], Funsor) and isinstance(kwargs["value"].output, RealsType):
if isinstance(kwargs["value"], Funsor):
# try to broadcast the event shape with the value, in case they disagree
domains["value"] = Reals[broadcast_shape(domains["value"].shape, kwargs["value"].output.shape)]
dtype = domains["value"].dtype
domains["value"] = Array[dtype, broadcast_shape(domains["value"].shape, kwargs["value"].output.shape)]

# finally, perform conversions to funsors
kwargs = OrderedDict((k, to_funsor(v, output=domains[k])) for k, v in kwargs.items())
Expand Down
6 changes: 3 additions & 3 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,13 @@ def find_domain(op, *domains):
shape = op.shape
elif isinstance(op, ops.AssociativeOp):
shape = ()
return Reals[shape] if dtype == "real" else Bint[dtype]
return Array[dtype, shape]

lhs, rhs = domains
if isinstance(op, ops.GetitemOp):
dtype = lhs.dtype
shape = lhs.shape[:op.offset] + lhs.shape[1 + op.offset:]
return Reals[shape] if dtype == "real" else Bint[dtype]
return Array[dtype, shape]
elif op == ops.matmul:
assert lhs.shape and rhs.shape
if len(rhs.shape) == 1:
Expand Down Expand Up @@ -232,7 +232,7 @@ def find_domain(op, *domains):
shape = lhs.shape
else:
shape = broadcast_shape(lhs.shape, rhs.shape)
return Reals[shape] if dtype == "real" else Bint[dtype]
return Array[dtype, shape]


__all__ = [
Expand Down
12 changes: 6 additions & 6 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import funsor.interpreter as interpreter
import funsor.ops as ops
from funsor.domains import Bint, Domain, Real, Reals, find_domain
from funsor.domains import Array, Bint, Domain, Real, find_domain
from funsor.interpreter import PatternMissingError, dispatched_interpretation, interpret
from funsor.ops import AssociativeOp, GetitemOp, Op
from funsor.util import getargspec, lazy_property, pretty, quote
Expand Down Expand Up @@ -1086,7 +1086,7 @@ def __init__(self, data, dtype=None):
assert isinstance(dtype, str) and dtype == "real"
data = float(data)
inputs = OrderedDict()
output = Real if dtype == "real" else Bint[dtype]
output = Array[dtype, ()]
super(Number, self).__init__(inputs, output)
self.data = data

Expand Down Expand Up @@ -1450,7 +1450,7 @@ def __init__(self, var, expr):
inputs = expr.inputs.copy()
inputs.pop(var.name, None)
shape = (var.dtype,) + expr.output.shape
output = Reals[shape] if expr.dtype == "real" else Bint[expr.size]
output = Array[expr.dtype, shape]
fresh = frozenset()
bound = frozenset({var.name})
super(Lambda, self).__init__(inputs, output, fresh, bound)
Expand Down Expand Up @@ -1502,11 +1502,11 @@ def __init__(self, fn, reals_var, bint_var, diag_var):
assert isinstance(fn.inputs[bint_var].dtype, int)
assert isinstance(diag_var, str)
assert diag_var in fn.inputs
assert fn.inputs[diag_var].dtype == 'real'
inputs = fn.inputs.copy()
shape = (inputs.pop(bint_var).dtype,) + inputs.pop(diag_var).shape
diag_input = inputs.pop(diag_var)
shape = (inputs.pop(bint_var).dtype,) + diag_input.shape
assert reals_var not in inputs
inputs[reals_var] = Reals[shape]
inputs[reals_var] = Array[diag_input.dtype, shape]
fresh = frozenset({reals_var})
bound = frozenset({bint_var, diag_var})
super(Independent, self).__init__(inputs, fn.output, fresh, bound)
Expand Down
31 changes: 30 additions & 1 deletion test/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from funsor.cnf import Contraction, GaussianMixture
from funsor.delta import Delta
from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND
from funsor.domains import Bint, Real, Reals
from funsor.domains import Array, Bint, Real, Reals
from funsor.integrate import Integrate
from funsor.interpreter import interpretation, reinterpret
from funsor.tensor import Einsum, Tensor, numeric_array, stack
Expand Down Expand Up @@ -1186,3 +1186,32 @@ def test_mvnormal_event_dim_conversion(batch_shape, event_shape):
funsor.to_data(data, name_to_dim=name_to_dim))
assert actual_log_prob.shape == expected_log_prob.shape
assert_close(actual_log_prob, expected_log_prob)


@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
@pytest.mark.parametrize('event_shape', [(), (4,), (4, 7), (1, 4), (4, 1), (4, 1, 7)], ids=str)
def test_categorical_event_dim_conversion(batch_shape, event_shape):

dtype = 6
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))

value = Variable("value", Array[dtype, event_shape])
probs = Tensor(rand(batch_shape + event_shape + (6,)), inputs)

with interpretation(lazy):
actual = dist.Categorical(probs=probs, value=value)

expected_inputs = inputs.copy()
expected_inputs.update({"value": Array[dtype, event_shape]})
check_funsor(actual, expected_inputs, Real)

name_to_dim = {batch_dim: -1-i for i, batch_dim in enumerate(batch_dims)}
rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32)
data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0].terms[0][1][0]

actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim)
expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob(
funsor.to_data(data, name_to_dim=name_to_dim))
assert actual_log_prob.shape == expected_log_prob.shape
assert_close(actual_log_prob, expected_log_prob)
28 changes: 23 additions & 5 deletions test/test_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,16 +404,17 @@ def test_lambda(base_shape):
check_funsor(zij[:, i], zj.inputs, zj.output)


def test_independent():
f = Variable('x_i', Reals[4, 5]) + random_tensor(OrderedDict(i=Bint[3]))
assert f.inputs['x_i'] == Reals[4, 5]
@pytest.mark.parametrize("dtype", ["real", 2, 3])
def test_independent(dtype):
f = Variable('x_i', Array[dtype, (4, 5)]) + random_tensor(OrderedDict(i=Bint[3]), output=Array[dtype, ()])
assert f.inputs['x_i'] == Array[dtype, (4, 5)]
assert f.inputs['i'] == Bint[3]

actual = Independent(f, 'x', 'i', 'x_i')
assert actual.inputs['x'] == Reals[3, 4, 5]
assert actual.inputs['x'] == Array[dtype, (3, 4, 5)]
assert 'i' not in actual.inputs

x = Variable('x', Reals[3, 4, 5])
x = Variable('x', Array[dtype, (3, 4, 5)])
expected = f(x_i=x['i']).reduce(ops.add, 'i')
assert actual.inputs == expected.inputs
assert actual.output == expected.output
Expand Down Expand Up @@ -581,3 +582,20 @@ def test_cat_slice_tensor(start, stop, step):
actual = reinterpret(actual)

assert_close(actual, expected)


@pytest.mark.parametrize("dtype", ["real", 2, 3])
def test_stack_lambda(dtype):

x1 = Number(0, dtype)
x2 = Number(1, dtype)

y = Stack("i", (x1, x2))

z = Lambda(Variable("i", Bint[2]), y)

assert y.shape == ()
assert z.output == Array[dtype, (2,)]

assert z[0] is x1
assert z[1] is x2

0 comments on commit a6fa083

Please sign in to comment.