Skip to content

Commit

Permalink
refactor: update default tensor call for pytorch 2.1. refactors tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jun 11, 2024
1 parent 6416caa commit 8f0c3a7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions tests/analysis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def test_analysis_modules(device: str) -> None:
low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim), device=device
)

def simulator(parameter_set):
return 1.0 + parameter_set + torch.randn(parameter_set.shape) * 0.1
def simulator(theta):
return 1.0 + theta + torch.randn(theta.shape, device=theta.device) * 0.1

theta = prior.sample((300,)).to("cpu")
theta = prior.sample((300,)).to(device)
x = simulator(theta)

inf = SNPE(prior=prior, device=device)
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def set_seed():

@pytest.fixture(scope="session", autouse=True)
def set_default_tensor_type():
torch.set_default_tensor_type("torch.FloatTensor")
torch.set_default_dtype(torch.float32)


# Pytest hook to skip GPU tests if no devices are available.
Expand Down

0 comments on commit 8f0c3a7

Please sign in to comment.