Skip to content

Commit

Permalink
Move generic to_funsor registration for distributions into funsor.dis…
Browse files Browse the repository at this point in the history
…tributions.make_dist
  • Loading branch information
eb8680 committed Nov 3, 2020
1 parent ec22df5 commit 62135e5
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 25 deletions.
22 changes: 7 additions & 15 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _infer_param_domain(cls, name, raw_shape):
################################################################################


def make_dist(backend_dist_class, param_names=()):
def make_dist(backend_dist_class, param_names=(), generate_eager=True, generate_to_funsor=True):
if not param_names:
param_names = tuple(name for name in inspect.getfullargspec(backend_dist_class.__init__)[0][1:]
if name in backend_dist_class.arg_constraints)
Expand All @@ -244,7 +244,11 @@ def dist_init(self, **kwargs):
'__init__': dist_init,
})

eager.register(dist_class, *((Tensor,) * (len(param_names) + 1)))(dist_class.eager_log_prob)
if generate_eager:
eager.register(dist_class, *((Tensor,) * (len(param_names) + 1)))(dist_class.eager_log_prob)

if generate_to_funsor:
to_funsor.register(backend_dist_class)(functools.partial(backenddist_to_funsor, dist_class))

return dist_class

Expand Down Expand Up @@ -277,9 +281,7 @@ def dist_init(self, **kwargs):
# Converting backend Distributions to funsors
###############################################

def backenddist_to_funsor(backend_dist, output=None, dim_to_name=None):
funsor_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
funsor_dist_class = getattr(funsor_dist, type(backend_dist).__name__.split("Wrapper_")[-1])
def backenddist_to_funsor(funsor_dist_class, backend_dist, output=None, dim_to_name=None):
params = [to_funsor(
getattr(backend_dist, param_name),
output=funsor_dist_class._infer_param_domain(
Expand Down Expand Up @@ -311,16 +313,6 @@ def transformeddist_to_funsor(backend_dist, output=None, dim_to_name=None):
raise NotImplementedError("TODO implement conversion of TransformedDistribution")


def mvndist_to_funsor(backend_dist, output=None, dim_to_name=None, real_inputs=OrderedDict()):
funsor_dist = backenddist_to_funsor(backend_dist, output=output, dim_to_name=dim_to_name)
if len(real_inputs) == 0:
return funsor_dist
discrete, gaussian = funsor_dist(value="value").terms
inputs = OrderedDict((k, v) for k, v in gaussian.inputs.items() if v.dtype != 'real')
inputs.update(real_inputs)
return discrete + Gaussian(gaussian.info_vec, gaussian.precision, inputs)


class CoerceDistributionToFunsor:
"""
Handler to reinterpret a backend distribution ``D`` as a corresponding
Expand Down
9 changes: 3 additions & 6 deletions funsor/jax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
indepdist_to_funsor,
make_dist,
maskeddist_to_funsor,
mvndist_to_funsor,
transformeddist_to_funsor,
)
from funsor.domains import Real, Reals
Expand Down Expand Up @@ -167,19 +166,17 @@ def _infer_param_domain(cls, name, raw_shape):
# Converting PyTorch Distributions to funsors
###############################################

to_funsor.register(dist.Distribution)(backenddist_to_funsor)
to_funsor.register(dist.Independent)(indepdist_to_funsor)
if hasattr(dist, "MaskedDistribution"):
to_funsor.register(dist.MaskedDistribution)(maskeddist_to_funsor)
to_funsor.register(dist.TransformedDistribution)(transformeddist_to_funsor)
to_funsor.register(dist.MultivariateNormal)(mvndist_to_funsor)


@to_funsor.register(dist.BinomialProbs)
@to_funsor.register(dist.BinomialLogits)
def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None):
new_pyro_dist = _NumPyroWrapper_Binomial(probs=numpyro_dist.probs)
return backenddist_to_funsor(new_pyro_dist, output, dim_to_name)
return backenddist_to_funsor(Binomial, new_pyro_dist, output, dim_to_name) # noqa: F821


@to_funsor.register(dist.CategoricalProbs)
Expand All @@ -188,14 +185,14 @@ def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None):
@to_funsor.register(dist.CategoricalLogits)
def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None):
new_pyro_dist = _NumPyroWrapper_Categorical(probs=numpyro_dist.probs)
return backenddist_to_funsor(new_pyro_dist, output, dim_to_name)
return backenddist_to_funsor(Categorical, new_pyro_dist, output, dim_to_name) # noqa: F821


@to_funsor.register(dist.MultinomialProbs)
@to_funsor.register(dist.MultinomialLogits)
def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None):
new_pyro_dist = _NumPyroWrapper_Multinomial(probs=numpyro_dist.probs)
return backenddist_to_funsor(new_pyro_dist, output, dim_to_name)
return backenddist_to_funsor(Multinomial, new_pyro_dist, output, dim_to_name) # noqa: F821


JointDirichletMultinomial = Contraction[
Expand Down
5 changes: 1 addition & 4 deletions funsor/torch/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
indepdist_to_funsor,
make_dist,
maskeddist_to_funsor,
mvndist_to_funsor,
transformeddist_to_funsor,
)
from funsor.domains import Real, Reals
Expand Down Expand Up @@ -158,17 +157,15 @@ def _infer_param_domain(cls, name, raw_shape):
# Converting PyTorch Distributions to funsors
###############################################

to_funsor.register(torch.distributions.Distribution)(backenddist_to_funsor)
to_funsor.register(torch.distributions.Independent)(indepdist_to_funsor)
to_funsor.register(MaskedDistribution)(maskeddist_to_funsor)
to_funsor.register(torch.distributions.TransformedDistribution)(transformeddist_to_funsor)
to_funsor.register(torch.distributions.MultivariateNormal)(mvndist_to_funsor)


@to_funsor.register(torch.distributions.Bernoulli)
def bernoulli_to_funsor(pyro_dist, output=None, dim_to_name=None):
new_pyro_dist = _PyroWrapper_BernoulliLogits(logits=pyro_dist.logits)
return backenddist_to_funsor(new_pyro_dist, output, dim_to_name)
return backenddist_to_funsor(BernoulliLogits, new_pyro_dist, output, dim_to_name) # noqa: F821


JointDirichletMultinomial = Contraction[
Expand Down

0 comments on commit 62135e5

Please sign in to comment.