From 8f0c3a71e1b7bceab3a5359a512981e8d0cf255a Mon Sep 17 00:00:00 2001 From: janfb Date: Tue, 11 Jun 2024 15:44:57 +0200 Subject: [PATCH] refactor: update default tensor call for pytorch 2.1. refactors tests. --- tests/analysis_test.py | 6 +++--- tests/conftest.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/analysis_test.py b/tests/analysis_test.py index 76ecf28d4..afec426b9 100644 --- a/tests/analysis_test.py +++ b/tests/analysis_test.py @@ -28,10 +28,10 @@ def test_analysis_modules(device: str) -> None: low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim), device=device ) - def simulator(parameter_set): - return 1.0 + parameter_set + torch.randn(parameter_set.shape) * 0.1 + def simulator(theta): + return 1.0 + theta + torch.randn(theta.shape, device=theta.device) * 0.1 - theta = prior.sample((300,)).to("cpu") + theta = prior.sample((300,)).to(device) x = simulator(theta) inf = SNPE(prior=prior, device=device) diff --git a/tests/conftest.py b/tests/conftest.py index 445a27bae..cccf5dbf6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,7 @@ def set_seed(): @pytest.fixture(scope="session", autouse=True) def set_default_tensor_type(): - torch.set_default_tensor_type("torch.FloatTensor") + torch.set_default_dtype(torch.float32) # Pytest hook to skip GPU tests if no devices are available.