diff --git a/tests/analysis_test.py b/tests/analysis_test.py index 76ecf28d4..afec426b9 100644 --- a/tests/analysis_test.py +++ b/tests/analysis_test.py @@ -28,10 +28,10 @@ def test_analysis_modules(device: str) -> None: low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim), device=device ) - def simulator(parameter_set): - return 1.0 + parameter_set + torch.randn(parameter_set.shape) * 0.1 + def simulator(theta): + return 1.0 + theta + torch.randn(theta.shape, device=theta.device) * 0.1 - theta = prior.sample((300,)).to("cpu") + theta = prior.sample((300,)).to(device) x = simulator(theta) inf = SNPE(prior=prior, device=device) diff --git a/tests/conftest.py b/tests/conftest.py index 445a27bae..cccf5dbf6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,7 @@ def set_seed(): @pytest.fixture(scope="session", autouse=True) def set_default_tensor_type(): - torch.set_default_tensor_type("torch.FloatTensor") + torch.set_default_dtype(torch.float32) # Pytest hook to skip GPU tests if no devices are available.