Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add mne-connectivity dependency and wrap spectral_connectivity_time there #34

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e .['full']
# install dev version of mne-connectivity
pip install -U https://github.com/mne-tools/mne-connectivity/archive/main.zip

- name: Test with pytest 🔧
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
.vscode

# C extensions
*.so
Expand Down
4 changes: 3 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ Frites
.. figure:: _static/logo_desc.png
:align: center

.. _MNE-Python: https://mne.tools/stable
.. _MNE-Connectivity: https://mne.tools/mne-connectivity/dev/

Description
+++++++++++

**Frites** is a Python toolbox for assessing information-based measures on human and animal neurophysiological data (M/EEG, Intracranial). The toolbox also includes directed and undirected connectivity metrics such as group-level statistics on measures of information (information-theory, machine-learning and measures of distance).

The toolbox builds off the popular `MNE-Python`_ and `MNE-Connectivity`_ packages to perform connectivity analyses.

Highlights
++++++++++
Expand Down
64 changes: 14 additions & 50 deletions frites/conn/conn_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
import numpy as np
import xarray as xr
from mne_connectivity import spectral_connectivity_time

from frites.conn import conn_io
from frites.io import set_log_level, logger, check_attrs
Expand Down Expand Up @@ -188,73 +189,36 @@ def conn_spec(
raise NotImplementedError("Frequency dependent kernel in development"
f"only first {sm_times[0]} will be used")

# _________________________________ METHODS _______________________________
conn_f, f_name = {
'coh': (_coh, 'Coherence'),
'plv': (_plv, "Phase-Locking Value"),
'sxy': (_cs, "Cross-spectrum")
}[metric]
if mode == 'morlet':
mode = 'cwt_morlet'
conn = spectral_connectivity_time(
data=data, names=roi, method=metric, sfreq=sfreq,
foi=foi, sm_times=sm_times, sm_freqs=sm_freqs, sm_kernel=sm_kernel,
mode=mode, n_cycles=n_cycles, mt_bandwidth=mt_bandwidth, freqs=freqs,
block_size=block_size, n_jobs=n_jobs, verbose=verbose)

# _________________________________ INPUTS ________________________________
# inputs conversion
data, cfg = conn_io(
data, times=times, roi=roi, agg_ch=False, win_sample=None, pairs=pairs,
sort=True, block_size=block_size, sfreq=sfreq, freqs=freqs, foi=foi,
sm_times=sm_times, sm_freqs=sm_freqs, verbose=verbose,
name=f'Sepctral connectivity (metric = {f_name}, mode={mode})',
)

# extract variables
x, trials, attrs = data.data, data['y'].data, cfg['attrs']
times, n_trials = data['times'].data, len(trials)
x_s, x_t, roi_p = cfg['x_s'], cfg['x_t'], cfg['roi_p']
indices, sfreq = cfg['blocks'], cfg['sfreq']
freqs, _, foi_idx = cfg['freqs'], cfg['need_foi'], cfg['foi_idx']
_, trials, attrs = data.data, data['y'].data, cfg['attrs']
times, _ = data['times'].data, len(trials)
_, _, roi_p = cfg['x_s'], cfg['x_t'], cfg['roi_p']
_, sfreq = cfg['blocks'], cfg['sfreq']
freqs, _, _ = cfg['freqs'], cfg['need_foi'], cfg['foi_idx']
f_vec, sm_times, sm_freqs = cfg['f_vec'], cfg['sm_times'], cfg['sm_freqs']
n_pairs, n_freqs = len(x_s), len(freqs)

# temporal decimation
if isinstance(decim, int):
times = times[::decim]
sm_times = int(np.round(sm_times / decim))
sm_times = max(sm_times, 1)

# Create smoothing kernel
kernel = _create_kernel(sm_times, sm_freqs, kernel=sm_kernel)

# define arguments for parallel computing
mesg = f'Estimating pairwise {f_name} for trials %s'
kw_para = dict(n_jobs=n_jobs, verbose=verbose, total=n_pairs)

# show info
logger.info(f"Computing pairwise {f_name} (n_pairs={n_pairs}, "
f"n_freqs={n_freqs}, decim={decim}, sm_times={sm_times}, "
f"sm_freqs={sm_freqs})")

# ______________________ CONTAINER FOR CONNECTIVITY _______________________
# compute coherence on blocks of trials
conn = np.zeros((n_trials, n_pairs, len(f_vec), len(times)), dtype=dtype)
for tr in indices:
# --------------------------- TIME-FREQUENCY --------------------------
# time-frequency decomposition
w = _tf_decomp(
x[tr, ...], sfreq, freqs, n_cycles=n_cycles, decim=decim,
mode=mode, mt_bandwidth=mt_bandwidth, kw_cwt=kw_cwt, kw_mt=kw_mt,
n_jobs=n_jobs)

# ----------------------------- CONN TRIALS ---------------------------
# give indication about computed trials
kw_para['mesg'] = mesg % f"{tr[0]}...{tr[-1]}"

# computes conn across trials
conn_tr = conn_f(w, kernel, foi_idx, x_s, x_t, kw_para)

# merge results
conn[tr, ...] = np.stack(conn_tr, axis=1)

# Call GC
del conn_tr, w

# _________________________________ OUTPUTS _______________________________
# configuration
cfg = dict(
Expand All @@ -264,7 +228,7 @@ def conn_spec(
)

# conversion
conn = xr.DataArray(conn, dims=('trials', 'roi', 'freqs', 'times'),
conn = xr.DataArray(conn.get_data(), dims=('trials', 'roi', 'freqs', 'times'),
name=metric, coords=(trials, roi_p, f_vec, times),
attrs=check_attrs({**attrs, **cfg}))
return conn
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ joblib
xarray
netCDF4
h5netcdf
mne-connectivity