Skip to content

Commit

Permalink
Raise explicit error for unobserved Delta sample sites during intia…
Browse files Browse the repository at this point in the history
…lization.
  • Loading branch information
tillahoffmann committed Jan 16, 2025
1 parent 8d766bf commit 15b04ca
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
13 changes: 13 additions & 0 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import jax.numpy as jnp

import numpyro
from numpyro import distributions as dist
from numpyro.distributions import constraints
from numpyro.distributions.transforms import biject_to
from numpyro.distributions.util import is_identically_one, sum_rightmost
Expand Down Expand Up @@ -685,6 +686,18 @@ def initialize_model(
has_enumerate_support,
model_trace,
) = _get_model_transforms(substituted_model, model_args, model_kwargs)

for name, site in model_trace.items():
if (
site["type"] == "sample"
and isinstance(site["fn"], dist.Delta)
and not site["is_observed"]
):
raise ValueError(
f"Sample site '{name}' has a delta distribution; use "
"`numpyro.deterministic` to add this value to the trace instead."
)

# substitute param sites from model_trace to model so
# we don't need to generate again parameters of `numpyro.module`
model = substitute(
Expand Down
15 changes: 15 additions & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,3 +1350,18 @@ def model():
assert jnp.allclose(
jnp.stack(list(median.values())).ravel(), params["auto_loc"].ravel()
)


def test_autoguide_with_delta_site() -> None:
def model(x):
numpyro.sample("x", dist.Delta(3.0), obs=x)
# Need to sample a latent variable so the guide is not empty.
numpyro.sample("y", dist.Normal())

guide = AutoDiagonalNormal(lambda: model(None))
with pytest.raises(ValueError, match="has a delta distribution"):
numpyro.handlers.seed(guide, 9)()

# Check delta distributions are fine if observed.
guide = AutoDiagonalNormal(lambda: model(3.0))
numpyro.handlers.seed(guide, 9)()

0 comments on commit 15b04ca

Please sign in to comment.