@@ -805,6 +805,50 @@ def test_render_model(self):
805
805
pyro .render_model (model )
806
806
807
807
808
+ @pytest .mark .parametrize ("keep_dist" , [False , True ])
809
+ @pytest .mark .parametrize (
810
+ "loc_x, scale_x, loc_y, scale_y" , [(0.0 , 1.0 , 5.0 , 2.0 ), (5.0 , 2.0 , 0.0 , 1.0 )]
811
+ )
812
+ def test_condition_by_equalize (loc_x , scale_x , loc_y , scale_y , keep_dist ):
813
+ # Create model and equalize it.
814
+ def model ():
815
+ x = pyro .sample ("x" , dist .Normal (loc_x , scale_x ))
816
+ y = pyro .sample ("y" , dist .Normal (loc_y , scale_y ))
817
+ return x , y
818
+
819
+ equalized_model = pyro .poutine .equalize (model , ["x" , "y" ], keep_dist = keep_dist )
820
+
821
+ # Fit guide to model
822
+ guide = pyro .infer .autoguide .AutoNormal (equalized_model )
823
+ optim = pyro .optim .Adam (dict (lr = 0.1 ))
824
+ svi = pyro .infer .SVI (
825
+ equalized_model ,
826
+ guide ,
827
+ optim ,
828
+ loss = pyro .infer .TraceGraph_ELBO (num_particles = 1000 , vectorize_particles = True ),
829
+ )
830
+ for step_num in range (100 ):
831
+ svi .step ()
832
+
833
+ # Get guide distribution parameters
834
+ loc , scale = guide ._get_loc_and_scale ("x" )
835
+ loc = float (loc .detach ().numpy ())
836
+ scale = float (scale .detach ().numpy ())
837
+
838
+ # Verify against expected distribution parameters
839
+ if keep_dist :
840
+ # Both 'x' and 'y' are sampled and the model is conditioned on 'x' and 'y' having the same value.
841
+ expected_var = 1 / (1 / scale_x ** 2 + 1 / scale_y ** 2 )
842
+ expected_loc = (loc_x / scale_x ** 2 + loc_y / scale_y ** 2 ) * expected_var
843
+ expected_scale = expected_var ** 0.5
844
+ else :
845
+ # The random variable 'x' is sampled and its value is assigned to 'y'.
846
+ expected_loc = loc_x
847
+ expected_scale = scale_x
848
+ assert_close (loc , expected_loc , atol = 0.05 )
849
+ assert_close (scale , expected_scale , atol = 0.05 )
850
+
851
+
808
852
@pytest .mark .parametrize ("first_available_dim" , [- 1 , - 2 , - 3 ])
809
853
@pytest .mark .parametrize ("depth" , [0 , 1 , 2 ])
810
854
def test_enumerate_poutine (depth , first_available_dim ):
0 commit comments