Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lowpass option for SNR calculation #202

Merged
merged 4 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions ml4gw/gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from jaxtyping import Float
from torch import Tensor

from ml4gw.utils.interferometer import InterferometerGeometry

from .constants import C
from .types import (
BatchTensor,
Expand All @@ -28,6 +26,7 @@
VectorGeometry,
WaveformTensor,
)
from .utils.interferometer import InterferometerGeometry


def outer(x: VectorGeometry, y: VectorGeometry) -> TensorGeometry:
Expand Down Expand Up @@ -285,6 +284,7 @@ def compute_ifo_snr(
psd: PSDTensor,
sample_rate: float,
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
) -> Float[Tensor, "batch num_ifos"]:
r"""Compute the SNRs of a batch of interferometer responses

Expand All @@ -300,7 +300,8 @@ def compute_ifo_snr(
{S_n^{(j)}(f)}df$$

Where $f_{\text{min}}$ is a minimum frequency denoted
by `highpass`, `f_{\text{max}}` is the Nyquist frequency
by `highpass`, `f_{\text{max}}` is the maximum frequency
denoted by `lowpass`, which defaults to the Nyquist frequency
dictated by `sample_rate`; `\tilde{h_{ij}}` and `\tilde{h_{ij}}*`
indicate the fourier transform of the $i$th waveform at
the $j$th inteferometer and its complex conjugate, respectively;
Expand Down Expand Up @@ -328,8 +329,15 @@ def compute_ifo_snr(
If a tensor is provided, it will be assumed to be a
pre-computed mask used to 0-out low frequency components.
If a float, it will be used to compute such a mask. If
left as `None`, all frequencies up to `sample_rate / 2`
left as `None`, all frequencies up to `lowpass`
will contribute to the SNR calculation.
lowpass:
The maximum frequency below which to compute the SNR.
If a tensor is provided, it will be assumed to be a
pre-computed mask used to 0-out high frequency components.
If a float, it will be used to compute such a mask. If
left as `None`, all frequencies from `highpass` up to
the Nyquist freqyency will contribute to the SNR calculation.
Returns:
Batch of SNRs computed for each interferometer
"""
Expand All @@ -346,7 +354,7 @@ def compute_ifo_snr(
integrand = fft / (psd**0.5)
integrand = integrand.type(torch.float32) ** 2

# mask out low frequency components if a critical
# mask out frequency components if a critical
# frequency or frequency mask was provided
if highpass is not None:
if not isinstance(highpass, torch.Tensor):
Expand All @@ -360,6 +368,18 @@ def compute_ifo_snr(
)
)
integrand *= highpass.to(integrand.device)
if lowpass is not None:
if not isinstance(lowpass, torch.Tensor):
freqs = torch.fft.rfftfreq(responses.shape[-1], 1 / sample_rate)
lowpass = freqs < lowpass
elif len(lowpass) != integrand.shape[-1]:
raise ValueError(
"Can't apply lowpass filter mask with {} frequecy bins"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor typo in 'frequency'

"to signal fft with {} frequency bins".format(
len(lowpass), integrand.shape[-1]
)
)
integrand *= lowpass.to(integrand.device)

# sum over the desired frequency range and multiply
# by df to turn it into an integration (and get
Expand All @@ -386,6 +406,7 @@ def compute_network_snr(
psd: PSDTensor,
sample_rate: float,
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
) -> BatchTensor:
r"""
Compute the total SNR from a gravitational waveform
Expand Down Expand Up @@ -422,10 +443,17 @@ def compute_network_snr(
If a float, it will be used to compute such a mask. If
left as `None`, all frequencies up to `sample_rate / 2`
will contribute to the SNR calculation.
lowpass:
The maximum frequency below which to compute the SNR.
If a tensor is provided, it will be assumed to be a
pre-computed mask used to 0-out high frequency components.
If a float, it will be used to compute such a mask. If
left as `None`, all frequencies from `highpass` up to
the Nyquist freqyency will contribute to the SNR calculation.
Returns:
Batch of SNRs for each waveform across the interferometer network
"""
snrs = compute_ifo_snr(responses, psd, sample_rate, highpass)
snrs = compute_ifo_snr(responses, psd, sample_rate, highpass, lowpass)
snrs = snrs**2
return snrs.sum(axis=-1) ** 0.5

