Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add ability to reject epochs using callables #12195

Merged
merged 54 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
81ef83e
Add ability to reject epochs using functions
withmywoessner Nov 11, 2023
d8dda07
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 11, 2023
e9294de
Merge branch 'main' into epoch_reject
withmywoessner Nov 13, 2023
06e6770
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Nov 16, 2023
867496d
Update docs
withmywoessner Nov 19, 2023
1b4f5b3
Add ability to reject based on callables
withmywoessner Nov 19, 2023
2a66049
Add tutorial
withmywoessner Nov 19, 2023
e16f4a1
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Nov 19, 2023
e708465
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2023
fbdec77
Make flake8 compliant
withmywoessner Nov 20, 2023
6e23ecc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2023
cdd2843
Add docstrings and make flake8 compliant
withmywoessner Nov 20, 2023
cd9f7b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2023
3f5fd84
Update mne/epochs.py
withmywoessner Dec 1, 2023
a74ccf6
Update tutorials/preprocessing/20_rejecting_bad_data.py
withmywoessner Dec 1, 2023
f0cb1b8
Update tutorials/preprocessing/20_rejecting_bad_data.py
withmywoessner Dec 1, 2023
7d02fca
Update mne/utils/docs.py
withmywoessner Dec 1, 2023
1b836ad
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Dec 4, 2023
f3e8841
Make callable check more fine, doc, add noqa
withmywoessner Dec 6, 2023
984c604
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Dec 6, 2023
fbe4cd2
Update epochs so that adding refl tuple doesnt cause error
withmywoessner Jan 5, 2024
ee599a7
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Jan 5, 2024
24e669c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2024
bce6486
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Jan 9, 2024
8401c92
return callable/reasons
withmywoessner Jan 9, 2024
cf1facf
allow callables
withmywoessner Jan 9, 2024
e98bee2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2024
a579491
Delete mne/_version.py
withmywoessner Jan 9, 2024
c9da4db
Add None Check
withmywoessner Jan 9, 2024
5a7a618
Update mne/tests/test_epochs.py
withmywoessner Jan 10, 2024
98b92c4
Update mne/utils/mixin.py
withmywoessner Jan 10, 2024
44fdf8a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
3ff8a9e
Update mne/epochs.py
withmywoessner Jan 10, 2024
e729e81
Update mne/tests/test_epochs.py
withmywoessner Jan 10, 2024
fd4c75f
Update mne/epochs.py
withmywoessner Jan 10, 2024
a685b89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
65778d1
Update mne/epochs.py
withmywoessner Jan 10, 2024
3ece314
Apply suggestions from code review
withmywoessner Jan 10, 2024
db0cf11
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Jan 10, 2024
b2686c0
Apply reason to all dropped epochs
withmywoessner Jan 16, 2024
3ae37a4
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Jan 16, 2024
45b30f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2024
e77ae63
Add check
withmywoessner Jan 16, 2024
f560ccd
Merge branch 'main' into epoch_reject
withmywoessner Jan 16, 2024
235f6d3
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Jan 23, 2024
4e4b369
Apply suggestions from code review
withmywoessner Jan 23, 2024
b7c6a36
Add suggestions
withmywoessner Jan 23, 2024
ba45b5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 23, 2024
b5829b0
devel
withmywoessner Jan 23, 2024
a11a410
Remove support for callabes in constructor
withmywoessner Jan 24, 2024
b918ea8
Merge branch 'main' of https://github.com/mne-tools/mne-python into e…
withmywoessner Jan 24, 2024
ab009bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2024
8cf2e49
Merge branch 'main' into epoch_reject
withmywoessner Feb 1, 2024
fb6a5b1
Apply suggestions from code review
larsoner Feb 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/12195.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ability reject :class:`mne.Epochs` using callables.
larsoner marked this conversation as resolved.
Show resolved Hide resolved
102 changes: 78 additions & 24 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def apply_baseline(self, baseline=(None, 0), *, verbose=None):
self.baseline = baseline
return self

