From 81ef83e3655a2588259e6233750b77df546883c3 Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Sat, 11 Nov 2023 15:28:01 -0600 Subject: [PATCH 01/41] Add ability to reject epochs using functions --- mne/epochs.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index b7afada3d1a..2ae9deb21e6 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -818,11 +818,16 @@ def _reject_setup(self, reject, flat): # check for invalid values for rej, kind in zip((reject, flat), ("Rejection", "Flat")): for key, val in rej.items(): - if val is None or val < 0: + if callable(val): + continue + elif val is not None and val >= 0: + continue + else: raise ValueError( - '%s value must be a number >= 0, not "%s"' % (kind, val) + '%s value must be a number >= 0 or a valid function, not "%s"' % (kind, val) ) + # now check to see if our rejection and flat are getting more # restrictive old_reject = self.reject if self.reject is not None else dict() @@ -3618,6 +3623,8 @@ def _is_good( ): """Test if data segment e is good according to reject and flat. + The reject and flat dictionaries can now accept functions as values. + If full_report=True, it will give True/False as well as a list of all offending channels. """ @@ -3625,18 +3632,29 @@ def _is_good( has_printed = False checkable = np.ones(len(ch_names), dtype=bool) checkable[np.array([c in ignore_chs for c in ch_names], dtype=bool)] = False + for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]): if refl is not None: - for key, thresh in refl.items(): + for key, criterion in refl.items(): idx = channel_type_idx[key] name = key.upper() if len(idx) > 0: e_idx = e[idx] - deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1) checkable_idx = checkable[idx] - idx_deltas = np.where( - np.logical_and(f(deltas, thresh), checkable_idx) - )[0] + + # Check if criterion is a function and apply it + if callable(criterion): + idx_deltas = np.where( + np.logical_and( + criterion(e_idx), + checkable_idx + ) + )[0] + else: + deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1) + idx_deltas = np.where( + np.logical_and(f(deltas, criterion), checkable_idx) + )[0] if len(idx_deltas) > 0: bad_names = [ch_names[idx[i]] for i in idx_deltas] From d8dda070b349213fc5332454b28c9427a43c8330 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 11 Nov 2023 21:34:00 +0000 Subject: [PATCH 02/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/epochs.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index 2ae9deb21e6..a7d4b5198e5 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -824,10 +824,10 @@ def _reject_setup(self, reject, flat): continue else: raise ValueError( - '%s value must be a number >= 0 or a valid function, not "%s"' % (kind, val) + '%s value must be a number >= 0 or a valid function, not "%s"' + % (kind, val) ) - # now check to see if our rejection and flat are getting more # restrictive old_reject = self.reject if self.reject is not None else dict() @@ -3645,10 +3645,7 @@ def _is_good( # Check if criterion is a function and apply it if callable(criterion): idx_deltas = np.where( - np.logical_and( - criterion(e_idx), - checkable_idx - ) + np.logical_and(criterion(e_idx), checkable_idx) )[0] else: deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1) From 867496dd75fa54a4472bee5e9e703b89c0568aad Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Sat, 18 Nov 2023 19:53:10 -0600 Subject: [PATCH 03/41] Update docs --- mne/utils/docs.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index e68b839055d..11f1f3a557f 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1701,11 +1701,14 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ _flat_common = """\ - Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP). - Valid **keys** can be any channel type present in the object. The - **values** are floats that set the minimum acceptable PTP. If the PTP - is smaller than this threshold, the epoch will be dropped. If ``None`` - then no rejection is performed based on flatness of the signal.""" + Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP) + or a custom function. Valid **keys** can be any channel type present + in the object. If using PTP, **values** are floats that set the minimum + acceptable PTP. If the PTP is smaller than this threshold, the epoch + will be dropped. If ``None`` then no rejection is performed based on + flatness of the signal. If a custom function is used than ``flat`` can be + used to reject epochs based on any criteria (including maxima and + minima).""" docdict[ "flat" @@ -3793,8 +3796,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ) _reject_common = """\ - Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP), - i.e. the absolute difference between the lowest and the highest signal + Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP) + or custom functions. Peak-to-peak signal amplitude is defined as + the absolute difference between the lowest and the highest signal value. In each individual epoch, the PTP is calculated for every channel. If the PTP of any one channel exceeds the rejection threshold, the respective epoch will be dropped. @@ -3810,10 +3814,21 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): eog=250e-6 # unit: V (EOG channels) ) - .. note:: Since rejection is based on a signal **difference** - calculated for each channel separately, applying baseline - correction does not affect the rejection procedure, as the - difference will be preserved. + Custom rejection criteria can be also be used by passing a callable + to the dictionary. + + Example:: + + reject = dict(eeg=lambda x: True if (np.max(x, axis=1) > + 1e-3).any() else False)) + + .. note:: If rejection is based on a signal **difference** + calculated for each channel separately, applying baseline + correction does not affect the rejection procedure, as the + difference will be preserved. + + .. note:: If ``reject`` is a callable, than **any** criteria can be + used to reject epochs (including maxima and minima). """ docdict[ From 1b4f5b3978725c20d12ad25186b9fce52ab4d001 Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Sat, 18 Nov 2023 20:19:01 -0600 Subject: [PATCH 04/41] Add ability to reject based on callables --- mne/epochs.py | 52 +++++++++++++++++++------------- mne/tests/test_epochs.py | 65 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 21 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index 6f1540035d0..87360f55641 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -838,33 +838,43 @@ def _reject_setup(self, reject, flat): "previous ones" ) + # Skip this check if old_reject, reject, old_flat, and flat are + # callables + is_callable = False + for rej in (reject, flat, old_reject, old_flat): + for key, val in rej.items(): + if callable(val): + is_callable = True + # copy thresholds for channel types that were used previously, but not # passed this time for key in set(old_reject) - set(reject): reject[key] = old_reject[key] - # make sure new thresholds are at least as stringent as the old ones - for key in reject: - if key in old_reject and reject[key] > old_reject[key]: - raise ValueError( - bad_msg.format( - kind="reject", - key=key, - new=reject[key], - old=old_reject[key], - op=">", + + if not is_callable: + # make sure new thresholds are at least as stringent as the old ones + for key in reject: + if key in old_reject and reject[key] > old_reject[key]: + raise ValueError( + bad_msg.format( + kind="reject", + key=key, + new=reject[key], + old=old_reject[key], + op=">", + ) ) - ) - # same for flat thresholds - for key in set(old_flat) - set(flat): - flat[key] = old_flat[key] - for key in flat: - if key in old_flat and flat[key] < old_flat[key]: - raise ValueError( - bad_msg.format( - kind="flat", key=key, new=flat[key], old=old_flat[key], op="<" + # same for flat thresholds + for key in set(old_flat) - set(flat): + flat[key] = old_flat[key] + for key in flat: + if key in old_flat and flat[key] < old_flat[key]: + raise ValueError( + bad_msg.format( + kind="flat", key=key, new=flat[key], old=old_flat[key], op="<" + ) ) - ) # after validation, set parameters self._bad_dropped = False @@ -3625,7 +3635,7 @@ def _is_good( ): """Test if data segment e is good according to reject and flat. - The reject and flat dictionaries can now accept functions as values. + The reject and flat dictionaries can accept functions as values. If full_report=True, it will give True/False as well as a list of all offending channels. diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index e5ac3892ca8..7a525c7b1be 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -2127,6 +2127,71 @@ def test_reject_epochs(tmp_path): assert epochs_cleaned.flat == dict(grad=new_flat["grad"], mag=flat["mag"]) +@testing.requires_testing_data +# @pytest.mark.parametrize("fname", (fname_raw_testing)) +def test_callable_reject(): + raw = read_raw_fif(fname_raw_testing, preload=True) + raw.crop(0, 5) + raw.del_proj() + chans = raw.info['ch_names'][-6:-1] + raw.pick(chans) + data = raw.get_data() + + # Multipy 20 points of the first channel by 10 + new_data = data + new_data[0, 180:200] *= 1e7 + new_data[0, 610:880] += 1e-3 + edit_raw = mne.io.RawArray(new_data, raw.info) + + events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0) + epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + preload=True + ) + assert len(epochs) == 5 + epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + reject=dict(eeg=lambda x: True if (np.median(x, axis=1) > 1e-3).any() else False), + preload=True + ) + assert epochs.drop_log[2] != () + + epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1).any() else False), + preload=True + ) + assert epochs.drop_log[0] != () + + def reject_criteria(x): + max_condition = np.max(x, axis=1) > 1e-2 + median_condition = np.median(x, axis=1) > 1e-4 + return True if max_condition.any() or median_condition.any() else False + + epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + reject=dict(eeg=reject_criteria), + preload=True + ) + assert epochs.drop_log[0] != () and epochs.drop_log[2] != () + + def test_preload_epochs(): """Test preload of epochs.""" raw, events, picks = _get_data() From 2a6604959751e8802d34752c30f47e4bbc0b3b4b Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Sat, 18 Nov 2023 20:19:26 -0600 Subject: [PATCH 05/41] Add tutorial --- .../preprocessing/20_rejecting_bad_data.py | 99 ++++++++++++++++++- 1 file changed, 97 insertions(+), 2 deletions(-) diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index 99228eb37d7..1f89fa74c24 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -23,6 +23,8 @@ import mne +import numpy as np + sample_data_folder = mne.datasets.sample.data_path() sample_data_raw_file = os.path.join( sample_data_folder, "MEG", "sample", "sample_audvis_filt-0-40_raw.fif" @@ -203,8 +205,8 @@ # %% # .. _`tut-reject-epochs-section`: # -# Rejecting Epochs based on channel amplitude -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Rejecting Epochs based on peak-to-peak channel amplitude +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # # Besides "bad" annotations, the :class:`mne.Epochs` class constructor has # another means of rejecting epochs, based on signal amplitude thresholds for @@ -326,6 +328,99 @@ epochs.drop_bad(reject=stronger_reject_criteria) print(epochs.drop_log) +# %% +# .. _`tut-reject-epochs-func-section`: +# +# Rejecting Epochs using callables (functions) +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Sometimes it is useful to reject epochs based criteria other than +# peak-to-peak amplitudes. For example, we might want to reject epochs +# based on the maximum or minimum amplitude of a channel. +# In this case, the :class:`mne.Epochs` class constructor also accepts +# callables (functions) in the ``reject`` and ``flat`` parameters. This +# allows us to define functions to reject epochs based on our desired criteria. +# +# Let's begin by generating Epoch data with large artifacts in one eeg channel +# in order to demonstrate the versatility of this approach. + +raw.crop(0, 5) +raw.del_proj() +chans = raw.info['ch_names'][-5:-1] +raw.pick(chans) +data = raw.get_data() + +new_data = data +new_data[0, 180:200] *= 1e3 +new_data[0, 460:580] += 1e-3 +edit_raw = mne.io.RawArray(new_data, raw.info) + +# Create fixed length epochs of 1 second +events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0) +epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, baseline=None) +epochs.plot(scalings=dict(eeg=50e-5)) + +# %% +# As you can see, we have two large artifacts in the first channel. One large +# spike in amplitude and one large increase in amplitude. + +# Let's try to reject the epoch containing the spike in amplitude based on the +# maximum amplitude of the first channel. + +epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1e-2).any() else False), + preload=True +) +epochs.plot(scalings=dict(eeg=50e-5)) + +# %% +# Here, the epoch containing the spike in amplitude was rejected for having a +# maximum amplitude greater than 1e-2 Volts. Notice the use of the ``any()`` +# function to check if any of the channels exceeded the threshold. We could +# have also used the ``all()`` function to check if all channels exceeded the +# threshold. + +# Next, let's try to reject the epoch containing the increase in amplitude +# using the median. + +epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + reject=dict(eeg=lambda x: True if (np.median(x, axis=1) > 1e-4).any() else False), + preload=True +) +epochs.plot(scalings=dict(eeg=50e-5)) + +# %% +# Finally, let's try to reject both epochs using a combination of the maximum +# and median. We'll define a custom function and use boolean operators to +# combine the two criteria. + + +def reject_criteria(x): + max_condition = np.max(x, axis=1) > 1e-2 + median_condition = np.median(x, axis=1) > 1e-4 + return True if max_condition.any() or median_condition.any() else False + + +epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + reject=dict(eeg=reject_criteria), + preload=True +) +epochs.plot(events=True) + # %% # Note that a complementary Python module, the `autoreject package`_, uses # machine learning to find optimal rejection criteria, and is designed to From e7084659ce28ef90cbfd47a0b7ebffd44fa1332d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 19 Nov 2023 02:20:34 +0000 Subject: [PATCH 06/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/epochs.py | 6 +++++- mne/tests/test_epochs.py | 21 +++++++------------ .../preprocessing/20_rejecting_bad_data.py | 14 ++++++------- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index 2a2a41616e8..05cfbcfb84e 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -873,7 +873,11 @@ def _reject_setup(self, reject, flat): if key in old_flat and flat[key] < old_flat[key]: raise ValueError( bad_msg.format( - kind="flat", key=key, new=flat[key], old=old_flat[key], op="<" + kind="flat", + key=key, + new=flat[key], + old=old_flat[key], + op="<", ) ) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index fc379ba0882..35dd3eda58c 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -2134,7 +2134,7 @@ def test_callable_reject(): raw = read_raw_fif(fname_raw_testing, preload=True) raw.crop(0, 5) raw.del_proj() - chans = raw.info['ch_names'][-6:-1] + chans = raw.info["ch_names"][-6:-1] raw.pick(chans) data = raw.get_data() @@ -2145,14 +2145,7 @@ def test_callable_reject(): edit_raw = mne.io.RawArray(new_data, raw.info) events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0) - epochs = mne.Epochs( - edit_raw, - events, - tmin=0, - tmax=1, - baseline=None, - preload=True - ) + epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, baseline=None, preload=True) assert len(epochs) == 5 epochs = mne.Epochs( edit_raw, @@ -2160,8 +2153,10 @@ def test_callable_reject(): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: True if (np.median(x, axis=1) > 1e-3).any() else False), - preload=True + reject=dict( + eeg=lambda x: True if (np.median(x, axis=1) > 1e-3).any() else False + ), + preload=True, ) assert epochs.drop_log[2] != () @@ -2172,7 +2167,7 @@ def test_callable_reject(): tmax=1, baseline=None, reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1).any() else False), - preload=True + preload=True, ) assert epochs.drop_log[0] != () @@ -2188,7 +2183,7 @@ def reject_criteria(x): tmax=1, baseline=None, reject=dict(eeg=reject_criteria), - preload=True + preload=True, ) assert epochs.drop_log[0] != () and epochs.drop_log[2] != () diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index c1b5c7cf2e8..dde24ab38b6 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -23,10 +23,10 @@ import os -import mne - import numpy as np +import mne + sample_data_folder = mne.datasets.sample.data_path() sample_data_raw_file = os.path.join( sample_data_folder, "MEG", "sample", "sample_audvis_filt-0-40_raw.fif" @@ -347,7 +347,7 @@ raw.crop(0, 5) raw.del_proj() -chans = raw.info['ch_names'][-5:-1] +chans = raw.info["ch_names"][-5:-1] raw.pick(chans) data = raw.get_data() @@ -369,13 +369,13 @@ # maximum amplitude of the first channel. epochs = mne.Epochs( - edit_raw, + edit_raw, events, tmin=0, tmax=1, baseline=None, reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1e-2).any() else False), - preload=True + preload=True, ) epochs.plot(scalings=dict(eeg=50e-5)) @@ -396,7 +396,7 @@ tmax=1, baseline=None, reject=dict(eeg=lambda x: True if (np.median(x, axis=1) > 1e-4).any() else False), - preload=True + preload=True, ) epochs.plot(scalings=dict(eeg=50e-5)) @@ -419,7 +419,7 @@ def reject_criteria(x): tmax=1, baseline=None, reject=dict(eeg=reject_criteria), - preload=True + preload=True, ) epochs.plot(events=True) From fbdec770d37cac018f3d197e5602603643cb51dc Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Mon, 20 Nov 2023 10:24:39 -0600 Subject: [PATCH 07/41] Make flake8 compliant --- mne/epochs.py | 6 ++++-- tutorials/preprocessing/20_rejecting_bad_data.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index 05cfbcfb84e..e525ce2f4c5 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -825,7 +825,8 @@ def _reject_setup(self, reject, flat): continue else: raise ValueError( - '%s value must be a number >= 0 or a valid function, not "%s"' + '%s value must be a number >= 0 or a valid function,' + 'not "%s"' % (kind, val) ) @@ -853,7 +854,8 @@ def _reject_setup(self, reject, flat): reject[key] = old_reject[key] if not is_callable: - # make sure new thresholds are at least as stringent as the old ones + # make sure new thresholds are at least as stringent + # as the old ones for key in reject: if key in old_reject and reject[key] > old_reject[key]: raise ValueError( diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index dde24ab38b6..bc7f4bba796 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -374,7 +374,8 @@ tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1e-2).any() else False), + reject=dict( + eeg=lambda x: True if (np.max(x, axis=1) > 1e-2).any() else False), preload=True, ) epochs.plot(scalings=dict(eeg=50e-5)) @@ -395,7 +396,8 @@ tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: True if (np.median(x, axis=1) > 1e-4).any() else False), + reject=dict( + eeg=lambda x: True if (np.median(x, axis=1) > 1e-4).any() else False), preload=True, ) epochs.plot(scalings=dict(eeg=50e-5)) From 6e23ecce68edbecee0fad762eddb33dfb06ea1b0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Nov 2023 16:26:12 +0000 Subject: [PATCH 08/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/epochs.py | 5 ++--- tutorials/preprocessing/20_rejecting_bad_data.py | 6 ++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index e525ce2f4c5..a0fcd46de5e 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -825,9 +825,8 @@ def _reject_setup(self, reject, flat): continue else: raise ValueError( - '%s value must be a number >= 0 or a valid function,' - 'not "%s"' - % (kind, val) + "%s value must be a number >= 0 or a valid function," + 'not "%s"' % (kind, val) ) # now check to see if our rejection and flat are getting more diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index bc7f4bba796..dde24ab38b6 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -374,8 +374,7 @@ tmin=0, tmax=1, baseline=None, - reject=dict( - eeg=lambda x: True if (np.max(x, axis=1) > 1e-2).any() else False), + reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1e-2).any() else False), preload=True, ) epochs.plot(scalings=dict(eeg=50e-5)) @@ -396,8 +395,7 @@ tmin=0, tmax=1, baseline=None, - reject=dict( - eeg=lambda x: True if (np.median(x, axis=1) > 1e-4).any() else False), + reject=dict(eeg=lambda x: True if (np.median(x, axis=1) > 1e-4).any() else False), preload=True, ) epochs.plot(scalings=dict(eeg=50e-5)) From cdd2843a0b3838dbf8d76bbb752d03abcd5d68bd Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Mon, 20 Nov 2023 10:58:51 -0600 Subject: [PATCH 09/41] Add docstrings and make flake8 compliant --- mne/tests/test_epochs.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 35dd3eda58c..c97f78097b5 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -2129,8 +2129,8 @@ def test_reject_epochs(tmp_path): @testing.requires_testing_data -# @pytest.mark.parametrize("fname", (fname_raw_testing)) def test_callable_reject(): + """Test using a callable for rejection.""" raw = read_raw_fif(fname_raw_testing, preload=True) raw.crop(0, 5) raw.del_proj() @@ -2138,14 +2138,17 @@ def test_callable_reject(): raw.pick(chans) data = raw.get_data() - # Multipy 20 points of the first channel by 10 + # Add some artifacts new_data = data new_data[0, 180:200] *= 1e7 new_data[0, 610:880] += 1e-3 edit_raw = mne.io.RawArray(new_data, raw.info) - events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0) - epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, baseline=None, preload=True) + events = mne.make_fixed_length_events( + edit_raw, id=1, duration=1.0, start=0) + epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, + baseline=None, preload=True) + assert len(epochs) == 5 epochs = mne.Epochs( edit_raw, @@ -2153,9 +2156,8 @@ def test_callable_reject(): tmin=0, tmax=1, baseline=None, - reject=dict( - eeg=lambda x: True if (np.median(x, axis=1) > 1e-3).any() else False - ), + reject=dict(eeg=lambda x: True if ( + np.median(x, axis=1) > 1e-3).any() else False), preload=True, ) assert epochs.drop_log[2] != () @@ -2166,7 +2168,8 @@ def test_callable_reject(): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1).any() else False), + reject=dict(eeg=lambda x: True if ( + np.max(x, axis=1) > 1).any() else False), preload=True, ) assert epochs.drop_log[0] != () From cd9f7b1d0a701bd5d4b434d3c38801bcb5a696da Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Nov 2023 16:59:57 +0000 Subject: [PATCH 10/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/tests/test_epochs.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index c97f78097b5..ddb51849673 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -2144,10 +2144,8 @@ def test_callable_reject(): new_data[0, 610:880] += 1e-3 edit_raw = mne.io.RawArray(new_data, raw.info) - events = mne.make_fixed_length_events( - edit_raw, id=1, duration=1.0, start=0) - epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, - baseline=None, preload=True) + events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0) + epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, baseline=None, preload=True) assert len(epochs) == 5 epochs = mne.Epochs( @@ -2156,8 +2154,9 @@ def test_callable_reject(): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: True if ( - np.median(x, axis=1) > 1e-3).any() else False), + reject=dict( + eeg=lambda x: True if (np.median(x, axis=1) > 1e-3).any() else False + ), preload=True, ) assert epochs.drop_log[2] != () @@ -2168,8 +2167,7 @@ def test_callable_reject(): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: True if ( - np.max(x, axis=1) > 1).any() else False), + reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1).any() else False), preload=True, ) assert epochs.drop_log[0] != () From 3f5fd84d4c2c88609c3b8f0d87ec5e17c3581927 Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Fri, 1 Dec 2023 14:28:58 -0600 Subject: [PATCH 11/41] Update mne/epochs.py Co-authored-by: Eric Larson --- mne/epochs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/epochs.py b/mne/epochs.py index a0fcd46de5e..cb4f7d4eb47 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -3641,7 +3641,7 @@ def _is_good( ): """Test if data segment e is good according to reject and flat. - The reject and flat dictionaries can accept functions as values. + The reject and flat parameters can accept functions as values. If full_report=True, it will give True/False as well as a list of all offending channels. From a74ccf6f5fa99ffb21e33ba8b1af0b143a816a36 Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Fri, 1 Dec 2023 14:29:08 -0600 Subject: [PATCH 12/41] Update tutorials/preprocessing/20_rejecting_bad_data.py Co-authored-by: Eric Larson --- tutorials/preprocessing/20_rejecting_bad_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index dde24ab38b6..da567b6e6ad 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -374,7 +374,7 @@ tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1e-2).any() else False), + reject=dict(eeg=lambda x: (np.max(x, axis=1) > 1e-2).any()), preload=True, ) epochs.plot(scalings=dict(eeg=50e-5)) From f0cb1b8d499c6f9695a162f69729b3091c3cb6a1 Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Fri, 1 Dec 2023 14:29:22 -0600 Subject: [PATCH 13/41] Update tutorials/preprocessing/20_rejecting_bad_data.py Co-authored-by: Eric Larson --- tutorials/preprocessing/20_rejecting_bad_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index da567b6e6ad..1690371b930 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -395,7 +395,7 @@ tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: True if (np.median(x, axis=1) > 1e-4).any() else False), + reject=dict(eeg=lambda x: (np.median(x, axis=1) > 1e-4).any()), preload=True, ) epochs.plot(scalings=dict(eeg=50e-5)) From 7d02fca98645cb14ffa3e728735522ec6e46d94d Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Fri, 1 Dec 2023 14:30:02 -0600 Subject: [PATCH 14/41] Update mne/utils/docs.py Co-authored-by: Eric Larson --- mne/utils/docs.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index a6e8f5db5b7..de2c21ad455 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -3815,13 +3815,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): eog=250e-6 # unit: V (EOG channels) ) - Custom rejection criteria can be also be used by passing a callable - to the dictionary. + Custom rejection criteria can be also be used by passing a callable, + e.g., to check for 99th percentile of absolute values of any channel + across time being bigger than 1mV:: - Example:: - - reject = dict(eeg=lambda x: True if (np.max(x, axis=1) > - 1e-3).any() else False)) + reject = dict(eeg=lambda x: (np.percentile(np.abs(x), 99, axis=1) > 1e-3).any()) .. note:: If rejection is based on a signal **difference** calculated for each channel separately, applying baseline From f3e88410f2c8621332704dde4b0bd7525ef38224 Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Wed, 6 Dec 2023 09:13:16 -0600 Subject: [PATCH 15/41] Make callable check more fine, doc, add noqa --- mne/epochs.py | 66 +++++++++++++++++++++-------------------------- mne/utils/docs.py | 2 +- 2 files changed, 31 insertions(+), 37 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index cb4f7d4eb47..4bb152dc683 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -839,48 +839,38 @@ def _reject_setup(self, reject, flat): "previous ones" ) - # Skip this check if old_reject, reject, old_flat, and flat are - # callables - is_callable = False - for rej in (reject, flat, old_reject, old_flat): - for key, val in rej.items(): - if callable(val): - is_callable = True - # copy thresholds for channel types that were used previously, but not # passed this time for key in set(old_reject) - set(reject): reject[key] = old_reject[key] - - if not is_callable: - # make sure new thresholds are at least as stringent - # as the old ones - for key in reject: - if key in old_reject and reject[key] > old_reject[key]: - raise ValueError( - bad_msg.format( - kind="reject", - key=key, - new=reject[key], - old=old_reject[key], - op=">", - ) + # make sure new thresholds are at least as stringent as the old ones + for key in reject: + # Skip this check if old_reject and reject are callables + if callable(reject[key]): + continue + if key in old_reject and reject[key] > old_reject[key]: + raise ValueError( + bad_msg.format( + kind="reject", + key=key, + new=reject[key], + old=old_reject[key], + op=">", ) + ) - # same for flat thresholds - for key in set(old_flat) - set(flat): - flat[key] = old_flat[key] - for key in flat: - if key in old_flat and flat[key] < old_flat[key]: - raise ValueError( - bad_msg.format( - kind="flat", - key=key, - new=flat[key], - old=old_flat[key], - op="<", - ) + # same for flat thresholds + for key in set(old_flat) - set(flat): + flat[key] = old_flat[key] + for key in flat: + if callable(flat[key]): + continue + if key in old_flat and flat[key] < old_flat[key]: + raise ValueError( + bad_msg.format( + kind="flat", key=key, new=flat[key], old=old_flat[key], op="<" ) + ) # after validation, set parameters self._bad_dropped = False @@ -1544,7 +1534,7 @@ def drop(self, indices, reason="USER", verbose=None): Set epochs to remove by specifying indices to remove or a boolean mask to apply (where True values get removed). Events are correspondingly modified. - reason : str + reason : list | tuple | str Reason for dropping the epochs ('ECG', 'timeout', 'blink' etc). Default: 'USER'. %(verbose)s @@ -3180,6 +3170,10 @@ class Epochs(BaseEpochs): See :meth:`~mne.Epochs.equalize_event_counts` - 'USER' For user-defined reasons (see :meth:`~mne.Epochs.drop`). + + When dropping based on flat or reject parameters the tuple of + reasons contains a tuple of channels that satisfied the rejection + criteria. filename : str The filename of the object. times : ndarray diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 1a9649226d0..207ba48d840 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -3828,7 +3828,7 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. note:: If ``reject`` is a callable, than **any** criteria can be used to reject epochs (including maxima and minima). -""" +""" # noqa: E501 docdict[ "reject_drop_bad" From fbe4cd256bf1fa46eb74813a376b204fbfd86524 Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Fri, 5 Jan 2024 14:14:52 -0600 Subject: [PATCH 16/41] Update epochs so that adding refl tuple doesnt cause error --- mne/epochs.py | 43 +++++++++++++++++++++++++++++++------------ mne/utils/mixin.py | 9 +++++++-- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index 61d31c73f62..a65781a13ee 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -795,10 +795,6 @@ def _reject_setup(self, reject, flat): reject = deepcopy(reject) if reject is not None else dict() flat = deepcopy(flat) if flat is not None else dict() for rej, kind in zip((reject, flat), ("reject", "flat")): - if not isinstance(rej, dict): - raise TypeError( - "reject and flat must be dict or None, not %s" % type(rej) - ) bads = set(rej.keys()) - set(idx.keys()) if len(bads) > 0: raise KeyError("Unknown channel types found in %s: %s" % (kind, bads)) @@ -818,16 +814,38 @@ def _reject_setup(self, reject, flat): # check for invalid values for rej, kind in zip((reject, flat), ("Rejection", "Flat")): + if not isinstance(rej, dict): + raise TypeError( + "reject and flat must be dict or None, not %s" % type(rej) + ) + + # Check if each reject/flat dict is a tuple that contains a + # callable function and a collection or string for key, val in rej.items(): - if callable(val): - continue - elif val is not None and val >= 0: - continue + if isinstance(val, (list, tuple)): + if callable(val[0]): + continue + elif val[0] is not None and val[0] >= 0: + continue + else: + raise ValueError( + "%s criteria must be a number >= 0 or a valid" + ' callable, not "%s"' % (kind, val) + ) + if isinstance(val[1], (list, tuple, str)): + continue + else: + raise ValueError( + "%s reason must be a collection or string, " + "not %s" % (kind, type(val[1])) + ) else: raise ValueError( - "%s value must be a number >= 0 or a valid function," - 'not "%s"' % (kind, val) - ) + """The dictionary elements in %s must be in the + form of a collection that contains a callable or value + in the first element and a collection or string + in the second element""" % rej + ) # now check to see if our rejection and flat are getting more # restrictive @@ -3647,7 +3665,8 @@ def _is_good( for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]): if refl is not None: - for key, criterion in refl.items(): + for key, refl in refl.items(): + criterion = refl[0] idx = channel_type_idx[key] name = key.upper() if len(idx) > 0: diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index c90121fdfbb..634086480b5 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -209,8 +209,13 @@ def _getitem( key_selection = inst.selection[select] drop_log = list(inst.drop_log) if reason is not None: - for k in np.setdiff1d(inst.selection, key_selection): - drop_log[k] = (reason,) + # Used for multiple reasons + if isinstance(reason, (list, tuple)): + for i, idx in enumerate(np.setdiff1d(inst.selection, key_selection)): + drop_log[idx] = reason[i] + else: + for idx in np.setdiff1d(inst.selection, key_selection): + drop_log[idx] = reason inst.drop_log = tuple(drop_log) inst.selection = key_selection del drop_log From 24e669c0b5b5982a35ec32000a4398b0289f04e6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Jan 2024 20:15:26 +0000 Subject: [PATCH 17/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/epochs.py | 7 ++++--- mne/utils/mixin.py | 4 +++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index d0c7c3a58d5..bb61e0cc62c 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -839,11 +839,12 @@ def _reject_setup(self, reject, flat): ) else: raise ValueError( - """The dictionary elements in %s must be in the + """The dictionary elements in %s must be in the form of a collection that contains a callable or value in the first element and a collection or string - in the second element""" % rej - ) + in the second element""" + % rej + ) # now check to see if our rejection and flat are getting more # restrictive diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index 634086480b5..8d7320fbb8d 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -211,7 +211,9 @@ def _getitem( if reason is not None: # Used for multiple reasons if isinstance(reason, (list, tuple)): - for i, idx in enumerate(np.setdiff1d(inst.selection, key_selection)): + for i, idx in enumerate( + np.setdiff1d(inst.selection, key_selection) + ): drop_log[idx] = reason[i] else: for idx in np.setdiff1d(inst.selection, key_selection): From 8401c92b3d006c86a063e25f47f59df75d552a82 Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Tue, 9 Jan 2024 12:58:04 -0600 Subject: [PATCH 18/41] return callable/reasons --- mne/epochs.py | 106 +++++++++++++++++++++++---------------- mne/tests/test_epochs.py | 47 ++++++++++++++--- mne/utils/mixin.py | 14 ++++-- 3 files changed, 113 insertions(+), 54 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index bb61e0cc62c..fdba9d23cc6 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -700,8 +700,9 @@ def _check_consistency(self): assert hasattr(self, "_times_readonly") assert not self.times.flags["WRITEABLE"] assert isinstance(self.drop_log, tuple) + print("self.drop_log", self.drop_log) assert all(isinstance(log, tuple) for log in self.drop_log) - assert all(isinstance(s, str) for log in self.drop_log for s in log) + assert all(isinstance(s, (str,tuple)) for log in self.drop_log for s in log) def reset_drop_log_selection(self): """Reset the drop_log and selection entries. @@ -793,9 +794,14 @@ def _reject_setup(self, reject, flat): reject = deepcopy(reject) if reject is not None else dict() flat = deepcopy(flat) if flat is not None else dict() for rej, kind in zip((reject, flat), ("reject", "flat")): + if not isinstance(rej, dict): + raise TypeError( + "reject and flat must be dict or None, not %s" % type(rej) + ) bads = set(rej.keys()) - set(idx.keys()) if len(bads) > 0: - raise KeyError("Unknown channel types found in %s: %s" % (kind, bads)) + raise KeyError( + "Unknown channel types found in %s: %s" % (kind, bads)) for key in idx.keys(): # don't throw an error if rejection/flat would do nothing @@ -811,40 +817,30 @@ def _reject_setup(self, reject, flat): ) # check for invalid values - for rej, kind in zip((reject, flat), ("Rejection", "Flat")): - if not isinstance(rej, dict): - raise TypeError( - "reject and flat must be dict or None, not %s" % type(rej) - ) - - # Check if each reject/flat dict is a tuple that contains a - # callable function and a collection or string - for key, val in rej.items(): - if isinstance(val, (list, tuple)): - if callable(val[0]): - continue - elif val[0] is not None and val[0] >= 0: - continue - else: - raise ValueError( - "%s criteria must be a number >= 0 or a valid" - ' callable, not "%s"' % (kind, val) + for rej, kind in zip((reject, flat), ("Rejection", "Flat")): + for key, val in rej.items(): + name = f"{kind} dict value for {key}" + if isinstance(val, (list, tuple)): + _validate_type( + val[0], ("numeric", "callable"), + val[0], "float, int, or callable" ) - if isinstance(val[1], (list, tuple, str)): + if ( + isinstance(val[0], (int, float)) and + (val[0] is None or val[0] < 0) + ): + raise ValueError( + """If using numerical %s criteria, the value + must be >= 0 Not '%s'.""" % (kind, val[0]) + ) + _validate_type(val[1], ("str", "array-like"), val[1]) continue - else: + _validate_type(val, "numeric", name, extra="or callable") + if val is None or val < 0: raise ValueError( - "%s reason must be a collection or string, " - "not %s" % (kind, type(val[1])) + """If using numerical %s criteria, the value + must be >= 0 Not '%s'.""" % (kind, val) ) - else: - raise ValueError( - """The dictionary elements in %s must be in the - form of a collection that contains a callable or value - in the first element and a collection or string - in the second element""" - % rej - ) # now check to see if our rejection and flat are getting more # restrictive @@ -1565,6 +1561,16 @@ def drop(self, indices, reason="USER", verbose=None): if indices.ndim > 1: raise ValueError("indices must be a scalar or a 1-d array") + # Check if indices and reasons are of the same length + # if using collection to drop epochs + if (isinstance(reason, (list, tuple))): + if len(indices) != len(reason): + raise ValueError( + "If using a list or tuple as the reason, " + "indices and reasons must be of the same length, got " + f"{len(indices)} and {len(reason)}" + ) + if indices.dtype == bool: indices = np.where(indices)[0] @@ -1767,7 +1773,7 @@ def _get_data( is_good, bad_tuple = self._is_good_epoch(epoch, verbose=verbose) if not is_good: assert isinstance(bad_tuple, tuple) - assert all(isinstance(x, str) for x in bad_tuple) + assert all(isinstance(x, (str, tuple)) for x in bad_tuple) drop_log[sel] = drop_log[sel] + bad_tuple continue good_idx.append(idx) @@ -3715,7 +3721,10 @@ def _is_good( for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]): if refl is not None: for key, refl in refl.items(): - criterion = refl[0] + if isinstance(refl, (tuple, list)): + criterion = refl[0] + else: + criterion = refl idx = channel_type_idx[key] name = key.upper() if len(idx) > 0: @@ -3734,17 +3743,26 @@ def _is_good( )[0] if len(idx_deltas) > 0: - bad_names = [ch_names[idx[i]] for i in idx_deltas] - if not has_printed: - logger.info( - " Rejecting %s epoch based on %s : " - "%s" % (t, name, bad_names) - ) - has_printed = True - if not full_report: - return False + if isinstance(refl, (tuple, list)): + reasons = list(refl[1]) + for idx, reason in enumerate(reasons): + if isinstance(reason, str): + reasons[idx] = (reason,) + if isinstance(reason, list): + reasons[idx] = tuple(reason) + bad_tuple += tuple(reasons) else: - bad_tuple += tuple(bad_names) + bad_names = [ch_names[idx[i]] for i in idx_deltas] + if not has_printed: + logger.info( + " Rejecting %s epoch based on %s : " + "%s" % (t, name, bad_names) + ) + has_printed = True + if not full_report: + return False + else: + bad_tuple += tuple(bad_names) if not full_report: return True diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 9f9d24d31a7..8d9c13afdd6 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -488,10 +488,11 @@ def test_average_movements(): def _assert_drop_log_types(drop_log): __tracebackhide__ = True - assert isinstance(drop_log, tuple), "drop_log should be tuple" + assert isinstance(drop_log, (tuple, list)), """drop_log should be tuple + or list""" assert all( - isinstance(log, tuple) for log in drop_log - ), "drop_log[ii] should be tuple" + isinstance(log, (tuple, list)) for log in drop_log + ), "drop_log[ii] should be tuple or list" assert all( isinstance(s, str) for log in drop_log for s in log ), "drop_log[ii][jj] should be str" @@ -549,7 +550,7 @@ def test_reject(): preload=False, reject=dict(eeg=np.inf), ) - for val in (None, -1): # protect against older MNE-C types + for val in (-1, (-1, 'Hi')): # protect against older MNE-C types for kwarg in ("reject", "flat"): pytest.raises( ValueError, @@ -563,6 +564,21 @@ def test_reject(): preload=False, **{kwarg: dict(grad=val)}, ) + bad_types = ['Hi', ('Hi' 'Hi'), (1, 1)] + for val in bad_types: # protect against bad types + for kwarg in ("reject", "flat"): + pytest.raises( + TypeError, + Epochs, + raw, + events, + event_id, + tmin, + tmax, + picks=picks_meg, + preload=False, + **{kwarg: dict(grad=val)}, + ) pytest.raises( KeyError, Epochs, @@ -2175,7 +2191,7 @@ def test_callable_reject(): tmax=1, baseline=None, reject=dict( - eeg=lambda x: True if (np.median(x, axis=1) > 1e-3).any() else False + eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median") ), preload=True, ) @@ -2187,7 +2203,7 @@ def test_callable_reject(): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: True if (np.max(x, axis=1) > 1).any() else False), + reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), "eeg max")), preload=True, ) assert epochs.drop_log[0] != () @@ -2195,7 +2211,7 @@ def test_callable_reject(): def reject_criteria(x): max_condition = np.max(x, axis=1) > 1e-2 median_condition = np.median(x, axis=1) > 1e-4 - return True if max_condition.any() or median_condition.any() else False + (max_condition.any() or median_condition.any(), "eeg max or median") epochs = mne.Epochs( edit_raw, @@ -2206,6 +2222,7 @@ def reject_criteria(x): reject=dict(eeg=reject_criteria), preload=True, ) + print(epochs.drop_log) assert epochs.drop_log[0] != () and epochs.drop_log[2] != () @@ -3262,6 +3279,22 @@ def test_drop_epochs(): assert_array_equal(events[epochs[3:].selection], events1[[5, 6]]) assert_array_equal(events[epochs["1"].selection], events1[[0, 1, 3, 5, 6]]) + # Test using tuple to drop epochs + raw, events, picks = _get_data() + epochs_tuple = Epochs( + raw, events, event_id, + tmin, tmax, picks=picks, preload=True + ) + selection_tuple = epochs_tuple.selection.copy() + epochs_tuple.drop((2, 3, 4), reason=([['list'], 'string', ('tuple',)])) + n_events = len(epochs.events) + assert_equal( + [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]], [ + ['list'], ["string"], ['tuple']] + ) + + + @pytest.mark.parametrize("preload", (True, False)) def test_drop_epochs_mult(preload): diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index 8d7320fbb8d..c7cfa06297e 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -210,14 +210,22 @@ def _getitem( drop_log = list(inst.drop_log) if reason is not None: # Used for multiple reasons - if isinstance(reason, (list, tuple)): + if isinstance(reason, tuple): + reason = list(reason) + + if isinstance(reason, list): for i, idx in enumerate( np.setdiff1d(inst.selection, key_selection) ): - drop_log[idx] = reason[i] + r = reason[i] + if isinstance(r, str): + r = (r,) + if isinstance(r, list): + r = tuple(r) + drop_log[idx] = r else: for idx in np.setdiff1d(inst.selection, key_selection): - drop_log[idx] = reason + drop_log[idx] = (reason,) inst.drop_log = tuple(drop_log) inst.selection = key_selection del drop_log From cf1facf3634054292186a9781561985b38cdf6e0 Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Tue, 9 Jan 2024 15:41:07 -0600 Subject: [PATCH 19/41] allow callables --- mne/_version.py | 16 +++++ mne/epochs.py | 58 +++++++++---------- mne/tests/test_epochs.py | 21 +++++-- mne/utils/docs.py | 5 +- .../preprocessing/20_rejecting_bad_data.py | 15 +++-- 5 files changed, 75 insertions(+), 40 deletions(-) create mode 100644 mne/_version.py diff --git a/mne/_version.py b/mne/_version.py new file mode 100644 index 00000000000..c741fa16728 --- /dev/null +++ b/mne/_version.py @@ -0,0 +1,16 @@ +# file generated by setuptools_scm +# don't change, don't track in version control +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple, Union + VERSION_TUPLE = Tuple[Union[int, str], ...] +else: + VERSION_TUPLE = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE + +__version__ = version = '1.6.0.dev139+gfdaeb8620' +__version_tuple__ = version_tuple = (1, 6, 0, 'dev139', 'gfdaeb8620') diff --git a/mne/epochs.py b/mne/epochs.py index fdba9d23cc6..e60e235f5cd 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -700,9 +700,8 @@ def _check_consistency(self): assert hasattr(self, "_times_readonly") assert not self.times.flags["WRITEABLE"] assert isinstance(self.drop_log, tuple) - print("self.drop_log", self.drop_log) assert all(isinstance(log, tuple) for log in self.drop_log) - assert all(isinstance(s, (str,tuple)) for log in self.drop_log for s in log) + assert all(isinstance(s, str) for log in self.drop_log for s in log) def reset_drop_log_selection(self): """Reset the drop_log and selection entries. @@ -820,20 +819,7 @@ def _reject_setup(self, reject, flat): for rej, kind in zip((reject, flat), ("Rejection", "Flat")): for key, val in rej.items(): name = f"{kind} dict value for {key}" - if isinstance(val, (list, tuple)): - _validate_type( - val[0], ("numeric", "callable"), - val[0], "float, int, or callable" - ) - if ( - isinstance(val[0], (int, float)) and - (val[0] is None or val[0] < 0) - ): - raise ValueError( - """If using numerical %s criteria, the value - must be >= 0 Not '%s'.""" % (kind, val[0]) - ) - _validate_type(val[1], ("str", "array-like"), val[1]) + if callable(val): continue _validate_type(val, "numeric", name, extra="or callable") if val is None or val < 0: @@ -3721,20 +3707,30 @@ def _is_good( for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]): if refl is not None: for key, refl in refl.items(): - if isinstance(refl, (tuple, list)): - criterion = refl[0] - else: - criterion = refl + criterion = refl idx = channel_type_idx[key] name = key.upper() if len(idx) > 0: e_idx = e[idx] checkable_idx = checkable[idx] - # Check if criterion is a function and apply it if callable(criterion): + result = criterion(e_idx) + _validate_type(result, tuple, result, "tuple") + if len(result) != 2: + raise TypeError( + "Function criterion must return a " + "tuple of length 2" + ) + cri_truth, reasons = result + _validate_type(cri_truth, (bool, np.bool_), + cri_truth, "bool") + _validate_type( + reasons, (str, list, tuple), + reasons, "str, list, or tuple" + ) idx_deltas = np.where( - np.logical_and(criterion(e_idx), checkable_idx) + np.logical_and(cri_truth, checkable_idx) )[0] else: deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1) @@ -3743,14 +3739,16 @@ def _is_good( )[0] if len(idx_deltas) > 0: - if isinstance(refl, (tuple, list)): - reasons = list(refl[1]) - for idx, reason in enumerate(reasons): - if isinstance(reason, str): - reasons[idx] = (reason,) - if isinstance(reason, list): - reasons[idx] = tuple(reason) - bad_tuple += tuple(reasons) + # Check to verify that refl is a callable that returns + # (bool, reason). Reason must be a str/list/tuple. + # If using tuple + if callable(refl): + if isinstance(reasons, (tuple, list)): + for idx, reason in enumerate(reasons): + _validate_type(reason, str, reason, "str") + bad_tuple += tuple(reasons) + if isinstance(reasons, str): + bad_tuple += (reasons,) else: bad_names = [ch_names[idx[i]] for i in idx_deltas] if not has_printed: diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 8d9c13afdd6..f71eabda717 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -550,7 +550,7 @@ def test_reject(): preload=False, reject=dict(eeg=np.inf), ) - for val in (-1, (-1, 'Hi')): # protect against older MNE-C types + for val in (-1, -2): # protect against older MNE-C types for kwarg in ("reject", "flat"): pytest.raises( ValueError, @@ -564,7 +564,17 @@ def test_reject(): preload=False, **{kwarg: dict(grad=val)}, ) - bad_types = ['Hi', ('Hi' 'Hi'), (1, 1)] + + def my_reject_1(epoch_data): + bad_idxs = np.where(np.percentile(epoch_data, axis=1) > 1e-35) + return len(bad_idxs) > 0 + + def my_reject_2(epoch_data): + bad_idxs = np.where(np.percentile(epoch_data, axis=1) > 1e-35) + reasons = tuple(epochs.ch_name[bad_idx] for bad_idx in bad_idxs) + return len(bad_idxs), reasons + + bad_types = [my_reject_1, my_reject_2, ('Hi' 'Hi'), (1, 1)] for val in bad_types: # protect against bad types for kwarg in ("reject", "flat"): pytest.raises( @@ -576,7 +586,7 @@ def test_reject(): tmin, tmax, picks=picks_meg, - preload=False, + preload=True, **{kwarg: dict(grad=val)}, ) pytest.raises( @@ -2211,7 +2221,10 @@ def test_callable_reject(): def reject_criteria(x): max_condition = np.max(x, axis=1) > 1e-2 median_condition = np.median(x, axis=1) > 1e-4 - (max_condition.any() or median_condition.any(), "eeg max or median") + return ( + (max_condition.any() or median_condition.any()), + "eeg max or median" + ) epochs = mne.Epochs( edit_raw, diff --git a/mne/utils/docs.py b/mne/utils/docs.py index f64dbf4e66e..79e6bb945e2 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -3291,9 +3291,10 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): Custom rejection criteria can be also be used by passing a callable, e.g., to check for 99th percentile of absolute values of any channel - across time being bigger than 1mV:: + across time being bigger than 1mV. The callable must return a good, reason tuple. + Where good must be bool and reason must be str, list, or tuple where each entry is a str.:: - reject = dict(eeg=lambda x: (np.percentile(np.abs(x), 99, axis=1) > 1e-3).any()) + reject = dict(eeg=lambda x: ((np.percentile(np.abs(x), 99, axis=1) > 1e-3).any(), "> 1mV somewhere")) .. note:: If rejection is based on a signal **difference** calculated for each channel separately, applying baseline diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index 1690371b930..27f1085f713 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -366,7 +366,9 @@ # spike in amplitude and one large increase in amplitude. # Let's try to reject the epoch containing the spike in amplitude based on the -# maximum amplitude of the first channel. +# maximum amplitude of the first channel. Please note that the callable in +# ``reject`` must return a (good, reason) tuple. Where the good must be bool +# and reason must be a str, list, or tuple where each entry is a str. epochs = mne.Epochs( edit_raw, @@ -374,7 +376,7 @@ tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: (np.max(x, axis=1) > 1e-2).any()), + reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1e-2).any(), "max amp")), preload=True, ) epochs.plot(scalings=dict(eeg=50e-5)) @@ -395,7 +397,9 @@ tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: (np.median(x, axis=1) > 1e-4).any()), + reject=dict( + eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp") + ), preload=True, ) epochs.plot(scalings=dict(eeg=50e-5)) @@ -409,7 +413,10 @@ def reject_criteria(x): max_condition = np.max(x, axis=1) > 1e-2 median_condition = np.median(x, axis=1) > 1e-4 - return True if max_condition.any() or median_condition.any() else False + return ( + (max_condition.any() or median_condition.any()), + ["max amp", "median amp"] + ) epochs = mne.Epochs( From e98bee2045644becd2e9dc8c1ff3be9877e58a86 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Jan 2024 21:41:37 +0000 Subject: [PATCH 20/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/_version.py | 5 ++-- mne/epochs.py | 28 ++++++++----------- mne/tests/test_epochs.py | 24 +++++----------- .../preprocessing/20_rejecting_bad_data.py | 9 ++---- 4 files changed, 24 insertions(+), 42 deletions(-) diff --git a/mne/_version.py b/mne/_version.py index c741fa16728..d402c528c09 100644 --- a/mne/_version.py +++ b/mne/_version.py @@ -3,6 +3,7 @@ TYPE_CHECKING = False if TYPE_CHECKING: from typing import Tuple, Union + VERSION_TUPLE = Tuple[Union[int, str], ...] else: VERSION_TUPLE = object @@ -12,5 +13,5 @@ __version_tuple__: VERSION_TUPLE version_tuple: VERSION_TUPLE -__version__ = version = '1.6.0.dev139+gfdaeb8620' -__version_tuple__ = version_tuple = (1, 6, 0, 'dev139', 'gfdaeb8620') +__version__ = version = "1.6.0.dev139+gfdaeb8620" +__version_tuple__ = version_tuple = (1, 6, 0, "dev139", "gfdaeb8620") diff --git a/mne/epochs.py b/mne/epochs.py index e60e235f5cd..9cdb53cd753 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -799,8 +799,7 @@ def _reject_setup(self, reject, flat): ) bads = set(rej.keys()) - set(idx.keys()) if len(bads) > 0: - raise KeyError( - "Unknown channel types found in %s: %s" % (kind, bads)) + raise KeyError("Unknown channel types found in %s: %s" % (kind, bads)) for key in idx.keys(): # don't throw an error if rejection/flat would do nothing @@ -815,7 +814,7 @@ def _reject_setup(self, reject, flat): "%s." % (key.upper(), key.upper()) ) - # check for invalid values + # check for invalid values for rej, kind in zip((reject, flat), ("Rejection", "Flat")): for key, val in rej.items(): name = f"{kind} dict value for {key}" @@ -825,7 +824,8 @@ def _reject_setup(self, reject, flat): if val is None or val < 0: raise ValueError( """If using numerical %s criteria, the value - must be >= 0 Not '%s'.""" % (kind, val) + must be >= 0 Not '%s'.""" + % (kind, val) ) # now check to see if our rejection and flat are getting more @@ -1547,9 +1547,9 @@ def drop(self, indices, reason="USER", verbose=None): if indices.ndim > 1: raise ValueError("indices must be a scalar or a 1-d array") - # Check if indices and reasons are of the same length + # Check if indices and reasons are of the same length # if using collection to drop epochs - if (isinstance(reason, (list, tuple))): + if isinstance(reason, (list, tuple)): if len(indices) != len(reason): raise ValueError( "If using a list or tuple as the reason, " @@ -1557,7 +1557,6 @@ def drop(self, indices, reason="USER", verbose=None): f"{len(indices)} and {len(reason)}" ) - if indices.dtype == bool: indices = np.where(indices)[0] try_idx = np.where(indices < 0, indices + len(self.events), indices) @@ -3719,19 +3718,16 @@ def _is_good( _validate_type(result, tuple, result, "tuple") if len(result) != 2: raise TypeError( - "Function criterion must return a " - "tuple of length 2" + "Function criterion must return a " "tuple of length 2" ) cri_truth, reasons = result - _validate_type(cri_truth, (bool, np.bool_), - cri_truth, "bool") + _validate_type(cri_truth, (bool, np.bool_), cri_truth, "bool") _validate_type( - reasons, (str, list, tuple), - reasons, "str, list, or tuple" + reasons, (str, list, tuple), reasons, "str, list, or tuple" ) - idx_deltas = np.where( - np.logical_and(cri_truth, checkable_idx) - )[0] + idx_deltas = np.where(np.logical_and(cri_truth, checkable_idx))[ + 0 + ] else: deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1) idx_deltas = np.where( diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index f71eabda717..7f54ab12fe7 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -574,7 +574,7 @@ def my_reject_2(epoch_data): reasons = tuple(epochs.ch_name[bad_idx] for bad_idx in bad_idxs) return len(bad_idxs), reasons - bad_types = [my_reject_1, my_reject_2, ('Hi' 'Hi'), (1, 1)] + bad_types = [my_reject_1, my_reject_2, ("Hi" "Hi"), (1, 1)] for val in bad_types: # protect against bad types for kwarg in ("reject", "flat"): pytest.raises( @@ -2200,9 +2200,7 @@ def test_callable_reject(): tmin=0, tmax=1, baseline=None, - reject=dict( - eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median") - ), + reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median")), preload=True, ) assert epochs.drop_log[2] != () @@ -2221,10 +2219,7 @@ def test_callable_reject(): def reject_criteria(x): max_condition = np.max(x, axis=1) > 1e-2 median_condition = np.median(x, axis=1) > 1e-4 - return ( - (max_condition.any() or median_condition.any()), - "eeg max or median" - ) + return ((max_condition.any() or median_condition.any()), "eeg max or median") epochs = mne.Epochs( edit_raw, @@ -3294,21 +3289,16 @@ def test_drop_epochs(): # Test using tuple to drop epochs raw, events, picks = _get_data() - epochs_tuple = Epochs( - raw, events, event_id, - tmin, tmax, picks=picks, preload=True - ) + epochs_tuple = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True) selection_tuple = epochs_tuple.selection.copy() - epochs_tuple.drop((2, 3, 4), reason=([['list'], 'string', ('tuple',)])) + epochs_tuple.drop((2, 3, 4), reason=([["list"], "string", ("tuple",)])) n_events = len(epochs.events) assert_equal( - [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]], [ - ['list'], ["string"], ['tuple']] + [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]], + [["list"], ["string"], ["tuple"]], ) - - @pytest.mark.parametrize("preload", (True, False)) def test_drop_epochs_mult(preload): """Test that subselecting epochs or making fewer epochs is similar.""" diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index 27f1085f713..51f8fa012f8 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -397,9 +397,7 @@ tmin=0, tmax=1, baseline=None, - reject=dict( - eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp") - ), + reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp")), preload=True, ) epochs.plot(scalings=dict(eeg=50e-5)) @@ -413,10 +411,7 @@ def reject_criteria(x): max_condition = np.max(x, axis=1) > 1e-2 median_condition = np.median(x, axis=1) > 1e-4 - return ( - (max_condition.any() or median_condition.any()), - ["max amp", "median amp"] - ) + return ((max_condition.any() or median_condition.any()), ["max amp", "median amp"]) epochs = mne.Epochs( From a579491a5e88c8362a2455628ee7690b7988e28b Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Tue, 9 Jan 2024 15:46:14 -0600 Subject: [PATCH 21/41] Delete mne/_version.py Not supposed to be tracked --- mne/_version.py | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 mne/_version.py diff --git a/mne/_version.py b/mne/_version.py deleted file mode 100644 index d402c528c09..00000000000 --- a/mne/_version.py +++ /dev/null @@ -1,17 +0,0 @@ -# file generated by setuptools_scm -# don't change, don't track in version control -TYPE_CHECKING = False -if TYPE_CHECKING: - from typing import Tuple, Union - - VERSION_TUPLE = Tuple[Union[int, str], ...] -else: - VERSION_TUPLE = object - -version: str -__version__: str -__version_tuple__: VERSION_TUPLE -version_tuple: VERSION_TUPLE - -__version__ = version = "1.6.0.dev139+gfdaeb8620" -__version_tuple__ = version_tuple = (1, 6, 0, "dev139", "gfdaeb8620") From c9da4db5ce88942328bb23914a1a5667aa4f8a5b Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Tue, 9 Jan 2024 15:51:38 -0600 Subject: [PATCH 22/41] Add None Check --- mne/tests/test_epochs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 7f54ab12fe7..6774fff9880 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -574,7 +574,7 @@ def my_reject_2(epoch_data): reasons = tuple(epochs.ch_name[bad_idx] for bad_idx in bad_idxs) return len(bad_idxs), reasons - bad_types = [my_reject_1, my_reject_2, ("Hi" "Hi"), (1, 1)] + bad_types = [my_reject_1, my_reject_2, ("Hi" "Hi"), (1, 1), None] for val in bad_types: # protect against bad types for kwarg in ("reject", "flat"): pytest.raises( From 5a7a6182fe76e137c543c4398c893cd00261ce5c Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Wed, 10 Jan 2024 12:18:31 -0600 Subject: [PATCH 23/41] Update mne/tests/test_epochs.py Co-authored-by: Eric Larson --- mne/tests/test_epochs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 6774fff9880..96d6b4231e0 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -2219,7 +2219,7 @@ def test_callable_reject(): def reject_criteria(x): max_condition = np.max(x, axis=1) > 1e-2 median_condition = np.median(x, axis=1) > 1e-4 - return ((max_condition.any() or median_condition.any()), "eeg max or median") + return (max_condition.any() or median_condition.any()), "eeg max or median" epochs = mne.Epochs( edit_raw, From 98b92c4624f352c4c550a4a3db134fa47092b090 Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Wed, 10 Jan 2024 12:18:50 -0600 Subject: [PATCH 24/41] Update mne/utils/mixin.py Co-authored-by: Eric Larson --- mne/utils/mixin.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index c7cfa06297e..bf2b9f6b7a1 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -209,23 +209,14 @@ def _getitem( key_selection = inst.selection[select] drop_log = list(inst.drop_log) if reason is not None: - # Used for multiple reasons - if isinstance(reason, tuple): - reason = list(reason) - - if isinstance(reason, list): - for i, idx in enumerate( - np.setdiff1d(inst.selection, key_selection) - ): - r = reason[i] - if isinstance(r, str): - r = (r,) - if isinstance(r, list): - r = tuple(r) - drop_log[idx] = r - else: - for idx in np.setdiff1d(inst.selection, key_selection): - drop_log[idx] = (reason,) + _validate_type(reason, (list, tuple, str), "reason") + if isinstance(reason, str): + reason = (reason,) + reason = tuple(reason) + for i, idx in enumerate( + np.setdiff1d(inst.selection, key_selection) + ): + drop_log[idx] = reason inst.drop_log = tuple(drop_log) inst.selection = key_selection del drop_log From 44fdf8a438ee4ced6efe2ba5977d95b32a3c2419 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jan 2024 18:19:25 +0000 Subject: [PATCH 25/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/utils/mixin.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index bf2b9f6b7a1..e62b2a57a6a 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -213,9 +213,7 @@ def _getitem( if isinstance(reason, str): reason = (reason,) reason = tuple(reason) - for i, idx in enumerate( - np.setdiff1d(inst.selection, key_selection) - ): + for i, idx in enumerate(np.setdiff1d(inst.selection, key_selection)): drop_log[idx] = reason inst.drop_log = tuple(drop_log) inst.selection = key_selection From 3ff8a9ee739f0ca5f4b35431219ef6e338078df0 Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Wed, 10 Jan 2024 12:21:30 -0600 Subject: [PATCH 26/41] Update mne/epochs.py Co-authored-by: Eric Larson --- mne/epochs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/epochs.py b/mne/epochs.py index 9cdb53cd753..dedf73aa870 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -3718,7 +3718,7 @@ def _is_good( _validate_type(result, tuple, result, "tuple") if len(result) != 2: raise TypeError( - "Function criterion must return a " "tuple of length 2" + "Function criterion must return a tuple of length 2" ) cri_truth, reasons = result _validate_type(cri_truth, (bool, np.bool_), cri_truth, "bool") From e729e81145bbe849fa437efea161fa20149763fe Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Wed, 10 Jan 2024 12:21:51 -0600 Subject: [PATCH 27/41] Update mne/tests/test_epochs.py Co-authored-by: Eric Larson --- mne/tests/test_epochs.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 96d6b4231e0..2919dfea5f8 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -3293,10 +3293,7 @@ def test_drop_epochs(): selection_tuple = epochs_tuple.selection.copy() epochs_tuple.drop((2, 3, 4), reason=([["list"], "string", ("tuple",)])) n_events = len(epochs.events) - assert_equal( - [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]], - [["list"], ["string"], ["tuple"]], - ) + assert [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]] == [("list",), ("string",), ("tuple",)] @pytest.mark.parametrize("preload", (True, False)) From fd4c75f2ada0aa7b9a64a9d03e63162f0c8e44ac Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Wed, 10 Jan 2024 12:22:06 -0600 Subject: [PATCH 28/41] Update mne/epochs.py Co-authored-by: Eric Larson --- mne/epochs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index dedf73aa870..0c87de09f7b 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -823,9 +823,8 @@ def _reject_setup(self, reject, flat): _validate_type(val, "numeric", name, extra="or callable") if val is None or val < 0: raise ValueError( - """If using numerical %s criteria, the value - must be >= 0 Not '%s'.""" - % (kind, val) + f"If using numerical {name} criteria, the value " + f"must be >= 0, not {repr(val)}" ) # now check to see if our rejection and flat are getting more From a685b89fce768bb82bc9c15a2b69eabcc6c24deb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jan 2024 18:22:20 +0000 Subject: [PATCH 29/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/tests/test_epochs.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 2919dfea5f8..c8b4280a428 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -3293,7 +3293,11 @@ def test_drop_epochs(): selection_tuple = epochs_tuple.selection.copy() epochs_tuple.drop((2, 3, 4), reason=([["list"], "string", ("tuple",)])) n_events = len(epochs.events) - assert [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]] == [("list",), ("string",), ("tuple",)] + assert [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]] == [ + ("list",), + ("string",), + ("tuple",), + ] @pytest.mark.parametrize("preload", (True, False)) From 65778d1b384691354b7d3049a3de84e83368bb2d Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Wed, 10 Jan 2024 12:26:31 -0600 Subject: [PATCH 30/41] Update mne/epochs.py Co-authored-by: Eric Larson --- mne/epochs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne/epochs.py b/mne/epochs.py index 0c87de09f7b..045cd7d32fe 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1757,7 +1757,7 @@ def _get_data( is_good, bad_tuple = self._is_good_epoch(epoch, verbose=verbose) if not is_good: assert isinstance(bad_tuple, tuple) - assert all(isinstance(x, (str, tuple)) for x in bad_tuple) + assert all(isinstance(x, str)) for x in bad_tuple) drop_log[sel] = drop_log[sel] + bad_tuple continue good_idx.append(idx) From 3ece3145619ef834e87f6e7e3354db4089c4f090 Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Wed, 10 Jan 2024 12:27:30 -0600 Subject: [PATCH 31/41] Apply suggestions from code review Co-authored-by: Eric Larson --- mne/epochs.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index 045cd7d32fe..43a48daa38f 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1533,7 +1533,7 @@ def drop(self, indices, reason="USER", verbose=None): mask to apply (where True values get removed). Events are correspondingly modified. reason : list | tuple | str - Reason for dropping the epochs ('ECG', 'timeout', 'blink' etc). + Reason(s) for dropping the epochs ('ECG', 'timeout', 'blink' etc). Default: 'USER'. %(verbose)s @@ -3714,7 +3714,7 @@ def _is_good( # Check if criterion is a function and apply it if callable(criterion): result = criterion(e_idx) - _validate_type(result, tuple, result, "tuple") + _validate_type(result, tuple, "reject/flat output") if len(result) != 2: raise TypeError( "Function criterion must return a tuple of length 2" @@ -3738,12 +3738,11 @@ def _is_good( # (bool, reason). Reason must be a str/list/tuple. # If using tuple if callable(refl): - if isinstance(reasons, (tuple, list)): - for idx, reason in enumerate(reasons): - _validate_type(reason, str, reason, "str") - bad_tuple += tuple(reasons) if isinstance(reasons, str): - bad_tuple += (reasons,) + reasons = (reasons,) + for idx, reason in enumerate(reasons): + _validate_type(reason, str, f"reasons[{idx}]") + bad_tuple += tuple(reasons) else: bad_names = [ch_names[idx[i]] for i in idx_deltas] if not has_printed: From b2686c0a7fee9691886843a2ca137309a41be62c Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Mon, 15 Jan 2024 21:20:25 -0600 Subject: [PATCH 32/41] Apply reason to all dropped epochs --- mne/epochs.py | 14 ++++---------- mne/tests/test_epochs.py | 38 +++++++++++++++++++++++++------------- mne/utils/mixin.py | 5 ++++- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index 2be67869112..fdf0ef3aa12 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1534,6 +1534,7 @@ def drop(self, indices, reason="USER", verbose=None): correspondingly modified. reason : list | tuple | str Reason(s) for dropping the epochs ('ECG', 'timeout', 'blink' etc). + Reason(s) are applied to all indices specified. Default: 'USER'. %(verbose)s @@ -1545,16 +1546,9 @@ def drop(self, indices, reason="USER", verbose=None): indices = np.atleast_1d(indices) if indices.ndim > 1: - raise ValueError("indices must be a scalar or a 1-d array") + raise TypeError("indices must be a scalar or a 1-d array") # Check if indices and reasons are of the same length # if using collection to drop epochs - if isinstance(reason, (list, tuple)): - if len(indices) != len(reason): - raise ValueError( - "If using a list or tuple as the reason, " - "indices and reasons must be of the same length, got " - f"{len(indices)} and {len(reason)}" - ) if indices.dtype == bool: indices = np.where(indices)[0] @@ -1757,7 +1751,7 @@ def _get_data( is_good, bad_tuple = self._is_good_epoch(epoch, verbose=verbose) if not is_good: assert isinstance(bad_tuple, tuple) - assert all(isinstance(x, str)) for x in bad_tuple) + assert all(isinstance(x, str) for x in bad_tuple) drop_log[sel] = drop_log[sel] + bad_tuple continue good_idx.append(idx) @@ -3741,7 +3735,7 @@ def _is_good( if isinstance(reasons, str): reasons = (reasons,) for idx, reason in enumerate(reasons): - _validate_type(reason, str, f"reasons[{idx}]") + _validate_type(reason, str, reason) bad_tuple += tuple(reasons) else: bad_names = [ch_names[idx[i]] for i in idx_deltas] diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index c8b4280a428..02e48295073 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -488,10 +488,10 @@ def test_average_movements(): def _assert_drop_log_types(drop_log): __tracebackhide__ = True - assert isinstance(drop_log, (tuple, list)), """drop_log should be tuple + assert isinstance(drop_log, tuple), """drop_log should be tuple or list""" assert all( - isinstance(log, (tuple, list)) for log in drop_log + isinstance(log, tuple) for log in drop_log ), "drop_log[ii] should be tuple or list" assert all( isinstance(s, str) for log in drop_log for s in log @@ -2192,8 +2192,8 @@ def test_callable_reject(): events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0) epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, baseline=None, preload=True) - assert len(epochs) == 5 + epochs = mne.Epochs( edit_raw, events, @@ -2203,7 +2203,7 @@ def test_callable_reject(): reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median")), preload=True, ) - assert epochs.drop_log[2] != () + assert epochs.drop_log[2] == ('eeg median',) epochs = mne.Epochs( edit_raw, @@ -2211,10 +2211,10 @@ def test_callable_reject(): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), "eeg max")), + reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), ("eeg max",))), preload=True, ) - assert epochs.drop_log[0] != () + assert epochs.drop_log[0] == ("eeg max",) def reject_criteria(x): max_condition = np.max(x, axis=1) > 1e-2 @@ -2230,9 +2230,19 @@ def reject_criteria(x): reject=dict(eeg=reject_criteria), preload=True, ) - print(epochs.drop_log) - assert epochs.drop_log[0] != () and epochs.drop_log[2] != () + assert epochs.drop_log[0] == ("eeg max or median",) and epochs.drop_log[2] == ("eeg max or median",) + # Test reasons must be str or tuple of str + with pytest.raises(TypeError): + epochs = mne.Epochs( + edit_raw, + events, + tmin=0, + tmax=1, + baseline=None, + reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), ("eeg median", 2))), + preload=True, + ) def test_preload_epochs(): """Test preload of epochs.""" @@ -3267,7 +3277,9 @@ def test_drop_epochs(): # Bound checks pytest.raises(IndexError, epochs.drop, [len(epochs.events)]) pytest.raises(IndexError, epochs.drop, [-len(epochs.events) - 1]) - pytest.raises(ValueError, epochs.drop, [[1, 2], [3, 4]]) + pytest.raises(TypeError, epochs.drop, [[1, 2], [3, 4]]) + with pytest.raises(TypeError): + epochs.drop([1], reason=('a', 'b', 2)) # Test selection attribute assert_array_equal(epochs.selection, np.where(events[:, 2] == event_id)[0]) @@ -3291,12 +3303,12 @@ def test_drop_epochs(): raw, events, picks = _get_data() epochs_tuple = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True) selection_tuple = epochs_tuple.selection.copy() - epochs_tuple.drop((2, 3, 4), reason=([["list"], "string", ("tuple",)])) + epochs_tuple.drop((2, 3, 4), reason=('a', 'b')) n_events = len(epochs.events) assert [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]] == [ - ("list",), - ("string",), - ("tuple",), + ("a", 'b'), + ("a", 'b'), + ("a", 'b'), ] diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index e62b2a57a6a..a6de7ed9907 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -178,7 +178,7 @@ def _getitem( ---------- item: slice, array-like, str, or list see `__getitem__` for details. - reason: str + reason: str, list/tuple of str entry in `drop_log` for unselected epochs copy: bool return a copy of the current object @@ -210,6 +210,9 @@ def _getitem( drop_log = list(inst.drop_log) if reason is not None: _validate_type(reason, (list, tuple, str), "reason") + if isinstance(reason, (list, tuple)): + for r in reason: + _validate_type(r, str, r) if isinstance(reason, str): reason = (reason,) reason = tuple(reason) From 45b30f359befdda594ee21c4dc331d5b1a9d7545 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jan 2024 06:24:49 +0000 Subject: [PATCH 33/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/tests/test_epochs.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 16a9abe5b18..208db4d6022 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -2203,7 +2203,7 @@ def test_callable_reject(): reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median")), preload=True, ) - assert epochs.drop_log[2] == ('eeg median',) + assert epochs.drop_log[2] == ("eeg median",) epochs = mne.Epochs( edit_raw, @@ -2230,7 +2230,9 @@ def reject_criteria(x): reject=dict(eeg=reject_criteria), preload=True, ) - assert epochs.drop_log[0] == ("eeg max or median",) and epochs.drop_log[2] == ("eeg max or median",) + assert epochs.drop_log[0] == ("eeg max or median",) and epochs.drop_log[2] == ( + "eeg max or median", + ) # Test reasons must be str or tuple of str with pytest.raises(TypeError): @@ -2240,10 +2242,13 @@ def reject_criteria(x): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), ("eeg median", 2))), + reject=dict( + eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), ("eeg median", 2)) + ), preload=True, ) + def test_preload_epochs(): """Test preload of epochs.""" raw, events, picks = _get_data() @@ -3279,7 +3284,7 @@ def test_drop_epochs(): pytest.raises(IndexError, epochs.drop, [-len(epochs.events) - 1]) pytest.raises(TypeError, epochs.drop, [[1, 2], [3, 4]]) with pytest.raises(TypeError): - epochs.drop([1], reason=('a', 'b', 2)) + epochs.drop([1], reason=("a", "b", 2)) # Test selection attribute assert_array_equal(epochs.selection, np.where(events[:, 2] == event_id)[0]) @@ -3303,12 +3308,12 @@ def test_drop_epochs(): raw, events, picks = _get_data() epochs_tuple = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True) selection_tuple = epochs_tuple.selection.copy() - epochs_tuple.drop((2, 3, 4), reason=('a', 'b')) + epochs_tuple.drop((2, 3, 4), reason=("a", "b")) n_events = len(epochs.events) assert [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]] == [ - ("a", 'b'), - ("a", 'b'), - ("a", 'b'), + ("a", "b"), + ("a", "b"), + ("a", "b"), ] From e77ae63d6c52daa4c93440936e068bcbbf9c0bf1 Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Mon, 15 Jan 2024 21:39:08 -0600 Subject: [PATCH 34/41] Add check --- mne/epochs.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index 15582f0c6fe..c787ae48635 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -814,7 +814,6 @@ def _reject_setup(self, reject, flat): f"{key.upper()}." ) -<<<<<<< HEAD # check for invalid values for rej, kind in zip((reject, flat), ("Rejection", "Flat")): for key, val in rej.items(): @@ -827,13 +826,6 @@ def _reject_setup(self, reject, flat): f"If using numerical {name} criteria, the value " f"must be >= 0, not {repr(val)}" ) -======= - # check for invalid values - for rej, kind in zip((reject, flat), ("Rejection", "Flat")): - for key, val in rej.items(): - if val is None or val < 0: - raise ValueError(f'{kind} value must be a number >= 0, not "{val}"') ->>>>>>> 2040898ac14e79353b7a23a07e177d1633298c0f # now check to see if our rejection and flat are getting more # restrictive @@ -3736,7 +3728,6 @@ def _is_good( )[0] if len(idx_deltas) > 0: -<<<<<<< HEAD # Check to verify that refl is a callable that returns # (bool, reason). Reason must be a str/list/tuple. # If using tuple @@ -3746,16 +3737,6 @@ def _is_good( for idx, reason in enumerate(reasons): _validate_type(reason, str, reason) bad_tuple += tuple(reasons) -======= - bad_names = [ch_names[idx[i]] for i in idx_deltas] - if not has_printed: - logger.info( - f" Rejecting {t} epoch based on {name} : {bad_names}" - ) - has_printed = True - if not full_report: - return False ->>>>>>> 2040898ac14e79353b7a23a07e177d1633298c0f else: bad_names = [ch_names[idx[i]] for i in idx_deltas] if not has_printed: From 4e4b369c3537bcff18a64152328a145c79ca2695 Mon Sep 17 00:00:00 2001 From: Jacob Woessner Date: Mon, 22 Jan 2024 23:43:32 -0600 Subject: [PATCH 35/41] Apply suggestions from code review Co-authored-by: Eric Larson --- mne/tests/test_epochs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 208db4d6022..a8e1daafa1e 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -488,11 +488,10 @@ def test_average_movements(): def _assert_drop_log_types(drop_log): __tracebackhide__ = True - assert isinstance(drop_log, tuple), """drop_log should be tuple - or list""" + assert isinstance(drop_log, tuple), "drop_log should be tuple" assert all( isinstance(log, tuple) for log in drop_log - ), "drop_log[ii] should be tuple or list" + ), "drop_log[ii] should be tuple" assert all( isinstance(s, str) for log in drop_log for s in log ), "drop_log[ii][jj] should be str" From b7c6a36b80d17799dd481985358fb46317394e38 Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Tue, 23 Jan 2024 02:09:36 -0600 Subject: [PATCH 36/41] Add suggestions --- mne/tests/test_epochs.py | 54 ++++++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index a8e1daafa1e..35a4b7c91a0 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -565,29 +565,32 @@ def test_reject(): ) def my_reject_1(epoch_data): - bad_idxs = np.where(np.percentile(epoch_data, axis=1) > 1e-35) + bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) return len(bad_idxs) > 0 def my_reject_2(epoch_data): - bad_idxs = np.where(np.percentile(epoch_data, axis=1) > 1e-35) - reasons = tuple(epochs.ch_name[bad_idx] for bad_idx in bad_idxs) + bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) + reasons = 'a' * len(bad_idxs[0]) return len(bad_idxs), reasons bad_types = [my_reject_1, my_reject_2, ("Hi" "Hi"), (1, 1), None] for val in bad_types: # protect against bad types for kwarg in ("reject", "flat"): - pytest.raises( + with pytest.raises( TypeError, - Epochs, - raw, - events, - event_id, - tmin, - tmax, - picks=picks_meg, - preload=True, - **{kwarg: dict(grad=val)}, - ) + match=r".* must be an instance of .* got instead." + ): + Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks_meg, + preload=True, + **{kwarg: dict(grad=val)}, + ) + pytest.raises( KeyError, Epochs, @@ -3279,10 +3282,25 @@ def test_drop_epochs(): events1 = events[events[:, 2] == event_id] # Bound checks - pytest.raises(IndexError, epochs.drop, [len(epochs.events)]) - pytest.raises(IndexError, epochs.drop, [-len(epochs.events) - 1]) - pytest.raises(TypeError, epochs.drop, [[1, 2], [3, 4]]) - with pytest.raises(TypeError): + with pytest.raises( + IndexError, + match=r"Epoch index .* is out of bounds" + ): + epochs.drop([len(epochs.events)]) + with pytest.raises( + IndexError, + match=r"Epoch index .* is out of bounds" + ): + epochs.drop([-len(epochs.events) - 1]) + with pytest.raises( + TypeError, + match="indices must be a scalar or a 1-d array" + ): + epochs.drop([[1, 2], [3, 4]]) + with pytest.raises( + TypeError, + match=r".* must be an instance of .* got instead." + ): epochs.drop([1], reason=("a", "b", 2)) # Test selection attribute From ba45b5aee53e2cfcbf200a18ed899800ebb00568 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jan 2024 08:10:21 +0000 Subject: [PATCH 37/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/tests/test_epochs.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 35a4b7c91a0..2dce010f01e 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -570,7 +570,7 @@ def my_reject_1(epoch_data): def my_reject_2(epoch_data): bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) - reasons = 'a' * len(bad_idxs[0]) + reasons = "a" * len(bad_idxs[0]) return len(bad_idxs), reasons bad_types = [my_reject_1, my_reject_2, ("Hi" "Hi"), (1, 1), None] @@ -578,7 +578,7 @@ def my_reject_2(epoch_data): for kwarg in ("reject", "flat"): with pytest.raises( TypeError, - match=r".* must be an instance of .* got instead." + match=r".* must be an instance of .* got instead.", ): Epochs( raw, @@ -3282,24 +3282,14 @@ def test_drop_epochs(): events1 = events[events[:, 2] == event_id] # Bound checks - with pytest.raises( - IndexError, - match=r"Epoch index .* is out of bounds" - ): + with pytest.raises(IndexError, match=r"Epoch index .* is out of bounds"): epochs.drop([len(epochs.events)]) - with pytest.raises( - IndexError, - match=r"Epoch index .* is out of bounds" - ): + with pytest.raises(IndexError, match=r"Epoch index .* is out of bounds"): epochs.drop([-len(epochs.events) - 1]) - with pytest.raises( - TypeError, - match="indices must be a scalar or a 1-d array" - ): + with pytest.raises(TypeError, match="indices must be a scalar or a 1-d array"): epochs.drop([[1, 2], [3, 4]]) with pytest.raises( - TypeError, - match=r".* must be an instance of .* got instead." + TypeError, match=r".* must be an instance of .* got instead." ): epochs.drop([1], reason=("a", "b", 2)) From b5829b07268045f680482d923f12b0df94fb36bf Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Tue, 23 Jan 2024 02:18:31 -0600 Subject: [PATCH 38/41] devel --- doc/changes/devel/12195.newfeature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 doc/changes/devel/12195.newfeature.rst diff --git a/doc/changes/devel/12195.newfeature.rst b/doc/changes/devel/12195.newfeature.rst new file mode 100644 index 00000000000..da5e59cbc1a --- /dev/null +++ b/doc/changes/devel/12195.newfeature.rst @@ -0,0 +1 @@ +Add ability reject :class:`mne.Epochs` using callables. \ No newline at end of file From a11a4106a55c0150c8904a004603b55ca3fa0a0d Mon Sep 17 00:00:00 2001 From: withmywoessner Date: Wed, 24 Jan 2024 17:25:43 -0600 Subject: [PATCH 39/41] Remove support for callabes in constructor --- mne/epochs.py | 15 +++-- mne/tests/test_epochs.py | 66 ++++++++++++++----- mne/utils/docs.py | 56 +++++++++++----- .../preprocessing/20_rejecting_bad_data.py | 11 ++-- 4 files changed, 104 insertions(+), 44 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index c787ae48635..3e1c11650fd 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -787,7 +787,7 @@ def apply_baseline(self, baseline=(None, 0), *, verbose=None): self.baseline = baseline return self - def _reject_setup(self, reject, flat): + def _reject_setup(self, reject, flat, *, allow_callable=False): """Set self._reject_time and self._channel_type_idx.""" idx = channel_indices_by_type(self.info) reject = deepcopy(reject) if reject is not None else dict() @@ -818,9 +818,12 @@ def _reject_setup(self, reject, flat): for rej, kind in zip((reject, flat), ("Rejection", "Flat")): for key, val in rej.items(): name = f"{kind} dict value for {key}" - if callable(val): + if callable(val) and allow_callable: continue - _validate_type(val, "numeric", name, extra="or callable") + extra_str = "" + if allow_callable: + extra_str = "or callable" + _validate_type(val, "numeric", name, extra=extra_str) if val is None or val < 0: raise ValueError( f"If using numerical {name} criteria, the value " @@ -844,7 +847,7 @@ def _reject_setup(self, reject, flat): # make sure new thresholds are at least as stringent as the old ones for key in reject: # Skip this check if old_reject and reject are callables - if callable(reject[key]): + if callable(reject[key]) and allow_callable: continue if key in old_reject and reject[key] > old_reject[key]: raise ValueError( @@ -861,7 +864,7 @@ def _reject_setup(self, reject, flat): for key in set(old_flat) - set(flat): flat[key] = old_flat[key] for key in flat: - if callable(flat[key]): + if callable(flat[key]) and allow_callable: continue if key in old_flat and flat[key] < old_flat[key]: raise ValueError( @@ -1416,7 +1419,7 @@ def drop_bad(self, reject="existing", flat="existing", verbose=None): flat = self.flat if any(isinstance(rej, str) and rej != "existing" for rej in (reject, flat)): raise ValueError('reject and flat, if strings, must be "existing"') - self._reject_setup(reject, flat) + self._reject_setup(reject, flat, allow_callable=True) self._get_data(out=False, verbose=verbose) return self diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 2dce010f01e..34218546ffc 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -549,6 +549,19 @@ def test_reject(): preload=False, reject=dict(eeg=np.inf), ) + + # Good function + def my_reject_1(epoch_data): + bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) + reasons = "a" * len(bad_idxs[0]) + return len(bad_idxs) > 0, reasons + + # Bad function + def my_reject_2(epoch_data): + bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) + reasons = "a" * len(bad_idxs[0]) + return len(bad_idxs), reasons + for val in (-1, -2): # protect against older MNE-C types for kwarg in ("reject", "flat"): pytest.raises( @@ -564,23 +577,33 @@ def test_reject(): **{kwarg: dict(grad=val)}, ) - def my_reject_1(epoch_data): - bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) - return len(bad_idxs) > 0 - - def my_reject_2(epoch_data): - bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35) - reasons = "a" * len(bad_idxs[0]) - return len(bad_idxs), reasons + # Check that reject and flat in constructor are not callables + val = my_reject_1 + for kwarg in ("reject", "flat"): + with pytest.raises( + TypeError, + match=r".* must be an instance of numeric, got instead." + ): + Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks_meg, + preload=False, + **{kwarg: dict(grad=val)}, + ) - bad_types = [my_reject_1, my_reject_2, ("Hi" "Hi"), (1, 1), None] + # Check if callable returns a tuple with reasons + bad_types = [my_reject_2, ("Hi" "Hi"), (1, 1), None] for val in bad_types: # protect against bad types for kwarg in ("reject", "flat"): with pytest.raises( TypeError, match=r".* must be an instance of .* got instead.", ): - Epochs( + epochs = Epochs( raw, events, event_id, @@ -588,8 +611,8 @@ def my_reject_2(epoch_data): tmax, picks=picks_meg, preload=True, - **{kwarg: dict(grad=val)}, ) + epochs.drop_bad(**{kwarg: dict(grad=val)}) pytest.raises( KeyError, @@ -2202,9 +2225,10 @@ def test_callable_reject(): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median")), preload=True, ) + epochs.drop_bad(reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median"))) + assert epochs.drop_log[2] == ("eeg median",) epochs = mne.Epochs( @@ -2213,9 +2237,10 @@ def test_callable_reject(): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), ("eeg max",))), preload=True, ) + epochs.drop_bad(reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), ("eeg max",)))) + assert epochs.drop_log[0] == ("eeg max",) def reject_criteria(x): @@ -2229,25 +2254,31 @@ def reject_criteria(x): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=reject_criteria), preload=True, ) + epochs.drop_bad(reject=dict(eeg=reject_criteria)) + assert epochs.drop_log[0] == ("eeg max or median",) and epochs.drop_log[2] == ( "eeg max or median", ) # Test reasons must be str or tuple of str - with pytest.raises(TypeError): + with pytest.raises( + TypeError, + match=r".* must be an instance of str, got instead.", + ): epochs = mne.Epochs( edit_raw, events, tmin=0, tmax=1, baseline=None, + preload=True, + ) + epochs.drop_bad( reject=dict( eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), ("eeg median", 2)) - ), - preload=True, + ) ) @@ -3323,7 +3354,6 @@ def test_drop_epochs(): ("a", "b"), ] - @pytest.mark.parametrize("preload", (True, False)) def test_drop_epochs_mult(preload): """Test that subselecting epochs or making fewer epochs is similar.""" diff --git a/mne/utils/docs.py b/mne/utils/docs.py index efc1046ac44..c3005427ead 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1442,14 +1442,11 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): """ _flat_common = """\ - Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP) - or a custom function. Valid **keys** can be any channel type present - in the object. If using PTP, **values** are floats that set the minimum - acceptable PTP. If the PTP is smaller than this threshold, the epoch - will be dropped. If ``None`` then no rejection is performed based on - flatness of the signal. If a custom function is used than ``flat`` can be - used to reject epochs based on any criteria (including maxima and - minima).""" + Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP). + Valid **keys** can be any channel type present in the object. The + **values** are floats that set the minimum acceptable PTP. If the PTP + is smaller than this threshold, the epoch will be dropped. If ``None`` + then no rejection is performed based on flatness of the signal.""" docdict["flat"] = f""" flat : dict | None @@ -1459,9 +1456,16 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): quality, pass the ``reject_tmin`` and ``reject_tmax`` parameters. """ -docdict["flat_drop_bad"] = f""" +docdict["flat_drop_bad"] = """ flat : dict | str | None -{_flat_common} + Reject epochs based on **minimum** peak-to-peak signal amplitude (PTP) + or a custom function. Valid **keys** can be any channel type present + in the object. If using PTP, **values** are floats that set the minimum + acceptable PTP. If the PTP is smaller than this threshold, the epoch + will be dropped. If ``None`` then no rejection is performed based on + flatness of the signal. If a custom function is used than ``flat`` can be + used to reject epochs based on any criteria (including maxima and + minima). If ``'existing'``, then the flat parameters set during epoch creation are used. """ @@ -3271,6 +3275,31 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): ) _reject_common = """\ + Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP), + i.e. the absolute difference between the lowest and the highest signal + value. In each individual epoch, the PTP is calculated for every channel. + If the PTP of any one channel exceeds the rejection threshold, the + respective epoch will be dropped. + + The dictionary keys correspond to the different channel types; valid + **keys** can be any channel type present in the object. + + Example:: + + reject = dict(grad=4000e-13, # unit: T / m (gradiometers) + mag=4e-12, # unit: T (magnetometers) + eeg=40e-6, # unit: V (EEG channels) + eog=250e-6 # unit: V (EOG channels) + ) + + .. note:: Since rejection is based on a signal **difference** + calculated for each channel separately, applying baseline + correction does not affect the rejection procedure, as the + difference will be preserved. +""" + +docdict["reject_drop_bad"] = """ +reject : dict | str | None Reject epochs based on **maximum** peak-to-peak signal amplitude (PTP) or custom functions. Peak-to-peak signal amplitude is defined as the absolute difference between the lowest and the highest signal @@ -3303,14 +3332,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): .. note:: If ``reject`` is a callable, than **any** criteria can be used to reject epochs (including maxima and minima). -""" # noqa: E501 - -docdict["reject_drop_bad"] = f""" -reject : dict | str | None -{_reject_common} If ``reject`` is ``None``, no rejection is performed. If ``'existing'`` (default), then the rejection parameters set at instantiation are used. -""" +""" # noqa: E501 docdict["reject_epochs"] = f""" reject : dict | None diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index 51f8fa012f8..4883c6bce4c 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -338,7 +338,7 @@ # Sometimes it is useful to reject epochs based criteria other than # peak-to-peak amplitudes. For example, we might want to reject epochs # based on the maximum or minimum amplitude of a channel. -# In this case, the :class:`mne.Epochs` class constructor also accepts +# In this case, the `mne.Epochs.drop_bad` function also accepts # callables (functions) in the ``reject`` and ``flat`` parameters. This # allows us to define functions to reject epochs based on our desired criteria. # @@ -376,9 +376,10 @@ tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1e-2).any(), "max amp")), preload=True, ) + +epochs.drop_bad(reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1e-2).any(), "max amp"))) epochs.plot(scalings=dict(eeg=50e-5)) # %% @@ -397,9 +398,10 @@ tmin=0, tmax=1, baseline=None, - reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp")), preload=True, ) + +epochs.drop_bad(reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp"))) epochs.plot(scalings=dict(eeg=50e-5)) # %% @@ -420,9 +422,10 @@ def reject_criteria(x): tmin=0, tmax=1, baseline=None, - reject=dict(eeg=reject_criteria), preload=True, ) + +epochs.drop_bad(reject=dict(eeg=reject_criteria)) epochs.plot(events=True) # %% From ab009bce866fba08aa647a7eaf2e0d29639c8f63 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Jan 2024 23:26:36 +0000 Subject: [PATCH 40/41] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne/tests/test_epochs.py | 11 ++++++++--- tutorials/preprocessing/20_rejecting_bad_data.py | 8 ++++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 34218546ffc..13161665e37 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -582,7 +582,7 @@ def my_reject_2(epoch_data): for kwarg in ("reject", "flat"): with pytest.raises( TypeError, - match=r".* must be an instance of numeric, got instead." + match=r".* must be an instance of numeric, got instead.", ): Epochs( raw, @@ -2227,7 +2227,9 @@ def test_callable_reject(): baseline=None, preload=True, ) - epochs.drop_bad(reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median"))) + epochs.drop_bad( + reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median")) + ) assert epochs.drop_log[2] == ("eeg median",) @@ -2239,7 +2241,9 @@ def test_callable_reject(): baseline=None, preload=True, ) - epochs.drop_bad(reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), ("eeg max",)))) + epochs.drop_bad( + reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), ("eeg max",))) + ) assert epochs.drop_log[0] == ("eeg max",) @@ -3354,6 +3358,7 @@ def test_drop_epochs(): ("a", "b"), ] + @pytest.mark.parametrize("preload", (True, False)) def test_drop_epochs_mult(preload): """Test that subselecting epochs or making fewer epochs is similar.""" diff --git a/tutorials/preprocessing/20_rejecting_bad_data.py b/tutorials/preprocessing/20_rejecting_bad_data.py index 4883c6bce4c..a04005f3532 100644 --- a/tutorials/preprocessing/20_rejecting_bad_data.py +++ b/tutorials/preprocessing/20_rejecting_bad_data.py @@ -379,7 +379,9 @@ preload=True, ) -epochs.drop_bad(reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1e-2).any(), "max amp"))) +epochs.drop_bad( + reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1e-2).any(), "max amp")) +) epochs.plot(scalings=dict(eeg=50e-5)) # %% @@ -401,7 +403,9 @@ preload=True, ) -epochs.drop_bad(reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp"))) +epochs.drop_bad( + reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-4).any(), "median amp")) +) epochs.plot(scalings=dict(eeg=50e-5)) # %% From fb6a5b16fdaf03937238d40df926f4936abf13fc Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 2 Feb 2024 09:32:34 -0500 Subject: [PATCH 41/41] Apply suggestions from code review --- doc/changes/devel/12195.newfeature.rst | 2 +- mne/utils/mixin.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/changes/devel/12195.newfeature.rst b/doc/changes/devel/12195.newfeature.rst index da5e59cbc1a..0c7e044abce 100644 --- a/doc/changes/devel/12195.newfeature.rst +++ b/doc/changes/devel/12195.newfeature.rst @@ -1 +1 @@ -Add ability reject :class:`mne.Epochs` using callables. \ No newline at end of file +Add ability reject :class:`mne.Epochs` using callables, by `Jacob Woessner`_. \ No newline at end of file diff --git a/mne/utils/mixin.py b/mne/utils/mixin.py index a6de7ed9907..87e86aaa315 100644 --- a/mne/utils/mixin.py +++ b/mne/utils/mixin.py @@ -216,7 +216,7 @@ def _getitem( if isinstance(reason, str): reason = (reason,) reason = tuple(reason) - for i, idx in enumerate(np.setdiff1d(inst.selection, key_selection)): + for idx in np.setdiff1d(inst.selection, key_selection): drop_log[idx] = reason inst.drop_log = tuple(drop_log) inst.selection = key_selection