Skip to content

Commit

Permalink
Fix Spectrum rank checking
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Aug 2, 2024
1 parent 00c381a commit ddaffbe
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions mne_connectivity/spectral/epochs_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import numpy as np
from mne.epochs import BaseEpochs
from mne.parallel import parallel_func
from mne.time_frequency import EpochsSpectrum, EpochsSpectrumArray
from mne.time_frequency.multitaper import _psd_from_mt
from mne.utils import ProgressBar, _validate_type, logger


Expand All @@ -31,8 +33,14 @@ def _check_rank_input(rank, data, indices):
if "copy" in inspect.getfullargspec(data.get_data).kwonlyargs:
kwargs["copy"] = False
data_arr = data.get_data(**kwargs)
else:
data_arr = data
elif isinstance(data, (EpochsSpectrum, EpochsSpectrumArray)):
# Spectrum objs will drop bad channels, so specify picking all channels
data_arr = data.get_data(picks=np.arange(data.info["nchan"]))
# Convert to power (and aggregate over tapers) before computing rank
if "taper" in data._dims:
data_arr = _psd_from_mt(data_arr, data.weights)
else:
data_arr = (data_arr * data_arr.conj()).real

for group_i in range(2): # seeds and targets
for con_i, con_idcs in enumerate(indices[group_i]):
Expand Down

0 comments on commit ddaffbe

Please sign in to comment.