diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index a3b7425d0..fe39a87bf 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpyro +from numpyro import distributions as dist from numpyro.distributions import constraints from numpyro.distributions.transforms import biject_to from numpyro.distributions.util import is_identically_one, sum_rightmost @@ -685,6 +686,18 @@ def initialize_model( has_enumerate_support, model_trace, ) = _get_model_transforms(substituted_model, model_args, model_kwargs) + + for name, site in model_trace.items(): + if ( + site["type"] == "sample" + and isinstance(site["fn"], dist.Delta) + and not site["is_observed"] + ): + raise ValueError( + f"Sample site '{name}' has a delta distribution; use " + "`numpyro.deterministic` to add this value to the trace instead." + ) + # substitute param sites from model_trace to model so # we don't need to generate again parameters of `numpyro.module` model = substitute( diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 0e2e53aa4..ec26d0ae9 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -1350,3 +1350,18 @@ def model(): assert jnp.allclose( jnp.stack(list(median.values())).ravel(), params["auto_loc"].ravel() ) + + +def test_autoguide_with_delta_site() -> None: + def model(x): + numpyro.sample("x", dist.Delta(3.0), obs=x) + # Need to sample a latent variable so the guide is not empty. + numpyro.sample("y", dist.Normal()) + + guide = AutoDiagonalNormal(lambda: model(None)) + with pytest.raises(ValueError, match="has a delta distribution"): + numpyro.handlers.seed(guide, 9)() + + # Check delta distributions are fine if observed. + guide = AutoDiagonalNormal(lambda: model(3.0)) + numpyro.handlers.seed(guide, 9)()