diff --git a/funsor/distribution.py b/funsor/distribution.py index 90633d8ab..d4edeb05b 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -16,7 +16,7 @@ import funsor.ops as ops from funsor.affine import is_affine from funsor.cnf import Contraction, GaussianMixture -from funsor.domains import Array, Real, Reals, RealsType +from funsor.domains import Array, Real, Reals from funsor.gaussian import Gaussian from funsor.interpreter import gensym from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype, @@ -72,16 +72,18 @@ def __call__(cls, *args, **kwargs): # broadcast individual param domains with Funsor inputs # this avoids .expand-ing underlying parameter tensors - if isinstance(v, Funsor) and isinstance(v.output, RealsType): - domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)] + dtype = domains[k].dtype + if isinstance(v, Funsor): + domains[k] = Array[dtype, broadcast_shape(v.shape, domains[k].shape)] elif ops.is_numeric_array(v): - domains[k] = Reals[broadcast_shape(v.shape, domains[k].shape)] + domains[k] = Array[dtype, broadcast_shape(v.shape, domains[k].shape)] # now use the broadcasted parameter shapes to infer the event_shape domains["value"] = cls._infer_value_domain(**domains) - if isinstance(kwargs["value"], Funsor) and isinstance(kwargs["value"].output, RealsType): + if isinstance(kwargs["value"], Funsor): # try to broadcast the event shape with the value, in case they disagree - domains["value"] = Reals[broadcast_shape(domains["value"].shape, kwargs["value"].output.shape)] + dtype = domains["value"].dtype + domains["value"] = Array[dtype, broadcast_shape(domains["value"].shape, kwargs["value"].output.shape)] # finally, perform conversions to funsors kwargs = OrderedDict((k, to_funsor(v, output=domains[k])) for k, v in kwargs.items()) diff --git a/funsor/domains.py b/funsor/domains.py index 084f89c35..ce7b254f1 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -197,13 +197,13 @@ def find_domain(op, *domains): shape = op.shape elif isinstance(op, ops.AssociativeOp): shape = () - return Reals[shape] if dtype == "real" else Bint[dtype] + return Array[dtype, shape] lhs, rhs = domains if isinstance(op, ops.GetitemOp): dtype = lhs.dtype shape = lhs.shape[:op.offset] + lhs.shape[1 + op.offset:] - return Reals[shape] if dtype == "real" else Bint[dtype] + return Array[dtype, shape] elif op == ops.matmul: assert lhs.shape and rhs.shape if len(rhs.shape) == 1: @@ -232,7 +232,7 @@ def find_domain(op, *domains): shape = lhs.shape else: shape = broadcast_shape(lhs.shape, rhs.shape) - return Reals[shape] if dtype == "real" else Bint[dtype] + return Array[dtype, shape] __all__ = [ diff --git a/funsor/terms.py b/funsor/terms.py index 34501c58d..db31a209a 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -17,7 +17,7 @@ import funsor.interpreter as interpreter import funsor.ops as ops -from funsor.domains import Bint, Domain, Real, Reals, find_domain +from funsor.domains import Array, Bint, Domain, Real, find_domain from funsor.interpreter import PatternMissingError, dispatched_interpretation, interpret from funsor.ops import AssociativeOp, GetitemOp, Op from funsor.util import getargspec, lazy_property, pretty, quote @@ -1086,7 +1086,7 @@ def __init__(self, data, dtype=None): assert isinstance(dtype, str) and dtype == "real" data = float(data) inputs = OrderedDict() - output = Real if dtype == "real" else Bint[dtype] + output = Array[dtype, ()] super(Number, self).__init__(inputs, output) self.data = data @@ -1450,7 +1450,7 @@ def __init__(self, var, expr): inputs = expr.inputs.copy() inputs.pop(var.name, None) shape = (var.dtype,) + expr.output.shape - output = Reals[shape] if expr.dtype == "real" else Bint[expr.size] + output = Array[expr.dtype, shape] fresh = frozenset() bound = frozenset({var.name}) super(Lambda, self).__init__(inputs, output, fresh, bound) @@ -1502,11 +1502,11 @@ def __init__(self, fn, reals_var, bint_var, diag_var): assert isinstance(fn.inputs[bint_var].dtype, int) assert isinstance(diag_var, str) assert diag_var in fn.inputs - assert fn.inputs[diag_var].dtype == 'real' inputs = fn.inputs.copy() - shape = (inputs.pop(bint_var).dtype,) + inputs.pop(diag_var).shape + diag_input = inputs.pop(diag_var) + shape = (inputs.pop(bint_var).dtype,) + diag_input.shape assert reals_var not in inputs - inputs[reals_var] = Reals[shape] + inputs[reals_var] = Array[diag_input.dtype, shape] fresh = frozenset({reals_var}) bound = frozenset({bint_var, diag_var}) super(Independent, self).__init__(inputs, fn.output, fresh, bound) diff --git a/test/test_distribution.py b/test/test_distribution.py index d72e83f4b..f79202b0b 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -14,7 +14,7 @@ from funsor.cnf import Contraction, GaussianMixture from funsor.delta import Delta from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND -from funsor.domains import Bint, Real, Reals +from funsor.domains import Array, Bint, Real, Reals from funsor.integrate import Integrate from funsor.interpreter import interpretation, reinterpret from funsor.tensor import Einsum, Tensor, numeric_array, stack @@ -1186,3 +1186,32 @@ def test_mvnormal_event_dim_conversion(batch_shape, event_shape): funsor.to_data(data, name_to_dim=name_to_dim)) assert actual_log_prob.shape == expected_log_prob.shape assert_close(actual_log_prob, expected_log_prob) + + +@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) +@pytest.mark.parametrize('event_shape', [(), (4,), (4, 7), (1, 4), (4, 1), (4, 1, 7)], ids=str) +def test_categorical_event_dim_conversion(batch_shape, event_shape): + + dtype = 6 + batch_dims = ('i', 'j', 'k')[:len(batch_shape)] + inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) + + value = Variable("value", Array[dtype, event_shape]) + probs = Tensor(rand(batch_shape + event_shape + (6,)), inputs) + + with interpretation(lazy): + actual = dist.Categorical(probs=probs, value=value) + + expected_inputs = inputs.copy() + expected_inputs.update({"value": Array[dtype, event_shape]}) + check_funsor(actual, expected_inputs, Real) + + name_to_dim = {batch_dim: -1-i for i, batch_dim in enumerate(batch_dims)} + rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) + data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0].terms[0][1][0] + + actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim) + expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob( + funsor.to_data(data, name_to_dim=name_to_dim)) + assert actual_log_prob.shape == expected_log_prob.shape + assert_close(actual_log_prob, expected_log_prob) diff --git a/test/test_terms.py b/test/test_terms.py index da3d25569..6e5167e95 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -404,16 +404,17 @@ def test_lambda(base_shape): check_funsor(zij[:, i], zj.inputs, zj.output) -def test_independent(): - f = Variable('x_i', Reals[4, 5]) + random_tensor(OrderedDict(i=Bint[3])) - assert f.inputs['x_i'] == Reals[4, 5] +@pytest.mark.parametrize("dtype", ["real", 2, 3]) +def test_independent(dtype): + f = Variable('x_i', Array[dtype, (4, 5)]) + random_tensor(OrderedDict(i=Bint[3]), output=Array[dtype, ()]) + assert f.inputs['x_i'] == Array[dtype, (4, 5)] assert f.inputs['i'] == Bint[3] actual = Independent(f, 'x', 'i', 'x_i') - assert actual.inputs['x'] == Reals[3, 4, 5] + assert actual.inputs['x'] == Array[dtype, (3, 4, 5)] assert 'i' not in actual.inputs - x = Variable('x', Reals[3, 4, 5]) + x = Variable('x', Array[dtype, (3, 4, 5)]) expected = f(x_i=x['i']).reduce(ops.add, 'i') assert actual.inputs == expected.inputs assert actual.output == expected.output @@ -581,3 +582,20 @@ def test_cat_slice_tensor(start, stop, step): actual = reinterpret(actual) assert_close(actual, expected) + + +@pytest.mark.parametrize("dtype", ["real", 2, 3]) +def test_stack_lambda(dtype): + + x1 = Number(0, dtype) + x2 = Number(1, dtype) + + y = Stack("i", (x1, x2)) + + z = Lambda(Variable("i", Bint[2]), y) + + assert y.shape == () + assert z.output == Array[dtype, (2,)] + + assert z[0] is x1 + assert z[1] is x2