def _reject_setup(self, reject, flat):
def _reject_setup(self, reject, flat, *, allow_callable=False):
"""Set self._reject_time and self._channel_type_idx."""
idx = channel_indices_by_type(self.info)
reject = deepcopy(reject) if reject is not None else dict()
Expand All @@ -814,11 +814,21 @@ def _reject_setup(self, reject, flat):
f"{key.upper()}."
)

# check for invalid values
for rej, kind in zip((reject, flat), ("Rejection", "Flat")):
for key, val in rej.items():
if val is None or val < 0:
raise ValueError(f'{kind} value must be a number >= 0, not "{val}"')
# check for invalid values
for rej, kind in zip((reject, flat), ("Rejection", "Flat")):
for key, val in rej.items():
name = f"{kind} dict value for {key}"
if callable(val) and allow_callable:
continue
extra_str = ""
if allow_callable:
extra_str = "or callable"
_validate_type(val, "numeric", name, extra=extra_str)
if val is None or val < 0:
raise ValueError(
f"If using numerical {name} criteria, the value "
f"must be >= 0, not {repr(val)}"
)

# now check to see if our rejection and flat are getting more
larsoner marked this conversation as resolved.
Show resolved Hide resolved
# restrictive
Expand All @@ -836,6 +846,9 @@ def _reject_setup(self, reject, flat):
reject[key] = old_reject[key]
# make sure new thresholds are at least as stringent as the old ones
for key in reject:
# Skip this check if old_reject and reject are callables
if callable(reject[key]) and allow_callable:
continue
if key in old_reject and reject[key] > old_reject[key]:
raise ValueError(
bad_msg.format(
Expand All @@ -851,6 +864,8 @@ def _reject_setup(self, reject, flat):
for key in set(old_flat) - set(flat):
flat[key] = old_flat[key]
for key in flat:
if callable(flat[key]) and allow_callable:
continue
if key in old_flat and flat[key] < old_flat[key]:
raise ValueError(
bad_msg.format(
Expand Down Expand Up @@ -1404,7 +1419,7 @@ def drop_bad(self, reject="existing", flat="existing", verbose=None):
flat = self.flat
if any(isinstance(rej, str) and rej != "existing" for rej in (reject, flat)):
raise ValueError('reject and flat, if strings, must be "existing"')
self._reject_setup(reject, flat)
self._reject_setup(reject, flat, allow_callable=True)
self._get_data(out=False, verbose=verbose)
return self

Expand Down Expand Up @@ -1520,8 +1535,9 @@ def drop(self, indices, reason="USER", verbose=None):
Set epochs to remove by specifying indices to remove or a boolean
mask to apply (where True values get removed). Events are
correspondingly modified.
reason : str
Reason for dropping the epochs ('ECG', 'timeout', 'blink' etc).
reason : list | tuple | str
Reason(s) for dropping the epochs ('ECG', 'timeout', 'blink' etc).
Reason(s) are applied to all indices specified.
Default: 'USER'.
%(verbose)s

Expand All @@ -1533,7 +1549,9 @@ def drop(self, indices, reason="USER", verbose=None):
indices = np.atleast_1d(indices)

if indices.ndim > 1:
raise ValueError("indices must be a scalar or a 1-d array")
raise TypeError("indices must be a scalar or a 1-d array")
# Check if indices and reasons are of the same length
# if using collection to drop epochs

if indices.dtype == np.dtype(bool):
indices = np.where(indices)[0]
Expand Down Expand Up @@ -3199,6 +3217,10 @@ class Epochs(BaseEpochs):
See :meth:`~mne.Epochs.equalize_event_counts`
- 'USER'
For user-defined reasons (see :meth:`~mne.Epochs.drop`).

When dropping based on flat or reject parameters the tuple of
reasons contains a tuple of channels that satisfied the rejection
criteria.
filename : str
The filename of the object.
times : ndarray
Expand Down Expand Up @@ -3667,37 +3689,69 @@ def _is_good(
):
"""Test if data segment e is good according to reject and flat.

The reject and flat parameters can accept functions as values.

If full_report=True, it will give True/False as well as a list of all
offending channels.
"""
bad_tuple = tuple()
has_printed = False
checkable = np.ones(len(ch_names), dtype=bool)
checkable[np.array([c in ignore_chs for c in ch_names], dtype=bool)] = False

for refl, f, t in zip([reject, flat], [np.greater, np.less], ["", "flat"]):
if refl is not None:
for key, thresh in refl.items():
for key, refl in refl.items():
criterion = refl
idx = channel_type_idx[key]
name = key.upper()
if len(idx) > 0:
e_idx = e[idx]
deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1)
checkable_idx = checkable[idx]
idx_deltas = np.where(
np.logical_and(f(deltas, thresh), checkable_idx)
)[0]
# Check if criterion is a function and apply it
if callable(criterion):
result = criterion(e_idx)
_validate_type(result, tuple, "reject/flat output")
if len(result) != 2:
raise TypeError(
"Function criterion must return a tuple of length 2"
)
cri_truth, reasons = result
_validate_type(cri_truth, (bool, np.bool_), cri_truth, "bool")
_validate_type(
reasons, (str, list, tuple), reasons, "str, list, or tuple"
)
idx_deltas = np.where(np.logical_and(cri_truth, checkable_idx))[
0
]
else:
deltas = np.max(e_idx, axis=1) - np.min(e_idx, axis=1)
idx_deltas = np.where(
np.logical_and(f(deltas, criterion), checkable_idx)
)[0]

if len(idx_deltas) > 0:
bad_names = [ch_names[idx[i]] for i in idx_deltas]
if not has_printed:
logger.info(
f" Rejecting {t} epoch based on {name} : {bad_names}"
)
has_printed = True
if not full_report:
return False
# Check to verify that refl is a callable that returns
# (bool, reason). Reason must be a str/list/tuple.
# If using tuple
if callable(refl):
if isinstance(reasons, str):
reasons = (reasons,)
for idx, reason in enumerate(reasons):
_validate_type(reason, str, reason)
bad_tuple += tuple(reasons)
else:
bad_tuple += tuple(bad_names)
bad_names = [ch_names[idx[i]] for i in idx_deltas]
if not has_printed:
logger.info(
" Rejecting %s epoch based on %s : "
"%s" % (t, name, bad_names)
)
has_printed = True
if not full_report:
return False
else:
bad_tuple += tuple(bad_names)

if not full_report:
return True
Expand Down
165 changes: 161 additions & 4 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,20 @@ def test_reject():
preload=False,
reject=dict(eeg=np.inf),
)
for val in (None, -1): # protect against older MNE-C types

# Good function
def my_reject_1(epoch_data):
bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35)
reasons = "a" * len(bad_idxs[0])
return len(bad_idxs) > 0, reasons

# Bad function
def my_reject_2(epoch_data):
bad_idxs = np.where(np.percentile(epoch_data, 90, axis=1) > 1e-35)
reasons = "a" * len(bad_idxs[0])
return len(bad_idxs), reasons

for val in (-1, -2): # protect against older MNE-C types
for kwarg in ("reject", "flat"):
pytest.raises(
ValueError,
Expand All @@ -564,6 +577,44 @@ def test_reject():
preload=False,
**{kwarg: dict(grad=val)},
)

# Check that reject and flat in constructor are not callables
val = my_reject_1
for kwarg in ("reject", "flat"):
with pytest.raises(
TypeError,
match=r".* must be an instance of numeric, got <class 'function'> instead.",
):
Epochs(
raw,
events,
event_id,
tmin,
tmax,
picks=picks_meg,
preload=False,
**{kwarg: dict(grad=val)},
)

# Check if callable returns a tuple with reasons
bad_types = [my_reject_2, ("Hi" "Hi"), (1, 1), None]
for val in bad_types: # protect against bad types
for kwarg in ("reject", "flat"):
with pytest.raises(
TypeError,
match=r".* must be an instance of .* got <class '.*'> instead.",
):
epochs = Epochs(
raw,
events,
event_id,
tmin,
tmax,
picks=picks_meg,
preload=True,
)
epochs.drop_bad(**{kwarg: dict(grad=val)})

pytest.raises(
KeyError,
Epochs,
Expand Down Expand Up @@ -2149,6 +2200,93 @@ def test_reject_epochs(tmp_path):
assert epochs_cleaned.flat == dict(grad=new_flat["grad"], mag=flat["mag"])


@testing.requires_testing_data
def test_callable_reject():
"""Test using a callable for rejection."""
raw = read_raw_fif(fname_raw_testing, preload=True)
raw.crop(0, 5)
raw.del_proj()
chans = raw.info["ch_names"][-6:-1]
raw.pick(chans)
data = raw.get_data()

# Add some artifacts
new_data = data
new_data[0, 180:200] *= 1e7
new_data[0, 610:880] += 1e-3
edit_raw = mne.io.RawArray(new_data, raw.info)

events = mne.make_fixed_length_events(edit_raw, id=1, duration=1.0, start=0)
epochs = mne.Epochs(edit_raw, events, tmin=0, tmax=1, baseline=None, preload=True)
assert len(epochs) == 5

epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
preload=True,
)
epochs.drop_bad(
reject=dict(eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), "eeg median"))
)

