Skip to content

Commit 661ed03

Browse files
author
Sam Anklesaria
authored
Add Dirichlet Multinomial conjugacy (#345)
1 parent c739031 commit 661ed03

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

funsor/torch/distributions.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import functools
5+
from typing import Tuple, Union
56

67
import pyro.distributions as dist
78
import pyro.distributions.testing.fakes as fakes
89
from pyro.distributions.torch_distribution import MaskedDistribution
910
import torch
1011

12+
from funsor.cnf import Contraction
1113
from funsor.distribution import ( # noqa: F401
1214
Bernoulli,
1315
FUNSOR_DIST_NAMES,
@@ -31,8 +33,9 @@
3133
transformeddist_to_funsor,
3234
)
3335
from funsor.domains import reals
36+
import funsor.ops as ops
3437
from funsor.tensor import Tensor, dummy_numeric_array
35-
from funsor.terms import Funsor, Variable, eager, to_funsor
38+
from funsor.terms import Binary, Funsor, Variable, eager, to_funsor
3639

3740

3841
__all__ = list(x[0] for x in FUNSOR_DIST_NAMES)
@@ -137,3 +140,32 @@ def bernoulli_to_funsor(pyro_dist, output=None, dim_to_name=None):
137140
eager.register(Delta, Variable, Variable, Variable)(eager_delta_variable_variable) # noqa: F821
138141
eager.register(Normal, Funsor, Tensor, Funsor)(eager_normal) # noqa: F821
139142
eager.register(MultivariateNormal, Funsor, Tensor, Funsor)(eager_mvn) # noqa: F821
143+
144+
145+
@eager.register(Contraction, ops.LogAddExpOp, ops.AddOp, frozenset, Dirichlet, Multinomial) # noqa: F821
146+
def eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, x, y):
147+
dirichlet_reduction = frozenset(x.inputs).intersection(reduced_vars)
148+
if dirichlet_reduction:
149+
return DirichletMultinomial(concentration=x.concentration, # noqa: F821
150+
total_count=y.total_count,
151+
value=y.value)
152+
else:
153+
return eager(Contraction, red_op, bin_op, reduced_vars, (x, y))
154+
155+
156+
JointDirichletMultinomial = Contraction[
157+
Union[ops.LogAddExpOp, ops.NullOp],
158+
ops.AddOp,
159+
frozenset,
160+
Tuple[Dirichlet, Multinomial], # noqa: F821
161+
]
162+
163+
164+
@eager.register(Binary, ops.SubOp, JointDirichletMultinomial, DirichletMultinomial) # noqa: F821
165+
def eager_dirichlet_posterior(op, c, z):
166+
if (z.concentration is c.terms[0].concentration) and (c.terms[1].total_count is z.total_count):
167+
return Dirichlet( # noqa: F821
168+
concentration=z.concentration + c.terms[1].value,
169+
value=c.terms[0].value)
170+
else:
171+
return None

test/test_distribution.py

+27
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,33 @@ def dirichlet_multinomial(concentration, total_count, value):
268268
assert_close(actual, expected)
269269

270270

271+
@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
272+
@pytest.mark.parametrize('event_shape', [(2,), (4,), (5,)], ids=str)
273+
@pytest.mark.xfail(get_backend() != 'torch', reason="DirichletMultinmial is not implemented yet in NumPyro")
274+
def test_dirichlet_multinomial_conjugate(batch_shape, event_shape):
275+
max_count = 10
276+
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
277+
inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))
278+
full_shape = batch_shape + event_shape
279+
prior = Variable("prior", reals(*full_shape))[batch_dims]
280+
concentration = Tensor(ops.exp(randn(full_shape)), inputs)
281+
value_data = ops.astype(randint(0, max_count, size=full_shape), 'float32')
282+
value = Tensor(value_data, inputs)
283+
total_count_data = value_data.sum(-1) + ops.astype(randint(0, max_count, size=batch_shape), 'float32')
284+
total_count = Tensor(total_count_data, inputs)
285+
p = dist.Dirichlet(concentration, value=prior)
286+
p += dist.Multinomial(probs=prior, total_count=total_count)
287+
marginalized = p.reduce(ops.logaddexp, set(["value"]))
288+
assert isinstance(marginalized, dist.Dirichlet)
289+
reduced = p.reduce(ops.logaddexp, set(["prior"]))
290+
assert isinstance(reduced, dist.DirichletMultinomial)
291+
assert_close(reduced.concentration, concentration)
292+
assert_close(reduced.total_count, total_count)
293+
result = (p - reduced)(value=value)
294+
assert isinstance(result, dist.Dirichlet)
295+
assert_close(result.concentration, concentration + value)
296+
297+
271298
@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
272299
def test_lognormal_density(batch_shape):
273300
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]

0 commit comments

Comments
 (0)