Skip to content

Commit

Permalink
fixed transforms for API
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrockhill committed Jul 22, 2021
1 parent 6548510 commit 9405705
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 78 deletions.
136 changes: 76 additions & 60 deletions mne/gui/_ieeg_locate_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
QComboBox, QPlainTextEdit)
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg

from ..channels import make_dig_montage, compute_native_head_t
from ..coreg import get_mni_fiducials
from .._freesurfer import _check_subject_dir
from ..io.constants import FIFF
from ..transforms import apply_trans, invert_transform
from ..utils import logger, _check_fname, requires_nibabel, warn


Expand Down Expand Up @@ -49,39 +52,16 @@
'elec_colors', _UNIQUE_COLORS, N=_N_COLORS)


def _load_electrodes(info):
"""Load previously determined electrode contact locations."""
ch_coords = np.array([ch['loc'][:3] for ch in info['chs']])
chs = dict()
for idx, ch_coord in enumerate(ch_coords):
ch_name = info['chs'][idx]['ch_name']
coord_frame = info['chs'][idx]['coord_frame']
is_eeg = any([info['chs'][idx]['kind'] == kind for kind in
(FIFF.FIFFV_EEG_CH, FIFF.FIFFV_SEEG_CH,
FIFF.FIFFV_DBS_CH, FIFF.FIFFV_ECOG_CH)])
if not is_eeg:
kind = info['chs'][idx]['kind']
warn(f'Skipping {ch_name}, kind {kind} not an EEG, SEEG, '
'DBS or ECOG electrode')
if coord_frame != FIFF.FIFFV_COORD_MRI:
raise RuntimeError(
'Coordinate frame channels in `info` must be "mri", '
f'got {coord_frame} for {ch_name}. The easiest way is to '
'use saved locations from this GUI. You may want to solve '
'previous channel locations having the wrong coordinate '
'frame by removing them with `inst.info["dig"] = []`')
chs[ch_name] = ch_coord
return


def _save_electrodes(info, chs, verbose=True):
"""Save the location of the electrode contacts."""
if verbose:
logger.info('Saving electrode positions to `info`')
for ch_name, ch_coord in chs.items():
idx = info.ch_names.index(ch_name)
info['chs'][idx]['loc'][:3] = ch_coord
return info
def _get_mri_head_trans(subject, subjects_dir):
"""Get the head to surface RAS transform using the Freesurfer recon."""
lpa, nasion, rpa = get_mni_fiducials(subject, subjects_dir)
assert lpa['ident'] == FIFF.FIFFV_POINT_LPA
assert nasion['ident'] == FIFF.FIFFV_POINT_NASION
assert rpa['ident'] == FIFF.FIFFV_POINT_RPA
montage = make_dig_montage(
lpa=lpa['r'], nasion=nasion['r'], rpa=rpa['r'], coord_frame='mri')
trans = compute_native_head_t(montage)
return trans


@requires_nibabel()
Expand Down Expand Up @@ -127,27 +107,29 @@ def __init__(self, parent=None, width=24, height=16, dpi=300):
class IntracranialElectrodeLocator(QMainWindow):
"""Pick electrodes using a coregistered MRI and CT."""

