Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SciPy filters #189

Merged
merged 19 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ml4gw/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .constants import *
100 changes: 100 additions & 0 deletions ml4gw/transforms/iirfilter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Union

import torch
from scipy.signal import iirfilter
from torchaudio.functional import filtfilt


class IIRFilter(torch.nn.Module):
r"""
IIR digital and analog filter design given order and critical points.
Design an Nth-order digital or analog filter and apply it to a signal.
Uses SciPy's `iirfilter` function to create the filter coefficients.
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirfilter.html # noqa E501

The forward call of this module accepts a batch tensor of shape
(n_waveforms, n_samples) and returns the filtered waveforms.

Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ravioli1369 no point in repeating verbatim the scipy documentation.

I would just point users to their documentation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be still good to add the documentation so that it at least appears on hover? Anyways it's mentioned that it's a wrapper for scipy

N:
The order of the filter.
Wn:
A scalar or length-2 sequence giving the critical frequencies.
For digital filters, Wn are in the same units as fs. By
default, fs is 2 half-cycles/sample, so these are normalized
from 0 to 1, where 1 is the Nyquist frequency. (Wn is thus in
half-cycles / sample). For analog filters, Wn is an angular
frequency (e.g., rad/s). When Wn is a length-2 sequence,`Wn[0]`
must be less than `Wn[1]`.
rp:
For Chebyshev and elliptic filters, provides the maximum ripple in
the passband. (dB)
rs:
For Chebyshev and elliptic filters, provides the minimum
attenuation in the stop band. (dB)
btype:
The type of filter. Default is 'bandpass'.
analog:
When True, return an analog filter, otherwise a digital filter
is returned.
ftype:
The type of IIR filter to design:

- Butterworth : 'butter'
- Chebyshev I : 'cheby1'
- Chebyshev II : 'cheby2'
- Cauer/elliptic: 'ellip'
- Bessel/Thomson: 'bessel's
fs:
The sampling frequency of the digital system.

Returns:
Filtered signal on the forward pass.
"""

def __init__(
self,
N: int,
Wn: Union[float, torch.Tensor],
rs: Union[None, float, torch.Tensor] = None,
rp: Union[None, float, torch.Tensor] = None,
btype="band",
analog=False,
ftype="butter",
fs=None,
) -> None:
super().__init__()

if isinstance(Wn, torch.Tensor):
Wn = Wn.numpy()
if isinstance(rs, torch.Tensor):
rs = rs.numpy()
if isinstance(rp, torch.Tensor):
rp = rp.numpy()

b, a = iirfilter(
N,
Wn,
rs=rs,
rp=rp,
btype=btype,
analog=analog,
ftype=ftype,
output="ba",
fs=fs,
)
self.register_buffer("b", torch.tensor(b))
self.register_buffer("a", torch.tensor(a))

def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
Apply the filter to the input signal.

Args:
x:
The input signal to be filtered.

Returns:
The filtered signal.
"""
return filtfilt(x, self.a, self.b, clamp=False)
Loading
Loading