From 8409542b8517aab108a7cd317486c9af8abd8a54 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Wed, 15 Dec 2021 21:11:32 -0800 Subject: [PATCH 1/2] Add utility function to enable autoreject cleaning on Epoch level --- doc/api.rst | 1 + mne_connectivity/tests/test_utils.py | 41 ++++++++++++- mne_connectivity/utils/__init__.py | 2 +- mne_connectivity/utils/utils.py | 86 ++++++++++++++++++++++++++++ 4 files changed, 128 insertions(+), 2 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 2fbfddce..98ea4e3b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -72,6 +72,7 @@ Post-processing on connectivity seed_target_indices check_indices select_order + map_epoch_annotations_to_epoch Visualization functions ======================= diff --git a/mne_connectivity/tests/test_utils.py b/mne_connectivity/tests/test_utils.py index c012481c..bc288445 100644 --- a/mne_connectivity/tests/test_utils.py +++ b/mne_connectivity/tests/test_utils.py @@ -2,8 +2,13 @@ import pytest from numpy.testing import assert_array_equal +from mne.io import RawArray +from mne.epochs import Epochs, make_fixed_length_epochs +from mne.io.meas_info import create_info + from mne_connectivity import Connectivity -from mne_connectivity.utils import degree, seed_target_indices +from mne_connectivity.utils import ( + degree, seed_target_indices, map_epoch_annotations_to_epoch) def test_indices(): @@ -64,3 +69,37 @@ def test_degree(): conn = Connectivity(data=np.zeros((4,)), n_nodes=2) deg = degree(conn) assert_array_equal(deg, [0, 0]) + + +def test_mapping_epochs_to_epochs(): + """Test map_epoch_annotations_to_epoch function.""" + n_times = 1000 + sfreq = 100 + data = np.random.random((2, n_times)) + info = create_info(ch_names=['A1', 'A2'], sfreq=sfreq, + ch_types='mag') + raw = RawArray(data, info) + + # create two different sets of Epochs + # the first one is just a contiguous chunks of 1 seconds + epoch_one = make_fixed_length_epochs(raw, duration=1, overlap=0) + + events = np.zeros((2, 3), dtype=int) + events[:, 0] = [100, 900] + epoch_two = Epochs(raw, events, tmin=-0.5, tmax=0.5) + + # map Epochs from two to one + all_cases = map_epoch_annotations_to_epoch(epoch_one, epoch_two) + assert all_cases.shape == (2, 10) + + # only 1-3 Epochs of epoch_one should overlap with the epoch_two's + # 1st Epoch + assert all(all_cases[0, :2]) + assert all(all_cases[1, -2:]) + + # map Epochs from one to two + all_cases = map_epoch_annotations_to_epoch(epoch_two, epoch_one) + assert all_cases.shape == (10, 2) + + assert all(all_cases[:2, 0]) + assert all(all_cases[-2:, 1]) diff --git a/mne_connectivity/utils/__init__.py b/mne_connectivity/utils/__init__.py index 0b7e1110..fb213109 100644 --- a/mne_connectivity/utils/__init__.py +++ b/mne_connectivity/utils/__init__.py @@ -1,3 +1,3 @@ from .docs import fill_doc from .utils import (check_indices, degree, seed_target_indices, - parallel_loop) + parallel_loop, map_epoch_annotations_to_epoch) diff --git a/mne_connectivity/utils/utils.py b/mne_connectivity/utils/utils.py index 2ec39610..05522ad6 100644 --- a/mne_connectivity/utils/utils.py +++ b/mne_connectivity/utils/utils.py @@ -1,11 +1,97 @@ # Authors: Martin Luessi +# Adam Li # # License: BSD (3-clause) import numpy as np +from mne import BaseEpochs from mne.utils import logger +def map_epoch_annotations_to_epoch(dest_epoch, src_epoch): + """Map Annotations that occur in one Epoch to another Epoch. + + Two different Epochs might occur at different time points. + This function will map Annotations that occur in one Epoch + setting to another Epoch taking into account their onset + samples and window lengths. + + Parameters + ---------- + dest_epoch : instance of Epochs | events array + The reference Epochs that you want to match to. + src_epoch : instance of Epochs | events array + The source Epochs that contain Epochs you want to + see if it overlaps at any point with ``dest_epoch``. + + Returns + ------- + all_cases : np.ndarray of shape (n_src_epochs, n_dest_epochs) + This is an array indicating the overlap of any source epoch + relative to the destination epoch. An overlap is indicated + by a ``True``, whereas if a source Epoch does not overlap + with a destination Epoch, then the element will be ``False``. + + Notes + ----- + This is a useful utility function to enable mapping Autoreject + ``RejectLog`` that occurs over a set of defined Epochs to + another Epoched data structure, such as a ``Epoch*`` connectivity + class, which computes connectivity over Epochs. + """ + if isinstance(dest_epoch, BaseEpochs): + dest_events = dest_epoch.events + dest_times = dest_epoch.times + dest_sfreq = dest_epoch._raw_sfreq + else: + dest_events = dest_epoch + if isinstance(src_epoch, BaseEpochs): + src_events = src_epoch.events + src_times = src_epoch.times + src_sfreq = src_epoch._raw_sfreq + else: + src_events = src_epoch + + # get the sample points of the source Epochs we want + # to map over to the destination sample points + src_onset_sample = src_events[:, 0] + src_epoch_tzeros = src_onset_sample / src_sfreq + dest_onset_sample = dest_events[:, 0] + dest_epoch_tzeros = dest_onset_sample / dest_sfreq + + # get start and stop points of every single source Epoch + src_epoch_starts, src_epoch_stops = np.atleast_2d( + src_epoch_tzeros) + np.atleast_2d(src_times[[0, -1]]).T + + # get start and stop points of every single destination Epoch + dest_epoch_starts, dest_epoch_stops = np.atleast_2d( + dest_epoch_tzeros) + np.atleast_2d(dest_times[[0, -1]]).T + + # get destination Epochs that start within the source Epoch + src_straddles_dest_start = np.logical_and( + np.atleast_2d(dest_epoch_starts) >= np.atleast_2d(src_epoch_starts).T, + np.atleast_2d(dest_epoch_starts) < np.atleast_2d(src_epoch_stops).T) + + # get epochs that end within the annotations + src_straddles_dest_end = np.logical_and( + np.atleast_2d(dest_epoch_stops) > np.atleast_2d(src_epoch_starts).T, + np.atleast_2d(dest_epoch_stops) <= np.atleast_2d(src_epoch_stops).T) + + # get epochs that are fully contained within annotations + src_fully_within_dest = np.logical_and( + np.atleast_2d(dest_epoch_starts) <= np.atleast_2d(src_epoch_starts).T, + np.atleast_2d(dest_epoch_stops) >= np.atleast_2d(src_epoch_stops).T) + + # combine all cases to get array of shape (n_src_epochs, n_dest_epochs). + # Nonzero entries indicate overlap between the corresponding + # annotation (row index) and epoch (column index). + all_cases = (src_straddles_dest_start + + src_straddles_dest_end + + src_fully_within_dest) + + return all_cases + + def parallel_loop(func, n_jobs=1, verbose=1): """run loops in parallel, if joblib is available. From bb7d382446eddeb524ce2851ae17f4efe1e5b44a Mon Sep 17 00:00:00 2001 From: Adam Li Date: Wed, 15 Dec 2021 21:17:15 -0800 Subject: [PATCH 2/2] Add whatsnew --- doc/whats_new.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 0f2d3946..34eb445f 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -27,6 +27,7 @@ Enhancements - Adding symmetric orthogonalization via :func:`mne_connectivity.symmetric_orth`, by `Eric Larson`_ (:gh:`36`) - Improved RAM usage for :func:`mne_connectivity.vector_auto_regression` by leveraging code from ``statsmodels``, by `Adam Li`_ (:gh:`46`) - Added :func:`mne_connectivity.select_order` for helping to select VAR order using information criterion, by `Adam Li`_ (:gh:`46`) +- Adds a utility function :func:`mne_connectivity.utils.map_epoch_annotations_to_epoch` to map arbitrary Epoch windows to another arbitrary Epoch window, by `Adam Li`_ (:gh:`62`) Bug ~~~