def __init__(self, info, ct, subject=None, subjects_dir=None,
def __init__(self, info, aligned_ct, subject, subjects_dir=None,
verbose=None):
"""A GUI that locates intracranial electrodes.
.. note :: Images will be displayed using orientation information
obtained from the image header. Images will be resampled to
dimensions [256, 256, 256] for display.
.. note:: Images will be displayed using orientation information
obtained from the image header. Images will be resampled to
dimensions [256, 256, 256] for display.
"""
# initialize QMainWindow class
super(IntracranialElectrodeLocator, self).__init__()

# store info for modification
self._info = info

# load imaging data
self._subject_dir = _check_subject_dir(subject, subjects_dir,
raise_error=False)
self._T1_on = self._subject_dir is not None
self._load_image_data(ct, verbose)
self._subject_dir = _check_subject_dir(subject, subjects_dir)
self._mri_head_t = _get_mri_head_trans(subject, subjects_dir)
self._load_image_data(aligned_ct, verbose)

self._elec_radius = int(np.mean(_ELEC_PLOT_SIZE) // 100)
# initialize electrode data
self._elec_index = 0
self._elecs = _load_electrodes(info)
self._elecs = self._load_electrodes(info)

# GUI design
self._make_slice_plots()
Expand Down Expand Up @@ -181,12 +163,11 @@ def __init__(self, info, ct, subject=None, subjects_dir=None,

def _load_image_data(self, ct, verbose):
# prepare MRI data
if self._subject_dir is not None:
self._img_data = _load_image_data(
op.join(self._subject_dir, 'mri', 'brain.mgz'),
'MRI Image', reorient=True, verbose=verbose)
self._mri_min = self._img_data.min()
self._mri_max = self._img_data.max()
self._mri_data = _load_image_data(
op.join(self._subject_dir, 'mri', 'brain.mgz'),
'MRI Image', reorient=True, verbose=verbose)
self._mri_min = self._mri_data.min()
self._mri_max = self._mri_data.max()

# ready ct
self._ct_data = _load_image_data(ct, 'CT', reorient=True)
Expand Down Expand Up @@ -234,23 +215,58 @@ def color_elec_radius(elec_image, xf, yf, group, radius):
elec_image, x / vx, y / vy, group, r)
return elec_image

def make_slice_plots(self):
def _load_electrodes(self, info):
"""Load previously determined electrode contact locations."""
ch_coords = np.array([ch['loc'][:3] for ch in info['chs']])
chs = dict()
for idx, ch_coord in enumerate(ch_coords):
ch_name = info['chs'][idx]['ch_name']
coord_frame = info['chs'][idx]['coord_frame']
is_eeg = any([info['chs'][idx]['kind'] == kind for kind in
(FIFF.FIFFV_EEG_CH, FIFF.FIFFV_SEEG_CH,
FIFF.FIFFV_DBS_CH, FIFF.FIFFV_ECOG_CH)])
if not is_eeg:
kind = info['chs'][idx]['kind']
warn(f'Skipping {ch_name}, kind {kind} not an EEG, SEEG, '
'DBS or ECOG electrode')
continue
if coord_frame != FIFF.FIFFV_COORD_HEAD: # should never be true
raise RuntimeError(
'Coordinate frame channels in `info` must be "head", got '
f'{coord_frame} for {ch_name}. The easiest way to solve '
'this is to use saved locations from this GUI or you may '
' want to remove previous channel locations with '
'`inst.info["dig"] = []` **will erase previous data**.')
chs[ch_name] = apply_trans(
invert_transform(self._mri_head_t), ch_coord)
return chs

def _save_electrodes(self, info, chs, verbose=True):
"""Save the location of the electrode contacts."""
if verbose:
logger.info('Saving electrode positions to `info`')
for ch_name, ch_coord in chs.items():
idx = info.ch_names.index(ch_name)
info['chs'][idx]['loc'][:3] = ch_coord
return info

def _make_slice_plots(self):
self._plt = SlicePlots(self)
# Plot sagittal (0), coronal (1) or axial (2) view
self._images = dict(mri=list(), ct=dict(), elec=dict(),
cursor=dict(), cursor2=dict())
for axis in range(3):
if self._subject_dir is not None:
img_data = np.take(self.img_data, self.current_slice[axis],
img_data = np.take(self._mri_data, self.current_slice[axis],
axis=axis).T
self._images['mri'].append(self._plt.axes[0, axis].imshow(
self._images['mri'].append(self._plt._axes[0, axis].imshow(
img_data, cmap='gray', aspect='auto'))
ct_data = np.take(self.ct_data, self.current_slice[axis],
ct_data = np.take(self._ct_data, self.current_slice[axis],
axis=axis).T
self._images['ct'][(0, axis)] = self._plt.axes[0, axis].imshow(
self._images['ct'][(0, axis)] = self._plt._axes[0, axis].imshow(
ct_data, cmap='hot', aspect='auto', alpha=0.5,
vmin=_CT_MIN_VAL, vmax=np.nanmax(self.ct_data))
self._images['ct'][(1, axis)] = self._plt.axes[1, axis].imshow(
self._images['ct'][(1, axis)] = self._plt._axes[1, axis].imshow(
ct_data, cmap='gray', aspect='auto')
for axis2 in range(2):
self._images['elec'][(axis2, axis)] = \
Expand All @@ -260,7 +276,7 @@ def make_slice_plots(self):
cmap=_ELECTRODE_CMAP, alpha=1, vmin=0, vmax=_N_COLORS)
self._images['cursor'][(axis2, axis)] = \
self._plt._axes[axis2, axis].plot(
(self.current_slice[1], self.current_slice[1]),
(self._current_slice[1], self._current_slice[1]),
(0, _VOXEL_SIZES[axis]), color=[0, 1, 0],
linewidth=0.25)[0]
self._images['cursor2'][(axis2, axis)] = \
Expand Down Expand Up @@ -310,7 +326,7 @@ def _get_button_bar(self):

RAS_label = QLabel('RAS=')
self._RAS_textbox = QPlainTextEdit(
'{:.2f}, {:.2f}, {:.2f}'.format(*self.cursors_to_RAS()))
'{:.2f}, {:.2f}, {:.2f}'.format(*self._cursors_to_RAS()))
self._RAS_textbox.setMaximumHeight(25)
self._RAS_textbox.setMaximumWidth(200)
self._RAS_textbox.focusOutEvent = self._update_RAS
Expand Down Expand Up @@ -601,7 +617,7 @@ def _mark_elec(self):
self._cursors_to_RAS().tolist() + [self._get_group(), 'n/a']
self._color_list_item()
self._update_elec_images(draw=True)
_save_electrodes(self._elecs)
self._save_electrodes(self._elecs)
self._next_elec()

@pyqtSlot()
Expand All @@ -610,7 +626,7 @@ def _remove_elec(self):
if name in self._elecs:
self._color_list_item(clear=True)
self._elecs.pop(name)
_save_electrodes(self._elecs)
self._save_electrodes(self._elecs)
self._update_elec_images(draw=True)

def _update_elec_images(self, axis_selected=None, draw=False):
Expand All @@ -624,11 +640,11 @@ def _update_elec_images(self, axis_selected=None, draw=False):
def _update_mri_images(self, axis_selected=None, draw=False):
for axis in range(3) if axis_selected is None else [axis_selected]:
if self._T1_on:
img_data = np.take(self._img_data, self._current_slice[axis],
img_data = np.take(self._mri_data, self._current_slice[axis],
axis=axis).T
else:
img_data = np.take(np.zeros(_VOXEL_SIZES),
self.current_slice[axis], axis=axis).T
self._current_slice[axis], axis=axis).T
self._images['mri'][axis].set_data(img_data)
if draw:
self._plt._fig.canvas.draw()
Expand Down
34 changes: 16 additions & 18 deletions mne/gui/tests/test_ieeg_locate_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,39 @@
fname_trans = op.join(sample_dir, 'sample_audvis_trunc-trans.fif')


@requires_nibabel()
def _make_fake_CT():
def _make_fake_CT(shape, ch_coords, contact_size):
"""Make somewhat realistic CT data.
Looks like a hollow cube with three contacts in it.
"""
import nibabel as nib
size = (512, 512, 192) # make realistic weird size
center = np.array(size) / 2
contact_size = 2
center = np.array(shape) / 2
# make image
ct_data = np.zeros(size, dtype=np.float32)
ct_data = np.zeros(shape, dtype=np.float32)
# make cube for skull
ct_data[tuple(slice(int(center[i] - s * 0.3), int(center[i] + s * 0.3))
for i, s in enumerate(size))] = 1000
for i, s in enumerate(shape))] = 1000
# hollow out inner section
ct_data[tuple(slice(int(center[i] - s * 0.2), int(center[i] + s * 0.2))
for i, s in enumerate(size))] = 0
ch_coords = [(240, 290, 88), (245, 248, 100), (280, 250, 90)]
for i, s in enumerate(shape))] = 0
for (x, y, z) in ch_coords:
# make sure not in skull
assert np.linalg.norm(center - np.array((x, y, z))) < 50
for i in range(-contact_size, contact_size + 1):
for j in range(-contact_size, contact_size + 1):
for k in range(-contact_size, contact_size + 1):
ct_data[x + i, y + j, z + k] = 1000
affine = np.array([[-0.41499999, -0., 0., 0.],
[0., 0.41499999, 0., -212.06500244],
[0., 0., 0.833, -159.10299683],
[0., 0., 0., 1.]])
ct = nib.MGHImage(ct_data, affine)
return ct, ch_coords
return ct_data


@requires_nibabel()
def test_ieeg_elec_locate_gui_display():
"""Test that the intracranial location GUI displays properly."""
import nibabel as nib
raw = mne.io.read_raw_fif(raw_path)
ct, ch_coords = _make_fake_CT()
raw.info = mne.gui.locate_ieeg(raw.info, ct)
raw.pick_types(eeg=True)
T1 = nib.load(op.join(subjects_dir, subject, 'mri', 'T1.mgz'))
x, y, z = np.array(T1.shape).astype(int) // 2
ch_coords = [(x, y, z), (x - 6, y + 1, z), (x - 12, y + 1, z + 1)]
ct_data = _make_fake_CT(T1.shape, ch_coords, contact_size=1)
ct_aligned = nib.MGHImage(ct_data, T1.affine)
mne.gui.locate_ieeg(raw.info, ct_aligned, subject, subjects_dir)

0 comments on commit 9405705

Please sign in to comment.