diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 7ba0a154..3ec0bb43 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -12,6 +12,11 @@ from mne.epochs import BaseEpochs from mne.parallel import parallel_func from mne.source_estimate import _BaseSourceEstimate +from mne.time_frequency import ( + EpochsSpectrum, + EpochsSpectrumArray, + EpochsTFR, +) from mne.time_frequency.multitaper import ( _compute_mt_params, _csd_from_mt, @@ -19,12 +24,8 @@ _psd_from_mt, _psd_from_mt_adaptive, ) -from mne.time_frequency.spectrum import ( - BaseSpectrum, - EpochsSpectrum, - EpochsSpectrumArray, -) -from mne.time_frequency.tfr import cwt, morlet +from mne.time_frequency.spectrum import BaseSpectrum +from mne.time_frequency.tfr import BaseTFR, cwt, morlet from mne.utils import _arange_div, _check_option, _time_mask, logger, verbose, warn from ..base import SpectralConnectivity, SpectroTemporalConnectivity @@ -161,17 +162,18 @@ def _prepare_connectivity( """Check and precompute dimensions of results data.""" first_epoch = epoch_block[0] - # Sort times and freqs - if spectrum_computed: + # Sort times + if spectrum_computed and times_in is None: # is a Spectrum object n_signals = first_epoch[0].shape[0] times = None - n_times = None - times_in = None - n_times_in = None + n_times = 0 + n_times_in = 0 tmin_idx = None tmax_idx = None warn_times = False - else: + else: # data has a time dimension (timeseries or TFR object) + if spectrum_computed: # is a TFR object + first_epoch = (first_epoch[0][:, 0],) # just take first freq ( n_signals, times, @@ -184,6 +186,9 @@ def _prepare_connectivity( ) = _check_times( data=first_epoch, sfreq=sfreq, times=times_in, tmin=tmin, tmax=tmax ) + + # Sort freqs + if not spectrum_computed: # is an (ordinary) timeseries # check that fmin corresponds to at least 5 cycles fmin = _check_freqs(sfreq=sfreq, fmin=fmin, n_times=n_times) # compute frequencies to analyze based on number of samples, sampling rate, @@ -511,14 +516,19 @@ def _epoch_spectral_connectivity( # compute tapered spectra if spectrum_computed: # use existing spectral info - # XXX: Will need to distinguish time-resolved spectra here if support added - # Select signals & freqs of interest (flexible indexing for optional tapers dim) - x_t = np.array(data)[:, sig_idx][..., freq_mask] # split dims to avoid np.ix_ - if weights is None: # also assumes no tapers dim - x_t = np.expand_dims(x_t, axis=2) # CSD construction expects a tapers dim - weights = np.ones((1, 1, 1)) + # Select entries of interest (flexible indexing for optional tapers dim) + if tmin_idx is not None and tmax_idx is not None: + x_t = np.asarray(data)[:, sig_idx][..., freq_mask, tmin_idx:tmax_idx] + else: + x_t = np.asarray(data)[:, sig_idx][..., freq_mask] + if weights is None: # assumes no tapers dim + x_t = np.expand_dims(x_t, axis=2) # CSD construction expects tapers dim + weights = np.ones((1, 1, 1)) if accumulate_psd: - this_psd = _psd_from_mt(x_t, weights) + if weights is not None: # only None if mode == 'cwt_morlet' + this_psd = _psd_from_mt(x_t, weights) + else: + this_psd = (x_t * x_t.conj()).real else: # compute spectral info from scratch x_t, this_psd, weights = _compute_spectra( data=data, @@ -727,13 +737,14 @@ def spectral_connectivity_epochs( Parameters ---------- - data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs | ~mne.time_frequency.EpochsSpectrum + data : array-like, shape=(n_epochs, n_signals, n_times) | ~mne.Epochs | ~mne.time_frequency.EpochsSpectrum | ~mne.time_frequency.EpochsTFR The data from which to compute connectivity. Can be epoched timeseries data as an :term:`array-like` or :class:`~mne.Epochs` object, or Fourier coefficients - for each epoch as an :class:`~mne.time_frequency.EpochsSpectrum` object. If - timeseries data, the spectral information will be computed according to the - spectral estimation mode (see the ``mode`` parameter). If an - :class:`~mne.time_frequency.EpochsSpectrum` object, this spectral information + for each epoch as an :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object. If timeseries data, the spectral + information will be computed according to the spectral estimation mode (see the + ``mode`` parameter). If an :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object, existing spectral information will be used and the ``mode`` parameter will be ignored. Note that it is also possible to combine multiple timeseries signals by @@ -748,8 +759,9 @@ def spectral_connectivity_epochs( .. versionchanged:: 0.8 Fourier coefficients stored in an :class:`~mne.time_frequency.EpochsSpectrum` - or :class:`~mne.time_frequency.EpochsSpectrumArray` object can also be passed - in as data. Storing Fourier coefficients requires ``mne >= 1.8``. + or :class:`~mne.time_frequency.EpochsTFR` object can also be passed in as + data. Storing Fourier coefficients in + :class:`~mne.time_frequency.EpochsSpectrum` objects requires ``mne >= 1.8``. %(names)s method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'cohy', @@ -789,7 +801,8 @@ def spectral_connectivity_epochs( mode : str Spectrum estimation mode can be either: 'multitaper', 'fourier', or 'cwt_morlet'. Ignored if ``data`` is an - :class:`~mne.time_frequency.EpochsSpectrum` object. + :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object. fmin : float | tuple of float The lower frequency of interest. Multiple bands are defined using a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. @@ -1105,7 +1118,7 @@ def spectral_connectivity_epochs( weights = None metadata = None spectrum_computed = False - if isinstance(data, BaseEpochs | EpochsSpectrum | EpochsSpectrumArray): + if isinstance(data, BaseEpochs | EpochsSpectrum | EpochsTFR): names = data.ch_names sfreq = data.info["sfreq"] @@ -1126,28 +1139,45 @@ def spectral_connectivity_epochs( data.add_annotations_to_metadata(overwrite=True) metadata = data.metadata - if isinstance(data, EpochsSpectrum | EpochsSpectrumArray): + if isinstance(data, EpochsSpectrum | EpochsTFR): # XXX: Will need to be updated if new Spectrum methods are added if not np.iscomplexobj(data.get_data()): raise TypeError( - "if `data` is an EpochsSpectrum object, it must contain " - "complex-valued Fourier coefficients, such as that returned from " - "Epochs.compute_psd(output='complex')" + "if `data` is an EpochsSpectrum or EpochsTFR object, it must " + "contain complex-valued Fourier coefficients, such as that " + "returned from Epochs.compute_psd/tfr() with `output='complex'`" ) if "segment" in data._dims: raise ValueError( "`data` cannot contain Fourier coefficients for individual segments" ) - if isinstance(data, EpochsSpectrum): # mode can be read mode from Spectrum - mode = data.method - mode = "fourier" if mode == "welch" else mode - else: # spectral method is "unknown", so take mode from data dimensions - # Currently, actual mode doesn't matter as long as we handle tapers and - # their weights in the same way as for multitaper spectra - mode = "multitaper" if "taper" in data._dims else "fourier" + mode = data.method + if isinstance(data, EpochsSpectrum | EpochsSpectrumArray): + if isinstance(data, EpochsSpectrum): # read mode from object + mode = "fourier" if mode == "welch" else mode + else: # infer mode from dimensions + # Currently, actual mode doesn't matter as long as we handle tapers + # and their weights in the same way as for multitaper spectra + mode = "multitaper" if "taper" in data._dims else "fourier" + weights = data.weights + else: + if isinstance(data, EpochsTFR): # read mode from object + if mode != "morlet": # FIXME: Add support for other TFR methods + raise ValueError( + "if `data` is an EpochsTFR object, the spectral method " + "must be 'morlet'" + ) + else: + if "taper" in data._dims: # FIXME: Add support for multitaper TFR + raise ValueError( + "if `data` is an EpochsTFRArray object, it cannot contain " + "Fourier coefficients for individual tapers" + ) + mode = "cwt_morlet" # currently only supported mode here + times_in = data.times + weights = None # no weights stored in TFR objects spectrum_computed = True freqs = data.freqs - weights = data.weights else: times_in = data.times # input times for Epochs input type elif sfreq is None: @@ -1235,7 +1265,7 @@ def spectral_connectivity_epochs( spectral_params = dict( eigvals=None, window_fun=None, wavelets=None, weights=weights ) - n_times_spectrum = 0 + n_times_spectrum = n_times # 0 if no times n_tapers = None if weights is None else weights.size # unique signals for which we actually need to compute PSD etc. @@ -1289,7 +1319,7 @@ def spectral_connectivity_epochs( logger.info(f" the following metrics will be computed: {metrics_str}") # check dimensions and time scale - if not spectrum_computed: # XXX: Can we assume upstream checks sufficient? + if not spectrum_computed: for this_epoch in epoch_block: _, _, _, warn_times = _get_and_verify_data_sizes( this_epoch, @@ -1469,7 +1499,9 @@ def spectral_connectivity_epochs( freqs=freqs, method=_method, n_nodes=n_nodes, - spec_method=mode if not isinstance(data, BaseSpectrum) else data.method, + spec_method=( + mode if not isinstance(data, BaseSpectrum | BaseTFR) else data.method + ), indices=indices, n_epochs_used=n_epochs, freqs_used=freqs_used, diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 44e91213..1531d4cb 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -15,7 +15,11 @@ import numpy as np from mne.epochs import BaseEpochs from mne.parallel import parallel_func -from mne.time_frequency import EpochsSpectrum, EpochsSpectrumArray +from mne.time_frequency import ( + EpochsSpectrum, + EpochsSpectrumArray, + EpochsTFR, +) from mne.time_frequency.multitaper import _psd_from_mt from mne.utils import ProgressBar, _validate_type, logger @@ -40,6 +44,12 @@ def _check_rank_input(rank, data, indices): data_arr = _psd_from_mt(data_arr, data.weights) else: data_arr = (data_arr * data_arr.conj()).real + elif isinstance(data, EpochsTFR): + # TFR objs will drop bad channels, so specify picking all channels + data_arr = data.get_data(picks=np.arange(data.info["nchan"])) + # Convert to power and aggregate over time before computing rank + # XXX: need to change when other types of TFR are supported + data_arr = np.sum((data_arr * data_arr.conj()).real, axis=-1) else: data_arr = data diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index d13cbea4..93da85d0 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -474,9 +474,9 @@ def test_spectral_connectivity(method, mode): not check_version("mne", "1.8"), reason="Requires MNE v1.8.0 or higher" ) @pytest.mark.parametrize("method", ["coh", "cacoh"]) -@pytest.mark.parametrize("mode", ["multitaper", "fourier"]) +@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) def test_spectral_connectivity_epochs_spectrum_input(method, mode): - """Test spec_conn_epochs works with EpochsSpectrum data as input. + """Test spec_conn_epochs works with EpochsSpectrum/TFR data as input. Important to test both bivariate and multivariate methods, as the latter involves additional steps (e.g., rank computation). @@ -489,7 +489,7 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): n_epochs = 30 n_times = 200 # samples trans_bandwidth = 1.0 # Hz - delay = 10 # samples + delay = 5 # samples data = make_signals_in_freq_bands( n_seeds=n_seeds, @@ -499,7 +499,7 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): n_times=n_times, sfreq=sfreq, trans_bandwidth=trans_bandwidth, - snr=0.5, + snr=0.7, connection_delay=delay, rng_seed=44, ) @@ -512,26 +512,35 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): indices = ([np.arange(n_seeds)], [np.arange(n_targets) + n_seeds]) # Compute Fourier coefficients + cwt_freqs = np.arange(10, 50) # similar to Fourier & multitaper modes kwargs = dict() if mode == "fourier": kwargs.update(window="hann") # default is Hamming, but we need Hanning - coeffs = data.compute_psd( - method="welch" if mode == "fourier" else mode, output="complex", **kwargs - ) + spec_mode = "welch" + elif mode == "cwt_morlet": + kwargs.update(freqs=cwt_freqs) + spec_mode = "morlet" + else: + spec_mode = mode + compute_method = data.compute_tfr if mode == "cwt_morlet" else data.compute_psd + coeffs = compute_method(method=spec_mode, output="complex", **kwargs) # Compute connectivity con = spectral_connectivity_epochs(data=coeffs, method=method, indices=indices) - # Check connectivity from Epochs and Spectrum are equivalent; - # Works for multitaper, but Welch of Spectrum and Fourier of spec_conn are slightly - # off (max. abs. diff. ~0.006) even when what should be identical settings are used + # Check connectivity from Epochs and Spectrum/TFR are equivalent con_from_epochs = spectral_connectivity_epochs( - data=data, method=method, indices=indices, mode=mode + data=data, method=method, indices=indices, mode=mode, cwt_freqs=cwt_freqs ) - if mode == "multitaper": - atol = 0 - else: + # Works for multitaper & Morlet, but Welch of Spectrum and Fourier of spec_conn are + # slightly off (max. abs. diff. ~0.006). This is due to the Spectrum object using + # scipy.signal.spectrogram to compute the coefficients, while spec_conn uses + # scipy.signal.rfft, which give slightly different outputs even with identical + # settings. + if mode == "fourier": atol = 7e-3 + else: + atol = 0 # spec_conn_epochs excludes freqs without at least 5 cycles, but not Spectrum fstart = con.freqs.index(con_from_epochs.freqs[0]) assert_allclose( @@ -546,25 +555,20 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): freqs_noise = (freqs < fband[0] - trans_bandwidth * 2) | ( freqs > fband[1] + trans_bandwidth * 2 ) - - # nothing for CaCoh to optimise, so use same thresholds for CaCoh and Coh - if mode == "multitaper": # lower baseline for multitaper - con_thresh = (0.1, 0.3) - else: # higher baseline for Welch/Fourier - con_thresh = (0.2, 0.4) - + WEAK_CONN_OR_NOISE = 0.3 # conn values outside of simulated fband should be < this + STRONG_CONN = 0.6 # conn values inside simulated fband should be > this # check freqs of simulated interaction show strong connectivity - assert_array_less(con_thresh[1], np.abs(con.get_data()[:, freqs_con].mean())) + assert_array_less(STRONG_CONN, np.abs(con.get_data()[:, freqs_con].mean())) # check freqs of no simulated interaction (just noise) show weak connectivity - assert_array_less(np.abs(con.get_data()[:, freqs_noise].mean()), con_thresh[0]) + assert_array_less(np.abs(con.get_data()[:, freqs_noise].mean()), WEAK_CONN_OR_NOISE) # TODO: Add general test for error catching for spec_conn_epochs @pytest.mark.skipif( not check_version("mne", "1.8"), reason="Requires MNE v1.8.0 or higher" ) -def test_spectral_connectivity_epochs_spectrum_input_error_catch(): - """Test spec_conn_epochs catches error with EpochsSpectrum data as input.""" +def test_spectral_connectivity_epochs_spectrum_tfr_input_error_catch(): + """Test spec_conn_epochs catches errors with EpochsSpectrum/TFR data as input.""" # Generate data rng = np.random.default_rng(44) n_epochs, n_chans, n_times = (5, 2, 50) @@ -1684,6 +1688,89 @@ def test_multivar_spectral_connectivity_time_shapes( assert np.all(np.array(con.indices) == np.array(([[0, 1, 2]], [[3, 4, -1]]))) +@pytest.mark.parametrize("method", ["coh", "cacoh"]) +@pytest.mark.parametrize("mode", ["multitaper", "cwt_morlet"]) +def test_spectral_connectivity_time_tfr_input(method, mode): + """Test spec_conn_time works with EpochsTFR data as input. + + Important to test both bivariate and multivariate methods, as the latter involves + additional steps (e.g., rank computation). + """ + # Simulation parameters & data generation + n_seeds = 2 + n_targets = 2 + fband = (15, 20) # Hz + trans_bandwidth = 1.0 # Hz + + data = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=fband, + n_epochs=30, + n_times=200, + sfreq=100, + trans_bandwidth=trans_bandwidth, + snr=0.7, + connection_delay=5, + rng_seed=44, + ) + + if method == "coh": + indices = seed_target_indices( + seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds + ) + else: + indices = ([np.arange(n_seeds)], [np.arange(n_targets) + n_seeds]) + + # Compute TFR + freqs = np.arange(10, 50) + n_cycles = 5.0 # non-default value to avoid warning in spec_conn_time + mt_bandwidth = 4.0 + kwargs = dict() + if mode == "cwt_morlet": + kwargs.update(zero_mean=False) # default in spec_conn_time + spec_mode = "morlet" + else: + kwargs.update(time_bandwidth=mt_bandwidth) + spec_mode = mode + coeffs = data.compute_tfr( + method=spec_mode, freqs=freqs, n_cycles=n_cycles, output="complex", **kwargs + ) + + # Compute connectivity + con_kwargs = dict( + method=method, + indices=indices, + mode=mode, + freqs=freqs, + n_cycles=n_cycles, + mt_bandwidth=mt_bandwidth, + average=True, + ) + con = spectral_connectivity_time(data=coeffs, **con_kwargs) + + # Check connectivity from Epochs and EpochsTFR are equivalent (small but non-zero + # tolerance given due to some platform-dependent variation) + con_from_epochs = spectral_connectivity_time(data=data, **con_kwargs) + assert_allclose( + np.abs(con.get_data()), np.abs(con_from_epochs.get_data()), atol=1e-7 + ) + + # Check connectivity values are as expected + freqs_con = (freqs >= fband[0]) & (freqs <= fband[1]) + freqs_noise = (freqs < fband[0] - trans_bandwidth * 2) | ( + freqs > fband[1] + trans_bandwidth * 2 + ) + # check freqs of simulated interaction show strong connectivity + assert_array_less(0.6, np.abs(con.get_data()[:, freqs_con].mean())) + # check freqs of no simulated interaction (just noise) show weak connectivity + assert_array_less(np.abs(con.get_data()[:, freqs_noise].mean()), 0.3) + + +test_spectral_connectivity_time_tfr_input("cacoh", "cwt_morlet") + + +# TODO: Add general test for error catching for spec_conn_time @pytest.mark.parametrize("method", ["cacoh", "mic", "mim", _gc, _gc_tr]) @pytest.mark.parametrize("mode", ["multitaper", "cwt_morlet"]) def test_multivar_spectral_connectivity_time_error_catch(method, mode): @@ -1705,7 +1792,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): freqs = np.arange(10, 25 + 1) # test type-checking of data - with pytest.raises(TypeError, match="must be an instance of Epochs or a NumPy arr"): + with pytest.raises(TypeError, match="Epochs, EpochsTFR, or a NumPy arr"): spectral_connectivity_time(data="foo", freqs=freqs) # check bad indices without nested array caught @@ -1822,6 +1909,30 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): ) +def test_spectral_connectivity_time_tfr_input_error_catch(): + """Test spec_conn_time catches errors with EpochsTFR data as input.""" + # Generate data + rng = np.random.default_rng(44) + n_epochs, n_chans, n_times = (5, 2, 100) + sfreq = 50 + data = rng.random((n_epochs, n_chans, n_times)) + info = create_info(ch_names=n_chans, sfreq=sfreq, ch_types="eeg") + data = EpochsArray(data=data, info=info) + freqs = np.arange(10, 20) + + # Test not Fourier coefficients caught + with pytest.raises(TypeError, match="must contain complex-valued Fourier coeff"): + tfr = data.compute_tfr(method="morlet", freqs=freqs, output="power") + spectral_connectivity_time(data=tfr, freqs=freqs) + + # Catch default value warning (multitaper only) + tfr = data.compute_tfr(method="multitaper", freqs=freqs, output="complex") + with pytest.warns(RuntimeWarning, match="`mt_bandwidth` is not specified"): + spectral_connectivity_time(data=tfr, freqs=freqs, n_cycles=5.0) + with pytest.warns(RuntimeWarning, match="`n_cycles` is the default value"): + spectral_connectivity_time(data=tfr, freqs=freqs, mt_bandwidth=4.0) + + def test_save(tmp_path): """Test saving results of spectral connectivity.""" epochs = make_signals_in_freq_bands( diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index a8bcdfb6..c14923dc 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -10,8 +10,13 @@ import xarray as xr from mne.epochs import BaseEpochs from mne.parallel import parallel_func -from mne.time_frequency import dpss_windows, tfr_array_morlet, tfr_array_multitaper -from mne.utils import _validate_type, logger, verbose +from mne.time_frequency import ( + EpochsTFR, + dpss_windows, + tfr_array_morlet, + tfr_array_multitaper, +) +from mne.utils import _check_option, _validate_type, logger, verbose, warn from ..base import EpochSpectralConnectivity, SpectralConnectivity from ..utils import _check_multivariate_indices, check_indices, fill_doc @@ -66,12 +71,23 @@ def spectral_connectivity_time( Parameters ---------- - data : array_like, shape (n_epochs, n_signals, n_times) | Epochs - The data from which to compute connectivity. + data : array_like, shape (n_epochs, n_signals, n_times) | ~mne.Epochs | ~mne.time_frequency.EpochsTFR + The data from which to compute connectivity. Can be epoched timeseries data as + an :term:`array-like` or :class:`~mne.Epochs` object, or Fourier coefficients + for each epoch as an :class:`~mne.time_frequency.EpochsTFR` object. If + timeseries data, the spectral information will be computed according to the + spectral estimation mode (see the ``mode`` parameter). If an + :class:`~mne.time_frequency.EpochsTFR` object, existing spectral information + will be used and the ``mode`` parameter will be ignored. + + .. versionchanged:: 0.8 + Fourier coefficients stored in an :class:`~mne.time_frequency.EpochsTFR` + object can also be passed in as data. freqs : array_like - Array of frequencies of interest for time-frequency decomposition. - Only the frequencies within the range specified by ``fmin`` and - ``fmax`` are used. + Array of frequencies of interest for time-frequency decomposition. Only the + frequencies within the range specified by ``fmin`` and ``fmax`` are used. If + ``data`` is an :class:`~mne.time_frequency.EpochsTFR` object, ``data.freqs`` is + used and this parameter is ignored. method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'pli', 'wpli', 'gc', 'gc_tr']``. These @@ -104,8 +120,8 @@ def spectral_connectivity_time( connections between all channels are computed, unless a Granger causality method is called, in which case an error is raised. sfreq : float - The sampling frequency. Required if data is not - :class:`Epochs `. + The sampling frequency. Required if ``data`` is not an :class:`~mne.Epochs` or + :class:`~mne.time_frequency.EpochsTFR` object. fmin : float | tuple of float | None The lower frequency of interest. Multiple bands are defined using a tuple, e.g., ``(8., 20.)`` for two bands with 8 Hz and 20 Hz lower @@ -133,21 +149,26 @@ def spectral_connectivity_time( Amount of time to consider as padding at the beginning and end of each epoch in seconds. See Notes for more information. mode : str - Time-frequency decomposition method. Can be either: 'multitaper', or - 'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and - :func:`mne.time_frequency.tfr_array_morlet` for reference. + Time-frequency decomposition method. Can be either: ``'multitaper'``, or + ``'cwt_morlet'``. See :func:`mne.time_frequency.tfr_array_multitaper` and + :func:`mne.time_frequency.tfr_array_morlet` for reference. Ignored if ``data`` + is an :class:`~mne.time_frequency.EpochsTFR` object. mt_bandwidth : float | None - Product between the temporal window length (in seconds) and the full - frequency bandwidth (in Hz). This product can be seen as the surface - of the window on the time/frequency plane and controls the frequency - bandwidth (thus the frequency resolution) and the number of good - tapers. See :func:`mne.time_frequency.tfr_array_multitaper` - documentation. + Product between the temporal window length (in seconds) and the full frequency + bandwidth (in Hz). This product can be seen as the surface of the window on the + time/frequency plane and controls the frequency bandwidth (thus the frequency + resolution) and the number of good tapers. See + :func:`mne.time_frequency.tfr_array_multitaper` documentation. If ``data`` is an + :class:`~mne.time_frequency.EpochsTFR` object computed with the ``multitaper`` + method, this should match the value used to compute the TFR. n_cycles : float | array_like of float - Number of cycles in the wavelet, either a fixed number or one per - frequency. The number of cycles ``n_cycles`` and the frequencies of - interest ``cwt_freqs`` define the temporal window length. For details, - see :func:`mne.time_frequency.tfr_array_morlet` documentation. + Number of cycles in the wavelet, either a fixed number or one per frequency. The + number of cycles ``n_cycles`` and the frequencies of interest ``freqs`` define + the temporal window length. For details, see + :func:`mne.time_frequency.tfr_array_multitaper` and + :func:`mne.time_frequency.tfr_array_morlet` documentation. If ``data`` is an + :class:`~mne.time_frequency.EpochsTFR` object computed with the ``multitaper`` + method, this should match the value used to compute the TFR. gc_n_lags : int Number of lags to use for the vector autoregressive model when computing Granger causality. Higher values increase computational cost, @@ -369,12 +390,18 @@ def spectral_connectivity_time( References ---------- .. footbibliography:: - """ + """ # noqa: E501 events = None event_id = None # extract data from Epochs object - _validate_type(data, (np.ndarray, BaseEpochs), "`data`", "Epochs or a NumPy array") - if isinstance(data, BaseEpochs): + _validate_type( + data, + (np.ndarray, BaseEpochs, EpochsTFR), + "`data`", + "Epochs, EpochsTFR, or a NumPy array", + ) + spectrum_computed = False + if isinstance(data, BaseEpochs | EpochsTFR): names = data.ch_names sfreq = data.info["sfreq"] events = data.events @@ -392,12 +419,40 @@ def spectral_connectivity_time( if hasattr(data, "annotations") and not annots_in_metadata: data.add_annotations_to_metadata(overwrite=True) metadata = data.metadata - # XXX: remove logic once support for mne<1.6 is dropped - kwargs = dict() - if "copy" in inspect.getfullargspec(data.get_data).kwonlyargs: - kwargs["copy"] = False - data = data.get_data(**kwargs) - n_epochs, n_signals, n_times = data.shape + if isinstance(data, BaseEpochs): + # XXX: remove logic once support for mne<1.6 is dropped + kwargs = dict() + if "copy" in inspect.getfullargspec(data.get_data).kwonlyargs: + kwargs["copy"] = False + data = data.get_data(**kwargs) + n_epochs, n_signals, n_times = data.shape + else: + freqs = data.freqs # use freqs from EpochsTFR object + # Read mode from the EpochsTFR object + mode = "cwt_morlet" if data.method == "morlet" else data.method + # TFR objs will drop bad channels, so specify picking all channels + data = data.get_data(picks=np.arange(data.info["nchan"])) + if not np.iscomplexobj(data): + raise TypeError( + "if `data` is an EpochsTFR object, it must contain complex-valued " + "Fourier coefficients, such as that returned from " + "Epochs.compute_tfr() with `output='complex'`" + ) + n_epochs, n_signals = data.shape[:2] + n_times = data.shape[-1] + # Warn if mt_bandwidth & n_cycles default-valued (can't be read from TFR) + if mode == "multitaper": # doesn't matter for cwt_morlet + if mt_bandwidth is None: + warn( + "`mt_bandwidth` is not specified; assuming 4.0 Hz was used to " + "compute the TFR" + ) + if n_cycles == 7.0: + warn( + "`n_cycles` is the default value; assuming 7.0 was used to " + "compute the TFR" + ) + spectrum_computed = True else: data = np.asarray(data) n_epochs, n_signals, n_times = data.shape @@ -520,8 +575,7 @@ def spectral_connectivity_time( n_components = _check_n_components_input(n_components, rank) if n_components == 1: # n_components=0 means space for a components dimension is not allocated in - # the results, similar to how n_times_spectrum=0 is used to indicate that - # time is not a dimension in the results + # the results n_components = 0 else: rank = None @@ -620,9 +674,10 @@ def spectral_connectivity_time( padding=padding, kw_cwt={}, kw_mt={}, + multivariate_con=multivariate_con, + spectrum_computed=spectrum_computed, n_jobs=n_jobs, verbose=verbose, - multivariate_con=multivariate_con, ) for epoch_idx in np.arange(n_epochs): @@ -713,16 +768,17 @@ def _spectral_connectivity( padding, kw_cwt, kw_mt, + multivariate_con, + spectrum_computed, n_jobs, verbose, - multivariate_con, ): """Estimate time-resolved connectivity for one epoch. Parameters ---------- - data : array_like, shape (n_channels, n_times) - Time-series data. + data : array_like, shape (channels, [freqs,] [tapers,] times) + Time-series data or time-frequency data. method : list of str List of connectivity metrics to compute. kernel : array_like, shape (n_sm_fres, n_sm_times) @@ -766,6 +822,8 @@ def _spectral_connectivity( epoch in seconds. multivariate_con : bool Whether or not multivariate connectivity is to be computed. + spectrum_computed : bool + Whether or not the time-frequency decomposition has already been computed. Returns ------- @@ -784,51 +842,62 @@ def _spectral_connectivity( and target signals (respectively). ``n_comps`` is present for valid multivariate methods if ``n_components > 0``. """ - n_cons = len(source_idx) - data = np.expand_dims(data, axis=0) - kw_cwt.setdefault("zero_mean", False) # avoid FutureWarning + # check that spectral mode is recognised + _check_option("mode", mode, ("cwt_morlet", "multitaper")) + + # compute time-frequency decomposition + mt_bandwidth = mt_bandwidth if mt_bandwidth else 4 + if not spectrum_computed: + data = np.expand_dims(data, axis=0) + kw_cwt.setdefault("zero_mean", False) # avoid FutureWarning + if mode == "cwt_morlet": + out = tfr_array_morlet( + data, + sfreq, + freqs, + n_cycles=n_cycles, + output="complex", + decim=decim, + n_jobs=n_jobs, + **kw_cwt, + ) + else: + out = tfr_array_multitaper( + data, + sfreq, + freqs, + n_cycles=n_cycles, + time_bandwidth=mt_bandwidth, + output="complex", + decim=decim, + n_jobs=n_jobs, + **kw_mt, + ) + out = np.squeeze(out, axis=0) + else: + out = data + + # give tapers dim to cwt_morlet output if mode == "cwt_morlet": - out = tfr_array_morlet( - data, - sfreq, - freqs, - n_cycles=n_cycles, - output="complex", - decim=decim, - n_jobs=n_jobs, - **kw_cwt, - ) - out = np.expand_dims(out, axis=2) # same dims with multitaper - weights = None - elif mode == "multitaper": - out = tfr_array_multitaper( - data, - sfreq, - freqs, - n_cycles=n_cycles, - time_bandwidth=mt_bandwidth, - output="complex", - decim=decim, - n_jobs=n_jobs, - **kw_mt, - ) + out = np.expand_dims(out, axis=1) + + # compute taper weights + if mode == "multitaper": if isinstance(n_cycles, int | float): n_cycles = [n_cycles] * len(freqs) - mt_bandwidth = mt_bandwidth if mt_bandwidth else 4 - n_tapers = int(np.floor(mt_bandwidth - 1)) - weights = np.zeros((n_tapers, len(freqs), out.shape[-1])) + n_tapers = out.shape[-3] + n_times = out.shape[-1] + half_nbw = mt_bandwidth / 2.0 + weights = np.zeros((n_tapers, len(freqs), n_times)) for i, (f, n_c) in enumerate(zip(freqs, n_cycles)): window_length = np.arange(0.0, n_c / float(f), 1.0 / sfreq).shape[0] - half_nbw = mt_bandwidth / 2.0 - n_tapers = int(np.floor(mt_bandwidth - 1)) _, eigvals = dpss_windows(window_length, half_nbw, n_tapers, sym=False) weights[:, i, :] = np.sqrt(eigvals[:, np.newaxis]) # weights have shape (n_tapers, n_freqs, n_times) else: - raise ValueError("Mode must be 'cwt_morlet' or 'multitaper'.") - - out = np.squeeze(out, axis=0) + weights = None + # pad spectrum and weights if padding: if padding < 0: raise ValueError(f"Padding cannot be negative, got {padding}.") @@ -843,6 +912,7 @@ def _spectral_connectivity( # compute for each connectivity method scores = {} patterns = {} + n_cons = len(source_idx) conn = _parallel_con( out, method,