Skip to content

Commit

Permalink
Get tests running
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenoit26 committed Feb 27, 2025
1 parent 9f7c9e5 commit 446e708
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 6 deletions.
3 changes: 2 additions & 1 deletion ml4gw/transforms/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def __init__(
self.register_buffer("time_idxs", time_idxs)

def _check_and_format_kwargs(self, kwargs: Dict[str, List]) -> List:
lengths = sorted((len(v) for v in kwargs.values()))
lengths = sorted(len(v) for v in kwargs.values())
lengths = list(set(lengths))

if lengths[-1] > 3:
warnings.warn(
Expand Down
4 changes: 3 additions & 1 deletion ml4gw/transforms/spline_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,14 @@ def __init__(
y_out: Optional[Tensor] = None,
):
super().__init__()
if y_in is None:
y_in = Tensor([1])
self.kx = kx
self.ky = ky
self.sx = sx
self.sy = sy
self.register_buffer("x_in", x_in)
self.register_buffer("y_in", y_in or Tensor([1]))
self.register_buffer("y_in", y_in)
self.register_buffer("x_out", x_out)
self.register_buffer("y_out", y_out)

Expand Down
4 changes: 3 additions & 1 deletion tests/dataloading/test_chunked_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def dataset(
)

def test_iter(self, dataset, coincident):
length = 0
for x in dataset:
assert x.shape == (8, 2, 192)
if coincident:
Expand All @@ -76,5 +77,6 @@ def test_iter(self, dataset, coincident):
diffs = torch.diff(x[:, 0], axis=-1)
expected = torch.ones_like(diffs)
torch.testing.assert_close(diffs, expected, rtol=0, atol=0)
length += 1

assert len(dataset) == 42
assert length == 42
8 changes: 5 additions & 3 deletions tests/test_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,9 @@ def test_reweight_snrs(_get_waveforms_from_lalsimulation):
)
# mutate data in the hp timeseries, and recompute snr using LAL
hp.data.data = reweighted_response[..., 0, :].numpy().flatten()
ligo_snr = lalsimulation.MeasureSNR(hp, psd_1, 1, sample_rate / 2)
hp.data.data = reweighted_response[..., 1, :].numpy().flatten()
virgo_snr = lalsimulation.MeasureSNR(hp, psd_1, 1, sample_rate / 2)
network_snr = (ligo_snr**2 + virgo_snr**2) ** 0.5

assert lalsimulation.MeasureSNR(hp, psd_1, 1, 100) == pytest.approx(
target_network_snr.numpy()
)
assert network_snr == pytest.approx(target_network_snr.numpy(), rel=1e-1)

0 comments on commit 446e708

Please sign in to comment.