Skip to content

Commit 2d1e1ff

Browse files
author
Ben Zickel
committed
Added tests for the keep distribution option of the EqualizeMessenger effect handler.
1 parent e2a6725 commit 2d1e1ff

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

tests/poutine/test_poutines.py

+44
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,50 @@ def test_render_model(self):
805805
pyro.render_model(model)
806806

807807

808+
@pytest.mark.parametrize("keep_dist", [False, True])
809+
@pytest.mark.parametrize(
810+
"loc_x, scale_x, loc_y, scale_y", [(0.0, 1.0, 5.0, 2.0), (5.0, 2.0, 0.0, 1.0)]
811+
)
812+
def test_condition_by_equalize(loc_x, scale_x, loc_y, scale_y, keep_dist):
813+
# Create model and equalize it.
814+
def model():
815+
x = pyro.sample("x", dist.Normal(loc_x, scale_x))
816+
y = pyro.sample("y", dist.Normal(loc_y, scale_y))
817+
return x, y
818+
819+
equalized_model = pyro.poutine.equalize(model, ["x", "y"], keep_dist=keep_dist)
820+
821+
# Fit guide to model
822+
guide = pyro.infer.autoguide.AutoNormal(equalized_model)
823+
optim = pyro.optim.Adam(dict(lr=0.1))
824+
svi = pyro.infer.SVI(
825+
equalized_model,
826+
guide,
827+
optim,
828+
loss=pyro.infer.TraceGraph_ELBO(num_particles=1000, vectorize_particles=True),
829+
)
830+
for step_num in range(100):
831+
svi.step()
832+
833+
# Get guide distribution parameters
834+
loc, scale = guide._get_loc_and_scale("x")
835+
loc = float(loc.detach().numpy())
836+
scale = float(scale.detach().numpy())
837+
838+
# Verify against expected distribution parameters
839+
if keep_dist:
840+
# Both 'x' and 'y' are sampled and the model is conditioned on 'x' and 'y' having the same value.
841+
expected_var = 1 / (1 / scale_x**2 + 1 / scale_y**2)
842+
expected_loc = (loc_x / scale_x**2 + loc_y / scale_y**2) * expected_var
843+
expected_scale = expected_var**0.5
844+
else:
845+
# The random variable 'x' is sampled and its value is assigned to 'y'.
846+
expected_loc = loc_x
847+
expected_scale = scale_x
848+
assert_close(loc, expected_loc, atol=0.05)
849+
assert_close(scale, expected_scale, atol=0.05)
850+
851+
808852
@pytest.mark.parametrize("first_available_dim", [-1, -2, -3])
809853
@pytest.mark.parametrize("depth", [0, 1, 2])
810854
def test_enumerate_poutine(depth, first_available_dim):

0 commit comments

Comments
 (0)