Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updating TFR classes #11282

Merged
merged 53 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
816d75a
overhaul TFR classes
drammock Oct 25, 2022
8e0c76c
fixup after rebase
drammock Feb 13, 2024
412c151
don't test for deprecations where there aren't any
drammock Feb 13, 2024
01c39d3
make test less hacky
drammock Feb 13, 2024
b301d1d
avoid legacy in tuts/examples
drammock Feb 13, 2024
2969301
coverage; simplify parametrize config
drammock Feb 13, 2024
6cdf01c
fix old/compat tests
drammock Feb 13, 2024
b711529
fix doc build errors
drammock Feb 13, 2024
b5819e7
misc typo
drammock Feb 13, 2024
2a8305f
rename test
drammock Feb 13, 2024
762fc5f
remove redundant test
drammock Feb 13, 2024
80494eb
fix metadata round-trip
drammock Feb 13, 2024
d31aeb3
more tutorial/example fixes
drammock Feb 14, 2024
e630745
fix import problem
drammock Feb 14, 2024
212da92
WIP allow specifying output in compute_tfr
drammock Feb 14, 2024
90be26c
test post-decimation
drammock Feb 14, 2024
1e30e47
remove redundant test
drammock Feb 15, 2024
afc478a
handle complex output
drammock Feb 16, 2024
9d2b720
cleanups after self-review
drammock Feb 28, 2024
77137b5
make BaseTFR public
drammock Feb 28, 2024
c0223d1
undo introduced bug
drammock Feb 28, 2024
f610de5
remove cruft
drammock Feb 28, 2024
3fd8b20
docstring/comment fixes
drammock Feb 28, 2024
8a34570
fix backcompat
drammock Feb 28, 2024
f766fc0
move test to minimize diff
drammock Feb 28, 2024
fb4977a
fix decoding test fail
drammock Feb 28, 2024
5ee3263
fix docstring test
drammock Feb 28, 2024
91fa2f1
fix docstrings; param deprecations
drammock Feb 29, 2024
c6d7218
fix test
drammock Feb 29, 2024
3c30d1b
revert deprecation of topo and joint plot titles
drammock Feb 29, 2024
8f7c4a1
Apply suggestions from code review [ci skip]
drammock Mar 1, 2024
a303b38
remove option average="auto"
drammock Mar 1, 2024
e35e79e
obligate kwarg now
drammock Mar 1, 2024
d87aee8
Merge branch 'main' into spectrogram-class
larsoner Mar 4, 2024
e09d3c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
496958f
Merge branch 'main' into spectrogram-class
larsoner Mar 13, 2024
00a7e71
fix docstring: freqs param
drammock Mar 14, 2024
2a51331
fix docstring: method_kw param
drammock Mar 14, 2024
aa8dc6e
remove space before cite marker
drammock Mar 14, 2024
17f1891
add comment re: vlims [ci skip]
drammock Mar 14, 2024
4ad2812
change zero_mean default (with deprecation cycle)
drammock Mar 18, 2024
f5eb407
remove backslash-escape (breaks import nesting test)
drammock Mar 18, 2024
f4c98b7
avoid futurewarning in tests
drammock Mar 18, 2024
69066c9
Merge remote-tracking branch 'upstream/main' into spectrogram-class
drammock Mar 18, 2024
891ffcb
fix docdict order
drammock Mar 18, 2024
1723c2f
update example to use TFRArray.plot() instead of custom MPL code
drammock Mar 18, 2024
7b27c86
warn once, not twice
drammock Mar 20, 2024
2c9dcc6
don't set norm empirically if user passed vlim
drammock Mar 20, 2024
567e7f3
add TODO comment for future PR
drammock Mar 20, 2024
ae4c5ec
fix hilbert plots in example
drammock Mar 20, 2024
f76e6b4
less confusing fallback repr
drammock Mar 20, 2024
5ea9cc7
possible improvement to hilbert example
drammock Mar 20, 2024
714c14b
Merge remote-tracking branch 'upstream/main' into spectrogram-class
drammock Mar 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/api/time_frequency.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ Time-Frequency
:toctree: ../generated/