Expand All @@ -436,6 +464,7 @@ def reweight_snrs(
psd: PSDTensor,
sample_rate: float,
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
) -> WaveformTensor:
"""Scale interferometer responses such that they have a desired SNR

Expand Down Expand Up @@ -466,10 +495,17 @@ def reweight_snrs(
If a float, it will be used to compute such a mask. If
left as `None`, all frequencies up to `sample_rate / 2`
will contribute to the SNR calculation.
lowpass:
The maximum frequency below which to compute the SNR.
If a tensor is provided, it will be assumed to be a
pre-computed mask used to 0-out high frequency components.
If a float, it will be used to compute such a mask. If
left as `None`, all frequencies from `highpass` up to
the Nyquist freqyency will contribute to the SNR calculation.
Returns:
Rescaled interferometer responses
"""

snrs = compute_network_snr(responses, psd, sample_rate, highpass)
snrs = compute_network_snr(responses, psd, sample_rate, highpass, lowpass)
weights = target_snrs / snrs
return responses * weights[:, None, None]
2 changes: 1 addition & 1 deletion ml4gw/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def truncate_inverse_power_spectrum(
as `None`, no lowpass filtering will be applied.
Returns:
The PSD with its time domain response truncated
to `fduration` and any highpassed frequencies
to `fduration` and any filtered frequencies
tapered.
"""

Expand Down
21 changes: 17 additions & 4 deletions ml4gw/transforms/snr_rescaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def __init__(
sample_rate: float,
waveform_duration: float,
highpass: Optional[float] = None,
lowpass: Optional[float] = None,
dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self.highpass = highpass
self.sample_rate = sample_rate
self.num_channels = num_channels

Expand All @@ -29,9 +29,18 @@ def __init__(

if highpass is not None:
freqs = torch.fft.rfftfreq(waveform_size, 1 / sample_rate)
self.register_buffer("mask", freqs >= highpass, persistent=False)
self.register_buffer(
"highpass_mask", freqs >= highpass, persistent=False
)
else:
self.highpass_mask = None
if lowpass is not None:
freqs = torch.fft.rfftfreq(waveform_size, 1 / sample_rate)
self.register_buffer(
"lowpass_mask", freqs < lowpass, persistent=False
)
else:
self.mask = None
self.lowpass_mask = None

def fit(
self,
Expand Down Expand Up @@ -63,7 +72,11 @@ def forward(
target_snrs: Optional[BatchTensor] = None,
):
snrs = compute_network_snr(
responses, self.background, self.sample_rate, self.mask
responses,
self.background,
self.sample_rate,
self.highpass_mask,
self.lowpass_mask,
)
if target_snrs is None:
idx = torch.randperm(len(snrs))
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def validate(whitened, highpass, lowpass, sample_rate, df):
stds = whitened.std(axis=-1)
target = torch.ones_like(stds)

# if we're highpassingi or lowpassing, then we
# if we're highpassing or lowpassing, then we
# shouldn't expect the standard deviation to be
# one because we're subtracting some power, so
# remove roughly the expected power contributed
Expand Down
62 changes: 54 additions & 8 deletions tests/test_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_compute_antenna_responses(
):
ra, dec, psi, phi, gps_times, plus, cross = data
expected = bilby_get_ifo_response(
ra, dec, psi, gps_times, ["plus", "cross"]
ra, dec, psi, gps_times, ["plus", "cross", "breathing"]
)

phi = torch.tensor(phi)
Expand All @@ -96,10 +96,15 @@ def test_compute_antenna_responses(
tensors, vertices = injection.get_ifo_geometry(*ifos)
tensors = tensors.type(torch.float64)

with pytest.raises(ValueError) as exc:
injection.compute_antenna_responses(
np.pi / 2 - dec, psi, phi, tensors, ["dummy_mode"]
)
assert str(exc.value).startswith("No polarization mode")
result = injection.compute_antenna_responses(
np.pi / 2 - dec, psi, phi, tensors, ["plus", "cross"]
np.pi / 2 - dec, psi, phi, tensors, ["plus", "cross", "breathing"]
)
assert result.shape == (batch_size, 2, len(ifos))
assert result.shape == (batch_size, 3, len(ifos))
compare_against_numpy(result, expected)


Expand Down Expand Up @@ -278,26 +283,67 @@ def _get_O4_psd(
return psd


def test_compute_ifo_snr(_get_waveforms_from_lalsimulation):
@pytest.fixture(params=[None, 32])
def highpass(request):
return request.param


@pytest.fixture(params=[None, 64])
def lowpass(request):
return request.param


def test_compute_ifo_snr(_get_waveforms_from_lalsimulation, highpass, lowpass):
"""Test SNR for stellar mass system against lalsimulation
with a relative tolerance.
"""
hp, hc = _get_waveforms_from_lalsimulation
sample_rate = 1024
psd = _get_O4_psd(sample_rate, hp.data.data.shape[-1])
# All systems in test have ISCO < 100
snr_hp_lal = lalsimulation.MeasureSNR(hp, psd, 1, 100)
snr_hc_lal = lalsimulation.MeasureSNR(hc, psd, 1, 100)
lal_highpass = highpass or 1
lal_lowpass = lowpass or 100
snr_hp_lal = lalsimulation.MeasureSNR(hp, psd, lal_highpass, lal_lowpass)
snr_hc_lal = lalsimulation.MeasureSNR(hc, psd, lal_highpass, lal_lowpass)

backgrounds = psd.data.data[: len(hp.data.data) // 2 + 1]
backgrounds = torch.from_numpy(backgrounds)
hp_torch = torch.from_numpy(hp.data.data)
hc_torch = torch.from_numpy(hc.data.data)

num_freqs = hp_torch.shape[-1] // 2 + 1
with pytest.raises(ValueError) as exc:
snr_hp_compute_ifo_snr = injection.compute_ifo_snr(
hp_torch,
backgrounds,
sample_rate=sample_rate,
highpass=torch.ones(num_freqs - 1),
lowpass=lowpass,
)
assert str(exc.value).startswith("Can't apply highpass")
with pytest.raises(ValueError) as exc:
snr_hp_compute_ifo_snr = injection.compute_ifo_snr(
hp_torch,
backgrounds,
sample_rate=sample_rate,
highpass=highpass,
lowpass=torch.ones(num_freqs - 1),
)
assert str(exc.value).startswith("Can't apply lowpass")

snr_hp_compute_ifo_snr = injection.compute_ifo_snr(
hp_torch, backgrounds, sample_rate=sample_rate
hp_torch,
backgrounds,
sample_rate=sample_rate,
highpass=highpass,
lowpass=lowpass,
)
snr_hc_compute_ifo_snr = injection.compute_ifo_snr(
hc_torch, backgrounds, sample_rate=sample_rate
hc_torch,
backgrounds,
sample_rate=sample_rate,
highpass=highpass,
lowpass=lowpass,
)

assert snr_hp_lal == pytest.approx(
Expand Down
37 changes: 30 additions & 7 deletions tests/transforms/test_snr_rescaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,26 @@ def waveform_duration(self, request):
def highpass(self, request):
return request.param

@pytest.fixture(params=[None, 64])
def lowpass(self, request):
return request.param

@pytest.fixture
def transform(
self, num_channels, sample_rate, waveform_duration, highpass
self, num_channels, sample_rate, waveform_duration, highpass, lowpass
):
return SnrRescaler(
num_channels, sample_rate, waveform_duration, highpass
num_channels, sample_rate, waveform_duration, highpass, lowpass
)

def test_init(
self, transform, num_channels, sample_rate, waveform_duration, highpass
self,
transform,
num_channels,
sample_rate,
waveform_duration,
highpass,
lowpass,
):
assert not transform.built

Expand All @@ -53,11 +63,24 @@ def test_init(
assert transform.background.shape == shape
assert (transform.background == 0).all().item()

if highpass is not None:
if highpass is not None and lowpass is not None:
start = int(highpass * waveform_duration)
stop = int(lowpass * waveform_duration)
assert len(transform._buffers) == 3
assert not (transform.highpass_mask[:start].any().item())
assert not (transform.lowpass_mask[stop:].any().item())
assert transform.lowpass_mask[start:stop].all().item()
elif highpass is not None:
idx = int(highpass * waveform_duration)
assert len(transform._buffers) == 2
assert not (transform.mask[:idx]).any().item()
assert (transform.mask[idx:]).all().item()
assert not (transform.highpass_mask[:idx]).any().item()
assert (transform.highpass_mask[idx:]).all().item()
elif lowpass is not None:
idx = int(lowpass * waveform_duration)
assert len(transform._buffers) == 2
assert not (transform.lowpass_mask[idx:]).any().item()
assert (transform.lowpass_mask[:idx]).all().item()
else:
assert len(transform._buffers) == 1
assert transform.mask is None
assert transform.highpass_mask is None
assert transform.lowpass_mask is None
Loading