Skip to content

Commit

Permalink
add get_data() method
Browse files Browse the repository at this point in the history
  • Loading branch information
drammock committed Apr 14, 2023
1 parent 50f1527 commit 6a4f5cc
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion mne/time_frequency/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,37 @@ def get_data(self, picks=None, exclude='bads', fmin=0, fmax=np.inf,
The frequency values for the requested data range. Only returned if
``return_freqs`` is ``True``.
"""
pass
tmin = self.times[0] if tmin is None else tmin
tmax = self.times[-1] if tmax is None else tmax
picks = _picks_to_idx(self.info, picks, 'data_or_ica', exclude=exclude,
with_ref_meg=False)
fmin_idx = np.searchsorted(self.freqs, fmin)
fmax_idx = np.searchsorted(self.freqs, fmax, side='right')
tmin_idx = np.searchsorted(self.times, tmin)
tmax_idx = np.searchsorted(self.times, tmax, side='right')
freq_picks = np.arange(fmin_idx, fmax_idx)
time_picks = np.arange(tmin_idx, tmax_idx)
freq_axis = self._dims.index('freq')
time_axis = self._dims.index('time')
chan_axis = self._dims.index('channel')
# normally there's a risk of np.take reducing array dimension if there
# were only one channel or frequency selected, but `_picks_to_idx`
# and np.arange both always return arrays, so we're safe; the result
# will always be 3D.
data = (self._data.take(picks, chan_axis)
.take(freq_picks, freq_axis)
.take(time_picks, time_axis)
)
out = (data,)
if return_times:
times = self._times[tmin_idx:tmax_idx]
out.append(times)
if return_freqs:
freqs = self._freqs[fmin_idx:fmax_idx]
out.append(freqs)
if not return_times and not return_freqs:
return out[0]
return out


class RawTFR(BaseTFR):
Expand Down

0 comments on commit 6a4f5cc

Please sign in to comment.