AverageTFR
AverageTFRArray
BaseTFR
EpochsTFR
EpochsTFRArray
RawTFR
RawTFRArray
CrossSpectralDensity
Spectrum
SpectrumArray
Expand Down
1 change: 1 addition & 0 deletions doc/changes/devel/11282.apichange.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The default value of the ``zero_mean`` parameter of :func:`mne.time_frequency.tfr_array_morlet` will change from ``False`` to ``True`` in version 1.8, for consistency with related functions. By `Daniel McCloy`_.
1 change: 1 addition & 0 deletions doc/changes/devel/11282.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixes to interactivity in time-frequency objects: the rectangle selector now works on TFR image plots of gradiometer data; and in ``TFR.plot_joint()`` plots, the colormap limits of interactively-generated topomaps match the colormap limits of the main plot. By `Daniel McCloy`_.
1 change: 1 addition & 0 deletions doc/changes/devel/11282.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
New class :class:`mne.time_frequency.RawTFR` and new methods :meth:`mne.io.Raw.compute_tfr`, :meth:`mne.Epochs.compute_tfr`, and :meth:`mne.Evoked.compute_tfr`. These new methods supersede functions :func:`mne.time_frequency.tfr_morlet`, and :func:`mne.time_frequency.tfr_multitaper`, and :func:`mne.time_frequency.tfr_stockwell`, which are now considered "legacy" functions. By `Daniel McCloy`_.
4 changes: 4 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,11 @@
"EvokedArray": "mne.EvokedArray",
"BiHemiLabel": "mne.BiHemiLabel",
"AverageTFR": "mne.time_frequency.AverageTFR",
"AverageTFRArray": "mne.time_frequency.AverageTFRArray",
"EpochsTFR": "mne.time_frequency.EpochsTFR",
"EpochsTFRArray": "mne.time_frequency.EpochsTFRArray",
"RawTFR": "mne.time_frequency.RawTFR",
"RawTFRArray": "mne.time_frequency.RawTFRArray",
"Raw": "mne.io.Raw",
"ICA": "mne.preprocessing.ICA",
"Covariance": "mne.Covariance",
Expand Down
18 changes: 10 additions & 8 deletions examples/decoding/decoding_csp_timefreq.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.io import concatenate_raws, read_raw_edf
from mne.time_frequency import AverageTFR
from mne.time_frequency import AverageTFRArray

# %%
# Set parameters and read data
Expand Down Expand Up @@ -173,13 +173,15 @@
# Plot time-frequency results

