Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix largest culprit in speed issue #28

Merged
merged 9 commits into from
Apr 30, 2024
10 changes: 6 additions & 4 deletions mne_gui_addons/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,11 @@ def _plot_images(self):
plot_x_idx, plot_y_idx = self._xy_idx[axis]
fig = self._figs[axis]
ax = fig.axes[0]
img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T
self._images["base"].append(
ax.imshow(
img_data,
self._base_data[
(slice(None),) * axis + (self._current_slice[axis],)
].T,
cmap="gray",
aspect="auto",
zorder=1,
Expand Down Expand Up @@ -623,8 +624,9 @@ def _draw(self, axis=None):
def _update_base_images(self, axis=None, draw=False):
"""Update the base images."""
for axis in range(3) if axis is None else [axis]:
img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T
self._images["base"][axis].set_data(img_data)
self._images["base"][axis].set_data(
self._base_data[(slice(None),) * axis + (self._current_slice[axis],)].T
)
if draw:
self._draw(axis)

Expand Down
39 changes: 25 additions & 14 deletions mne_gui_addons/_ieeg_locate.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,16 +996,21 @@ def _update_ch_images(self, axis=None, draw=False):
def _update_ct_images(self, axis=None, draw=False):
"""Update the CT image(s)."""
for axis in range(3) if axis is None else [axis]:
ct_data = np.take(self._ct_data, self._current_slice[axis], axis=axis).T
ct_data = (
self._ct_data[(slice(None),) * axis + (self._current_slice[axis],)]
.copy()
.T
)
# Threshold the CT so only bright objects (electrodes) are visible
ct_data[ct_data < self._ct_min_slider.value()] = np.nan
ct_data[ct_data > self._ct_max_slider.value()] = np.nan
self._images["ct"][axis].set_data(ct_data)
if "local_max" in self._images:
ct_max_data = np.take(
self._ct_maxima, self._current_slice[axis], axis=axis
).T
self._images["local_max"][axis].set_data(ct_max_data)
self._images["local_max"][axis].set_data(
self._ct_maxima[
(slice(None),) * axis + (self._current_slice[axis],)
].T
)
if draw:
self._draw(axis)

Expand All @@ -1014,7 +1019,9 @@ def _update_mri_images(self, axis=None, draw=False):
if "mri" in self._images:
for axis in range(3) if axis is None else [axis]:
self._images["mri"][axis].set_data(
np.take(self._mr_data, self._current_slice[axis], axis=axis).T
self._mr_data[
(slice(None),) * axis + (self._current_slice[axis],)
].T
)
if draw:
self._draw(axis)
Expand Down Expand Up @@ -1150,14 +1157,13 @@ def _toggle_show_max(self):
self._update_ct_maxima()
self._images["local_max"] = list()
for axis in range(3):
ct_max_data = np.take(
self._ct_maxima, self._current_slice[axis], axis=axis
).T
self._images["local_max"].append(
self._figs[axis]
.axes[0]
.imshow(
ct_max_data,
self._ct_maxima[
(slice(None),) * axis + (self._current_slice[axis],)
].T,
cmap="autumn",
aspect="auto",
vmin=0,
Expand All @@ -1182,13 +1188,18 @@ def _toggle_show_brain(self):
else:
self._images["mri"] = list()
for axis in range(3):
mri_data = np.take(
self._mr_data, self._current_slice[axis], axis=axis
).T
self._images["mri"].append(
self._figs[axis]
.axes[0]
.imshow(mri_data, cmap="hot", aspect="auto", alpha=0.25, zorder=2)
.imshow(
self._mr_data[
(slice(None),) * axis + (self._current_slice[axis],)
].T,
cmap="hot",
aspect="auto",
alpha=0.25,
zorder=2,
)
)
self._toggle_brain_button.setText("Hide Brain")
self._draw()
Expand Down
14 changes: 9 additions & 5 deletions mne_gui_addons/_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,9 @@ def _update_img_scale(self):
def _update_base_images(self, axis=None, draw=False):
"""Update the CT image(s)."""
for axis in range(3) if axis is None else [axis]:
img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T
img_data = self._base_data[
(slice(None),) * axis + (self._current_slice[axis],)
].T.copy()
img_data[img_data < self._img_min_slider.value()] = np.nan
img_data[img_data > self._img_max_slider.value()] = np.nan
self._images["base"][axis].set_data(img_data)
Expand All @@ -335,10 +337,11 @@ def _plot_vol_images(self):
for axis in range(3):
fig = self._figs[axis]
ax = fig.axes[0]
vol_data = np.take(self._vol_img, self._current_slice[axis], axis=axis).T
self._images["vol"].append(
ax.imshow(
vol_data,
self._vol_img[
(slice(None),) * axis + (self._current_slice[axis],)
].T,
aspect="auto",
zorder=3,
cmap=_CMAP,
Expand Down Expand Up @@ -438,8 +441,9 @@ def _mark_all(self):
def _update_vol_images(self, axis=None, draw=False):
"""Update the volume image(s)."""
for axis in range(3) if axis is None else [axis]:
vol_data = np.take(self._vol_img, self._current_slice[axis], axis=axis).T
self._images["vol"][axis].set_data(vol_data)
self._images["vol"][axis].set_data(
self._vol_img[(slice(None),) * axis + (self._current_slice[axis],)].T
)
if draw:
self._draw(axis)

Expand Down
18 changes: 10 additions & 8 deletions mne_gui_addons/_vol_stc.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ def __init__(
]
src_coord = self._get_src_coord()
for axis in range(3):
stc_slice = np.take(self._stc_img, src_coord[axis], axis=axis).T
x_idx, y_idx = self._xy_idx[axis]
extent = [
corners[0][x_idx],
Expand All @@ -318,7 +317,7 @@ def __init__(
self._figs[axis]
.axes[0]
.imshow(
stc_slice,
self._stc_img[(slice(None),) * axis + (src_coord[axis],)].T,
aspect="auto",
extent=extent,
cmap=self._cmap,
Expand Down Expand Up @@ -507,7 +506,7 @@ def _apply_vector_norm(self, stc_data, axis=1):
# if self._data.dtype in (COMPLEX_DTYPE, BASE_INT_DTYPE):
# stc_data = stc_data.round().astype(BASE_INT_DTYPE)
else:
stc_data = np.take(stc_data, 0, axis=axis)
stc_data = stc_data[(slice(None),) * axis + (0,)]
return stc_data

def _apply_baseline_correction(self, stc_data):
Expand Down Expand Up @@ -541,9 +540,9 @@ def _pick_stc_vertex(self, stc_data):

def _pick_stc_tfr(self, stc_data):
"""Select the frequency and time based on GUI values."""
stc_data = np.take(stc_data, self._t_idx, axis=-1)
stc_data = stc_data[..., self._t_idx]
f_idx = 0 if self._f_idx is None else self._f_idx
stc_data = np.take(stc_data, f_idx, axis=-1)
stc_data = stc_data[..., f_idx]
return stc_data

def _configure_ui(self):
Expand Down Expand Up @@ -1381,10 +1380,13 @@ def _plot_stc_images(self, axis=None, draw=True):
for axis in range(3):
# ensure in bounds
if src_coord[axis] >= 0 and src_coord[axis] < self._stc_img.shape[axis]:
stc_slice = np.take(self._stc_img, src_coord[axis], axis=axis).T
self._images["stc"][axis].set_data(
self._stc_img[(slice(None),) * axis + (src_coord[axis],)].T
)
else:
stc_slice = np.take(self._stc_img, 0, axis=axis).T * np.nan
self._images["stc"][axis].set_data(stc_slice)
self._images["stc"][axis].set_data(
self._stc_img[(slice(None),) * axis + (0,)].copy().T * np.nan
)
if draw and self._update:
self._draw(axis)

Expand Down
7 changes: 4 additions & 3 deletions mne_gui_addons/tests/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ def test_segment_display(renderer_interactive_pyvistaqt):

# test no seghead, fsaverage doesn't have seghead
with pytest.warns(RuntimeWarning, match="`seghead` not found"):
gui = VolumeSegmenter(
subject="fsaverage", subjects_dir=subjects_dir, verbose=True
)
with pytest.warns(RuntimeWarning, match="`pial` surface not found"):
gui = VolumeSegmenter(
subject="fsaverage", subjects_dir=subjects_dir, verbose=True
)

# test functions
gui.set_RAS([25.37, 0.00, 34.18])
Expand Down
2 changes: 1 addition & 1 deletion mne_gui_addons/tests/test_vol_stc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _fake_stc(src_type="vol"):
) + 1j * rng.integers(
-1000, 1000, size=(n_epochs, len(info.ch_names), freqs.size, times.size)
)
epochs_tfr = mne.time_frequency.EpochsTFR(info, data, times=times, freqs=freqs)
epochs_tfr = mne.time_frequency.EpochsTFRArray(info, data, times=times, freqs=freqs)
nuse = sum([this_src["nuse"] for this_src in src])
stc_data = rng.integers(
-1000, 1000, size=(n_epochs, nuse, 3, freqs.size, times.size)
Expand Down
Loading