|
2 | 2 | # SPDX-License-Identifier: Apache-2.0
|
3 | 3 |
|
4 | 4 | import functools
|
| 5 | +from typing import Tuple, Union |
5 | 6 |
|
6 | 7 | import pyro.distributions as dist
|
7 | 8 | import pyro.distributions.testing.fakes as fakes
|
8 | 9 | from pyro.distributions.torch_distribution import MaskedDistribution
|
9 | 10 | import torch
|
10 | 11 |
|
| 12 | +from funsor.cnf import Contraction |
11 | 13 | from funsor.distribution import ( # noqa: F401
|
12 | 14 | Bernoulli,
|
13 | 15 | FUNSOR_DIST_NAMES,
|
|
31 | 33 | transformeddist_to_funsor,
|
32 | 34 | )
|
33 | 35 | from funsor.domains import reals
|
| 36 | +import funsor.ops as ops |
34 | 37 | from funsor.tensor import Tensor, dummy_numeric_array
|
35 |
| -from funsor.terms import Funsor, Variable, eager, to_funsor |
| 38 | +from funsor.terms import Binary, Funsor, Variable, eager, to_funsor |
36 | 39 |
|
37 | 40 |
|
38 | 41 | __all__ = list(x[0] for x in FUNSOR_DIST_NAMES)
|
@@ -137,3 +140,32 @@ def bernoulli_to_funsor(pyro_dist, output=None, dim_to_name=None):
|
137 | 140 | eager.register(Delta, Variable, Variable, Variable)(eager_delta_variable_variable) # noqa: F821
|
138 | 141 | eager.register(Normal, Funsor, Tensor, Funsor)(eager_normal) # noqa: F821
|
139 | 142 | eager.register(MultivariateNormal, Funsor, Tensor, Funsor)(eager_mvn) # noqa: F821
|
| 143 | + |
| 144 | + |
| 145 | +@eager.register(Contraction, ops.LogAddExpOp, ops.AddOp, frozenset, Dirichlet, Multinomial) # noqa: F821 |
| 146 | +def eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, x, y): |
| 147 | + dirichlet_reduction = frozenset(x.inputs).intersection(reduced_vars) |
| 148 | + if dirichlet_reduction: |
| 149 | + return DirichletMultinomial(concentration=x.concentration, # noqa: F821 |
| 150 | + total_count=y.total_count, |
| 151 | + value=y.value) |
| 152 | + else: |
| 153 | + return eager(Contraction, red_op, bin_op, reduced_vars, (x, y)) |
| 154 | + |
| 155 | + |
| 156 | +JointDirichletMultinomial = Contraction[ |
| 157 | + Union[ops.LogAddExpOp, ops.NullOp], |
| 158 | + ops.AddOp, |
| 159 | + frozenset, |
| 160 | + Tuple[Dirichlet, Multinomial], # noqa: F821 |
| 161 | +] |
| 162 | + |
| 163 | + |
| 164 | +@eager.register(Binary, ops.SubOp, JointDirichletMultinomial, DirichletMultinomial) # noqa: F821 |
| 165 | +def eager_dirichlet_posterior(op, c, z): |
| 166 | + if (z.concentration is c.terms[0].concentration) and (c.terms[1].total_count is z.total_count): |
| 167 | + return Dirichlet( # noqa: F821 |
| 168 | + concentration=z.concentration + c.terms[1].value, |
| 169 | + value=c.terms[0].value) |
| 170 | + else: |
| 171 | + return None |
0 commit comments