Skip to content

Commit

Permalink
Use num_samples fixture everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenoit26 committed Feb 6, 2025
1 parent 3629932 commit 556cfe3
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 212 deletions.
65 changes: 32 additions & 33 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ def validate(whitened, highpass, lowpass, sample_rate, df):
return validate


# number of samples to draw from
# the distributions for testing
N_SAMPLES = 1000
# A num_samples fixture should be defined for any
# test that wants to use these fixtures


@pytest.fixture(params=[256, 1024, 2048])
Expand All @@ -104,90 +103,90 @@ def sample_rate(request):


@pytest.fixture()
def chirp_mass(request):
def chirp_mass(num_samples, seed_everything):
dist = Uniform(5, 100)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def mass_ratio():
def mass_ratio(num_samples, seed_everything):
dist = Uniform(0.125, 0.99)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def a_1(request):
def a_1(num_samples, seed_everything):
dist = Uniform(0, 0.90)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def a_2(request):
def a_2(num_samples, seed_everything):
dist = Uniform(0, 0.90)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def tilt_1(request):
def tilt_1(num_samples, seed_everything):
dist = Uniform(0, torch.pi)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def tilt_2(request):
def tilt_2(num_samples, seed_everything):
dist = Uniform(0, torch.pi)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def phi_12(request):
def phi_12(num_samples, seed_everything):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def phi_jl(request):
def phi_jl(num_samples, seed_everything):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def distance(request):
def distance(num_samples, seed_everything):
dist = Uniform(100, 3000)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def distance_far(request):
def distance_far(num_samples, seed_everything):
dist = Uniform(500, 3000)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def distance_close(request):
def distance_close(num_samples, seed_everything):
dist = Uniform(100, 500)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def theta_jn(request):
def theta_jn(num_samples, seed_everything):
dist = Uniform(0, torch.pi)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def phase(request):
def phase(num_samples, seed_everything):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def chi1(request):
def chi1(num_samples, seed_everything):
dist = Uniform(-0.999, 0.999)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))


@pytest.fixture()
def chi2(request):
def chi2(num_samples, seed_everything):
dist = Uniform(-0.999, 0.999)
return dist.sample((N_SAMPLES,))
return dist.sample((num_samples,))
69 changes: 2 additions & 67 deletions tests/transforms/test_iirfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from astropy import units as u
from scipy.signal import filtfilt, iirfilter
from torch.distributions import Uniform

from ml4gw.constants import MSUN
from ml4gw.transforms.iirfilter import IIRFilter
Expand Down Expand Up @@ -124,73 +123,9 @@ def test_filters_synthetic_signal(sample_rate, order):
)


N_SAMPLES = 1


@pytest.fixture()
def chirp_mass(seed_everything, request):
dist = Uniform(5, 100)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def mass_ratio(seed_everything):
dist = Uniform(0.125, 0.99)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def a_1(seed_everything, request):
dist = Uniform(0, 0.90)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def a_2(seed_everything, request):
dist = Uniform(0, 0.90)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def tilt_1(seed_everything, request):
dist = Uniform(0, torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def tilt_2(seed_everything, request):
dist = Uniform(0, torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def phi_12(seed_everything, request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def phi_jl(seed_everything, request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def distance(seed_everything, request):
dist = Uniform(100, 3000)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def theta_jn(seed_everything, request):
dist = Uniform(0, torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))


@pytest.fixture()
def phase(seed_everything, request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample(torch.Size((N_SAMPLES,)))
def num_samples():
return 1


@pytest.fixture(params=[20, 40])
Expand Down
Loading

0 comments on commit 556cfe3

Please sign in to comment.