From 62135e5d70ebf5ed6e6f989eead3f052ce66efd0 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 3 Nov 2020 15:09:10 -0500 Subject: [PATCH 1/4] Move generic to_funsor registration for distributions into funsor.distributions.make_dist --- funsor/distribution.py | 22 +++++++--------------- funsor/jax/distributions.py | 9 +++------ funsor/torch/distributions.py | 5 +---- 3 files changed, 11 insertions(+), 25 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index 51fbd0a68..3e41dc932 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -230,7 +230,7 @@ def _infer_param_domain(cls, name, raw_shape): ################################################################################ -def make_dist(backend_dist_class, param_names=()): +def make_dist(backend_dist_class, param_names=(), generate_eager=True, generate_to_funsor=True): if not param_names: param_names = tuple(name for name in inspect.getfullargspec(backend_dist_class.__init__)[0][1:] if name in backend_dist_class.arg_constraints) @@ -244,7 +244,11 @@ def dist_init(self, **kwargs): '__init__': dist_init, }) - eager.register(dist_class, *((Tensor,) * (len(param_names) + 1)))(dist_class.eager_log_prob) + if generate_eager: + eager.register(dist_class, *((Tensor,) * (len(param_names) + 1)))(dist_class.eager_log_prob) + + if generate_to_funsor: + to_funsor.register(backend_dist_class)(functools.partial(backenddist_to_funsor, dist_class)) return dist_class @@ -277,9 +281,7 @@ def dist_init(self, **kwargs): # Converting backend Distributions to funsors ############################################### -def backenddist_to_funsor(backend_dist, output=None, dim_to_name=None): - funsor_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) - funsor_dist_class = getattr(funsor_dist, type(backend_dist).__name__.split("Wrapper_")[-1]) +def backenddist_to_funsor(funsor_dist_class, backend_dist, output=None, dim_to_name=None): params = [to_funsor( getattr(backend_dist, param_name), output=funsor_dist_class._infer_param_domain( @@ -311,16 +313,6 @@ def transformeddist_to_funsor(backend_dist, output=None, dim_to_name=None): raise NotImplementedError("TODO implement conversion of TransformedDistribution") -def mvndist_to_funsor(backend_dist, output=None, dim_to_name=None, real_inputs=OrderedDict()): - funsor_dist = backenddist_to_funsor(backend_dist, output=output, dim_to_name=dim_to_name) - if len(real_inputs) == 0: - return funsor_dist - discrete, gaussian = funsor_dist(value="value").terms - inputs = OrderedDict((k, v) for k, v in gaussian.inputs.items() if v.dtype != 'real') - inputs.update(real_inputs) - return discrete + Gaussian(gaussian.info_vec, gaussian.precision, inputs) - - class CoerceDistributionToFunsor: """ Handler to reinterpret a backend distribution ``D`` as a corresponding diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index 70b6b2aec..dd780adab 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -32,7 +32,6 @@ indepdist_to_funsor, make_dist, maskeddist_to_funsor, - mvndist_to_funsor, transformeddist_to_funsor, ) from funsor.domains import Real, Reals @@ -167,19 +166,17 @@ def _infer_param_domain(cls, name, raw_shape): # Converting PyTorch Distributions to funsors ############################################### -to_funsor.register(dist.Distribution)(backenddist_to_funsor) to_funsor.register(dist.Independent)(indepdist_to_funsor) if hasattr(dist, "MaskedDistribution"): to_funsor.register(dist.MaskedDistribution)(maskeddist_to_funsor) to_funsor.register(dist.TransformedDistribution)(transformeddist_to_funsor) -to_funsor.register(dist.MultivariateNormal)(mvndist_to_funsor) @to_funsor.register(dist.BinomialProbs) @to_funsor.register(dist.BinomialLogits) def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): new_pyro_dist = _NumPyroWrapper_Binomial(probs=numpyro_dist.probs) - return backenddist_to_funsor(new_pyro_dist, output, dim_to_name) + return backenddist_to_funsor(Binomial, new_pyro_dist, output, dim_to_name) # noqa: F821 @to_funsor.register(dist.CategoricalProbs) @@ -188,14 +185,14 @@ def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): @to_funsor.register(dist.CategoricalLogits) def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): new_pyro_dist = _NumPyroWrapper_Categorical(probs=numpyro_dist.probs) - return backenddist_to_funsor(new_pyro_dist, output, dim_to_name) + return backenddist_to_funsor(Categorical, new_pyro_dist, output, dim_to_name) # noqa: F821 @to_funsor.register(dist.MultinomialProbs) @to_funsor.register(dist.MultinomialLogits) def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): new_pyro_dist = _NumPyroWrapper_Multinomial(probs=numpyro_dist.probs) - return backenddist_to_funsor(new_pyro_dist, output, dim_to_name) + return backenddist_to_funsor(Multinomial, new_pyro_dist, output, dim_to_name) # noqa: F821 JointDirichletMultinomial = Contraction[ diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index c2918eb97..ba1798a5a 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -35,7 +35,6 @@ indepdist_to_funsor, make_dist, maskeddist_to_funsor, - mvndist_to_funsor, transformeddist_to_funsor, ) from funsor.domains import Real, Reals @@ -158,17 +157,15 @@ def _infer_param_domain(cls, name, raw_shape): # Converting PyTorch Distributions to funsors ############################################### -to_funsor.register(torch.distributions.Distribution)(backenddist_to_funsor) to_funsor.register(torch.distributions.Independent)(indepdist_to_funsor) to_funsor.register(MaskedDistribution)(maskeddist_to_funsor) to_funsor.register(torch.distributions.TransformedDistribution)(transformeddist_to_funsor) -to_funsor.register(torch.distributions.MultivariateNormal)(mvndist_to_funsor) @to_funsor.register(torch.distributions.Bernoulli) def bernoulli_to_funsor(pyro_dist, output=None, dim_to_name=None): new_pyro_dist = _PyroWrapper_BernoulliLogits(logits=pyro_dist.logits) - return backenddist_to_funsor(new_pyro_dist, output, dim_to_name) + return backenddist_to_funsor(BernoulliLogits, new_pyro_dist, output, dim_to_name) # noqa: F821 JointDirichletMultinomial = Contraction[ From 808cdefaf49bb5b98d202f1b321059300a9aaf4c Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 3 Nov 2020 15:53:02 -0500 Subject: [PATCH 2/4] fix funsor.pyro.convert --- funsor/pyro/convert.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/funsor/pyro/convert.py b/funsor/pyro/convert.py index c12e1764d..18fe149ba 100644 --- a/funsor/pyro/convert.py +++ b/funsor/pyro/convert.py @@ -140,7 +140,13 @@ def mvn_to_funsor(pyro_dist, event_inputs=(), real_inputs=OrderedDict()): assert isinstance(event_inputs, tuple) assert isinstance(real_inputs, OrderedDict) dim_to_name = default_dim_to_name(pyro_dist.batch_shape, event_inputs) - return to_funsor(pyro_dist, Real, dim_to_name, real_inputs=real_inputs) + funsor_dist = to_funsor(pyro_dist, Real, dim_to_name) + if len(real_inputs) == 0: + return funsor_dist + discrete, gaussian = funsor_dist(value="value").terms + inputs = OrderedDict((k, v) for k, v in gaussian.inputs.items() if v.dtype != 'real') + inputs.update(real_inputs) + return discrete + Gaussian(gaussian.info_vec, gaussian.precision, inputs) def funsor_to_mvn(gaussian, ndims, event_inputs=()): From ef3817efb7fb92914c79f2c2edb0801287fc188b Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 3 Nov 2020 18:22:29 -0500 Subject: [PATCH 3/4] switch to pyro distributions in eeg_slds --- examples/eeg_slds.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/eeg_slds.py b/examples/eeg_slds.py index a9dd65f2c..aef92a9ad 100644 --- a/examples/eeg_slds.py +++ b/examples/eeg_slds.py @@ -19,6 +19,7 @@ import numpy as np import torch import torch.nn as nn +import pyro import funsor import funsor.torch.distributions as dist @@ -82,7 +83,7 @@ def __init__(self, self.log_obs_noise = nn.Parameter(0.1 * torch.randn(obs_dim)) # define the prior distribution p(x_0) over the continuous latent at the initial time step t=0 - x_init_mvn = torch.distributions.MultivariateNormal(torch.zeros(self.hidden_dim), torch.eye(self.hidden_dim)) + x_init_mvn = pyro.distributions.MultivariateNormal(torch.zeros(self.hidden_dim), torch.eye(self.hidden_dim)) self.x_init_mvn = mvn_to_funsor(x_init_mvn, real_inputs=OrderedDict([('x_0', funsor.Reals[self.hidden_dim])])) # we construct the various funsors used to compute the marginal log probability and other model quantities. @@ -92,10 +93,10 @@ def get_tensors_and_dists(self): trans_logits = self.transition_logits - self.transition_logits.logsumexp(dim=-1, keepdim=True) trans_probs = funsor.Tensor(trans_logits, OrderedDict([("s", funsor.Bint[self.num_components])])) - trans_mvn = torch.distributions.MultivariateNormal(torch.zeros(self.hidden_dim), - self.log_transition_noise.exp().diag_embed()) - obs_mvn = torch.distributions.MultivariateNormal(torch.zeros(self.obs_dim), - self.log_obs_noise.exp().diag_embed()) + trans_mvn = pyro.distributions.MultivariateNormal(torch.zeros(self.hidden_dim), + self.log_transition_noise.exp().diag_embed()) + obs_mvn = pyro.distributions.MultivariateNormal(torch.zeros(self.obs_dim), + self.log_obs_noise.exp().diag_embed()) event_dims = ("s",) if self.fine_transition_matrix or self.fine_transition_noise else () x_trans_dist = matrix_and_mvn_to_funsor(self.transition_matrix, trans_mvn, event_dims, "x", "y") From 5a7e42895259fc8c47dcf69bb2e5739bc3a8506f Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 3 Nov 2020 19:50:50 -0500 Subject: [PATCH 4/4] use pyro distribution in sensor.py --- examples/sensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sensor.py b/examples/sensor.py index 6a003f78a..fca7d2f55 100644 --- a/examples/sensor.py +++ b/examples/sensor.py @@ -101,7 +101,7 @@ def forward(self, observations, add_bias=True): ) )(value=bias) - init_dist = torch.distributions.MultivariateNormal( + init_dist = dist.MultivariateNormal( torch.zeros(4), scale_tril=100. * torch.eye(4)) self.init = dist_to_funsor(init_dist)(value="state")