diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index d1b72006f..b93c8e648 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -25,6 +25,8 @@ jobs: run: | python -m pip install --upgrade pip pip install -e .['full'] + # install dev version of mne-connectivity + pip install -U https://github.com/mne-tools/mne-connectivity/archive/main.zip - name: Test with pytest 🔧 run: | diff --git a/.gitignore b/.gitignore index e915b2ef8..68e4cdd39 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +.vscode # C extensions *.so diff --git a/docs/source/index.rst b/docs/source/index.rst index 4f57de27c..2ca43b6e7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -24,12 +24,14 @@ Frites .. figure:: _static/logo_desc.png :align: center +.. _MNE-Python: https://mne.tools/stable +.. _MNE-Connectivity: https://mne.tools/mne-connectivity/dev/ Description +++++++++++ **Frites** is a Python toolbox for assessing information-based measures on human and animal neurophysiological data (M/EEG, Intracranial). The toolbox also includes directed and undirected connectivity metrics such as group-level statistics on measures of information (information-theory, machine-learning and measures of distance). - +The toolbox builds off the popular `MNE-Python`_ and `MNE-Connectivity`_ packages to perform connectivity analyses. Highlights ++++++++++ diff --git a/frites/conn/conn_spec.py b/frites/conn/conn_spec.py index da68d6744..ceff8c903 100644 --- a/frites/conn/conn_spec.py +++ b/frites/conn/conn_spec.py @@ -7,6 +7,7 @@ """ import numpy as np import xarray as xr +from mne_connectivity import spectral_connectivity_time from frites.conn import conn_io from frites.io import set_log_level, logger, check_attrs @@ -188,12 +189,13 @@ def conn_spec( raise NotImplementedError("Frequency dependent kernel in development" f"only first {sm_times[0]} will be used") - # _________________________________ METHODS _______________________________ - conn_f, f_name = { - 'coh': (_coh, 'Coherence'), - 'plv': (_plv, "Phase-Locking Value"), - 'sxy': (_cs, "Cross-spectrum") - }[metric] + if mode == 'morlet': + mode = 'cwt_morlet' + conn = spectral_connectivity_time( + data=data, names=roi, method=metric, sfreq=sfreq, + foi=foi, sm_times=sm_times, sm_freqs=sm_freqs, sm_kernel=sm_kernel, + mode=mode, n_cycles=n_cycles, mt_bandwidth=mt_bandwidth, freqs=freqs, + block_size=block_size, n_jobs=n_jobs, verbose=verbose) # _________________________________ INPUTS ________________________________ # inputs conversion @@ -201,17 +203,15 @@ def conn_spec( data, times=times, roi=roi, agg_ch=False, win_sample=None, pairs=pairs, sort=True, block_size=block_size, sfreq=sfreq, freqs=freqs, foi=foi, sm_times=sm_times, sm_freqs=sm_freqs, verbose=verbose, - name=f'Sepctral connectivity (metric = {f_name}, mode={mode})', ) # extract variables - x, trials, attrs = data.data, data['y'].data, cfg['attrs'] - times, n_trials = data['times'].data, len(trials) - x_s, x_t, roi_p = cfg['x_s'], cfg['x_t'], cfg['roi_p'] - indices, sfreq = cfg['blocks'], cfg['sfreq'] - freqs, _, foi_idx = cfg['freqs'], cfg['need_foi'], cfg['foi_idx'] + _, trials, attrs = data.data, data['y'].data, cfg['attrs'] + times, _ = data['times'].data, len(trials) + _, _, roi_p = cfg['x_s'], cfg['x_t'], cfg['roi_p'] + _, sfreq = cfg['blocks'], cfg['sfreq'] + freqs, _, _ = cfg['freqs'], cfg['need_foi'], cfg['foi_idx'] f_vec, sm_times, sm_freqs = cfg['f_vec'], cfg['sm_times'], cfg['sm_freqs'] - n_pairs, n_freqs = len(x_s), len(freqs) # temporal decimation if isinstance(decim, int): @@ -219,42 +219,6 @@ def conn_spec( sm_times = int(np.round(sm_times / decim)) sm_times = max(sm_times, 1) - # Create smoothing kernel - kernel = _create_kernel(sm_times, sm_freqs, kernel=sm_kernel) - - # define arguments for parallel computing - mesg = f'Estimating pairwise {f_name} for trials %s' - kw_para = dict(n_jobs=n_jobs, verbose=verbose, total=n_pairs) - - # show info - logger.info(f"Computing pairwise {f_name} (n_pairs={n_pairs}, " - f"n_freqs={n_freqs}, decim={decim}, sm_times={sm_times}, " - f"sm_freqs={sm_freqs})") - - # ______________________ CONTAINER FOR CONNECTIVITY _______________________ - # compute coherence on blocks of trials - conn = np.zeros((n_trials, n_pairs, len(f_vec), len(times)), dtype=dtype) - for tr in indices: - # --------------------------- TIME-FREQUENCY -------------------------- - # time-frequency decomposition - w = _tf_decomp( - x[tr, ...], sfreq, freqs, n_cycles=n_cycles, decim=decim, - mode=mode, mt_bandwidth=mt_bandwidth, kw_cwt=kw_cwt, kw_mt=kw_mt, - n_jobs=n_jobs) - - # ----------------------------- CONN TRIALS --------------------------- - # give indication about computed trials - kw_para['mesg'] = mesg % f"{tr[0]}...{tr[-1]}" - - # computes conn across trials - conn_tr = conn_f(w, kernel, foi_idx, x_s, x_t, kw_para) - - # merge results - conn[tr, ...] = np.stack(conn_tr, axis=1) - - # Call GC - del conn_tr, w - # _________________________________ OUTPUTS _______________________________ # configuration cfg = dict( @@ -264,7 +228,7 @@ def conn_spec( ) # conversion - conn = xr.DataArray(conn, dims=('trials', 'roi', 'freqs', 'times'), + conn = xr.DataArray(conn.get_data(), dims=('trials', 'roi', 'freqs', 'times'), name=metric, coords=(trials, roi_p, f_vec, times), attrs=check_attrs({**attrs, **cfg})) return conn diff --git a/requirements.txt b/requirements.txt index 9fb42c2a0..daa9e1d42 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ joblib xarray netCDF4 h5netcdf +mne-connectivity