diff --git a/tests/contrib/test_zuko.py b/tests/contrib/test_zuko.py index 4cc068556f..0e22b898a5 100644 --- a/tests/contrib/test_zuko.py +++ b/tests/contrib/test_zuko.py @@ -27,6 +27,7 @@ def test_ZukoToPyro(multivariate: bool, rsample_and_log_prob: bool): dist = normal(mu, sigma) if rsample_and_log_prob: + def dummy(self, shape): x = self.rsample(shape) return x, self.log_prob(x)