diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 2b12dd637..593baff3e 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -300,7 +300,15 @@ def body_fn(wrapped_carry, x): return (i + 1, rng_key, carry), (PytreeTrace(trace), y) wrapped_carry = device_put((0, rng_key, init)) - return lax.scan(body_fn, wrapped_carry, xs, length=length, reverse=reverse) + last_carry, (pytree_trace, ys) = lax.scan( + body_fn, wrapped_carry, xs, length=length, reverse=reverse + ) + for name, site in pytree_trace.trace.items(): + if site["type"] != "sample": + continue + # we haven't promote shapes of values yet during `lax.scan`, so we do it here + site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"]) + return last_carry, (pytree_trace, ys) def scan(f, init, xs, length=None, reverse=False, history=1): diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index 61eca0025..c4b2e641d 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -196,3 +196,17 @@ def false_fun(_): atol=0.1, ) assert_allclose([x.mean(), x.std()], [2.0, jnp.sqrt(5.0)], atol=0.5) + + +def test_scan_promote(): + def model(): + def transition_fn(c, val): + with numpyro.plate("N", 3, dim=-1): + numpyro.sample("x", dist.Normal(0, 1), obs=1.0) + return None, None + + scan(transition_fn, None, None, length=10) + + tr = numpyro.handlers.trace(model).get_trace() + assert tr["x"]["value"].shape == (10, 1) + assert tr["x"]["fn"].log_prob(tr["x"]["value"]).shape == (10, 3)