Skip to content

Commit

Permalink
DRY / fixes for method=stockwell
Browse files Browse the repository at this point in the history
  • Loading branch information
drammock committed Apr 14, 2023
1 parent 6a4f5cc commit 897a1e4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
26 changes: 15 additions & 11 deletions mne/time_frequency/_stockwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W):
return psd, itc


def _compute_freqs_st(fmin, fmax, n_fft, sfreq):
from scipy.fft import fftfreq
freqs = fftfreq(n_fft, 1. / sfreq)
if fmin is None:
fmin = freqs[freqs > 0][0]
if fmax is None:
fmax = freqs.max()

start_f = np.abs(freqs - fmin).argmin()
stop_f = np.abs(freqs - fmax).argmin()
freqs = freqs[start_f:stop_f]
return start_f, stop_f, freqs


@verbose
def tfr_array_stockwell(data, sfreq, fmin=None, fmax=None, n_fft=None,
width=1.0, decim=1, return_itc=False, n_jobs=None,
Expand Down Expand Up @@ -158,7 +172,6 @@ def tfr_array_stockwell(data, sfreq, fmin=None, fmax=None, n_fft=None,
----------
.. footbibliography::
"""
from scipy.fft import fftfreq
_validate_type(data, np.ndarray, 'data')
if data.ndim != 3:
raise ValueError(
Expand All @@ -167,16 +180,7 @@ def tfr_array_stockwell(data, sfreq, fmin=None, fmax=None, n_fft=None,
n_epochs, n_channels = data.shape[:2]
n_out = data.shape[2] // decim + bool(data.shape[-1] % decim)
data, n_fft_, zero_pad = _check_input_st(data, n_fft)

freqs = fftfreq(n_fft_, 1. / sfreq)
if fmin is None:
fmin = freqs[freqs > 0][0]
if fmax is None:
fmax = freqs.max()

start_f = np.abs(freqs - fmin).argmin()
stop_f = np.abs(freqs - fmax).argmin()
freqs = freqs[start_f:stop_f]
start_f, stop_f, freqs = _compute_freqs_st(fmin, fmax, n_fft_, sfreq)

W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width)
n_freq = stop_f - start_f
Expand Down
9 changes: 4 additions & 5 deletions mne/time_frequency/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..utils import (GetEpochsMixin, SizeMixin, TimeMixin,
_check_method_kwargs, _check_option, _time_mask, fill_doc,
object_diff, repr_html, warn)
from ._stockwell import _check_input_st, tfr_array_stockwell
from ._stockwell import _check_input_st, _compute_freqs_st, tfr_array_stockwell
from .multitaper import tfr_array_multitaper
from .spectrum import _get_instance_type_string
from .tfr import tfr_array_morlet
Expand All @@ -28,8 +28,6 @@ class BaseTFR(ContainsMixin, UpdateChannelsMixin, SizeMixin, TimeMixin):

def __init__(self, inst, method, freqs, tmin, tmax, picks,
proj, *, decim, n_jobs, verbose=None, **method_kw):
from scipy.fft import fftfreq

# arg checking
method = _check_option('method', method,
['morlet', 'multitaper', 'stockwell'])
Expand All @@ -52,8 +50,9 @@ def __init__(self, inst, method, freqs, tmin, tmax, picks,
f'or "auto", got {freqs}.')
method_kw.update(fmin=fmin, fmax=fmax)
# compute freqs
n_fft = method_kw.get('n_fft', _check_input_st(self._times, None))
self._freqs = fftfreq(n_fft, 1. / self._sfreq)
_, default_nfft, _ = _check_input_st(self._times, None)
n_fft = method_kw.get('n_fft', default_nfft)
*_, self._freqs = _compute_freqs_st(fmin, fmax, n_fft, self._sfreq)
else:
method_kw.update(freqs=freqs)
self._freqs = freqs
Expand Down

0 comments on commit 897a1e4

Please sign in to comment.