# Set up time frequency object
av_tfr = AverageTFR(
create_info(["freq"], sfreq),
tf_scores[np.newaxis, :],
centered_w_times,
freqs[1:],
1,
av_tfr = AverageTFRArray(
info=create_info(["freq"], sfreq),
data=tf_scores[np.newaxis, :],
times=centered_w_times,
freqs=freqs[1:],
nave=1,
)

chance = np.mean(y) # set chance level to white in the plot
av_tfr.plot([0], vmin=chance, title="Time-Frequency Decoding Scores", cmap=plt.cm.Reds)
av_tfr.plot(
[0], vlim=(chance, None), title="Time-Frequency Decoding Scores", cmap=plt.cm.Reds
)
6 changes: 3 additions & 3 deletions examples/inverse/dics_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import mne
from mne.beamformer import apply_dics_tfr_epochs, make_dics
from mne.datasets import somato
from mne.time_frequency import csd_tfr, tfr_morlet
from mne.time_frequency import csd_tfr

print(__doc__)

Expand Down Expand Up @@ -67,8 +67,8 @@
# decomposition for each epoch. We must pass ``output='complex'`` if we wish to
# use this TFR later with a DICS beamformer. We also pass ``average=False`` to
# compute the TFR for each individual epoch.
epochs_tfr = tfr_morlet(
epochs, freqs, n_cycles=5, return_itc=False, output="complex", average=False
epochs_tfr = epochs.compute_tfr(
"morlet", freqs, n_cycles=5, return_itc=False, output="complex", average=False
)

# crop either side to use a buffer to remove edge artifact
Expand Down
5 changes: 2 additions & 3 deletions examples/time_frequency/time_frequency_erds.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from mne.datasets import eegbci
from mne.io import concatenate_raws, read_raw_edf
from mne.stats import permutation_cluster_1samp_test as pcluster_test
from mne.time_frequency import tfr_multitaper

# %%
# First, we load and preprocess the data. We use runs 6, 10, and 14 from
Expand Down Expand Up @@ -96,8 +95,8 @@

# %%
# Finally, we perform time/frequency decomposition over all epochs.
tfr = tfr_multitaper(
epochs,
tfr = epochs.compute_tfr(
method="multitaper",
freqs=freqs,
n_cycles=freqs,
use_fft=True,
Expand Down
80 changes: 38 additions & 42 deletions examples/time_frequency/time_frequency_simulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,8 @@
from matplotlib import pyplot as plt

from mne import Epochs, create_info
from mne.baseline import rescale
from mne.io import RawArray
from mne.time_frequency import (
AverageTFR,
tfr_array_morlet,
tfr_morlet,
tfr_multitaper,
tfr_stockwell,
)
from mne.viz import centers_to_edges
from mne.time_frequency import AverageTFRArray, EpochsTFRArray, tfr_array_morlet

print(__doc__)

Expand Down Expand Up @@ -112,21 +104,21 @@
"Sim: Less time smoothing,\nmore frequency smoothing",
],
):
power = tfr_multitaper(
epochs,
power = epochs.compute_tfr(
method="multitaper",
freqs=freqs,
n_cycles=n_cycles,
time_bandwidth=time_bandwidth,
return_itc=False,
average=True,
)
ax.set_title(title)
# Plot results. Baseline correct based on first 100 ms.
power.plot(
[0],
baseline=(0.0, 0.1),
mode="mean",
vmin=vmin,
vmax=vmax,
vlim=(vmin, vmax),
axes=ax,
show=False,
colorbar=False,
Expand All @@ -146,7 +138,7 @@
fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained")
fmin, fmax = freqs[[0, -1]]
for width, ax in zip((0.2, 0.7, 3.0), axs):
power = tfr_stockwell(epochs, fmin=fmin, fmax=fmax, width=width)
power = epochs.compute_tfr(method="stockwell", freqs=(fmin, fmax), width=width)
power.plot(
[0], baseline=(0.0, 0.1), mode="mean", axes=ax, show=False, colorbar=False
)
Expand All @@ -164,13 +156,14 @@
fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained")
all_n_cycles = [1, 3, freqs / 2.0]
for n_cycles, ax in zip(all_n_cycles, axs):
power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False)
power = epochs.compute_tfr(
method="morlet", freqs=freqs, n_cycles=n_cycles, return_itc=False, average=True
)
power.plot(
[0],
baseline=(0.0, 0.1),
mode="mean",
vmin=vmin,
vmax=vmax,
vlim=(vmin, vmax),
axes=ax,
show=False,
colorbar=False,
Expand All @@ -190,7 +183,9 @@
fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained")
bandwidths = [1.0, 2.0, 4.0]
for bandwidth, ax in zip(bandwidths, axs):
data = np.zeros((len(ch_names), freqs.size, epochs.times.size), dtype=complex)
data = np.zeros(
(len(epochs), len(ch_names), freqs.size, epochs.times.size), dtype=complex
)
for idx, freq in enumerate(freqs):
# Filter raw data and re-epoch to avoid the filter being longer than
# the epoch data for low frequencies and short epochs, such as here.
Expand All @@ -210,17 +205,13 @@
epochs_hilb = Epochs(
raw_filter, events, tmin=0, tmax=n_times / sfreq, baseline=(0, 0.1)
)
tfr_data = epochs_hilb.get_data()
tfr_data = tfr_data * tfr_data.conj() # compute power
tfr_data = np.mean(tfr_data, axis=0) # average over epochs
data[:, idx] = tfr_data
power = AverageTFR(info, data, epochs.times, freqs, nave=n_epochs)
power.plot(
data[:, :, idx] = epochs_hilb.get_data()
power = EpochsTFRArray(epochs.info, data, epochs.times, freqs, method="hilbert")
power.average().plot(
[0],
baseline=(0.0, 0.1),
mode="mean",
vmin=-0.1,
vmax=0.1,
vlim=(0, 0.1),
axes=ax,
show=False,
colorbar=False,
Expand All @@ -241,17 +232,16 @@
# :class:`mne.time_frequency.EpochsTFR` is returned.

n_cycles = freqs / 2.0
power = tfr_morlet(
epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False, average=False
power = epochs.compute_tfr(
method="morlet", freqs=freqs, n_cycles=n_cycles, return_itc=False, average=False
)
print(type(power))
avgpower = power.average()
avgpower.plot(
[0],
baseline=(0.0, 0.1),
mode="mean",
vmin=vmin,
vmax=vmax,
vlim=(vmin, vmax),
title="Using Morlet wavelets and EpochsTFR",
show=False,
)
Expand All @@ -260,23 +250,29 @@
# Operating on arrays
# -------------------
#
# MNE also has versions of the functions above which operate on numpy arrays
# instead of MNE objects. They expect inputs of the shape
# ``(n_epochs, n_channels, n_times)``. They will also return a numpy array
# of shape ``(n_epochs, n_channels, n_freqs, n_times)``.
# MNE-Python also has functions that operate on :class:`NumPy arrays <numpy.ndarray>`
# instead of MNE-Python objects. These are :func:`~mne.time_frequency.tfr_array_morlet`
# and :func:`~mne.time_frequency.tfr_array_multitaper`. They expect inputs of the shape
# ``(n_epochs, n_channels, n_times)`` and return an array of shape
# ``(n_epochs, n_channels, n_freqs, n_times)`` (or optionally, can collapse the epochs
# dimension if you want average power or inter-trial coherence; see ``output`` param).

power = tfr_array_morlet(
epochs.get_data(),
sfreq=epochs.info["sfreq"],
freqs=freqs,
n_cycles=n_cycles,
output="avg_power",
zero_mean=False,
)
# Put it into a TFR container for easy plotting
tfr = AverageTFRArray(
info=epochs.info, data=power, times=epochs.times, freqs=freqs, nave=len(epochs)
)
tfr.plot(
baseline=(0.0, 0.1),
picks=[0],
mode="mean",
vlim=(vmin, vmax),
title="TFR calculated on a NumPy array",
)
# Baseline the output
rescale(power, epochs.times, (0.0, 0.1), mode="mean", copy=False)
fig, ax = plt.subplots(layout="constrained")
x, y = centers_to_edges(epochs.times * 1000, freqs)
mesh = ax.pcolormesh(x, y, power[0], cmap="RdBu_r", vmin=vmin, vmax=vmax)
ax.set_title("TFR calculated on a numpy array")
ax.set(ylim=freqs[[0, -1]], xlabel="Time (ms)")
fig.colorbar(mesh)
4 changes: 2 additions & 2 deletions mne/beamformer/tests/test_dics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from mne.io import read_info
from mne.proj import compute_proj_evoked, make_projector
from mne.surface import _compute_nearest
from mne.time_frequency import CrossSpectralDensity, EpochsTFR, csd_morlet, csd_tfr
from mne.time_frequency import CrossSpectralDensity, EpochsTFRArray, csd_morlet, csd_tfr
from mne.time_frequency.csd import _sym_mat_to_vector
from mne.transforms import apply_trans, invert_transform
from mne.utils import catch_logging, object_diff
Expand Down Expand Up @@ -727,7 +727,7 @@ def test_apply_dics_tfr(return_generator):
data = rng.random((n_epochs, n_chans, len(freqs), n_times))
data *= 1e-6
data = data + data * 1j # add imag. component to simulate phase
epochs_tfr = EpochsTFR(info, data, times=times, freqs=freqs)
epochs_tfr = EpochsTFRArray(info=info, data=data, times=times, freqs=freqs)

# Create a DICS beamformer and convert the EpochsTFR to source space.
csd = csd_tfr(epochs_tfr)
Expand Down
10 changes: 3 additions & 7 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def equalize_channels(instances, copy=True, verbose=None):
from ..evoked import Evoked
from ..forward import Forward
from ..io import BaseRaw
from ..time_frequency import CrossSpectralDensity, _BaseTFR
from ..time_frequency import BaseTFR, CrossSpectralDensity

# Instances need to have a `ch_names` attribute and a `pick_channels`
# method that supports `ordered=True`.
allowed_types = (
BaseRaw,
BaseEpochs,
Evoked,
_BaseTFR,
BaseTFR,
Forward,
Covariance,
CrossSpectralDensity,
Expand Down Expand Up @@ -607,8 +607,6 @@ def drop_channels(self, ch_names, on_missing="raise"):
def _pick_drop_channels(self, idx, *, verbose=None):
# avoid circular imports
from ..io import BaseRaw
from ..time_frequency import AverageTFR, EpochsTFR
from ..time_frequency.spectrum import BaseSpectrum

msg = "adding, dropping, or reordering channels"
if isinstance(self, BaseRaw):
Expand All @@ -633,10 +631,8 @@ def _pick_drop_channels(self, idx, *, verbose=None):
if mat is not None:
setattr(self, key, mat[idx][:, idx])

if isinstance(self, BaseSpectrum):
if hasattr(self, "_dims"): # Spectrum and "new-style" TFRs
axis = self._dims.index("channel")
elif isinstance(self, (AverageTFR, EpochsTFR)):
axis = -3
else: # All others (Evoked, Epochs, Raw) have chs axis=-2
axis = -2
if hasattr(self, "_data"): # skip non-preloaded Raw
Expand Down
Loading
Loading