Skip to content

Commit

Permalink
WIP AverageTFR
Browse files Browse the repository at this point in the history
  • Loading branch information
drammock committed Jan 2, 2024
1 parent de575c6 commit 15f2ad3
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 2 deletions.
14 changes: 12 additions & 2 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,13 +1143,23 @@ def compute_tfr(
Notes
-----
.. versionadded:: 1.3
.. versionadded:: 1.6
References
----------
.. footbibliography::
"""
pass
return AverageTFR(
self,
method=method,
freqs=freqs,
picks=picks,
proj=proj,
decim=decim,
n_jobs=n_jobs,
verbose=verbose,
**method_kw,
)

@verbose
def plot_psd(
Expand Down
87 changes: 87 additions & 0 deletions mne/time_frequency/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,3 +1094,90 @@ def average(self, method="mean", *, dim="epochs", copy=False):
"""
# TODO probably can reuse the existing method from EpochsTFR
pass


class AverageTFR(BaseTFR):
"""Data object for spectrotemporal representations of averaged data.
.. warning:: The preferred means of creating AverageTFR objects is via the instance
method :meth:`mne.Evoked.compute_tfr` or via
:meth:`mne.time_frequency.EpochsTFR.average`. Direct class
instantiation is not supported.
Parameters
----------
inst : instance of Evoked
The data from which to compute the time-frequency representation.
%(method_tfr)s
%(freqs_tfr)s
%(picks_good_data_noref)s
%(proj_psd)s
%(decim_tfr)s
%(n_jobs)s
%(verbose)s
%(method_kw_tfr)s
Attributes
----------
ch_names : list
The channel names.
freqs : array
Frequencies at which the amplitude, power, or fourier coefficients
have been computed.
%(info_not_none)s
method : str
The method used to compute the spectra (``'morlet'``, ``'multitaper'``
or ``'stockwell'``).
See Also
--------
RawTFR
EpochsTFR
mne.Evoked.compute_tfr
mne.time_frequency.EpochsTFR.average
References
----------
.. footbibliography::
"""

def __init__(
self,
inst,
method,
freqs,
*,
tmin=None,
tmax=None,
picks=None,
proj=False,
decim=1,
n_jobs=None,
verbose=None,
**method_kw,
):
from ..evoked import Evoked

# instance type check
msg = f"AverageTFR got instance type {type(inst).__name__}"
assert isinstance(inst, (Evoked, dict)), msg

super().__init__(
inst,
method,
freqs,
tmin=tmin,
tmax=tmax,
picks=picks,
proj=proj,
decim=decim,
n_jobs=n_jobs,
verbose=verbose,
**method_kw,
)

def _get_instance_data(self):
# prepend a singleton "epochs" axis
return self.inst.get_data(picks=self._picks)[np.newaxis, :, self._time_mask]

0 comments on commit 15f2ad3

Please sign in to comment.