Skip to content

Commit

Permalink
Add seeding to fixture args
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenoit26 committed Feb 6, 2025
1 parent d2ee746 commit 3629932
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 26 deletions.
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from torch.distributions import Uniform


# If a fixture is doing anything random,
# it should take this function as an argument
@pytest.fixture(autouse=True)
def seed_everything():
seed = 101589
Expand Down
22 changes: 11 additions & 11 deletions tests/transforms/test_iirfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,67 +128,67 @@ def test_filters_synthetic_signal(sample_rate, order):


@pytest.fixture()
def chirp_mass(request):
def chirp_mass(seed_everything, request):
dist = Uniform(5, 100)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def mass_ratio():
def mass_ratio(seed_everything):
dist = Uniform(0.125, 0.99)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def a_1(request):
def a_1(seed_everything, request):
dist = Uniform(0, 0.90)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def a_2(request):
def a_2(seed_everything, request):
dist = Uniform(0, 0.90)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def tilt_1(request):
def tilt_1(seed_everything, request):
dist = Uniform(0, torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def tilt_2(request):
def tilt_2(seed_everything, request):
dist = Uniform(0, torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def phi_12(request):
def phi_12(seed_everything, request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def phi_jl(request):
def phi_jl(seed_everything, request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def distance(request):
def distance(seed_everything, request):
dist = Uniform(100, 3000)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def theta_jn(request):
def theta_jn(seed_everything, request):
dist = Uniform(0, torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def phase(request):
def phase(seed_everything, request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))

Expand Down
30 changes: 15 additions & 15 deletions tests/waveforms/cbc/test_cbc_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,91 +23,91 @@ def sample_rate(request):


@pytest.fixture()
def chirp_mass(request):
def chirp_mass(seed_everything, request):
dist = Uniform(5, 100)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def mass_ratio():
def mass_ratio(seed_everything):
dist = Uniform(0.125, 0.99)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def a_1(request):
def a_1(seed_everything, request):
dist = Uniform(0, 0.90)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def a_2(request):
def a_2(seed_everything, request):
dist = Uniform(0, 0.90)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def tilt_1(request):
def tilt_1(seed_everything, request):
dist = Uniform(0, torch.pi)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def tilt_2(request):
def tilt_2(seed_everything, request):
dist = Uniform(0, torch.pi)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def phi_12(request):
def phi_12(seed_everything, request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def phi_jl(request):
def phi_jl(seed_everything, request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def distance(request):
def distance(seed_everything, request):
dist = Uniform(100, 3000)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def distance_far(request):
def distance_far(seed_everything, request):
dist = Uniform(400, 3000)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def distance_close(request):
def distance_close(seed_everything, request):
dist = Uniform(100, 400)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def theta_jn(request):
def theta_jn(seed_everything, request):
dist = Uniform(0, torch.pi)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def phase(request):
def phase(seed_everything, request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def chi_1(request):
def chi_1(seed_everything, request):
dist = Uniform(-0.999, 0.999)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def chi_2(request):
def chi_2(seed_everything, request):
dist = Uniform(-0.999, 0.999)
return dist.sample((N_SAMPLES,))

Expand Down

0 comments on commit 3629932

Please sign in to comment.