assert epochs.drop_log[2] == ("eeg median",)

epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
preload=True,
)
epochs.drop_bad(
reject=dict(eeg=lambda x: ((np.max(x, axis=1) > 1).any(), ("eeg max",)))
)

assert epochs.drop_log[0] == ("eeg max",)

def reject_criteria(x):
max_condition = np.max(x, axis=1) > 1e-2
median_condition = np.median(x, axis=1) > 1e-4
return (max_condition.any() or median_condition.any()), "eeg max or median"

epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
preload=True,
)
epochs.drop_bad(reject=dict(eeg=reject_criteria))

assert epochs.drop_log[0] == ("eeg max or median",) and epochs.drop_log[2] == (
"eeg max or median",
)

# Test reasons must be str or tuple of str
with pytest.raises(
TypeError,
match=r".* must be an instance of str, got <class 'int'> instead.",
):
epochs = mne.Epochs(
edit_raw,
events,
tmin=0,
tmax=1,
baseline=None,
preload=True,
)
epochs.drop_bad(
reject=dict(
eeg=lambda x: ((np.median(x, axis=1) > 1e-3).any(), ("eeg median", 2))
)
)


def test_preload_epochs():
"""Test preload of epochs."""
raw, events, picks = _get_data()
Expand Down Expand Up @@ -3180,9 +3318,16 @@ def test_drop_epochs():
events1 = events[events[:, 2] == event_id]

