Skip to content

Commit

Permalink
topomaps plot modify & bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
colehank committed Oct 23, 2024
1 parent 96ed02d commit 989cb40
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 55 deletions.
38 changes: 7 additions & 31 deletions mne_icalabel/megnet/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
import matplotlib.pyplot as plt
import mne # type: ignore
import numpy as np
from _utils import cart2sph, pol2cart
from ._utils import cart2sph, pol2cart
from mne.io import BaseRaw # type: ignore
from mne.preprocessing import ICA # type: ignore
from mne.utils import warn # type: ignore
from numpy.typing import NDArray
from PIL import Image
from scipy import interpolate # type: ignore
from scipy.spatial import ConvexHull # type: ignore
from ._utils import cart2sph, pol2cart


def get_megnet_features(raw: BaseRaw, ica: ICA):
"""
Extract time series and topomaps for each ICA component.
the main work is focused on making BrainStorm-like topomaps
which trained the MEGnet.
Expand All @@ -34,6 +34,7 @@ def get_megnet_features(raw: BaseRaw, ica: ICA):
The time series for each ICA component.
topomaps : np.ndarray
The topomaps for each ICA component
"""
if "meg" not in raw:
raise RuntimeError(
Expand Down Expand Up @@ -67,9 +68,7 @@ def get_megnet_features(raw: BaseRaw, ica: ICA):


def _make_head_outlines(sphere: NDArray, pos: NDArray, clip_origin: tuple):
"""
Generate head outlines and mask positions for the topomap plot.
"""
"""Generate head outlines and mask positions for the topomap plot."""
x, y, _, radius = sphere
ll = np.linspace(0, 2 * np.pi, 101)
head_x = np.cos(ll) * radius * 1.01 + x
Expand All @@ -88,9 +87,7 @@ def _make_head_outlines(sphere: NDArray, pos: NDArray, clip_origin: tuple):


def _get_topomaps_data(ica: ICA):
"""
Prepare 2D sensor positions and outlines for topomap plotting.
"""
"""Prepare 2D sensor positions and outlines for topomap plotting."""
mags = mne.pick_types(ica.info, meg="mag")
channel_info = ica.info["chs"]
loc_3d = [channel_info[i]["loc"][0:3] for i in mags]
Expand Down Expand Up @@ -133,9 +130,7 @@ def _get_topomaps_data(ica: ICA):


def _get_topomaps(ica: ICA, pos_new: NDArray, outlines: dict):
"""
Generate topomap images for each ICA component.
"""
"""Generate topomap images for each ICA component."""
topomaps = []
data_picks, _, _, _, _, _, _ = mne.viz.topomap._prepare_topomap_plot(
ica, ch_type="mag"
Expand All @@ -153,7 +148,7 @@ def _get_topomaps(ica: ICA, pos_new: NDArray, outlines: dict):
outlines=outlines,
extrapolate="head",
sphere=[0, 0, 0, 1],
contours=10,
contours=0,
res=120,
axes=ax,
show=False,
Expand All @@ -174,22 +169,3 @@ def _get_topomaps(ica: ICA, pos_new: NDArray, outlines: dict):
return np.array(topomaps)


if __name__ == "__main__":
sample_dir = mne.datasets.sample.data_path()
sample_fname = sample_dir / "MEG" / "sample" / "sample_audvis_raw.fif"

raw = mne.io.read_raw_fif(sample_fname).pick_types("mag")
raw.resample(250)
ica = mne.preprocessing.ICA(n_components=20, method="infomax")
ica.fit(raw)

time_series, topomaps = get_megnet_features(raw, ica)

fig, axes = plt.subplots(4, 5)
for i, comp in enumerate(topomaps):
row, col = divmod(i, 5)
ax = axes[row, col]
ax.imshow(comp)
ax.axis("off")
fig.tight_layout()
plt.show()
30 changes: 6 additions & 24 deletions mne_icalabel/megnet/label_componets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from mne.io import BaseRaw
from mne.preprocessing import ICA
from numpy.typing import NDArray

from features import get_megnet_features

def megnet_label_components(
raw: BaseRaw,
ica: ICA,
model_path: str = op.join("assets", "network", "megnet.onnx"),
) -> dict:
"""
Label the provided ICA components with the MEGnet neural network.
Parameters
Expand All @@ -32,6 +33,7 @@ def megnet_label_components(
The predicted probabilities for each component.
- 'labels' : list of str
The predicted labels for each component.
"""
time_series, topomaps = get_megnet_features(raw, ica)

Expand All @@ -50,7 +52,7 @@ def megnet_label_components(
session = ort.InferenceSession(model_path)
predictions_vote = _chunk_predicting(session, time_series, topomaps)

all_labels = ["brain/other", "eye blink", "eye movement", "heart"]
all_labels = ["brain/other", "eye movement", "heart", "eye blink"]
# megnet_labels = ['NA', 'EB', 'SA', 'CA']
result = predictions_vote[:, 0, :]
labels = [all_labels[i] for i in result.argmax(axis=1)]
Expand All @@ -66,9 +68,7 @@ def _chunk_predicting(
chunk_len=15000,
overlap_len=3750,
) -> NDArray:
"""Predict the labels for each component using
MEGnet's chunk volte algorithm.
"""
"""MEGnet's chunk volte algorithm."""
predction_vote = []

for comp_series, comp_map in zip(time_series, spatial_maps):
Expand Down Expand Up @@ -111,9 +111,7 @@ def _chunk_predicting(
def _get_chunk_start(
input_len: int, chunk_len: int = 15000, overlap_len: int = 3750
) -> list:
"""
Calculate start times for time series chunks with overlap.
"""
"""Calculate start times for time series chunks with overlap."""
start_times = []
start_time = 0
while start_time + chunk_len <= input_len:
Expand All @@ -122,19 +120,3 @@ def _get_chunk_start(
return start_times


if __name__ == "__main__":
import mne
from features import get_megnet_features

sample_dir = mne.datasets.sample.data_path()
sample_fname = sample_dir / "MEG" / "sample" / "sample_audvis_raw.fif"
raw = mne.io.read_raw_fif(sample_fname).pick("mag")
raw.resample(250)
raw.filter(1, 100)
ica = mne.preprocessing.ICA(n_components=20, max_iter="auto", method="infomax")
ica.fit(raw)

res = megnet_label_components(raw, ica)
print(res)
ica.exclude = [i for i, label in enumerate(res["labels"]) if label != "brain/other"]
ica.plot_components()

0 comments on commit 989cb40

Please sign in to comment.