diff --git a/doc/api/time_frequency.rst b/doc/api/time_frequency.rst index f8948909491..8923920bdba 100644 --- a/doc/api/time_frequency.rst +++ b/doc/api/time_frequency.rst @@ -14,7 +14,12 @@ Time-Frequency :toctree: ../generated/ AverageTFR + AverageTFRArray + BaseTFR EpochsTFR + EpochsTFRArray + RawTFR + RawTFRArray CrossSpectralDensity Spectrum SpectrumArray diff --git a/doc/changes/devel/11282.apichange.rst b/doc/changes/devel/11282.apichange.rst new file mode 100644 index 00000000000..9112db897cf --- /dev/null +++ b/doc/changes/devel/11282.apichange.rst @@ -0,0 +1 @@ +The default value of the ``zero_mean`` parameter of :func:`mne.time_frequency.tfr_array_morlet` will change from ``False`` to ``True`` in version 1.8, for consistency with related functions. By `Daniel McCloy`_. diff --git a/doc/changes/devel/11282.bugfix.rst b/doc/changes/devel/11282.bugfix.rst new file mode 100644 index 00000000000..72e6e73a42a --- /dev/null +++ b/doc/changes/devel/11282.bugfix.rst @@ -0,0 +1 @@ +Fixes to interactivity in time-frequency objects: the rectangle selector now works on TFR image plots of gradiometer data; and in ``TFR.plot_joint()`` plots, the colormap limits of interactively-generated topomaps match the colormap limits of the main plot. By `Daniel McCloy`_. \ No newline at end of file diff --git a/doc/changes/devel/11282.newfeature.rst b/doc/changes/devel/11282.newfeature.rst new file mode 100644 index 00000000000..5c19d68f351 --- /dev/null +++ b/doc/changes/devel/11282.newfeature.rst @@ -0,0 +1 @@ +New class :class:`mne.time_frequency.RawTFR` and new methods :meth:`mne.io.Raw.compute_tfr`, :meth:`mne.Epochs.compute_tfr`, and :meth:`mne.Evoked.compute_tfr`. These new methods supersede functions :func:`mne.time_frequency.tfr_morlet`, and :func:`mne.time_frequency.tfr_multitaper`, and :func:`mne.time_frequency.tfr_stockwell`, which are now considered "legacy" functions. By `Daniel McCloy`_. \ No newline at end of file diff --git a/doc/conf.py b/doc/conf.py index b2dbe387f27..cc2a25d7089 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -231,7 +231,11 @@ "EvokedArray": "mne.EvokedArray", "BiHemiLabel": "mne.BiHemiLabel", "AverageTFR": "mne.time_frequency.AverageTFR", + "AverageTFRArray": "mne.time_frequency.AverageTFRArray", "EpochsTFR": "mne.time_frequency.EpochsTFR", + "EpochsTFRArray": "mne.time_frequency.EpochsTFRArray", + "RawTFR": "mne.time_frequency.RawTFR", + "RawTFRArray": "mne.time_frequency.RawTFRArray", "Raw": "mne.io.Raw", "ICA": "mne.preprocessing.ICA", "Covariance": "mne.Covariance", diff --git a/examples/decoding/decoding_csp_timefreq.py b/examples/decoding/decoding_csp_timefreq.py index 6f13175846e..c389645d668 100644 --- a/examples/decoding/decoding_csp_timefreq.py +++ b/examples/decoding/decoding_csp_timefreq.py @@ -32,7 +32,7 @@ from mne.datasets import eegbci from mne.decoding import CSP from mne.io import concatenate_raws, read_raw_edf -from mne.time_frequency import AverageTFR +from mne.time_frequency import AverageTFRArray # %% # Set parameters and read data @@ -173,13 +173,15 @@ # Plot time-frequency results # Set up time frequency object -av_tfr = AverageTFR( - create_info(["freq"], sfreq), - tf_scores[np.newaxis, :], - centered_w_times, - freqs[1:], - 1, +av_tfr = AverageTFRArray( + info=create_info(["freq"], sfreq), + data=tf_scores[np.newaxis, :], + times=centered_w_times, + freqs=freqs[1:], + nave=1, ) chance = np.mean(y) # set chance level to white in the plot -av_tfr.plot([0], vmin=chance, title="Time-Frequency Decoding Scores", cmap=plt.cm.Reds) +av_tfr.plot( + [0], vlim=(chance, None), title="Time-Frequency Decoding Scores", cmap=plt.cm.Reds +) diff --git a/examples/inverse/dics_epochs.py b/examples/inverse/dics_epochs.py index d480b13f8a4..c359c30c0fb 100644 --- a/examples/inverse/dics_epochs.py +++ b/examples/inverse/dics_epochs.py @@ -22,7 +22,7 @@ import mne from mne.beamformer import apply_dics_tfr_epochs, make_dics from mne.datasets import somato -from mne.time_frequency import csd_tfr, tfr_morlet +from mne.time_frequency import csd_tfr print(__doc__) @@ -67,8 +67,8 @@ # decomposition for each epoch. We must pass ``output='complex'`` if we wish to # use this TFR later with a DICS beamformer. We also pass ``average=False`` to # compute the TFR for each individual epoch. -epochs_tfr = tfr_morlet( - epochs, freqs, n_cycles=5, return_itc=False, output="complex", average=False +epochs_tfr = epochs.compute_tfr( + "morlet", freqs, n_cycles=5, return_itc=False, output="complex", average=False ) # crop either side to use a buffer to remove edge artifact diff --git a/examples/time_frequency/time_frequency_erds.py b/examples/time_frequency/time_frequency_erds.py index 556730b6cab..1d805121739 100644 --- a/examples/time_frequency/time_frequency_erds.py +++ b/examples/time_frequency/time_frequency_erds.py @@ -45,7 +45,6 @@ from mne.datasets import eegbci from mne.io import concatenate_raws, read_raw_edf from mne.stats import permutation_cluster_1samp_test as pcluster_test -from mne.time_frequency import tfr_multitaper # %% # First, we load and preprocess the data. We use runs 6, 10, and 14 from @@ -96,8 +95,8 @@ # %% # Finally, we perform time/frequency decomposition over all epochs. -tfr = tfr_multitaper( - epochs, +tfr = epochs.compute_tfr( + method="multitaper", freqs=freqs, n_cycles=freqs, use_fft=True, diff --git a/examples/time_frequency/time_frequency_simulated.py b/examples/time_frequency/time_frequency_simulated.py index 85cc9a1f436..dc42f16da3a 100644 --- a/examples/time_frequency/time_frequency_simulated.py +++ b/examples/time_frequency/time_frequency_simulated.py @@ -25,16 +25,8 @@ from matplotlib import pyplot as plt from mne import Epochs, create_info -from mne.baseline import rescale from mne.io import RawArray -from mne.time_frequency import ( - AverageTFR, - tfr_array_morlet, - tfr_morlet, - tfr_multitaper, - tfr_stockwell, -) -from mne.viz import centers_to_edges +from mne.time_frequency import AverageTFRArray, EpochsTFRArray, tfr_array_morlet print(__doc__) @@ -112,12 +104,13 @@ "Sim: Less time smoothing,\nmore frequency smoothing", ], ): - power = tfr_multitaper( - epochs, + power = epochs.compute_tfr( + method="multitaper", freqs=freqs, n_cycles=n_cycles, time_bandwidth=time_bandwidth, return_itc=False, + average=True, ) ax.set_title(title) # Plot results. Baseline correct based on first 100 ms. @@ -125,8 +118,7 @@ [0], baseline=(0.0, 0.1), mode="mean", - vmin=vmin, - vmax=vmax, + vlim=(vmin, vmax), axes=ax, show=False, colorbar=False, @@ -146,7 +138,7 @@ fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained") fmin, fmax = freqs[[0, -1]] for width, ax in zip((0.2, 0.7, 3.0), axs): - power = tfr_stockwell(epochs, fmin=fmin, fmax=fmax, width=width) + power = epochs.compute_tfr(method="stockwell", freqs=(fmin, fmax), width=width) power.plot( [0], baseline=(0.0, 0.1), mode="mean", axes=ax, show=False, colorbar=False ) @@ -164,13 +156,14 @@ fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained") all_n_cycles = [1, 3, freqs / 2.0] for n_cycles, ax in zip(all_n_cycles, axs): - power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False) + power = epochs.compute_tfr( + method="morlet", freqs=freqs, n_cycles=n_cycles, return_itc=False, average=True + ) power.plot( [0], baseline=(0.0, 0.1), mode="mean", - vmin=vmin, - vmax=vmax, + vlim=(vmin, vmax), axes=ax, show=False, colorbar=False, @@ -190,7 +183,9 @@ fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained") bandwidths = [1.0, 2.0, 4.0] for bandwidth, ax in zip(bandwidths, axs): - data = np.zeros((len(ch_names), freqs.size, epochs.times.size), dtype=complex) + data = np.zeros( + (len(epochs), len(ch_names), freqs.size, epochs.times.size), dtype=complex + ) for idx, freq in enumerate(freqs): # Filter raw data and re-epoch to avoid the filter being longer than # the epoch data for low frequencies and short epochs, such as here. @@ -210,17 +205,13 @@ epochs_hilb = Epochs( raw_filter, events, tmin=0, tmax=n_times / sfreq, baseline=(0, 0.1) ) - tfr_data = epochs_hilb.get_data() - tfr_data = tfr_data * tfr_data.conj() # compute power - tfr_data = np.mean(tfr_data, axis=0) # average over epochs - data[:, idx] = tfr_data - power = AverageTFR(info, data, epochs.times, freqs, nave=n_epochs) - power.plot( + data[:, :, idx] = epochs_hilb.get_data() + power = EpochsTFRArray(epochs.info, data, epochs.times, freqs, method="hilbert") + power.average().plot( [0], baseline=(0.0, 0.1), mode="mean", - vmin=-0.1, - vmax=0.1, + vlim=(0, 0.1), axes=ax, show=False, colorbar=False, @@ -241,8 +232,8 @@ # :class:`mne.time_frequency.EpochsTFR` is returned. n_cycles = freqs / 2.0 -power = tfr_morlet( - epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False, average=False +power = epochs.compute_tfr( + method="morlet", freqs=freqs, n_cycles=n_cycles, return_itc=False, average=False ) print(type(power)) avgpower = power.average() @@ -250,8 +241,7 @@ [0], baseline=(0.0, 0.1), mode="mean", - vmin=vmin, - vmax=vmax, + vlim=(vmin, vmax), title="Using Morlet wavelets and EpochsTFR", show=False, ) @@ -260,10 +250,12 @@ # Operating on arrays # ------------------- # -# MNE also has versions of the functions above which operate on numpy arrays -# instead of MNE objects. They expect inputs of the shape -# ``(n_epochs, n_channels, n_times)``. They will also return a numpy array -# of shape ``(n_epochs, n_channels, n_freqs, n_times)``. +# MNE-Python also has functions that operate on :class:`NumPy arrays ` +# instead of MNE-Python objects. These are :func:`~mne.time_frequency.tfr_array_morlet` +# and :func:`~mne.time_frequency.tfr_array_multitaper`. They expect inputs of the shape +# ``(n_epochs, n_channels, n_times)`` and return an array of shape +# ``(n_epochs, n_channels, n_freqs, n_times)`` (or optionally, can collapse the epochs +# dimension if you want average power or inter-trial coherence; see ``output`` param). power = tfr_array_morlet( epochs.get_data(), @@ -271,12 +263,16 @@ freqs=freqs, n_cycles=n_cycles, output="avg_power", + zero_mean=False, +) +# Put it into a TFR container for easy plotting +tfr = AverageTFRArray( + info=epochs.info, data=power, times=epochs.times, freqs=freqs, nave=len(epochs) +) +tfr.plot( + baseline=(0.0, 0.1), + picks=[0], + mode="mean", + vlim=(vmin, vmax), + title="TFR calculated on a NumPy array", ) -# Baseline the output -rescale(power, epochs.times, (0.0, 0.1), mode="mean", copy=False) -fig, ax = plt.subplots(layout="constrained") -x, y = centers_to_edges(epochs.times * 1000, freqs) -mesh = ax.pcolormesh(x, y, power[0], cmap="RdBu_r", vmin=vmin, vmax=vmax) -ax.set_title("TFR calculated on a numpy array") -ax.set(ylim=freqs[[0, -1]], xlabel="Time (ms)") -fig.colorbar(mesh) diff --git a/mne/beamformer/tests/test_dics.py b/mne/beamformer/tests/test_dics.py index 1daaaf17eb0..bcde4503307 100644 --- a/mne/beamformer/tests/test_dics.py +++ b/mne/beamformer/tests/test_dics.py @@ -30,7 +30,7 @@ from mne.io import read_info from mne.proj import compute_proj_evoked, make_projector from mne.surface import _compute_nearest -from mne.time_frequency import CrossSpectralDensity, EpochsTFR, csd_morlet, csd_tfr +from mne.time_frequency import CrossSpectralDensity, EpochsTFRArray, csd_morlet, csd_tfr from mne.time_frequency.csd import _sym_mat_to_vector from mne.transforms import apply_trans, invert_transform from mne.utils import catch_logging, object_diff @@ -727,7 +727,7 @@ def test_apply_dics_tfr(return_generator): data = rng.random((n_epochs, n_chans, len(freqs), n_times)) data *= 1e-6 data = data + data * 1j # add imag. component to simulate phase - epochs_tfr = EpochsTFR(info, data, times=times, freqs=freqs) + epochs_tfr = EpochsTFRArray(info=info, data=data, times=times, freqs=freqs) # Create a DICS beamformer and convert the EpochsTFR to source space. csd = csd_tfr(epochs_tfr) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 0d0af8279cb..341e355f363 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -153,7 +153,7 @@ def equalize_channels(instances, copy=True, verbose=None): from ..evoked import Evoked from ..forward import Forward from ..io import BaseRaw - from ..time_frequency import CrossSpectralDensity, _BaseTFR + from ..time_frequency import BaseTFR, CrossSpectralDensity # Instances need to have a `ch_names` attribute and a `pick_channels` # method that supports `ordered=True`. @@ -161,7 +161,7 @@ def equalize_channels(instances, copy=True, verbose=None): BaseRaw, BaseEpochs, Evoked, - _BaseTFR, + BaseTFR, Forward, Covariance, CrossSpectralDensity, @@ -607,8 +607,6 @@ def drop_channels(self, ch_names, on_missing="raise"): def _pick_drop_channels(self, idx, *, verbose=None): # avoid circular imports from ..io import BaseRaw - from ..time_frequency import AverageTFR, EpochsTFR - from ..time_frequency.spectrum import BaseSpectrum msg = "adding, dropping, or reordering channels" if isinstance(self, BaseRaw): @@ -633,10 +631,8 @@ def _pick_drop_channels(self, idx, *, verbose=None): if mat is not None: setattr(self, key, mat[idx][:, idx]) - if isinstance(self, BaseSpectrum): + if hasattr(self, "_dims"): # Spectrum and "new-style" TFRs axis = self._dims.index("channel") - elif isinstance(self, (AverageTFR, EpochsTFR)): - axis = -3 else: # All others (Evoked, Epochs, Raw) have chs axis=-2 axis = -2 if hasattr(self, "_data"): # skip non-preloaded Raw diff --git a/mne/conftest.py b/mne/conftest.py index 2d153f92f40..7dd02366ace 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -397,6 +397,34 @@ def epochs_spectrum(): return _get_epochs().load_data().compute_psd() +@pytest.fixture() +def epochs_tfr(): + """Get an EpochsTFR computed from mne.io.tests.data.""" + epochs = _get_epochs().load_data() + return epochs.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5)) + + +@pytest.fixture() +def average_tfr(epochs_tfr): + """Get an AverageTFR computed by averaging an EpochsTFR (this is small & fast).""" + return epochs_tfr.average() + + +@pytest.fixture() +def full_average_tfr(full_evoked): + """Get an AverageTFR computed from Evoked. + + This is slower than the `average_tfr` fixture, but a few TFR.plot_* tests need it. + """ + return full_evoked.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5)) + + +@pytest.fixture() +def raw_tfr(raw): + """Get a RawTFR computed from mne.io.tests.data.""" + return raw.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5)) + + @pytest.fixture() def epochs_empty(): """Get empty epochs from mne.io.tests.data.""" @@ -408,22 +436,31 @@ def epochs_empty(): @pytest.fixture(scope="session", params=[testing._pytest_param()]) -def _evoked(): - # This one is session scoped, so be sure not to modify it (use evoked - # instead) - evoked = mne.read_evokeds( - fname_evoked, condition="Left Auditory", baseline=(None, 0) - ) - evoked.crop(0, 0.2) - return evoked +def _full_evoked(): + # This is session scoped, so be sure not to modify its return value (use + # `full_evoked` fixture instead) + return mne.read_evokeds(fname_evoked, condition="Left Auditory", baseline=(None, 0)) + + +@pytest.fixture(scope="session", params=[testing._pytest_param()]) +def _evoked(_full_evoked): + # This is session scoped, so be sure not to modify its return value (use `evoked` + # fixture instead) + return _full_evoked.copy().crop(0, 0.2) @pytest.fixture() def evoked(_evoked): - """Get evoked data.""" + """Get truncated evoked data.""" return _evoked.copy() +@pytest.fixture() +def full_evoked(_full_evoked): + """Get full-duration evoked data (needed for, e.g., testing TFR).""" + return _full_evoked.copy() + + @pytest.fixture(scope="function", params=[testing._pytest_param()]) def noise_cov(): """Get a noise cov from the testing dataset.""" diff --git a/mne/decoding/time_frequency.py b/mne/decoding/time_frequency.py index bd0076d0355..0555d190ddd 100644 --- a/mne/decoding/time_frequency.py +++ b/mne/decoding/time_frequency.py @@ -150,17 +150,17 @@ def transform(self, X): # Compute time-frequency Xt = _compute_tfr( X, - self.freqs, - self.sfreq, - self.method, - self.n_cycles, - True, - self.time_bandwidth, - self.use_fft, - self.decim, - self.output, - self.n_jobs, - self.verbose, + freqs=self.freqs, + sfreq=self.sfreq, + method=self.method, + n_cycles=self.n_cycles, + zero_mean=True, + time_bandwidth=self.time_bandwidth, + use_fft=self.use_fft, + decim=self.decim, + output=self.output, + n_jobs=self.n_jobs, + verbose=self.verbose, ) # Back to original shape diff --git a/mne/epochs.py b/mne/epochs.py index 14a0092c07a..9e48936f8bf 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -75,7 +75,7 @@ from .html_templates import _get_html_template from .parallel import parallel_func from .time_frequency.spectrum import EpochsSpectrum, SpectrumMixin, _validate_method -from .time_frequency.tfr import EpochsTFR +from .time_frequency.tfr import AverageTFR, EpochsTFR from .utils import ( ExtendedTimeMixin, GetEpochsMixin, @@ -419,6 +419,8 @@ class BaseEpochs( filename : str | None The filename (if the epochs are read from disk). %(metadata_epochs)s + + .. versionadded:: 0.16 %(event_repeated_epochs)s %(raw_sfreq)s annotations : instance of mne.Annotations | None @@ -2560,6 +2562,139 @@ def compute_psd( **method_kw, ) + @verbose + def compute_tfr( + self, + method, + freqs, + *, + tmin=None, + tmax=None, + picks=None, + proj=False, + output="power", + average=False, + return_itc=False, + decim=1, + n_jobs=None, + verbose=None, + **method_kw, + ): + """Compute a time-frequency representation of epoched data. + + Parameters + ---------- + %(method_tfr_epochs)s + %(freqs_tfr_epochs)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(output_compute_tfr)s + average : bool + Whether to return average power across epochs (instead of single-trial + power). ``average=True`` is not compatible with ``output="complex"`` or + ``output="phase"``. Ignored if ``method="stockwell"`` (Stockwell method + *requires* averaging). Default is ``False``. + return_itc : bool + Whether to return inter-trial coherence (ITC) as well as power estimates. + If ``True`` then must specify ``average=True`` (or ``method="stockwell", + average="auto"``). Default is ``False``. + %(decim_tfr)s + %(n_jobs)s + %(verbose)s + %(method_kw_epochs_tfr)s + + Returns + ------- + tfr : instance of EpochsTFR or AverageTFR + The time-frequency-resolved power estimates. + itc : instance of AverageTFR + The inter-trial coherence (ITC). Only returned if ``return_itc=True``. + + Notes + ----- + If ``average=True`` (or ``method="stockwell", average="auto"``) the result will + be an :class:`~mne.time_frequency.AverageTFR` instead of an + :class:`~mne.time_frequency.EpochsTFR`. + + .. versionadded:: 1.7 + + References + ---------- + .. footbibliography:: + """ + if method == "stockwell" and not average: # stockwell method *must* average + logger.info( + 'Requested `method="stockwell"` so ignoring parameter `average=False`.' + ) + average = True + if average: + # augment `output` value for use by tfr_array_* functions + _check_option("output", output, ("power",), extra=" when average=True") + method_kw["output"] = "avg_power_itc" if return_itc else "avg_power" + else: + msg = ( + "compute_tfr() got incompatible parameters `average=False` and `{}` " + "({} requires averaging over epochs)." + ) + if return_itc: + raise ValueError(msg.format("return_itc=True", "computing ITC")) + if method == "stockwell": + raise ValueError(msg.format('method="stockwell"', "Stockwell method")) + # `average` and `return_itc` both False, so "phase" and "complex" are OK + _check_option("output", output, ("power", "phase", "complex")) + method_kw["output"] = output + + if method == "stockwell": + method_kw["return_itc"] = return_itc + method_kw.pop("output") + if isinstance(freqs, str): + _check_option("freqs", freqs, "auto") + else: + _validate_type(freqs, "array-like") + _check_option( + "freqs", np.array(freqs).shape, ((2,),), extra=" (wrong shape)." + ) + if average: + out = AverageTFR( + inst=self, + method=method, + freqs=freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) + # tfr_array_stockwell always returns ITC (but sometimes it's None) + if hasattr(out, "_itc"): + if out._itc is not None: + state = out.__getstate__() + state["data"] = out._itc + state["data_type"] = "Inter-trial coherence" + itc = AverageTFR(inst=state) + del out._itc + return out, itc + del out._itc + return out + # now handle average=False + return EpochsTFR( + inst=self, + method=method, + freqs=freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) + @verbose def plot_psd( self, @@ -3303,20 +3438,18 @@ class Epochs(BaseEpochs): %(on_missing_epochs)s %(reject_by_annotation_epochs)s %(metadata_epochs)s + + .. versionadded:: 0.16 %(event_repeated_epochs)s %(verbose)s Attributes ---------- %(info_not_none)s - event_id : dict - Names of conditions corresponding to event_ids. + %(event_id_attr)s ch_names : list of string List of channel names. - selection : array - List of indices of selected events (not dropped or ignored etc.). For - example, if the original event array had 4 events and the second event - has been dropped, this attribute would be np.array([0, 2, 3]). + %(selection_attr)s preload : bool Indicates whether epochs are in memory. drop_log : tuple of tuple @@ -3535,6 +3668,8 @@ class EpochsArray(BaseEpochs): %(proj_epochs)s %(on_missing_epochs)s %(metadata_epochs)s + + .. versionadded:: 0.16 %(selection)s %(drop_log)s diff --git a/mne/evoked.py b/mne/evoked.py index f6f752cadbf..2e36f47f81b 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -48,6 +48,7 @@ from .html_templates import _get_html_template from .parallel import parallel_func from .time_frequency.spectrum import Spectrum, SpectrumMixin, _validate_method +from .time_frequency.tfr import AverageTFR from .utils import ( ExtendedTimeMixin, SizeMixin, @@ -1168,6 +1169,66 @@ def compute_psd( **method_kw, ) + @verbose + def compute_tfr( + self, + method, + freqs, + *, + tmin=None, + tmax=None, + picks=None, + proj=False, + output="power", + decim=1, + n_jobs=None, + verbose=None, + **method_kw, + ): + """Compute a time-frequency representation of evoked data. + + Parameters + ---------- + %(method_tfr)s + %(freqs_tfr)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(output_compute_tfr)s + %(decim_tfr)s + %(n_jobs)s + %(verbose)s + %(method_kw_tfr)s + + Returns + ------- + tfr : instance of AverageTFR + The time-frequency-resolved power estimates of the data. + + Notes + ----- + .. versionadded:: 1.7 + + References + ---------- + .. footbibliography:: + """ + _check_option("output", output, ("power", "phase", "complex")) + method_kw["output"] = output + return AverageTFR( + inst=self, + method=method, + freqs=freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) + @verbose def plot_psd( self, diff --git a/mne/html_templates/repr/tfr.html.jinja b/mne/html_templates/repr/tfr.html.jinja new file mode 100644 index 00000000000..f6ab107ab0b --- /dev/null +++ b/mne/html_templates/repr/tfr.html.jinja @@ -0,0 +1,60 @@ + + + + + + {%- for unit in units %} + + {%- if loop.index == 1 %} + + {%- endif %} + + + {%- endfor %} + + + + + {%- if inst_type == "Epochs" %} + + + + + {% endif -%} + {%- if inst_type == "Evoked" %} + + + + + {% endif -%} + + + + + + + + + {% if "taper" in tfr._dims %} + + + + + {% endif %} + + + + + + + + + + + + + + + + +
Data type{{ tfr._data_type }}
Units{{ unit }}
Data source{{ inst_type }}
Number of epochs{{ tfr.shape[0] }}
Number of averaged trials{{ nave }}
Dims{{ tfr._dims | join(", ") }}
Estimation method{{ tfr.method }}
Number of tapers{{ tfr._mt_weights.size }}
Number of channels{{ tfr.ch_names|length }}
Number of timepoints{{ tfr.times|length }}
Number of frequency bins{{ tfr.freqs|length }}
Frequency range{{ '%.2f'|format(tfr.freqs[0]) }} – {{ '%.2f'|format(tfr.freqs[-1]) }} Hz
diff --git a/mne/io/base.py b/mne/io/base.py index ed909e5658f..c7fb5e4ddd0 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -82,6 +82,7 @@ from ..html_templates import _get_html_template from ..parallel import parallel_func from ..time_frequency.spectrum import Spectrum, SpectrumMixin, _validate_method +from ..time_frequency.tfr import RawTFR from ..utils import ( SizeMixin, TimeMixin, @@ -2241,6 +2242,69 @@ def compute_psd( **method_kw, ) + @verbose + def compute_tfr( + self, + method, + freqs, + *, + tmin=None, + tmax=None, + picks=None, + proj=False, + output="power", + reject_by_annotation=True, + decim=1, + n_jobs=None, + verbose=None, + **method_kw, + ): + """Compute a time-frequency representation of sensor data. + + Parameters + ---------- + %(method_tfr)s + %(freqs_tfr)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(output_compute_tfr)s + %(reject_by_annotation_tfr)s + %(decim_tfr)s + %(n_jobs)s + %(verbose)s + %(method_kw_tfr)s + + Returns + ------- + tfr : instance of RawTFR + The time-frequency-resolved power estimates of the data. + + Notes + ----- + .. versionadded:: 1.7 + + References + ---------- + .. footbibliography:: + """ + _check_option("output", output, ("power", "phase", "complex")) + method_kw["output"] = output + return RawTFR( + self, + method=method, + freqs=freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + reject_by_annotation=reject_by_annotation, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) + @verbose def to_data_frame( self, diff --git a/mne/minimum_norm/tests/test_inverse.py b/mne/minimum_norm/tests/test_inverse.py index a620fdbbf29..e3be18a3fc9 100644 --- a/mne/minimum_norm/tests/test_inverse.py +++ b/mne/minimum_norm/tests/test_inverse.py @@ -55,7 +55,7 @@ from mne.source_estimate import VolSourceEstimate, read_source_estimate from mne.source_space._source_space import _get_src_nn from mne.surface import _normal_orth -from mne.time_frequency import EpochsTFR +from mne.time_frequency import EpochsTFRArray from mne.utils import _record_warnings, catch_logging test_path = testing.data_path(download=False) @@ -1375,11 +1375,11 @@ def test_apply_inverse_tfr(return_generator): times = np.arange(sfreq) / sfreq # make epochs 1s long data = rng.random((n_epochs, len(info.ch_names), freqs.size, times.size)) data = data + 1j * data # make complex to simulate amplitude + phase - epochs_tfr = EpochsTFR(info, data, times=times, freqs=freqs) + epochs_tfr = EpochsTFRArray(info=info, data=data, times=times, freqs=freqs) epochs_tfr.apply_baseline((0, 0.5)) pick_ori = "vector" - with pytest.raises(ValueError, match="Expected 2 inverse operators, " "got 3"): + with pytest.raises(ValueError, match="Expected 2 inverse operators, got 3"): apply_inverse_tfr_epochs(epochs_tfr, [inverse_operator] * 3, lambda2) # test epochs diff --git a/mne/time_frequency/__init__.pyi b/mne/time_frequency/__init__.pyi index 9fc0c271cc4..0faeb7263d8 100644 --- a/mne/time_frequency/__init__.pyi +++ b/mne/time_frequency/__init__.pyi @@ -1,12 +1,16 @@ __all__ = [ "AverageTFR", + "AverageTFRArray", + "BaseTFR", "CrossSpectralDensity", "EpochsSpectrum", "EpochsSpectrumArray", "EpochsTFR", + "EpochsTFRArray", + "RawTFR", + "RawTFRArray", "Spectrum", "SpectrumArray", - "_BaseTFR", "csd_array_fourier", "csd_array_morlet", "csd_array_multitaper", @@ -61,8 +65,12 @@ from .spectrum import ( ) from .tfr import ( AverageTFR, + AverageTFRArray, + BaseTFR, EpochsTFR, - _BaseTFR, + EpochsTFRArray, + RawTFR, + RawTFRArray, fwhm, morlet, read_tfrs, diff --git a/mne/time_frequency/_stockwell.py b/mne/time_frequency/_stockwell.py index d1108f8057b..08acf28b357 100644 --- a/mne/time_frequency/_stockwell.py +++ b/mne/time_frequency/_stockwell.py @@ -12,8 +12,8 @@ from .._fiff.pick import _pick_data_channels, pick_info from ..parallel import parallel_func -from ..utils import _validate_type, fill_doc, logger, verbose -from .tfr import AverageTFR, _get_data +from ..utils import _validate_type, legacy, logger, verbose +from .tfr import AverageTFRArray, _ensure_slice, _get_data def _check_input_st(x_in, n_fft): @@ -81,9 +81,10 @@ def _st(x, start_f, windows): def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W): """Aux function.""" + decim = _ensure_slice(decim) n_samp = x.shape[-1] - n_out = n_samp - zero_pad - n_out = n_out // decim + bool(n_out % decim) + decim_indices = decim.indices(n_samp - zero_pad) + n_out = len(range(*decim_indices)) psd = np.empty((len(W), n_out)) itc = np.empty_like(psd) if compute_itc else None X = fft(x) @@ -91,10 +92,7 @@ def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W): for i_f, window in enumerate(W): f = start_f + i_f ST = ifft(XX[:, f : f + n_samp] * window) - if zero_pad > 0: - TFR = ST[:, :-zero_pad:decim] - else: - TFR = ST[:, ::decim] + TFR = ST[:, slice(*decim_indices)] TFR_abs = np.abs(TFR) TFR_abs[TFR_abs == 0] = 1.0 if compute_itc: @@ -105,7 +103,22 @@ def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W): return psd, itc -@fill_doc +def _compute_freqs_st(fmin, fmax, n_fft, sfreq): + from scipy.fft import fftfreq + + freqs = fftfreq(n_fft, 1.0 / sfreq) + if fmin is None: + fmin = freqs[freqs > 0][0] + if fmax is None: + fmax = freqs.max() + + start_f = np.abs(freqs - fmin).argmin() + stop_f = np.abs(freqs - fmax).argmin() + freqs = freqs[start_f:stop_f] + return start_f, stop_f, freqs + + +@verbose def tfr_array_stockwell( data, sfreq, @@ -116,6 +129,8 @@ def tfr_array_stockwell( decim=1, return_itc=False, n_jobs=None, + *, + verbose=None, ): """Compute power and intertrial coherence using Stockwell (S) transform. @@ -143,11 +158,11 @@ def tfr_array_stockwell( The width of the Gaussian window. If < 1, increased temporal resolution, if > 1, increased frequency resolution. Defaults to 1. (classical S-Transform). - decim : int - The decimation factor on the time axis. To reduce memory usage. + %(decim_tfr)s return_itc : bool Return intertrial coherence (ITC) as well as averaged power. %(n_jobs)s + %(verbose)s Returns ------- @@ -177,26 +192,17 @@ def tfr_array_stockwell( "data must be 3D with shape (n_epochs, n_channels, n_times), " f"got {data.shape}" ) - n_epochs, n_channels = data.shape[:2] - n_out = data.shape[2] // decim + bool(data.shape[-1] % decim) + decim = _ensure_slice(decim) + _, n_channels, n_out = data[..., decim].shape data, n_fft_, zero_pad = _check_input_st(data, n_fft) - - freqs = fftfreq(n_fft_, 1.0 / sfreq) - if fmin is None: - fmin = freqs[freqs > 0][0] - if fmax is None: - fmax = freqs.max() - - start_f = np.abs(freqs - fmin).argmin() - stop_f = np.abs(freqs - fmax).argmin() - freqs = freqs[start_f:stop_f] + start_f, stop_f, freqs = _compute_freqs_st(fmin, fmax, n_fft_, sfreq) W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width) n_freq = stop_f - start_f psd = np.empty((n_channels, n_freq, n_out)) itc = np.empty((n_channels, n_freq, n_out)) if return_itc else None - parallel, my_st, n_jobs = parallel_func(_st_power_itc, n_jobs) + parallel, my_st, n_jobs = parallel_func(_st_power_itc, n_jobs, verbose=verbose) tfrs = parallel( my_st(data[:, c, :], start_f, return_itc, zero_pad, decim, W) for c in range(n_channels) @@ -209,6 +215,7 @@ def tfr_array_stockwell( return psd, itc, freqs +@legacy(alt='.compute_tfr(method="stockwell", freqs="auto")') @verbose def tfr_stockwell( inst, @@ -281,6 +288,7 @@ def tfr_stockwell( picks = _pick_data_channels(inst.info) info = pick_info(inst.info, picks) data = data[:, picks, :] + decim = _ensure_slice(decim) power, itc, freqs = tfr_array_stockwell( data, sfreq=info["sfreq"], @@ -292,18 +300,25 @@ def tfr_stockwell( return_itc=return_itc, n_jobs=n_jobs, ) - times = inst.times[::decim].copy() + times = inst.times[decim].copy() nave = len(data) - out = AverageTFR(info, power, times, freqs, nave, method="stockwell-power") + out = AverageTFRArray( + info=info, + data=power, + times=times, + freqs=freqs, + nave=nave, + method="stockwell-power", + ) if return_itc: out = ( out, - AverageTFR( - deepcopy(info), - itc, - times.copy(), - freqs.copy(), - nave, + AverageTFRArray( + info=deepcopy(info), + data=itc, + times=times.copy(), + freqs=freqs.copy(), + nave=nave, method="stockwell-itc", ), ) diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index 00e3c1c1e17..4a9e66c4673 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -488,7 +488,7 @@ def tfr_array_multitaper( The epochs. sfreq : float Sampling frequency of the data in Hz. - %(freqs_tfr)s + %(freqs_tfr_array)s %(n_cycles_tfr)s zero_mean : bool If True, make sure the wavelets have a mean of zero. Defaults to True. @@ -506,6 +506,7 @@ def tfr_array_multitaper( * ``'avg_power_itc'`` : average of single trial power and inter-trial coherence across trials. %(n_jobs)s + The parallelization is implemented across channels. %(verbose)s epoch_data : None Deprecated parameter for providing epoched data as of 1.7, will be replaced with diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index e46be389695..7300753c584 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -1,7 +1,4 @@ """Container classes for spectral data.""" - -# Authors: Dan McCloy -# # License: BSD-3-Clause # Copyright the MNE-Python contributors. @@ -25,6 +22,7 @@ from ..utils import ( GetEpochsMixin, _build_data_frame, + _check_method_kwargs, _check_pandas_index_arguments, _check_pandas_installed, _check_sphere, @@ -46,12 +44,13 @@ check_fname, ) from ..utils.misc import _identity_function, _pl -from ..utils.spectrum import _split_psd_kwargs +from ..utils.spectrum import _get_instance_type_string, _split_psd_kwargs from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo from ..viz.topomap import _make_head_outlines, _prepare_topomap_plot, plot_psds_topomap from ..viz.utils import ( _format_units_psd, _get_plot_ch_type, + _make_combine_callable, _plot_psd, _prepare_sensor_names, plt_show, @@ -314,7 +313,7 @@ def __init__( ) # method self._inst_type = type(inst) - method = _validate_method(method, self._get_instance_type_string()) + method = _validate_method(method, _get_instance_type_string(self)) # don't allow complex output psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper) if method_kw.get("output", "") == "complex": @@ -324,16 +323,8 @@ def __init__( ) # triage method and kwargs. partial() doesn't check validity of kwargs, # so we do it manually to save compute time if any are invalid. - invalid_ix = np.isin( - list(method_kw), list(signature(psd_funcs[method]).parameters), invert=True - ) - if invalid_ix.any(): - invalid_kw = np.array(list(method_kw))[invalid_ix].tolist() - s = _pl(invalid_kw) - raise TypeError( - f'Got unexpected keyword argument{s} {", ".join(invalid_kw)} ' - f'for PSD method "{method}".' - ) + psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper) + _check_method_kwargs(psd_funcs[method], method_kw, msg=f'PSD method "{method}"') self._psd_func = partial(psd_funcs[method], remove_dc=remove_dc, **method_kw) # apply proj if desired @@ -352,7 +343,7 @@ def __init__( self.info = pick_info(inst.info, sel=self._picks, copy=True) # assign some attributes - self.preload = True # needed for __getitem__, doesn't mean anything + self.preload = True # needed for __getitem__, never False self._method = method # self._dims may also get updated by child classes self._dims = ( @@ -365,6 +356,8 @@ def __init__( self._data_type = ( "Fourier Coefficients" if "taper" in self._dims else "Power Spectrum" ) + # set nave (child constructor overrides this for Evoked input) + self._nave = None def __eq__(self, other): """Test equivalence of two Spectrum instances.""" @@ -372,7 +365,7 @@ def __eq__(self, other): def __getstate__(self): """Prepare object for serialization.""" - inst_type_str = self._get_instance_type_string() + inst_type_str = _get_instance_type_string(self) out = dict( method=self.method, data=self._data, @@ -382,6 +375,7 @@ def __getstate__(self): inst_type_str=inst_type_str, data_type=self._data_type, info=self.info, + nave=self.nave, ) return out @@ -398,6 +392,7 @@ def __setstate__(self, state): self._sfreq = state["sfreq"] self.info = Info(**state["info"]) self._data_type = state["data_type"] + self._nave = state.get("nave") # objs saved before #11282 won't have `nave` self.preload = True # instance type inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray) @@ -405,7 +400,7 @@ def __setstate__(self, state): def __repr__(self): """Build string representation of the Spectrum object.""" - inst_type_str = self._get_instance_type_string() + inst_type_str = _get_instance_type_string(self) # shape & dimension names dims = " × ".join( [f"{dim[0]} {dim[1]}s" for dim in zip(self.shape, self._dims)] @@ -419,7 +414,7 @@ def __repr__(self): @repr_html def _repr_html_(self, caption=None): """Build HTML representation of the Spectrum object.""" - inst_type_str = self._get_instance_type_string() + inst_type_str = _get_instance_type_string(self) units = [f"{ch_type}: {unit}" for ch_type, unit in self.units().items()] t = _get_html_template("repr", "spectrum.html.jinja") t = t.render(spectrum=self, inst_type=inst_type_str, units=units) @@ -466,25 +461,6 @@ def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): del self._psd_func del self._time_mask - def _get_instance_type_string(self): - """Get string representation of the originating instance type.""" - from ..epochs import BaseEpochs - from ..evoked import Evoked, EvokedArray - from ..io import BaseRaw - - parent_classes = self._inst_type.__bases__ - if BaseRaw in parent_classes: - inst_type_str = "Raw" - elif BaseEpochs in parent_classes: - inst_type_str = "Epochs" - elif self._inst_type in (Evoked, EvokedArray): - inst_type_str = "Evoked" - elif self._inst_type is np.ndarray: - inst_type_str = "Array" - else: - raise RuntimeError(f"Unknown instance type {self._inst_type} in Spectrum") - return inst_type_str - @property def _detrend_picks(self): """Provide compatibility with __iter__.""" @@ -494,6 +470,10 @@ def _detrend_picks(self): def ch_names(self): return self.info["ch_names"] + @property + def data(self): + return self._data + @property def freqs(self): return self._freqs @@ -502,6 +482,10 @@ def freqs(self): def method(self): return self._method + @property + def nave(self): + return self._nave + @property def sfreq(self): return self._sfreq @@ -977,7 +961,7 @@ def to_data_frame( # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa # triage for Epoch-derived or unaggregated spectra - from_epo = self._dims[0] == "epoch" + from_epo = _get_instance_type_string(self) == "Epochs" unagg_welch = "segment" in self._dims unagg_mt = "taper" in self._dims # arg checking @@ -1083,8 +1067,10 @@ class Spectrum(BaseSpectrum): have been computed. %(info_not_none)s method : str - The method used to compute the spectrum (``'welch'`` or - ``'multitaper'``). + The method used to compute the spectrum (``'welch'`` or ``'multitaper'``). + nave : int | None + The number of trials averaged together when generating the spectrum. ``None`` + indicates no averaging is known to have occurred. See Also -------- @@ -1148,11 +1134,13 @@ def __init__( ) else: # Evoked data = self.inst.data[self._picks][:, self._time_mask] + # set nave + self._nave = getattr(inst, "nave", None) # compute the spectra self._compute_spectra(data, fmin, fmax, n_jobs, method_kw, verbose) # check for correct shape and bad values self._check_values() - del self._shape + del self._shape # calculated from self._data henceforth # save memory del self.inst @@ -1185,7 +1173,8 @@ def __getitem__(self, item): requested data values and the corresponding times), accessing :class:`~mne.time_frequency.Spectrum` values via subscript does **not** return the corresponding frequency bin values. If you need - them, use ``spectrum.freqs[freq_indices]``. + them, use ``spectrum.freqs[freq_indices]`` or + ``spectrum.get_data(..., return_freqs=True)``. """ from ..io import BaseRaw @@ -1220,7 +1209,7 @@ class SpectrumArray(Spectrum): data : array, shape (n_channels, n_freqs) The power spectral density for each channel. %(info_not_none)s - %(freqs_tfr)s + %(freqs_tfr_array)s %(verbose)s See Also @@ -1418,9 +1407,10 @@ def average(self, method="mean"): spectrum : instance of Spectrum The aggregated spectrum object. """ - if isinstance(method, str): - method = getattr(np, method) # mean, median, std, etc - method = partial(method, axis=0) + _validate_type(method, ("str", "callable")) + method = _make_combine_callable( + method, axis=0, valid=("mean", "median"), keepdims=False + ) if not callable(method): raise ValueError( '"method" must be a valid string or callable, ' @@ -1435,6 +1425,7 @@ def average(self, method="mean"): ) # serialize the object and update data, dims, and data type state = super().__getstate__() + state["nave"] = state["data"].shape[0] state["data"] = method(state["data"]) state["dims"] = state["dims"][1:] state["data_type"] = f'Averaged {state["data_type"]}' @@ -1464,7 +1455,7 @@ class EpochsSpectrumArray(EpochsSpectrum): data : array, shape (n_epochs, n_channels, n_freqs) The power spectral density for each channel in each epoch. %(info_not_none)s - %(freqs_tfr)s + %(freqs_tfr_array)s %(events_epochs)s %(event_id)s %(verbose)s diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 18fbf4da483..752e1d000a1 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -125,11 +125,15 @@ def test_n_welch_windows(raw): ) -def _get_inst(inst, request, evoked): +def _get_inst(inst, request, *, evoked=None, average_tfr=None): # ↓ XXX workaround: # ↓ parametrized fixtures are not accessible via request.getfixturevalue # ↓ https://github.com/pytest-dev/pytest/issues/4666#issuecomment-456593913 - return evoked if inst == "evoked" else request.getfixturevalue(inst) + if inst == "evoked": + return evoked + elif inst == "average_tfr": + return average_tfr + return request.getfixturevalue(inst) @pytest.mark.parametrize("inst", ("raw", "epochs", "evoked")) @@ -137,7 +141,7 @@ def test_spectrum_io(inst, tmp_path, request, evoked): """Test save/load of spectrum objects.""" pytest.importorskip("h5io") fname = tmp_path / f"{inst}-spectrum.h5" - inst = _get_inst(inst, request, evoked) + inst = _get_inst(inst, request, evoked=evoked) orig = inst.compute_psd() orig.save(fname) loaded = read_spectrum(fname) @@ -214,7 +218,7 @@ def test_spectrum_to_data_frame(inst, request, evoked): # setup is_already_psd = inst in ("raw_spectrum", "epochs_spectrum") is_epochs = inst == "epochs_spectrum" - inst = _get_inst(inst, request, evoked) + inst = _get_inst(inst, request, evoked=evoked) extra_dim = () if is_epochs else (1,) extra_cols = ["freq", "condition", "epoch"] if is_epochs else ["freq"] # compute PSD diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index 4fc6f147377..5087a8c46a9 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -8,26 +8,37 @@ import matplotlib.pyplot as plt import numpy as np import pytest -from numpy.testing import assert_allclose, assert_array_equal, assert_equal +from matplotlib.collections import PathCollection +from numpy.testing import ( + assert_allclose, + assert_array_almost_equal, + assert_array_equal, + assert_equal, +) from scipy.signal import morlet2 import mne from mne import ( Epochs, EpochsArray, - Info, - Transform, create_info, pick_types, read_events, ) from mne.epochs import equalize_epoch_counts from mne.io import read_raw_fif -from mne.tests.test_epochs import assert_metadata_equal -from mne.time_frequency import tfr_array_morlet, tfr_array_multitaper -from mne.time_frequency.tfr import ( +from mne.time_frequency import ( AverageTFR, + AverageTFRArray, + EpochsSpectrum, EpochsTFR, + EpochsTFRArray, + RawTFR, + RawTFRArray, + tfr_array_morlet, + tfr_array_multitaper, +) +from mne.time_frequency.tfr import ( _compute_tfr, _make_dpss, combine_tfr, @@ -40,34 +51,41 @@ write_tfrs, ) from mne.utils import catch_logging, grand_average -from mne.viz.utils import _fake_click, _fake_keypress, _fake_scroll +from mne.utils._testing import _get_suptitle +from mne.viz.utils import ( + _channel_type_prettyprint, + _fake_click, + _fake_keypress, + _fake_scroll, +) + +from .test_spectrum import _get_inst data_path = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_path / "test_raw.fif" event_fname = data_path / "test-eve.fif" raw_ctf_fname = data_path / "test_ctf_raw.fif" +freqs_linspace = np.linspace(20, 40, num=5) +freqs_unsorted_list = [26, 33, 41, 20] +mag_names = [f"MEG 01{n}1" for n in (1, 2, 3)] -def _create_test_epochstfr(): - n_epos = 3 - ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] - n_picks = len(ch_names) - ch_types = ["eeg"] * n_picks - n_freqs = 5 - n_times = 6 - data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) - srate = 1000.0 - freqs = np.arange(5) - events = np.zeros((n_epos, 3), dtype=int) - events[:, 0] = np.arange(n_epos) - events[:, 2] = np.arange(5, 5 + n_epos) - event_id = {k: v for v, k in zip(events[:, 2], ["ha", "he", "hu"])} - info = mne.create_info(ch_names, srate, ch_types) - tfr = mne.time_frequency.EpochsTFR( - info, data, times, freqs, events=events, event_id=event_id - ) - return tfr +parametrize_morlet_multitaper = pytest.mark.parametrize( + "method", ("morlet", "multitaper") +) +parametrize_power_phase_complex = pytest.mark.parametrize( + "output", ("power", "phase", "complex") +) +parametrize_inst_and_ch_type = pytest.mark.parametrize( + "inst,ch_type", + ( + pytest.param("raw_tfr", "mag"), + pytest.param("raw_tfr", "grad"), + pytest.param("epochs_tfr", "mag"), # no grad pairs in epochs fixture + pytest.param("average_tfr", "mag"), + pytest.param("average_tfr", "grad"), + ), +) def test_tfr_ctf(): @@ -111,7 +129,7 @@ def test_morlet(sfreq, freq, n_cycles): assert_allclose(fwhm_formula, fwhm_empirical, atol=3 / sfreq) -def test_time_frequency(): +def test_tfr_morlet(): """Test time-frequency transform (PSD and ITC).""" # Set parameters event_id = 1 @@ -148,7 +166,8 @@ def test_time_frequency(): # Now compute evoked evoked = epochs.average() - pytest.raises(ValueError, tfr_morlet, evoked, freqs, 1.0, return_itc=True) + with pytest.raises(ValueError, match="Inter-trial coherence is not supported with"): + tfr_morlet(evoked, freqs, n_cycles=1.0, return_itc=True) power, itc = tfr_morlet( epochs, freqs=freqs, n_cycles=n_cycles, use_fft=True, return_itc=True ) @@ -542,216 +561,193 @@ def test_tfr_multitaper(): tfr_multitaper(epochs, freqs=np.arange(-4, -1), n_cycles=7) -def test_crop(): - """Test TFR cropping.""" - data = np.zeros((3, 4, 5)) - times = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) - freqs = np.array([0.10, 0.20, 0.30, 0.40]) - info = mne.create_info( - ["MEG 001", "MEG 002", "MEG 003"], 1000.0, ["mag", "mag", "mag"] - ) - tfr = AverageTFR( - info, - data=data, - times=times, - freqs=freqs, - nave=20, - comment="test", - method="crazy-tfr", - ) - - tfr.crop(tmin=0.2) - assert_array_equal(tfr.times, [0.2, 0.3, 0.4, 0.5]) - assert tfr.data.ndim == 3 - assert tfr.data.shape[-1] == 4 - - tfr.crop(fmax=0.3) - assert_array_equal(tfr.freqs, [0.1, 0.2, 0.3]) - assert tfr.data.ndim == 3 - assert tfr.data.shape[-2] == 3 - - tfr.crop(tmin=0.3, tmax=0.4, fmin=0.1, fmax=0.2) - assert_array_equal(tfr.times, [0.3, 0.4]) - assert tfr.data.ndim == 3 - assert tfr.data.shape[-1] == 2 - assert_array_equal(tfr.freqs, [0.1, 0.2]) - assert tfr.data.shape[-2] == 2 - - -def test_decim_shift_time(): - """Test TFR decimation and shift_time.""" - data = np.zeros((3, 3, 3, 1000)) - times = np.linspace(0, 1, 1000) - freqs = np.array([0.10, 0.20, 0.30]) - info = mne.create_info( - ["MEG 001", "MEG 002", "MEG 003"], 1000.0, ["mag", "mag", "mag"] - ) - with info._unlock(): - info["lowpass"] = 100 - tfr = EpochsTFR(info, data=data, times=times, freqs=freqs) - tfr_ave = tfr.average() - assert_allclose(tfr.times, tfr_ave.times) - assert not hasattr(tfr_ave, "first") - tfr_ave.decimate(3) - assert not hasattr(tfr_ave, "first") - tfr.decimate(3) - assert tfr.times.size == 1000 // 3 + 1 - assert tfr.data.shape == ((3, 3, 3, 1000 // 3 + 1)) - tfr_ave_2 = tfr.average() - assert not hasattr(tfr_ave_2, "first") - assert_allclose(tfr.times, tfr_ave.times) - assert_allclose(tfr.times, tfr_ave_2.times) - assert_allclose(tfr_ave_2.data, tfr_ave.data) - tfr.shift_time(-0.1, relative=True) - tfr_ave.shift_time(-0.1, relative=True) - tfr_ave_3 = tfr.average() - assert_allclose(tfr_ave_3.times, tfr_ave.times) - assert_allclose(tfr_ave_3.data, tfr_ave.data) - assert_allclose(tfr_ave_2.data, tfr_ave_3.data) # data unchanged - - -def test_io(tmp_path): - """Test TFR IO capacities.""" - pd = pytest.importorskip("pandas") +@pytest.mark.parametrize( + "method,freqs", + ( + pytest.param("morlet", freqs_linspace, id="morlet"), + pytest.param("multitaper", freqs_linspace, id="multitaper"), + pytest.param("stockwell", freqs_linspace[[0, -1]], id="stockwell"), + ), +) +@pytest.mark.parametrize("decim", (4, slice(0, 200), slice(1, 200, 3))) +def test_tfr_decim_and_shift_time(epochs, method, freqs, decim): + """Test TFR decimation; slices must be long-ish to be longer than the wavelets.""" + tfr = epochs.compute_tfr(method, freqs=freqs, decim=decim) + if not isinstance(decim, slice): + decim = slice(None, None, decim) + # check n_times + want = len(range(*decim.indices(len(epochs.times)))) + assert tfr.shape[-1] == want + # Check that decim changes sfreq + assert tfr.sfreq == epochs.info["sfreq"] / (decim.step or 1) + # check after-the-fact decimation. The mixin .decimate method doesn't allow slices + if isinstance(decim, int): + tfr2 = epochs.compute_tfr(method, freqs=freqs, decim=1) + tfr2.decimate(decim) + assert tfr == tfr2 + # test .shift_time() too + shift = -0.137 + data, times, freqs = tfr.get_data(return_times=True, return_freqs=True) + tfr.shift_time(shift, relative=True) + assert_allclose(times + shift, tfr.times, rtol=0, atol=0.5 / tfr.sfreq) + # shift time should only affect times: + assert_array_equal(data, tfr.get_data()) + assert_array_equal(freqs, tfr.freqs) + + +@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr")) +def test_tfr_io(inst, average_tfr, request, tmp_path): + """Test TFR I/O.""" pytest.importorskip("h5io") + pd = pytest.importorskip("pandas") - fname = tmp_path / "test-tfr.h5" - data = np.zeros((3, 2, 3)) - times = np.array([0.1, 0.2, 0.3]) - freqs = np.array([0.10, 0.20]) - - info = mne.create_info( - ["MEG 001", "MEG 002", "MEG 003"], 1000.0, ["mag", "mag", "mag"] - ) - with info._unlock(check_after=True): - info["meas_date"] = datetime.datetime( - year=2020, month=2, day=5, tzinfo=datetime.timezone.utc - ) - tfr = AverageTFR( - info, - data=data, - times=times, - freqs=freqs, - nave=20, - comment="test", - method="crazy-tfr", - ) - tfr.save(fname) - tfr2 = read_tfrs(fname, condition="test") - assert isinstance(tfr2.info, Info) - assert isinstance(tfr2.info["dev_head_t"], Transform) - - assert_array_equal(tfr.data, tfr2.data) - assert_array_equal(tfr.times, tfr2.times) - assert_array_equal(tfr.freqs, tfr2.freqs) - assert_equal(tfr.comment, tfr2.comment) - assert_equal(tfr.nave, tfr2.nave) - - pytest.raises(OSError, tfr.save, fname) - - tfr.comment = None - # test old meas_date - with info._unlock(): - info["meas_date"] = (1, 2) + tfr = _get_inst(inst, request, average_tfr=average_tfr) + fname = tmp_path / "temp_tfr.hdf5" + # test .save() method + tfr.save(fname, overwrite=True) + assert read_tfrs(fname) == tfr + # test save single TFR with write_tfrs() + write_tfrs(fname, tfr, overwrite=True) + assert read_tfrs(fname) == tfr + # test save multiple TFRs with write_tfrs() + tfr2 = tfr.copy() + tfr2._data = np.zeros_like(tfr._data) + write_tfrs(fname, [tfr, tfr2], overwrite=True) + tfr_list = read_tfrs(fname) + assert tfr_list[0] == tfr + assert tfr_list[1] == tfr2 + # test condition-related errors + if isinstance(tfr, AverageTFR): + # auto-generated keys: first TFR has comment, so `0` not assigned + tfr2.comment = None + write_tfrs(fname, [tfr, tfr2], overwrite=True) + with pytest.raises(ValueError, match='Cannot find condition "0" in this'): + read_tfrs(fname, condition=0) + # second TFR had no comment, so should get auto-comment `1` assigned + read_tfrs(fname, condition=1) + return + else: + with pytest.raises(NotImplementedError, match="condition is only supported"): + read_tfrs(fname, condition="foo") + # the rest we only do for EpochsTFR (no need to parametrize) + if isinstance(tfr, RawTFR): + return + # make sure everything still works if there's metadata + tfr.metadata = pd.DataFrame(dict(foo=range(tfr.shape[0])), index=tfr.selection) + # test old-style meas date + sec_microsec_tuple = (1, 2) + with tfr.info._unlock(): + tfr.info["meas_date"] = sec_microsec_tuple tfr.save(fname, overwrite=True) - assert_equal(read_tfrs(fname, condition=0).comment, tfr.comment) - tfr.comment = "test-A" - tfr2.comment = "test-B" - - fname = tmp_path / "test2-tfr.h5" - write_tfrs(fname, [tfr, tfr2]) - tfr3 = read_tfrs(fname, condition="test-A") - assert_equal(tfr.comment, tfr3.comment) - - assert isinstance(tfr.info, mne.Info) - - tfrs = read_tfrs(fname, condition=None) - assert_equal(len(tfrs), 2) - tfr4 = tfrs[1] - assert_equal(tfr2.comment, tfr4.comment) - - pytest.raises(ValueError, read_tfrs, fname, condition="nonono") - # Test save of EpochsTFR. - n_events = 5 - data = np.zeros((n_events, 3, 2, 3)) - - # create fake metadata - rng = np.random.RandomState(42) - rt = np.round(rng.uniform(size=(n_events,)), 3) - trialtypes = np.array(["face", "place"]) - trial = trialtypes[(rng.uniform(size=(n_events,)) > 0.5).astype(int)] - meta = pd.DataFrame(dict(RT=rt, Trial=trial)) - # fake events and event_id - events = np.zeros([n_events, 3]) - events[:, 0] = np.arange(n_events) - events[:, 2] = np.ones(n_events) - event_id = {"a/b": 1} - # fake selection - n_dropped_epochs = 3 - selection = np.arange(n_events + n_dropped_epochs)[n_dropped_epochs:] - drop_log = tuple( - [("IGNORED",) for i in range(n_dropped_epochs)] + [() for i in range(n_events)] + tfr_loaded = read_tfrs(fname) + want = datetime.datetime( + year=1970, + month=1, + day=1, + hour=0, + minute=0, + second=sec_microsec_tuple[0], + microsecond=sec_microsec_tuple[1], + tzinfo=datetime.timezone.utc, ) - - tfr = EpochsTFR( - info, - data=data, - times=times, - freqs=freqs, - comment="test", - method="crazy-tfr", - events=events, - event_id=event_id, - selection=selection, - drop_log=drop_log, - metadata=meta, + assert tfr_loaded.info["meas_date"] == want + with tfr.info._unlock(): + tfr.info["meas_date"] = want + assert tfr_loaded == tfr + # test overwrite + with pytest.raises(OSError, match="Destination file exists."): + tfr.save(fname, overwrite=False) + + +def test_raw_tfr_init(raw): + """Test the RawTFR and RawTFRArray constructors.""" + one = RawTFR(inst=raw, method="morlet", freqs=freqs_linspace) + two = RawTFRArray(one.info, one.data, one.times, one.freqs, method="morlet") + # some attributes we know won't match: + for attr in ("_data_type", "_inst_type"): + assert getattr(one, attr) != getattr(two, attr) + delattr(one, attr) + delattr(two, attr) + assert one == two + # test RawTFR.__getitem__ + data = one[:5] + assert data.shape == (5,) + one.shape[1:] + # test missing method/freqs + with pytest.raises(ValueError, match="RawTFR got unsupported parameter value"): + RawTFR(inst=raw) + + +def test_average_tfr_init(full_evoked): + """Test the AverageTFR and AverageTFRArray constructors.""" + one = AverageTFR(inst=full_evoked, method="morlet", freqs=freqs_linspace) + two = AverageTFRArray( + one.info, + one.data, + one.times, + one.freqs, + method="morlet", + comment=one.comment, + nave=one.nave, ) - fname_save = fname - tfr.save(fname_save, True) - fname_write = tmp_path / "test3-tfr.h5" - write_tfrs(fname_write, tfr, overwrite=True) - for fname in [fname_save, fname_write]: - read_tfr = read_tfrs(fname)[0] - assert_array_equal(tfr.data, read_tfr.data) - assert_metadata_equal(tfr.metadata, read_tfr.metadata) - assert_array_equal(tfr.events, read_tfr.events) - assert tfr.event_id == read_tfr.event_id - assert_array_equal(tfr.selection, read_tfr.selection) - assert tfr.drop_log == read_tfr.drop_log - with pytest.raises(NotImplementedError, match="condition not supported"): - tfr = read_tfrs(fname, condition="a") - - -def test_init_EpochsTFR(): + # some attributes we know won't match, otherwise should be identical + assert one._data_type != two._data_type + one._data_type = two._data_type + assert one == two + # test missing method, bad freqs + with pytest.raises(ValueError, match="AverageTFR got unsupported parameter value"): + AverageTFR(inst=full_evoked) + with pytest.raises(ValueError, match='must be a length-2 iterable or "auto"'): + AverageTFR(inst=full_evoked, method="stockwell", freqs=freqs_linspace) + + +def test_epochstfr_init_errors(epochs_tfr): """Test __init__ for EpochsTFR.""" - # Create fake data: - data = np.zeros((3, 3, 3, 3)) - times = np.array([0.1, 0.2, 0.3]) - freqs = np.array([0.10, 0.20, 0.30]) - info = mne.create_info( - ["MEG 001", "MEG 002", "MEG 003"], 1000.0, ["mag", "mag", "mag"] - ) - data_x = data[:, :, :, 0] - with pytest.raises(ValueError, match="data should be 4d. Got 3"): - tfr = EpochsTFR(info, data=data_x, times=times, freqs=freqs) - data_x = data[:, :-1, :, :] - with pytest.raises(ValueError, match="channels and data size don't"): - tfr = EpochsTFR(info, data=data_x, times=times, freqs=freqs) - times_x = times[:-1] - with pytest.raises(ValueError, match="times and data size don't match"): - tfr = EpochsTFR(info, data=data, times=times_x, freqs=freqs) - freqs_x = freqs[:-1] - with pytest.raises(ValueError, match="frequencies and data size don't"): - tfr = EpochsTFR(info, data=data, times=times_x, freqs=freqs_x) - del tfr - - -def test_equalize_epochs_tfr_counts(): + state = epochs_tfr.__getstate__() + with pytest.raises(ValueError, match="EpochsTFR data should be 4D, got 3"): + EpochsTFR(inst=state | dict(data=epochs_tfr.data[..., 0])) + with pytest.raises(ValueError, match="Channel axis of data .* doesn't match info"): + EpochsTFR(inst=state | dict(data=epochs_tfr.data[:, :-1])) + with pytest.raises(ValueError, match="Time axis of data.*doesn't match times attr"): + EpochsTFR(inst=state | dict(times=epochs_tfr.times[:-1])) + with pytest.raises(ValueError, match="Frequency axis of.*doesn't match freqs attr"): + EpochsTFR(inst=state | dict(freqs=epochs_tfr.freqs[:-1])) + + +@pytest.mark.parametrize("inst", ("epochs_tfr", "average_tfr")) +def test_tfr_init_deprecation(inst, average_tfr, request): + """Check for the deprecation warning message (not needed for RawTFR, it's new).""" + tfr = _get_inst(inst, request, average_tfr=average_tfr) + kwargs = dict(info=tfr.info, data=tfr.data, times=tfr.times, freqs=tfr.freqs) + Klass = tfr.__class__ + with pytest.warns(FutureWarning, match='"info", "data", "times" are deprecat'): + Klass(**kwargs) + with pytest.raises(ValueError, match="Do not pass `inst` alongside deprecated"): + with pytest.warns(FutureWarning, match='"info", "data", "times" are deprecat'): + Klass(**kwargs, inst="foo") + + +@pytest.mark.parametrize( + "method,freqs,match", + ( + ("morlet", None, "EpochsTFR got unsupported parameter value freqs=None."), + (None, freqs_linspace, "got unsupported parameter value method=None."), + (None, None, "got unsupported parameter values method=None and freqs=None."), + ), +) +def test_compute_tfr_init_errors(epochs, method, freqs, match): + """Test that method and freqs are always passed (if not using __setstate__).""" + with pytest.raises(ValueError, match=match): + epochs.compute_tfr(method=method, freqs=freqs) + + +def test_equalize_epochs_tfr_counts(epochs_tfr): """Test equalize_epoch_counts for EpochsTFR.""" - tfr = _create_test_epochstfr() - tfr2 = tfr.copy() + # make the fixture have 3 epochs instead of 1 + epochs_tfr._data = np.vstack((epochs_tfr._data, epochs_tfr._data, epochs_tfr._data)) + tfr2 = epochs_tfr.copy() tfr2 = tfr2[:-1] - equalize_epoch_counts([tfr, tfr2]) + equalize_epoch_counts([epochs_tfr, tfr2]) + assert epochs_tfr.shape == tfr2.shape def test_dB_computation(): @@ -765,9 +761,9 @@ def test_dB_computation(): ["MEG 001", "MEG 002", "MEG 003"], 1000.0, ["mag", "mag", "mag"] ) kwargs = dict(times=times, freqs=freqs, nave=20, comment="test", method="crazy-tfr") - tfr = AverageTFR(info, data=data, **kwargs) - complex_tfr = AverageTFR(info, data=complex_data, **kwargs) - plot_kwargs = dict(dB=True, combine="mean", vmin=0, vmax=7) + tfr = AverageTFRArray(info=info, data=data, **kwargs) + complex_tfr = AverageTFRArray(info=info, data=complex_data, **kwargs) + plot_kwargs = dict(dB=True, combine="mean", vlim=(0, 7)) fig1 = tfr.plot(**plot_kwargs)[0] fig2 = complex_tfr.plot(**plot_kwargs)[0] # since we're fixing vmin/vmax, equal colors should mean ~equal input data @@ -785,8 +781,8 @@ def test_plot(): info = mne.create_info( ["MEG 001", "MEG 002", "MEG 003"], 1000.0, ["mag", "mag", "mag"] ) - tfr = AverageTFR( - info, + tfr = AverageTFRArray( + info=info, data=data, times=times, freqs=freqs, @@ -795,88 +791,6 @@ def test_plot(): method="crazy-tfr", ) - # test title=auto, combine=None, and correct length of figure list - picks = [1, 2] - figs = tfr.plot( - picks, title="auto", colorbar=False, mask=np.ones(tfr.data.shape[1:], bool) - ) - assert len(figs) == len(picks) - assert "MEG" in figs[0].texts[0].get_text() - plt.close("all") - - # test combine and title keyword - figs = tfr.plot( - picks, - title="title", - colorbar=False, - combine="rms", - mask=np.ones(tfr.data.shape[1:], bool), - ) - assert len(figs) == 1 - assert figs[0].texts[0].get_text() == "title" - figs = tfr.plot( - picks, - title="auto", - colorbar=False, - combine="mean", - mask=np.ones(tfr.data.shape[1:], bool), - ) - assert len(figs) == 1 - assert figs[0].texts[0].get_text() == "Mean of 2 sensors" - figs = tfr.plot( - picks, - title="auto", - colorbar=False, - combine=lambda x: x.mean(axis=0), - mask=np.ones(tfr.data.shape[1:], bool), - ) - assert len(figs) == 1 - - with pytest.raises(ValueError, match="Invalid value for the 'combine'"): - tfr.plot( - picks, - colorbar=False, - combine="something", - mask=np.ones(tfr.data.shape[1:], bool), - ) - with pytest.raises(RuntimeError, match="must operate on a single"): - tfr.plot(picks, combine=lambda x, y: x.mean(axis=0)) - with pytest.raises(RuntimeError, match=re.escape("of shape (n_freqs, n_times).")): - tfr.plot(picks, combine=lambda x: x.mean(axis=0, keepdims=True)) - with pytest.raises( - RuntimeError, - match=re.escape("return a numpy array of shape (n_freqs, n_times)."), - ): - tfr.plot(picks, combine=lambda x: 101) - - plt.close("all") - - # test axes argument - first with list of axes - ax = plt.subplot2grid((2, 2), (0, 0)) - ax2 = plt.subplot2grid((2, 2), (0, 1)) - ax3 = plt.subplot2grid((2, 2), (1, 0)) - figs = tfr.plot(picks=[0, 1, 2], axes=[ax, ax2, ax3]) - assert len(figs) == len([ax, ax2, ax3]) - # and as a single axes - figs = tfr.plot(picks=[0], axes=ax) - assert len(figs) == 1 - plt.close("all") - # and invalid inputs - with pytest.raises(ValueError, match="axes must be None"): - tfr.plot(picks, colorbar=False, axes={}, mask=np.ones(tfr.data.shape[1:], bool)) - - # different number of axes and picks should throw a RuntimeError - with pytest.raises(RuntimeError, match="There must be an axes"): - tfr.plot( - picks=[0], - colorbar=False, - axes=[ax, ax2], - mask=np.ones(tfr.data.shape[1:], bool), - ) - - tfr.plot_topo(picks=[1, 2]) - plt.close("all") - # interactive mode on by default fig = tfr.plot(picks=[1], cmap="RdBu_r")[0] _fake_keypress(fig, "up") @@ -907,65 +821,76 @@ def test_plot(): plt.close("all") -def test_plot_joint(): - """Test TFR joint plotting.""" - raw = read_raw_fif(raw_fname) - times = np.linspace(-0.1, 0.1, 200) - n_freqs = 3 - nave = 1 - rng = np.random.RandomState(42) - data = rng.randn(len(raw.ch_names), n_freqs, len(times)) - tfr = AverageTFR(raw.info, data, times, np.arange(n_freqs), nave) - - topomap_args = {"res": 8, "contours": 0, "sensors": False} - - for combine in ("mean", "rms", lambda x: x.mean(axis=0)): - with catch_logging() as log: - tfr.plot_joint( - title="auto", - colorbar=True, - combine=combine, - topomap_args=topomap_args, - verbose="debug", - ) - plt.close("all") - log = log.getvalue() - assert "Plotting topomap for grad data" in log - - # check various timefreqs - for timefreqs in ( - { - (tfr.times[0], tfr.freqs[1]): (0.1, 0.5), - (tfr.times[-1], tfr.freqs[-1]): (0.2, 0.6), - }, - [(tfr.times[1], tfr.freqs[1])], - ): - tfr.plot_joint(timefreqs=timefreqs, topomap_args=topomap_args) - plt.close("all") - - # test bad timefreqs - timefreqs = ( - [(-100, 1)], - tfr.times[1], - [1], - [(tfr.times[1], tfr.freqs[1], tfr.freqs[1])], +@pytest.mark.parametrize( + "timefreqs,title,combine", + ( + pytest.param( + {(0.33, 23): (0, 0), (0.25, 30): (0.1, 2)}, + "0.25 ± 0.05 s,\n30.0 ± 1.0 Hz", + "mean", + id="dict,mean", + ), + pytest.param([(0.25, 30)], "0.25 s,\n30.0 Hz", "rms", id="list,rms"), + pytest.param(None, None, lambda x: x.mean(axis=0), id="none,lambda"), + ), +) +@parametrize_inst_and_ch_type +def test_tfr_plot_joint( + inst, ch_type, combine, timefreqs, title, full_average_tfr, request +): + """Test {Raw,Epochs,Average}TFR.plot_joint().""" + tfr = _get_inst(inst, request, average_tfr=full_average_tfr) + with catch_logging() as log: + fig = tfr.plot_joint( + picks=ch_type, + timefreqs=timefreqs, + combine=combine, + topomap_args=dict(res=8, contours=0, sensors=False), # for speed + verbose="debug", + ) + assert f"Plotting topomap for {ch_type} data" in log.getvalue() + # check for correct number of axes + n_topomaps = 1 if timefreqs is None else len(timefreqs) + assert len(fig.axes) == n_topomaps + 2 # n_topomaps + 1 image + 1 colorbar + # title varies by `ch_type` when `timefreqs=None`, so we don't test that here + if title is not None: + assert fig.axes[0].get_title() == title + # test interactivity + ax = [ax for ax in fig.axes if ax.get_xlabel() == "Time (s)"][0] + kw = dict(fig=fig, ax=ax, xform="ax") + _fake_click(**kw, kind="press", point=(0.4, 0.4)) + _fake_click(**kw, kind="motion", point=(0.5, 0.5)) + _fake_click(**kw, kind="release", point=(0.6, 0.6)) + # make sure we actually got a pop-up figure, and it has a plausible title + fignums = plt.get_fignums() + assert len(fignums) == 2 + popup_fig = plt.figure(fignums[-1]) + assert re.match( + r"-?\d{1,2}\.\d{3} - -?\d{1,2}\.\d{3} s,\n\d{1,2}\.\d{2} - \d{1,2}\.\d{2} Hz", + _get_suptitle(popup_fig), ) - for these_timefreqs in timefreqs: - pytest.raises(ValueError, tfr.plot_joint, these_timefreqs) - # test that the object is not internally modified - tfr_orig = tfr.copy() - tfr.plot_joint( - baseline=(0, None), exclude=[tfr.ch_names[0]], topomap_args=topomap_args - ) - plt.close("all") - assert_array_equal(tfr.data, tfr_orig.data) - assert set(tfr.ch_names) == set(tfr_orig.ch_names) - assert set(tfr.times) == set(tfr_orig.times) - # test tfr with picked channels - tfr.pick(tfr.ch_names[:-1]) - tfr.plot_joint(title="auto", colorbar=True, topomap_args=topomap_args) +@pytest.mark.parametrize( + "match,timefreqs,topomap_args", + ( + (r"Requested time point \(-88.000 s\) exceeds the range of", [(-88, 1)], None), + (r"Requested frequency \(99.0 Hz\) exceeds the range of", [(0.0, 99)], None), + ("list of tuple pairs, or a dict of such tuple pairs, not 0", [0.0], None), + ("does not match the channel type present in", None, dict(ch_type="eeg")), + ), +) +def test_tfr_plot_joint_errors(full_average_tfr, match, timefreqs, topomap_args): + """Test AverageTFR.plot_joint() error messages.""" + with pytest.raises(ValueError, match=match): + full_average_tfr.plot_joint(timefreqs=timefreqs, topomap_args=topomap_args) + + +def test_tfr_plot_joint_doesnt_modify(full_average_tfr): + """Test that the object is unchanged after plot_joint().""" + tfr = full_average_tfr.copy() + full_average_tfr.plot_joint() + assert tfr == full_average_tfr def test_add_channels(): @@ -978,8 +903,8 @@ def test_add_channels(): 1000.0, ["mag", "mag", "mag", "eeg", "eeg", "stim"], ) - tfr = AverageTFR( - info, + tfr = AverageTFRArray( + info=info, data=data, times=times, freqs=freqs, @@ -1199,13 +1124,12 @@ def test_averaging_epochsTFR(): avgpower = power.average(method=method) assert_array_equal(func(power.data, axis=0), avgpower.data) with pytest.raises( - RuntimeError, match="You passed a function that " "resulted in data" + RuntimeError, match=r"EpochsTFR.average\(\) got .* shape \(\), but it should be" ): power.average(method=np.mean) -@pytest.mark.parametrize("copy", [True, False]) -def test_averaging_freqsandtimes_epochsTFR(copy): +def test_averaging_freqsandtimes_epochsTFR(): """Test that EpochsTFR averaging freqs methods work.""" # Setup for reading the raw data event_id = 1 @@ -1240,138 +1164,60 @@ def test_averaging_freqsandtimes_epochsTFR(copy): return_itc=False, ) - # Test average methods for freqs and times - for idx, (func, method) in enumerate( - zip( - [np.mean, np.median, np.mean, np.mean], - [ - "mean", - "median", - lambda x: np.mean(x, axis=2), - lambda x: np.mean(x, axis=3), - ], - ) + # Test averaging over freqs + kwargs = dict(dim="freqs", copy=True) + for method, func in zip( + ("mean", "median", lambda x: np.mean(x, axis=2)), (np.mean, np.median, np.mean) ): - if idx == 3: - with pytest.raises(RuntimeError, match="You passed a function"): - avgpower = power.copy().average(method=method, dim="freqs", copy=copy) - continue - avgpower = power.copy().average(method=method, dim="freqs", copy=copy) - assert_array_equal(func(power.data, axis=2, keepdims=True), avgpower.data) - assert avgpower.freqs == np.mean(power.freqs) + avgpower = power.average(method=method, **kwargs) + assert_array_equal(avgpower.data, func(power.data, axis=2, keepdims=True)) + assert_array_equal(avgpower.freqs, func(power.freqs, keepdims=True)) assert isinstance(avgpower, EpochsTFR) - - # average over epochs - avgpower = avgpower.average() + avgpower = avgpower.average() # average over epochs assert isinstance(avgpower, AverageTFR) - - # Test average methods for freqs and times - for idx, (func, method) in enumerate( - zip( - [np.mean, np.median, np.mean, np.mean], - [ - "mean", - "median", - lambda x: np.mean(x, axis=3), - lambda x: np.mean(x, axis=2), - ], - ) + with pytest.raises(RuntimeError, match=r"shape \(1, 2, 3\), but it should"): + # collapsing wrong axis (time instead of freq) + avgpower = power.average(method=lambda x: np.mean(x, axis=3), **kwargs) + + # Test averaging over times + kwargs = dict(dim="times", copy=False) + for method, func in zip( + ("mean", "median", lambda x: np.mean(x, axis=3)), (np.mean, np.median, np.mean) ): - if idx == 3: - with pytest.raises(RuntimeError, match="You passed a function"): - avgpower = power.copy().average(method=method, dim="times", copy=copy) - continue - avgpower = power.copy().average(method=method, dim="times", copy=copy) - assert_array_equal(func(power.data, axis=-1, keepdims=True), avgpower.data) - assert avgpower.times == np.mean(power.times) - assert isinstance(avgpower, EpochsTFR) + avgpower = power.average(method=method, **kwargs) + assert_array_equal(avgpower.data, func(power.data, axis=-1, keepdims=False)) + assert isinstance(avgpower, EpochsSpectrum) + with pytest.raises(RuntimeError, match=r"shape \(1, 2, 420\), but it should"): + # collapsing wrong axis (freq instead of time) + avgpower = power.average(method=lambda x: np.mean(x, axis=2), **kwargs) - # average over epochs - avgpower = avgpower.average() - assert isinstance(avgpower, AverageTFR) - -def test_getitem_epochsTFR(): - """Test GetEpochsMixin in the context of EpochsTFR.""" +@pytest.mark.parametrize("n_drop", (0, 2)) +def test_epochstfr_getitem(epochs_full, n_drop): + """Test EpochsTFR.__getitem__().""" pd = pytest.importorskip("pandas") - - # Setup for reading the raw data and select a few trials - raw = read_raw_fif(raw_fname) - events = read_events(event_fname) - # create fake data, test with and without dropping epochs - for n_drop_epochs in [0, 2]: - n_events = 12 - # create fake metadata - rng = np.random.RandomState(42) - rt = rng.uniform(size=(n_events,)) - trialtypes = np.array(["face", "place"]) - trial = trialtypes[(rng.uniform(size=(n_events,)) > 0.5).astype(int)] - meta = pd.DataFrame(dict(RT=rt, Trial=trial)) - event_id = dict(a=1, b=2, c=3, d=4) - epochs = Epochs( - raw, events[:n_events], event_id=event_id, metadata=meta, decim=1 - ) - epochs.drop(np.arange(n_drop_epochs)) - n_events -= n_drop_epochs - - freqs = np.arange(12.0, 17.0, 2.0) # define frequencies of interest - n_cycles = freqs / 2.0 # 0.5 second time windows for all frequencies - - # Choose time x (full) bandwidth product - time_bandwidth = 4.0 - # With 0.5 s time windows, this gives 8 Hz smoothing - kwargs = dict( - freqs=freqs, - n_cycles=n_cycles, - use_fft=True, - time_bandwidth=time_bandwidth, - return_itc=False, - average=False, - n_jobs=None, - ) - power = tfr_multitaper(epochs, **kwargs) - - # Check that power and epochs metadata is the same - assert_metadata_equal(epochs.metadata, power.metadata) - assert_metadata_equal(epochs[::2].metadata, power[::2].metadata) - assert_metadata_equal(epochs["RT < .5"].metadata, power["RT < .5"].metadata) - assert_array_equal(epochs.selection, power.selection) - assert epochs.drop_log == power.drop_log - - # Check that get power is functioning - assert_array_equal(power[3:6].data, power.data[3:6]) - assert_array_equal(power[3:6].events, power.events[3:6]) - assert_array_equal(epochs.selection[3:6], power.selection[3:6]) - - indx_check = power.metadata["Trial"] == "face" - try: - indx_check = indx_check.to_numpy() - except Exception: - pass # older Pandas - indx_check = indx_check.nonzero() - assert_array_equal(power['Trial == "face"'].events, power.events[indx_check]) - assert_array_equal(power['Trial == "face"'].data, power.data[indx_check]) - - # Check that the wrong Key generates a Key Error for Metadata search - with pytest.raises(KeyError): - power['Trialz == "place"'] - - # Test length function - assert len(power) == n_events - assert len(power[3:6]) == 3 - - # Test iteration function - for ind, power_ep in enumerate(power): - assert_array_equal(power_ep, power.data[ind]) - if ind == 5: - break - - # Test that current state is maintained - assert_array_equal(power.next(), power.data[ind + 1]) - - # Check decim affects sfreq - power_decim = tfr_multitaper(epochs, decim=2, **kwargs) - assert power.info["sfreq"] / 2.0 == power_decim.info["sfreq"] + from pandas.testing import assert_frame_equal + + epochs_full.metadata = pd.DataFrame(dict(foo=list("aaaabbb"), bar=np.arange(7))) + epochs_full.drop(np.arange(n_drop)) + tfr = epochs_full.compute_tfr(method="morlet", freqs=freqs_linspace) + # check that various attributes are preserved + assert_frame_equal(tfr.metadata, epochs_full.metadata) + assert epochs_full.drop_log == tfr.drop_log + for attr in ("events", "selection", "times"): + assert_array_equal(getattr(epochs_full, attr), getattr(tfr, attr)) + # test pandas query + foo_a = tfr["foo == 'a'"] + bar_3 = tfr["bar <= 3"] + assert foo_a == bar_3 + assert foo_a.shape[0] == 4 - n_drop + # test integer and slice + subset_ints = tfr[[0, 1, 2]] + subset_slice = tfr[:3] + assert subset_ints == subset_slice + # test iteration + for ix, epo in enumerate(tfr): + assert_array_equal(tfr[ix].data, epo.data.obj[np.newaxis]) def test_to_data_frame(): @@ -1393,8 +1239,13 @@ def test_to_data_frame(): events[:, 2] = np.arange(5, 5 + n_epos) event_id = {k: v for v, k in zip(events[:, 2], ["ha", "he", "hu"])} info = mne.create_info(ch_names, srate, ch_types) - tfr = mne.time_frequency.EpochsTFR( - info, data, times, freqs, events=events, event_id=event_id + tfr = EpochsTFRArray( + info=info, + data=data, + times=times, + freqs=freqs, + events=events, + event_id=event_id, ) # test index checking with pytest.raises(ValueError, match="options. Valid index options are"): @@ -1477,8 +1328,13 @@ def test_to_data_frame_index(index): events[:, 2] = np.arange(5, 8) event_id = {k: v for v, k in zip(events[:, 2], ["ha", "he", "hu"])} info = mne.create_info(ch_names, 1000.0, ch_types) - tfr = mne.time_frequency.EpochsTFR( - info, data, times, freqs, events=events, event_id=event_id + tfr = EpochsTFRArray( + info=info, + data=data, + times=times, + freqs=freqs, + events=events, + event_id=event_id, ) df = tfr.to_data_frame(picks=[0, 2, 3], index=index) # test index order/hierarchy preservation @@ -1502,17 +1358,333 @@ def test_to_data_frame_time_format(time_format): n_freqs = 5 n_times = 6 data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) + times = np.arange(6, dtype=float) freqs = np.arange(5) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 8) event_id = {k: v for v, k in zip(events[:, 2], ["ha", "he", "hu"])} info = mne.create_info(ch_names, 1000.0, ch_types) - tfr = mne.time_frequency.EpochsTFR( - info, data, times, freqs, events=events, event_id=event_id + tfr = EpochsTFRArray( + info=info, + data=data, + times=times, + freqs=freqs, + events=events, + event_id=event_id, ) # test time_format df = tfr.to_data_frame(time_format=time_format) dtypes = {None: np.float64, "ms": np.int64, "timedelta": pd.Timedelta} assert isinstance(df["time"].iloc[0], dtypes[time_format]) + + +@parametrize_morlet_multitaper +@parametrize_power_phase_complex +@pytest.mark.parametrize("picks", ("mag", mag_names, [2, 5, 8])) # all 3 equivalent +def test_raw_compute_tfr(raw, method, output, picks): + """Test Raw.compute_tfr() and picks handling.""" + full_tfr = raw.compute_tfr(method, output=output, freqs=freqs_linspace) + pick_tfr = raw.compute_tfr(method, output=output, freqs=freqs_linspace, picks=picks) + assert isinstance(pick_tfr, RawTFR), type(pick_tfr) + # ↓↓↓ can't use [2,5,8] because ch0 is IAS, so indices change between raw and TFR + want = full_tfr.get_data(picks=mag_names) + got = pick_tfr.get_data() + assert_array_equal(want, got) + + +@parametrize_morlet_multitaper +@parametrize_power_phase_complex +@pytest.mark.parametrize("freqs", (freqs_linspace, freqs_unsorted_list)) +def test_evoked_compute_tfr(full_evoked, method, output, freqs): + """Test Evoked.compute_tfr(), with a few different ways of specifying freqs.""" + tfr = full_evoked.compute_tfr(method, freqs, output=output) + assert isinstance(tfr, AverageTFR), type(tfr) + assert tfr.nave == full_evoked.nave + assert tfr.comment == full_evoked.comment + + +@parametrize_morlet_multitaper +@pytest.mark.parametrize( + "average,return_itc,dim,want_class", + ( + pytest.param(True, False, None, None, id="average,no_itc"), + pytest.param(True, True, None, None, id="average,itc"), + pytest.param(False, False, "freqs", EpochsTFR, id="no_average,agg_freqs"), + pytest.param(False, False, "epochs", AverageTFR, id="no_average,agg_epochs"), + pytest.param(False, False, "times", EpochsSpectrum, id="no_average,agg_times"), + ), +) +def test_epochs_compute_tfr_average_itc( + epochs, method, average, return_itc, dim, want_class +): + """Test Epochs.compute_tfr(), averaging (at call time and afterward), and ITC.""" + tfr = epochs.compute_tfr( + method, freqs=freqs_linspace, average=average, return_itc=return_itc + ) + if return_itc: + tfr, itc = tfr + assert isinstance(itc, AverageTFR), type(itc) + # for single-epoch input, ITC should be (nearly) unity + assert_array_almost_equal(itc.get_data(), 1.0, decimal=15) + # if not averaging initially, make sure the post-facto .average() works too + if average: + assert isinstance(tfr, AverageTFR), type(tfr) + assert tfr.nave == 1 + assert tfr.comment == "1" + else: + assert isinstance(tfr, EpochsTFR), type(tfr) + avg = tfr.average(dim=dim) + assert isinstance(avg, want_class), type(avg) + if dim == "epochs": + assert avg.nave == len(epochs) + assert avg.comment.startswith(f"mean of {len(epochs)} EpochsTFR") + + +def test_epochs_vs_evoked_compute_tfr(epochs): + """Compare result of averaging before or after the TFR computation. + + This is mostly a test of object structure / attribute preservation. In normal cases, + the data should not match: + - epochs.compute_tfr().average() is average of squared magnitudes + - epochs.average().compute_tfr() is squared magnitude of average + But the `epochs` fixture has only one epoch, so here data should be identical too. + + The three things that will always end up different are `._comment`, `._inst_type`, + and `._data_type`, so we ignore those here. + """ + avg_first = epochs.average().compute_tfr(method="morlet", freqs=freqs_linspace) + avg_second = epochs.compute_tfr(method="morlet", freqs=freqs_linspace).average() + for attr in ("_comment", "_inst_type", "_data_type"): + assert getattr(avg_first, attr) != getattr(avg_second, attr) + delattr(avg_first, attr) + delattr(avg_second, attr) + assert avg_first == avg_second + + +morlet_kw = dict(n_cycles=freqs_linspace / 4, use_fft=False, zero_mean=True) +mt_kw = morlet_kw | dict(zero_mean=False, time_bandwidth=6) +stockwell_kw = dict(n_fft=1024, width=2) + + +@pytest.mark.parametrize( + "method,freqs,method_kw", + ( + pytest.param("morlet", freqs_linspace, morlet_kw, id="morlet-nondefaults"), + pytest.param("multitaper", freqs_linspace, mt_kw, id="multitaper-nondefaults"), + pytest.param("stockwell", "auto", stockwell_kw, id="stockwell-nondefaults"), + ), +) +def test_epochs_compute_tfr_method_kw(epochs, method, freqs, method_kw): + """Test Epochs.compute_tfr(**method_kw).""" + tfr = epochs.compute_tfr(method, freqs=freqs, average=True, **method_kw) + assert isinstance(tfr, AverageTFR), type(tfr) + + +@pytest.mark.parametrize( + "freqs", + (pytest.param("auto", id="freqauto"), pytest.param([20, 41], id="fminfmax")), +) +@pytest.mark.parametrize("return_itc", (False, True)) +def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc): + """Test Epochs.compute_tfr(method="stockwell").""" + tfr = epochs.compute_tfr("stockwell", freqs, return_itc=return_itc) + if return_itc: + tfr, itc = tfr + assert isinstance(itc, AverageTFR) + # for single-epoch input, ITC should be (nearly) unity + assert_array_almost_equal(itc.get_data(), 1.0, decimal=15) + assert isinstance(tfr, AverageTFR) + assert tfr.comment == "1" + + +@pytest.mark.parametrize("copy", (False, True)) +def test_epochstfr_iter_evoked(epochs_tfr, copy): + """Test EpochsTFR.iter_evoked().""" + avgs = list(epochs_tfr.iter_evoked(copy=copy)) + assert len(avgs) == len(epochs_tfr) + assert all(avg.nave == 1 for avg in avgs) + assert avgs[0].comment == str(epochs_tfr.events[0, -1]) + + +def test_tfr_proj(epochs): + """Test `compute_tfr(proj=True)`.""" + epochs.compute_tfr(method="morlet", freqs=freqs_linspace, proj=True) + + +def test_tfr_copy(average_tfr): + """Test BaseTFR.copy() method.""" + tfr_copy = average_tfr.copy() + # check that info is independent + tfr_copy.info["bads"] = tfr_copy.ch_names + assert average_tfr.info["bads"] == [] + # check that data is independent + tfr_copy.data = np.inf + assert np.isfinite(average_tfr.get_data()).all() + + +@pytest.mark.parametrize( + "mode", ("mean", "ratio", "logratio", "percent", "zscore", "zlogratio") +) +def test_tfr_apply_baseline(average_tfr, mode): + """Test TFR baselining.""" + average_tfr.apply_baseline((-0.1, -0.05), mode=mode) + + +def test_tfr_arithmetic(epochs): + """Test TFR arithmetic operations.""" + tfr, itc = epochs.compute_tfr( + "morlet", freqs=freqs_linspace, average=True, return_itc=True + ) + itc_copy = itc.copy() + # addition / subtraction of objects + double = tfr + tfr + double -= tfr + assert tfr == double + itc_copy += tfr + assert itc == itc_copy - tfr + # multiplication / division with scalars + bigger_itc = itc * 23 + assert_array_almost_equal(itc.data, (bigger_itc / 23).data, decimal=15) + # multiplication / division with arrays + arr = np.full_like(itc.data, 23) + assert_array_equal(bigger_itc.data, (itc * arr).data) + # in-place multiplication/division + bigger_itc *= 2 + bigger_itc /= 46 + assert_array_almost_equal(itc.data, bigger_itc.data, decimal=15) + # check errors + with pytest.raises(RuntimeError, match="types do not match"): + tfr + epochs + with pytest.raises(RuntimeError, match="times do not match"): + tfr + tfr.copy().crop(tmax=0.2) + with pytest.raises(RuntimeError, match="freqs do not match"): + tfr + tfr.copy().crop(fmax=33) + + +def test_tfr_repr_html(epochs_tfr): + """Test TFR._repr_html_().""" + result = epochs_tfr._repr_html_(caption="Foo") + for heading in ("Data type", "Data source", "Estimation method"): + assert f"{heading}" in result + for data in ("Power Estimates", "Epochs", "morlet"): + assert f"{data}" in result + + +@pytest.mark.parametrize( + "picks,combine", + ( + pytest.param("mag", "mean", id="mean_of_mags"), + pytest.param("grad", "rms", id="rms_of_grads"), + pytest.param([1], "mean", id="single_channel"), + pytest.param([1, 2], None, id="two_separate_channels"), + ), +) +def test_tfr_plot_combine(epochs_tfr, picks, combine): + """Test TFR.plot() picks, combine, and title="auto". + + No need to parametrize over {Raw,Epochs,Evoked}TFR, the code path is shared. + """ + fig = epochs_tfr.plot(picks=picks, combine=combine, title="auto") + assert len(fig) == 1 if isinstance(picks, str) else len(picks) + # test `title="auto"` + for ix, _fig in enumerate(fig): + if isinstance(picks, str): + ch_type = _channel_type_prettyprint[picks] + want = rf"{'RMS' if combine == 'rms' else 'Mean'} of \d{{1,3}} {ch_type}s" + else: + want = epochs_tfr.ch_names[picks[ix]] + assert re.search(want, _get_suptitle(_fig)) + + +def test_tfr_plot_extras(epochs_tfr): + """Test other options of TFR.plot().""" + # test mask and custom title + picks = [1] + mask = np.ones(epochs_tfr.data.shape[2:], bool) + fig = epochs_tfr.plot(picks=picks, mask=mask, title="Foo") + assert _get_suptitle(fig[0]) == "Foo" + mask = np.ones(epochs_tfr.data.shape[1:], bool) + with pytest.raises(ValueError, match="mask must have the same shape as the data"): + epochs_tfr.plot(picks=picks, mask=mask) + # test combine-related errors + with pytest.raises(ValueError, match='"combine" must be None, a callable, or one'): + epochs_tfr.plot(picks=picks, combine="foo") + with pytest.raises(RuntimeError, match="Wrong type yielded by callable"): + epochs_tfr.plot(picks=picks, combine=lambda x: 777) + with pytest.raises(RuntimeError, match="Wrong shape yielded by callable"): + epochs_tfr.plot(picks=picks, combine=lambda x: np.array([777])) + with pytest.raises(ValueError, match="wrong with the callable passed to 'combine'"): + epochs_tfr.plot(picks=picks, combine=lambda x, y: x.mean(axis=0)) + # test custom Axes + fig, axs = plt.subplots(1, 5) + fig2 = epochs_tfr.plot(picks=[1, 2], combine=lambda x: x.mean(axis=0), axes=axs[0]) + fig3 = epochs_tfr.plot(picks=[1, 2, 3], axes=axs[1:-1]) + fig4 = epochs_tfr.plot(picks=[1], axes=axs[-1:].tolist()) + for _fig in fig2 + fig3 + fig4: + assert fig == _fig + with pytest.raises(ValueError, match="axes must be None"): + epochs_tfr.plot(picks=picks, axes={}) + with pytest.raises(RuntimeError, match="must be one axes for each picked channel"): + epochs_tfr.plot(picks=[1, 2], axes=axs[-1:]) + # test singleton check by faking having 2 epochs + epochs_tfr._data = np.vstack((epochs_tfr._data, epochs_tfr._data)) + with pytest.raises(NotImplementedError, match=r"Cannot call plot\(\) from"): + epochs_tfr.plot() + + +def test_tfr_plot_interactivity(epochs_tfr): + """Test interactivity of TFR.plot().""" + fig = epochs_tfr.plot(picks="mag", combine="mean")[0] + assert len(plt.get_fignums()) == 1 + # press and release in same spot (should do nothing) + kw = dict(fig=fig, ax=fig.axes[0], xform="ax") + _fake_click(**kw, point=(0.5, 0.5), kind="press") + _fake_click(**kw, point=(0.5, 0.5), kind="motion") + _fake_click(**kw, point=(0.5, 0.5), kind="release") + assert len(plt.get_fignums()) == 1 + # click and drag (should create popup topomap) + _fake_click(**kw, point=(0.4, 0.4), kind="press") + _fake_click(**kw, point=(0.5, 0.5), kind="motion") + _fake_click(**kw, point=(0.6, 0.6), kind="release") + assert len(plt.get_fignums()) == 2 + + +@parametrize_inst_and_ch_type +def test_tfr_plot_topo(inst, ch_type, average_tfr, request): + """Test {Raw,Epochs,Average}TFR.plot_topo().""" + tfr = _get_inst(inst, request, average_tfr=average_tfr) + fig = tfr.plot_topo(picks=ch_type) + assert fig is not None + + +@parametrize_inst_and_ch_type +def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request): + """Test {Raw,Epochs,Average}TFR.plot_topomap().""" + tfr = _get_inst(inst, request, average_tfr=full_average_tfr) + fig = tfr.plot_topomap(ch_type=ch_type) + # fake a click-drag-release to select all sensors & generate a pop-up TFR image + ax = fig.axes[0] + pts = [ + coll.get_offsets() + for coll in ax.collections + if isinstance(coll, PathCollection) + ][0] + # sometimes sensors are outside axes; make sure our click starts inside axes + lims = np.vstack((ax.get_xlim(), ax.get_ylim())) + pad = np.diff(lims, axis=1).ravel() / 100 + start = np.clip(pts.min(axis=0) - pad, *(lims.min(axis=1) + pad)) + stop = np.clip(pts.max(axis=0) + pad, *(lims.max(axis=1) - pad)) + kw = dict(fig=fig, ax=ax, xform="data") + _fake_click(**kw, kind="press", point=tuple(start)) + # ↓↓↓ possible bug? using (start+stop)/2 for the motion event causes the motion + # ↓↓↓ event (not release event) coords to propagate → fails to select sensors + _fake_click(**kw, kind="motion", point=tuple(stop)) + _fake_click(**kw, kind="release", point=tuple(stop)) + # make sure we actually got a pop-up figure, and it has a plausible title + fignums = plt.get_fignums() + assert len(fignums) == 2 + popup_fig = plt.figure(fignums[-1]) + assert re.match( + rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title() + ) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index d7df408b564..97df892ad46 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -11,19 +11,17 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +import inspect from copy import deepcopy from functools import partial +import matplotlib.pyplot as plt import numpy as np from scipy.fft import fft, ifft from scipy.signal import argrelmax from .._fiff.meas_info import ContainsMixin, Info -from .._fiff.pick import ( - _picks_to_idx, - channel_type, - pick_info, -) +from .._fiff.pick import _picks_to_idx, pick_info from ..baseline import _check_baseline, rescale from ..channels.channels import UpdateChannelsMixin from ..channels.layout import _find_topomap_coords, _merge_ch_data, _pair_grad_sensors @@ -37,27 +35,35 @@ _build_data_frame, _check_combine, _check_event_id, + _check_fname, + _check_method_kwargs, _check_option, _check_pandas_index_arguments, _check_pandas_installed, _check_time_format, _convert_times, + _ensure_events, _freq_mask, - _gen_events, _import_h5io_funcs, _is_numeric, + _pl, _prepare_read_metadata, _prepare_write_metadata, _time_mask, _validate_type, check_fname, + copy_doc, copy_function_doc_to_method_doc, fill_doc, + legacy, logger, + object_diff, + repr_html, sizeof_fmt, verbose, warn, ) +from ..utils.spectrum import _get_instance_type_string from ..viz.topo import _imshow_tfr, _imshow_tfr_unified, _plot_topo from ..viz.topomap import ( _add_colorbar, @@ -67,6 +73,7 @@ plot_topomap, ) from ..viz.utils import ( + _make_combine_callable, _prepare_joint_axes, _set_title_multiple_electrodes, _setup_cmap, @@ -75,7 +82,8 @@ figure_nobar, plt_show, ) -from .multitaper import dpss_windows +from .multitaper import dpss_windows, tfr_array_multitaper +from .spectrum import EpochsSpectrum @fill_doc @@ -239,7 +247,14 @@ def fwhm(freq, n_cycles): return n_cycles * np.sqrt(2 * np.log(2)) / (np.pi * freq) -def _make_dpss(sfreq, freqs, n_cycles=7.0, time_bandwidth=4.0, zero_mean=False): +def _make_dpss( + sfreq, + freqs, + n_cycles=7.0, + time_bandwidth=4.0, + zero_mean=False, + return_weights=False, +): """Compute DPSS tapers for the given frequency range. Parameters @@ -257,6 +272,8 @@ def _make_dpss(sfreq, freqs, n_cycles=7.0, time_bandwidth=4.0, zero_mean=False): Default is 4.0, giving 3 good tapers. zero_mean : bool | None, , default False Make sure the wavelet has a mean of zero. + return_weights : bool + Whether to return the concentration weights. Returns ------- @@ -304,7 +321,8 @@ def _make_dpss(sfreq, freqs, n_cycles=7.0, time_bandwidth=4.0, zero_mean=False): Wm.append(Wk) Ws.append(Wm) - + if return_weights: + return Ws, conc return Ws @@ -360,7 +378,7 @@ def _cwt_gen(X, Ws, *, fsize=0, mode="same", decim=1, use_fft=True): The time-frequency transform of the signals. """ _check_option("mode", mode, ["same", "valid", "full"]) - decim = _check_decim(decim) + decim = _ensure_slice(decim) X = np.asarray(X) # Precompute wavelets for given frequency range to save time @@ -426,6 +444,7 @@ def _compute_tfr( decim=1, output="complex", n_jobs=None, + *, verbose=None, ): """Compute time-frequency transforms. @@ -490,8 +509,7 @@ def _compute_tfr( ``'phase'`` results in shape of ``out`` being ``(n_epochs, n_chans, n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the real values in the ``output`` contain average power' and the imaginary - values contain the inter-trial coherence: - ``out = avg_power + i * ITC``. + values contain the ITC: ``out = avg_power + i * itc``. """ # Check data epoch_data = np.asarray(epoch_data) @@ -514,7 +532,7 @@ def _compute_tfr( output, ) - decim = _check_decim(decim) + decim = _ensure_slice(decim) if (freqs > sfreq / 2.0).any(): raise ValueError( "Cannot compute freq above Nyquist freq of the data " @@ -698,7 +716,7 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): dtype = np.complex128 # Init outputs - decim = _check_decim(decim) + decim = _ensure_slice(decim) n_tapers = len(Ws) n_epochs, n_times = X[:, decim].shape n_freqs = len(Ws[0]) @@ -790,7 +808,7 @@ def cwt(X, Ws, use_fft=True, mode="same", decim=1): def _cwt_array(X, Ws, nfft, mode, decim, use_fft): - decim = _check_decim(decim) + decim = _ensure_slice(decim) coefs = _cwt_gen(X, Ws, fsize=nfft, mode=mode, decim=decim, use_fft=use_fft) n_signals, n_times = X[:, decim].shape @@ -802,85 +820,31 @@ def _cwt_array(X, Ws, nfft, mode, decim, use_fft): def _tfr_aux( - method, inst, freqs, decim, return_itc, picks, average, output=None, **tfr_params + method, inst, freqs, decim, return_itc, picks, average, output, **tfr_params ): from ..epochs import BaseEpochs - """Help reduce redundancy between tfr_morlet and tfr_multitaper.""" - decim = _check_decim(decim) - data = _get_data(inst, return_itc) - info = inst.info.copy() # make a copy as sfreq can be altered - - info, data = _prepare_picks(info, data, picks, axis=1) - del picks - - if average: - if output == "complex": - raise ValueError('output must be "power" if average=True') - if return_itc: - output = "avg_power_itc" - else: - output = "avg_power" - else: - output = "power" if output is None else output - if return_itc: - raise ValueError( - "Inter-trial coherence is not supported" " with average=False" - ) - - out = _compute_tfr( - data, - freqs, - info["sfreq"], + kwargs = dict( method=method, - output=output, + freqs=freqs, + picks=picks, decim=decim, + output=output, **tfr_params, ) - times = inst.times[decim].copy() - with info._unlock(): - info["sfreq"] /= decim.step - - if average: - if return_itc: - power, itc = out.real, out.imag - else: - power = out - nave = len(data) - out = AverageTFR(info, power, times, freqs, nave, method="%s-power" % method) - if return_itc: - out = ( - out, - AverageTFR(info, itc, times, freqs, nave, method="%s-itc" % method), - ) - else: - power = out - if isinstance(inst, BaseEpochs): - meta = deepcopy(inst._metadata) - evs = deepcopy(inst.events) - ev_id = deepcopy(inst.event_id) - selection = deepcopy(inst.selection) - drop_log = deepcopy(inst.drop_log) - else: - # if the input is of class Evoked - meta = evs = ev_id = selection = drop_log = None - - out = EpochsTFR( - info, - power, - times, - freqs, - method="%s-power" % method, - events=evs, - event_id=ev_id, - selection=selection, - drop_log=drop_log, - metadata=meta, - ) - - return out - - + if isinstance(inst, BaseEpochs): + kwargs.update(average=average, return_itc=return_itc) + elif average: + logger.info("inst is Evoked, setting `average=False`") + average = False + if average and output == "complex": + raise ValueError('output must be "power" if average=True') + if not average and return_itc: + raise ValueError("Inter-trial coherence is not supported with average=False") + return inst.compute_tfr(**kwargs) + + +@legacy(alt='.compute_tfr(method="morlet")') @verbose def tfr_morlet( inst, @@ -906,7 +870,7 @@ def tfr_morlet( ---------- inst : Epochs | Evoked The epochs or evoked object. - %(freqs_tfr)s + %(freqs_tfr_array)s %(n_cycles_tfr)s use_fft : bool, default False The fft based convolution or not. @@ -977,7 +941,7 @@ def tfr_array_morlet( sfreq, freqs, n_cycles=7.0, - zero_mean=False, + zero_mean=None, use_fft=True, decim=1, output="complex", @@ -996,10 +960,15 @@ def tfr_array_morlet( The epochs. sfreq : float | int Sampling frequency of the data. - %(freqs_tfr)s + %(freqs_tfr_array)s %(n_cycles_tfr)s - zero_mean : bool + zero_mean : bool | None If True, make sure the wavelets have a mean of zero. default False. + + .. versionchanged:: 1.8 + The default will change from ``zero_mean=False`` in 1.6 to ``True`` in + 1.8, and (if not set explicitly) will raise a ``FutureWarning`` in 1.7. + use_fft : bool Use the FFT for convolutions or not. default True. %(decim_tfr)s @@ -1054,6 +1023,13 @@ def tfr_array_morlet( ---------- .. footbibliography:: """ + if zero_mean is None: + warn( + "The default value of `zero_mean` will change from `False` to `True` " + "in version 1.8. Set the value explicitly to avoid this warning.", + FutureWarning, + ) + zero_mean = False if epoch_data is not None: warn( "The parameter for providing data will be switched from `epoch_data` to " @@ -1077,6 +1053,7 @@ def tfr_array_morlet( ) +@legacy(alt='.compute_tfr(method="multitaper")') @verbose def tfr_multitaper( inst, @@ -1094,15 +1071,15 @@ def tfr_multitaper( ): """Compute Time-Frequency Representation (TFR) using DPSS tapers. - Same computation as `~mne.time_frequency.tfr_array_multitaper`, but - operates on `~mne.Epochs` or `~mne.Evoked` objects instead of + Same computation as :func:`~mne.time_frequency.tfr_array_multitaper`, but + operates on :class:`~mne.Epochs` or :class:`~mne.Evoked` objects instead of :class:`NumPy arrays `. Parameters ---------- inst : Epochs | Evoked The epochs or evoked object. - %(freqs_tfr)s + %(freqs_tfr_array)s %(n_cycles_tfr)s %(time_bandwidth_tfr)s use_fft : bool, default True @@ -1140,6 +1117,9 @@ def tfr_multitaper( .. versionadded:: 0.9.0 """ + from ..epochs import EpochsArray + from ..evoked import Evoked + tfr_params = dict( n_cycles=n_cycles, n_jobs=n_jobs, @@ -1147,23 +1127,578 @@ def tfr_multitaper( zero_mean=True, time_bandwidth=time_bandwidth, ) + if isinstance(inst, Evoked) and not average: + # convert AverageTFR to EpochsTFR for backwards compatibility + inst = EpochsArray(inst.data[np.newaxis], inst.info, tmin=inst.tmin, proj=False) return _tfr_aux( - "multitaper", inst, freqs, decim, return_itc, picks, average, **tfr_params + method="multitaper", + inst=inst, + freqs=freqs, + decim=decim, + return_itc=return_itc, + picks=picks, + average=average, + output="power", + **tfr_params, ) # TFR(s) class -class _BaseTFR(ContainsMixin, UpdateChannelsMixin, SizeMixin, ExtendedTimeMixin): - """Base TFR class.""" +@fill_doc +class BaseTFR(ContainsMixin, UpdateChannelsMixin, SizeMixin, ExtendedTimeMixin): + """Base class for RawTFR, EpochsTFR, and AverageTFR (for type checking only). + + .. note:: + This class should not be instantiated directly; it is provided in the public API + only for type-checking purposes (e.g., ``isinstance(my_obj, BaseTFR)``). To + create TFR objects, use the ``.compute_tfr()`` methods on :class:`~mne.io.Raw`, + :class:`~mne.Epochs`, or :class:`~mne.Evoked`, or use the constructors listed + below under "See Also". + + Parameters + ---------- + inst : instance of Raw, Epochs, or Evoked + The data from which to compute the time-frequency representation. + %(method_tfr)s + %(freqs_tfr)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(decim_tfr)s + %(n_jobs)s + %(reject_by_annotation_tfr)s + %(verbose)s + %(method_kw_tfr)s + + See Also + -------- + mne.time_frequency.RawTFR + mne.time_frequency.RawTFRArray + mne.time_frequency.EpochsTFR + mne.time_frequency.EpochsTFRArray + mne.time_frequency.AverageTFR + mne.time_frequency.AverageTFRArray + """ + + def __init__( + self, + inst, + method, + freqs, + tmin, + tmax, + picks, + proj, + *, + decim, + n_jobs, + reject_by_annotation=None, + verbose=None, + **method_kw, + ): + from ..epochs import BaseEpochs + from ._stockwell import tfr_array_stockwell - def __init__(self): - self.baseline = None + # triage reading from file + if isinstance(inst, dict): + self.__setstate__(inst) + return + if method is None or freqs is None: + problem = [ + f"{k}=None" + for k, v in dict(method=method, freqs=freqs).items() + if v is None + ] + # TODO when py3.11 is min version, replace if/elif/else block with + # classname = inspect.currentframe().f_back.f_code.co_qualname.split(".")[0] + _varnames = inspect.currentframe().f_back.f_code.co_varnames + if "BaseRaw" in _varnames: + classname = "RawTFR" + elif "Evoked" in _varnames: + classname = "AverageTFR" + else: + assert "BaseEpochs" in _varnames and "Evoked" not in _varnames + classname = "EpochsTFR" + # end TODO + raise ValueError( + f'{classname} got unsupported parameter value{_pl(problem)} ' + f'{" and ".join(problem)}.' + ) + # shim for tfr_array_morlet deprecation warning (TODO: remove after 1.7 release) + if method == "morlet": + method_kw.setdefault("zero_mean", True) + # check method + valid_methods = ["morlet", "multitaper"] + if isinstance(inst, BaseEpochs): + valid_methods.append("stockwell") + method = _check_option("method", method, valid_methods) + # for stockwell, `tmin, tmax` already added to `method_kw` by calling method, + # and `freqs` vector has been pre-computed + if method != "stockwell": + method_kw.update(freqs=freqs) + # ↓↓↓ if constructor called directly, prevents key error + method_kw.setdefault("output", "power") + self._freqs = np.asarray(freqs, dtype=np.float64) + del freqs + # check validity of kwargs manually to save compute time if any are invalid + tfr_funcs = dict( + morlet=tfr_array_morlet, + multitaper=tfr_array_multitaper, + stockwell=tfr_array_stockwell, + ) + _check_method_kwargs(tfr_funcs[method], method_kw, msg=f'TFR method "{method}"') + self._tfr_func = partial(tfr_funcs[method], **method_kw) + # apply proj if desired + if proj: + inst = inst.copy().apply_proj() + self.inst = inst + + # prep picks and add the info object. bads and non-data channels are dropped by + # _picks_to_idx() so we update the info accordingly: + self._picks = _picks_to_idx(inst.info, picks, "data", with_ref_meg=False) + self.info = pick_info(inst.info, sel=self._picks, copy=True) + # assign some attributes + self._method = method + self._inst_type = type(inst) + self._baseline = None + self.preload = True # needed for __getitem__, never False for TFRs + # self._dims may also get updated by child classes + self._dims = ["channel", "freq", "time"] + self._needs_taper_dim = method == "multitaper" and method_kw["output"] in ( + "complex", + "phase", + ) + if self._needs_taper_dim: + self._dims.insert(1, "taper") + self._dims = tuple(self._dims) + # get the instance data. + time_mask = _time_mask(inst.times, tmin, tmax, sfreq=self.sfreq) + get_instance_data_kw = dict(time_mask=time_mask) + if reject_by_annotation is not None: + get_instance_data_kw.update(reject_by_annotation=reject_by_annotation) + data = self._get_instance_data(**get_instance_data_kw) + # compute the TFR + self._decim = _ensure_slice(decim) + self._raw_times = inst.times[time_mask] + self._compute_tfr(data, n_jobs, verbose) + self._update_epoch_attributes() + # "apply" decim to the rest of the object (data is decimated in _compute_tfr) + with self.info._unlock(): + self.info["sfreq"] /= self._decim.step + _decim_times = inst.times[self._decim] + _decim_time_mask = _time_mask(_decim_times, tmin, tmax, sfreq=self.sfreq) + self._raw_times = _decim_times[_decim_time_mask].copy() + self._set_times(self._raw_times) self._decim = 1 + # record data type (for repr and html_repr). ITC handled in the calling method. + if method == "stockwell": + self._data_type = "Power Estimates" + else: + data_types = dict( + power="Power Estimates", + avg_power="Average Power Estimates", + avg_power_itc="Average Power Estimates", + phase="Phase", + complex="Complex Amplitude", + ) + self._data_type = data_types[method_kw["output"]] + # check for correct shape and bad values. `tfr_array_stockwell` doesn't take kw + # `output` so it may be missing here, so use `.get()` + negative_ok = method_kw.get("output", "") in ("complex", "phase") + # if method_kw.get("output", None) in ("phase", "complex"): + # raise RuntimeError + self._check_values(negative_ok=negative_ok) + # we don't need these anymore, and they make save/load harder + del self._picks + del self._tfr_func + del self._needs_taper_dim + del self._shape # calculated from self._data henceforth + del self.inst # save memory + + def __abs__(self): + """Return the absolute value.""" + tfr = self.copy() + tfr.data = np.abs(tfr.data) + return tfr + + @fill_doc + def __add__(self, other): + """Add two TFR instances. + + %(__add__tfr)s + """ + self._check_compatibility(other) + out = self.copy() + out.data += other.data + return out + + @fill_doc + def __iadd__(self, other): + """Add a TFR instance to another, in-place. + + %(__iadd__tfr)s + """ + self._check_compatibility(other) + self.data += other.data + return self + + @fill_doc + def __sub__(self, other): + """Subtract two TFR instances. + + %(__sub__tfr)s + """ + self._check_compatibility(other) + out = self.copy() + out.data -= other.data + return out + + @fill_doc + def __isub__(self, other): + """Subtract a TFR instance from another, in-place. + + %(__isub__tfr)s + """ + self._check_compatibility(other) + self.data -= other.data + return self + + @fill_doc + def __mul__(self, num): + """Multiply a TFR instance by a scalar. + + %(__mul__tfr)s + """ + out = self.copy() + out.data *= num + return out + + @fill_doc + def __imul__(self, num): + """Multiply a TFR instance by a scalar, in-place. + + %(__imul__tfr)s + """ + self.data *= num + return self + + @fill_doc + def __truediv__(self, num): + """Divide a TFR instance by a scalar. + + %(__truediv__tfr)s + """ + out = self.copy() + out.data /= num + return out + + @fill_doc + def __itruediv__(self, num): + """Divide a TFR instance by a scalar, in-place. + + %(__itruediv__tfr)s + """ + self.data /= num + return self + + def __eq__(self, other): + """Test equivalence of two TFR instances.""" + return object_diff(vars(self), vars(other)) == "" + + def __getstate__(self): + """Prepare object for serialization.""" + return dict( + method=self.method, + data=self._data, + sfreq=self.sfreq, + dims=self._dims, + freqs=self.freqs, + times=self.times, + inst_type_str=_get_instance_type_string(self), + data_type=self._data_type, + info=self.info, + baseline=self._baseline, + decim=self._decim, + ) + + def __setstate__(self, state): + """Unpack from serialized format.""" + from ..epochs import Epochs + from ..evoked import Evoked + from ..io import Raw + + defaults = dict( + method="unknown", + dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :], + baseline=None, + decim=1, + data_type="TFR", + inst_type_str="Unknown", + ) + defaults.update(**state) + self._method = defaults["method"] + self._data = defaults["data"] + self._freqs = np.asarray(defaults["freqs"], dtype=np.float64) + self._dims = defaults["dims"] + self._raw_times = np.asarray(defaults["times"], dtype=np.float64) + self._baseline = defaults["baseline"] + self.info = Info(**defaults["info"]) + self._data_type = defaults["data_type"] + self._decim = defaults["decim"] + self.preload = True + self._set_times(self._raw_times) + # Handle instance type. Prior to gh-11282, Raw was not a possibility so if + # `inst_type_str` is missing it must be Epochs or Evoked + unknown_class = Epochs if self._data.ndim == 4 else Evoked + inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class) + self._inst_type = inst_types[defaults["inst_type_str"]] + # sanity check data/freqs/times/info agreement + self._check_state() + + def __repr__(self): + """Build string representation of the TFR object.""" + inst_type_str = _get_instance_type_string(self) + nave = f" (nave={self.nave})" if hasattr(self, "nave") else "" + # shape & dimension names + dims = " × ".join( + [f"{size} {dim}s" for size, dim in zip(self.shape, self._dims)] + ) + freq_range = f"{self.freqs[0]:0.1f} - {self.freqs[-1]:0.1f} Hz" + time_range = f"{self.times[0]:0.2f} - {self.times[-1]:0.2f} s" + return ( + f"<{self._data_type} from {inst_type_str}{nave}, " + f"{self.method} method | {dims}, {freq_range}, {time_range}, " + f"{sizeof_fmt(self._size)}>" + ) + + @repr_html + def _repr_html_(self, caption=None): + """Build HTML representation of the TFR object.""" + from ..html_templates import _get_html_template + + inst_type_str = _get_instance_type_string(self) + nave = getattr(self, "nave", 0) + t = _get_html_template("repr", "tfr.html.jinja") + t = t.render(tfr=self, inst_type=inst_type_str, nave=nave, caption=caption) + return t + + def _check_compatibility(self, other): + """Check compatibility of two TFR instances, in preparation for arithmetic.""" + operation = inspect.currentframe().f_back.f_code.co_name.strip("_") + if operation.startswith("i"): + operation = operation[1:] + msg = f"Cannot {operation} the two TFR instances: {{}} do not match{{}}." + extra = "" + if not isinstance(other, type(self)): + problem = "types" + extra = f" (self is {type(self)}, other is {type(other)})" + elif not self.times.shape == other.times.shape or np.any( + self.times != other.times + ): + problem = "times" + elif not self.freqs.shape == other.freqs.shape or np.any( + self.freqs != other.freqs + ): + problem = "freqs" + else: # should be OK + return + raise RuntimeError(msg.format(problem, extra)) + + def _check_state(self): + """Check data/freqs/times/info agreement during __setstate__.""" + msg = "{} axis of data ({}) doesn't match {} attribute ({})" + n_chan_info = len(self.info["chs"]) + n_chan, n_freq, n_time = self._data.shape[self._dims.index("channel") :] + if n_chan_info != n_chan: + msg = msg.format("Channel", n_chan, "info", n_chan_info) + elif n_freq != len(self.freqs): + msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) + elif n_time != len(self.times): + msg = msg.format("Time", n_time, "times", self.times.size) + else: + return + raise ValueError(msg) + + def _check_values(self, negative_ok=False): + """Check TFR results for correct shape and bad values.""" + assert len(self._dims) == self._data.ndim + assert self._data.shape == self._shape + # Check for implausible power values: take min() across all but the channel axis + # TODO: should this be more fine-grained (report "chan X in epoch Y")? + ch_dim = self._dims.index("channel") + dims = np.arange(self._data.ndim).tolist() + dims.pop(ch_dim) + negative_values = self._data.min(axis=tuple(dims)) < 0 + if negative_values.any() and not negative_ok: + chs = np.array(self.ch_names)[negative_values].tolist() + s = _pl(negative_values.sum()) + warn( + f"Negative value in time-frequency decomposition for channel{s} " + f'{", ".join(chs)}', + UserWarning, + ) + + def _compute_tfr(self, data, n_jobs, verbose): + result = self._tfr_func( + data, + self.sfreq, + decim=self._decim, + n_jobs=n_jobs, + verbose=verbose, + ) + # assign ._data and maybe ._itc + # tfr_array_stockwell always returns ITC (sometimes it's None) + if self.method == "stockwell": + self._data, self._itc, freqs = result + assert np.array_equal(self._freqs, freqs) + elif self._tfr_func.keywords.get("output", "").endswith("_itc"): + self._data, self._itc = result.real, result.imag + else: + self._data = result + # remove fake "epoch" dimension + if self.method != "stockwell" and _get_instance_type_string(self) != "Epochs": + self._data = np.squeeze(self._data, axis=0) + + # this is *expected* shape, it gets asserted later in _check_values() + # (and then deleted afterwards) + expected_shape = [ + len(self.ch_names), + len(self.freqs), + len(self._raw_times[self._decim]), # don't use self.times, not set yet + ] + # deal with the "taper" dimension + if self._needs_taper_dim: + expected_shape.insert(1, self._data.shape[1]) + self._shape = tuple(expected_shape) + + @verbose + def _onselect( + self, + eclick, + erelease, + picks=None, + exclude="bads", + combine="mean", + baseline=None, + mode=None, + cmap=None, + source_plot_joint=False, + topomap_args=None, + verbose=None, + ): + """Respond to rectangle selector in TFR image plots with a topomap plot.""" + if abs(eclick.x - erelease.x) < 0.1 or abs(eclick.y - erelease.y) < 0.1: + return + t_range = (min(eclick.xdata, erelease.xdata), max(eclick.xdata, erelease.xdata)) + f_range = (min(eclick.ydata, erelease.ydata), max(eclick.ydata, erelease.ydata)) + # snap to nearest measurement point + t_idx = np.abs(self.times - np.atleast_2d(t_range).T).argmin(axis=1) + f_idx = np.abs(self.freqs - np.atleast_2d(f_range).T).argmin(axis=1) + tmin, tmax = self.times[t_idx] + fmin, fmax = self.freqs[f_idx] + # immutable → mutable default + if topomap_args is None: + topomap_args = dict() + topomap_args.setdefault("cmap", cmap) + topomap_args.setdefault("vlim", (None, None)) + # figure out which channel types we're dealing with + types = list() + if "eeg" in self: + types.append("eeg") + if "mag" in self: + types.append("mag") + if "grad" in self: + grad_picks = _pair_grad_sensors( + self.info, topomap_coords=False, raise_error=False + ) + if len(grad_picks) > 1: + types.append("grad") + elif len(types) == 0: + logger.info( + "Need at least 2 gradiometer pairs to plot a gradiometer topomap." + ) + return # Don't draw a figure for nothing. + + fig = figure_nobar() + t_range = f"{tmin:.3f}" if tmin == tmax else f"{tmin:.3f} - {tmax:.3f}" + f_range = f"{fmin:.2f}" if fmin == fmax else f"{fmin:.2f} - {fmax:.2f}" + fig.suptitle(f"{t_range} s,\n{f_range} Hz") + + if source_plot_joint: + ax = fig.add_subplot() + data, times, freqs = self.get_data( + picks=picks, exclude=exclude, return_times=True, return_freqs=True + ) + # merge grads before baselining (makes ERDs visible) + ch_types = np.array(self.get_channel_types(unique=True)) + ch_type = ch_types.item() # will error if there are more than one + data, pos = _merge_if_grads( + data=data, + info=self.info, + ch_type=ch_type, + sphere=topomap_args.get("sphere"), + combine=combine, + ) + # baseline and crop + data, *_ = _prep_data_for_plot( + data, + times, + freqs, + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, + verbose=verbose, + ) + # average over times and freqs + data = data.mean((-2, -1)) + + im, _ = plot_topomap(data, pos, axes=ax, show=False, **topomap_args) + _add_colorbar(ax, im, topomap_args["cmap"], title="AU") + plt_show(fig=fig) + else: + for idx, ch_type in enumerate(types): + ax = fig.add_subplot(1, len(types), idx + 1) + plot_tfr_topomap( + self, + ch_type=ch_type, + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, + axes=ax, + **topomap_args, + ) + ax.set_title(ch_type) + + def _update_epoch_attributes(self): + # overwritten in EpochsTFR; adds things needed for to_data_frame and __getitem__ + pass + + @property + def _detrend_picks(self): + """Provide compatibility with __iter__.""" + return list() + + @property + def baseline(self): + """Start and end of the baseline period (in seconds).""" + return self._baseline + + @property + def ch_names(self): + """The channel names.""" + return self.info["ch_names"] @property def data(self): + """The time-frequency-resolved power estimates.""" return self._data @data.setter @@ -1171,9 +1706,29 @@ def data(self, data): self._data = data @property - def ch_names(self): - """Channel names.""" - return self.info["ch_names"] + def freqs(self): + """The frequencies at which power estimates were computed.""" + return self._freqs + + @property + def method(self): + """The method used to compute the time-frequency power estimates.""" + return self._method + + @property + def sfreq(self): + """Sampling frequency of the data.""" + return self.info["sfreq"] + + @property + def shape(self): + """Data shape.""" + return self._data.shape + + @property + def times(self): + """The time points present in the data (in seconds).""" + return self._times_readonly @fill_doc def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): @@ -1181,10 +1736,7 @@ def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): Parameters ---------- - tmin : float | None - Start time of selection in seconds. - tmax : float | None - End time of selection in seconds. + %(tmin_tmax_psd)s fmin : float | None Lowest frequency of selection in Hz. @@ -1197,7 +1749,7 @@ def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): Returns ------- - inst : instance of AverageTFR + %(inst_tfr)s The modified instance. """ super().crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax) @@ -1209,7 +1761,7 @@ def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): else: freq_mask = slice(None) - self.freqs = self.freqs[freq_mask] + self._freqs = self.freqs[freq_mask] # Deal with broadcasting (boolean arrays do not broadcast, but indices # do, so we need to convert freq_mask to make use of broadcasting) if isinstance(freq_mask, np.ndarray): @@ -1218,12 +1770,12 @@ def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): return self def copy(self): - """Return a copy of the instance. + """Return copy of the TFR instance. Returns ------- - copy : instance of EpochsTFR | instance of AverageTFR - A copy of the instance. + %(inst_tfr)s + A copy of the object. """ return deepcopy(self) @@ -1233,14 +1785,9 @@ def apply_baseline(self, baseline, mode="mean", verbose=None): Parameters ---------- - baseline : array-like, shape (2,) - The time interval to apply rescaling / baseline correction. - If None do not apply it. If baseline is (a, b) - the interval is between "a (s)" and "b (s)". - If a is None the beginning of the data is used - and if b is None then b is set to the end of the interval. - If baseline is equal to (None, None) all the time - interval is used. + %(baseline_rescale)s + + How baseline is computed is determined by the ``mode`` parameter. mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' Perform baseline correction by @@ -1259,524 +1806,313 @@ def apply_baseline(self, baseline, mode="mean", verbose=None): Returns ------- - inst : instance of AverageTFR + %(inst_tfr)s The modified instance. - """ # noqa: E501 - self.baseline = _check_baseline( - baseline, times=self.times, sfreq=self.info["sfreq"] - ) - rescale(self.data, self.times, self.baseline, mode, copy=False) + """ + self._baseline = _check_baseline(baseline, times=self.times, sfreq=self.sfreq) + rescale(self.data, self.times, self.baseline, mode, copy=False, verbose=verbose) return self - @verbose - def save(self, fname, overwrite=False, *, verbose=None): - """Save TFR object to hdf5 file. + @fill_doc + def get_data( + self, + picks=None, + exclude="bads", + fmin=None, + fmax=None, + tmin=None, + tmax=None, + return_times=False, + return_freqs=False, + ): + """Get time-frequency data in NumPy array format. Parameters ---------- - fname : path-like - The file name, which should end with ``-tfr.h5``. - %(overwrite)s - %(verbose)s + %(picks_good_data_noref)s + %(exclude_spectrum_get_data)s + %(fmin_fmax_tfr)s + %(tmin_tmax_psd)s + return_times : bool + Whether to return the time values for the requested time range. + Default is ``False``. + return_freqs : bool + Whether to return the frequency bin values for the requested + frequency range. Default is ``False``. - See Also - -------- - read_tfrs, write_tfrs - """ - write_tfrs(fname, self, overwrite=overwrite) + Returns + ------- + data : array + The requested data in a NumPy array. + times : array + The time values for the requested data range. Only returned if + ``return_times`` is ``True``. + freqs : array + The frequency values for the requested data range. Only returned if + ``return_freqs`` is ``True``. - @verbose - def to_data_frame( - self, - picks=None, - index=None, - long_format=False, - time_format=None, - *, - verbose=None, - ): - """Export data in tabular structure as a pandas DataFrame. - - Channels are converted to columns in the DataFrame. By default, - additional columns ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` (epoch event description) are added, unless ``index`` - is not ``None`` (in which case the columns specified in ``index`` will - be used to form the DataFrame's index instead). ``'epoch'``, and - ``'condition'`` are not supported for ``AverageTFR``. - - Parameters - ---------- - %(picks_all)s - %(index_df_epo)s - Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'`` - for ``AverageTFR``. - Defaults to ``None``. - %(long_format_df_epo)s - %(time_format_df)s - - .. versionadded:: 0.23 - %(verbose)s - - Returns - ------- - %(df_return)s + Notes + ----- + Returns a copy of the underlying data (not a view). """ - # check pandas once here, instead of in each private utils function - pd = _check_pandas_installed() # noqa - # arg checking - valid_index_args = ["time", "freq"] - if isinstance(self, EpochsTFR): - valid_index_args.extend(["epoch", "condition"]) - valid_time_formats = ["ms", "timedelta"] - index = _check_pandas_index_arguments(index, valid_index_args) - time_format = _check_time_format(time_format, valid_time_formats) - # get data - times = self.times - picks = _picks_to_idx(self.info, picks, "all", exclude=()) - if isinstance(self, EpochsTFR): - data = self.data[:, picks, :, :] - else: - data = self.data[np.newaxis, picks] # add singleton "epochs" axis - n_epochs, n_picks, n_freqs, n_times = data.shape - # reshape to (epochs*freqs*times) x signals - data = np.moveaxis(data, 1, -1) - data = data.reshape(n_epochs * n_freqs * n_times, n_picks) - # prepare extra columns / multiindex - mindex = list() - times = np.tile(times, n_epochs * n_freqs) - times = _convert_times(times, time_format, self.info["meas_date"]) - mindex.append(("time", times)) - freqs = self.freqs - freqs = np.tile(np.repeat(freqs, n_times), n_epochs) - mindex.append(("freq", freqs)) - if isinstance(self, EpochsTFR): - mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) - rev_event_id = {v: k for k, v in self.event_id.items()} - conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) - assert all(len(mdx) == len(mindex[0]) for mdx in mindex) - # build DataFrame - if isinstance(self, EpochsTFR): - default_index = ["condition", "epoch", "freq", "time"] - else: - default_index = ["freq", "time"] - df = _build_data_frame( - self, data, picks, long_format, mindex, index, default_index=default_index + tmin = self.times[0] if tmin is None else tmin + tmax = self.times[-1] if tmax is None else tmax + fmin = 0 if fmin is None else fmin + fmax = np.inf if fmax is None else fmax + picks = _picks_to_idx( + self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False ) - return df - - -@fill_doc -class AverageTFR(_BaseTFR): - """Container for Time-Frequency data. - - Can for example store induced power at sensor level or inter-trial - coherence. - - Parameters - ---------- - %(info_not_none)s - data : ndarray, shape (n_channels, n_freqs, n_times) - The data. - times : ndarray, shape (n_times,) - The time values in seconds. - freqs : ndarray, shape (n_freqs,) - The frequencies in Hz. - nave : int - The number of averaged TFRs. - comment : str | None, default None - Comment on the data, e.g., the experimental condition. - method : str | None, default None - Comment on the method used to compute the data, e.g., morlet wavelet. - %(verbose)s - - Attributes - ---------- - %(info_not_none)s - ch_names : list - The names of the channels. - nave : int - Number of averaged epochs. - data : ndarray, shape (n_channels, n_freqs, n_times) - The data array. - times : ndarray, shape (n_times,) - The time values in seconds. - freqs : ndarray, shape (n_freqs,) - The frequencies in Hz. - comment : str - Comment on dataset. Can be the condition. - method : str | None, default None - Comment on the method used to compute the data, e.g., morlet wavelet. - """ - - @verbose - def __init__( - self, info, data, times, freqs, nave, comment=None, method=None, verbose=None - ): - super().__init__() - self.info = info - if data.ndim != 3: - raise ValueError("data should be 3d. Got %d." % data.ndim) - n_channels, n_freqs, n_times = data.shape - if n_channels != len(info["chs"]): - raise ValueError( - "Number of channels and data size don't match" - " (%d != %d)." % (n_channels, len(info["chs"])) - ) - if n_freqs != len(freqs): - raise ValueError( - "Number of frequencies and data size don't match" - " (%d != %d)." % (n_freqs, len(freqs)) - ) - if n_times != len(times): - raise ValueError( - "Number of times and data size don't match" - " (%d != %d)." % (n_times, len(times)) - ) - self.data = data - self._set_times(np.array(times, dtype=float)) - self._raw_times = self.times.copy() - self.freqs = np.array(freqs, dtype=float) - self.nave = nave - self.comment = comment - self.method = method - self.preload = True + fmin_idx = np.searchsorted(self.freqs, fmin) + fmax_idx = np.searchsorted(self.freqs, fmax, side="right") + tmin_idx = np.searchsorted(self.times, tmin) + tmax_idx = np.searchsorted(self.times, tmax, side="right") + freq_picks = np.arange(fmin_idx, fmax_idx) + time_picks = np.arange(tmin_idx, tmax_idx) + freq_axis = self._dims.index("freq") + time_axis = self._dims.index("time") + chan_axis = self._dims.index("channel") + # normally there's a risk of np.take reducing array dimension if there + # were only one channel or frequency selected, but `_picks_to_idx` + # and np.arange both always return arrays, so we're safe; the result + # will always have the same `ndim` as it started with. + data = ( + self._data.take(picks, chan_axis) + .take(freq_picks, freq_axis) + .take(time_picks, time_axis) + ) + out = [data] + if return_times: + times = self._raw_times[tmin_idx:tmax_idx] + out.append(times) + if return_freqs: + freqs = self._freqs[fmin_idx:fmax_idx] + out.append(freqs) + if not return_times and not return_freqs: + return out[0] + return tuple(out) @verbose def plot( self, picks=None, - baseline=None, - mode="mean", + *, + exclude=(), tmin=None, tmax=None, - fmin=None, - fmax=None, + fmin=0.0, + fmax=np.inf, + baseline=None, + mode="mean", + dB=False, + combine=None, + layout=None, # TODO deprecate? not used in orig implementation either + yscale="auto", vmin=None, vmax=None, - cmap="RdBu_r", - dB=False, + vlim=(None, None), + cnorm=None, + cmap=None, colorbar=True, - show=True, - title=None, - axes=None, - layout=None, - yscale="auto", + title=None, # don't deprecate this one; has (useful) option title="auto" mask=None, mask_style=None, mask_cmap="Greys", mask_alpha=0.1, - combine=None, - exclude=(), - cnorm=None, + axes=None, + show=True, verbose=None, ): - """Plot TFRs as a two-dimensional image(s). + """Plot TFRs as two-dimensional time-frequency images. Parameters ---------- %(picks_good_data)s - baseline : None (default) or tuple, shape (2,) - The time interval to apply baseline correction. - If None do not apply it. If baseline is (a, b) - the interval is between "a (s)" and "b (s)". - If a is None the beginning of the data is used - and if b is None then b is set to the end of the interval. - If baseline is equal to (None, None) all the time - interval is used. - mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' - Perform baseline correction by + %(exclude_spectrum_plot)s + %(tmin_tmax_psd)s + %(fmin_fmax_tfr)s + %(baseline_rescale)s - - subtracting the mean of baseline values ('mean') (default) - - dividing by the mean of baseline values ('ratio') - - dividing by the mean of baseline values and taking the log - ('logratio') - - subtracting the mean of baseline values followed by dividing by - the mean of baseline values ('percent') - - subtracting the mean of baseline values and dividing by the - standard deviation of baseline values ('zscore') - - dividing by the mean of baseline values, taking the log, and - dividing by the standard deviation of log baseline values - ('zlogratio') + How baseline is computed is determined by the ``mode`` parameter. + %(mode_tfr_plot)s + %(dB_spectrum_plot)s + %(combine_tfr_plot)s - tmin : None | float - The first time instant to display. If None the first time point - available is used. Defaults to None. - tmax : None | float - The last time instant to display. If None the last time point - available is used. Defaults to None. - fmin : None | float - The first frequency to display. If None the first frequency - available is used. Defaults to None. - fmax : None | float - The last frequency to display. If None the last frequency - available is used. Defaults to None. - vmin : float | None - The minimum value an the color scale. If vmin is None, the data - minimum value is used. Defaults to None. - vmax : float | None - The maximum value an the color scale. If vmax is None, the data - maximum value is used. Defaults to None. - cmap : matplotlib colormap | 'interactive' | (colormap, bool) - The colormap to use. If tuple, the first value indicates the - colormap to use and the second value is a boolean defining - interactivity. In interactive mode the colors are adjustable by - clicking and dragging the colorbar with left and right mouse - button. Left mouse button moves the scale up and down and right - mouse button adjusts the range. Hitting space bar resets the range. - Up and down arrows can be used to change the colormap. If - 'interactive', translates to ('RdBu_r', True). Defaults to - 'RdBu_r'. - - .. warning:: Interactive mode works smoothly only for a small - amount of images. - - dB : bool - If True, 10*log10 is applied to the data to get dB. - Defaults to False. - colorbar : bool - If true, colorbar will be added to the plot. Defaults to True. - show : bool - Call pyplot.show() at the end. Defaults to True. - title : str | 'auto' | None - String for ``title``. Defaults to None (blank/no title). If - 'auto', and ``combine`` is None, the title for each figure - will be the channel name. If 'auto' and ``combine`` is not None, - ``title`` states how many channels were combined into that figure - and the method that was used for ``combine``. If str, that String - will be the title for each figure. - axes : instance of Axes | list | None - The axes to plot to. If list, the list must be a list of Axes of - the same length as ``picks``. If instance of Axes, there must be - only one channel plotted. If ``combine`` is not None, ``axes`` - must either be an instance of Axes, or a list of length 1. - layout : Layout | None - Layout instance specifying sensor positions. Used for interactive - plotting of topographies on rectangle selection. If possible, the - correct layout is inferred from the data. - yscale : 'auto' (default) | 'linear' | 'log' - The scale of y (frequency) axis. 'linear' gives linear y axis, - 'log' leads to log-spaced y axis and 'auto' detects if frequencies - are log-spaced and only then sets the y axis to 'log'. + .. versionchanged:: 1.3 + Added support for ``callable``. + %(layout_spectrum_plot_topo)s + %(yscale_tfr_plot)s .. versionadded:: 0.14.0 - mask : ndarray | None - An array of booleans of the same shape as the data. Entries of the - data that correspond to False in the mask are plotted - transparently. Useful for, e.g., masking for statistical - significance. + %(vmin_vmax_tfr_plot)s + %(vlim_tfr_plot)s + %(cnorm)s + + .. versionadded:: 0.24 + %(cmap_topomap)s + %(colorbar)s + %(title_tfr_plot)s + %(mask_tfr_plot)s .. versionadded:: 0.16.0 - mask_style : None | 'both' | 'contour' | 'mask' - If ``mask`` is not None: if ``'contour'``, a contour line is drawn - around the masked areas (``True`` in ``mask``). If ``'mask'``, - entries not ``True`` in ``mask`` are shown transparently. If - ``'both'``, both a contour and transparency are used. - If ``None``, defaults to ``'both'`` if ``mask`` is not None, and is - ignored otherwise. + %(mask_style_tfr_plot)s .. versionadded:: 0.17 - mask_cmap : matplotlib colormap | (colormap, bool) | 'interactive' - The colormap chosen for masked parts of the image (see below), if - ``mask`` is not ``None``. If None, ``cmap`` is reused. Defaults to - ``'Greys'``. Not interactive. Otherwise, as ``cmap``. + %(mask_cmap_tfr_plot)s .. versionadded:: 0.17 - mask_alpha : float - A float between 0 and 1. If ``mask`` is not None, this sets the - alpha level (degree of transparency) for the masked-out segments. - I.e., if 0, masked-out segments are not visible at all. - Defaults to 0.1. + %(mask_alpha_tfr_plot)s .. versionadded:: 0.16.0 - combine : 'mean' | 'rms' | callable | None - Type of aggregation to perform across selected channels. If - None, plot one figure per selected channel. If a function, it must - operate on an array of shape ``(n_channels, n_freqs, n_times)`` and - return an array of shape ``(n_freqs, n_times)``. - - .. versionchanged:: 1.3 - Added support for ``callable``. - exclude : list of str | 'bads' - Channels names to exclude from being shown. If 'bads', the - bad channels are excluded. Defaults to an empty list. - %(cnorm)s - - .. versionadded:: 0.24 + %(axes_tfr_plot)s + %(show)s %(verbose)s Returns ------- figs : list of instances of matplotlib.figure.Figure A list of figures containing the time-frequency power. - """ # noqa: E501 - return self._plot( - picks=picks, - baseline=baseline, - mode=mode, + """ + # deprecations + vlim = _warn_deprecated_vmin_vmax(vlim, vmin, vmax) + # the rectangle selector plots topomaps, which needs all channels uncombined, + # so we keep a reference to that state here, and (because the topomap plotting + # function wants an AverageTFR) update it with `comment` and `nave` values in + # case we started out with a singleton EpochsTFR or RawTFR + initial_state = self.__getstate__() + initial_state.setdefault("comment", "") + initial_state.setdefault("nave", 1) + # `_picks_to_idx` also gets done inside `get_data()`` below, but we do it here + # because we need the indices later + idx_picks = _picks_to_idx( + self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False + ) + pick_names = np.array(self.ch_names)[idx_picks].tolist() # for titles + ch_types = self.get_channel_types(idx_picks) + # get data arrays + data, times, freqs = self.get_data( + picks=idx_picks, exclude=(), return_times=True, return_freqs=True + ) + # pass tmin/tmax here ↓↓↓, not here ↑↑↑; we want to crop *after* baselining + data, times, freqs = _prep_data_for_plot( + data, + times, + freqs, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, - vmin=vmin, - vmax=vmax, - cmap=cmap, + baseline=baseline, + mode=mode, dB=dB, - colorbar=colorbar, - show=show, - title=title, - axes=axes, - layout=layout, - yscale=yscale, - mask=mask, - mask_style=mask_style, - mask_cmap=mask_cmap, - mask_alpha=mask_alpha, - combine=combine, - exclude=exclude, - cnorm=cnorm, verbose=verbose, ) - - @verbose - def _plot( - self, - picks=None, - baseline=None, - mode="mean", - tmin=None, - tmax=None, - fmin=None, - fmax=None, - vmin=None, - vmax=None, - cmap="RdBu_r", - dB=False, - colorbar=True, - show=True, - title=None, - axes=None, - layout=None, - yscale="auto", - mask=None, - mask_style=None, - mask_cmap="Greys", - mask_alpha=0.25, - combine=None, - exclude=None, - copy=True, - source_plot_joint=False, - topomap_args=None, - ch_type=None, - cnorm=None, - verbose=None, - ): - """Plot TFRs as a two-dimensional image(s). - - See self.plot() for parameters description. - """ - _validate_type(topomap_args, (dict, None), "topomap_args") - topomap_args = {} if topomap_args is None else topomap_args - import matplotlib.pyplot as plt - - # channel selection - # simply create a new tfr object(s) with the desired channel selection - tfr = _preproc_tfr_instance( - self, - picks, - tmin, - tmax, - fmin, - fmax, - vmin, - vmax, - dB, - mode, - baseline, - exclude, - copy, + # shape + ch_axis = self._dims.index("channel") + freq_axis = self._dims.index("freq") + time_axis = self._dims.index("time") + want_shape = list(self.shape) + want_shape[ch_axis] = len(idx_picks) if combine is None else 1 + want_shape[freq_axis] = len(freqs) # in case there was fmin/fmax cropping + want_shape[time_axis] = len(times) # in case there was tmin/tmax cropping + want_shape = tuple(want_shape) + # combine + combine_was_none = combine is None + combine = _make_combine_callable( + combine, axis=ch_axis, valid=("mean", "rms"), keepdims=True ) - del picks - - data = tfr.data - n_picks = len(tfr.ch_names) if combine is None else 1 - - # combine picks - _validate_type(combine, (None, str, "callable")) - if isinstance(combine, str): - _check_option("combine", combine, ("mean", "rms")) - if combine == "mean": - data = data.mean(axis=0, keepdims=True) - elif combine == "rms": - data = np.sqrt((data**2).mean(axis=0, keepdims=True)) - elif combine is not None: # callable - # It must operate on (n_channels, n_freqs, n_times) and return - # (n_freqs, n_times). Operates on a copy in-case 'combine' does - # some in-place operations. - try: - data = combine(data.copy()) - except TypeError: - raise RuntimeError( - "A callable 'combine' must operate on a single argument, " - "a numpy array of shape (n_channels, n_freqs, n_times)." - ) - if not isinstance(data, np.ndarray) or data.shape != tfr.data.shape[1:]: - raise RuntimeError( - "A callable 'combine' must return a numpy array of shape " - "(n_freqs, n_times)." - ) - # keep initial dimensions + try: + data = combine(data) # no need to copy; get_data() never returns a view + except Exception as e: + msg = ( + "Something went wrong with the callable passed to 'combine'; see " + "traceback." + ) + raise ValueError(msg) from e + # call succeeded, check type and shape + mismatch = False + if not isinstance(data, np.ndarray): + mismatch = "type" + extra = "" + elif data.shape not in (want_shape, want_shape[1:]): + mismatch = "shape" + extra = f" of shape {data.shape}" + if mismatch: + raise RuntimeError( + f"Wrong {mismatch} yielded by callable passed to 'combine'. Make sure " + "your function takes a single argument (an array of shape " + "(n_channels, n_freqs, n_times)) and returns an array of shape " + f"(n_freqs, n_times); yours yielded: {type(data)}{extra}." + ) + # restore singleton collapsed axis (removed by user-provided callable): + # (n_freqs, n_times) → (1, n_freqs, n_times) + if data.shape == (len(freqs), len(times)): data = data[np.newaxis] - # figure overhead - # set plot dimension - tmin, tmax = tfr.times[[0, -1]] - if vmax is None: - vmax = np.abs(data).max() - if vmin is None: - vmin = -np.abs(data).max() - - # set colorbar - cmap = _setup_cmap(cmap) - - # make sure there are as many axes as there will be channels to plot - if isinstance(axes, list) or isinstance(axes, np.ndarray): - figs_and_axes = [(ax.get_figure(), ax) for ax in axes] + assert data.shape == want_shape + # cmap handling. power may be negative depending on baseline strategy so set + # `norm` empirically — but only if user didn't set limits explicitly. + norm = False if vlim == (None, None) else data.min() >= 0.0 + vmin, vmax = _setup_vmin_vmax(data, *vlim, norm=norm) + cmap = _setup_cmap(cmap, norm=norm) + # prepare figure(s) + if axes is None: + figs = [plt.figure(layout="constrained") for _ in range(data.shape[0])] + axes = [fig.add_subplot() for fig in figs] elif isinstance(axes, plt.Axes): - figs_and_axes = [(ax.get_figure(), ax) for ax in [axes]] - elif axes is None: - figs = [plt.figure(layout="constrained") for i in range(n_picks)] - figs_and_axes = [(fig, fig.add_subplot(111)) for fig in figs] + figs = [axes.get_figure()] + axes = [axes] + elif isinstance(axes, np.ndarray): # allow plotting into a grid of axes + figs = [ax.get_figure() for ax in axes.flat] + elif hasattr(axes, "__iter__") and len(axes): + figs = [ax.get_figure() for ax in axes] else: - raise ValueError("axes must be None, plt.Axes, or list " "of plt.Axes.") - if len(figs_and_axes) != n_picks: - raise RuntimeError("There must be an axes for each picked " "channel.") - - for idx in range(n_picks): - fig = figs_and_axes[idx][0] - ax = figs_and_axes[idx][1] - onselect_callback = partial( - tfr._onselect, + raise ValueError( + f"axes must be None, Axes, or list/array of Axes, got {type(axes)}" + ) + if len(axes) != data.shape[0]: + raise RuntimeError( + f"Mismatch between picked channels ({data.shape[0]}) and axes " + f"({len(axes)}); there must be one axes for each picked channel." + ) + # check if we're being called from within plot_joint(). If so, get the + # `topomap_args` from the calling context and pass it to the onselect handler. + # (we need 2 `f_back` here because of the verbose decorator) + calling_frame = inspect.currentframe().f_back.f_back + source_plot_joint = calling_frame.f_code.co_name == "plot_joint" + topomap_args = ( + dict() + if not source_plot_joint + else calling_frame.f_locals.get("topomap_args", dict()) + ) + # plot + for ix, _fig in enumerate(figs): + # restrict the onselect instance to the channel type of the picks used in + # the image plot + uniq_types = np.unique(ch_types) + ch_type = None if len(uniq_types) > 1 else uniq_types.item() + this_tfr = AverageTFR(inst=initial_state).pick(ch_type, verbose=verbose) + _onselect_callback = partial( + this_tfr._onselect, + picks=None, # already restricted the picks in `this_tfr` + exclude=(), + baseline=baseline, + mode=mode, cmap=cmap, source_plot_joint=source_plot_joint, - topomap_args={ - k: v - for k, v in topomap_args.items() - if k not in {"vmin", "vmax", "cmap", "axes"} - }, + topomap_args=topomap_args, ) + # draw the image plot _imshow_tfr( - ax, - 0, - tmin, - tmax, - vmin, - vmax, - onselect_callback, + ax=axes[ix], + tfr=data[[ix]], + ch_idx=0, + tmin=times[0], + tmax=times[-1], + vmin=vmin, + vmax=vmax, + onselect=_onselect_callback, ylim=None, - tfr=data[idx : idx + 1], - freq=tfr.freqs, + freq=freqs, x_label="Time (s)", y_label="Frequency (Hz)", colorbar=colorbar, @@ -1788,123 +2124,83 @@ def _plot( mask_alpha=mask_alpha, cnorm=cnorm, ) - + # handle title. automatic title is: + # f"{Baselined} {power} ({ch_name})" or + # f"{Baselined} {power} ({combination} of {N} {ch_type}s)" if title == "auto": - if len(tfr.info["ch_names"]) == 1 or combine is None: - subtitle = tfr.info["ch_names"][idx] - else: - subtitle = _set_title_multiple_electrodes( - None, combine, tfr.info["ch_names"], all_=True, ch_type=ch_type + if combine_was_none: # one plot per channel + which_chs = pick_names[ix] + elif len(pick_names) == 1: # there was only one pick anyway + which_chs = pick_names[0] + else: # one plot for all chs combined + which_chs = _set_title_multiple_electrodes( + None, combine, pick_names, all_=True, ch_type=ch_type ) + _prefix = "Power" if baseline is None else "Baselined power" + _title = f"{_prefix} ({which_chs})" else: - subtitle = title - fig.suptitle(subtitle) - + _title = title + _fig.suptitle(_title) plt_show(show) - return [fig for (fig, ax) in figs_and_axes] + return figs @verbose def plot_joint( self, + *, timefreqs=None, picks=None, - baseline=None, - mode="mean", + exclude=(), + combine="mean", tmin=None, tmax=None, fmin=None, fmax=None, + baseline=None, + mode="mean", + dB=False, + yscale="auto", vmin=None, vmax=None, - cmap="RdBu_r", - dB=False, + vlim=(None, None), + cnorm=None, + cmap=None, colorbar=True, + title=None, # TODO consider deprecating this one, or adding an "auto" option show=True, - title=None, - yscale="auto", - combine="mean", - exclude=(), topomap_args=None, image_args=None, verbose=None, ): - """Plot TFRs as a two-dimensional image with topomaps. + """Plot TFRs as a two-dimensional image with topomap highlights. Parameters ---------- - timefreqs : None | list of tuple | dict of tuple - The time-frequency point(s) for which topomaps will be plotted. - See Notes. + %(timefreqs)s %(picks_good_data)s - baseline : None (default) or tuple of length 2 - The time interval to apply baseline correction. - If None do not apply it. If baseline is (a, b) - the interval is between "a (s)" and "b (s)". - If a is None, the beginning of the data is used. - If b is None, then b is set to the end of the interval. - If baseline is equal to (None, None), the entire time - interval is used. - mode : None | str - If str, must be one of 'ratio', 'zscore', 'mean', 'percent', - 'logratio' and 'zlogratio'. - Do baseline correction with ratio (power is divided by mean - power during baseline) or zscore (power is divided by standard - deviation of power during baseline after subtracting the mean, - power = [power - mean(power_baseline)] / std(power_baseline)), - mean simply subtracts the mean power, percent is the same as - applying ratio then mean, logratio is the same as mean but then - rendered in log-scale, zlogratio is the same as zscore but data - is rendered in log-scale first. - If None no baseline correction is applied. - %(tmin_tmax_psd)s - %(fmin_fmax_psd)s - vmin : float | None - The minimum value of the color scale for the image (for - topomaps, see ``topomap_args``). If vmin is None, the data - absolute minimum value is used. - vmax : float | None - The maximum value of the color scale for the image (for - topomaps, see ``topomap_args``). If vmax is None, the data - absolute maximum value is used. - cmap : matplotlib colormap - The colormap to use. - dB : bool - If True, 10*log10 is applied to the data to get dB. - colorbar : bool - If true, colorbar will be added to the plot (relating to the - topomaps). For user defined axes, the colorbar cannot be drawn. - Defaults to True. - show : bool - Call pyplot.show() at the end. - title : str | None - String for title. Defaults to None (blank/no title). - yscale : 'auto' (default) | 'linear' | 'log' - The scale of y (frequency) axis. 'linear' gives linear y axis, - 'log' leads to log-spaced y axis and 'auto' detects if frequencies - are log-spaced and only then sets the y axis to 'log'. - combine : 'mean' | 'rms' | callable - Type of aggregation to perform across selected channels. If a - function, it must operate on an array of shape - ``(n_channels, n_freqs, n_times)`` and return an array of shape - ``(n_freqs, n_times)``. + %(exclude_psd)s + Default is an empty :class:`tuple` which includes all channels. + %(combine_tfr_plot_joint)s .. versionchanged:: 1.3 - Added support for ``callable``. - exclude : list of str | 'bads' - Channels names to exclude from being shown. If 'bads', the - bad channels are excluded. Defaults to an empty list, i.e., ``[]``. - topomap_args : None | dict - A dict of ``kwargs`` that are forwarded to - :func:`mne.viz.plot_topomap` to style the topomaps. ``axes`` and - ``show`` are ignored. If ``times`` is not in this dict, automatic - peak detection is used. Beyond that, if ``None``, no customizable - arguments will be passed. - Defaults to ``None``. - image_args : None | dict - A dict of ``kwargs`` that are forwarded to :meth:`AverageTFR.plot` - to style the image. ``axes`` and ``show`` are ignored. Beyond that, - if ``None``, no customizable arguments will be passed. - Defaults to ``None``. + Added support for ``callable``. + %(tmin_tmax_psd)s + %(fmin_fmax_tfr)s + %(baseline_rescale)s + + How baseline is computed is determined by the ``mode`` parameter. + %(mode_tfr_plot)s + %(dB_tfr_plot_topo)s + %(yscale_tfr_plot)s + %(vmin_vmax_tfr_plot)s + %(vlim_tfr_plot_joint)s + %(cnorm)s + %(cmap_tfr_plot_topo)s + %(colorbar_tfr_plot_joint)s + %(title_none)s + %(show)s + %(topomap_args)s + %(image_args)s %(verbose)s Returns @@ -1914,68 +2210,37 @@ def plot_joint( Notes ----- - ``timefreqs`` has three different modes: tuples, dicts, and auto. - For (list of) tuple(s) mode, each tuple defines a pair - (time, frequency) in s and Hz on the TFR plot. For example, to - look at 10 Hz activity 1 second into the epoch and 3 Hz activity - 300 msec into the epoch, :: - - timefreqs=((1, 10), (.3, 3)) - - If provided as a dictionary, (time, frequency) tuples are keys and - (time_window, frequency_window) tuples are the values - indicating the - width of the windows (centered on the time and frequency indicated by - the key) to be averaged over. For example, :: - - timefreqs={(1, 10): (0.1, 2)} - - would translate into a window that spans 0.95 to 1.05 seconds, as - well as 9 to 11 Hz. If None, a single topomap will be plotted at the - absolute peak across the time-frequency representation. + %(notes_timefreqs_tfr_plot_joint)s .. versionadded:: 0.16.0 - """ # noqa: E501 + """ + from matplotlib import ticker from matplotlib.patches import ConnectionPatch - ##################################### - # Handle channels (picks and types) # - ##################################### - - # it would be nicer to let this happen in self._plot, - # but we need it here to do the loop over the remaining channel - # types in case a user supplies `picks` that pre-select only one - # channel type. - # Nonetheless, it should be refactored for code reuse. - copy = any(var is not None for var in (exclude, picks, baseline)) - tfr = self - if copy: - tfr = tfr.copy() - picks = "data" if picks is None else picks - tfr.pick(picks, exclude=() if exclude is None else exclude) - del picks - ch_types = tfr.info.get_channel_types(unique=True) - - # if multiple sensor types: one plot per channel type, recursive call - if len(ch_types) > 1: - logger.info( - "Multiple channel types selected, returning one " "figure per type." - ) + # deprecations + vlim = _warn_deprecated_vmin_vmax(vlim, vmin, vmax) + # handle recursion + picks = _picks_to_idx( + self.info, picks, "data_or_ica", exclude=exclude, with_ref_meg=False + ) + all_ch_types = np.array(self.get_channel_types()) + uniq_ch_types = sorted(set(all_ch_types[picks])) + if len(uniq_ch_types) > 1: + msg = "Multiple channel types selected, returning one figure per type." + logger.info(msg) figs = list() - for this_type in ch_types: # pick corresponding channel type - type_picks = [ - idx - for idx in range(tfr.info["nchan"]) - if channel_type(tfr.info, idx) == this_type - ] - tf_ = tfr.copy().pick(type_picks) - if len(tf_.info.get_channel_types(unique=True)) > 1: - raise RuntimeError( - "Possibly infinite loop due to channel selection " - "problem. This should never happen! Please check " - "your channel types." - ) + for this_type in uniq_ch_types: + this_picks = np.intersect1d( + picks, + np.nonzero(np.isin(all_ch_types, this_type))[0], + assume_unique=True, + ) + # TODO might be nice to not "copy first, then pick"; alternative might + # be to subset the data with `this_picks` and then construct the "copy" + # using __getstate__ and __setstate__ + _tfr = self.copy().pick(this_picks) figs.append( - tf_.plot_joint( + _tfr.plot_joint( timefreqs=timefreqs, picks=None, baseline=baseline, @@ -1984,8 +2249,7 @@ def plot_joint( tmax=tmax, fmin=fmin, fmax=fmax, - vmin=vmin, - vmax=vmax, + vlim=vlim, cmap=cmap, dB=dB, colorbar=colorbar, @@ -1993,205 +2257,181 @@ def plot_joint( title=title, yscale=yscale, combine=combine, - exclude=None, + exclude=(), topomap_args=topomap_args, verbose=verbose, ) ) return figs else: - ch_type = ch_types.pop() - - # Handle timefreqs - timefreqs = _get_timefreqs(tfr, timefreqs) - n_timefreqs = len(timefreqs) - - if topomap_args is None: - topomap_args = dict() - topomap_args_pass = { - k: v - for k, v in topomap_args.items() - if k not in ("axes", "show", "colorbar") - } - topomap_args_pass["outlines"] = topomap_args.get("outlines", "head") - topomap_args_pass["contours"] = topomap_args.get("contours", 6) - topomap_args_pass["ch_type"] = ch_type - - ############## - # Image plot # - ############## - - fig, tf_ax, map_ax = _prepare_joint_axes(n_timefreqs) - - cmap = _setup_cmap(cmap) - - # image plot - # we also use this to baseline and truncate (times and freqs) - # (a copy of) the instance - if image_args is None: - image_args = dict() - fig = tfr._plot( - picks=None, - baseline=baseline, - mode=mode, + ch_type = uniq_ch_types[0] + + # handle defaults + _validate_type(combine, ("str", "callable"), item_name="combine") # no `None` + image_args = dict() if image_args is None else image_args + topomap_args = dict() if topomap_args is None else topomap_args.copy() + # make sure if topomap_args["ch_type"] is set, it matches what is in `self.info` + topomap_args.setdefault("ch_type", ch_type) + if topomap_args["ch_type"] != ch_type: + raise ValueError( + f"topomap_args['ch_type'] is {topomap_args['ch_type']} which does not " + f"match the channel type present in the object ({ch_type})." + ) + # some necessary defaults + topomap_args.setdefault("outlines", "head") + topomap_args.setdefault("contours", 6) + # don't pass these: + topomap_args.pop("axes", None) + topomap_args.pop("show", None) + topomap_args.pop("colorbar", None) + + # get the time/freq limits of the image plot, to make sure requested annotation + # times/freqs are in range + _, times, freqs = self.get_data( + picks=picks, + exclude=(), tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, - vmin=vmin, - vmax=vmax, - cmap=cmap, + return_times=True, + return_freqs=True, + ) + # validate requested annotation times and freqs + timefreqs = _get_timefreqs(self, timefreqs) + valid_timefreqs = dict() + while timefreqs: + (_time, _freq), (t_win, f_win) = timefreqs.popitem() + # convert to half-windows + t_win /= 2 + f_win /= 2 + # make sure the times / freqs are in-bounds + msg = ( + "Requested {} exceeds the range of the data ({}). Choose different " + "`timefreqs`." + ) + if (times > _time).all() or (times < _time).all(): + _var = f"time point ({_time:0.3f} s)" + _range = f"{times[0]:0.3f} - {times[-1]:0.3f} s" + raise ValueError(msg.format(_var, _range)) + elif (freqs > _freq).all() or (freqs < _freq).all(): + _var = f"frequency ({_freq:0.1f} Hz)" + _range = f"{freqs[0]:0.1f} - {freqs[-1]:0.1f} Hz" + raise ValueError(msg.format(_var, _range)) + # snap the times/freqs to the nearest point we have an estimate for, and + # store the validated points + if t_win == 0: + _time = times[np.argmin(np.abs(times - _time))] + if f_win == 0: + _freq = freqs[np.argmin(np.abs(freqs - _freq))] + valid_timefreqs[(_time, _freq)] = (t_win, f_win) + + # prep data for topomaps (unlike image plot, must include all channels of the + # current ch_type). Don't pass tmin/tmax here (crop later after baselining) + topomap_picks = _picks_to_idx(self.info, ch_type) + data, times, freqs = self.get_data( + picks=topomap_picks, exclude=(), return_times=True, return_freqs=True + ) + # merge grads before baselining (makes ERDS visible) + info = pick_info(self.info, sel=topomap_picks, copy=True) + data, pos = _merge_if_grads( + data=data, + info=info, + ch_type=ch_type, + sphere=topomap_args.get("sphere"), + combine=combine, + ) + # loop over intended topomap locations, to find one vlim that works for all. + tf_array = np.array(list(valid_timefreqs)) # each row is [time, freq] + tf_array = tf_array[tf_array[:, 0].argsort()] # sort by time + _vmin, _vmax = (np.inf, -np.inf) + topomap_arrays = list() + topomap_titles = list() + for _time, _freq in tf_array: + # reduce data to the range of interest in the TF plane (i.e., finally crop) + t_win, f_win = valid_timefreqs[(_time, _freq)] + _tmin, _tmax = np.array([-1, 1]) * t_win + _time + _fmin, _fmax = np.array([-1, 1]) * f_win + _freq + _data, *_ = _prep_data_for_plot( + data, + times, + freqs, + tmin=_tmin, + tmax=_tmax, + fmin=_fmin, + fmax=_fmax, + baseline=baseline, + mode=mode, + verbose=verbose, + ) + _data = _data.mean(axis=(-1, -2)) # avg over times and freqs + topomap_arrays.append(_data) + _vmin = min(_data.min(), _vmin) + _vmax = max(_data.max(), _vmax) + # construct topopmap subplot title + t_pm = "" if t_win == 0 else f" ± {t_win:0.2f}" + f_pm = "" if f_win == 0 else f" ± {f_win:0.1f}" + _title = f"{_time:0.2f}{t_pm} s,\n{_freq:0.1f}{f_pm} Hz" + topomap_titles.append(_title) + # handle cmap. Power may be negative depending on baseline strategy so set + # `norm` empirically. vmin/vmax will be handled separately within the `plot()` + # call for the image plot. + norm = np.min(topomap_arrays) >= 0.0 + cmap = _setup_cmap(cmap, norm=norm) + topomap_args.setdefault("cmap", cmap[0]) # prevent interactive cbar + # finalize topomap vlims and compute contour locations. + # By passing `data=None` here ↓↓↓↓ we effectively assert vmin & vmax aren't None + _vlim = _setup_vmin_vmax(data=None, vmin=_vmin, vmax=_vmax, norm=norm) + topomap_args.setdefault("vlim", _vlim) + locator, topomap_args["contours"] = _set_contour_locator( + *topomap_args["vlim"], topomap_args["contours"] + ) + # initialize figure and do the image plot. `self.plot()` needed to wait to be + # called until after `topomap_args` was fully populated --- we don't pass the + # dict through to `self.plot()` explicitly here, but we do "reach back" and get + # it if it's needed by the interactive rectangle selector. + fig, image_ax, topomap_axes = _prepare_joint_axes(len(valid_timefreqs)) + fig = self.plot( + picks=picks, + exclude=(), + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, dB=dB, + combine=combine, + yscale=yscale, + vlim=vlim, + cnorm=cnorm, + cmap=cmap, colorbar=False, - show=False, title=title, - axes=tf_ax, - yscale=yscale, - combine=combine, - exclude=None, - copy=False, - source_plot_joint=True, - topomap_args=topomap_args_pass, - ch_type=ch_type, + # mask, mask_style, mask_cmap, mask_alpha + axes=image_ax, + show=False, + verbose=verbose, **image_args, - )[0] - - # set and check time and freq limits ... - # can only do this after the tfr plot because it may change these - # parameters - tmax, tmin = tfr.times.max(), tfr.times.min() - fmax, fmin = tfr.freqs.max(), tfr.freqs.min() - for time, freq in timefreqs.keys(): - if not (tmin <= time <= tmax): - error_value = "time point (" + str(time) + " s)" - elif not (fmin <= freq <= fmax): - error_value = "frequency (" + str(freq) + " Hz)" - else: - continue - raise ValueError( - "Requested " + error_value + " exceeds the range" - "of the data. Choose different `timefreqs`." - ) - - ############ - # Topomaps # - ############ - - titles, all_data, all_pos, vlims = [], [], [], [] - - # the structure here is a bit complicated to allow aggregating vlims - # over all topomaps. First, one loop over all timefreqs to collect - # vlims. Then, find the max vlims and in a second loop over timefreqs, - # do the actual plotting. - timefreqs_array = np.array([np.array(keys) for keys in timefreqs]) - order = timefreqs_array[:, 0].argsort() # sort by time - - for ii, (time, freq) in enumerate(timefreqs_array[order]): - avg = timefreqs[(time, freq)] - # set up symmetric windows - time_half_range, freq_half_range = avg / 2.0 - - if time_half_range == 0: - time = tfr.times[np.argmin(np.abs(tfr.times - time))] - if freq_half_range == 0: - freq = tfr.freqs[np.argmin(np.abs(tfr.freqs - freq))] - - if (time_half_range == 0) and (freq_half_range == 0): - sub_map_title = f"({time:.2f} s,\n{freq:.1f} Hz)" - else: - sub_map_title = ( - f"({time:.1f} \u00b1 {time_half_range:.1f} " - f"s,\n{freq:.1f} \u00b1 {freq_half_range:.1f} Hz)" - ) - - tmin = time - time_half_range - tmax = time + time_half_range - fmin = freq - freq_half_range - fmax = freq + freq_half_range - - data = tfr.data - - # merging grads here before rescaling makes ERDs visible - - sphere = topomap_args.get("sphere") - if ch_type == "grad": - picks = _pair_grad_sensors(tfr.info, topomap_coords=False) - pos = _find_topomap_coords(tfr.info, picks=picks[::2], sphere=sphere) - method = combine if isinstance(combine, str) else "rms" - data, _ = _merge_ch_data(data[picks], ch_type, [], method=method) - del picks, method - else: - pos, _ = _get_pos_outlines(tfr.info, None, sphere) - del sphere - - all_pos.append(pos) - - data, times, freqs, _, _ = _preproc_tfr( - data, - tfr.times, - tfr.freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - None, - tfr.info["sfreq"], - ) - - vlims.append(np.abs(data).max()) - titles.append(sub_map_title) - all_data.append(data) - new_t = tfr.times[np.abs(tfr.times - np.median([times])).argmin()] - new_f = tfr.freqs[np.abs(tfr.freqs - np.median([freqs])).argmin()] - timefreqs_array[ii] = (new_t, new_f) - - # passing args to the topomap calls - max_lim = max(vlims) - _vlim = list(topomap_args.get("vlim", (None, None))) - # fall back on ± max_lim - for sign, index in zip((-1, 1), (0, 1)): - if _vlim[index] is None: - _vlim[index] = sign * max_lim - topomap_args_pass["vlim"] = tuple(_vlim) - locator, contours = _set_contour_locator(*_vlim, topomap_args_pass["contours"]) - topomap_args_pass["contours"] = contours - - for ax, title, data, pos in zip(map_ax, titles, all_data, all_pos): + )[0] # [0] because `.plot()` always returns a list + # now, actually plot the topomaps + for ax, title, _data in zip(topomap_axes, topomap_titles, topomap_arrays): ax.set_title(title) - plot_topomap( - data.mean(axis=(-1, -2)), - pos, - cmap=cmap[0], - axes=ax, - show=False, - **topomap_args_pass, - ) - - ############# - # Finish up # - ############# + plot_topomap(_data, pos, axes=ax, show=False, **topomap_args) + # draw colorbar if colorbar: - from matplotlib import ticker - cbar = fig.colorbar(ax.images[0]) - if locator is None: - locator = ticker.MaxNLocator(nbins=5) - cbar.locator = locator + cbar.locator = ticker.MaxNLocator(nbins=5) if locator is None else locator cbar.update_ticks() - - # draw the connection lines between time series and topoplots - for (time_, freq_), map_ax_ in zip(timefreqs_array, map_ax): + # draw the connection lines between time-frequency image and topoplots + for (time_, freq_), topo_ax in zip(tf_array, topomap_axes): con = ConnectionPatch( xyA=[time_, freq_], xyB=[0.5, 0], coordsA="data", coordsB="axes fraction", - axesA=tf_ax, - axesB=map_ax_, + axesA=image_ax, + axesB=topo_ax, color="grey", linestyle="-", linewidth=1.5, @@ -2204,108 +2444,6 @@ def plot_joint( plt_show(show) return fig - @verbose - def _onselect( - self, - eclick, - erelease, - baseline=None, - mode=None, - cmap=None, - source_plot_joint=False, - topomap_args=None, - verbose=None, - ): - """Handle rubber band selector in channel tfr.""" - if abs(eclick.x - erelease.x) < 0.1 or abs(eclick.y - erelease.y) < 0.1: - return - tmin = round(min(eclick.xdata, erelease.xdata), 5) # s - tmax = round(max(eclick.xdata, erelease.xdata), 5) - fmin = round(min(eclick.ydata, erelease.ydata), 5) # Hz - fmax = round(max(eclick.ydata, erelease.ydata), 5) - tmin = min(self.times, key=lambda x: abs(x - tmin)) # find closest - tmax = min(self.times, key=lambda x: abs(x - tmax)) - fmin = min(self.freqs, key=lambda x: abs(x - fmin)) - fmax = min(self.freqs, key=lambda x: abs(x - fmax)) - if tmin == tmax or fmin == fmax: - logger.info( - "The selected area is too small. " - "Select a larger time-frequency window." - ) - return - - types = list() - if "eeg" in self: - types.append("eeg") - if "mag" in self: - types.append("mag") - if "grad" in self: - if ( - len( - _pair_grad_sensors( - self.info, topomap_coords=False, raise_error=False - ) - ) - >= 2 - ): - types.append("grad") - elif len(types) == 0: - return # Don't draw a figure for nothing. - - fig = figure_nobar() - fig.suptitle( - f"{tmin:.2f} s - {tmax:.2f} s, {fmin:.2f} Hz - {fmax:.2f} Hz", - y=0.04, - ) - - if source_plot_joint: - ax = fig.add_subplot(111) - data = _preproc_tfr( - self.data, - self.times, - self.freqs, - tmin, - tmax, - fmin, - fmax, - None, - None, - None, - None, - None, - self.info["sfreq"], - )[0] - data = data.mean(-1).mean(-1) - vmax = np.abs(data).max() - im, _ = plot_topomap( - data, - self.info, - vlim=(-vmax, vmax), - cmap=cmap[0], - axes=ax, - show=False, - **topomap_args, - ) - _add_colorbar(ax, im, cmap, title="AU", pad=0.1) - fig.show() - else: - for idx, ch_type in enumerate(types): - ax = fig.add_subplot(1, len(types), idx + 1) - plot_tfr_topomap( - self, - ch_type=ch_type, - tmin=tmin, - tmax=tmax, - fmin=fmin, - fmax=fmax, - baseline=baseline, - mode=mode, - cmap=None, - vlim=(None, None), - axes=ax, - ) - ax.set_title(ch_type) - @verbose def plot_topo( self, @@ -2316,11 +2454,11 @@ def plot_topo( tmax=None, fmin=None, fmax=None, - vmin=None, + vmin=None, # TODO deprecate in favor of `vlim` (needs helper func refactor) vmax=None, layout=None, cmap="RdBu_r", - title=None, + title=None, # don't deprecate; topo titles aren't standard (color, size, just.) dB=False, colorbar=True, layout_scale=0.945, @@ -2332,88 +2470,38 @@ def plot_topo( yscale="auto", verbose=None, ): - """Plot TFRs in a topography with images. + """Plot a TFR image for each channel in a sensor layout arrangement. Parameters ---------- %(picks_good_data)s - baseline : None (default) or tuple of length 2 - The time interval to apply baseline correction. - If None do not apply it. If baseline is (a, b) - the interval is between "a (s)" and "b (s)". - If a is None the beginning of the data is used - and if b is None then b is set to the end of the interval. - If baseline is equal to (None, None) all the time - interval is used. - mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' - Perform baseline correction by - - - subtracting the mean of baseline values ('mean') - - dividing by the mean of baseline values ('ratio') - - dividing by the mean of baseline values and taking the log - ('logratio') - - subtracting the mean of baseline values followed by dividing by - the mean of baseline values ('percent') - - subtracting the mean of baseline values and dividing by the - standard deviation of baseline values ('zscore') - - dividing by the mean of baseline values, taking the log, and - dividing by the standard deviation of log baseline values - ('zlogratio') + %(baseline_rescale)s - tmin : None | float - The first time instant to display. If None the first time point - available is used. - tmax : None | float - The last time instant to display. If None the last time point - available is used. - fmin : None | float - The first frequency to display. If None the first frequency - available is used. - fmax : None | float - The last frequency to display. If None the last frequency - available is used. - vmin : float | None - The minimum value of the color scale. If vmin is None, the data - minimum value is used. - vmax : float | None - The maximum value of the color scale. If vmax is None, the data - maximum value is used. - layout : Layout | None - Layout instance specifying sensor positions. If possible, the - correct layout is inferred from the data. - cmap : matplotlib colormap | str - The colormap to use. Defaults to 'RdBu_r'. - title : str - Title of the figure. - dB : bool - If True, 10*log10 is applied to the data to get dB. - colorbar : bool - If true, colorbar will be added to the plot. - layout_scale : float - Scaling factor for adjusting the relative size of the layout - on the canvas. - show : bool - Call pyplot.show() at the end. - border : str - Matplotlib borders style to be used for each sensor plot. - fig_facecolor : color - The figure face color. Defaults to black. - fig_background : None | array - A background image for the figure. This must be a valid input to - `matplotlib.pyplot.imshow`. Defaults to None. - font_color : color - The color of tick labels in the colorbar. Defaults to white. - yscale : 'auto' (default) | 'linear' | 'log' - The scale of y (frequency) axis. 'linear' gives linear y axis, - 'log' leads to log-spaced y axis and 'auto' detects if frequencies - are log-spaced and only then sets the y axis to 'log'. + How baseline is computed is determined by the ``mode`` parameter. + %(mode_tfr_plot)s + %(tmin_tmax_psd)s + %(fmin_fmax_tfr)s + %(vmin_vmax_tfr_plot_topo)s + %(layout_spectrum_plot_topo)s + %(cmap_tfr_plot_topo)s + %(title_none)s + %(dB_tfr_plot_topo)s + %(colorbar)s + %(layout_scale)s + %(show)s + %(border_topo)s + %(fig_facecolor)s + %(fig_background)s + %(font_color)s + %(yscale_tfr_plot)s %(verbose)s Returns ------- fig : matplotlib.figure.Figure The figure containing the topography. - """ # noqa: E501 + """ + # convenience vars times = self.times.copy() freqs = self.freqs data = self.data @@ -2422,6 +2510,8 @@ def plot_topo( info, data = _prepare_picks(info, data, picks, axis=0) del picks + # TODO this is the only remaining call to _preproc_tfr; should be refactored + # (to use _prep_data_for_plot?) data, times, freqs, vmin, vmax = _preproc_tfr( data, times, @@ -2548,160 +2638,1138 @@ def plot_topomap( show=show, ) - def _check_compat(self, tfr): - """Check that self and tfr have the same time-frequency ranges.""" - assert np.all(tfr.times == self.times) - assert np.all(tfr.freqs == self.freqs) - - def __add__(self, tfr): # noqa: D105 - """Add instances.""" - self._check_compat(tfr) - out = self.copy() - out.data += tfr.data - return out - - def __iadd__(self, tfr): # noqa: D105 - self._check_compat(tfr) - self.data += tfr.data - return self + @verbose + def save(self, fname, *, overwrite=False, verbose=None): + """Save time-frequency data to disk (in HDF5 format). - def __sub__(self, tfr): # noqa: D105 - """Subtract instances.""" - self._check_compat(tfr) - out = self.copy() - out.data -= tfr.data - return out + Parameters + ---------- + fname : path-like + Path of file to save to. + %(overwrite)s + %(verbose)s - def __isub__(self, tfr): # noqa: D105 - self._check_compat(tfr) - self.data -= tfr.data - return self + See Also + -------- + mne.time_frequency.read_spectrum + """ + _, write_hdf5 = _import_h5io_funcs() + check_fname(fname, "time-frequency object", (".h5", ".hdf5")) + fname = _check_fname(fname, overwrite=overwrite, verbose=verbose) + out = self.__getstate__() + if "metadata" in out: + out["metadata"] = _prepare_write_metadata(out["metadata"]) + write_hdf5(fname, out, overwrite=overwrite, title="mnepython", slash="replace") - def __truediv__(self, a): # noqa: D105 - """Divide instances.""" - out = self.copy() - out /= a - return out + @verbose + def to_data_frame( + self, + picks=None, + index=None, + long_format=False, + time_format=None, + *, + verbose=None, + ): + """Export data in tabular structure as a pandas DataFrame. - def __itruediv__(self, a): # noqa: D105 - self.data /= a - return self + Channels are converted to columns in the DataFrame. By default, + additional columns ``'time'``, ``'freq'``, ``'epoch'``, and + ``'condition'`` (epoch event description) are added, unless ``index`` + is not ``None`` (in which case the columns specified in ``index`` will + be used to form the DataFrame's index instead). ``'epoch'``, and + ``'condition'`` are not supported for ``AverageTFR``. - def __mul__(self, a): - """Multiply source instances.""" - out = self.copy() - out *= a - return out + Parameters + ---------- + %(picks_all)s + %(index_df_epo)s + Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and + ``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'`` + for ``AverageTFR``. + Defaults to ``None``. + %(long_format_df_epo)s + %(time_format_df)s - def __imul__(self, a): # noqa: D105 - self.data *= a - return self + .. versionadded:: 0.23 + %(verbose)s - def __repr__(self): # noqa: D105 - s = f"time : [{self.times[0]}, {self.times[-1]}]" - s += f", freq : [{self.freqs[0]}, {self.freqs[-1]}]" - s += ", nave : %d" % self.nave - s += ", channels : %d" % self.data.shape[0] - s += f", ~{sizeof_fmt(self._size)}" - return "" % s + Returns + ------- + %(df_return)s + """ + # check pandas once here, instead of in each private utils function + pd = _check_pandas_installed() # noqa + # arg checking + valid_index_args = ["time", "freq"] + if isinstance(self, EpochsTFR): + valid_index_args.extend(["epoch", "condition"]) + valid_time_formats = ["ms", "timedelta"] + index = _check_pandas_index_arguments(index, valid_index_args) + time_format = _check_time_format(time_format, valid_time_formats) + # get data + picks = _picks_to_idx(self.info, picks, "all", exclude=()) + data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) + axis = self._dims.index("channel") + if not isinstance(self, EpochsTFR): + data = data[np.newaxis] # add singleton "epochs" axis + axis += 1 + n_epochs, n_picks, n_freqs, n_times = data.shape + # reshape to (epochs*freqs*times) x signals + data = np.moveaxis(data, axis, -1) + data = data.reshape(n_epochs * n_freqs * n_times, n_picks) + # prepare extra columns / multiindex + mindex = list() + times = _convert_times(times, time_format, self.info["meas_date"]) + times = np.tile(times, n_epochs * n_freqs) + freqs = np.tile(np.repeat(freqs, n_times), n_epochs) + mindex.append(("time", times)) + mindex.append(("freq", freqs)) + if isinstance(self, EpochsTFR): + mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) + rev_event_id = {v: k for k, v in self.event_id.items()} + conditions = [rev_event_id[k] for k in self.events[:, 2]] + mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) + assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) + # build DataFrame + if isinstance(self, EpochsTFR): + default_index = ["condition", "epoch", "freq", "time"] + else: + default_index = ["freq", "time"] + df = _build_data_frame( + self, data, picks, long_format, mindex, index, default_index=default_index + ) + return df + + +@fill_doc +class AverageTFR(BaseTFR): + """Data object for spectrotemporal representations of averaged data. + + .. warning:: The preferred means of creating AverageTFR objects is via the + instance methods :meth:`mne.Epochs.compute_tfr` and + :meth:`mne.Evoked.compute_tfr`, or via + :meth:`mne.time_frequency.EpochsTFR.average`. Direct class + instantiation is discouraged. + + Parameters + ---------- + %(info_not_none)s + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead, or + use :class:`~mne.time_frequency.AverageTFRArray` which retains the old API. + data : ndarray, shape (n_channels, n_freqs, n_times) + The data. + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead, or + use :class:`~mne.time_frequency.AverageTFRArray` which retains the old API. + times : ndarray, shape (n_times,) + The time values in seconds. + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead and + (optionally) use ``tmin`` and ``tmax`` to restrict the time domain; or use + :class:`~mne.time_frequency.AverageTFRArray` which retains the old API. + freqs : ndarray, shape (n_freqs,) + The frequencies in Hz. + nave : int + The number of averaged TFRs. + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` or :class:`~mne.Evoked` instead; + ``nave`` will be inferred automatically. Or, use + :class:`~mne.time_frequency.AverageTFRArray` which retains the old API. + inst : instance of Evoked | instance of Epochs | dict + The data from which to compute the time-frequency representation. Passing a + :class:`dict` will create the AverageTFR using the ``__setstate__`` interface + and is not recommended for typical use cases. + %(method_tfr)s + %(freqs_tfr)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(decim_tfr)s + %(comment_averagetfr)s + %(n_jobs)s + %(verbose)s + %(method_kw_tfr)s + + Attributes + ---------- + %(baseline_tfr_attr)s + %(ch_names_tfr_attr)s + %(comment_averagetfr_attr)s + %(freqs_tfr_attr)s + %(info_not_none)s + %(method_tfr_attr)s + %(nave_tfr_attr)s + %(sfreq_tfr_attr)s + %(shape_tfr_attr)s + + See Also + -------- + RawTFR + EpochsTFR + AverageTFRArray + mne.Evoked.compute_tfr + mne.time_frequency.EpochsTFR.average + + Notes + ----- + The old API (prior to version 1.7) was:: + + AverageTFR(info, data, times, freqs, nave, comment=None, method=None) + + That API is still available via :class:`~mne.time_frequency.AverageTFRArray` for + cases where the data are precomputed or do not originate from MNE-Python objects. + The preferred new API uses instance methods:: + + evoked.compute_tfr(method, freqs, ...) + epochs.compute_tfr(method, freqs, average=True, ...) + + The new API also supports AverageTFR instantiation from a :class:`dict`, but this + is primarily for save/load and internal purposes, and wraps ``__setstate__``. + During the transition from the old to the new API, it may be expedient to use + :class:`~mne.time_frequency.AverageTFRArray` as a "quick-fix" approach to updating + scripts under active development. + + References + ---------- + .. footbibliography:: + """ + + def __init__( + self, + info=None, + data=None, + times=None, + freqs=None, + nave=None, + *, + inst=None, + method=None, + tmin=None, + tmax=None, + picks=None, + proj=False, + decim=1, + comment=None, + n_jobs=None, + verbose=None, + **method_kw, + ): + from ..epochs import BaseEpochs + from ..evoked import Evoked + from ._stockwell import _check_input_st, _compute_freqs_st + + # deprecations. TODO remove after 1.7 release + depr_params = dict(info=info, data=data, times=times, nave=nave) + bad_params = list() + for name, param in depr_params.items(): + if param is not None: + bad_params.append(name) + if len(bad_params): + _s = _pl(bad_params) + is_are = _pl(bad_params, "is", "are") + bad_params_list = '", "'.join(bad_params) + warn( + f'Parameter{_s} "{bad_params_list}" {is_are} deprecated and will be ' + "removed in version 1.8. For a quick fix, use ``AverageTFRArray`` with " + "the same parameters. For a long-term fix, see the docstring notes.", + FutureWarning, + ) + if inst is not None: + raise ValueError( + "Do not pass `inst` alongside deprecated params " + f'"{bad_params_list}"; see docstring of AverageTFR for guidance.' + ) + inst = depr_params | dict(freqs=freqs, method=method, comment=comment) + # end TODO ↑↑↑↑↑↑ + + # dict is allowed for __setstate__ compatibility, and Epochs.compute_tfr() can + # return an AverageTFR depending on its parameters, so Epochs input is allowed + _validate_type( + inst, (BaseEpochs, Evoked, dict), "object passed to AverageTFR constructor" + ) + # stockwell API is very different from multitaper/morlet + if method == "stockwell" and not isinstance(inst, dict): + if isinstance(freqs, str) and freqs == "auto": + fmin, fmax = None, None + elif len(freqs) == 2: + fmin, fmax = freqs + else: + raise ValueError( + "for Stockwell method, freqs must be a length-2 iterable " + f'or "auto", got {freqs}.' + ) + method_kw.update(fmin=fmin, fmax=fmax) + # Compute freqs. We need a couple lines of code dupe here (also in + # BaseTFR.__init__) to get the subset of times to pass to _check_input_st() + _mask = _time_mask(inst.times, tmin, tmax, sfreq=inst.info["sfreq"]) + _times = inst.times[_mask].copy() + _, default_nfft, _ = _check_input_st(_times, None) + n_fft = method_kw.get("n_fft", default_nfft) + *_, freqs = _compute_freqs_st(fmin, fmax, n_fft, inst.info["sfreq"]) + + # use Evoked.comment or str(Epochs.event_id) as the default comment... + if comment is None: + comment = getattr(inst, "comment", ",".join(getattr(inst, "event_id", ""))) + # ...but don't overwrite if it's coming in with a comment already set + if isinstance(inst, dict): + inst.setdefault("comment", comment) + else: + self._comment = getattr(self, "_comment", comment) + super().__init__( + inst, + method, + freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) + + def __getstate__(self): + """Prepare AverageTFR object for serialization.""" + out = super().__getstate__() + out.update(nave=self.nave, comment=self.comment) + # NOTE: self._itc should never exist in the instance returned to the user; it + # is temporarily present in the output from the tfr_array_* function, and is + # split out into a separate AverageTFR object (and deleted from the object + # holding power estimates) before those objects are passed back to the user. + # The following lines are there because we make use of __getstate__ to achieve + # that splitting of objects. + if hasattr(self, "_itc"): + out.update(itc=self._itc) + return out + + def __setstate__(self, state): + """Unpack AverageTFR from serialized format.""" + super().__setstate__(state) + self._comment = state.get("comment", "") + self._nave = state.get("nave", 1) + + @property + def comment(self): + return self._comment + + @comment.setter + def comment(self, comment): + self._comment = comment + + @property + def nave(self): + return self._nave + + @nave.setter + def nave(self, nave): + self._nave = nave + + def _get_instance_data(self, time_mask): + # AverageTFRs can be constructed from Epochs data, so we triage shape here. + # Evoked data get a fake singleton "epoch" axis prepended + dim = slice(None) if _get_instance_type_string(self) == "Epochs" else np.newaxis + data = self.inst.get_data(picks=self._picks)[dim, :, time_mask] + self._nave = getattr(self.inst, "nave", data.shape[0]) + return data + + +@fill_doc +class AverageTFRArray(AverageTFR): + """Data object for *precomputed* spectrotemporal representations of averaged data. + + Parameters + ---------- + %(info_not_none)s + %(data_tfr)s + %(times)s + %(freqs_tfr_array)s + nave : int + The number of averaged TFRs. + %(comment_averagetfr_attr)s + %(method_tfr_array)s + + Attributes + ---------- + %(baseline_tfr_attr)s + %(ch_names_tfr_attr)s + %(comment_averagetfr_attr)s + %(freqs_tfr_attr)s + %(info_not_none)s + %(method_tfr_attr)s + %(nave_tfr_attr)s + %(sfreq_tfr_attr)s + %(shape_tfr_attr)s + + See Also + -------- + AverageTFR + EpochsTFRArray + mne.Epochs.compute_tfr + mne.Evoked.compute_tfr + """ + + def __init__( + self, info, data, times, freqs, *, nave=None, comment=None, method=None + ): + state = dict(info=info, data=data, times=times, freqs=freqs) + for name, optional in dict(nave=nave, comment=comment, method=method).items(): + if optional is not None: + state[name] = optional + self.__setstate__(state) + + +@fill_doc +class EpochsTFR(BaseTFR, GetEpochsMixin): + """Data object for spectrotemporal representations of epoched data. + + .. important:: + The preferred means of creating EpochsTFR objects from :class:`~mne.Epochs` + objects is via the instance method :meth:`~mne.Epochs.compute_tfr`. + To create an EpochsTFR object from pre-computed data (i.e., a NumPy array) use + :class:`~mne.time_frequency.EpochsTFRArray`. + + Parameters + ---------- + %(info_not_none)s + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + data : ndarray, shape (n_channels, n_freqs, n_times) + The data. + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + times : ndarray, shape (n_times,) + The time values in seconds. + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead and + (optionally) use ``tmin`` and ``tmax`` to restrict the time domain; or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + %(freqs_tfr_epochs)s + inst : instance of Epochs + The data from which to compute the time-frequency representation. + %(method_tfr_epochs)s + %(comment_tfr_attr)s + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(decim_tfr)s + %(events_epochstfr)s + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + %(event_id_epochstfr)s + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + selection : array + List of indices of selected events (not dropped or ignored etc.). For + example, if the original event array had 4 events and the second event + has been dropped, this attribute would be np.array([0, 2, 3]). + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + drop_log : tuple of tuple + A tuple of the same length as the event array used to initialize the + ``EpochsTFR`` object. If the i-th original event is still part of the + selection, drop_log[i] will be an empty tuple; otherwise it will be + a tuple of the reasons the event is not longer in the selection, e.g.: + + - ``'IGNORED'`` + If it isn't part of the current subset defined by the user + - ``'NO_DATA'`` or ``'TOO_SHORT'`` + If epoch didn't contain enough data names of channels that + exceeded the amplitude threshold + - ``'EQUALIZED_COUNTS'`` + See :meth:`~mne.Epochs.equalize_event_counts` + - ``'USER'`` + For user-defined reasons (see :meth:`~mne.Epochs.drop`). + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + %(metadata_epochstfr)s + + .. deprecated:: 1.7 + Pass an instance of :class:`~mne.Epochs` as ``inst`` instead, or use + :class:`~mne.time_frequency.EpochsTFRArray` which retains the old API. + %(n_jobs)s + %(verbose)s + %(method_kw_tfr)s + + Attributes + ---------- + %(baseline_tfr_attr)s + %(ch_names_tfr_attr)s + %(comment_tfr_attr)s + %(drop_log)s + %(event_id_attr)s + %(events_attr)s + %(freqs_tfr_attr)s + %(info_not_none)s + %(metadata_attr)s + %(method_tfr_attr)s + %(selection_attr)s + %(sfreq_tfr_attr)s + %(shape_tfr_attr)s + + See Also + -------- + mne.Epochs.compute_tfr + RawTFR + AverageTFR + EpochsTFRArray + + References + ---------- + .. footbibliography:: + """ + + def __init__( + self, + info=None, + data=None, + times=None, + freqs=None, + *, + inst=None, + method=None, + comment=None, + tmin=None, + tmax=None, + picks=None, + proj=False, + decim=1, + events=None, + event_id=None, + selection=None, + drop_log=None, + metadata=None, + n_jobs=None, + verbose=None, + **method_kw, + ): + from ..epochs import BaseEpochs + + # deprecations. TODO remove after 1.7 release + depr_params = dict(info=info, data=data, times=times, comment=comment) + bad_params = list() + for name, param in depr_params.items(): + if param is not None: + bad_params.append(name) + if len(bad_params): + _s = _pl(bad_params) + is_are = _pl(bad_params, "is", "are") + bad_params_list = '", "'.join(bad_params) + warn( + f'Parameter{_s} "{bad_params_list}" {is_are} deprecated and will be ' + "removed in version 1.8. For a quick fix, use ``EpochsTFRArray`` with " + "the same parameters. For a long-term fix, see the docstring notes.", + FutureWarning, + ) + if inst is not None: + raise ValueError( + "Do not pass `inst` alongside deprecated params " + f'"{bad_params_list}"; see docstring of AverageTFR for guidance.' + ) + # sensible defaults are created in __setstate__ so only pass these through + # if they're user-specified + optional = dict( + freqs=freqs, + method=method, + events=events, + event_id=event_id, + selection=selection, + drop_log=drop_log, + metadata=metadata, + ) + optional_params = { + key: val for key, val in optional.items() if val is not None + } + inst = depr_params | optional_params + # end TODO ↑↑↑↑↑↑ + + # dict is allowed for __setstate__ compatibility + _validate_type( + inst, (BaseEpochs, dict), "object passed to EpochsTFR constructor", "Epochs" + ) + super().__init__( + inst, + method, + freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) + + @fill_doc + def __getitem__(self, item): + """Subselect epochs from an EpochsTFR. + + Parameters + ---------- + %(item)s + Access options are the same as for :class:`~mne.Epochs` objects, see the + docstring Notes section of :meth:`mne.Epochs.__getitem__` for explanation. + + Returns + ------- + %(getitem_epochstfr_return)s + """ + return super().__getitem__(item) + + def __getstate__(self): + """Prepare EpochsTFR object for serialization.""" + out = super().__getstate__() + out.update( + metadata=self._metadata, + drop_log=self.drop_log, + event_id=self.event_id, + events=self.events, + selection=self.selection, + raw_times=self._raw_times, + ) + return out + + def __setstate__(self, state): + """Unpack EpochsTFR from serialized format.""" + if state["data"].ndim != 4: + raise ValueError(f"EpochsTFR data should be 4D, got {state['data'].ndim}.") + super().__setstate__(state) + self._metadata = state.get("metadata", None) + n_epochs = self.shape[0] + n_times = self.shape[-1] + fake_samps = np.linspace( + n_times, n_times * (n_epochs + 1), n_epochs, dtype=int, endpoint=False + ) + fake_events = np.dstack( + (fake_samps, np.zeros_like(fake_samps), np.ones_like(fake_samps)) + ).squeeze(axis=0) + self.events = state.get("events", _ensure_events(fake_events)) + self.event_id = state.get("event_id", _check_event_id(None, self.events)) + self.drop_log = state.get("drop_log", tuple()) + self.selection = state.get("selection", np.arange(n_epochs)) + self._bad_dropped = True # always true, need for `equalize_event_counts()` + + def __next__(self, return_event_id=False): + """Iterate over EpochsTFR objects. + + NOTE: __iter__() and _stop_iter() are defined by the GetEpochs mixin. + + Parameters + ---------- + return_event_id : bool + If ``True``, return both the EpochsTFR data and its associated ``event_id``. + + Returns + ------- + epoch : array of shape (n_channels, n_freqs, n_times) + The single-epoch time-frequency data. + event_id : int + The integer event id associated with the epoch. Only returned if + ``return_event_id`` is ``True``. + """ + if self._current >= len(self._data): + self._stop_iter() + epoch = self._data[self._current] + event_id = self.events[self._current][-1] + self._current += 1 + if return_event_id: + return epoch, event_id + return epoch + + def _check_singleton(self): + """Check if self contains only one Epoch, and return it as an AverageTFR.""" + if self.shape[0] > 1: + calling_func = inspect.currentframe().f_back.f_code.co_name + raise NotImplementedError( + f"Cannot call {calling_func}() from EpochsTFR with multiple epochs; " + "please subselect a single epoch before plotting." + ) + return list(self.iter_evoked())[0] + + def _get_instance_data(self, time_mask): + return self.inst.get_data(picks=self._picks)[:, :, time_mask] + + def _update_epoch_attributes(self): + # adjust dims and shape + if self.method != "stockwell": # stockwell consumes epochs dimension + self._dims = ("epoch",) + self._dims + self._shape = (len(self.inst),) + self._shape + # we need these for to_data_frame() + self.event_id = self.inst.event_id.copy() + self.events = self.inst.events.copy() + self.selection = self.inst.selection.copy() + # we need these for __getitem__() + self.drop_log = deepcopy(self.inst.drop_log) + self._metadata = self.inst.metadata + # we need this for compatibility with equalize_event_counts() + self._bad_dropped = True + + def average(self, method="mean", *, dim="epochs", copy=False): + """Aggregate the EpochsTFR across epochs, frequencies, or times. + + Parameters + ---------- + method : "mean" | "median" | callable + How to aggregate the data across the given ``dim``. If callable, + must take a :class:`NumPy array` of shape + ``(n_epochs, n_channels, n_freqs, n_times)`` and return an array + with one fewer dimensions (which dimension is collapsed depends on + the value of ``dim``). Default is ``"mean"``. + dim : "epochs" | "freqs" | "times" + The dimension along which to combine the data. + copy : bool + Whether to return a copy of the modified instance, or modify in place. + Ignored when ``dim="epochs"`` or ``"times"`` because those options return + different types (:class:`~mne.time_frequency.AverageTFR` and + :class:`~mne.time_frequency.EpochsSpectrum`, respectively). + + Returns + ------- + tfr : instance of EpochsTFR | AverageTFR | EpochsSpectrum + The aggregated TFR object. + + Notes + ----- + Passing in ``np.median`` is considered unsafe for complex data; pass + the string ``"median"`` instead to compute the *marginal* median + (i.e. the median of the real and imaginary components separately). + See discussion here: + + https://github.com/scipy/scipy/pull/12676#issuecomment-783370228 + """ + _check_option("dim", dim, ("epochs", "freqs", "times")) + axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural + + func = _check_combine(mode=method, axis=axis) + data = func(self.data) + + n_epochs, n_channels, n_freqs, n_times = self.data.shape + freqs, times = self.freqs, self.times + if dim == "epochs": + expected_shape = self._data.shape[1:] + elif dim == "freqs": + expected_shape = (n_epochs, n_channels, n_times) + freqs = np.mean(self.freqs, keepdims=True) + elif dim == "times": + expected_shape = (n_epochs, n_channels, n_freqs) + times = np.mean(self.times, keepdims=True) + + if data.shape != expected_shape: + raise RuntimeError( + "EpochsTFR.average() got a method that resulted in data of shape " + f"{data.shape}, but it should be {expected_shape}." + ) + # restore singleton freqs axis (not necessary for epochs/times: class changes) + if dim == "freqs": + data = np.expand_dims(data, axis=axis) + state = self.__getstate__() + state["data"] = data + state["info"] = deepcopy(self.info) + state["dims"] = (*state["dims"][:axis], *state["dims"][axis + 1 :]) + state["freqs"] = freqs + state["times"] = times + if dim == "epochs": + state["inst_type_str"] = "Evoked" + state["nave"] = n_epochs + state["comment"] = f"{method} of {n_epochs} EpochsTFR{_pl(n_epochs)}" + out = AverageTFR(inst=state) + out._data_type = "Average Power" + return out + + elif dim == "times": + return EpochsSpectrum( + state, + method=None, + fmin=None, + fmax=None, + tmin=None, + tmax=None, + picks=None, + exclude=None, + proj=None, + remove_dc=None, + n_jobs=None, + ) + # ↓↓↓ these two are for dim == "freqs" + elif copy: + return EpochsTFR(inst=state, method=None, freqs=None) + else: + self._data = np.expand_dims(data, axis=axis) + self._freqs = freqs + return self + + @verbose + def drop(self, indices, reason="USER", verbose=None): + """Drop epochs based on indices or boolean mask. + + .. note:: The indices refer to the current set of undropped epochs + rather than the complete set of dropped and undropped epochs. + They are therefore not necessarily consistent with any + external indices (e.g., behavioral logs). To drop epochs + based on external criteria, do not use the ``preload=True`` + flag when constructing an Epochs object, and call this + method before calling the :meth:`mne.Epochs.drop_bad` or + :meth:`mne.Epochs.load_data` methods. + + Parameters + ---------- + indices : array of int or bool + Set epochs to remove by specifying indices to remove or a boolean + mask to apply (where True values get removed). Events are + correspondingly modified. + reason : str + Reason for dropping the epochs ('ECG', 'timeout', 'blink' etc). + Default: 'USER'. + %(verbose)s + + Returns + ------- + epochs : instance of Epochs or EpochsTFR + The epochs with indices dropped. Operates in-place. + """ + from ..epochs import BaseEpochs + + BaseEpochs.drop(self, indices=indices, reason=reason, verbose=verbose) + + return self + + def iter_evoked(self, copy=False): + """Iterate over EpochsTFR to yield a sequence of AverageTFR objects. + + The AverageTFR objects will each contain a single epoch (i.e., no averaging is + performed). This method resets the EpochTFR instance's iteration state to the + first epoch. + + Parameters + ---------- + copy : bool + Whether to yield copies of the data and measurement info, or views/pointers. + """ + self.__iter__() + state = self.__getstate__() + state["inst_type_str"] = "Evoked" + state["dims"] = state["dims"][1:] # drop "epochs" + + while True: + try: + data, event_id = self.__next__(return_event_id=True) + except StopIteration: + break + if copy: + state["info"] = deepcopy(self.info) + state["data"] = data.copy() + else: + state["data"] = data + state["nave"] = 1 + yield AverageTFR(inst=state, method=None, freqs=None, comment=str(event_id)) + + @verbose + @copy_doc(BaseTFR.plot) + def plot( + self, + picks=None, + *, + exclude=(), + tmin=None, + tmax=None, + fmin=None, + fmax=None, + baseline=None, + mode="mean", + dB=False, + combine=None, + layout=None, # TODO deprecate; not used in orig implementation + yscale="auto", + vmin=None, + vmax=None, + vlim=(None, None), + cnorm=None, + cmap=None, + colorbar=True, + title=None, # don't deprecate this one; has (useful) option title="auto" + mask=None, + mask_style=None, + mask_cmap="Greys", + mask_alpha=0.1, + axes=None, + show=True, + verbose=None, + ): + singleton_epoch = self._check_singleton() + return singleton_epoch.plot( + picks=picks, + exclude=exclude, + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, + dB=dB, + combine=combine, + layout=layout, + yscale=yscale, + vmin=vmin, + vmax=vmax, + vlim=vlim, + cnorm=cnorm, + cmap=cmap, + colorbar=colorbar, + title=title, + mask=mask, + mask_style=mask_style, + mask_cmap=mask_cmap, + mask_alpha=mask_alpha, + axes=axes, + show=show, + verbose=verbose, + ) + + @verbose + @copy_doc(BaseTFR.plot_topo) + def plot_topo( + self, + picks=None, + baseline=None, + mode="mean", + tmin=None, + tmax=None, + fmin=None, + fmax=None, + vmin=None, # TODO deprecate in favor of `vlim` (needs helper func refactor) + vmax=None, + layout=None, + cmap=None, + title=None, # don't deprecate; topo titles aren't standard (color, size, just.) + dB=False, + colorbar=True, + layout_scale=0.945, + show=True, + border="none", + fig_facecolor="k", + fig_background=None, + font_color="w", + yscale="auto", + verbose=None, + ): + singleton_epoch = self._check_singleton() + return singleton_epoch.plot_topo( + picks=picks, + baseline=baseline, + mode=mode, + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + vmin=vmin, + vmax=vmax, + layout=layout, + cmap=cmap, + title=title, + dB=dB, + colorbar=colorbar, + layout_scale=layout_scale, + show=show, + border=border, + fig_facecolor=fig_facecolor, + fig_background=fig_background, + font_color=font_color, + yscale=yscale, + verbose=verbose, + ) + + @verbose + @copy_doc(BaseTFR.plot_joint) + def plot_joint( + self, + *, + timefreqs=None, + picks=None, + exclude=(), + combine="mean", + tmin=None, + tmax=None, + fmin=None, + fmax=None, + baseline=None, + mode="mean", + dB=False, + yscale="auto", + vmin=None, + vmax=None, + vlim=(None, None), + cnorm=None, + cmap=None, + colorbar=True, + title=None, + show=True, + topomap_args=None, + image_args=None, + verbose=None, + ): + singleton_epoch = self._check_singleton() + return singleton_epoch.plot_joint( + timefreqs=timefreqs, + picks=picks, + exclude=exclude, + combine=combine, + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, + dB=dB, + yscale=yscale, + vmin=vmin, + vmax=vmax, + vlim=vlim, + cnorm=cnorm, + cmap=cmap, + colorbar=colorbar, + title=title, + show=show, + topomap_args=topomap_args, + image_args=image_args, + verbose=verbose, + ) + + @copy_doc(BaseTFR.plot_topomap) + def plot_topomap( + self, + tmin=None, + tmax=None, + fmin=0.0, + fmax=np.inf, + *, + ch_type=None, + baseline=None, + mode="mean", + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=2, + cmap=None, + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%1.1e", + units=None, + axes=None, + show=True, + ): + singleton_epoch = self._check_singleton() + return singleton_epoch.plot_topomap( + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + ch_type=ch_type, + baseline=baseline, + mode=mode, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + show=show, + ) @fill_doc -class EpochsTFR(_BaseTFR, GetEpochsMixin): - """Container for Time-Frequency data on epochs. - - Can for example store induced power at sensor level. +class EpochsTFRArray(EpochsTFR): + """Data object for *precomputed* spectrotemporal representations of epoched data. Parameters ---------- %(info_not_none)s - data : ndarray, shape (n_epochs, n_channels, n_freqs, n_times) - The data. - times : ndarray, shape (n_times,) - The time values in seconds. - freqs : ndarray, shape (n_freqs,) - The frequencies in Hz. - comment : str | None, default None - Comment on the data, e.g., the experimental condition. - method : str | None, default None - Comment on the method used to compute the data, e.g., morlet wavelet. - events : ndarray, shape (n_events, 3) | None - The events as stored in the Epochs class. If None (default), all event - values are set to 1 and event time-samples are set to range(n_epochs). - event_id : dict | None - Example: dict(auditory=1, visual=3). They keys can be used to access - associated events. If None, all events will be used and a dict is - created with string integer names corresponding to the event id - integers. - selection : iterable | None - Iterable of indices of selected epochs. If ``None``, will be - automatically generated, corresponding to all non-zero events. - - .. versionadded:: 0.23 - drop_log : tuple | None - Tuple of tuple of strings indicating which epochs have been marked to - be ignored. - - .. versionadded:: 0.23 - metadata : instance of pandas.DataFrame | None - A :class:`pandas.DataFrame` containing pertinent information for each - trial. See :class:`mne.Epochs` for further details. - %(verbose)s + %(data_tfr)s + %(times)s + %(freqs_tfr_array)s + %(comment_tfr_attr)s + %(method_tfr_array)s + %(events_epochstfr)s + %(event_id_epochstfr)s + %(selection)s + %(drop_log)s + %(metadata_epochstfr)s Attributes ---------- + %(baseline_tfr_attr)s + %(ch_names_tfr_attr)s + %(comment_tfr_attr)s + %(drop_log)s + %(event_id_attr)s + %(events_attr)s + %(freqs_tfr_attr)s %(info_not_none)s - ch_names : list - The names of the channels. - data : ndarray, shape (n_epochs, n_channels, n_freqs, n_times) - The data array. - times : ndarray, shape (n_times,) - The time values in seconds. - freqs : ndarray, shape (n_freqs,) - The frequencies in Hz. - comment : string - Comment on dataset. Can be the condition. - method : str | None, default None - Comment on the method used to compute the data, e.g., morlet wavelet. - events : ndarray, shape (n_events, 3) | None - Array containing sample information as event_id - event_id : dict | None - Names of conditions correspond to event_ids - selection : array - List of indices of selected events (not dropped or ignored etc.). For - example, if the original event array had 4 events and the second event - has been dropped, this attribute would be np.array([0, 2, 3]). - drop_log : tuple of tuple - A tuple of the same length as the event array used to initialize the - ``EpochsTFR`` object. If the i-th original event is still part of the - selection, drop_log[i] will be an empty tuple; otherwise it will be - a tuple of the reasons the event is not longer in the selection, e.g.: - - - ``'IGNORED'`` - If it isn't part of the current subset defined by the user - - ``'NO_DATA'`` or ``'TOO_SHORT'`` - If epoch didn't contain enough data names of channels that - exceeded the amplitude threshold - - ``'EQUALIZED_COUNTS'`` - See :meth:`~mne.Epochs.equalize_event_counts` - - ``'USER'`` - For user-defined reasons (see :meth:`~mne.Epochs.drop`). - - metadata : pandas.DataFrame, shape (n_events, n_cols) | None - DataFrame containing pertinent information for each trial + %(metadata_attr)s + %(method_tfr_attr)s + %(selection_attr)s + %(sfreq_tfr_attr)s + %(shape_tfr_attr)s - Notes - ----- - .. versionadded:: 0.13.0 + See Also + -------- + AverageTFR + mne.Epochs.compute_tfr + mne.Evoked.compute_tfr """ - @verbose def __init__( self, info, data, times, freqs, + *, comment=None, method=None, events=None, @@ -2709,206 +3777,204 @@ def __init__( selection=None, drop_log=None, metadata=None, - verbose=None, ): - super().__init__() - self.info = info - if data.ndim != 4: - raise ValueError("data should be 4d. Got %d." % data.ndim) - n_epochs, n_channels, n_freqs, n_times = data.shape - if n_channels != len(info["chs"]): - raise ValueError( - "Number of channels and data size don't match" - " (%d != %d)." % (n_channels, len(info["chs"])) - ) - if n_freqs != len(freqs): - raise ValueError( - "Number of frequencies and data size don't match" - " (%d != %d)." % (n_freqs, len(freqs)) - ) - if n_times != len(times): - raise ValueError( - "Number of times and data size don't match" - " (%d != %d)." % (n_times, len(times)) - ) - if events is None: - n_epochs = len(data) - events = _gen_events(n_epochs) - if selection is None: - n_epochs = len(data) - selection = np.arange(n_epochs) - if drop_log is None: - n_epochs_prerejection = max(len(events), max(selection) + 1) - drop_log = tuple( - () if k in selection else ("IGNORED",) - for k in range(n_epochs_prerejection) - ) - else: - drop_log = drop_log - # check consistency: - assert len(selection) == len(events) - assert len(drop_log) >= len(events) - assert len(selection) == sum(len(dl) == 0 for dl in drop_log) - event_id = _check_event_id(event_id, events) - self.data = data - self._set_times(np.array(times, dtype=float)) - self._raw_times = self.times.copy() # needed for decimate - self.freqs = np.array(freqs, dtype=float) - self.events = events - self.event_id = event_id - self.selection = selection - self.drop_log = drop_log - self.comment = comment - self.method = method - self.preload = True - self.metadata = metadata - # we need this to allow equalize_epoch_counts to work with EpochsTFRs - self._bad_dropped = True + state = dict(info=info, data=data, times=times, freqs=freqs) + optional = dict( + comment=comment, + method=method, + events=events, + event_id=event_id, + selection=selection, + drop_log=drop_log, + metadata=metadata, + ) + for name, value in optional.items(): + if value is not None: + state[name] = value + self.__setstate__(state) - @property - def _detrend_picks(self): - return list() - def __repr__(self): # noqa: D105 - s = f"time : [{self.times[0]}, {self.times[-1]}]" - s += f", freq : [{self.freqs[0]}, {self.freqs[-1]}]" - s += ", epochs : %d" % self.data.shape[0] - s += ", channels : %d" % self.data.shape[1] - s += f", ~{sizeof_fmt(self._size)}" - return "" % s +@fill_doc +class RawTFR(BaseTFR): + """Data object for spectrotemporal representations of continuous data. + + .. warning:: The preferred means of creating RawTFR objects from + :class:`~mne.io.Raw` objects is via the instance method + :meth:`~mne.io.Raw.compute_tfr`. Direct class instantiation + is not supported. - def __abs__(self): - """Take the absolute value.""" - epochs = self.copy() - epochs.data = np.abs(self.data) - return epochs + Parameters + ---------- + inst : instance of Raw + The data from which to compute the time-frequency representation. + %(method_tfr)s + %(freqs_tfr)s + %(tmin_tmax_psd)s + %(picks_good_data_noref)s + %(proj_psd)s + %(reject_by_annotation_tfr)s + %(decim_tfr)s + %(n_jobs)s + %(verbose)s + %(method_kw_tfr)s + + Attributes + ---------- + ch_names : list + The channel names. + freqs : array + Frequencies at which the amplitude, power, or fourier coefficients + have been computed. + %(info_not_none)s + method : str + The method used to compute the spectra (``'morlet'``, ``'multitaper'`` + or ``'stockwell'``). + + See Also + -------- + mne.io.Raw.compute_tfr + EpochsTFR + AverageTFR + + References + ---------- + .. footbibliography:: + """ + + def __init__( + self, + inst, + method=None, + freqs=None, + *, + tmin=None, + tmax=None, + picks=None, + proj=False, + reject_by_annotation=False, + decim=1, + n_jobs=None, + verbose=None, + **method_kw, + ): + from ..io import BaseRaw + + # dict is allowed for __setstate__ compatibility + _validate_type( + inst, (BaseRaw, dict), "object passed to RawTFR constructor", "Raw" + ) + super().__init__( + inst, + method, + freqs, + tmin=tmin, + tmax=tmax, + picks=picks, + proj=proj, + reject_by_annotation=reject_by_annotation, + decim=decim, + n_jobs=n_jobs, + verbose=verbose, + **method_kw, + ) - def average(self, method="mean", dim="epochs", copy=False): - """Average the data across epochs. + def __getitem__(self, item): + """Get RawTFR data. Parameters ---------- - method : str | callable - How to combine the data. If "mean"/"median", the mean/median - are returned. Otherwise, must be a callable which, when passed - an array of shape (n_epochs, n_channels, n_freqs, n_time) - returns an array of shape (n_channels, n_freqs, n_time). - Note that due to file type limitations, the kind for all - these will be "average". - dim : 'epochs' | 'freqs' | 'times' - The dimension along which to combine the data. - copy : bool - Whether to return a copy of the modified instance, - or modify in place. Ignored when ``dim='epochs'`` - because a new instance must be returned. + item : int | slice | array-like + Indexing is similar to a :class:`NumPy array`; see + Notes. Returns ------- - ave : instance of AverageTFR | EpochsTFR - The averaged data. + %(getitem_tfr_return)s Notes ----- - Passing in ``np.median`` is considered unsafe when there is complex - data because NumPy doesn't compute the marginal median. Numpy currently - sorts the complex values by real part and return whatever value is - computed. Use with caution. We use the marginal median in the - complex case (i.e. the median of each component separately) if - one passes in ``median``. See a discussion in scipy: + The last axis is always time, the next-to-last axis is always + frequency, and the first axis is always channel. If + ``method='multitaper'`` and ``output='complex'`` then the second axis + will be taper index. - https://github.com/scipy/scipy/pull/12676#issuecomment-783370228 - """ - _check_option("dim", dim, ("epochs", "freqs", "times")) - axis = dict(epochs=0, freqs=2, times=self.data.ndim - 1)[dim] + Integer-, list-, and slice-based indexing is possible: - # return a lambda function for computing a combination metric - # over epochs - func = _check_combine(mode=method, axis=axis) - data = func(self.data) + - ``raw_tfr[[0, 2]]`` gives the whole time-frequency plane for the + first and third channels. + - ``raw_tfr[..., :3, :]`` gives the first 3 frequency bins and all + times for all channels (and tapers, if present). + - ``raw_tfr[..., :100]`` gives the first 100 time samples in all + frequency bins for all channels (and tapers). + - ``raw_tfr[(4, 7)]`` is the same as ``raw_tfr[4, 7]``. - n_epochs, n_channels, n_freqs, n_times = self.data.shape - freqs, times = self.freqs, self.times + .. note:: - if dim == "freqs": - freqs = np.mean(self.freqs, keepdims=True) - n_freqs = 1 - elif dim == "times": - times = np.mean(self.times, keepdims=True) - n_times = 1 - if dim == "epochs": - expected_shape = self._data.shape[1:] - else: - expected_shape = (n_epochs, n_channels, n_freqs, n_times) - data = np.expand_dims(data, axis=axis) + Unlike :class:`~mne.io.Raw` objects (which returns a tuple of the + requested data values and the corresponding times), accessing + :class:`~mne.time_frequency.RawTFR` values via subscript does + **not** return the corresponding frequency bin values. If you need + them, use ``RawTFR.freqs[freq_indices]`` or + ``RawTFR.get_data(..., return_freqs=True)``. + """ + from ..io import BaseRaw - if data.shape != expected_shape: - raise RuntimeError( - f"You passed a function that resulted in data of shape " - f"{data.shape}, but it should be {expected_shape}." - ) + self._parse_get_set_params = partial(BaseRaw._parse_get_set_params, self) + return BaseRaw._getitem(self, item, return_times=False) - if dim == "epochs": - return AverageTFR( - info=self.info.copy(), - data=data, - times=times, - freqs=freqs, - nave=self.data.shape[0], - method=self.method, - comment=self.comment, - ) - elif copy: - return EpochsTFR( - info=self.info.copy(), - data=data, - times=times, - freqs=freqs, - method=self.method, - comment=self.comment, - metadata=self.metadata, - events=self.events, - event_id=self.event_id, - ) - else: - self.data = data - self._set_times(times) - self.freqs = freqs - return self + def _get_instance_data(self, time_mask, reject_by_annotation): + start, stop = np.where(time_mask)[0][[0, -1]] + rba = "NaN" if reject_by_annotation else None + data = self.inst.get_data( + self._picks, start, stop + 1, reject_by_annotation=rba + ) + # prepend a singleton "epochs" axis + return data[np.newaxis] - @verbose - def drop(self, indices, reason="USER", verbose=None): - """Drop epochs based on indices or boolean mask. - .. note:: The indices refer to the current set of undropped epochs - rather than the complete set of dropped and undropped epochs. - They are therefore not necessarily consistent with any - external indices (e.g., behavioral logs). To drop epochs - based on external criteria, do not use the ``preload=True`` - flag when constructing an Epochs object, and call this - method before calling the :meth:`mne.Epochs.drop_bad` or - :meth:`mne.Epochs.load_data` methods. +@fill_doc +class RawTFRArray(RawTFR): + """Data object for *precomputed* spectrotemporal representations of continuous data. - Parameters - ---------- - indices : array of int or bool - Set epochs to remove by specifying indices to remove or a boolean - mask to apply (where True values get removed). Events are - correspondingly modified. - reason : str - Reason for dropping the epochs ('ECG', 'timeout', 'blink' etc). - Default: 'USER'. - %(verbose)s + Parameters + ---------- + %(info_not_none)s + %(data_tfr)s + %(times)s + %(freqs_tfr_array)s + %(method_tfr_array)s - Returns - ------- - epochs : instance of Epochs or EpochsTFR - The epochs with indices dropped. Operates in-place. - """ - from ..epochs import BaseEpochs + Attributes + ---------- + %(baseline_tfr_attr)s + %(ch_names_tfr_attr)s + %(freqs_tfr_attr)s + %(info_not_none)s + %(method_tfr_attr)s + %(sfreq_tfr_attr)s + %(shape_tfr_attr)s - BaseEpochs.drop(self, indices=indices, reason=reason, verbose=verbose) + See Also + -------- + RawTFR + mne.io.Raw.compute_tfr + EpochsTFRArray + AverageTFRArray + """ - return self + def __init__( + self, + info, + data, + times, + freqs, + *, + method=None, + ): + state = dict(info=info, data=data, times=times, freqs=freqs) + if method is not None: + state["method"] = method + self.__setstate__(state) def combine_tfr(all_tfr, weights="nave"): @@ -2972,6 +4038,7 @@ def combine_tfr(all_tfr, weights="nave"): # Utils +# ↓↓↓↓↓↓↓↓↓↓↓ this is still used in _stockwell.py def _get_data(inst, return_itc): """Get data from Epochs or Evoked instance as epochs x ch x time.""" from ..epochs import BaseEpochs @@ -3065,8 +4132,7 @@ def _preproc_tfr( return data, times, freqs, vmin, vmax -# TODO: Name duplication with mne/utils/mixin.py -def _check_decim(decim): +def _ensure_slice(decim): """Aux function checking the decim parameter.""" _validate_type(decim, ("int-like", slice), "decim") if not isinstance(decim, slice): @@ -3088,10 +4154,11 @@ def write_tfrs(fname, tfr, overwrite=False, *, verbose=None): ---------- fname : path-like The file name, which should end with ``-tfr.h5``. - tfr : AverageTFR | list of AverageTFR | EpochsTFR - The TFR dataset, or list of TFR datasets, to save in one file. - Note. If .comment is not None, a name will be generated on the fly, - based on the order in which the TFR objects are passed. + tfr : RawTFR | EpochsTFR | AverageTFR | list of RawTFR | list of EpochsTFR | list of AverageTFR + The (list of) TFR object(s) to save in one file. If ``tfr.comment`` is ``None``, + a sequential numeric string name will be generated on the fly, based on the + order in which the TFR objects are passed. This can be used to selectively load + single TFR objects from the file later. %(overwrite)s %(verbose)s @@ -3102,92 +4169,112 @@ def write_tfrs(fname, tfr, overwrite=False, *, verbose=None): Notes ----- .. versionadded:: 0.9.0 - """ + """ # noqa E501 _, write_hdf5 = _import_h5io_funcs() out = [] if not isinstance(tfr, (list, tuple)): tfr = [tfr] for ii, tfr_ in enumerate(tfr): - comment = ii if tfr_.comment is None else tfr_.comment - out.append(_prepare_write_tfr(tfr_, condition=comment)) + comment = ii if getattr(tfr_, "comment", None) is None else tfr_.comment + state = tfr_.__getstate__() + if "metadata" in state: + state["metadata"] = _prepare_write_metadata(state["metadata"]) + out.append((comment, state)) write_hdf5(fname, out, overwrite=overwrite, title="mnepython", slash="replace") -def _prepare_write_tfr(tfr, condition): - """Aux function.""" - attributes = dict( - times=tfr.times, - freqs=tfr.freqs, - data=tfr.data, - info=tfr.info, - comment=tfr.comment, - method=tfr.method, - ) - if hasattr(tfr, "nave"): # if AverageTFR - attributes["nave"] = tfr.nave - elif hasattr(tfr, "events"): # if EpochsTFR - attributes["events"] = tfr.events - attributes["event_id"] = tfr.event_id - attributes["selection"] = tfr.selection - attributes["drop_log"] = tfr.drop_log - attributes["metadata"] = _prepare_write_metadata(tfr.metadata) - return condition, attributes - - @verbose def read_tfrs(fname, condition=None, *, verbose=None): - """Read TFR datasets from hdf5 file. + """Load a TFR object from disk. Parameters ---------- fname : path-like - The file name, which should end with -tfr.h5 . + Path to a TFR file in HDF5 format. condition : int or str | list of int or str | None - The condition to load. If None, all conditions will be returned. - Defaults to None. + The condition to load. If ``None``, all conditions will be returned. + Defaults to ``None``. %(verbose)s Returns ------- - tfr : AverageTFR | list of AverageTFR | EpochsTFR - Depending on ``condition`` either the TFR object or a list of multiple - TFR objects. + tfr : RawTFR | EpochsTFR | AverageTFR | list of RawTFR | list of EpochsTFR | list of AverageTFR + The loaded time-frequency object. See Also -------- + mne.time_frequency.RawTFR.save + mne.time_frequency.EpochsTFR.save + mne.time_frequency.AverageTFR.save write_tfrs Notes ----- .. versionadded:: 0.9.0 - """ - check_fname(fname, "tfr", ("-tfr.h5", "_tfr.h5")) + """ # noqa E501 read_hdf5, _ = _import_h5io_funcs() + fname = _check_fname(fname=fname, overwrite="read", must_exist=False) + valid_fnames = tuple( + f"{sep}tfr.{ext}" for sep in ("-", "_") for ext in ("h5", "hdf5") + ) + check_fname(fname, "tfr", valid_fnames) + logger.info(f"Reading {fname} ...") + hdf5_dict = read_hdf5(fname, title="mnepython", slash="replace") + # single TFR from TFR.save() + if "inst_type_str" in hdf5_dict: + inst_type_str = hdf5_dict["inst_type_str"] + Klass = dict(Epochs=EpochsTFR, Raw=RawTFR, Evoked=AverageTFR)[inst_type_str] + out = Klass(inst=hdf5_dict) + if getattr(out, "metadata", None) is not None: + out.metadata = _prepare_read_metadata(out.metadata) + return out + # maybe multiple TFRs from write_tfrs() + return _read_multiple_tfrs(hdf5_dict, condition=condition, verbose=verbose) - logger.info("Reading %s ..." % fname) - tfr_data = read_hdf5(fname, title="mnepython", slash="replace") - for k, tfr in tfr_data: + +@verbose +def _read_multiple_tfrs(tfr_data, condition=None, *, verbose=None): + """Read (possibly multiple) TFR datasets from an h5 file written by write_tfrs().""" + out = list() + keys = list() + # tfr_data is a list of (comment, tfr_dict) tuples + for key, tfr in tfr_data: + keys.append(str(key)) # auto-assigned keys are ints + is_epochs = tfr["data"].ndim == 4 + is_average = "nave" in tfr + if condition is not None: + if not is_average: + raise NotImplementedError( + "condition is only supported when reading AverageTFRs." + ) + if key != condition: + continue + tfr = dict(tfr) tfr["info"] = Info(tfr["info"]) tfr["info"]._check_consistency() if "metadata" in tfr: tfr["metadata"] = _prepare_read_metadata(tfr["metadata"]) - is_average = "nave" in tfr - if condition is not None: - if not is_average: - raise NotImplementedError( - "condition not supported when reading " "EpochsTFR." - ) - tfr_dict = dict(tfr_data) - if condition not in tfr_dict: - keys = ["%s" % k for k in tfr_dict] - raise ValueError( - 'Cannot find condition ("{}") in this file. ' - 'The file contains "{}""'.format(condition, " or ".join(keys)) + # additional keys needed for TFR __setstate__ + defaults = dict(baseline=None, data_type="Power Estimates") + if is_epochs: + Klass = EpochsTFR + defaults.update( + inst_type_str="Epochs", dims=("epoch", "channel", "freq", "time") ) - out = AverageTFR(**tfr_dict[condition]) - else: - inst = AverageTFR if is_average else EpochsTFR - out = [inst(**d) for d in list(zip(*tfr_data))[1]] + elif is_average: + Klass = AverageTFR + defaults.update(inst_type_str="Evoked", dims=("channel", "freq", "time")) + else: + Klass = RawTFR + defaults.update(inst_type_str="Raw", dims=("channel", "freq", "time")) + out.append(Klass(inst=defaults | tfr)) + if len(out) == 0: + raise ValueError( + f'Cannot find condition "{condition}" in this file. ' + f'The file contains conditions {", ".join(keys)}' + ) + if len(out) == 1: + out = out[0] return out @@ -3196,7 +4283,7 @@ def _get_timefreqs(tfr, timefreqs): # Input check timefreq_error_msg = ( "Supplied `timefreqs` are somehow malformed. Please supply None, " - "a list of tuple pairs, or a dict of such tuple pairs, not: " + "a list of tuple pairs, or a dict of such tuple pairs, not {}" ) if isinstance(timefreqs, dict): for k, v in timefreqs.items(): @@ -3205,7 +4292,7 @@ def _get_timefreqs(tfr, timefreqs): raise ValueError(timefreq_error_msg, item) elif timefreqs is not None: if not hasattr(timefreqs, "__len__"): - raise ValueError(timefreq_error_msg, timefreqs) + raise ValueError(timefreq_error_msg.format(timefreqs)) if len(timefreqs) == 2 and all(_is_numeric(v) for v in timefreqs): timefreqs = [tuple(timefreqs)] # stick a pair of numbers in a list else: @@ -3217,7 +4304,7 @@ def _get_timefreqs(tfr, timefreqs): ): pass else: - raise ValueError(timefreq_error_msg, item) + raise ValueError(timefreq_error_msg.format(item)) # If None, automatic identification of max peak else: @@ -3244,59 +4331,66 @@ def _get_timefreqs(tfr, timefreqs): return timefreqs -def _preproc_tfr_instance( - tfr, - picks, - tmin, - tmax, - fmin, - fmax, - vmin, - vmax, - dB, - mode, - baseline, - exclude, - copy=True, -): - """Baseline and truncate (times and freqs) a TFR instance.""" - tfr = tfr.copy() if copy else tfr - - exclude = None if picks is None else exclude - picks = _picks_to_idx(tfr.info, picks, exclude="bads") - pick_names = [tfr.info["ch_names"][pick] for pick in picks] - tfr.pick(pick_names) - - if exclude == "bads": - exclude = [ch for ch in tfr.info["bads"] if ch in tfr.info["ch_names"]] - if exclude is not None: - tfr.drop_channels(exclude) - - data, times, freqs, _, _ = _preproc_tfr( - tfr.data, - tfr.times, - tfr.freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - dB, - tfr.info["sfreq"], - copy=False, - ) - - tfr._set_times(times) - tfr.freqs = freqs - tfr.data = data - - return tfr - - def _check_tfr_complex(tfr, reason="source space estimation"): """Check that time-frequency epochs or average data is complex.""" if not np.iscomplexobj(tfr.data): raise RuntimeError(f"Time-frequency data must be complex for {reason}") + + +def _merge_if_grads(data, info, ch_type, sphere, combine=None): + if ch_type == "grad": + grad_picks = _pair_grad_sensors(info, topomap_coords=False) + pos = _find_topomap_coords(info, picks=grad_picks[::2], sphere=sphere) + grad_method = combine if isinstance(combine, str) else "rms" + data, _ = _merge_ch_data(data[grad_picks], ch_type, [], method=grad_method) + else: + pos, _ = _get_pos_outlines(info, picks=ch_type, sphere=sphere) + return data, pos + + +@verbose +def _prep_data_for_plot( + data, + times, + freqs, + *, + tmin=None, + tmax=None, + fmin=None, + fmax=None, + baseline=None, + mode=None, + dB=False, + verbose=None, +): + # baseline + copy = baseline is not None + data = rescale(data, times, baseline, mode, copy=copy, verbose=verbose) + # crop times + time_mask = np.nonzero(_time_mask(times, tmin, tmax))[0] + times = times[time_mask] + # crop freqs + freq_mask = np.nonzero(_time_mask(freqs, fmin, fmax))[0] + freqs = freqs[freq_mask] + # crop data + data = data[..., freq_mask, :][..., time_mask] + # complex amplitude → real power; real-valued data is already power (or ITC) + if np.iscomplexobj(data): + data = (data * data.conj()).real + if dB: + data = 10 * np.log10(data) + return data, times, freqs + + +def _warn_deprecated_vmin_vmax(vlim, vmin, vmax): + if vmin is not None or vmax is not None: + warning = "Parameters `vmin` and `vmax` are deprecated, use `vlim` instead." + if vlim[0] is None and vlim[1] is None: + vlim = (vmin, vmax) + else: + warning += ( + " You've also provided a (non-default) value for `vlim`, " + "so `vmin` and `vmax` will be ignored." + ) + warn(warning, FutureWarning) + return vlim diff --git a/mne/utils/__init__.pyi b/mne/utils/__init__.pyi index 3e4d1292ee2..e22d8f6166c 100644 --- a/mne/utils/__init__.pyi +++ b/mne/utils/__init__.pyi @@ -41,6 +41,7 @@ __all__ = [ "_check_if_nan", "_check_info_inv", "_check_integer_or_list", + "_check_method_kwargs", "_check_on_missing", "_check_one_ch_type", "_check_option", @@ -239,6 +240,7 @@ from .check import ( _check_if_nan, _check_info_inv, _check_integer_or_list, + _check_method_kwargs, _check_on_missing, _check_one_ch_type, _check_option, diff --git a/mne/utils/_testing.py b/mne/utils/_testing.py index b2829917f59..f0e76c70e8a 100644 --- a/mne/utils/_testing.py +++ b/mne/utils/_testing.py @@ -365,3 +365,13 @@ def _click_ch_name(fig, ch_index=0, button=1): x = bbox.intervalx.mean() y = bbox.intervaly.mean() _fake_click(fig, fig.mne.ax_main, (x, y), xform="pix", button=button) + + +def _get_suptitle(fig): + """Get fig suptitle (shim for matplotlib < 3.8.0).""" + # TODO: obsolete when minimum MPL version is 3.8 + if check_version("matplotlib", "3.8"): + return fig.get_suptitle() + else: + # unreliable hack; should work in most tests as we rarely use `sup_{x,y}label` + return fig.texts[0].get_text() diff --git a/mne/utils/check.py b/mne/utils/check.py index b703317f9d0..80d87cafd2b 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -11,6 +11,7 @@ from builtins import input # noqa: UP029 from difflib import get_close_matches from importlib import import_module +from inspect import signature from pathlib import Path import numpy as np @@ -313,10 +314,10 @@ def _check_preload(inst, msg): from ..epochs import BaseEpochs from ..evoked import Evoked from ..source_estimate import _BaseSourceEstimate - from ..time_frequency import _BaseTFR + from ..time_frequency import BaseTFR from ..time_frequency.spectrum import BaseSpectrum - if isinstance(inst, (_BaseTFR, Evoked, BaseSpectrum, _BaseSourceEstimate)): + if isinstance(inst, (BaseTFR, Evoked, BaseSpectrum, _BaseSourceEstimate)): pass else: name = "epochs" if isinstance(inst, BaseEpochs) else "raw" @@ -914,6 +915,7 @@ def _check_all_same_channel_names(instances): def _check_combine(mode, valid=("mean", "median", "std"), axis=0): + # XXX TODO Possibly de-duplicate with _make_combine_callable of mne/viz/utils.py if mode == "mean": def fun(data): @@ -1244,3 +1246,19 @@ def _import_nibabel(why="use MRI files"): except ImportError as exp: raise exp.__class__(f"nibabel is required to {why}, got:\n{exp}") from None return nib + + +def _check_method_kwargs(func, kwargs, msg=None): + """Ensure **kwargs are compatible with the function they're passed to.""" + from .misc import _pl + + valid = list(signature(func).parameters) + is_invalid = np.isin(list(kwargs), valid, invert=True) + if is_invalid.any(): + invalid_kw = np.array(list(kwargs))[is_invalid].tolist() + s = _pl(invalid_kw) + if msg is None: + msg = f'function "{func}"' + raise TypeError( + f'Got unexpected keyword argument{s} {", ".join(invalid_kw)} ' f"for {msg}." + ) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index f5d7c4f4669..c82f9d74344 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -64,6 +64,61 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # A +tfr_arithmetics_return_template = """ +Returns +------- +tfr : instance of RawTFR | instance of EpochsTFR | instance of AverageTFR + {} +""" + +tfr_add_sub_template = """ +Parameters +---------- +other : instance of RawTFR | instance of EpochsTFR | instance of AverageTFR + The TFR instance to {}. Must have the same type as ``self``, and matching + ``.times`` and ``.freqs`` attributes. + +{} +""" + +tfr_mul_truediv_template = """ +Parameters +---------- +num : int | float + The number to {} by. + +{} +""" + +tfr_arithmetics_return = tfr_arithmetics_return_template.format( + "A new TFR instance, of the same type as ``self``." +) +tfr_inplace_arithmetics_return = tfr_arithmetics_return_template.format( + "The modified TFR instance." +) + +docdict["__add__tfr"] = tfr_add_sub_template.format("add", tfr_arithmetics_return) +docdict["__iadd__tfr"] = tfr_add_sub_template.format( + "add", tfr_inplace_arithmetics_return +) +docdict["__imul__tfr"] = tfr_mul_truediv_template.format( + "multiply", tfr_inplace_arithmetics_return +) +docdict["__isub__tfr"] = tfr_add_sub_template.format( + "subtract", tfr_inplace_arithmetics_return +) +docdict["__itruediv__tfr"] = tfr_mul_truediv_template.format( + "divide", tfr_inplace_arithmetics_return +) +docdict["__mul__tfr"] = tfr_mul_truediv_template.format( + "multiply", tfr_arithmetics_return +) +docdict["__sub__tfr"] = tfr_add_sub_template.format("subtract", tfr_arithmetics_return) +docdict["__truediv__tfr"] = tfr_mul_truediv_template.format( + "divide", tfr_arithmetics_return +) + + docdict["accept"] = """ accept : bool If True (default False), accept the license terms of this dataset. @@ -303,42 +358,67 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ _axes_base = """\ -{} : instance of Axes | {}None - The axes to plot to. If ``None``, a new :class:`~matplotlib.figure.Figure` - will be created{}. {}Default is ``None``. -""" -_axes_num = ( - "If :class:`~matplotlib.axes.Axes` are provided (either as a " - "single instance or a :class:`list` of axes), the number of axes " - "provided must {}." -) +{param} : instance of Axes | {allowed}None + The axes to plot into. If ``None``, a new :class:`~matplotlib.figure.Figure` + will be created{created}. {list_extra}{extra}Default is ``None``. +""" _axes_list = _axes_base.format( - "{}", "list of Axes | ", " with the correct number of axes", _axes_num + param="{param}", + allowed="list of Axes | ", + created=" with the correct number of axes", + list_extra="""If :class:`~matplotlib.axes.Axes` + are provided (either as a single instance or a :class:`list` of axes), + the number of axes provided must {must}. """, + extra="{extra}", +) +_match_chtypes_present_in = "match the number of channel types present in the {}object." +docdict["ax_plot_psd"] = _axes_list.format( + param="ax", must=_match_chtypes_present_in.format(""), extra="" +) +docdict["axes_cov_plot_topomap"] = _axes_list.format( + param="axes", must="be length 1", extra="" ) -_ch_types_present = "match the number of channel types present in the {}" "object." -docdict["ax_plot_psd"] = _axes_list.format("ax", _ch_types_present.format("")) -docdict["axes_cov_plot_topomap"] = _axes_list.format("axes", "be length 1") docdict["axes_evoked_plot_topomap"] = _axes_list.format( - "axes", "match the number of ``times`` provided (unless ``times`` is ``None``)" + param="axes", + must="match the number of ``times`` provided (unless ``times`` is ``None``)", + extra="", ) docdict["axes_montage"] = """ axes : instance of Axes | instance of Axes3D | None Axes to draw the sensors to. If ``kind='3d'``, axes must be an instance - of Axes3D. If None (default), a new axes will be created.""" + of Axes3D. If None (default), a new axes will be created. +""" docdict["axes_plot_projs_topomap"] = _axes_list.format( - "axes", "match the number of projectors" + param="axes", + must="match the number of projectors", + extra="", +) +docdict["axes_plot_topomap"] = _axes_base.format( + param="axes", + allowed="", + created="", + list_extra="", + extra="", ) -docdict["axes_plot_topomap"] = _axes_base.format("axes", "", "", "") docdict["axes_spectrum_plot"] = _axes_list.format( - "axes", _ch_types_present.format(":class:`~mne.time_frequency.Spectrum`") + param="axes", + must=_match_chtypes_present_in.format(":class:`~mne.time_frequency.Spectrum` "), + extra="", ) docdict["axes_spectrum_plot_topo"] = _axes_list.format( - "axes", - "be length 1 (for efficiency, subplots for each channel are simulated " + param="axes", + must="be length 1 (for efficiency, subplots for each channel are simulated " "within a single :class:`~matplotlib.axes.Axes` object)", + extra="", ) docdict["axes_spectrum_plot_topomap"] = _axes_list.format( - "axes", "match the length of ``bands``" + param="axes", must="match the length of ``bands``", extra="" +) +docdict["axes_tfr_plot"] = _axes_list.format( + param="axes", + must="match the number of picks", + extra="""If ``combine`` is not None, + ``axes`` must either be an instance of Axes, or a list of length 1. """, ) docdict["axis_facecolor"] = """\ @@ -396,11 +476,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): If a tuple ``(a, b)``, the interval is between ``a`` and ``b`` (in seconds), including the endpoints. If ``a`` is ``None``, the **beginning** of the data is used; and if ``b`` - is ``None``, it is set to the **end** of the interval. + is ``None``, it is set to the **end** of the data. If ``(None, None)``, the entire time interval is used. - .. note:: The baseline ``(a, b)`` includes both endpoints, i.e. all - timepoints ``t`` such that ``a <= t <= b``. + .. note:: + The baseline ``(a, b)`` includes both endpoints, i.e. all timepoints ``t`` + such that ``a <= t <= b``. """ docdict["baseline_epochs"] = f"""{_baseline_rescale_base} @@ -448,12 +529,21 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ +docdict["baseline_tfr_attr"] = """ +baseline : array-like, shape (2,) + The start and end times of the baseline period, in seconds.""" + + docdict["block"] = """\ block : bool Whether to halt program execution until the figure is closed. May not work on all systems / platforms. Defaults to ``False``. """ +docdict["border_topo"] = """ +border : str + Matplotlib border style to be used for each sensor plot. +""" docdict["border_topomap"] = """ border : float | 'mean' Value to extrapolate to on the topomap borders. If ``'mean'`` (default), @@ -560,6 +650,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): description=['Start', 'BAD_flux', 'BAD_noise'], ch_names=[[], ['MEG0111', 'MEG2563'], ['MEG1443']]) """ +docdict["ch_names_tfr_attr"] = """ +ch_names : list + The channel names.""" docdict["ch_type_set_eeg_reference"] = """ ch_type : list of str | str @@ -652,13 +745,19 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``pos_lims``, as the surface plot must show the magnitude. """ -docdict["cmap"] = """ -cmap : matplotlib colormap | str | None - The :class:`~matplotlib.colors.Colormap` to use. Defaults to ``None``, which - will use the matplotlib default colormap. +_cmap_template = """ +cmap : matplotlib colormap | str{allowed} + The :class:`~matplotlib.colors.Colormap` to use. If a :class:`str`, must be a + valid Matplotlib colormap name. Default is {default}. """ - -docdict["cmap_topomap"] = """ +docdict["cmap"] = _cmap_template.format( + allowed=" | None", + default="``None``, which will use the Matplotlib default colormap", +) +docdict["cmap_tfr_plot_topo"] = _cmap_template.format( + allowed="", default='``"RdBu_r"``' +) +docdict["cmap_topomap"] = """\ cmap : matplotlib colormap | (colormap, bool) | 'interactive' | None Colormap to use. If :class:`tuple`, the first value indicates the colormap to use and the second value is a boolean defining interactivity. In @@ -707,6 +806,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): white. """ +docdict["colorbar"] = """\ +colorbar : bool + Whether to add a colorbar to the plot. Default is ``True``. +""" +docdict["colorbar_tfr_plot_joint"] = """ +colorbar : bool + Whether to add a colorbar to the plot (for the topomap annotations). Not compatible + with user-defined ``axes``. Default is ``True``. +""" docdict["colorbar_topomap"] = """ colorbar : bool Plot a colorbar in the rightmost column of the figure. @@ -720,27 +828,29 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ _combine_template = """ -combine : 'mean' | {literals} | callable | None - How to aggregate across channels. If ``None``, {none}. If a string, +combine : 'mean' | {literals} | callable{none} + How to aggregate across channels. {none_sentence}If a string, ``"mean"`` uses :func:`numpy.mean`, {other_string}. If :func:`callable`, it must operate on an :class:`array ` of shape ``({shape})`` and return an array of shape - ``({return_shape})``. {example} - {notes}Defaults to ``None``. + ``({return_shape})``. {example}{notes}Defaults to {default}. """ _example = """For example:: combine = lambda data: np.median(data, axis=1) -""" + + """ # ← the 4 trailing spaces are intentional here! _median_std_gfp = """``"median"`` computes the `marginal median `__, ``"std"`` uses :func:`numpy.std`, and ``"gfp"`` computes global field power for EEG channels and RMS amplitude for MEG channels""" +_none_default = dict(none=" | None", default="``None``") docdict["combine_plot_compare_evokeds"] = _combine_template.format( literals="'median' | 'std' | 'gfp'", - none="""channels are combined by + **_none_default, + none_sentence="""If ``None``, channels are combined by computing GFP/RMS, unless ``picks`` is a single channel (not channel type) - or ``axes="topo"``, in which cases no combining is performed""", + or ``axes="topo"``, in which cases no combining is performed. """, other_string=_median_std_gfp, shape="n_evokeds, n_channels, n_times", return_shape="n_evokeds, n_times", @@ -749,16 +859,54 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ) docdict["combine_plot_epochs_image"] = _combine_template.format( literals="'median' | 'std' | 'gfp'", - none="""channels are combined by + **_none_default, + none_sentence="""If ``None``, channels are combined by computing GFP/RMS, unless ``group_by`` is also ``None`` and ``picks`` is a list of specific channels (not channel types), in which case no combining - is performed and each channel gets its own figure""", + is performed and each channel gets its own figure. """, other_string=_median_std_gfp, shape="n_epochs, n_channels, n_times", return_shape="n_epochs, n_times", example=_example, notes="See Notes for further details. ", ) +docdict["combine_tfr_plot"] = _combine_template.format( + literals="'rms'", + **_none_default, + none_sentence="If ``None``, plot one figure per selected channel. ", + shape="n_channels, n_freqs, n_times", + return_shape="n_freqs, n_times", + other_string='``"rms"`` computes the root-mean-square', + example="", + notes="", +) +docdict["combine_tfr_plot_joint"] = _combine_template.format( + literals="'rms'", + none="", + none_sentence="", + shape="n_channels, n_freqs, n_times", + return_shape="n_freqs, n_times", + other_string='``"rms"`` computes the root-mean-square', + example="", + notes="", + default='``"mean"``', +) + +_comment_template = """ +comment : str{or_none} + Comment on the data, e.g., the experimental condition(s){avgd}.{extra}""" +docdict["comment_averagetfr"] = _comment_template.format( + or_none=" | None", + avgd="averaged", + extra="""Default is ``None`` + which is replaced with ``inst.comment`` (for :class:`~mne.Evoked` instances) + or a comma-separated string representation of the keys in ``inst.event_id`` + (for :class:`~mne.Epochs` instances).""", +) +docdict["comment_averagetfr_attr"] = _comment_template.format( + or_none="", avgd=" averaged", extra="" +) +docdict["comment_tfr_attr"] = _comment_template.format(or_none="", avgd="", extra="") docdict["compute_proj_ecg"] = """This function will: @@ -850,11 +998,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # D -_dB = """\ +_dB = """ dB : bool Whether to plot on a decibel-like scale. If ``True``, plots - 10 × log₁₀(spectral power){}.{} + 10 × log₁₀({quantity}){caveat}.{extra} """ +_ignored_if_normalize = " Ignored if ``normalize=True``." +_psd = "spectral power" docdict["dB_plot_psd"] = """\ dB : bool @@ -867,10 +1017,21 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``dB=True`` and ``estimate='amplitude'``. """ docdict["dB_plot_topomap"] = _dB.format( - " following the application of ``agg_fun``", " Ignored if ``normalize=True``." + quantity=_psd, + caveat=" following the application of ``agg_fun``", + extra=_ignored_if_normalize, ) -docdict["dB_spectrum_plot"] = _dB.format("", "") -docdict["dB_spectrum_plot_topo"] = _dB.format("", " Ignored if ``normalize=True``.") +docdict["dB_spectrum_plot"] = _dB.format(quantity=_psd, caveat="", extra="") +docdict["dB_spectrum_plot_topo"] = _dB.format( + quantity=_psd, caveat="", extra=_ignored_if_normalize +) +docdict["dB_tfr_plot_topo"] = _dB.format(quantity="data", caveat="", extra="") + +_data_template = """ +data : ndarray, shape ({}) + The data. +""" +docdict["data_tfr"] = _data_template.format("n_channels, n_freqs, n_times") docdict["daysback_anonymize_info"] = """ daysback : int | None @@ -916,12 +1077,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ docdict["decim_tfr"] = """ -decim : int | slice, default 1 - To reduce memory usage, decimation factor after time-frequency - decomposition. +decim : int | slice + Decimation factor, applied *after* time-frequency decomposition. - - if `int`, returns ``tfr[..., ::decim]``. - - if `slice`, returns ``tfr[..., decim]``. + - if :class:`int`, returns ``tfr[..., ::decim]`` (keep only every Nth + sample along the time axis). + - if :class:`slice`, returns ``tfr[..., decim]`` (keep only the specified + slice along the time axis). .. note:: Decimation is done after convolutions and may create aliasing @@ -1002,8 +1164,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["drop_log"] = """ drop_log : tuple | None Tuple of tuple of strings indicating which epochs have been marked to - be ignored. -""" + be ignored.""" docdict["dtype_applyfun"] = """ dtype : numpy.dtype @@ -1151,11 +1312,21 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): and then the IDs must be the name(s) of the annotations to use. If None, all :term:`events` will be used and a dict is created with string integer names corresponding to the event id integers.""" - +_event_id_template = """ +event_id : dict{or_none} + Mapping from condition descriptions (strings) to integer event codes.{extra}""" +docdict["event_id_attr"] = _event_id_template.format(or_none="", extra="") docdict["event_id_ecg"] = """ event_id : int The index to assign to found ECG events. """ +docdict["event_id_epochstfr"] = _event_id_template.format( + or_none=" | None", + extra="""If ``None``, + all events in ``events`` will be included, and the ``event_id`` attribute + will be a :class:`dict` mapping a string version of each integer event ID + to the corresponding integer.""", +) docdict["event_repeated_epochs"] = """ event_repeated : str @@ -1167,19 +1338,28 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionadded:: 0.19 """ -docdict["events"] = """ -events : array of int, shape (n_events, 3) - The array of :term:`events`. The first column contains the event time in - samples, with :term:`first_samp` included. The third column contains the - event id.""" - -docdict["events_epochs"] = """ -events : array of int, shape (n_events, 3) - The array of :term:`events`. The first column contains the event time in - samples, with :term:`first_samp` included. The third column contains the - event id. - If some events don't match the events of interest as specified by - ``event_id``, they will be marked as ``IGNORED`` in the drop log.""" +_events_template = """ +events : ndarray of int, shape (n_events, 3){or_none} + The identity and timing of experimental events, around which the epochs were + created. See :term:`events` for more information.{extra} +""" +docdict["events"] = _events_template.format(or_none="", extra="") +docdict["events_attr"] = """ +events : ndarray of int, shape (n_events, 3) + The events array.""" +docdict["events_epochs"] = _events_template.format( + or_none="", + extra="""Events that don't match + the events of interest as specified by ``event_id`` will be marked as + ``IGNORED`` in the drop log.""", +) +docdict["events_epochstfr"] = _events_template.format( + or_none=" | None", + extra="""If ``None``, all integer + event codes are set to ``1`` (i.e., all epochs are assumed to be of the same + type) and their corresponding sample numbers are set as arbitrary, equally + spaced sample numbers with a step size of ``len(times)``.""", +) docdict["evoked_by_event_type_returns"] = """ evoked : instance of Evoked | list of Evoked @@ -1402,10 +1582,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): and if absent, falls back to ``'estimated'``. """ -docdict["fig_facecolor"] = """\ +docdict["fig_background"] = """ +fig_background : None | array + A background image for the figure. This must be a valid input to + :func:`matplotlib.pyplot.imshow`. Defaults to ``None``. +""" +docdict["fig_facecolor"] = """ fig_facecolor : str | tuple - A matplotlib-compatible color to use for the figure background. - Defaults to black. + A matplotlib-compatible color to use for the figure background. Defaults to black. """ docdict["filter_length"] = """ @@ -1511,6 +1695,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ) docdict["fmin_fmax_psd_topo"] = _fmin_fmax.format("``fmin=0, fmax=100``.") +docdict["fmin_fmax_tfr"] = _fmin_fmax.format( + """``None`` + which is equivalent to ``fmin=0, fmax=np.inf`` (spans all frequencies + present in the data).""" +) docdict["fmin_fmid_fmax"] = """ fmin : float @@ -1560,17 +1749,37 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): mass of the visible bounds. """ +docdict["font_color"] = """ +font_color : color + The color of tick labels in the colorbar. Defaults to white. +""" + docdict["forward_set_eeg_reference"] = """ forward : instance of Forward | None Forward solution to use. Only used with ``ref_channels='REST'``. .. versionadded:: 0.21 """ - -docdict["freqs_tfr"] = """ -freqs : array of float, shape (n_freqs,) - The frequencies of interest in Hz. -""" +_freqs_tfr_template = """ +freqs : array-like |{auto} None + The frequencies at which to compute the power estimates. + {stockwell} be an array of shape (n_freqs,). ``None`` (the + default) only works when using ``__setstate__`` and will raise an error otherwise. +""" +docdict["freqs_tfr"] = _freqs_tfr_template.format(auto="", stockwell="Must") +docdict["freqs_tfr_array"] = """ +freqs : ndarray, shape (n_freqs,) + The frequencies in Hz. +""" +docdict["freqs_tfr_attr"] = """ +freqs : array + Frequencies at which power has been computed.""" +docdict["freqs_tfr_epochs"] = _freqs_tfr_template.format( + auto=" 'auto' | ", + stockwell="""If ``method='stockwell'`` this must be a length 2 iterable specifying lowest + and highest frequencies, or ``'auto'`` (to use all available frequencies). + For other methods, must""", # noqa E501 +) docdict["fullscreen"] = """ fullscreen : bool @@ -1660,17 +1869,28 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): (False, default). """ -_getitem_base = """\ +_getitem_spectrum_base = """ data : ndarray The selected spectral data. Shape will be - ``({}n_channels, n_freqs)`` for normal power spectra, - ``({}n_channels, n_freqs, n_segments)`` for unaggregated - Welch estimates, or ``({}n_channels, n_tapers, n_freqs)`` + ``({n_epo}n_channels, n_freqs)`` for normal power spectra, + ``({n_epo}n_channels, n_freqs, n_segments)`` for unaggregated + Welch estimates, or ``({n_epo}n_channels, n_tapers, n_freqs)`` for unaggregated multitaper estimates. """ -_fill_epochs = ["n_epochs, "] * 3 -docdict["getitem_epochspectrum_return"] = _getitem_base.format(*_fill_epochs) -docdict["getitem_spectrum_return"] = _getitem_base.format("", "", "") +_getitem_tfr_base = """ +data : ndarray + The selected time-frequency data. Shape will be + ``({n_epo}n_channels, n_freqs, n_times)`` for Morlet, Stockwell, and aggregated + (``output='power'``) multitaper methods, or + ``({n_epo}n_channels, n_tapers, n_freqs, n_times)`` for unaggregated + (``output='complex'``) multitaper method. +""" +n_epo = "n_epochs, " +docdict["getitem_epochspectrum_return"] = _getitem_spectrum_base.format(n_epo=n_epo) +docdict["getitem_epochstfr_return"] = _getitem_tfr_base.format(n_epo=n_epo) +docdict["getitem_spectrum_return"] = _getitem_spectrum_base.format(n_epo="") +docdict["getitem_tfr_return"] = _getitem_tfr_base.format(n_epo="") + docdict["group_by_browse"] = """ group_by : str @@ -1822,6 +2042,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): For more information, see :func:`mne.filter.construct_iir_filter`. """ +docdict["image_args"] = """ +image_args : dict | None + Keyword arguments to pass to :meth:`mne.time_frequency.AverageTFR.plot`. ``axes`` + and ``show`` are ignored. Defaults to ``None`` (i.e., and empty :class:`dict`). +""" + docdict["image_format_report"] = """ image_format : 'png' | 'svg' | 'gif' | None The image format to be used for the report, can be ``'png'``, @@ -1889,6 +2115,10 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): (e.g. :class:`mne.io.Raw`). """ +docdict["inst_tfr"] = """ +inst : instance of RawTFR, EpochsTFR, or AverageTFR +""" + docdict["int_order_maxwell"] = """ int_order : int Order of internal component of spherical expansion. @@ -1944,6 +2174,10 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Defaults to ``'matrix'``. """ +docdict["item"] = """ +item : int | slice | array-like | str +""" + # %% # J @@ -2071,12 +2305,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. versionchanged:: 0.21.0 Support for volume source estimates. """ - +docdict["layout_scale"] = """ +layout_scale : float + Scaling factor for adjusting the relative size of the layout on the canvas. +""" docdict["layout_spectrum_plot_topo"] = """\ layout : instance of Layout | None Layout instance specifying sensor positions (does not need to be specified for Neuromag data). If ``None`` (default), the layout is - inferred from the data. + inferred from the data (if possible). """ docdict["line_alpha_plot_psd"] = """\ @@ -2157,14 +2394,24 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): with the parameters given in ``mask_params``. Defaults to ``None``, equivalent to an array of all ``False`` elements. """ - +docdict["mask_alpha_tfr_plot"] = """ +mask_alpha : float + Relative opacity of the masked region versus the unmasked region, given as a + :class:`float` between 0 and 1 (i.e., 0 means masked areas are not visible at all). + Defaults to ``0.1``. +""" +docdict["mask_cmap_tfr_plot"] = """ +mask_cmap : matplotlib colormap | str | None + Colormap to use for masked areas of the plot. If a :class:`str`, must be a valid + Matplotlib colormap name. If None, ``cmap`` is used for both masked and unmasked + areas. Ignored if ``mask`` is ``None``. Default is ``'Greys'``. +""" docdict["mask_evoked_topomap"] = _mask_base.format( shape="(n_channels, n_times)", shape_appendix="-time combinations", example=" (useful for, e.g. marking which channels at which times a " "statistical test of the data reaches significance)", ) - docdict["mask_params_topomap"] = """ mask_params : dict | None Additional plotting parameters for plotting significant sensors. @@ -2173,11 +2420,25 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): dict(marker='o', markerfacecolor='w', markeredgecolor='k', linewidth=0, markersize=4) """ - docdict["mask_patterns_topomap"] = _mask_base.format( shape="(n_channels, n_patterns)", shape_appendix="-pattern combinations", example="" ) - +docdict["mask_style_tfr_plot"] = """ +mask_style : None | 'both' | 'contour' | 'mask' + How to distinguish the masked/unmasked regions of the plot. If ``"contour"``, a + line is drawn around the areas where the mask is ``True``. If ``"mask"``, areas + where the mask is ``False`` will be (partially) transparent, as determined by + ``mask_alpha``. If ``"both"``, both a contour and transparency are used. Default is + ``None``, which is silently ignored if ``mask`` is ``None`` and is interpreted like + ``"both"`` otherwise. +""" +docdict["mask_tfr_plot"] = """ +mask : ndarray | None + An :class:`array ` of :class:`boolean ` values, of the same + shape as the data. Data that corresponds to ``False`` entries in the mask are + plotted differently, as determined by ``mask_style``, ``mask_alpha``, and + ``mask_cmap``. Useful for, e.g., highlighting areas of statistical significance. +""" docdict["mask_topomap"] = _mask_base.format( shape="(n_channels,)", shape_appendix="(s)", example="" ) @@ -2250,19 +2511,26 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Added support for specifying alpha values as a dict. """ -docdict["metadata_epochs"] = """ +_metadata_attr_template = """ metadata : instance of pandas.DataFrame | None - A :class:`pandas.DataFrame` specifying metadata about each epoch. - If given, ``len(metadata)`` must equal ``len(events)``. The DataFrame - may only contain values of type (str | int | float | bool). - If metadata is given, then pandas-style queries may be used to select - subsets of data, see :meth:`mne.Epochs.__getitem__`. - When a subset of the epochs is created in this (or any other - supported) manner, the metadata object is subsetted accordingly, and - the row indices will be modified to match ``epochs.selection``. - - .. versionadded:: 0.16 -""" + A :class:`pandas.DataFrame` specifying metadata about each epoch{or_none}.{extra} +""" +_metadata_template = _metadata_attr_template.format( + or_none="", + extra=""" + If not ``None``, ``len(metadata)`` must equal ``len(events)``. For + save/load compatibility, the :class:`~pandas.DataFrame` may only contain + :class:`str`, :class:`int`, :class:`float`, and :class:`bool` values. + If not ``None``, then pandas-style queries may be used to select + subsets of data, see :meth:`mne.Epochs.__getitem__`. When the {obj} object + is subsetted, the metadata is subsetted accordingly, and the row indices + will be modified to match ``{obj}.selection``.""", +) +docdict["metadata_attr"] = _metadata_attr_template.format( + or_none=" (or ``None``)", extra="" +) +docdict["metadata_epochs"] = _metadata_template.format(obj="Epochs") +docdict["metadata_epochstfr"] = _metadata_template.format(obj="EpochsTFR") docdict["method_fir"] = """ method : str @@ -2270,6 +2538,20 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): forward-backward filtering (via :func:`~scipy.signal.filtfilt`). """ +_method_kw_tfr_template = """ +**method_kw + Additional keyword arguments passed to the spectrotemporal estimation function + (e.g., ``n_cycles, use_fft, zero_mean`` for Morlet method{stockwell} + or ``n_cycles, use_fft, zero_mean, time_bandwidth`` for multitaper method). + See :func:`~mne.time_frequency.tfr_array_morlet`{stockwell_crossref} + and :func:`~mne.time_frequency.tfr_array_multitaper` for additional details. +""" + +docdict["method_kw_epochs_tfr"] = _method_kw_tfr_template.format( + stockwell=", ``n_fft, width`` for Stockwell method,", + stockwell_crossref=", :func:`~mne.time_frequency.tfr_array_stockwell`,", +) + docdict["method_kw_psd"] = """\ **method_kw Additional keyword arguments passed to the spectral estimation @@ -2280,7 +2562,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): :func:`~mne.time_frequency.psd_array_multitaper` for details. """ -_method_psd = r""" +docdict["method_kw_tfr"] = _method_kw_tfr_template.format( + stockwell="", stockwell_crossref="" +) + +_method_psd = """ method : ``'welch'`` | ``'multitaper'``{} Spectral estimation method. ``'welch'`` uses Welch's method :footcite:p:`Welch1967`, ``'multitaper'`` uses DPSS @@ -2303,6 +2589,29 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): :func:`scipy.signal.resample` and :func:`scipy.signal.resample_poly`, respectively. """ +_method_tfr_template = """ +method : ``'morlet'`` | ``'multitaper'``{literals} | None + Spectrotemporal power estimation method. ``'morlet'`` uses Morlet wavelets, + ``'multitaper'`` uses DPSS tapers :footcite:p:`Slepian1978`{cites}. ``None`` (the + default) only works when using ``__setstate__`` and will raise an error otherwise. +""" +docdict["method_tfr"] = _method_tfr_template.format(literals="", cites="") +docdict["method_tfr_array"] = """ +method : str | None + Comment on the method used to compute the data, e.g., ``"hilbert"``. + Default is ``None``. +""" +docdict["method_tfr_attr"] = """ +method : str + The method used to compute the spectra (e.g., ``"morlet"``, ``"multitaper"`` + or ``"stockwell"``). +""" +docdict["method_tfr_epochs"] = _method_tfr_template.format( + literals=" | ``'stockwell'``", + cites=", and ``'stockwell'`` uses the S-transform " + ":footcite:p:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006`", +) + docdict["mode_eltc"] = """ mode : str Extraction mode, see Notes. @@ -2322,6 +2631,23 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): n_comp first SVD components. """ +docdict["mode_tfr_plot"] = """ +mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' + Perform baseline correction by + + - subtracting the mean of baseline values ('mean') (default) + - dividing by the mean of baseline values ('ratio') + - dividing by the mean of baseline values and taking the log + ('logratio') + - subtracting the mean of baseline values followed by dividing by + the mean of baseline values ('percent') + - subtracting the mean of baseline values and dividing by the + standard deviation of baseline values ('zscore') + - dividing by the mean of baseline values, taking the log, and + dividing by the standard deviation of log baseline values + ('zlogratio') +""" + docdict["montage"] = """ montage : None | str | DigMontage A montage containing channel positions. If a string or @@ -2446,6 +2772,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): names are plotted. """ +docdict["nave_tfr_attr"] = """ +nave : int + The number of epochs that were averaged to yield the result. This may reflect + epochs averaged *before* time-frequency analysis (as in + ``epochs.average(...).compute_tfr(...)``) or *after* time-frequency analysis (as + in ``epochs.compute_tfr(...).average(...)``). +""" docdict["nirx_notes"] = """ This function has only been tested with NIRScout and NIRSport devices, and with the NIRStar software version 15 and above and Aurora software @@ -2537,6 +2870,25 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): other things may also not work or be incorrect). """ +docdict["notes_timefreqs_tfr_plot_joint"] = """ +``timefreqs`` has three different modes: tuples, dicts, and auto. For (list of) tuple(s) +mode, each tuple defines a pair (time, frequency) in s and Hz on the TFR plot. +For example, to look at 10 Hz activity 1 second into the epoch and 3 Hz activity 300 ms +into the epoch, :: + + timefreqs=((1, 10), (.3, 3)) + +If provided as a dictionary, (time, frequency) tuples are keys and (time_window, +frequency_window) tuples are the values — indicating the width of the windows (centered +on the time and frequency indicated by the key) to be averaged over. For example, :: + + timefreqs={(1, 10): (0.1, 2)} + +would translate into a window that spans 0.95 to 1.05 seconds and 9 to 11 Hz. If +``None``, a single topomap will be plotted at the absolute peak across the +time-frequency representation. +""" + docdict["notes_tmax_included_by_default"] = """ Unlike Python slices, MNE time intervals by default include **both** their end points; ``crop(tmin, tmax)`` returns the interval @@ -2743,6 +3095,12 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Defaults to 'head'. """ +docdict["output_compute_tfr"] = """ +output : str + What kind of estimate to return. Allowed values are ``"complex"``, ``"phase"``, + and ``"power"``. Default is ``"power"``. +""" + docdict["overview_mode"] = """ overview_mode : str | None Can be "channels", "empty", or "hidden" to set the overview bar mode @@ -3331,6 +3689,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ ) +docdict["reject_by_annotation_tfr"] = """ +reject_by_annotation : bool + Whether to omit bad spans of data before spectrotemporal power + estimation. If ``True``, spans with annotations whose description + begins with ``bad`` will be represented with ``np.nan`` in the + time-frequency representation. +""" + _reject_common = """\ Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP), i.e. the absolute difference between the lowest and the highest signal @@ -3546,6 +3912,10 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Iterable of indices of selected epochs. If ``None``, will be automatically generated, corresponding to all non-zero events. """ +docdict["selection_attr"] = """ +selection : ndarray + Array of indices of *selected* epochs (i.e., epochs that were not rejected, dropped, + or ignored).""" docdict["sensor_colors"] = """ sensor_colors : array-like of color | dict | None @@ -3619,6 +3989,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. footbibliography:: """ +docdict["sfreq_tfr_attr"] = """ +sfreq : int | float + The sampling frequency (read from ``info``).""" +docdict["shape_tfr_attr"] = """ +shape : tuple of int + The shape of the data.""" + docdict["show"] = """\ show : bool Show the figure if ``True``. @@ -4149,12 +4526,27 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ``time_viewer=True`` and ``separate_canvas=False``. """ +docdict["timefreqs"] = """ +timefreqs : None | list of tuple | dict of tuple + The time-frequency point(s) for which topomaps will be plotted. See Notes. +""" + +docdict["times"] = """ +times : ndarray, shape (n_times,) + The time values in seconds. +""" + docdict["title_none"] = """ title : str | None The title of the generated figure. If ``None`` (default), no title is displayed. """ - +docdict["title_tfr_plot"] = """ +title : str | 'auto' | None + Title for the plot. If ``"auto"``, will use the channel name (if ``combine`` is + ``None``) or state the number and method of combined channels used to generate the + plot. If ``None``, no title is shown. Default is ``None``. +""" docdict["tmax_raw"] = """ tmax : float End time of the raw data to use in seconds (cannot exceed data duration). @@ -4210,10 +4602,20 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): same thresholding as :func:`scipy.linalg.orth`. """ -docdict["topomap_kwargs"] = """ -topomap_kwargs : dict | None - Keyword arguments to pass to the topomap-generating functions. +_topomap_args_template = """ +{param} : dict | None + Keyword arguments to pass to {func}.{extra} """ +docdict["topomap_args"] = _topomap_args_template.format( + param="topomap_args", + func=":func:`mne.viz.plot_topomap`", + extra=" ``axes`` and ``show`` are ignored. If ``times`` is not in this dict, " + "automatic peak detection is used. Beyond that, if ``None``, no customizable " + "arguments will be passed. Defaults to ``None`` (i.e., an empty :class:`dict`).", +) +docdict["topomap_kwargs"] = _topomap_args_template.format( + param="topomap_kwargs", func="the topomap-generating functions", extra="" +) _trans_base = """\ If str, the path to the head<->MRI transform ``*-trans.fif`` file produced @@ -4382,46 +4784,87 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): valid string options. """ -_vlim = """ -vlim : tuple of length 2{} - Colormap limits to use. If a :class:`tuple` of floats, specifies the - lower and upper bounds of the colormap (in that order); providing - ``None`` for either entry will set the corresponding boundary at the - min/max of the data{}. {}{}{}Defaults to ``(None, None)``. -""" -_vlim_joint = _vlim.format( - " | 'joint'", - " (separately for each {0})", - "{1}", - "If ``vlim='joint'``, will compute the colormap limits jointly across " - "all {0}s of the same channel type, using the min/max of the data for " - "that channel type. ", - "{2}", +_vlim = """\ +vlim : tuple of length 2{joint_param} + Lower and upper bounds of the colormap, typically a numeric value in the same + units as the data. {callable} + If both entries are ``None``, the bounds are set at {bounds}. + Providing ``None`` for just one entry will set the corresponding boundary at the + min/max of the data. {extra}Defaults to ``(None, None)``. +""" +_joint_param = ' | "joint"' +_callable_sentence = """Elements of the :class:`tuple` may also be callable functions + which take in a :class:`NumPy array ` and return a scalar. +""" +_bounds_symmetric = """± the maximum absolute value + of the data (yielding a colormap with midpoint at 0)""" +_bounds_minmax = "``(min(data), max(data))``" +_bounds_norm = "``(0, max(abs(data)))``" +_bounds_contingent = f"""{_bounds_symmetric}, or {_bounds_norm} + if the (possibly baselined) data are all-positive""" +_joint_sentence = """If ``vlim="joint"``, will compute the colormap limits + jointly across all {what}s of the same channel type (instead of separately + for each {what}), using the min/max of the data for that channel type. + {joint_extra}""" + +docdict["vlim_plot_topomap"] = _vlim.format( + joint_param="", callable="", bounds=_bounds_minmax, extra="" ) -_vlim_callable = ( - "Elements of the :class:`tuple` may also be callable functions which " - "take in a :class:`NumPy array ` and return a scalar. " +docdict["vlim_plot_topomap_proj"] = _vlim.format( + joint_param=_joint_param, + callable=_callable_sentence, + bounds=_bounds_contingent, + extra=_joint_sentence.format( + what="projector", + joint_extra='If vlim is ``"joint"``, ``info`` must not be ``None``. ', + ), ) - -docdict["vlim_plot_topomap"] = _vlim.format("", "", "", "", "") -docdict["vlim_plot_topomap_proj"] = _vlim_joint.format( - "projector", - _vlim_callable, - "If vlim is ``'joint'``, ``info`` must not be ``None``. ", +docdict["vlim_plot_topomap_psd"] = _vlim.format( + joint_param=_joint_param, + callable=_callable_sentence, + bounds=_bounds_contingent, + extra=_joint_sentence.format(what="topomap", joint_extra=""), +) +docdict["vlim_tfr_plot"] = _vlim.format( + joint_param="", callable="", bounds=_bounds_contingent, extra="" +) +docdict["vlim_tfr_plot_joint"] = _vlim.format( + joint_param="", + callable="", + bounds=_bounds_contingent, + extra="""To specify the colormap separately for the topomap annotations, + see ``topomap_args``. """, ) -docdict["vlim_plot_topomap_psd"] = _vlim_joint.format("topomap", _vlim_callable, "") -docdict["vmin_vmax_topomap"] = """ -vmin, vmax : float | callable | None +_vmin_vmax_template = """ +vmin, vmax : float | {allowed}None Lower and upper bounds of the colormap, in the same units as the data. - If ``vmin`` and ``vmax`` are both ``None``, they are set at ± the - maximum absolute value of the data (yielding a colormap with midpoint - at 0). If only one of ``vmin``, ``vmax`` is ``None``, will use - ``min(data)`` or ``max(data)``, respectively. If callable, should - accept a :class:`NumPy array ` of data and return a - float. + If ``vmin`` and ``vmax`` are both ``None``, the bounds are set at + {bounds}. If only one of ``vmin``, ``vmax`` is ``None``, will use + ``min(data)`` or ``max(data)``, respectively.{extra} """ +docdict["vmin_vmax_tfr_plot"] = """ +vmin, vmax : float | None + Lower and upper bounds of the colormap. See ``vlim``. + + .. deprecated:: 1.7 + ``vmin`` and ``vmax`` will be removed in version 1.8. + Use ``vlim`` parameter instead. +""" +# ↓↓↓ this one still used, needs helper func refactor before we can migrate to `vlim` +docdict["vmin_vmax_tfr_plot_topo"] = _vmin_vmax_template.format( + allowed="", bounds=_bounds_symmetric, extra="" +) +# ↓↓↓ this one still used in Evoked.animate_topomap(), should migrate to `vlim` +docdict["vmin_vmax_topomap"] = _vmin_vmax_template.format( + allowed="callable | ", + bounds=_bounds_symmetric, + extra=""" If callable, should accept + a :class:`NumPy array ` of data and return a :class:`float`.""", +) + + # %% # W @@ -4483,6 +4926,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): # %% # Y +docdict["yscale_tfr_plot"] = """ +yscale : 'auto' | 'linear' | 'log' + The scale of the y (frequency) axis. 'linear' gives linear y axis, 'log' gives + log-spaced y axis and 'auto' detects if frequencies are log-spaced and if so sets + the y axis to 'log'. Default is 'auto'. +""" + # %% # Z @@ -4554,12 +5004,12 @@ def copy_doc(source): Parameters ---------- source : function - Function to copy the docstring from + Function to copy the docstring from. Returns ------- wrapper : function - The decorated function + The decorated function. Examples -------- diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index f0fcf94de14..793e399a69f 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -80,13 +80,13 @@ def __getitem__(self, item): Parameters ---------- - item : slice, array-like, str, or list - See below for use cases. + item : int | slice | array-like | str + See Notes for use cases. Returns ------- epochs : instance of Epochs - See below for use cases. + The subset of epochs. Notes ----- @@ -197,10 +197,9 @@ def _getitem( `Epochs` or tuple(Epochs, np.ndarray) if `return_indices` is True subset of epochs (and optionally array with kept epoch indices) """ - data = self._data - self._data = None inst = self.copy() if copy else self - self._data = inst._data = data + if self._data is not None: + np.copyto(inst._data, self._data, casting="no") del self select = inst._item_to_select(item) @@ -681,10 +680,10 @@ def decimate(self, decim, offset=0, *, verbose=None): # appropriately filtered to avoid aliasing from ..epochs import BaseEpochs from ..evoked import Evoked - from ..time_frequency import AverageTFR, EpochsTFR + from ..time_frequency import BaseTFR # This should be the list of classes that inherit - _validate_type(self, (BaseEpochs, Evoked, EpochsTFR, AverageTFR), "inst") + _validate_type(self, (BaseEpochs, Evoked, BaseTFR), "inst") decim, offset, new_sfreq = _check_decim( self.info, decim, offset, check_filter=not hasattr(self, "freqs") ) @@ -755,7 +754,7 @@ def _prepare_write_metadata(metadata): """Convert metadata to JSON for saving.""" if metadata is not None: if not isinstance(metadata, list): - metadata = metadata.to_json(orient="records") + metadata = metadata.reset_index().to_json(orient="records") else: # Pandas DataFrame metadata = json.dumps(metadata) assert isinstance(metadata, str) @@ -772,5 +771,7 @@ def _prepare_read_metadata(metadata): assert isinstance(metadata, list) if pd: metadata = pd.DataFrame.from_records(metadata) + if "index" in metadata.columns: + metadata.set_index("index", inplace=True) assert isinstance(metadata, pd.DataFrame) return metadata diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index 9a7524505e7..2f09689917b 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -29,7 +29,12 @@ svd_flip, ) from ._logging import logger, verbose, warn -from .check import _ensure_int, _validate_type, check_random_state +from .check import ( + _check_pandas_installed, + _ensure_int, + _validate_type, + check_random_state, +) from .docs import fill_doc from .misc import _empty_hash @@ -255,9 +260,9 @@ def _get_inst_data(inst): from ..epochs import BaseEpochs from ..evoked import Evoked from ..io import BaseRaw - from ..time_frequency.tfr import _BaseTFR + from ..time_frequency.tfr import BaseTFR - _validate_type(inst, (BaseRaw, BaseEpochs, Evoked, _BaseTFR), "Instance") + _validate_type(inst, (BaseRaw, BaseEpochs, Evoked, BaseTFR), "Instance") if not inst.preload: inst.load_data() return inst._data @@ -776,6 +781,7 @@ def object_diff(a, b, pre="", *, allclose=False): diffs : str A string representation of the differences. """ + pd = _check_pandas_installed(strict=False) out = "" if type(a) != type(b): # Deal with NamedInt and NamedFloat @@ -835,6 +841,11 @@ def object_diff(a, b, pre="", *, allclose=False): c.eliminate_zeros() if c.nnz > 0: out += pre + (" sparse matrix a and b differ on %s " "elements" % c.nnz) + elif pd and isinstance(a, pd.DataFrame): + try: + pd.testing.assert_frame_equal(a, b) + except AssertionError: + out += pre + " DataFrame mismatch\n" elif hasattr(a, "__getstate__") and a.__getstate__() is not None: out += object_diff(a.__getstate__(), b.__getstate__(), pre, allclose=allclose) else: diff --git a/mne/utils/spectrum.py b/mne/utils/spectrum.py index 5abcb7e3378..67a68b344a7 100644 --- a/mne/utils/spectrum.py +++ b/mne/utils/spectrum.py @@ -1,3 +1,5 @@ +"""Utility functions for spectral and spectrotemporal analysis.""" + # License: BSD-3-Clause # Copyright the MNE-Python contributors. from inspect import currentframe, getargvalues, signature @@ -5,6 +7,26 @@ from ..utils import warn +def _get_instance_type_string(inst): + """Get string representation of the originating instance type.""" + from ..epochs import BaseEpochs + from ..evoked import Evoked, EvokedArray + from ..io import BaseRaw + + parent_classes = inst._inst_type.__bases__ + if BaseRaw in parent_classes: + inst_type_str = "Raw" + elif BaseEpochs in parent_classes: + inst_type_str = "Epochs" + elif inst._inst_type in (Evoked, EvokedArray): + inst_type_str = "Evoked" + else: + raise RuntimeError( + f"Unknown instance type {inst._inst_type} in {type(inst).__name__}" + ) + return inst_type_str + + def _pop_with_fallback(mapping, key, fallback_fun): """Pop from a dict and fallback to a function parameter's default value.""" fallback = signature(fallback_fun).parameters[key].default diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index 5830c647edb..344572dcfc9 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -18,7 +18,7 @@ from mne import Epochs, compute_proj_evoked, read_cov, read_events from mne.channels import read_layout from mne.io import read_raw_fif -from mne.time_frequency.tfr import AverageTFR +from mne.time_frequency.tfr import AverageTFRArray from mne.utils import _record_warnings from mne.viz import ( _get_presser, @@ -309,18 +309,20 @@ def test_plot_tfr_topo(): data = np.random.RandomState(0).randn( len(epochs.ch_names), n_freqs, len(epochs.times) ) - tfr = AverageTFR(epochs.info, data, epochs.times, np.arange(n_freqs), nave) - plt.close("all") - fig = tfr.plot_topo( - baseline=(None, 0), mode="ratio", title="Average power", vmin=0.0, vmax=14.0 + tfr = AverageTFRArray( + info=epochs.info, + data=data, + times=epochs.times, + freqs=np.arange(n_freqs), + nave=nave, ) + plt.close("all") + fig = tfr.plot_topo(baseline=(None, 0), mode="ratio", vmin=0.0, vmax=14.0) # test complex tfr.data = tfr.data * (1 + 1j) plt.close("all") - fig = tfr.plot_topo( - baseline=(None, 0), mode="ratio", title="Average power", vmin=0.0, vmax=14.0 - ) + fig = tfr.plot_topo(baseline=(None, 0), mode="ratio", vmin=0.0, vmax=14.0) # test opening tfr by clicking num_figures_before = len(plt.get_fignums()) @@ -335,14 +337,23 @@ def test_plot_tfr_topo(): # nonuniform freqs freqs = np.logspace(*np.log10([3, 10]), num=3) - tfr = AverageTFR(epochs.info, data, epochs.times, freqs, nave) - fig = tfr.plot([4], baseline=(None, 0), mode="mean", vmax=14.0, show=False) + tfr = AverageTFRArray( + info=epochs.info, data=data, times=epochs.times, freqs=freqs, nave=nave + ) + fig = tfr.plot([4], baseline=(None, 0), mode="mean", vlim=(None, 14.0), show=False) assert fig[0].axes[0].get_yaxis().get_scale() == "log" # one timesample - tfr = AverageTFR(epochs.info, data[:, :, [0]], epochs.times[[1]], freqs, nave) + tfr = AverageTFRArray( + info=epochs.info, + data=data[:, :, [0]], + times=epochs.times[[1]], + freqs=freqs, + nave=nave, + ) + with _record_warnings(): # matplotlib equal left/right - tfr.plot([4], baseline=None, vmax=14.0, show=False, yscale="linear") + tfr.plot([4], baseline=None, vlim=(None, 14.0), show=False, yscale="linear") # one frequency bin, log scale required: as it doesn't make sense # to plot log scale for one value, we test whether yscale is set to linear diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index 2774e198fe8..3ac6bb108a2 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -44,7 +44,7 @@ from mne.datasets import testing from mne.io import RawArray, read_info, read_raw_fif from mne.preprocessing import compute_bridged_electrodes -from mne.time_frequency.tfr import AverageTFR +from mne.time_frequency.tfr import AverageTFRArray from mne.viz import plot_evoked_topomap, plot_projs_topomap, topomap from mne.viz.tests.test_raw import _proj_status from mne.viz.topomap import ( @@ -578,13 +578,21 @@ def test_plot_tfr_topomap(): data = rng.randn(len(picks), n_freqs, len(times)) # test complex numbers - tfr = AverageTFR(info, data * (1 + 1j), times, np.arange(n_freqs), nave) + tfr = AverageTFRArray( + info=info, + data=data * (1 + 1j), + times=times, + freqs=np.arange(n_freqs), + nave=nave, + ) tfr.plot_topomap( ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 ) # test real numbers - tfr = AverageTFR(info, data, times, np.arange(n_freqs), nave) + tfr = AverageTFRArray( + info=info, data=data, times=times, freqs=np.arange(n_freqs), nave=nave + ) tfr.plot_topomap( ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 ) diff --git a/mne/viz/topo.py b/mne/viz/topo.py index e23e60b9bca..11f6695e834 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -428,7 +428,6 @@ def _imshow_tfr( cnorm=None, ): """Show time-frequency map as two-dimensional image.""" - from matplotlib import pyplot as plt from matplotlib.widgets import RectangleSelector _check_option("yscale", yscale, ["auto", "linear", "log"]) @@ -460,7 +459,7 @@ def _imshow_tfr( if isinstance(colorbar, DraggableColorbar): cbar = colorbar.cbar # this happens with multiaxes case else: - cbar = plt.colorbar(mappable=img, ax=ax) + cbar = ax.get_figure().colorbar(mappable=img, ax=ax) if interactive_cmap: ax.CB = DraggableColorbar(cbar, img, kind="tfr_image", ch_type=None) ax.RS = RectangleSelector(ax, onselect=onselect) # reference must be kept diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index cca239f844d..5a6eac4f1ab 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -912,6 +912,7 @@ def _topomap_plot_sensors(pos_x, pos_y, sensors, ax): def _get_pos_outlines(info, picks, sphere, to_sphere=True): from ..channels.layout import _find_topomap_coords + picks = _picks_to_idx(info, picks, "all", exclude=(), allow_empty=False) ch_type = _get_plot_ch_type(pick_info(_simplify_info(info), picks), None) orig_sphere = sphere sphere, clip_origin = _adjust_meg_sphere(sphere, info, ch_type) @@ -1891,7 +1892,6 @@ def plot_tfr_topomap( tfr, ch_type, sphere=sphere ) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) - data = tfr.data[picks, :, :] # merging grads before rescaling makes ERDs visible @@ -1910,7 +1910,6 @@ def plot_tfr_topomap( itmin = idx[0] if tmax is not None: itmax = idx[-1] + 1 - # crop freqs ifmin, ifmax = None, None idx = np.where(_time_mask(tfr.freqs, fmin, fmax))[0] @@ -1918,8 +1917,7 @@ def plot_tfr_topomap( ifmax = idx[-1] + 1 data = data[:, ifmin:ifmax, itmin:itmax] - data = np.mean(np.mean(data, axis=2), axis=1)[:, np.newaxis] - + data = data.mean(axis=(1, 2))[:, np.newaxis] norm = False if np.min(data) < 0 else True vlim = _setup_vmin_vmax(data, *vlim, norm) cmap = _setup_cmap(cmap, norm=norm) diff --git a/mne/viz/utils.py b/mne/viz/utils.py index 9f622a2dd87..5d2f2d95617 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -2138,13 +2138,20 @@ def _set_title_multiple_electrodes( ch_type = _channel_type_prettyprint.get(ch_type, ch_type) if ch_type is None: ch_type = "sensor" - if len(ch_names) > 1: - ch_type += "s" - combine = combine.capitalize() if isinstance(combine, str) else "Combination" + ch_type = f"{ch_type}{_pl(ch_names)}" + if hasattr(combine, "func"): # functools.partial + combine = combine.func + if callable(combine): + combine = getattr(combine, "__name__", str(combine)) + if not isinstance(combine, str): + combine = "Combination" + # mean → Mean, but avoid RMS → Rms and GFP → Gfp + if combine[0].islower(): + combine = combine.capitalize() if all_: title = f"{combine} of {len(ch_names)} {ch_type}" elif len(ch_names) > max_chans and combine != "gfp": - logger.info("More than %i channels, truncating title ...", max_chans) + logger.info(f"More than {max_chans} channels, truncating title ...") title += f", ...\n({combine} of {len(ch_names)} {ch_type})" return title @@ -2373,10 +2380,16 @@ def _make_combine_callable( def _rms(data): return np.sqrt((data**2).mean(**kwargs)) + def _gfp(data): + return data.std(axis=axis, ddof=0) + + # make them play nice with _set_title_multiple_electrodes() + _rms.__name__ = "RMS" + _gfp.__name__ = "GFP" if "rms" in valid: combine_dict["rms"] = _rms if "gfp" in valid and ch_type == "eeg": - combine_dict["gfp"] = lambda data: data.std(axis=axis, ddof=0) + combine_dict["gfp"] = _gfp elif "gfp" in valid: combine_dict["gfp"] = _rms try: diff --git a/pyproject.toml b/pyproject.toml index 7bf34bf3fc8..23a2efeaf4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,7 +140,7 @@ test_extra = [ "mne-bids", ] -# Dependencies for building the docuemntation +# Dependencies for building the documentation doc = [ "sphinx>=6", "numpydoc", diff --git a/tutorials/intro/10_overview.py b/tutorials/intro/10_overview.py index 94b659444b3..2c9a68a1baf 100644 --- a/tutorials/intro/10_overview.py +++ b/tutorials/intro/10_overview.py @@ -309,8 +309,8 @@ # frequency content. frequencies = np.arange(7, 30, 3) -power = mne.time_frequency.tfr_morlet( - aud_epochs, n_cycles=2, return_itc=False, freqs=frequencies, decim=3 +power = aud_epochs.compute_tfr( + "morlet", n_cycles=2, return_itc=False, freqs=frequencies, decim=3, average=True ) power.plot(["MEG 1332"]) diff --git a/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py b/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py index c32af4bcd97..0e7242e96d5 100644 --- a/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py +++ b/tutorials/stats-sensor-space/40_cluster_1samp_time_freq.py @@ -40,7 +40,6 @@ import mne from mne.datasets import sample from mne.stats import permutation_cluster_1samp_test -from mne.time_frequency import tfr_morlet # %% # Set parameters @@ -92,8 +91,8 @@ freqs = np.arange(8, 40, 2) # run the TFR decomposition -tfr_epochs = tfr_morlet( - epochs, +tfr_epochs = epochs.compute_tfr( + "morlet", freqs, n_cycles=4.0, decim=decim, diff --git a/tutorials/stats-sensor-space/50_cluster_between_time_freq.py b/tutorials/stats-sensor-space/50_cluster_between_time_freq.py index 3ced6a82463..0b4078ec883 100644 --- a/tutorials/stats-sensor-space/50_cluster_between_time_freq.py +++ b/tutorials/stats-sensor-space/50_cluster_between_time_freq.py @@ -32,7 +32,6 @@ import mne from mne.datasets import sample from mne.stats import permutation_cluster_test -from mne.time_frequency import tfr_morlet print(__doc__) @@ -104,24 +103,17 @@ decim = 2 freqs = np.arange(7, 30, 3) # define frequencies of interest n_cycles = 1.5 - -tfr_epochs_1 = tfr_morlet( - epochs_condition_1, - freqs, +tfr_kwargs = dict( + method="morlet", + freqs=freqs, n_cycles=n_cycles, decim=decim, return_itc=False, average=False, ) -tfr_epochs_2 = tfr_morlet( - epochs_condition_2, - freqs, - n_cycles=n_cycles, - decim=decim, - return_itc=False, - average=False, -) +tfr_epochs_1 = epochs_condition_1.compute_tfr(**tfr_kwargs) +tfr_epochs_2 = epochs_condition_2.compute_tfr(**tfr_kwargs) tfr_epochs_1.apply_baseline(mode="ratio", baseline=(None, 0)) tfr_epochs_2.apply_baseline(mode="ratio", baseline=(None, 0)) diff --git a/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py b/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py index 202c660575a..19a90decea8 100644 --- a/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py +++ b/tutorials/stats-sensor-space/70_cluster_rmANOVA_time_freq.py @@ -36,7 +36,6 @@ import mne from mne.datasets import sample from mne.stats import f_mway_rm, f_threshold_mway_rm, fdr_correction -from mne.time_frequency import tfr_morlet print(__doc__) @@ -105,8 +104,8 @@ # --------------------------------------------- epochs_power = list() for condition in [epochs[k] for k in event_id]: - this_tfr = tfr_morlet( - condition, + this_tfr = condition.compute_tfr( + "morlet", freqs, n_cycles=n_cycles, decim=decim, diff --git a/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py b/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py index fedd88a568f..2ba8c55bf3d 100644 --- a/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py +++ b/tutorials/stats-sensor-space/75_cluster_ftest_spatiotemporal.py @@ -41,7 +41,6 @@ from mne.channels import find_ch_adjacency from mne.datasets import sample from mne.stats import combine_adjacency, spatio_temporal_cluster_test -from mne.time_frequency import tfr_morlet from mne.viz import plot_compare_evokeds # %% @@ -269,9 +268,9 @@ epochs_power = list() for condition in [epochs[k] for k in ("Aud/L", "Vis/L")]: - this_tfr = tfr_morlet( - condition, - freqs, + this_tfr = condition.compute_tfr( + method="morlet", + freqs=freqs, n_cycles=n_cycles, decim=decim, average=False, diff --git a/tutorials/time-freq/20_sensors_time_frequency.py b/tutorials/time-freq/20_sensors_time_frequency.py index 247fdddfab1..9175e700041 100644 --- a/tutorials/time-freq/20_sensors_time_frequency.py +++ b/tutorials/time-freq/20_sensors_time_frequency.py @@ -10,7 +10,7 @@ We will use this dataset: :ref:`somato-dataset`. It contains so-called event related synchronizations (ERS) / desynchronizations (ERD) in the beta band. -""" +""" # noqa D400 # Authors: Alexandre Gramfort # Stefan Appelhoff # Richard Höchenberger @@ -24,7 +24,6 @@ import mne from mne.datasets import somato -from mne.time_frequency import tfr_morlet # %% # Set parameters @@ -190,14 +189,13 @@ # define frequencies of interest (log-spaced) freqs = np.logspace(*np.log10([6, 35]), num=8) n_cycles = freqs / 2.0 # different number of cycle per frequency -power, itc = tfr_morlet( - epochs, +power, itc = epochs.compute_tfr( + method="morlet", freqs=freqs, n_cycles=n_cycles, - use_fft=True, + average=True, return_itc=True, decim=3, - n_jobs=None, ) # %% @@ -210,7 +208,7 @@ # You can also select a portion in the time-frequency plane to # obtain a topomap for a certain time-frequency region. power.plot_topo(baseline=(-0.5, 0), mode="logratio", title="Average power") -power.plot([82], baseline=(-0.5, 0), mode="logratio", title=power.ch_names[82]) +power.plot(picks=[82], baseline=(-0.5, 0), mode="logratio", title=power.ch_names[82]) fig, axes = plt.subplots(1, 2, figsize=(7, 4), layout="constrained") topomap_kw = dict(