Skip to content

Commit

Permalink
copy over all the new tests [circle full]
Browse files Browse the repository at this point in the history
  • Loading branch information
drammock committed Jan 30, 2024
1 parent 0036415 commit a07af96
Show file tree
Hide file tree
Showing 3 changed files with 321 additions and 3 deletions.
19 changes: 19 additions & 0 deletions mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,25 @@ def epochs_spectrum():
return _get_epochs().load_data().compute_psd()


@pytest.fixture()
def epochs_tfr():
"""Get an EpochsTFR computed from mne.io.tests.data."""
epochs = _get_epochs().load_data()
return epochs.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5))


@pytest.fixture()
def average_tfr(full_evoked):
"""Get an AverageTFR computed from mne.io.tests.data."""
return full_evoked.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5))


@pytest.fixture()
def raw_tfr(raw):
"""Get a RawTFR computed from mne.io.tests.data."""
return raw.compute_tfr(method="morlet", freqs=np.linspace(20, 40, num=5))


@pytest.fixture()
def epochs_empty():
"""Get empty epochs from mne.io.tests.data."""
Expand Down
8 changes: 6 additions & 2 deletions mne/time_frequency/tests/test_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,15 @@ def test_n_welch_windows(raw):
)


def _get_inst(inst, request, evoked):
def _get_inst(inst, request, evoked=None, average_tfr=None):
# ↓ XXX workaround:
# ↓ parametrized fixtures are not accessible via request.getfixturevalue
# ↓ https://github.com/pytest-dev/pytest/issues/4666#issuecomment-456593913
return evoked if inst == "evoked" else request.getfixturevalue(inst)
if inst == "evoked":
return evoked
elif inst == "average_tfr":
return average_tfr
return request.getfixturevalue(inst)


@pytest.mark.parametrize("inst", ("raw", "epochs", "evoked"))
Expand Down
297 changes: 296 additions & 1 deletion mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
from numpy.testing import (
assert_allclose,
assert_array_almost_equal,
assert_array_equal,
assert_equal,
)
from scipy.signal import morlet2

import mne
Expand All @@ -27,6 +32,7 @@
AverageTFR,
EpochsSpectrum,
EpochsTFR,
RawTFR,
tfr_array_morlet,
tfr_array_multitaper,
)
Expand All @@ -45,11 +51,32 @@
from mne.utils import catch_logging, grand_average
from mne.viz.utils import _fake_click, _fake_keypress, _fake_scroll

from .test_spectrum import _get_inst

data_path = Path(__file__).parents[2] / "io" / "tests" / "data"
raw_fname = data_path / "test_raw.fif"
event_fname = data_path / "test-eve.fif"
raw_ctf_fname = data_path / "test_ctf_raw.fif"

freqs_linspace = np.linspace(20, 40, num=5)
freqs_unsorted_list = [26, 33, 41, 20]
mag_names = [f"MEG 01{n}1" for n in (1, 2, 3)]

parametrize_methods = pytest.mark.parametrize("method", ("morlet", "multitaper"))
parametrize_tfr_inst = pytest.mark.parametrize(
"inst", ("raw_tfr", "epochs_tfr", "average_tfr")
)
parametrize_inst_and_ch_type = pytest.mark.parametrize(
"inst,picks",
(
pytest.param("raw_tfr", "mag"),
pytest.param("raw_tfr", "grad"),
pytest.param("epochs_tfr", "mag"), # no grad pairs in epochs fixture
pytest.param("average_tfr", "mag"),
pytest.param("average_tfr", "grad"),
),
)


def test_tfr_ctf():
"""Test that TFRs can be calculated on CTF data."""
Expand Down Expand Up @@ -1492,3 +1519,271 @@ def test_to_data_frame_time_format(time_format):
df = tfr.to_data_frame(time_format=time_format)
dtypes = {None: np.float64, "ms": np.int64, "timedelta": pd.Timedelta}
assert isinstance(df["time"].iloc[0], dtypes[time_format])


@parametrize_methods
@pytest.mark.parametrize("picks", ("mag", mag_names, [2, 5, 8])) # all 3 equivalent
def test_raw_compute_tfr(raw, method, picks):
"""Test Raw.compute_tfr() and picks handling."""
full_tfr = raw.compute_tfr(method, freqs=freqs_linspace)
pick_tfr = raw.compute_tfr(method, freqs=freqs_linspace, picks=picks)
assert isinstance(pick_tfr, RawTFR), type(pick_tfr)
# ↓↓↓ can't use [2,5,8] because ch0 is IAS, so indices change between raw and TFR
want = full_tfr.get_data(picks=mag_names)
got = pick_tfr.get_data()
assert_array_equal(want, got)


