diff --git a/mne_icalabel/conftest.py b/mne_icalabel/conftest.py index 0d18a6ba1..4a1fc03a7 100644 --- a/mne_icalabel/conftest.py +++ b/mne_icalabel/conftest.py @@ -16,6 +16,8 @@ def pytest_configure(config): ignore:Python 3\.14 will, by default, filter extracted tar.*:DeprecationWarning # onnxruntime on windows runners ignore:Unsupported Windows version.*:UserWarning + # Matplotlib deprecation issued in VSCode test debugger + ignore:.*interactive_bk.*:matplotlib._api.deprecation.MatplotlibDeprecationWarning """ for warning_line in warnings_lines.split("\n"): warning_line = warning_line.strip() diff --git a/mne_icalabel/megnet/__init__.py b/mne_icalabel/megnet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mne_icalabel/megnet/_utils.py b/mne_icalabel/megnet/_utils.py new file mode 100644 index 000000000..caf3cc379 --- /dev/null +++ b/mne_icalabel/megnet/_utils.py @@ -0,0 +1,51 @@ +import numpy as np +from numpy.typing import NDArray + + +def _cart2sph(x, y, z): + xy = np.sqrt(x * x + y * y) + r = np.sqrt(x * x + y * y + z * z) + theta = np.arctan2(y, x) + phi = np.arctan2(z, xy) + return r, theta, phi + + +def _make_head_outlines(sphere: NDArray, pos: NDArray, clip_origin: tuple) -> dict: + """Generate head outlines for topomap plotting. + + This is a modified version of mne.viz.topomap._make_head_outlines. + The difference between this function and the original one is that + head_x and head_y here are scaled by a factor of 1.01 to make topomap + fit the 120x120 pixel size. + Also, removed the ear and nose outlines for not needed in MEGnet. + + Parameters + ---------- + sphere : NDArray + The sphere parameters (x, y, z, radius). + pos : NDArray + The 2D sensor positions. + clip_origin : tuple + The origin of the clipping circle. + + Returns + ------- + dict + Dictionary containing the head outlines and mask positions. + + """ + x, y, _, radius = sphere + ll = np.linspace(0, 2 * np.pi, 101) + head_x = np.cos(ll) * radius * 1.01 + x + head_y = np.sin(ll) * radius * 1.01 + y + + mask_scale = max(1.0, np.linalg.norm(pos, axis=1).max() * 1.01 / radius) + clip_radius = radius * mask_scale + + outlines_dict = { + "head": (head_x, head_y), + "mask_pos": (mask_scale * head_x, mask_scale * head_y), + "clip_radius": (clip_radius,) * 2, + "clip_origin": clip_origin, + } + return outlines_dict diff --git a/mne_icalabel/megnet/assets/megnet.onnx b/mne_icalabel/megnet/assets/megnet.onnx new file mode 100644 index 000000000..19d13ac6b Binary files /dev/null and b/mne_icalabel/megnet/assets/megnet.onnx differ diff --git a/mne_icalabel/megnet/features.py b/mne_icalabel/megnet/features.py new file mode 100644 index 000000000..4b959c88e --- /dev/null +++ b/mne_icalabel/megnet/features.py @@ -0,0 +1,217 @@ +import io + +import matplotlib.pyplot as plt +import mne +import numpy as np +from mne.io import BaseRaw +from mne.preprocessing import ICA +from mne.utils import _validate_type, warn +from numpy.typing import NDArray +from PIL import Image +from scipy import interpolate +from scipy.spatial import ConvexHull + +from mne_icalabel.iclabel._utils import _pol2cart + +from ._utils import _cart2sph, _make_head_outlines + + +def get_megnet_features(raw: BaseRaw, ica: ICA): + """Extract time series and topomaps for each ICA component. + + MEGNet uses topomaps from BrainStorm exported as 120x120x3 RGB images. + Thus, we need to replicate the 'appearance'/'look' of a BrainStorm topomap. + + Parameters + ---------- + raw : Raw + Raw MEG recording used to fit the ICA decomposition. + The raw instance should be bandpass filtered between + 1 and 100 Hz and notch filtered at 50 or 60 Hz to + remove line noise, and downsampled to 250 Hz. + ica : ICA + ICA decomposition of the provided instance. + The ICA decomposition should use the infomax method. + + Returns + ------- + time_series : array of shape (n_components, n_samples) + The time series for each ICA component. + topomaps : array of shape (n_components, 120, 120, 3) + The topomap RGB images for each ICA component. + """ + _validate_type(raw, BaseRaw, "raw") + _validate_type(ica, ICA, "ica") + if not any( + ch_type in ["mag", "grad"] for ch_type in raw.get_channel_types(unique=True) + ): + raise RuntimeError( + "Could not find MEG channels in the provided Raw instance." + "The MEGnet model was fitted on MEG data and is not" + "suited for other types of channels." + ) + if (n_samples := raw.get_data().shape[1]) < 15000: + raise RuntimeError( + f"The provided raw instance has {n_samples} points. " + "MEGnet was designed to classify features extracted " + "from an MEG dataset at least 60 seconds long @ 250 Hz," + "corresponding to at least. 15 000 samples." + ) + if not np.isclose(raw.info["sfreq"], 250, atol=1e-1): + warn( + "The provided raw instance is not sampled at 250 Hz " + f"(sfreq={raw.info['sfreq']} Hz). " + "MEGnet was designed to classify features extracted from" + "an MEG dataset sampled at 250 Hz " + "(see the 'resample()' method for Raw instances). " + "The classification performance might be negatively impacted." + ) + if raw.info["highpass"] != 1 or raw.info["lowpass"] != 100: + warn( + "The provided raw instance is not filtered between 1 and 100 Hz. " + "MEGnet was designed to classify features extracted from an MEG " + "dataset bandpass filtered between 1 and 100 Hz" + " (see the 'filter()' method for Raw instances)." + " The classification performance might be negatively impacted." + ) + if _check_line_noise(raw): + warn( + "Line noise detected in 50/60 Hz. MEGnet was trained on" + "MEG data without line noise. Please remove line noise" + "before using MEGnet (see the 'notch_filter()' method" + "for Raw instances)." + ) + if ica.method != "infomax": + warn( + f"The provided ICA instance was fitted with '{ica.method}'." + "MEGnet was designed with infomax method." + "To use the it, set mne.preprocessing.ICA instance with " + "the arguments ICA(method='infomax')." + ) + if ica.n_components != 20: + warn( + f"The provided ICA instance has {ica.n_components} components. " + "MEGnet was designed with 20 components. " + "use mne.preprocessing.ICA instance with " + "the arguments ICA(n_components=20)." + ) + + pos_new, outlines = _get_topomaps_data(ica) + topomaps = _get_topomaps(ica, pos_new, outlines) + time_series = ica.get_sources(raw).get_data() + return time_series, topomaps + + +def _get_topomaps_data(ica: ICA): + """Prepare 2D sensor positions and outlines for topomap plotting.""" + mags = mne.pick_types(ica.info, meg="mag") + channel_info = ica.info["chs"] + loc_3d = [channel_info[i]["loc"][0:3] for i in mags] + channel_locations_3d = np.array(loc_3d) + + # Convert to spherical and then to 2D + sph_coords = np.transpose( + _cart2sph( + channel_locations_3d[:, 0], + channel_locations_3d[:, 1], + channel_locations_3d[:, 2], + ) + ) + TH, PHI = sph_coords[:, 1], sph_coords[:, 2] + newR = 1 - PHI / np.pi * 2 + channel_locations_2d = np.transpose(_pol2cart(TH, newR)) + + # Adjust coordinates with convex hull interpolation + hull = ConvexHull(channel_locations_2d) + border_indices = hull.vertices + Dborder = 1 / newR[border_indices] + + funcTh = np.hstack( + [ + TH[border_indices] - 2 * np.pi, + TH[border_indices], + TH[border_indices] + 2 * np.pi, + ] + ) + funcD = np.hstack((Dborder, Dborder, Dborder)) + interp_func = interpolate.interp1d(funcTh, funcD) + D = interp_func(TH) + + adjusted_R = np.array([min(newR[i] * D[i], 1) for i in range(len(mags))]) + Xnew, Ynew = _pol2cart(TH, adjusted_R) + pos_new = np.vstack((Xnew, Ynew)).T + + outlines = _make_head_outlines(np.array([0, 0, 0, 1]), pos_new, (0, 0)) + return pos_new, outlines + + +def _get_topomaps(ica: ICA, pos_new: NDArray, outlines: dict): + """Generate topomap images for each ICA component.""" + topomaps = [] + data_picks = mne.pick_types(ica.info, meg="mag") + components = ica.get_components() + + for comp in range(ica.n_components_): + data = components[data_picks, comp] + fig = plt.figure(figsize=(1.3, 1.3), dpi=100, facecolor="black") + ax = fig.add_subplot(111) + mnefig, _ = mne.viz.plot_topomap( + data, + pos_new, + sensors=False, + outlines=outlines, + extrapolate="head", + sphere=[0, 0, 0, 1], + contours=0, + res=120, + axes=ax, + show=False, + cmap="bwr", + ) + img_buf = io.BytesIO() + mnefig.figure.savefig( + img_buf, format="png", dpi=120, bbox_inches="tight", pad_inches=0 + ) + img_buf.seek(0) + rgba_image = Image.open(img_buf) + rgb_image = rgba_image.convert("RGB") + img_buf.close() + plt.close(fig) + + topomaps.append(np.array(rgb_image)) + + return np.array(topomaps) + + +def _check_line_noise( + raw: BaseRaw, *, neighbor_width: int = 4, threshold_factor: int = 10 +) -> bool: + """Check if line noise is present in the MEG/EEG data.""" + # we don't know the line frequency + if raw.info.get("line_freq", None) is None: + return False + # validate the primary and first harmonic frequencies + nyquist_freq = raw.info["sfreq"] / 2.0 + line_freqs = [raw.info["line_freq"], 2 * raw.info["line_freq"]] + if any(nyquist_freq < lf for lf in line_freqs): + # not raising because if we get here, + # it means that someone provided a raw with + # a sampling rate extremely low (100 Hz?) and (1) + # either they missed all of the previous warnings + # encountered or (2) they know what they are doing. + warn("The sampling rate raw.info['sfreq'] is too low" "to estimate line niose.") + return False + # compute the power spectrum and retrieve the frequencies of interest + spectrum = raw.compute_psd(picks="meg", exclude="bads") + data, freqs = spectrum.get_data( + fmin=raw.info["line_freq"] - neighbor_width, + fmax=raw.info["line_freq"] + neighbor_width, + return_freqs=True, + ) # array of shape (n_good_channel, n_freqs) + idx = np.argmin(np.abs(freqs - raw.info["line_freq"])) + mask = np.ones(data.shape[1], dtype=bool) + mask[idx] = False + background_mean = np.mean(data[:, mask], axis=1) + background_std = np.std(data[:, mask], axis=1) + threshold = background_mean + threshold_factor * background_std + return np.any(data[:, idx] > threshold) diff --git a/mne_icalabel/megnet/label_components.py b/mne_icalabel/megnet/label_components.py new file mode 100644 index 000000000..ced660774 --- /dev/null +++ b/mne_icalabel/megnet/label_components.py @@ -0,0 +1,107 @@ +from importlib.resources import files + +import numpy as np +import onnxruntime as ort +from mne.io import BaseRaw +from mne.preprocessing import ICA +from numpy.typing import NDArray + +from .features import get_megnet_features + +_MODEL_PATH: str = files("mne_icalabel.megnet") / "assets" / "megnet.onnx" + + +def megnet_label_components(raw: BaseRaw, ica: ICA) -> NDArray: + """Label the provided ICA components with the MEGnet neural network. + + Parameters + ---------- + raw : Raw + Raw MEG recording used to fit the ICA decomposition. + The raw instance should be bandpass filtered between 1 and 100 Hz + and notch filtered at 50 or 60 Hz to remove line noise, + and downsampled to 250 Hz. + ica : ICA + ICA decomposition of the provided instance. + The ICA decomposition should use the infomax method. + + Returns + ------- + labels_pred_proba : numpy.ndarray of shape (n_components, n_classes) + The estimated corresponding predicted probabilities of output classes + for each independent component. Columns are ordered with + 'brain/other', 'eye movement', 'heart beat', 'eye blink', + + """ + time_series, topomaps = get_megnet_features(raw, ica) + + # sanity-checks + # number of time-series <-> topos + assert time_series.shape[0] == topomaps.shape[0] + # topos are images of shape 120x120x3 + assert topomaps.shape[1:] == (120, 120, 3) + # minimum time-series length + assert 15000 <= time_series.shape[1] + + session = ort.InferenceSession(_MODEL_PATH) + labels_pred_proba = _chunk_predicting(session, time_series, topomaps) + return labels_pred_proba[:, 0, :] + + +def _chunk_predicting( + session: ort.InferenceSession, + time_series: NDArray, + spatial_maps: NDArray, + chunk_len=15000, + overlap_len=3750, +) -> NDArray: + """MEGnet's chunk volte algorithm.""" + predction_vote = [] + + for comp_series, comp_map in zip(time_series, spatial_maps): + time_len = comp_series.shape[0] + start_times = _get_chunk_start(time_len, chunk_len, overlap_len) + + if start_times[-1] + chunk_len <= time_len: + start_times.append(time_len - chunk_len) + + chunk_votes = {start: 0 for start in start_times} + for t in range(time_len): + in_chunks = [start <= t < start + chunk_len for start in start_times] + # how many chunks the time point is in + num_chunks = np.sum(in_chunks) + for start_time, is_in_chunk in zip(start_times, in_chunks): + if is_in_chunk: + chunk_votes[start_time] += 1.0 / num_chunks + + weighted_predictions = {} + for start_time in chunk_votes.keys(): + onnx_inputs = { + session.get_inputs()[0].name: np.expand_dims(comp_map, 0).astype( + np.float32 + ), + session.get_inputs()[1].name: np.expand_dims( + np.expand_dims(comp_series[start_time : start_time + chunk_len], 0), + -1, + ).astype(np.float32), + } + prediction = session.run(None, onnx_inputs)[0] + weighted_predictions[start_time] = prediction * chunk_votes[start_time] + + comp_prediction = np.stack(list(weighted_predictions.values())).mean(axis=0) + comp_prediction /= comp_prediction.sum() + predction_vote.append(comp_prediction) + + return np.stack(predction_vote) + + +def _get_chunk_start( + input_len: int, chunk_len: int = 15000, overlap_len: int = 3750 +) -> list: + """Calculate start times for time series chunks with overlap.""" + start_times = [] + start_time = 0 + while start_time + chunk_len <= input_len: + start_times.append(start_time) + start_time += chunk_len - overlap_len + return start_times diff --git a/mne_icalabel/megnet/tests/__init__.py b/mne_icalabel/megnet/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mne_icalabel/megnet/tests/test_features.py b/mne_icalabel/megnet/tests/test_features.py new file mode 100644 index 000000000..9ff78a15e --- /dev/null +++ b/mne_icalabel/megnet/tests/test_features.py @@ -0,0 +1,180 @@ +import numpy as np +import pytest +from mne import create_info +from mne.io import RawArray +from mne.preprocessing import ICA + +from mne_icalabel.megnet.features import _check_line_noise, get_megnet_features + + +@pytest.fixture +def raw_with_line_noise(): + """Create a Raw instance with line noise.""" + times = np.arange(0, 2, 1 / 1000) + data1 = np.sin(2 * np.pi * 10 * times) + np.sin(2 * np.pi * 30 * times) + data2 = np.sin(2 * np.pi * 30 * times) + np.sin(2 * np.pi * 80 * times) + data = np.vstack([data1, data2]) + info = create_info(ch_names=["10-30", "30-80"], sfreq=1000, ch_types="mag") + return RawArray(data, info) + + +def test_check_line_noise(raw_with_line_noise): + """Check line-noise auto-detection.""" + assert not _check_line_noise(raw_with_line_noise) + # 50 Hz is absent from both channels + raw_with_line_noise.info["line_freq"] = 50 + assert not _check_line_noise(raw_with_line_noise) + # 10 and 80 Hz are present on one channel each, + # while 30 Hz is present on both + raw_with_line_noise.info["line_freq"] = 30 + assert _check_line_noise(raw_with_line_noise) + raw_with_line_noise.info["line_freq"] = 80 + assert _check_line_noise(raw_with_line_noise) + raw_with_line_noise.info["line_freq"] = 10 + assert _check_line_noise(raw_with_line_noise) + + +def create_raw_ica( + n_channels=20, + sfreq=250, + ch_type="mag", + n_components=20, + filter_range=(1, 100), + ica_method="infomax", + ntime=None, +): + """Create a Raw instance and ICA instance for testing.""" + n_times = sfreq * 60 if ntime is None else ntime + rng = np.random.default_rng() + data = rng.standard_normal((n_channels, n_times)) + ch_names = [f"MEG {i+1}" for i in range(n_channels)] + + # Create valid channel loc for feature extraction + channel_locs = rng.standard_normal((n_channels, 3)) + channel_locs[:, 0] += 0.1 + channel_locs[:, 1] += 0.1 + channel_locs[:, 2] += 0.1 + + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_type) + for i, loc in enumerate(channel_locs): + info["chs"][i]["loc"][:3] = loc + + raw = RawArray(data, info) + raw.filter(*filter_range) + + # fastica can not converge with the current data + # so we use infomax in computation + # but set ica_method after fitting for testing + ica = ICA(n_components=n_components, method="infomax") + ica.fit(raw) + if ica_method != "infomax": + ica.method = ica_method + + return raw, ica + + +@pytest.fixture +def raw_ica_valid(): + """Raw instance with valid parameters.""" + raw, ica = create_raw_ica() + return raw, ica + + +def test_get_megnet_features(raw_ica_valid): + """Test whether the function returns the correct features.""" + time_series, topomaps = get_megnet_features(*raw_ica_valid) + n_components = raw_ica_valid[1].n_components + n_times = raw_ica_valid[0].times.shape[0] + + assert time_series.shape == (n_components, n_times) + assert topomaps.shape == (n_components, 120, 120, 3) + + +@pytest.fixture +def raw_ica_invalid_channel(): + """Raw instance with invalid channel type.""" + raw, ica = create_raw_ica(ch_type="eeg") + return raw, ica + + +@pytest.fixture +def raw_ica_invalid_sfreq(): + """Raw instance with invalid sampling frequency.""" + raw, ica = create_raw_ica(sfreq=600) + return raw, ica + + +@pytest.fixture +def raw_ica_invalid_time(): + """Raw instance with invalid time points.""" + raw, ica = create_raw_ica(ntime=2500) + return raw, ica + + +@pytest.fixture +def raw_ica_invalid_filter(): + """Raw instance with invalid filter range.""" + raw, ica = create_raw_ica(filter_range=(0.1, 100)) + return raw, ica + + +@pytest.fixture +def raw_ica_invalid_ncomp(): + """Raw instance with invalid number of ICA components.""" + raw, ica = create_raw_ica(n_components=10) + return raw, ica + + +@pytest.fixture +def raw_ica_invalid_method(): + """Raw instance with invalid ICA method.""" + raw, ica = create_raw_ica(ica_method="fastica") + return raw, ica + + +def test_get_megnet_features_invalid( + raw_ica_invalid_channel, + raw_ica_invalid_time, + raw_ica_invalid_sfreq, + raw_ica_invalid_filter, + raw_ica_invalid_ncomp, + raw_ica_invalid_method, +): + """Test whether the function raises the correct exceptions.""" + test_cases = [ + (raw_ica_invalid_channel, RuntimeError, "Could not find MEG channels"), + ( + raw_ica_invalid_time, + RuntimeError, + "The provided raw instance has 2500 points.", + ), + ( + raw_ica_invalid_sfreq, + RuntimeWarning, + "The provided raw instance is not sampled at 250 Hz", + ), + ( + raw_ica_invalid_filter, + RuntimeWarning, + "The provided raw instance is not filtered between 1 and 100 Hz", + ), + ( + raw_ica_invalid_ncomp, + RuntimeWarning, + "The provided ICA instance has 10 components", + ), + ( + raw_ica_invalid_method, + RuntimeWarning, + "The provided ICA instance was fitted with 'fastica'", + ), + ] + + for raw_ica_fixture, exc_type, msg in test_cases: + raw, ica = raw_ica_fixture + if exc_type is RuntimeError: + with pytest.raises(exc_type, match=msg): + get_megnet_features(raw, ica) + elif exc_type is RuntimeWarning: + with pytest.warns(exc_type, match=msg): + get_megnet_features(raw, ica) diff --git a/mne_icalabel/megnet/tests/test_label_components.py b/mne_icalabel/megnet/tests/test_label_components.py new file mode 100644 index 000000000..d91b88234 --- /dev/null +++ b/mne_icalabel/megnet/tests/test_label_components.py @@ -0,0 +1,114 @@ +from unittest.mock import MagicMock + +import mne +import numpy as np +import onnxruntime as ort +import pytest +from numpy.testing import assert_allclose + +from mne_icalabel.megnet.label_components import ( + _chunk_predicting, + _get_chunk_start, + megnet_label_components, +) + + +@pytest.fixture +def raw_ica(): + """Create a Raw instance and ICA instance for testing.""" + sample_dir = mne.datasets.sample.data_path() + sample_fname = sample_dir / "MEG" / "sample" / "sample_audvis_raw.fif" + + raw = mne.io.read_raw_fif(sample_fname).pick("mag") + raw.load_data() + raw.resample(250) + raw.notch_filter(60) + raw.filter(1, 100) + + ica = mne.preprocessing.ICA(n_components=20, method="infomax", random_state=88) + ica.fit(raw) + + return raw, ica + + +def test_megnet_label_components(raw_ica): + """Test whether the function returns the correct artifact index.""" + real_atrifact_idx = [0, 3, 5] # heart beat, eye movement, heart beat + prob = megnet_label_components(*raw_ica) + this_atrifact_idx = [int(idx) for idx in np.nonzero(prob.argmax(axis=1))[0]] + assert set(real_atrifact_idx) == set(this_atrifact_idx) + + +def test_get_chunk_start(): + """Test whether the function returns the correct start times.""" + input_len = 10000 + chunk_len = 3000 + overlap_len = 750 + + start_times = _get_chunk_start(input_len, chunk_len, overlap_len) + + assert len(start_times) == 4 + assert start_times == [0, 2250, 4500, 6750] + + +def test_chunk_predicting(): + """Test whether MEGnet's chunk volte algorithm returns the correct shape.""" + rng = np.random.default_rng() + time_series = rng.random((5, 10000)) + spatial_maps = rng.random((5, 120, 120, 3)) + + mock_session = MagicMock(spec=ort.InferenceSession) + mock_session.run.return_value = [rng.random(4)] + + predictions = _chunk_predicting( + mock_session, time_series, spatial_maps, chunk_len=3000, overlap_len=750 + ) + + assert predictions.shape == (5, 4) + assert isinstance(predictions, np.ndarray) + + +def test_ica(raw_ica): + """Test whether the ICA instances are the same.""" + raw1, ica1 = raw_ica + raw2 = raw1.copy() + ica = mne.preprocessing.ICA(n_components=20, method="infomax", random_state=88) + ica2 = ica.fit(raw2) + assert_allclose( + raw1.get_data(), + raw2.get_data(), + atol=1e-6, + err_msg="Raw data should be the same!", + ) + + assert_allclose( + ica1.mixing_matrix_, + ica2.mixing_matrix_, + atol=1e-6, + err_msg="ICA mixing matrices should be the same!", + ) + assert_allclose( + ica1.unmixing_matrix_, + ica2.unmixing_matrix_, + atol=1e-6, + err_msg="ICA unmixing matrices should be the same!", + ) + + ica1_data = ica1.get_sources(raw1).get_data() + ica2_data = ica2.get_sources(raw2).get_data() + assert_allclose( + ica1_data, + ica2_data, + atol=1e-6, + err_msg="ICA transformed data should be the same!", + ) + + +def test_megnet(raw_ica): + """Test whether the MEGnet predictions are the same.""" + raw, ica = raw_ica + prob1 = megnet_label_components(raw, ica) + prob2 = megnet_label_components(raw, ica) + assert_allclose( + prob1, prob2, atol=1e-6, err_msg="MEGnet predictions should be the same!" + ) diff --git a/pyproject.toml b/pyproject.toml index 2fa23b309..803fa9ee5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -203,6 +203,7 @@ include-package-data = false [tool.setuptools.package-data] 'mne_icalabel.iclabel.network' = ['assets/*'] +'mne_icalabel.megnet' = ['assets/*'] [tool.setuptools.packages.find] exclude = ['mne_icalabel*tests']