From a07af96be8be106b3ba1d503cf3e74fbf84b5dde Mon Sep 17 00:00:00 2001 From: Daniel McCloy Date: Tue, 30 Jan 2024 15:30:06 -0600 Subject: [PATCH] copy over all the new tests [circle full] --- mne/conftest.py | 19 ++ mne/time_frequency/tests/test_spectrum.py | 8 +- mne/time_frequency/tests/test_tfr.py | 297 +++++++++++++++++++++- 3 files changed, 321 insertions(+), 3 deletions(-) diff --git a/mne/conftest.py b/mne/conftest.py index 787579f990b..658dd19229f 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -390,6 +390,25 @@ def epochs_spectrum(): return _get_epochs().load_data().compute_psd() +@pytest.fixture() +def epochs_tfr(): + """Get an EpochsTFR computed from mne.io.tests.data.""" + epochs = _get_epochs().load_data() + return epochs.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5)) + + +@pytest.fixture() +def average_tfr(full_evoked): + """Get an AverageTFR computed from mne.io.tests.data.""" + return full_evoked.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5)) + + +@pytest.fixture() +def raw_tfr(raw): + """Get a RawTFR computed from mne.io.tests.data.""" + return raw.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5)) + + @pytest.fixture() def epochs_empty(): """Get empty epochs from mne.io.tests.data.""" diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 26c18529143..7e235c0128f 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -124,11 +124,15 @@ def test_n_welch_windows(raw): ) -def _get_inst(inst, request, evoked): +def _get_inst(inst, request, evoked=None, average_tfr=None): # ↓ XXX workaround: # ↓ parametrized fixtures are not accessible via request.getfixturevalue # ↓ https://github.com/pytest-dev/pytest/issues/4666#issuecomment-456593913 - return evoked if inst == "evoked" else request.getfixturevalue(inst) + if inst == "evoked": + return evoked + elif inst == "average_tfr": + return average_tfr + return request.getfixturevalue(inst) @pytest.mark.parametrize("inst", ("raw", "epochs", "evoked")) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index e6668e0d088..b030124dfa0 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -8,7 +8,12 @@ import matplotlib.pyplot as plt import numpy as np import pytest -from numpy.testing import assert_allclose, assert_array_equal, assert_equal +from numpy.testing import ( + assert_allclose, + assert_array_almost_equal, + assert_array_equal, + assert_equal, +) from scipy.signal import morlet2 import mne @@ -27,6 +32,7 @@ AverageTFR, EpochsSpectrum, EpochsTFR, + RawTFR, tfr_array_morlet, tfr_array_multitaper, ) @@ -45,11 +51,32 @@ from mne.utils import catch_logging, grand_average from mne.viz.utils import _fake_click, _fake_keypress, _fake_scroll +from .test_spectrum import _get_inst + data_path = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_path / "test_raw.fif" event_fname = data_path / "test-eve.fif" raw_ctf_fname = data_path / "test_ctf_raw.fif" +freqs_linspace = np.linspace(20, 40, num=5) +freqs_unsorted_list = [26, 33, 41, 20] +mag_names = [f"MEG 01{n}1" for n in (1, 2, 3)] + +parametrize_methods = pytest.mark.parametrize("method", ("morlet", "multitaper")) +parametrize_tfr_inst = pytest.mark.parametrize( + "inst", ("raw_tfr", "epochs_tfr", "average_tfr") +) +parametrize_inst_and_ch_type = pytest.mark.parametrize( + "inst,picks", + ( + pytest.param("raw_tfr", "mag"), + pytest.param("raw_tfr", "grad"), + pytest.param("epochs_tfr", "mag"), # no grad pairs in epochs fixture + pytest.param("average_tfr", "mag"), + pytest.param("average_tfr", "grad"), + ), +) + def test_tfr_ctf(): """Test that TFRs can be calculated on CTF data.""" @@ -1492,3 +1519,271 @@ def test_to_data_frame_time_format(time_format): df = tfr.to_data_frame(time_format=time_format) dtypes = {None: np.float64, "ms": np.int64, "timedelta": pd.Timedelta} assert isinstance(df["time"].iloc[0], dtypes[time_format]) + + +@parametrize_methods +@pytest.mark.parametrize("picks", ("mag", mag_names, [2, 5, 8])) # all 3 equivalent +def test_raw_compute_tfr(raw, method, picks): + """Test Raw.compute_tfr() and picks handling.""" + full_tfr = raw.compute_tfr(method, freqs=freqs_linspace) + pick_tfr = raw.compute_tfr(method, freqs=freqs_linspace, picks=picks) + assert isinstance(pick_tfr, RawTFR), type(pick_tfr) + # ↓↓↓ can't use [2,5,8] because ch0 is IAS, so indices change between raw and TFR + want = full_tfr.get_data(picks=mag_names) + got = pick_tfr.get_data() + assert_array_equal(want, got) + + +@parametrize_methods +@pytest.mark.parametrize("freqs", (freqs_linspace, freqs_unsorted_list)) +def test_evoked_compute_tfr(full_evoked, method, freqs): + """Test Evoked.compute_tfr(), with a few different ways of specifying freqs.""" + tfr = full_evoked.compute_tfr(method, freqs) + assert isinstance(tfr, AverageTFR), type(tfr) + assert tfr.nave == full_evoked.nave + assert tfr.comment == full_evoked.comment + + +@parametrize_methods +@pytest.mark.parametrize( + "average,return_itc,dim,want_class", + ( + pytest.param(True, False, None, None, id="average,no_itc"), + pytest.param(True, True, None, None, id="average,itc"), + pytest.param(False, False, "freqs", EpochsTFR, id="no_average,agg_freqs"), + pytest.param(False, False, "epochs", AverageTFR, id="no_average,agg_epochs"), + pytest.param(False, False, "times", EpochsSpectrum, id="no_average,agg_times"), + ), +) +def test_epochs_compute_tfr_average_itc( + epochs, method, average, return_itc, dim, want_class +): + """Test Epochs.compute_tfr(), averaging (at call time and afterward), and ITC.""" + tfr = epochs.compute_tfr( + method, freqs=freqs_linspace, average=average, return_itc=return_itc + ) + if return_itc: + tfr, itc = tfr + assert isinstance(itc, AverageTFR), type(itc) + # for single-epoch input, ITC should be (nearly) unity + assert_array_almost_equal(itc.get_data(), 1.0, decimal=15) + # if not averaging initially, make sure the post-facto .average() works too + if average: + assert isinstance(tfr, AverageTFR), type(tfr) + assert tfr.nave == 1 + assert tfr.comment == "1" + else: + assert isinstance(tfr, EpochsTFR), type(tfr) + avg = tfr.average(dim=dim) + assert isinstance(avg, want_class), type(avg) + if dim == "epochs": + assert avg.nave == len(epochs) + assert avg.comment.startswith(f"mean of {len(epochs)} EpochsTFR") + + +@parametrize_methods +def test_epochs_vs_evoked_compute_tfr(epochs, method): + """Compare result of averaging before or after the TFR computation. + + This is mostly a test of object structure / attribute preservation. In normal cases, + the data should not match: + - epochs.compute_tfr().average() is average of squared magnitudes + - epochs.average().compute_tfr() is squared magnitude of average + But the `epochs` fixture has only one epoch, so here data should be identical too. + + The two things that will always end up different are `._comment` and `._inst_type`, + so we ignore those here. + """ + avg_first = epochs.average().compute_tfr(method=method, freqs=freqs_linspace) + avg_second = epochs.compute_tfr(method=method, freqs=freqs_linspace).average() + for attr in ("_comment", "_inst_type"): + assert getattr(avg_first, attr) != getattr(avg_second, attr) + delattr(avg_first, attr) + delattr(avg_second, attr) + assert avg_first == avg_second + + +@pytest.mark.parametrize( + "method,freqs,method_kw", + ( + pytest.param( + "morlet", + freqs_linspace, + dict(n_cycles=freqs_linspace / 4, use_fft=False, zero_mean=True), + id="morlet-nondefaults", + ), + pytest.param( + "multitaper", + freqs_linspace, + dict( + n_cycles=freqs_linspace / 4, + use_fft=False, + zero_mean=False, + time_bandwidth=6, + ), + id="multitaper-nondefaults", + ), + pytest.param( + "stockwell", "auto", dict(n_fft=1024, width=2), id="stockwell-nondefaults" + ), + ), +) +def test_epochs_compute_tfr_method_kw(epochs, method, freqs, method_kw): + """Test Epochs.compute_tfr(**method_kw).""" + tfr = epochs.compute_tfr(method, freqs=freqs, average=True, **method_kw) + assert isinstance(tfr, AverageTFR), type(tfr) + + +@pytest.mark.parametrize( + "freqs", + ( + pytest.param("auto", id="freqauto"), + pytest.param([20, 41], id="fminfmax"), + ), +) +@pytest.mark.parametrize("return_itc", (False, True)) +def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc): + """Test Epochs.compute_tfr(method="stockwell").""" + tfr = epochs.compute_tfr("stockwell", freqs, return_itc=return_itc) + if return_itc: + tfr, itc = tfr + assert isinstance(itc, AverageTFR) + # for single-epoch input, ITC should be (nearly) unity + assert_array_almost_equal(itc.get_data(), 1.0, decimal=15) + assert isinstance(tfr, AverageTFR) + assert tfr.comment == "1" + + +@pytest.mark.parametrize("copy", (False, True)) +def test_epochstfr_iter_evoked(epochs_tfr, copy): + """Test EpochsTFR.iter_evoked().""" + avgs = list(epochs_tfr.iter_evoked(copy=copy)) + assert len(avgs) == len(epochs_tfr) + assert all(avg.nave == 1 for avg in avgs) + assert avgs[0].comment == str(epochs_tfr.events[0, -1]) + + +def test_tfr_copy(average_tfr): + """Test _BaseTFR.copy() method.""" + tfr_copy = average_tfr.copy() + # check that info is independent + tfr_copy.info["bads"] = tfr_copy.ch_names + assert average_tfr.info["bads"] == [] + # check that data is independent + tfr_copy.data = np.inf + assert np.isfinite(average_tfr.get_data()).all() + + +@pytest.mark.parametrize( + "mode", ("mean", "ratio", "logratio", "percent", "zscore", "zlogratio") +) +def test_tfr_apply_baseline(average_tfr, mode): + """Test TFR baselining.""" + average_tfr.apply_baseline((-0.1, -0.05), mode=mode) + + +@pytest.mark.parametrize( + "method,freqs", + ( + pytest.param("morlet", freqs_linspace, id="morlet"), + pytest.param("multitaper", freqs_linspace, id="multitaper"), + pytest.param("stockwell", freqs_linspace[[0, -1]], id="stockwell"), + ), +) +@pytest.mark.parametrize("decim", (4, slice(0, 200), slice(1, 200, 3))) +def test_tfr_decim(epochs, method, freqs, decim): + """Test TFR decimation; slices must be long-ish to be longer than the wavelets.""" + tfr = epochs.compute_tfr(method, freqs=freqs, decim=decim) + if not isinstance(decim, slice): + decim = slice(None, None, decim) + want = len(range(*decim.indices(len(epochs.times)))) + assert tfr.shape[-1] == want + + +def test_tfr_arithmetic(epochs): + """Test TFR arithmetic operations.""" + tfr, itc = epochs.compute_tfr( + "morlet", freqs=freqs_linspace, average=True, return_itc=True + ) + itc_copy = itc.copy() + # addition / subtraction of objects + itc_copy += tfr + assert itc == itc_copy - tfr + # multiplication / division with scalars + bigger_itc = itc * 23 + assert_array_almost_equal(itc.data, (bigger_itc / 23).data, decimal=15) + # multiplication / division with arrays + arr = np.full_like(itc.data, 23) + assert_array_equal(bigger_itc.data, (itc * arr).data) + # check errors + different_tfr = tfr.copy().crop(tmax=0.2) + with pytest.raises(RuntimeError, match="times do not match"): + tfr += different_tfr + different_tfr = tfr.copy().crop(fmax=33) + with pytest.raises(RuntimeError, match="freqs do not match"): + tfr += different_tfr + + +@parametrize_tfr_inst +def test_tfr_save_load(inst, average_tfr, request, tmp_path): + """Test TFR I/O.""" + tfr = _get_inst(inst, request, average_tfr=average_tfr) + fname = tmp_path / "temp_tfr.hdf5" + tfr.save(fname, overwrite=True) + tfr_loaded = read_tfrs(fname) + assert tfr == tfr_loaded + + +@pytest.mark.parametrize( + "method,freqs,match", + ( + ("morlet", None, "EpochsTFR got unsupported parameter value freqs=None."), + (None, freqs_linspace, "got unsupported parameter value method=None."), + (None, None, "got unsupported parameter values method=None and freqs=None."), + ), +) +def test_tfr_errors(epochs, method, freqs, match): + """Test that method and freqs are always passed (if not using __setstate__).""" + with pytest.raises(ValueError, match=match): + epochs.compute_tfr(method=method, freqs=freqs) + + +@parametrize_tfr_inst +@pytest.mark.parametrize( + "picks,combine", + ( + pytest.param("mag", "mean", id="mean_of_mags"), + pytest.param("grad", "rms", id="rms_of_grads"), + pytest.param([1, 2], None, id="two_separate_channels"), + ), +) +def test_tfr_plot(inst, picks, combine, average_tfr, request): + """Test {Raw,Epochs,Average}TFR.plot().""" + tfr = _get_inst(inst, request, average_tfr=average_tfr) + fig = tfr.plot(picks=picks, combine=combine) + assert len(fig) == 1 if isinstance(picks, str) else len(picks) + + +@parametrize_tfr_inst +@pytest.mark.parametrize("picks", ("mag", "grad")) # no EEG in file +def test_tfr_plot_topo(inst, picks, average_tfr, request): + """Test {Raw,Epochs,Average}TFR.plot_topo().""" + tfr = _get_inst(inst, request, average_tfr=average_tfr) + fig = tfr.plot_topo(picks=picks) + assert fig is not None + + +@parametrize_inst_and_ch_type +def test_tfr_plot_topomap(inst, picks, average_tfr, request): + """Test {Raw,Epochs,Average}TFR.plot_topomap().""" + tfr = _get_inst(inst, request, average_tfr=average_tfr) + fig = tfr.plot_topomap(ch_type=picks) + assert fig is not None + + +@parametrize_inst_and_ch_type +def test_tfr_plot_joint(inst, picks, average_tfr, request): + """Test {Raw,Epochs,Average}TFR.plot_joint().""" + tfr = _get_inst(inst, request, average_tfr=average_tfr) + fig = tfr.plot_joint(picks=picks) + assert fig is not None