@parametrize_methods
@pytest.mark.parametrize("freqs", (freqs_linspace, freqs_unsorted_list))
def test_evoked_compute_tfr(full_evoked, method, freqs):
"""Test Evoked.compute_tfr(), with a few different ways of specifying freqs."""
tfr = full_evoked.compute_tfr(method, freqs)
assert isinstance(tfr, AverageTFR), type(tfr)
assert tfr.nave == full_evoked.nave
assert tfr.comment == full_evoked.comment


@parametrize_methods
@pytest.mark.parametrize(
"average,return_itc,dim,want_class",
(
pytest.param(True, False, None, None, id="average,no_itc"),
pytest.param(True, True, None, None, id="average,itc"),
pytest.param(False, False, "freqs", EpochsTFR, id="no_average,agg_freqs"),
pytest.param(False, False, "epochs", AverageTFR, id="no_average,agg_epochs"),
pytest.param(False, False, "times", EpochsSpectrum, id="no_average,agg_times"),
),
)
def test_epochs_compute_tfr_average_itc(
epochs, method, average, return_itc, dim, want_class
):
"""Test Epochs.compute_tfr(), averaging (at call time and afterward), and ITC."""
tfr = epochs.compute_tfr(
method, freqs=freqs_linspace, average=average, return_itc=return_itc
)
if return_itc:
tfr, itc = tfr
assert isinstance(itc, AverageTFR), type(itc)
# for single-epoch input, ITC should be (nearly) unity
assert_array_almost_equal(itc.get_data(), 1.0, decimal=15)
# if not averaging initially, make sure the post-facto .average() works too
if average:
assert isinstance(tfr, AverageTFR), type(tfr)
assert tfr.nave == 1
assert tfr.comment == "1"
else:
assert isinstance(tfr, EpochsTFR), type(tfr)
avg = tfr.average(dim=dim)
assert isinstance(avg, want_class), type(avg)
if dim == "epochs":
assert avg.nave == len(epochs)
assert avg.comment.startswith(f"mean of {len(epochs)} EpochsTFR")


@parametrize_methods
def test_epochs_vs_evoked_compute_tfr(epochs, method):
"""Compare result of averaging before or after the TFR computation.
This is mostly a test of object structure / attribute preservation. In normal cases,
the data should not match:
- epochs.compute_tfr().average() is average of squared magnitudes
- epochs.average().compute_tfr() is squared magnitude of average
But the `epochs` fixture has only one epoch, so here data should be identical too.
The two things that will always end up different are `._comment` and `._inst_type`,
so we ignore those here.
"""
avg_first = epochs.average().compute_tfr(method=method, freqs=freqs_linspace)
avg_second = epochs.compute_tfr(method=method, freqs=freqs_linspace).average()
for attr in ("_comment", "_inst_type"):
assert getattr(avg_first, attr) != getattr(avg_second, attr)
delattr(avg_first, attr)
delattr(avg_second, attr)
assert avg_first == avg_second


@pytest.mark.parametrize(
"method,freqs,method_kw",
(
pytest.param(
"morlet",
freqs_linspace,
dict(n_cycles=freqs_linspace / 4, use_fft=False, zero_mean=True),
id="morlet-nondefaults",
),
pytest.param(
"multitaper",
freqs_linspace,
dict(
n_cycles=freqs_linspace / 4,
use_fft=False,
zero_mean=False,
time_bandwidth=6,
),
id="multitaper-nondefaults",
),
pytest.param(
"stockwell", "auto", dict(n_fft=1024, width=2), id="stockwell-nondefaults"
),
),
)
def test_epochs_compute_tfr_method_kw(epochs, method, freqs, method_kw):
"""Test Epochs.compute_tfr(**method_kw)."""
tfr = epochs.compute_tfr(method, freqs=freqs, average=True, **method_kw)
assert isinstance(tfr, AverageTFR), type(tfr)


@pytest.mark.parametrize(
"freqs",
(
pytest.param("auto", id="freqauto"),
pytest.param([20, 41], id="fminfmax"),
),
)
@pytest.mark.parametrize("return_itc", (False, True))
def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc):
"""Test Epochs.compute_tfr(method="stockwell")."""
tfr = epochs.compute_tfr("stockwell", freqs, return_itc=return_itc)
if return_itc:
tfr, itc = tfr
assert isinstance(itc, AverageTFR)
# for single-epoch input, ITC should be (nearly) unity
assert_array_almost_equal(itc.get_data(), 1.0, decimal=15)
assert isinstance(tfr, AverageTFR)
assert tfr.comment == "1"


@pytest.mark.parametrize("copy", (False, True))
def test_epochstfr_iter_evoked(epochs_tfr, copy):
"""Test EpochsTFR.iter_evoked()."""
avgs = list(epochs_tfr.iter_evoked(copy=copy))
assert len(avgs) == len(epochs_tfr)
assert all(avg.nave == 1 for avg in avgs)
assert avgs[0].comment == str(epochs_tfr.events[0, -1])


