From a9dc973581c4a9faf94de82f908ebcf0de8ac776 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Fri, 6 Sep 2024 18:18:16 +0200 Subject: [PATCH 01/14] Add TFR support spec_conn_epochs --- mne_connectivity/spectral/epochs.py | 122 ++++++++++++------ .../spectral/epochs_multivariate.py | 12 +- .../spectral/tests/test_spectral.py | 50 +++---- 3 files changed, 116 insertions(+), 68 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 7ba0a154..4a5502fe 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -12,6 +12,12 @@ from mne.epochs import BaseEpochs from mne.parallel import parallel_func from mne.source_estimate import _BaseSourceEstimate +from mne.time_frequency import ( + EpochsSpectrum, + EpochsSpectrumArray, + EpochsTFR, + EpochsTFRArray, +) from mne.time_frequency.multitaper import ( _compute_mt_params, _csd_from_mt, @@ -19,12 +25,8 @@ _psd_from_mt, _psd_from_mt_adaptive, ) -from mne.time_frequency.spectrum import ( - BaseSpectrum, - EpochsSpectrum, - EpochsSpectrumArray, -) -from mne.time_frequency.tfr import cwt, morlet +from mne.time_frequency.spectrum import BaseSpectrum +from mne.time_frequency.tfr import BaseTFR, cwt, morlet from mne.utils import _arange_div, _check_option, _time_mask, logger, verbose, warn from ..base import SpectralConnectivity, SpectroTemporalConnectivity @@ -161,17 +163,19 @@ def _prepare_connectivity( """Check and precompute dimensions of results data.""" first_epoch = epoch_block[0] - # Sort times and freqs - if spectrum_computed: + # Sort times + if spectrum_computed and times_in is None: # if Spectrum object passed in as data n_signals = first_epoch[0].shape[0] times = None - n_times = None + n_times = 0 times_in = None - n_times_in = None + n_times_in = 0 tmin_idx = None tmax_idx = None warn_times = False - else: + else: # if data has a time dimension (timeseries or TFR object) + if spectrum_computed: # if TFR object passed in as data + first_epoch = (first_epoch[0][:, 0],) # just take first freq ( n_signals, times, @@ -184,6 +188,9 @@ def _prepare_connectivity( ) = _check_times( data=first_epoch, sfreq=sfreq, times=times_in, tmin=tmin, tmax=tmax ) + + # Sort freqs + if not spectrum_computed: # if timeseries passed in as data # check that fmin corresponds to at least 5 cycles fmin = _check_freqs(sfreq=sfreq, fmin=fmin, n_times=n_times) # compute frequencies to analyze based on number of samples, sampling rate, @@ -511,14 +518,19 @@ def _epoch_spectral_connectivity( # compute tapered spectra if spectrum_computed: # use existing spectral info - # XXX: Will need to distinguish time-resolved spectra here if support added - # Select signals & freqs of interest (flexible indexing for optional tapers dim) - x_t = np.array(data)[:, sig_idx][..., freq_mask] # split dims to avoid np.ix_ - if weights is None: # also assumes no tapers dim - x_t = np.expand_dims(x_t, axis=2) # CSD construction expects a tapers dim - weights = np.ones((1, 1, 1)) + # Select entries of interest (flexible indexing for optional tapers dim) + if tmin_idx is not None and tmax_idx is not None: + x_t = np.asarray(data)[:, sig_idx][..., freq_mask, tmin_idx:tmax_idx] + else: + x_t = np.asarray(data)[:, sig_idx][..., freq_mask] + if weights is None: # assumes no tapers dim + x_t = np.expand_dims(x_t, axis=2) # CSD construction expects tapers dim + weights = np.ones((1, 1, 1)) if accumulate_psd: - this_psd = _psd_from_mt(x_t, weights) + if weights is not None: # only None if mode == 'cwt_morlet' + this_psd = _psd_from_mt(x_t, weights) + else: + this_psd = (x_t * x_t.conj()).real else: # compute spectral info from scratch x_t, this_psd, weights = _compute_spectra( data=data, @@ -727,14 +739,15 @@ def spectral_connectivity_epochs( Parameters ---------- - data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs | ~mne.time_frequency.EpochsSpectrum + data : array-like, shape=(n_epochs, n_signals, n_times) | ~mne.Epochs | ~mne.time_frequency.EpochsSpectrum | ~mne.time_frequency.EpochsTFR The data from which to compute connectivity. Can be epoched timeseries data as an :term:`array-like` or :class:`~mne.Epochs` object, or Fourier coefficients - for each epoch as an :class:`~mne.time_frequency.EpochsSpectrum` object. If - timeseries data, the spectral information will be computed according to the - spectral estimation mode (see the ``mode`` parameter). If an - :class:`~mne.time_frequency.EpochsSpectrum` object, this spectral information - will be used and the ``mode`` parameter will be ignored. + for each epoch as an :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object. If timeseries data, the spectral + information will be computed according to the spectral estimation mode (see the + ``mode`` parameter). If an :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object, this spectral information will be + used and the ``mode`` parameter will be ignored. Note that it is also possible to combine multiple timeseries signals by providing a list of tuples, e.g.: :: @@ -748,8 +761,9 @@ def spectral_connectivity_epochs( .. versionchanged:: 0.8 Fourier coefficients stored in an :class:`~mne.time_frequency.EpochsSpectrum` - or :class:`~mne.time_frequency.EpochsSpectrumArray` object can also be passed - in as data. Storing Fourier coefficients requires ``mne >= 1.8``. + or :class:`~mne.time_frequency.EpochsTFR` object can also be passed in as + data. Storing Fourier coefficients in + :class:`~mne.time_frequency.EpochsSpectrum` objects requires ``mne >= 1.8``. %(names)s method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'cohy', @@ -1105,7 +1119,10 @@ def spectral_connectivity_epochs( weights = None metadata = None spectrum_computed = False - if isinstance(data, BaseEpochs | EpochsSpectrum | EpochsSpectrumArray): + if isinstance( + data, + BaseEpochs | EpochsSpectrum | EpochsSpectrumArray | EpochsTFR | EpochsTFRArray, + ): names = data.ch_names sfreq = data.info["sfreq"] @@ -1126,28 +1143,47 @@ def spectral_connectivity_epochs( data.add_annotations_to_metadata(overwrite=True) metadata = data.metadata - if isinstance(data, EpochsSpectrum | EpochsSpectrumArray): + if isinstance( + data, EpochsSpectrum | EpochsSpectrumArray | EpochsTFR | EpochsTFRArray + ): # XXX: Will need to be updated if new Spectrum methods are added if not np.iscomplexobj(data.get_data()): raise TypeError( - "if `data` is an EpochsSpectrum object, it must contain " - "complex-valued Fourier coefficients, such as that returned from " - "Epochs.compute_psd(output='complex')" + "if `data` is an EpochsSpectrum or EpochsTFR object, it must " + "contain complex-valued Fourier coefficients, such as that " + "returned from Epochs.compute_psd/tfr() with `output='complex'`" ) if "segment" in data._dims: raise ValueError( "`data` cannot contain Fourier coefficients for individual segments" ) - if isinstance(data, EpochsSpectrum): # mode can be read mode from Spectrum - mode = data.method - mode = "fourier" if mode == "welch" else mode - else: # spectral method is "unknown", so take mode from data dimensions - # Currently, actual mode doesn't matter as long as we handle tapers and - # their weights in the same way as for multitaper spectra - mode = "multitaper" if "taper" in data._dims else "fourier" + mode = data.method + if isinstance(data, EpochsSpectrum | EpochsSpectrumArray): + if isinstance(data, EpochsSpectrum): # read mode from object + mode = "fourier" if mode == "welch" else mode + else: # infer mode from dimensions + # Currently, actual mode doesn't matter as long as we handle tapers + # and their weights in the same way as for multitaper spectra + mode = "multitaper" if "taper" in data._dims else "fourier" + weights = data.weights + else: + if isinstance(data, EpochsTFR): # read mode from object + if mode != "morlet": # FIXME: Add support for other TFR methods + raise ValueError( + "if `data` is an EpochsTFR object, the spectral method " + "must be 'cwt_morlet'" + ) + else: + if "taper" in data._dims: # FIXME: Add support for multitaper TFR + raise ValueError( + "if `data` is an EpochsTFRArray object, it cannot contain " + "Fourier coefficients for individual tapers" + ) + mode = "cwt_morlet" # currently only supported mode here + times_in = data.times + weights = None # no weights stored in TFR objects spectrum_computed = True freqs = data.freqs - weights = data.weights else: times_in = data.times # input times for Epochs input type elif sfreq is None: @@ -1235,7 +1271,7 @@ def spectral_connectivity_epochs( spectral_params = dict( eigvals=None, window_fun=None, wavelets=None, weights=weights ) - n_times_spectrum = 0 + n_times_spectrum = n_times # 0 if no times n_tapers = None if weights is None else weights.size # unique signals for which we actually need to compute PSD etc. @@ -1289,7 +1325,7 @@ def spectral_connectivity_epochs( logger.info(f" the following metrics will be computed: {metrics_str}") # check dimensions and time scale - if not spectrum_computed: # XXX: Can we assume upstream checks sufficient? + if not spectrum_computed: for this_epoch in epoch_block: _, _, _, warn_times = _get_and_verify_data_sizes( this_epoch, @@ -1469,7 +1505,9 @@ def spectral_connectivity_epochs( freqs=freqs, method=_method, n_nodes=n_nodes, - spec_method=mode if not isinstance(data, BaseSpectrum) else data.method, + spec_method=( + mode if not isinstance(data, BaseSpectrum | BaseTFR) else data.method + ), indices=indices, n_epochs_used=n_epochs, freqs_used=freqs_used, diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 44e91213..0ce49856 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -15,7 +15,12 @@ import numpy as np from mne.epochs import BaseEpochs from mne.parallel import parallel_func -from mne.time_frequency import EpochsSpectrum, EpochsSpectrumArray +from mne.time_frequency import ( + EpochsSpectrum, + EpochsSpectrumArray, + EpochsTFR, + EpochsTFRArray, +) from mne.time_frequency.multitaper import _psd_from_mt from mne.utils import ProgressBar, _validate_type, logger @@ -40,6 +45,11 @@ def _check_rank_input(rank, data, indices): data_arr = _psd_from_mt(data_arr, data.weights) else: data_arr = (data_arr * data_arr.conj()).real + elif isinstance(data, EpochsTFR | EpochsTFRArray): + # XXX: need to change when other types of TFR are supported + data_arr = data.get_data(picks=np.arange(data.info["nchan"])) + # Convert to power and aggregate over time before computing rank + data_arr = np.sum((data_arr * data_arr.conj()).real, axis=-1) else: data_arr = data diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index d13cbea4..a89759a3 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -474,9 +474,9 @@ def test_spectral_connectivity(method, mode): not check_version("mne", "1.8"), reason="Requires MNE v1.8.0 or higher" ) @pytest.mark.parametrize("method", ["coh", "cacoh"]) -@pytest.mark.parametrize("mode", ["multitaper", "fourier"]) +@pytest.mark.parametrize("mode", ["multitaper", "fourier", "cwt_morlet"]) def test_spectral_connectivity_epochs_spectrum_input(method, mode): - """Test spec_conn_epochs works with EpochsSpectrum data as input. + """Test spec_conn_epochs works with EpochsSpectrum/TFR data as input. Important to test both bivariate and multivariate methods, as the latter involves additional steps (e.g., rank computation). @@ -489,7 +489,7 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): n_epochs = 30 n_times = 200 # samples trans_bandwidth = 1.0 # Hz - delay = 10 # samples + delay = 5 # samples data = make_signals_in_freq_bands( n_seeds=n_seeds, @@ -499,7 +499,7 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): n_times=n_times, sfreq=sfreq, trans_bandwidth=trans_bandwidth, - snr=0.5, + snr=0.7, connection_delay=delay, rng_seed=44, ) @@ -512,26 +512,33 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): indices = ([np.arange(n_seeds)], [np.arange(n_targets) + n_seeds]) # Compute Fourier coefficients + cwt_freqs = np.arange(10, 50) # similar to Fourier & multitaper modes kwargs = dict() if mode == "fourier": kwargs.update(window="hann") # default is Hamming, but we need Hanning - coeffs = data.compute_psd( - method="welch" if mode == "fourier" else mode, output="complex", **kwargs - ) + spec_mode = "welch" + elif mode == "cwt_morlet": + kwargs.update(freqs=cwt_freqs) + spec_mode = "morlet" + else: + spec_mode = mode + compute_method = data.compute_tfr if mode == "cwt_morlet" else data.compute_psd + coeffs = compute_method(method=spec_mode, output="complex", **kwargs) # Compute connectivity con = spectral_connectivity_epochs(data=coeffs, method=method, indices=indices) - # Check connectivity from Epochs and Spectrum are equivalent; - # Works for multitaper, but Welch of Spectrum and Fourier of spec_conn are slightly - # off (max. abs. diff. ~0.006) even when what should be identical settings are used + # Check connectivity from Epochs and Spectrum/TFR are equivalent con_from_epochs = spectral_connectivity_epochs( - data=data, method=method, indices=indices, mode=mode + data=data, method=method, indices=indices, mode=mode, cwt_freqs=cwt_freqs ) - if mode == "multitaper": - atol = 0 - else: + # Works for multitaper & Morlet, but Welch of Spectrum and Fourier of spec_conn are + # slightly off (max. abs. diff. ~0.006) even when what should be identical settings + # are used + if mode == "fourier": atol = 7e-3 + else: + atol = 0 # spec_conn_epochs excludes freqs without at least 5 cycles, but not Spectrum fstart = con.freqs.index(con_from_epochs.freqs[0]) assert_allclose( @@ -546,25 +553,18 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): freqs_noise = (freqs < fband[0] - trans_bandwidth * 2) | ( freqs > fband[1] + trans_bandwidth * 2 ) - - # nothing for CaCoh to optimise, so use same thresholds for CaCoh and Coh - if mode == "multitaper": # lower baseline for multitaper - con_thresh = (0.1, 0.3) - else: # higher baseline for Welch/Fourier - con_thresh = (0.2, 0.4) - # check freqs of simulated interaction show strong connectivity - assert_array_less(con_thresh[1], np.abs(con.get_data()[:, freqs_con].mean())) + assert_array_less(0.6, np.abs(con.get_data()[:, freqs_con].mean())) # check freqs of no simulated interaction (just noise) show weak connectivity - assert_array_less(np.abs(con.get_data()[:, freqs_noise].mean()), con_thresh[0]) + assert_array_less(np.abs(con.get_data()[:, freqs_noise].mean()), 0.3) # TODO: Add general test for error catching for spec_conn_epochs @pytest.mark.skipif( not check_version("mne", "1.8"), reason="Requires MNE v1.8.0 or higher" ) -def test_spectral_connectivity_epochs_spectrum_input_error_catch(): - """Test spec_conn_epochs catches error with EpochsSpectrum data as input.""" +def test_spectral_connectivity_epochs_spectrum_tfr_input_error_catch(): + """Test spec_conn_epochs catches errors with EpochsSpectrum/TFR data as input.""" # Generate data rng = np.random.default_rng(44) n_epochs, n_chans, n_times = (5, 2, 50) From 535db699a94549bf682a9329d9e3cb304a31a1f6 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 9 Sep 2024 16:39:04 +0200 Subject: [PATCH 02/14] Update epochs docstring --- mne_connectivity/spectral/epochs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 4a5502fe..2bf48fc4 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -803,7 +803,8 @@ def spectral_connectivity_epochs( mode : str Spectrum estimation mode can be either: 'multitaper', 'fourier', or 'cwt_morlet'. Ignored if ``data`` is an - :class:`~mne.time_frequency.EpochsSpectrum` object. + :class:`~mne.time_frequency.EpochsSpectrum` or + :class:`~mne.time_frequency.EpochsTFR` object. fmin : float | tuple of float The lower frequency of interest. Multiple bands are defined using a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. From 14428f6b917265107216ec22b35fbf3eff8869eb Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 9 Sep 2024 16:39:27 +0200 Subject: [PATCH 03/14] Update rank check comments --- mne_connectivity/spectral/epochs_multivariate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 0ce49856..84d7b240 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -46,9 +46,10 @@ def _check_rank_input(rank, data, indices): else: data_arr = (data_arr * data_arr.conj()).real elif isinstance(data, EpochsTFR | EpochsTFRArray): - # XXX: need to change when other types of TFR are supported + # TFR objs will drop bad channels, so specify picking all channels data_arr = data.get_data(picks=np.arange(data.info["nchan"])) # Convert to power and aggregate over time before computing rank + # XXX: need to change when other types of TFR are supported data_arr = np.sum((data_arr * data_arr.conj()).real, axis=-1) else: data_arr = data From 1b1456d7b593db921f627b92e319dc1343880251 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 10 Sep 2024 14:36:36 +0200 Subject: [PATCH 04/14] Add TFR support spec_conn_time --- .../spectral/tests/test_spectral.py | 103 ++++++++- mne_connectivity/spectral/time.py | 211 ++++++++++++------ 2 files changed, 243 insertions(+), 71 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index a89759a3..ba95e0a9 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1684,6 +1684,83 @@ def test_multivar_spectral_connectivity_time_shapes( assert np.all(np.array(con.indices) == np.array(([[0, 1, 2]], [[3, 4, -1]]))) +@pytest.mark.parametrize("method", ["coh", "cacoh"]) +@pytest.mark.parametrize("mode", ["multitaper", "cwt_morlet"]) +def test_spectral_connectivity_time_tfr_input(method, mode): + """Test spec_conn_time works with EpochsTFR data as input. + + Important to test both bivariate and multivariate methods, as the latter involves + additional steps (e.g., rank computation). + """ + # Simulation parameters & data generation + n_seeds = 2 + n_targets = 2 + fband = (15, 20) # Hz + trans_bandwidth = 1.0 # Hz + + data = make_signals_in_freq_bands( + n_seeds=n_seeds, + n_targets=n_targets, + freq_band=fband, + n_epochs=30, + n_times=200, + sfreq=100, + trans_bandwidth=trans_bandwidth, + snr=0.7, + connection_delay=5, + rng_seed=44, + ) + + if method == "coh": + indices = seed_target_indices( + seeds=np.arange(n_seeds), targets=np.arange(n_targets) + n_seeds + ) + else: + indices = ([np.arange(n_seeds)], [np.arange(n_targets) + n_seeds]) + + # Compute TFR + freqs = np.arange(10, 50) + n_cycles = 5.0 # non-default value to avoid warning in spec_conn_time + mt_bandwidth = 4.0 + kwargs = dict() + if mode == "cwt_morlet": + kwargs.update(zero_mean=False) # default in spec_conn_time + spec_mode = "morlet" + else: + kwargs.update(time_bandwidth=mt_bandwidth) + spec_mode = mode + coeffs = data.compute_tfr( + method=spec_mode, freqs=freqs, n_cycles=n_cycles, output="complex", **kwargs + ) + + # Compute connectivity + con_kwargs = dict( + method=method, + indices=indices, + mode=mode, + freqs=freqs, + n_cycles=n_cycles, + mt_bandwidth=mt_bandwidth, + average=True, + ) + con = spectral_connectivity_time(data=coeffs, **con_kwargs) + + # Check connectivity from Epochs and EpochsTFR are equivalent + con_from_epochs = spectral_connectivity_time(data=data, **con_kwargs) + assert_allclose(np.abs(con.get_data()), np.abs(con_from_epochs.get_data()), atol=0) + + # Check connectivity values are as expected + freqs_con = (freqs >= fband[0]) & (freqs <= fband[1]) + freqs_noise = (freqs < fband[0] - trans_bandwidth * 2) | ( + freqs > fband[1] + trans_bandwidth * 2 + ) + # check freqs of simulated interaction show strong connectivity + assert_array_less(0.6, np.abs(con.get_data()[:, freqs_con].mean())) + # check freqs of no simulated interaction (just noise) show weak connectivity + assert_array_less(np.abs(con.get_data()[:, freqs_noise].mean()), 0.3) + + +# TODO: Add general test for error catching for spec_conn_time @pytest.mark.parametrize("method", ["cacoh", "mic", "mim", _gc, _gc_tr]) @pytest.mark.parametrize("mode", ["multitaper", "cwt_morlet"]) def test_multivar_spectral_connectivity_time_error_catch(method, mode): @@ -1705,7 +1782,7 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): freqs = np.arange(10, 25 + 1) # test type-checking of data - with pytest.raises(TypeError, match="must be an instance of Epochs or a NumPy arr"): + with pytest.raises(TypeError, match="Epochs, EpochsTFR, or a NumPy arr"): spectral_connectivity_time(data="foo", freqs=freqs) # check bad indices without nested array caught @@ -1822,6 +1899,30 @@ def test_multivar_spectral_connectivity_time_error_catch(method, mode): ) +def test_spectral_connectivity_time_tfr_input_error_catch(): + """Test spec_conn_time catches errors with EpochsTFR data as input.""" + # Generate data + rng = np.random.default_rng(44) + n_epochs, n_chans, n_times = (5, 2, 100) + sfreq = 50 + data = rng.random((n_epochs, n_chans, n_times)) + info = create_info(ch_names=n_chans, sfreq=sfreq, ch_types="eeg") + data = EpochsArray(data=data, info=info) + freqs = np.arange(10, 20) + + # Test not Fourier coefficients caught + with pytest.raises(TypeError, match="must contain complex-valued Fourier coeff"): + tfr = data.compute_tfr(method="morlet", freqs=freqs, output="power") + spectral_connectivity_time(data=tfr, freqs=freqs) + + # Catch default value warning (multitaper only) + tfr = data.compute_tfr(method="multitaper", freqs=freqs, output="complex") + with pytest.warns(RuntimeWarning, match="`mt_bandwidth` is not specified"): + spectral_connectivity_time(data=tfr, freqs=freqs) + with pytest.warns(RuntimeWarning, match="`n_cycles` is the default value"): + spectral_connectivity_time(data=tfr, freqs=freqs, mt_bandwidth=4.0) + + def test_save(tmp_path): """Test saving results of spectral connectivity.""" epochs = make_signals_in_freq_bands( diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index a8bcdfb6..ac3b67f0 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -10,8 +10,14 @@ import xarray as xr from mne.epochs import BaseEpochs from mne.parallel import parallel_func -from mne.time_frequency import dpss_windows, tfr_array_morlet, tfr_array_multitaper -from mne.utils import _validate_type, logger, verbose +from mne.time_frequency import ( + EpochsTFR, + EpochsTFRArray, + dpss_windows, + tfr_array_morlet, + tfr_array_multitaper, +) +from mne.utils import _check_option, _validate_type, logger, verbose, warn from ..base import EpochSpectralConnectivity, SpectralConnectivity from ..utils import _check_multivariate_indices, check_indices, fill_doc @@ -66,12 +72,23 @@ def spectral_connectivity_time( Parameters ---------- - data : array_like, shape (n_epochs, n_signals, n_times) | Epochs - The data from which to compute connectivity. + data : array_like, shape (n_epochs, n_signals, n_times) | ~mne.Epochs | ~mne.time_frequency.EpochsTFR + The data from which to compute connectivity. Can be epoched timeseries data as + an :term:`array-like` or :class:`~mne.Epochs` object, or Fourier coefficients + for each epoch as an :class:`~mne.time_frequency.EpochsTFR` object. If + timeseries data, the spectral information will be computed according to the + spectral estimation mode (see the ``mode`` parameter). If an + :class:`~mne.time_frequency.EpochsTFR` object, this spectral information will be + used and the ``mode`` parameter will be ignored. + + .. versionchanged:: 0.8 + Fourier coefficients stored in an :class:`~mne.time_frequency.EpochsTFR` + object can also be passed in as data. freqs : array_like - Array of frequencies of interest for time-frequency decomposition. - Only the frequencies within the range specified by ``fmin`` and - ``fmax`` are used. + Array of frequencies of interest for time-frequency decomposition. Only the + frequencies within the range specified by ``fmin`` and ``fmax`` are used. If + ``data`` is an :class:`~mne.time_frequency.EpochsTFR` object, ``data.freqs`` is + used and this parameter is ignored. method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'pli', 'wpli', 'gc', 'gc_tr']``. These @@ -133,21 +150,26 @@ def spectral_connectivity_time( Amount of time to consider as padding at the beginning and end of each epoch in seconds. See Notes for more information. mode : str - Time-frequency decomposition method. Can be either: 'multitaper', or - 'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and - :func:`mne.time_frequency.tfr_array_morlet` for reference. + Time-frequency decomposition method. Can be either: ``'multitaper'``, or + ``'cwt_morlet'``. See :func:`mne.time_frequency.tfr_array_multitaper` and + :func:`mne.time_frequency.tfr_array_morlet` for reference. Ignored if ``data`` + is an :class:`~mne.time_frequency.EpochsTFR` object. mt_bandwidth : float | None - Product between the temporal window length (in seconds) and the full - frequency bandwidth (in Hz). This product can be seen as the surface - of the window on the time/frequency plane and controls the frequency - bandwidth (thus the frequency resolution) and the number of good - tapers. See :func:`mne.time_frequency.tfr_array_multitaper` - documentation. + Product between the temporal window length (in seconds) and the full frequency + bandwidth (in Hz). This product can be seen as the surface of the window on the + time/frequency plane and controls the frequency bandwidth (thus the frequency + resolution) and the number of good tapers. See + :func:`mne.time_frequency.tfr_array_multitaper` documentation. If ``data`` is an + :class:`~mne.time_frequency.EpochsTFR` object computed with the ``multitaper`` + method, this should match the value used to compute the TFR. n_cycles : float | array_like of float - Number of cycles in the wavelet, either a fixed number or one per - frequency. The number of cycles ``n_cycles`` and the frequencies of - interest ``cwt_freqs`` define the temporal window length. For details, - see :func:`mne.time_frequency.tfr_array_morlet` documentation. + Number of cycles in the wavelet, either a fixed number or one per frequency. The + number of cycles ``n_cycles`` and the frequencies of interest ``freqs`` define + the temporal window length. For details, see + :func:`mne.time_frequency.tfr_array_multitaper` and + :func:`mne.time_frequency.tfr_array_morlet` documentation. If ``data`` is an + :class:`~mne.time_frequency.EpochsTFR` object computed with the ``multitaper`` + method, this should match the value used to compute the TFR. gc_n_lags : int Number of lags to use for the vector autoregressive model when computing Granger causality. Higher values increase computational cost, @@ -369,12 +391,18 @@ def spectral_connectivity_time( References ---------- .. footbibliography:: - """ + """ # noqa: E501 events = None event_id = None # extract data from Epochs object - _validate_type(data, (np.ndarray, BaseEpochs), "`data`", "Epochs or a NumPy array") - if isinstance(data, BaseEpochs): + _validate_type( + data, + (np.ndarray, BaseEpochs, EpochsTFR, EpochsTFRArray), + "`data`", + "Epochs, EpochsTFR, or a NumPy array", + ) + spectrum_computed = False + if isinstance(data, BaseEpochs | EpochsTFR | EpochsTFRArray): names = data.ch_names sfreq = data.info["sfreq"] events = data.events @@ -392,12 +420,40 @@ def spectral_connectivity_time( if hasattr(data, "annotations") and not annots_in_metadata: data.add_annotations_to_metadata(overwrite=True) metadata = data.metadata - # XXX: remove logic once support for mne<1.6 is dropped - kwargs = dict() - if "copy" in inspect.getfullargspec(data.get_data).kwonlyargs: - kwargs["copy"] = False - data = data.get_data(**kwargs) - n_epochs, n_signals, n_times = data.shape + if isinstance(data, BaseEpochs): + # XXX: remove logic once support for mne<1.6 is dropped + kwargs = dict() + if "copy" in inspect.getfullargspec(data.get_data).kwonlyargs: + kwargs["copy"] = False + data = data.get_data(**kwargs) + n_epochs, n_signals, n_times = data.shape + else: + freqs = data.freqs # use freqs from EpochsTFR object + # Read mode from the EpochsTFR object + mode = "cwt_morlet" if data.method == "morlet" else data.method + # TFR objs will drop bad channels, so specify picking all channels + data = data.get_data(picks=np.arange(data.info["nchan"])) + if not np.iscomplexobj(data): + raise TypeError( + "if `data` is an EpochsTFR object, it must contain complex-valued " + "Fourier coefficients, such as that returned from " + "Epochs.compute_tfr() with `output='complex'`" + ) + n_epochs, n_signals = data.shape[:2] + n_times = data.shape[-1] + # Warn if mt_bandwidth & n_cycles default-valued (can't be read from TFR) + if mode == "multitaper": # doesn't matter for cwt_morlet + if mt_bandwidth is None: + warn( + "`mt_bandwidth` is not specified; assuming 4.0 Hz was used to " + "compute the TFR" + ) + if n_cycles == 7.0: + warn( + "`n_cycles` is the default value; assuming 7.0 was used to " + "compute the TFR" + ) + spectrum_computed = True else: data = np.asarray(data) n_epochs, n_signals, n_times = data.shape @@ -520,8 +576,7 @@ def spectral_connectivity_time( n_components = _check_n_components_input(n_components, rank) if n_components == 1: # n_components=0 means space for a components dimension is not allocated in - # the results, similar to how n_times_spectrum=0 is used to indicate that - # time is not a dimension in the results + # the results n_components = 0 else: rank = None @@ -620,9 +675,10 @@ def spectral_connectivity_time( padding=padding, kw_cwt={}, kw_mt={}, + multivariate_con=multivariate_con, + spectrum_computed=spectrum_computed, n_jobs=n_jobs, verbose=verbose, - multivariate_con=multivariate_con, ) for epoch_idx in np.arange(n_epochs): @@ -713,16 +769,17 @@ def _spectral_connectivity( padding, kw_cwt, kw_mt, + multivariate_con, + spectrum_computed, n_jobs, verbose, - multivariate_con, ): """Estimate time-resolved connectivity for one epoch. Parameters ---------- - data : array_like, shape (n_channels, n_times) - Time-series data. + data : array_like, shape (channels, [freqs,] [tapers,] times) + Time-series data or time-frequency data. method : list of str List of connectivity metrics to compute. kernel : array_like, shape (n_sm_fres, n_sm_times) @@ -766,6 +823,8 @@ def _spectral_connectivity( epoch in seconds. multivariate_con : bool Whether or not multivariate connectivity is to be computed. + spectrum_computed : bool + Whether or not the time-frequency decomposition has already been computed. Returns ------- @@ -784,51 +843,62 @@ def _spectral_connectivity( and target signals (respectively). ``n_comps`` is present for valid multivariate methods if ``n_components > 0``. """ - n_cons = len(source_idx) - data = np.expand_dims(data, axis=0) - kw_cwt.setdefault("zero_mean", False) # avoid FutureWarning + # check that spectral mode is recognised + _check_option("mode", mode, ("cwt_morlet", "multitaper")) + + # compute time-frequency decomposition + mt_bandwidth = mt_bandwidth if mt_bandwidth else 4 + if not spectrum_computed: + data = np.expand_dims(data, axis=0) + kw_cwt.setdefault("zero_mean", False) # avoid FutureWarning + if mode == "cwt_morlet": + out = tfr_array_morlet( + data, + sfreq, + freqs, + n_cycles=n_cycles, + output="complex", + decim=decim, + n_jobs=n_jobs, + **kw_cwt, + ) + else: + out = tfr_array_multitaper( + data, + sfreq, + freqs, + n_cycles=n_cycles, + time_bandwidth=mt_bandwidth, + output="complex", + decim=decim, + n_jobs=n_jobs, + **kw_mt, + ) + out = np.squeeze(out, axis=0) + else: + out = data + + # give tapers dim to cwt_morlet output if mode == "cwt_morlet": - out = tfr_array_morlet( - data, - sfreq, - freqs, - n_cycles=n_cycles, - output="complex", - decim=decim, - n_jobs=n_jobs, - **kw_cwt, - ) - out = np.expand_dims(out, axis=2) # same dims with multitaper - weights = None - elif mode == "multitaper": - out = tfr_array_multitaper( - data, - sfreq, - freqs, - n_cycles=n_cycles, - time_bandwidth=mt_bandwidth, - output="complex", - decim=decim, - n_jobs=n_jobs, - **kw_mt, - ) + out = np.expand_dims(out, axis=1) + + # compute taper weights + if mode == "multitaper": if isinstance(n_cycles, int | float): n_cycles = [n_cycles] * len(freqs) - mt_bandwidth = mt_bandwidth if mt_bandwidth else 4 - n_tapers = int(np.floor(mt_bandwidth - 1)) - weights = np.zeros((n_tapers, len(freqs), out.shape[-1])) + n_tapers = out.shape[-3] + n_times = out.shape[-1] + half_nbw = mt_bandwidth / 2.0 + weights = np.zeros((n_tapers, len(freqs), n_times)) for i, (f, n_c) in enumerate(zip(freqs, n_cycles)): window_length = np.arange(0.0, n_c / float(f), 1.0 / sfreq).shape[0] - half_nbw = mt_bandwidth / 2.0 - n_tapers = int(np.floor(mt_bandwidth - 1)) _, eigvals = dpss_windows(window_length, half_nbw, n_tapers, sym=False) weights[:, i, :] = np.sqrt(eigvals[:, np.newaxis]) # weights have shape (n_tapers, n_freqs, n_times) else: - raise ValueError("Mode must be 'cwt_morlet' or 'multitaper'.") - - out = np.squeeze(out, axis=0) + weights = None + # pad spectrum and weights if padding: if padding < 0: raise ValueError(f"Padding cannot be negative, got {padding}.") @@ -843,6 +913,7 @@ def _spectral_connectivity( # compute for each connectivity method scores = {} patterns = {} + n_cons = len(source_idx) conn = _parallel_con( out, method, From 6fd086122e632117bb1bba8fd3e7cdc095216c5f Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 10 Sep 2024 14:39:21 +0200 Subject: [PATCH 05/14] Switch tests to custom MNE branch --- .circleci/config.yml | 2 +- .github/workflows/linux_conda.yml | 2 +- .github/workflows/unit_tests.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 62cffbb1..868a3437 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -77,7 +77,7 @@ jobs: - run: name: Get Python running and install dependencies command: | - pip install git+https://github.com/mne-tools/mne-python@main + pip install git+https://github.com/tsbinns/mne-python@fix_tfr_multitaper curl https://raw.githubusercontent.com/mne-tools/mne-python/main/tools/circleci_dependencies.sh -o circleci_dependencies.sh chmod +x circleci_dependencies.sh ./circleci_dependencies.sh diff --git a/.github/workflows/linux_conda.yml b/.github/workflows/linux_conda.yml index d0ec1f34..3ea2d25e 100644 --- a/.github/workflows/linux_conda.yml +++ b/.github/workflows/linux_conda.yml @@ -41,7 +41,7 @@ jobs: source ./get_minimal_commands.sh pip install .[test] name: 'Install dependencies' - - run: pip install git+https://github.com/mne-tools/mne-python@main + - run: pip install git+https://github.com/tsbinns/mne-python@fix_tfr_multitaper - run: pip install -e . - run: | which mne diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 650490d5..84fd27b0 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -60,7 +60,7 @@ jobs: run: pip install --upgrade mne - name: Install MNE (main) if: matrix.mne-version == 'mne-main' - run: pip install git+https://github.com/mne-tools/mne-python@main + run: pip install git+https://github.com/tsbinns/mne-python@fix_tfr_multitaper - run: python -c "import mne; print(mne.datasets.testing.data_path(verbose=True))" if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' - name: Display versions and environment information From a38f8fd16e421de238a9eb90b221507641ebf481 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 10 Sep 2024 14:46:10 +0200 Subject: [PATCH 06/14] Fix failing tfr_error test --- mne_connectivity/spectral/tests/test_spectral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index ba95e0a9..b5481061 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1918,7 +1918,7 @@ def test_spectral_connectivity_time_tfr_input_error_catch(): # Catch default value warning (multitaper only) tfr = data.compute_tfr(method="multitaper", freqs=freqs, output="complex") with pytest.warns(RuntimeWarning, match="`mt_bandwidth` is not specified"): - spectral_connectivity_time(data=tfr, freqs=freqs) + spectral_connectivity_time(data=tfr, freqs=freqs, n_cycles=5.0) with pytest.warns(RuntimeWarning, match="`n_cycles` is the default value"): spectral_connectivity_time(data=tfr, freqs=freqs, mt_bandwidth=4.0) From 9c04f81e49d1eeed1bd5193dc54cddca4c1c81c9 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 10 Sep 2024 14:51:30 +0200 Subject: [PATCH 07/14] Fix time_tfr tolerances --- mne_connectivity/spectral/tests/test_spectral.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index b5481061..23001e4d 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1745,9 +1745,12 @@ def test_spectral_connectivity_time_tfr_input(method, mode): ) con = spectral_connectivity_time(data=coeffs, **con_kwargs) - # Check connectivity from Epochs and EpochsTFR are equivalent + # Check connectivity from Epochs and EpochsTFR are equivalent (small but non-zero + # tolerance given due to some platform-dependent variation) con_from_epochs = spectral_connectivity_time(data=data, **con_kwargs) - assert_allclose(np.abs(con.get_data()), np.abs(con_from_epochs.get_data()), atol=0) + assert_allclose( + np.abs(con.get_data()), np.abs(con_from_epochs.get_data()), atol=1e-7 + ) # Check connectivity values are as expected freqs_con = (freqs >= fband[0]) & (freqs <= fband[1]) @@ -1760,6 +1763,9 @@ def test_spectral_connectivity_time_tfr_input(method, mode): assert_array_less(np.abs(con.get_data()[:, freqs_noise].mean()), 0.3) +test_spectral_connectivity_time_tfr_input("cacoh", "cwt_morlet") + + # TODO: Add general test for error catching for spec_conn_time @pytest.mark.parametrize("method", ["cacoh", "mic", "mim", _gc, _gc_tr]) @pytest.mark.parametrize("mode", ["multitaper", "cwt_morlet"]) From c8614494a783fcec2af060de996e395d63d9910d Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 10 Sep 2024 14:58:14 +0200 Subject: [PATCH 08/14] Fix misleading error message --- mne_connectivity/spectral/epochs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 2bf48fc4..23cdb080 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -1172,7 +1172,7 @@ def spectral_connectivity_epochs( if mode != "morlet": # FIXME: Add support for other TFR methods raise ValueError( "if `data` is an EpochsTFR object, the spectral method " - "must be 'cwt_morlet'" + "must be 'morlet'" ) else: if "taper" in data._dims: # FIXME: Add support for multitaper TFR From 7013e19dd85c5a88ade51bb1a67b51e9380c8459 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 10 Sep 2024 15:56:22 +0200 Subject: [PATCH 09/14] Fix spec_conn_time docstring error --- mne_connectivity/spectral/time.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index ac3b67f0..badc9362 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -121,8 +121,8 @@ def spectral_connectivity_time( connections between all channels are computed, unless a Granger causality method is called, in which case an error is raised. sfreq : float - The sampling frequency. Required if data is not - :class:`Epochs `. + The sampling frequency. Required if ``data`` is not an :class:`~mne.Epochs` or + :class:`~mne.time_frequency.EpochsTFR` object. fmin : float | tuple of float | None The lower frequency of interest. Multiple bands are defined using a tuple, e.g., ``(8., 20.)`` for two bands with 8 Hz and 20 Hz lower From 160bc64d1ae7b1cc6aeba55f2929bcc2c8bc0c66 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 17 Sep 2024 11:48:36 +0200 Subject: [PATCH 10/14] Revert "Switch tests to custom MNE branch" This reverts commit 6fd086122e632117bb1bba8fd3e7cdc095216c5f. --- .circleci/config.yml | 2 +- .github/workflows/linux_conda.yml | 2 +- .github/workflows/unit_tests.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 868a3437..62cffbb1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -77,7 +77,7 @@ jobs: - run: name: Get Python running and install dependencies command: | - pip install git+https://github.com/tsbinns/mne-python@fix_tfr_multitaper + pip install git+https://github.com/mne-tools/mne-python@main curl https://raw.githubusercontent.com/mne-tools/mne-python/main/tools/circleci_dependencies.sh -o circleci_dependencies.sh chmod +x circleci_dependencies.sh ./circleci_dependencies.sh diff --git a/.github/workflows/linux_conda.yml b/.github/workflows/linux_conda.yml index 3ea2d25e..d0ec1f34 100644 --- a/.github/workflows/linux_conda.yml +++ b/.github/workflows/linux_conda.yml @@ -41,7 +41,7 @@ jobs: source ./get_minimal_commands.sh pip install .[test] name: 'Install dependencies' - - run: pip install git+https://github.com/tsbinns/mne-python@fix_tfr_multitaper + - run: pip install git+https://github.com/mne-tools/mne-python@main - run: pip install -e . - run: | which mne diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 84fd27b0..650490d5 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -60,7 +60,7 @@ jobs: run: pip install --upgrade mne - name: Install MNE (main) if: matrix.mne-version == 'mne-main' - run: pip install git+https://github.com/tsbinns/mne-python@fix_tfr_multitaper + run: pip install git+https://github.com/mne-tools/mne-python@main - run: python -c "import mne; print(mne.datasets.testing.data_path(verbose=True))" if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.11' - name: Display versions and environment information From 8e79c9ff0cd7dbf801ef17bb53ee25804233feb2 Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Fri, 20 Sep 2024 14:34:35 +0200 Subject: [PATCH 11/14] Apply suggestions from code review Co-authored-by: Daniel McCloy --- mne_connectivity/spectral/epochs.py | 22 +++++++------------ .../spectral/epochs_multivariate.py | 2 +- mne_connectivity/spectral/time.py | 8 +++---- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 23cdb080..186f5c3f 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -164,17 +164,16 @@ def _prepare_connectivity( first_epoch = epoch_block[0] # Sort times - if spectrum_computed and times_in is None: # if Spectrum object passed in as data + if spectrum_computed and times_in is None: # is a Spectrum object n_signals = first_epoch[0].shape[0] times = None n_times = 0 - times_in = None n_times_in = 0 tmin_idx = None tmax_idx = None warn_times = False - else: # if data has a time dimension (timeseries or TFR object) - if spectrum_computed: # if TFR object passed in as data + else: # data has a time dimension (timeseries or TFR object) + if spectrum_computed: # is a TFR object first_epoch = (first_epoch[0][:, 0],) # just take first freq ( n_signals, @@ -190,7 +189,7 @@ def _prepare_connectivity( ) # Sort freqs - if not spectrum_computed: # if timeseries passed in as data + if not spectrum_computed: # is an (ordinary) timeseries # check that fmin corresponds to at least 5 cycles fmin = _check_freqs(sfreq=sfreq, fmin=fmin, n_times=n_times) # compute frequencies to analyze based on number of samples, sampling rate, @@ -746,8 +745,8 @@ def spectral_connectivity_epochs( :class:`~mne.time_frequency.EpochsTFR` object. If timeseries data, the spectral information will be computed according to the spectral estimation mode (see the ``mode`` parameter). If an :class:`~mne.time_frequency.EpochsSpectrum` or - :class:`~mne.time_frequency.EpochsTFR` object, this spectral information will be - used and the ``mode`` parameter will be ignored. + :class:`~mne.time_frequency.EpochsTFR` object, existing spectral information + will be used and the ``mode`` parameter will be ignored. Note that it is also possible to combine multiple timeseries signals by providing a list of tuples, e.g.: :: @@ -1120,10 +1119,7 @@ def spectral_connectivity_epochs( weights = None metadata = None spectrum_computed = False - if isinstance( - data, - BaseEpochs | EpochsSpectrum | EpochsSpectrumArray | EpochsTFR | EpochsTFRArray, - ): + if isinstance(data, BaseEpochs | EpochsSpectrum | EpochsTFR): names = data.ch_names sfreq = data.info["sfreq"] @@ -1144,9 +1140,7 @@ def spectral_connectivity_epochs( data.add_annotations_to_metadata(overwrite=True) metadata = data.metadata - if isinstance( - data, EpochsSpectrum | EpochsSpectrumArray | EpochsTFR | EpochsTFRArray - ): + if isinstance(data, EpochsSpectrum | EpochsTFR): # XXX: Will need to be updated if new Spectrum methods are added if not np.iscomplexobj(data.get_data()): raise TypeError( diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index 84d7b240..ae36bf47 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -45,7 +45,7 @@ def _check_rank_input(rank, data, indices): data_arr = _psd_from_mt(data_arr, data.weights) else: data_arr = (data_arr * data_arr.conj()).real - elif isinstance(data, EpochsTFR | EpochsTFRArray): + elif isinstance(data, EpochsTFR): # TFR objs will drop bad channels, so specify picking all channels data_arr = data.get_data(picks=np.arange(data.info["nchan"])) # Convert to power and aggregate over time before computing rank diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index badc9362..251ee11d 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -78,8 +78,8 @@ def spectral_connectivity_time( for each epoch as an :class:`~mne.time_frequency.EpochsTFR` object. If timeseries data, the spectral information will be computed according to the spectral estimation mode (see the ``mode`` parameter). If an - :class:`~mne.time_frequency.EpochsTFR` object, this spectral information will be - used and the ``mode`` parameter will be ignored. + :class:`~mne.time_frequency.EpochsTFR` object, existing spectral information + will be used and the ``mode`` parameter will be ignored. .. versionchanged:: 0.8 Fourier coefficients stored in an :class:`~mne.time_frequency.EpochsTFR` @@ -397,12 +397,12 @@ def spectral_connectivity_time( # extract data from Epochs object _validate_type( data, - (np.ndarray, BaseEpochs, EpochsTFR, EpochsTFRArray), + (np.ndarray, BaseEpochs, EpochsTFR), "`data`", "Epochs, EpochsTFR, or a NumPy array", ) spectrum_computed = False - if isinstance(data, BaseEpochs | EpochsTFR | EpochsTFRArray): + if isinstance(data, BaseEpochs | EpochsTFR): names = data.ch_names sfreq = data.info["sfreq"] events = data.events From 23fd89c09f37f7a8dd538932d8212ef5c748ea37 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Sep 2024 12:35:05 +0000 Subject: [PATCH 12/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne_connectivity/spectral/epochs.py | 1 - mne_connectivity/spectral/epochs_multivariate.py | 1 - mne_connectivity/spectral/time.py | 1 - 3 files changed, 3 deletions(-) diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 186f5c3f..3ec0bb43 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -16,7 +16,6 @@ EpochsSpectrum, EpochsSpectrumArray, EpochsTFR, - EpochsTFRArray, ) from mne.time_frequency.multitaper import ( _compute_mt_params, diff --git a/mne_connectivity/spectral/epochs_multivariate.py b/mne_connectivity/spectral/epochs_multivariate.py index ae36bf47..1531d4cb 100644 --- a/mne_connectivity/spectral/epochs_multivariate.py +++ b/mne_connectivity/spectral/epochs_multivariate.py @@ -19,7 +19,6 @@ EpochsSpectrum, EpochsSpectrumArray, EpochsTFR, - EpochsTFRArray, ) from mne.time_frequency.multitaper import _psd_from_mt from mne.utils import ProgressBar, _validate_type, logger diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index 251ee11d..c14923dc 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -12,7 +12,6 @@ from mne.parallel import parallel_func from mne.time_frequency import ( EpochsTFR, - EpochsTFRArray, dpss_windows, tfr_array_morlet, tfr_array_multitaper, From 641fa1d6ef185f070021fa7364b1bc87a64a5126 Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Mon, 23 Sep 2024 16:45:43 +0200 Subject: [PATCH 13/14] Name expected conn values --- mne_connectivity/spectral/tests/test_spectral.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 23001e4d..d2299387 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -553,10 +553,12 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): freqs_noise = (freqs < fband[0] - trans_bandwidth * 2) | ( freqs > fband[1] + trans_bandwidth * 2 ) + WEAK_CONN_OR_NOISE = 0.3 # conn values outside of simulated fband should be < this + STRONG_CONN = 0.6 # conn values inside simulated fband should be > this # check freqs of simulated interaction show strong connectivity - assert_array_less(0.6, np.abs(con.get_data()[:, freqs_con].mean())) + assert_array_less(STRONG_CONN, np.abs(con.get_data()[:, freqs_con].mean())) # check freqs of no simulated interaction (just noise) show weak connectivity - assert_array_less(np.abs(con.get_data()[:, freqs_noise].mean()), 0.3) + assert_array_less(np.abs(con.get_data()[:, freqs_noise].mean()), WEAK_CONN_OR_NOISE) # TODO: Add general test for error catching for spec_conn_epochs From 232c0e988410e3eefbbce796ff875105adefe7ca Mon Sep 17 00:00:00 2001 From: Thomas Samuel Binns Date: Tue, 24 Sep 2024 11:09:03 +0200 Subject: [PATCH 14/14] Update Welch-Fourier variation message --- mne_connectivity/spectral/tests/test_spectral.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index d2299387..93da85d0 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -533,8 +533,10 @@ def test_spectral_connectivity_epochs_spectrum_input(method, mode): data=data, method=method, indices=indices, mode=mode, cwt_freqs=cwt_freqs ) # Works for multitaper & Morlet, but Welch of Spectrum and Fourier of spec_conn are - # slightly off (max. abs. diff. ~0.006) even when what should be identical settings - # are used + # slightly off (max. abs. diff. ~0.006). This is due to the Spectrum object using + # scipy.signal.spectrogram to compute the coefficients, while spec_conn uses + # scipy.signal.rfft, which give slightly different outputs even with identical + # settings. if mode == "fourier": atol = 7e-3 else: