diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 593baff3e..15afef28c 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -14,6 +14,8 @@ def _subs_wrapper(subs_map, i, length, site): + if site["type"] != "sample": + return value = None if isinstance(subs_map, dict) and site["name"] in subs_map: value = subs_map[site["name"]] diff --git a/numpyro/contrib/funsor/infer_util.py b/numpyro/contrib/funsor/infer_util.py index 5a65a2db5..430434899 100644 --- a/numpyro/contrib/funsor/infer_util.py +++ b/numpyro/contrib/funsor/infer_util.py @@ -162,7 +162,7 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): time_to_init_vars = defaultdict(frozenset) # PP... variables time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites sum_vars, prod_vars = frozenset(), frozenset() - history = 1 + history = 0 log_measures = {} for site in model_trace.values(): if site["type"] == "sample": @@ -186,10 +186,14 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): for dim, name in dim_to_name.items(): if name.startswith("_time"): time_dim = funsor.Variable(name, funsor.Bint[log_prob.shape[dim]]) - time_to_factors[time_dim].append(log_prob_factor) history = max( history, max(_get_shift(s) for s in dim_to_name.values()) ) + if history == 0: + log_factors.append(log_prob_factor) + prod_vars |= frozenset({name}) + else: + time_to_factors[time_dim].append(log_prob_factor) time_to_init_vars[time_dim] |= frozenset( s for s in dim_to_name.values() if s.startswith("_PREV_") ) diff --git a/test/contrib/test_funsor.py b/test/contrib/test_funsor.py index 795dbcc49..73cacdb1b 100644 --- a/test/contrib/test_funsor.py +++ b/test/contrib/test_funsor.py @@ -8,6 +8,7 @@ from numpy.testing import assert_allclose import pytest +import jax from jax import random import jax.numpy as jnp @@ -516,6 +517,29 @@ def transition_fn(carry, y): assert_allclose(actual_x_curr, expected_x_curr) +def test_scan_enum_history_0(): + def model(ys): + z = numpyro.sample("z", dist.Bernoulli(0.2), infer={"enumerate": "parallel"}) + + def transition_fn(c, y): + numpyro.sample("y", dist.Normal(z, 1), obs=y) + return None, None + + scan(transition_fn, None, ys) + + actual, trace = log_density( + model=enum(model, first_available_dim=-1), + model_args=(jnp.arange(3),), + model_kwargs={}, + params={}, + ) + z_factor = trace["z"]["fn"].log_prob(trace["z"]["value"]) + prev_y_factor = trace["_PREV_y"]["fn"].log_prob(trace["_PREV_y"]["value"]) + y_factor = trace["y"]["fn"].log_prob(trace["y"]["value"]).sum(0) + expected = jax.nn.logsumexp(z_factor + prev_y_factor + y_factor) + assert_allclose(actual, expected) + + def test_missing_plate(monkeypatch): K, N = 3, 1000