diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 23cdb080..186f5c3f 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -164,17 +164,16 @@ def _prepare_connectivity( first_epoch = epoch_block[0] # Sort times - if spectrum_computed and times_in is None: # if Spectrum object passed in as data + if spectrum_computed and times_in is None: # is a Spectrum object n_signals = first_epoch[0].shape[0] times = None n_times = 0 - times_in = None n_times_in = 0 tmin_idx = None tmax_idx = None warn_times = False - else: # if data has a time dimension (timeseries or TFR object) - if spectrum_computed: # if TFR object passed in as data + else: # data has a time dimension (timeseries or TFR object) + if spectrum_computed: # is a TFR object first_epoch = (first_epoch[0][:, 0],) # just take first freq ( n_signals, @@ -190,7 +189,7 @@ def _prepare_connectivity( ) # Sort freqs - if not spectrum_computed: # if timeseries passed in as data + if not spectrum_computed: # is an (ordinary) timeseries # check that fmin corresponds to at least 5 cycles fmin = _check_freqs(sfreq=sfreq, fmin=fmin, n_times=n_times) # compute frequencies to analyze based on number of samples, sampling rate, @@ -746,8 +745,8 @@ def spectral_connectivity_epochs( :class:`~mne.time_frequency.EpochsTFR` object. If timeseries data, the spectral information will be computed according to the spectral estimation mode (see the ``mode`` parameter). If an :class:`~mne.time_frequency.EpochsSpectrum` or - :class:`~mne.time_frequency.EpochsTFR` object, this spectral information will be - used and the ``mode`` parameter will be ignored. + :class:`~mne.time_frequency.EpochsTFR` object, existing spectral information + will be used and the ``mode`` parameter will be ignored. Note that it is also possible to combine multiple timeseries signals by providing a list of tuples, e.g.: :: @@ -1120,10 +1119,7 @@ def spectral_connectivity_epochs( weights = None metadata = None spectrum_computed = False - if isinstance( - data, - BaseEpochs | EpochsSpectrum | EpochsSpectrumArray | EpochsTFR | EpochsTFRArray, - ): + if isinstance(data, BaseEpochs | EpochsSpectrum | EpochsTFR): names = data.ch_names sfreq = data.info["sfreq"] @@ -1144,9 +1140,7 @@ def spectral_connectivity_epochs( data.add_annotations_to_metadata(overwrite=True) metadata = data.metadata - if isinstance( - data, EpochsSpectrum | EpochsSpectrumArray | EpochsTFR | EpochsTFRArray - ): + if isinstance(data, EpochsSpectrum | EpochsTFR): # XXX: Will need to be updated if new Spectrum methods are added if not np.iscomplexobj(data.get_data()): raise TypeError( diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 84d7b240..ae36bf47 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -45,7 +45,7 @@ def _check_rank_input(rank, data, indices): data_arr = _psd_from_mt(data_arr, data.weights) else: data_arr = (data_arr * data_arr.conj()).real - elif isinstance(data, EpochsTFR | EpochsTFRArray): + elif isinstance(data, EpochsTFR): # TFR objs will drop bad channels, so specify picking all channels data_arr = data.get_data(picks=np.arange(data.info["nchan"])) # Convert to power and aggregate over time before computing rank diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index badc9362..251ee11d 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -78,8 +78,8 @@ def spectral_connectivity_time( for each epoch as an :class:`~mne.time_frequency.EpochsTFR` object. If timeseries data, the spectral information will be computed according to the spectral estimation mode (see the ``mode`` parameter). If an - :class:`~mne.time_frequency.EpochsTFR` object, this spectral information will be - used and the ``mode`` parameter will be ignored. + :class:`~mne.time_frequency.EpochsTFR` object, existing spectral information + will be used and the ``mode`` parameter will be ignored. .. versionchanged:: 0.8 Fourier coefficients stored in an :class:`~mne.time_frequency.EpochsTFR` @@ -397,12 +397,12 @@ def spectral_connectivity_time( # extract data from Epochs object _validate_type( data, - (np.ndarray, BaseEpochs, EpochsTFR, EpochsTFRArray), + (np.ndarray, BaseEpochs, EpochsTFR), "`data`", "Epochs, EpochsTFR, or a NumPy array", ) spectrum_computed = False - if isinstance(data, BaseEpochs | EpochsTFR | EpochsTFRArray): + if isinstance(data, BaseEpochs | EpochsTFR): names = data.ch_names sfreq = data.info["sfreq"] events = data.events