Skip to content

Commit

Permalink
Update test to reflect changes
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Apr 8, 2024
1 parent 6062709 commit 99d4b7a
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions tests/tests_common/test_subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
FastMRIMagicMaskFunc,
Gaussian1DMaskFunc,
Gaussian2DMaskFunc,
RadialMaskFunc,
SpiralMaskFunc,
VariableDensityPoissonMaskFunc,
],
)
Expand All @@ -44,6 +42,30 @@ def test_mask_reuse(mask_func, center_fracs, accelerations, batch_size, dim):
assert torch.all(mask2 == mask3)


@pytest.mark.parametrize(
"mask_func",
[
RadialMaskFunc,
SpiralMaskFunc,
],
)
@pytest.mark.parametrize(
"accelerations, batch_size, dim",
[
([4], 4, 320),
([4, 8], 2, 368),
],
)
def test_mask_reuse_circus(mask_func, accelerations, batch_size, dim):
mask_func = mask_func(accelerations=accelerations)
shape = (batch_size, dim, dim, 2)
mask1 = mask_func(shape, seed=123)
mask2 = mask_func(shape, seed=123)
mask3 = mask_func(shape, seed=123)
assert torch.all(mask1 == mask2)
assert torch.all(mask2 == mask3)


@pytest.mark.parametrize(
"mask_func",
[
Expand Down

0 comments on commit 99d4b7a

Please sign in to comment.