# Bound checks
pytest.raises(IndexError, epochs.drop, [len(epochs.events)])
pytest.raises(IndexError, epochs.drop, [-len(epochs.events) - 1])
pytest.raises(ValueError, epochs.drop, [[1, 2], [3, 4]])
with pytest.raises(IndexError, match=r"Epoch index .* is out of bounds"):
epochs.drop([len(epochs.events)])
with pytest.raises(IndexError, match=r"Epoch index .* is out of bounds"):
epochs.drop([-len(epochs.events) - 1])
with pytest.raises(TypeError, match="indices must be a scalar or a 1-d array"):
epochs.drop([[1, 2], [3, 4]])
with pytest.raises(
TypeError, match=r".* must be an instance of .* got <class '.*'> instead."
):
epochs.drop([1], reason=("a", "b", 2))

# Test selection attribute
assert_array_equal(epochs.selection, np.where(events[:, 2] == event_id)[0])
Expand All @@ -3202,6 +3347,18 @@ def test_drop_epochs():
assert_array_equal(events[epochs[3:].selection], events1[[5, 6]])
assert_array_equal(events[epochs["1"].selection], events1[[0, 1, 3, 5, 6]])

# Test using tuple to drop epochs
raw, events, picks = _get_data()
epochs_tuple = Epochs(raw, events, event_id, tmin, tmax, picks=picks, preload=True)
selection_tuple = epochs_tuple.selection.copy()
epochs_tuple.drop((2, 3, 4), reason=("a", "b"))
n_events = len(epochs.events)
assert [epochs_tuple.drop_log[k] for k in selection_tuple[[2, 3, 4]]] == [
("a", "b"),
("a", "b"),
("a", "b"),
]


@pytest.mark.parametrize("preload", (True, False))
def test_drop_epochs_mult(preload):
Expand Down
Loading
Loading