def test_tfr_copy(average_tfr):
"""Test _BaseTFR.copy() method."""
tfr_copy = average_tfr.copy()
# check that info is independent
tfr_copy.info["bads"] = tfr_copy.ch_names
assert average_tfr.info["bads"] == []
# check that data is independent
tfr_copy.data = np.inf
assert np.isfinite(average_tfr.get_data()).all()


@pytest.mark.parametrize(
"mode", ("mean", "ratio", "logratio", "percent", "zscore", "zlogratio")
)
def test_tfr_apply_baseline(average_tfr, mode):
"""Test TFR baselining."""
average_tfr.apply_baseline((-0.1, -0.05), mode=mode)


@pytest.mark.parametrize(
"method,freqs",
(
pytest.param("morlet", freqs_linspace, id="morlet"),
pytest.param("multitaper", freqs_linspace, id="multitaper"),
pytest.param("stockwell", freqs_linspace[[0, -1]], id="stockwell"),
),
)
@pytest.mark.parametrize("decim", (4, slice(0, 200), slice(1, 200, 3)))
def test_tfr_decim(epochs, method, freqs, decim):
"""Test TFR decimation; slices must be long-ish to be longer than the wavelets."""
tfr = epochs.compute_tfr(method, freqs=freqs, decim=decim)
if not isinstance(decim, slice):
decim = slice(None, None, decim)
want = len(range(*decim.indices(len(epochs.times))))
assert tfr.shape[-1] == want


def test_tfr_arithmetic(epochs):
"""Test TFR arithmetic operations."""
tfr, itc = epochs.compute_tfr(
"morlet", freqs=freqs_linspace, average=True, return_itc=True
)
itc_copy = itc.copy()
# addition / subtraction of objects
itc_copy += tfr
assert itc == itc_copy - tfr
# multiplication / division with scalars
bigger_itc = itc * 23
assert_array_almost_equal(itc.data, (bigger_itc / 23).data, decimal=15)
# multiplication / division with arrays
arr = np.full_like(itc.data, 23)
assert_array_equal(bigger_itc.data, (itc * arr).data)
# check errors
different_tfr = tfr.copy().crop(tmax=0.2)
with pytest.raises(RuntimeError, match="times do not match"):
tfr += different_tfr
different_tfr = tfr.copy().crop(fmax=33)
with pytest.raises(RuntimeError, match="freqs do not match"):
tfr += different_tfr


@parametrize_tfr_inst
def test_tfr_save_load(inst, average_tfr, request, tmp_path):
"""Test TFR I/O."""
tfr = _get_inst(inst, request, average_tfr=average_tfr)
fname = tmp_path / "temp_tfr.hdf5"
tfr.save(fname, overwrite=True)
tfr_loaded = read_tfrs(fname)
assert tfr == tfr_loaded


@pytest.mark.parametrize(
"method,freqs,match",
(
("morlet", None, "EpochsTFR got unsupported parameter value freqs=None."),
(None, freqs_linspace, "got unsupported parameter value method=None."),
(None, None, "got unsupported parameter values method=None and freqs=None."),
),
)
def test_tfr_errors(epochs, method, freqs, match):
"""Test that method and freqs are always passed (if not using __setstate__)."""
with pytest.raises(ValueError, match=match):
epochs.compute_tfr(method=method, freqs=freqs)


@parametrize_tfr_inst
@pytest.mark.parametrize(
"picks,combine",
(
pytest.param("mag", "mean", id="mean_of_mags"),
pytest.param("grad", "rms", id="rms_of_grads"),
pytest.param([1, 2], None, id="two_separate_channels"),
),
)
def test_tfr_plot(inst, picks, combine, average_tfr, request):
"""Test {Raw,Epochs,Average}TFR.plot()."""
tfr = _get_inst(inst, request, average_tfr=average_tfr)
fig = tfr.plot(picks=picks, combine=combine)
assert len(fig) == 1 if isinstance(picks, str) else len(picks)


@parametrize_tfr_inst
@pytest.mark.parametrize("picks", ("mag", "grad")) # no EEG in file
def test_tfr_plot_topo(inst, picks, average_tfr, request):
"""Test {Raw,Epochs,Average}TFR.plot_topo()."""
tfr = _get_inst(inst, request, average_tfr=average_tfr)
fig = tfr.plot_topo(picks=picks)
assert fig is not None


@parametrize_inst_and_ch_type
def test_tfr_plot_topomap(inst, picks, average_tfr, request):
"""Test {Raw,Epochs,Average}TFR.plot_topomap()."""
tfr = _get_inst(inst, request, average_tfr=average_tfr)
fig = tfr.plot_topomap(ch_type=picks)
assert fig is not None


@parametrize_inst_and_ch_type
def test_tfr_plot_joint(inst, picks, average_tfr, request):
"""Test {Raw,Epochs,Average}TFR.plot_joint()."""
tfr = _get_inst(inst, request, average_tfr=average_tfr)
fig = tfr.plot_joint(picks=picks)
assert fig is not None

0 comments on commit a07af96

Please sign in to comment.