Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel McCloy <[email protected]>
  • Loading branch information
tsbinns and drammock authored Sep 20, 2024
1 parent 23dc1f2 commit 8e79c9f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 19 deletions.
22 changes: 8 additions & 14 deletions mne_connectivity/spectral/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,16 @@ def _prepare_connectivity(
first_epoch = epoch_block[0]

# Sort times
if spectrum_computed and times_in is None: # if Spectrum object passed in as data
if spectrum_computed and times_in is None: # is a Spectrum object
n_signals = first_epoch[0].shape[0]
times = None
n_times = 0
times_in = None
n_times_in = 0
tmin_idx = None
tmax_idx = None
warn_times = False
else: # if data has a time dimension (timeseries or TFR object)
if spectrum_computed: # if TFR object passed in as data
else: # data has a time dimension (timeseries or TFR object)
if spectrum_computed: # is a TFR object
first_epoch = (first_epoch[0][:, 0],) # just take first freq
(
n_signals,
Expand All @@ -190,7 +189,7 @@ def _prepare_connectivity(
)

# Sort freqs
if not spectrum_computed: # if timeseries passed in as data
if not spectrum_computed: # is an (ordinary) timeseries
# check that fmin corresponds to at least 5 cycles
fmin = _check_freqs(sfreq=sfreq, fmin=fmin, n_times=n_times)
# compute frequencies to analyze based on number of samples, sampling rate,
Expand Down Expand Up @@ -746,8 +745,8 @@ def spectral_connectivity_epochs(
:class:`~mne.time_frequency.EpochsTFR` object. If timeseries data, the spectral
information will be computed according to the spectral estimation mode (see the
``mode`` parameter). If an :class:`~mne.time_frequency.EpochsSpectrum` or
:class:`~mne.time_frequency.EpochsTFR` object, this spectral information will be
used and the ``mode`` parameter will be ignored.
:class:`~mne.time_frequency.EpochsTFR` object, existing spectral information
will be used and the ``mode`` parameter will be ignored.
Note that it is also possible to combine multiple timeseries signals by
providing a list of tuples, e.g.: ::
Expand Down Expand Up @@ -1120,10 +1119,7 @@ def spectral_connectivity_epochs(
weights = None
metadata = None
spectrum_computed = False
if isinstance(
data,
BaseEpochs | EpochsSpectrum | EpochsSpectrumArray | EpochsTFR | EpochsTFRArray,
):
if isinstance(data, BaseEpochs | EpochsSpectrum | EpochsTFR):
names = data.ch_names
sfreq = data.info["sfreq"]

Expand All @@ -1144,9 +1140,7 @@ def spectral_connectivity_epochs(
data.add_annotations_to_metadata(overwrite=True)
metadata = data.metadata

if isinstance(
data, EpochsSpectrum | EpochsSpectrumArray | EpochsTFR | EpochsTFRArray
):
if isinstance(data, EpochsSpectrum | EpochsTFR):
# XXX: Will need to be updated if new Spectrum methods are added
if not np.iscomplexobj(data.get_data()):
raise TypeError(
Expand Down
2 changes: 1 addition & 1 deletion mne_connectivity/spectral/epochs_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _check_rank_input(rank, data, indices):
data_arr = _psd_from_mt(data_arr, data.weights)
else:
data_arr = (data_arr * data_arr.conj()).real
elif isinstance(data, EpochsTFR | EpochsTFRArray):
elif isinstance(data, EpochsTFR):
# TFR objs will drop bad channels, so specify picking all channels
data_arr = data.get_data(picks=np.arange(data.info["nchan"]))
# Convert to power and aggregate over time before computing rank
Expand Down
8 changes: 4 additions & 4 deletions mne_connectivity/spectral/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def spectral_connectivity_time(
for each epoch as an :class:`~mne.time_frequency.EpochsTFR` object. If
timeseries data, the spectral information will be computed according to the
spectral estimation mode (see the ``mode`` parameter). If an
:class:`~mne.time_frequency.EpochsTFR` object, this spectral information will be
used and the ``mode`` parameter will be ignored.
:class:`~mne.time_frequency.EpochsTFR` object, existing spectral information
will be used and the ``mode`` parameter will be ignored.
.. versionchanged:: 0.8
Fourier coefficients stored in an :class:`~mne.time_frequency.EpochsTFR`
Expand Down Expand Up @@ -397,12 +397,12 @@ def spectral_connectivity_time(
# extract data from Epochs object
_validate_type(
data,
(np.ndarray, BaseEpochs, EpochsTFR, EpochsTFRArray),
(np.ndarray, BaseEpochs, EpochsTFR),
"`data`",
"Epochs, EpochsTFR, or a NumPy array",
)
spectrum_computed = False
if isinstance(data, BaseEpochs | EpochsTFR | EpochsTFRArray):
if isinstance(data, BaseEpochs | EpochsTFR):
names = data.ch_names
sfreq = data.info["sfreq"]
events = data.events
Expand Down

0 comments on commit 8e79c9f

Please sign in to comment.