From 05eb0a255034388c81584d7b4616adf87c7c67ed Mon Sep 17 00:00:00 2001 From: jadrew43 Date: Tue, 14 Jun 2022 16:42:18 -0700 Subject: [PATCH 01/17] implement example of state-space model for connectivity --- examples/mne_util.py | 289 +++++++++++++++++++++++++++ examples/state_space_connectivity.py | 59 ++++++ 2 files changed, 348 insertions(+) create mode 100644 examples/mne_util.py create mode 100644 examples/state_space_connectivity.py diff --git a/examples/mne_util.py b/examples/mne_util.py new file mode 100644 index 00000000..4dd8e298 --- /dev/null +++ b/examples/mne_util.py @@ -0,0 +1,289 @@ +""" MNE-Python utility functions for preprocessing data and constructing + matrices necessary for MEGLDS analysis """ + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import mne +import numpy as np +import os.path as op + +from mne.io.pick import pick_types +from mne.utils import logger +from mne import label_sign_flip + +from scipy.sparse import csc_matrix, csr_matrix, diags + + +class ROIToSourceMap(object): + """ class for computing ROI-to-source space mapping matrix """ + + def __init__(self, fwd, labels, label_flip=False): + + src = fwd['src'] + + roiidx = list() + vertidx = list() + + n_lhverts = len(src[0]['vertno']) + n_rhverts = len(src[1]['vertno']) + n_verts = n_lhverts + n_rhverts + offsets = {'lh': 0, 'rh': n_lhverts} + + hemis = {'lh': 0, 'rh': 1} + + # index vector of which ROI a source point belongs to + Q_J = np.zeros(n_verts, dtype=np.int64) + + data = [] + for li, lab in enumerate(labels): + + this_data = np.round(label_sign_flip(lab, src)) + if not label_flip: + this_data.fill(1.) + data.append(this_data) + if isinstance(lab, mne.Label): + comp_labs = [lab] + elif isinstance(lab, mne.BiHemiLabel): + comp_labs = [lab.lh, lab.rh] + + for clab in comp_labs: + hemi = clab.hemi + hi = 0 if hemi == 'lh' else 1 + + lverts = clab.get_vertices_used(vertices=src[hi]['vertno']) + + # gets the indices in the source space vertex array, not the huge + # array. + # use `src[hi]['vertno'][lverts]` to get surface vertex indices to + # plot. + lverts = np.searchsorted(src[hi]['vertno'], lverts) + lverts += offsets[hemi] + vertidx.extend(lverts) + roiidx.extend(np.full(lverts.size, li, dtype=np.int64)) + + # add 1 b/c 0 corresponds to unassigned variance + Q_J[lverts] = li + 1 + + N = len(labels) + M = n_verts + + # construct sparse L matrix + data = np.concatenate(data) + vertidx = np.array(vertidx, int) + roiidx = np.array(roiidx, int) + assert data.shape == vertidx.shape == roiidx.shape + L = csc_matrix((data, (vertidx, roiidx)), shape=(M, N)) + + self.fwd = fwd + self.L = L + self.Q_J = Q_J + self.offsets = offsets + self.n_lhverts = n_lhverts + self.n_rhverts = n_rhverts + self.labels = labels + + return + + @property + def G(self): + return self.fwd['sol']['data'] + + @property + def L(self): + return self._L + + @L.setter + def L(self, val): + self._L = val + + @property + def Q_J(self): + return self._Q_J + + @Q_J.setter + def Q_J(self, val): + self._Q_J = val + + @property + def GL(self): + from util import Carray + return Carray(csr_matrix.dot(self.L.T, self.G.T).T) + + def get_label_vinds(self, label): + li = self.labels.index(label) + if isinstance(label, mne.Label): + label_vert_idx = self.L[:, li].nonzero()[0] + label_vert_idx -= self.offsets[label.hemi] + return label_vert_idx + elif isinstance(label, mne.BiHemiLabel): + # these labels store both hemispheres so subtract the rh offset + # from that part of the vertex array + lh_label_vert_idx = self.L[:self.n_lhverts, li].nonzero()[0] + rh_label_vert_idx = self.L[self.n_lhverts:, li].nonzero()[0] + rh_label_vert_idx[self.n_lhverts:] -= self.offsets['rh'] + return [lh_label_vert_idx, rh_label_vert_idx] + + def get_label_verts(self, label, src): + # if you're thinking of using this to plot, why not just use + # brain.add_label from pysurfer? + if isinstance(label, mne.Label): + hi = 0 if label.hemi == 'lh' else 1 + label_vert_idx = self.get_label_vinds(label) + varray = src[hi]['vertno'][label_vert_idx] + elif isinstance(label, mne.BiHemiLabel): + lh_label_vert_idx, rh_label_vert_idx = self.get_label_vinds(label) + varray = [src[0]['vertno'][lh_label_vert_idx], + src[1]['vertno'][rh_label_vert_idx]] + return varray + + def get_hemi_idx(self, label): + if isinstance(label, mne.Label): + return 0 if label.hemi == 'lh' else 1 + elif isinstance(label, mne.BiHemiLabel): + hemis = [None] * 2 + for i, lab in enumerate([label.lh, label.rh]): + hemis[i] = 0 if lab.hemi == 'lh' else 1 + return hemis + + +def morph_labels(labels, subject_to, subjects_dir=None): + """ morph labels from fsaverage to specified subject """ + + if subjects_dir is None: + subjects_dir = mne.utils.get_subjects_dir() + + if isinstance(labels, mne.Label): + labels = [labels] + + labels_morphed = list() + for lab in labels: + if isinstance(lab, mne.Label): + labels_morphed.append(lab.copy()) + elif isinstance(lab, mne.BiHemiLabel): + labels_morphed.append(lab.lh.copy() + lab.rh.copy()) + + for i, l in enumerate(labels_morphed): + if l.subject == subject_to: + continue + elif l.subject == 'unknown': + print("uknown subject for label %s" % l.name, + "assuming if is 'fsaverage' and morphing") + l.subject = 'fsaverage' + + if isinstance(l, mne.Label): + l.values.fill(1.0) + labels_morphed[i] = l.morph(subject_to=subject_to, + subjects_dir=subjects_dir) + elif isinstance(l, mne.BiHemiLabel): + l.lh.values.fill(1.0) + l.rh.values.fill(1.0) + labels_morphed[i].lh = l.lh.morph(subject_to=subject_to, + subjects_dir=subjects_dir) + labels_morphed[i].rh = l.rh.morph(subject_to=subject_to, + subjects_dir=subjects_dir) + + # make sure there are no duplicate labels + labels_morphed = sorted(list(set(labels_morphed)), key=lambda x: x.name) + + return labels_morphed + + +def apply_projs(epochs, fwd, cov): + """ apply projection operators to fwd and cov """ + proj, _ = mne.io.proj.setup_proj(epochs.info, activate=False) + G = fwd['sol']['data'] + fwd['sol']['data'] = np.dot(proj, G) + + Q = cov.data + if not np.allclose(np.dot(proj, Q), Q): + Q = np.dot(proj, np.dot(Q, proj.T)) + cov.data = Q + + return fwd, cov + + +def scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., + grad_scale=1.): + """ apply per-channel-type scaling to epochs, forward, and covariance """ + # from util import Carray ##skip import + Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') + Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') + Carray = Carray64 + + # get indices for each channel type + ch_names = cov['names'] # same as self.fwd['info']['ch_names'] + sel_eeg = pick_types(fwd['info'], meg=False, eeg=True, ref_meg=False) + sel_mag = pick_types(fwd['info'], meg='mag', eeg=False, ref_meg=False) + sel_grad = pick_types(fwd['info'], meg='grad', eeg=False, ref_meg=False) + idx_eeg = [ch_names.index(ch_names[c]) for c in sel_eeg] + idx_mag = [ch_names.index(ch_names[c]) for c in sel_mag] + idx_grad = [ch_names.index(ch_names[c]) for c in sel_grad] + + # retrieve forward and sensor covariance + G = fwd['sol']['data'].copy() + Q = cov.data.copy() + + # scale forward matrix + G[idx_eeg,:] *= eeg_scale + G[idx_mag,:] *= mag_scale + G[idx_grad,:] *= grad_scale + + # construct GL matrix + GL = Carray(csr_matrix.dot(roi_to_src.L.T, G.T).T) + + # scale sensor covariance + Q[np.ix_(idx_eeg, idx_eeg)] *= eeg_scale**2 + Q[np.ix_(idx_mag, idx_mag)] *= mag_scale**2 + Q[np.ix_(idx_grad, idx_grad)] *= grad_scale**2 + + # scale epochs + info = epochs.info.copy() + data = epochs.get_data().copy() + + data[:,idx_eeg,:] *= eeg_scale + data[:,idx_mag,:] *= mag_scale + data[:,idx_grad,:] *= grad_scale + + epochs = mne.EpochsArray(data, info) + + return G, GL, Q, epochs + + +def combine_medial_labels(labels, subject='fsaverage', surf='white', + dist_limit=0.02): + """ combine each hemi pair of labels on medial wall into single label """ + subjects_dir = mne.get_config('SUBJECTS_DIR') + rrs = dict((hemi, mne.read_surface(op.join(subjects_dir, subject, 'surf', + '%s.%s' % (hemi, surf)))[0] / 1000.) + for hemi in ('lh', 'rh')) + use_labels = list() + used = np.zeros(len(labels), bool) + + logger.info('Matching medial regions for %s labels on %s %s, d=%0.1f mm' + % (len(labels), subject, surf, 1000 * dist_limit)) + + for li1, l1 in enumerate(labels): + if used[li1]: + continue + used[li1] = True + use_label = l1.copy() + rr1 = rrs[l1.hemi][l1.vertices] + for li2 in np.where(~used)[0]: + l2 = labels[li2] + same_name = (l2.name.replace(l2.hemi, '') == + l1.name.replace(l1.hemi, '')) + if l2.hemi != l1.hemi and same_name: + rr2 = rrs[l2.hemi][l2.vertices] + mean_min = np.mean(mne.surface._compute_nearest( + rr1, rr2, return_dists=True)[1]) + if mean_min <= dist_limit: + use_label += l2 + used[li2] = True + logger.info(' Matched: ' + l1.name) + use_labels.append(use_label) + + logger.info('Total %d labels' % (len(use_labels),)) + + return use_labels diff --git a/examples/state_space_connectivity.py b/examples/state_space_connectivity.py new file mode 100644 index 00000000..65facd56 --- /dev/null +++ b/examples/state_space_connectivity.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue Jun 14 13:15:43 2022 + +@author: jordandrew + +For 'mne-connectivity/examples/' to show usage of MEGLDS +Use MNE-sample-data for auditory/left +""" +import mne +import numpy as np +import matplotlib.pyplot as plt +from mne_util import ROIToSourceMap, scale_sensor_data + +# data_path = mne.datasets.sample.data_path() +data_path = '/Users/jordandrew/Documents/MEG/mne_data/MNE-sample-data' +sample_folder = '/MEG/sample' +raw_fname = data_path + sample_folder + '/sample_audvis_raw.fif' #how many subjects? +subjects_dir = data_path + '/subjects' + +raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) +events = mne.find_events(raw, stim_channel='STI 014') +""" OR +raw_events_fname = data_path + sample_folder + '/sample_audvis_raw-eve.fif' +events = mne.read_events(raw_events_fname) +""" + + +## compute forward solution +sphere = mne.make_sphere_model('auto', 'auto', raw.info) +src = mne.setup_volume_source_space(sphere=sphere, exclude=30., pos=15.) +fwd = mne.make_forward_solution(raw.info, trans=None, src=src, bem=sphere) +fwd['src'].append( fwd['src'][0]) #fwd['src'] needs lh and rh; duplicated here + + +#event_id = 1 +event_dict = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3, + 'visual/right': 4, 'face': 5, 'buttonpress': 32} +epochs = mne.Epochs(raw, events, tmin=-0.3, tmax=0.7, event_id=event_dict, + preload=True) +# del raw + +## compute covariance +noise_cov = mne.compute_covariance(epochs, tmax=0) #tmax=0 assuming no activity from tmin to tmax? +labels = mne.read_labels_from_annot('sample', subjects_dir=subjects_dir) +roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map +scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} +fwd_sr_sn, fwd_roi_sn, snsr_cov, epochs = \ + scale_sensor_data(epochs, fwd, noise_cov, roi_to_src, **scales) + + + + + + + +# model = MEGLDS(fwd, labels, noise_cov) # only needs the forward, labels, and noise_cov to be initialized +# model.fit(epochs) # now only needs epochs to fit \ No newline at end of file From fb40a1da57f81388179459c7e81aaf0c7c6a8278 Mon Sep 17 00:00:00 2001 From: jadrew43 Date: Thu, 16 Jun 2022 17:49:44 -0700 Subject: [PATCH 02/17] mne_util imported from MEGLDS repo; some incompatibility with sample-data --- examples/mne_util.py | 13 +++++++++---- examples/state_space_connectivity.py | 16 ++++++++++------ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/mne_util.py b/examples/mne_util.py index 4dd8e298..04b1c796 100644 --- a/examples/mne_util.py +++ b/examples/mne_util.py @@ -207,7 +207,7 @@ def apply_projs(epochs, fwd, cov): def scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., grad_scale=1.): """ apply per-channel-type scaling to epochs, forward, and covariance """ - # from util import Carray ##skip import + # from util import Carray ##skip import just pasted; util also from MEGLDS repo Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') Carray = Carray64 @@ -217,9 +217,14 @@ def scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., sel_eeg = pick_types(fwd['info'], meg=False, eeg=True, ref_meg=False) sel_mag = pick_types(fwd['info'], meg='mag', eeg=False, ref_meg=False) sel_grad = pick_types(fwd['info'], meg='grad', eeg=False, ref_meg=False) - idx_eeg = [ch_names.index(ch_names[c]) for c in sel_eeg] - idx_mag = [ch_names.index(ch_names[c]) for c in sel_mag] - idx_grad = [ch_names.index(ch_names[c]) for c in sel_grad] + #2 channels are removed so idx != ch_name + #can we do idx = c for c in sel?? + #idx_eeg = [ch_names.index(ch_names[c]) for c in sel_eeg] + #idx_mag = [ch_names.index(ch_names[c]) for c in sel_mag] + #idx_grad = [ch_names.index(ch_names[c]) for c in sel_grad] + idx_eeg = [c for c in sel_eeg] + idx_mag = [c for c in sel_mag] + idx_grad = [c for c in sel_grad] # retrieve forward and sensor covariance G = fwd['sol']['data'].copy() diff --git a/examples/state_space_connectivity.py b/examples/state_space_connectivity.py index 65facd56..da294887 100644 --- a/examples/state_space_connectivity.py +++ b/examples/state_space_connectivity.py @@ -8,17 +8,21 @@ For 'mne-connectivity/examples/' to show usage of MEGLDS Use MNE-sample-data for auditory/left """ + +## import necessary libraries import mne import numpy as np import matplotlib.pyplot as plt -from mne_util import ROIToSourceMap, scale_sensor_data +from mne_util import ROIToSourceMap, scale_sensor_data #mne_util is from MEGLDS repo -# data_path = mne.datasets.sample.data_path() +## define paths to sample data data_path = '/Users/jordandrew/Documents/MEG/mne_data/MNE-sample-data' +# data_path = mne.datasets.sample.data_path() sample_folder = '/MEG/sample' raw_fname = data_path + sample_folder + '/sample_audvis_raw.fif' #how many subjects? subjects_dir = data_path + '/subjects' +## import raw data and find events raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) events = mne.find_events(raw, stim_channel='STI 014') """ OR @@ -26,15 +30,15 @@ events = mne.read_events(raw_events_fname) """ - ## compute forward solution sphere = mne.make_sphere_model('auto', 'auto', raw.info) src = mne.setup_volume_source_space(sphere=sphere, exclude=30., pos=15.) fwd = mne.make_forward_solution(raw.info, trans=None, src=src, bem=sphere) fwd['src'].append( fwd['src'][0]) #fwd['src'] needs lh and rh; duplicated here +#is there a reason the sample data only has 1 hemisphere of data? - -#event_id = 1 +## define epochs using event_dict +# event_id = 1 event_dict = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3, 'visual/right': 4, 'face': 5, 'buttonpress': 32} epochs = mne.Epochs(raw, events, tmin=-0.3, tmax=0.7, event_id=event_dict, @@ -42,7 +46,7 @@ # del raw ## compute covariance -noise_cov = mne.compute_covariance(epochs, tmax=0) #tmax=0 assuming no activity from tmin to tmax? +noise_cov = mne.compute_covariance(epochs, tmax=0) labels = mne.read_labels_from_annot('sample', subjects_dir=subjects_dir) roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} From c65356487b42da61c5bda03a85825d51083c1f9d Mon Sep 17 00:00:00 2001 From: jadrew43 Date: Thu, 23 Jun 2022 15:15:55 -0700 Subject: [PATCH 03/17] added Yang et al 2016 to bib --- doc/references.bib | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/doc/references.bib b/doc/references.bib index 24b48327..aa79d595 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -238,4 +238,15 @@ @article{StamEtAl2012 year={2012}, month={Sep}, pages={1415–1428} +} + +@inproceedings{yang_state-space_2016, + title = {A state-space model of cross-region dynamic connectivity in {MEG}/{EEG}}, + volume = {29}, + url = {https://proceedings.neurips.cc/paper/2016/hash/9f396fe44e7c05c16873b05ec425cbad-Abstract.html}, + urldate = {2021-11-21}, + booktitle = {Advances in {Neural} {Information} {Processing} {Systems}}, + publisher = {Curran Associates, Inc.}, + author = {Yang, Ying and Aminoff, Elissa and Tarr, Michael and Robert, Kass E}, + year = {2016} } \ No newline at end of file From 60a04adfc15a4adbb877ba319a6978571426c50f Mon Sep 17 00:00:00 2001 From: jadrew43 Date: Fri, 24 Jun 2022 13:18:51 -0700 Subject: [PATCH 04/17] changed var names (not in models.py), added folder for LDS execution, moved functions into LDS() --- examples/megssm/__init__.py | 0 examples/megssm/message_passing.py | 732 ++++++++++++++++++++++ examples/megssm/mne_util.py | 217 +++++++ examples/megssm/models.py | 867 +++++++++++++++++++++++++++ examples/megssm/numpy_numthreads.py | 91 +++ examples/megssm/util.py | 117 ++++ examples/state_space_connectivity.py | 70 ++- 7 files changed, 2069 insertions(+), 25 deletions(-) create mode 100755 examples/megssm/__init__.py create mode 100755 examples/megssm/message_passing.py create mode 100644 examples/megssm/mne_util.py create mode 100755 examples/megssm/models.py create mode 100755 examples/megssm/numpy_numthreads.py create mode 100755 examples/megssm/util.py diff --git a/examples/megssm/__init__.py b/examples/megssm/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/examples/megssm/message_passing.py b/examples/megssm/message_passing.py new file mode 100755 index 00000000..21f08a76 --- /dev/null +++ b/examples/megssm/message_passing.py @@ -0,0 +1,732 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import autograd.numpy as np +from autograd.scipy.linalg import block_diag + +from .util import T_, sym, dot3, _ensure_ndim, component_matrix, hs + +try: + from autograd_linalg import solve_triangular +except ImportError: + raise RuntimeError("must install `autograd_linalg` package") + +# einsum2 is a parallel version of einsum that works for two arguments +try: + from einsum2 import einsum2 +except ImportError: + # rename standard numpy function if don't have einsum2 + print("=> WARNING: using standard numpy.einsum,", + "consider installing einsum2 package") + from numpy import einsum as einsum2 + + +def kalman_filter(Y, A, C, Q, R, mu0, Q0, store_St=True, sum_logliks=True): + """ Kalman filter that broadcasts over the first dimension. + Handles multiple lag dependence using component form. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D*nlag, D*nlag) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N = Y.shape[0] + T, D, Dnlags = A.shape + nlags = Dnlags // D + AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) + + p = C.shape[0] + CC = hs([C, np.zeros((p, D*(nlags-1)))]) + + QQ = np.zeros((T, Dnlags, Dnlags)) + QQ[:,:D,:D] = Q + + QQ0 = block_diag(*[Q0 for _ in range(nlags)]) + + mu_predict = np.stack([np.tile(mu0, nlags) for _ in range(N)], axis=0) + sigma_predict = np.stack([QQ0 for _ in range(N)], axis=0) + + St = np.empty((N, T, p, p)) if store_St else None + + mus_filt = np.zeros((N, T, Dnlags)) + sigmas_filt = np.zeros((N, T, Dnlags, Dnlags)) + + ll = np.zeros(T) + + for t in range(T): + + # condition + # dot3(CC, sigma_predict, CC.T) + R + tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict) + sigma_pred = np.dot(tmp1, CC.T) + R + sigma_pred = sym(sigma_pred) + + if St is not None: + St[...,t,:,:] = sigma_pred + + res = Y[...,t,:] - np.dot(mu_predict, CC.T) + + L = np.linalg.cholesky(sigma_pred) + v = solve_triangular(L, res, lower=True) + + # log-likelihood over all trials + ll[t] = -0.5*(2.*np.sum(np.log(np.diagonal(L, axis1=1, axis2=2))) + + np.sum(v*v) + + N*p*np.log(2.*np.pi)) + + mus_filt[...,t,:] = mu_predict + einsum2('nki,nk->ni', tmp1, + solve_triangular(L, v, 'T', lower=True)) + + tmp2 = solve_triangular(L, tmp1, lower=True) + sigmas_filt[...,t,:,:] = sym(sigma_predict - einsum2('nki,nkj->nij', tmp2, tmp2)) + + # prediction + mu_predict = einsum2('ik,nk->ni', AA[t], mus_filt[...,t,:]) + + sigma_predict = einsum2('ik,nkl->nil', AA[t], sigmas_filt[...,t,:,:]) + sigma_predict = sym(einsum2('nil,jl->nij', sigma_predict, AA[t]) + QQ[t]) + + if sum_logliks: + ll = np.sum(ll) + return ll, mus_filt, sigmas_filt, St + + +def rts_smooth(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False, + store_St=True): + """ RTS smoother that broadcasts over the first dimension. + Handles multiple lag dependence using component form. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D*nlag, D*nlag) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N, T, _ = Y.shape + _, D, Dnlags = A.shape + nlags = Dnlags // D + AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) + + p = C.shape[0] + CC = hs([C, np.zeros((p, D*(nlags-1)))]) + + QQ = np.zeros((T, Dnlags, Dnlags)) + QQ[:,:D,:D] = Q + + QQ0 = block_diag(*[Q0 for _ in range(nlags)]) + + mu_predict = np.empty((N, T+1, Dnlags)) + sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) + + mus_smooth = np.empty((N, T, Dnlags)) + sigmas_smooth = np.empty((N, T, Dnlags, Dnlags)) + + St = np.empty((N, T, p, p)) if store_St else None + + if compute_lag1_cov: + sigmas_smooth_tnt = np.empty((N, T-1, Dnlags, Dnlags)) + else: + sigmas_smooth_tnt = None + + ll = 0. + mu_predict[:,0,:] = np.tile(mu0, nlags) + sigma_predict[:,0,:,:] = QQ0.copy() + + for t in range(T): + + # condition + # sigma_x = dot3(C, sigma_predict, C.T) + R + tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) + sigma_x = einsum2('nik,jk->nij', tmp1, CC) + R + sigma_x = sym(sigma_x) + + if St is not None: + St[...,t,:,:] = sigma_x + + L = np.linalg.cholesky(sigma_x) + # res[n] = Y[n,t,:] = np.dot(C, mu_predict[n,t,:]) + res = Y[...,t,:] - einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) + v = solve_triangular(L, res, lower=True) + + # log-likelihood over all trials + ll += -0.5*(2.*np.sum(np.log(np.diagonal(L, axis1=1, axis2=2))) + + np.sum(v*v) + + N*p*np.log(2.*np.pi)) + + mus_smooth[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', + tmp1, + solve_triangular(L, v, trans='T', lower=True)) + + # tmp2 = L^{-1}*C*sigma_predict + tmp2 = solve_triangular(L, tmp1, lower=True) + sigmas_smooth[:,t,:,:] = sym(sigma_predict[:,t,:,:] - einsum2('nki,nkj->nij', tmp2, tmp2)) + + # prediction + #mu_predict = np.dot(A[t], mus_smooth[t]) + mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_smooth[:,t,:]) + + #sigma_predict = dot3(A[t], sigmas_smooth[t], A[t].T) + Q[t] + tmp = einsum2('ik,nkl->nil', AA[t], sigmas_smooth[:,t,:,:]) + sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) + + for t in range(T-2, -1, -1): + + # these names are stolen from mattjj and slinderman + #temp_nn = np.dot(A[t], sigmas_smooth[n,t,:,:]) + temp_nn = einsum2('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) + + L = np.linalg.cholesky(sigma_predict[:,t+1,:,:]) + v = solve_triangular(L, temp_nn, lower=True) + # Look in Saarka for dfn of Gt_T + Gt_T = solve_triangular(L, v, trans='T', lower=True) + + # {mus,sigmas}_smooth[n,t] contain the filtered estimates so we're + # overwriting them on purpose + #mus_smooth[n,t,:] = mus_smooth[n,t,:] + np.dot(T_(Gt_T), mus_smooth[n,t+1,:] - mu_predict[t+1,:]) + mus_smooth[:,t,:] = mus_smooth[:,t,:] + einsum2('nki,nk->ni', Gt_T, mus_smooth[:,t+1,:] - mu_predict[:,t+1,:]) + + #sigmas_smooth[n,t,:,:] = sigmas_smooth[n,t,:,:] + dot3(T_(Gt_T), sigmas_smooth[n,t+1,:,:] - temp_nn, Gt_T) + tmp = einsum2('nki,nkj->nij', Gt_T, sigmas_smooth[:,t+1,:,:] - sigma_predict[:,t+1,:,:]) + tmp = einsum2('nik,nkj->nij', tmp, Gt_T) + sigmas_smooth[:,t,:,:] = sym(sigmas_smooth[:,t,:,:] + tmp) + + if compute_lag1_cov: + # This matrix is NOT symmetric, so don't symmetrize! + #sigmas_smooth_tnt[n,t,:,:] = np.dot(sigmas_smooth[n,t+1,:,:], Gt_T) + sigmas_smooth_tnt[:,t,:,:] = einsum2('nik,nkj->nij', sigmas_smooth[:,t+1,:,:], Gt_T) + + return ll, mus_smooth, sigmas_smooth, sigmas_smooth_tnt, St + + +def rts_smooth_fast(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False): + """ RTS smoother that broadcasts over the first dimension. + Handles multiple lag dependence using component form. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D*nlag, D*nlag) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N, T, _ = Y.shape + _, D, Dnlags = A.shape + nlags = Dnlags // D + AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) + + L_R = np.linalg.cholesky(R) + + p = C.shape[0] + CC = hs([C, np.zeros((p, D*(nlags-1)))]) + tmp = solve_triangular(L_R, CC, lower=True) + Rinv_CC = solve_triangular(L_R, tmp, trans='T', lower=True) + CCT_Rinv_CC = einsum2('ki,kj->ij', CC, Rinv_CC) + + # tile L_R across number of trials so solve_triangular + # can broadcast over trials properly + L_R = np.tile(L_R, (N, 1, 1)) + + QQ = np.zeros((T, Dnlags, Dnlags)) + QQ[:,:D,:D] = Q + + QQ0 = block_diag(*[Q0 for _ in range(nlags)]) + + mu_predict = np.empty((N, T+1, Dnlags)) + sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) + + mus_smooth = np.empty((N, T, Dnlags)) + sigmas_smooth = np.empty((N, T, Dnlags, Dnlags)) + + if compute_lag1_cov: + sigmas_smooth_tnt = np.empty((N, T-1, Dnlags, Dnlags)) + else: + sigmas_smooth_tnt = None + + ll = 0. + mu_predict[:,0,:] = np.tile(mu0, nlags) + sigma_predict[:,0,:,:] = QQ0.copy() + + I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) + + for t in range(T): + + # condition + tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) + + res = Y[...,t,:] - einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) + + # Rinv * res + tmp2 = solve_triangular(L_R, res, lower=True) + tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) + + # C^T Rinv * res + tmp3 = einsum2('ki,nk->ni', Rinv_CC, res) + + # (Pinv + C^T Rinv C)_inv * tmp3 + L_P = np.linalg.cholesky(sigma_predict[:,t,:,:]) + tmp = solve_triangular(L_P, I_tiled, lower=True) + Pinv = solve_triangular(L_P, tmp, trans='T', lower=True) + tmp4 = sym(Pinv + CCT_Rinv_CC) + L_tmp4 = np.linalg.cholesky(tmp4) + tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) + tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) + + # Rinv C * tmp3 + tmp3 = einsum2('ik,nk->ni', Rinv_CC, tmp3) + + # add the two Woodbury * res terms together + tmp = tmp2 - tmp3 + + mus_smooth[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', tmp1, tmp) + + # Rinv * tmp1 + tmp2 = solve_triangular(L_R, tmp1, lower=True) + tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) + + # C^T Rinv * tmp1 + tmp3 = einsum2('ki,nkj->nij', Rinv_CC, tmp1) + + # (Pinv + C^T Rinv C)_inv * tmp3 + tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) + tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) + + # Rinv C * tmp3 + tmp3 = einsum2('ik,nkj->nij', Rinv_CC, tmp3) + + # add the two Woodbury * tmp1 terms together, left-multiply by tmp1 + tmp = einsum2('nki,nkj->nij', tmp1, tmp2 - tmp3) + + sigmas_smooth[:,t,:,:] = sym(sigma_predict[:,t,:,:] - tmp) + + # prediction + mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_smooth[:,t,:]) + + #sigma_predict = dot3(A[t], sigmas_smooth[t], A[t].T) + Q[t] + tmp = einsum2('ik,nkl->nil', AA[t], sigmas_smooth[:,t,:,:]) + sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) + + for t in range(T-2, -1, -1): + + # these names are stolen from mattjj and slinderman + #temp_nn = np.dot(A[t], sigmas_smooth[n,t,:,:]) + temp_nn = einsum2('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) + + L = np.linalg.cholesky(sigma_predict[:,t+1,:,:]) + v = solve_triangular(L, temp_nn, lower=True) + # Look in Saarka for dfn of Gt_T + Gt_T = solve_triangular(L, v, trans='T', lower=True) + + # {mus,sigmas}_smooth[n,t] contain the filtered estimates so we're + # overwriting them on purpose + #mus_smooth[n,t,:] = mus_smooth[n,t,:] + np.dot(T_(Gt_T), mus_smooth[n,t+1,:] - mu_predict[t+1,:]) + mus_smooth[:,t,:] = mus_smooth[:,t,:] + einsum2('nki,nk->ni', Gt_T, mus_smooth[:,t+1,:] - mu_predict[:,t+1,:]) + + #sigmas_smooth[n,t,:,:] = sigmas_smooth[n,t,:,:] + dot3(T_(Gt_T), sigmas_smooth[n,t+1,:,:] - temp_nn, Gt_T) + tmp = einsum2('nki,nkj->nij', Gt_T, sigmas_smooth[:,t+1,:,:] - sigma_predict[:,t+1,:,:]) + tmp = einsum2('nik,nkj->nij', tmp, Gt_T) + sigmas_smooth[:,t,:,:] = sym(sigmas_smooth[:,t,:,:] + tmp) + + if compute_lag1_cov: + # This matrix is NOT symmetric, so don't symmetrize! + #sigmas_smooth_tnt[n,t,:,:] = np.dot(sigmas_smooth[n,t+1,:,:], Gt_T) + sigmas_smooth_tnt[:,t,:,:] = einsum2('nik,nkj->nij', sigmas_smooth[:,t+1,:,:], Gt_T) + + return ll, mus_smooth, sigmas_smooth, sigmas_smooth_tnt + + + +def predict(Y, A, C, Q, R, mu0, Q0, pred_var=False): + if pred_var: + return _predict_mean_var(Y, A, C, Q, R, mu0, Q0) + else: + return _predict_mean(Y, A, C, Q, R, mu0, Q0) + + +def _predict_mean_var(Y, A, C, Q, R, mu0, Q0): + """ Model predictions for Y given model parameters. + + Handles multiple lag dependence using component form. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D*nlag, D*nlag) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N, T, _ = Y.shape + _, D, Dnlags = A.shape + nlags = Dnlags // D + AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) + + L_R = np.linalg.cholesky(R) + + p = C.shape[0] + CC = hs([C, np.zeros((p, D*(nlags-1)))]) + tmp = solve_triangular(L_R, CC, lower=True) + Rinv_CC = solve_triangular(L_R, tmp, trans='T', lower=True) + CCT_Rinv_CC = einsum2('ki,kj->ij', CC, Rinv_CC) + + # tile L_R across number of trials so solve_triangular + # can broadcast over trials properly + L_R = np.tile(L_R, (N, 1, 1)) + + QQ = np.zeros((T, Dnlags, Dnlags)) + QQ[:,:D,:D] = Q + + QQ0 = block_diag(*[Q0 for _ in range(nlags)]) + + mu_predict = np.empty((N, T+1, Dnlags)) + sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) + + mus_filt = np.empty((N, T, Dnlags)) + sigmas_filt = np.empty((N, T, Dnlags, Dnlags)) + + mu_predict[:,0,:] = np.tile(mu0, nlags) + sigma_predict[:,0,:,:] = QQ0.copy() + + I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) + + Yhat = np.empty_like(Y) + St = np.empty((N, T, p, p)) + + for t in range(T): + + # condition + # sigma_x = dot3(C, sigma_predict, C.T) + R + tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) + sigma_x = einsum2('nik,jk->nij', tmp1, CC) + R + sigma_x = sym(sigma_x) + + St[...,t,:,:] = sigma_x + + L = np.linalg.cholesky(sigma_x) + Yhat[...,t,:] = einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) + res = Y[...,t,:] - Yhat[...,t,:] + + v = solve_triangular(L, res, lower=True) + + mus_filt[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', + tmp1, + solve_triangular(L, v, trans='T', lower=True)) + + # tmp2 = L^{-1}*C*sigma_predict + tmp2 = solve_triangular(L, tmp1, lower=True) + sigmas_filt[:,t,:,:] = sym(sigma_predict[:,t,:,:] - einsum2('nki,nkj->nij', tmp2, tmp2)) + + # prediction + #mu_predict = np.dot(A[t], mus_filt[t]) + mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_filt[:,t,:]) + + #sigma_predict = dot3(A[t], sigmas_filt[t], A[t].T) + Q[t] + tmp = einsum2('ik,nkl->nil', AA[t], sigmas_filt[:,t,:,:]) + sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) + + # just return the diagonal of the St matrices for marginal predictive + # variances + return Yhat, np.diagonal(St, axis1=-2, axis2=-1) + + +def _predict_mean(Y, A, C, Q, R, mu0, Q0): + """ Model predictions for Y given model parameters. + Handles multiple lag dependence using component form. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D*nlag, D*nlag) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N, T, _ = Y.shape + _, D, Dnlags = A.shape + nlags = Dnlags // D + AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) + + L_R = np.linalg.cholesky(R) + + p = C.shape[0] + CC = hs([C, np.zeros((p, D*(nlags-1)))]) + tmp = solve_triangular(L_R, CC, lower=True) + Rinv_CC = solve_triangular(L_R, tmp, trans='T', lower=True) + CCT_Rinv_CC = einsum2('ki,kj->ij', CC, Rinv_CC) + + # tile L_R across number of trials so solve_triangular + # can broadcast over trials properly + L_R = np.tile(L_R, (N, 1, 1)) + + QQ = np.zeros((T, Dnlags, Dnlags)) + QQ[:,:D,:D] = Q + + QQ0 = block_diag(*[Q0 for _ in range(nlags)]) + + mu_predict = np.empty((N, T+1, Dnlags)) + sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) + + mus_filt = np.empty((N, T, Dnlags)) + sigmas_filt = np.empty((N, T, Dnlags, Dnlags)) + + mu_predict[:,0,:] = np.tile(mu0, nlags) + sigma_predict[:,0,:,:] = QQ0.copy() + + I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) + + Yhat = np.empty_like(Y) + + for t in range(T): + + # condition + tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) + + Yhat[...,t,:] = einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) + res = Y[...,t,:] - Yhat[...,t,:] + + # Rinv * res + tmp2 = solve_triangular(L_R, res, lower=True) + tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) + + # C^T Rinv * res + tmp3 = einsum2('ki,nk->ni', Rinv_CC, res) + + # (Pinv + C^T Rinv C)_inv * tmp3 + L_P = np.linalg.cholesky(sigma_predict[:,t,:,:]) + tmp = solve_triangular(L_P, I_tiled, lower=True) + Pinv = solve_triangular(L_P, tmp, trans='T', lower=True) + tmp4 = sym(Pinv + CCT_Rinv_CC) + L_tmp4 = np.linalg.cholesky(tmp4) + tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) + tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) + + # Rinv C * tmp3 + tmp3 = einsum2('ik,nk->ni', Rinv_CC, tmp3) + + # add the two Woodbury * res terms together + tmp = tmp2 - tmp3 + + mus_filt[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', tmp1, tmp) + + # Rinv * tmp1 + tmp2 = solve_triangular(L_R, tmp1, lower=True) + tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) + + # C^T Rinv * tmp1 + tmp3 = einsum2('ki,nkj->nij', Rinv_CC, tmp1) + + # (Pinv + C^T Rinv C)_inv * tmp3 + tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) + tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) + + # Rinv C * tmp3 + tmp3 = einsum2('ik,nkj->nij', Rinv_CC, tmp3) + + # add the two Woodbury * tmp1 terms together, left-multiply by tmp1 + tmp = einsum2('nki,nkj->nij', tmp1, tmp2 - tmp3) + + sigmas_filt[:,t,:,:] = sym(sigma_predict[:,t,:,:] - tmp) + + # prediction + mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_filt[:,t,:]) + + #sigma_predict = dot3(A[t], sigmas_filt[t], A[t].T) + Q[t] + tmp = einsum2('ik,nkl->nil', AA[t], sigmas_filt[:,t,:,:]) + sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) + + return Yhat + + +def predict_step(mu_filt, sigma_filt, A, Q): + mu_predict = einsum2('ik,nk->ni', A, mu_filt) + tmp = einsum2('ik,nkl->nil', A, sigma_filt) + sigma_predict = sym(einsum2('nil,jl->nij', tmp, A) + Q) + + return mu_predict, sigma_predict + + +def condition(y, C, R, mu_predict, sigma_predict): + # dot3(C, sigma_predict, C.T) + R + tmp1 = einsum2('ik,nkj->nij', C, sigma_predict) + sigma_pred = einsum2('nik,jk->nij', tmp1, C) + R + sigma_pred = sym(sigma_pred) + + L = np.linalg.cholesky(sigma_pred) + # the transpose works b/c of how dot broadcasts + #y_hat = np.dot(mu_predict, C.T) + y_hat = einsum2('ik,nk->ni', C, mu_predict) + res = y - y_hat + v = solve_triangular(L, res, lower=True) + + mu_filt = mu_predict + einsum2('nki,nk->ni', tmp1, solve_triangular(L, v, trans='T', lower=True)) + + tmp2 = solve_triangular(L, tmp1, lower=True) + sigma_filt = sym(sigma_predict - einsum2('nki,nkj->nij', tmp2, tmp2)) + + return y_hat, mu_filt, sigma_filt + + +def logZ(Y, A, C, Q, R, mu0, Q0): + """ Log marginal likelihood using the Kalman filter. + + The algorithm broadcasts over the first dimension which are considered + to be independent realizations. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D, D) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N = Y.shape[0] + T, D, _ = A.shape + p = C.shape[0] + + mu_predict = np.stack([np.copy(mu0) for _ in range(N)], axis=0) + sigma_predict = np.stack([np.copy(Q0) for _ in range(N)], axis=0) + + ll = 0. + + for t in range(T): + + # condition + # sigma_x = dot3(C, sigma_predict, C.T) + R + tmp1 = einsum2('ik,nkj->nij', C, sigma_predict) + sigma_x = einsum2('nik,jk->nij', tmp1, C) + R + sigma_x = sym(sigma_x) + + # res[n] = Y[n,t,:] = np.dot(C, mu_predict[n]) + res = Y[...,t,:] - einsum2('ik,nk->ni', C, mu_predict) + + L = np.linalg.cholesky(sigma_x) + v = solve_triangular(L, res, lower=True) + + # log-likelihood over all trials + ll += -0.5*(2.*np.sum(np.log(np.diagonal(L, axis1=1, axis2=2))) + + np.sum(v*v) + + N*p*np.log(2.*np.pi)) + + mus_filt = mu_predict + einsum2('nki,nk->ni', + tmp1, + solve_triangular(L, v, trans='T', lower=True)) + + # tmp2 = L^{-1}*C*sigma_predict + tmp2 = solve_triangular(L, tmp1, lower=True) + sigmas_filt = sigma_predict - einsum2('nki,nkj->nij', tmp2, tmp2) + sigmas_filt = sym(sigmas_filt) + + # prediction + #mu_predict = np.dot(A[t], mus_filt[t]) + mu_predict = einsum2('ik,nk->ni', A[t], mus_filt) + + # originally this worked with time-varying Q, but now it's fixed + #sigma_predict = dot3(A[t], sigmas_filt[t], A[t].T) + Q[t] + sigma_predict = einsum2('ik,nkl->nil', A[t], sigmas_filt) + sigma_predict = einsum2('nil,jl->nij', sigma_predict, A[t]) + Q + sigma_predict = sym(sigma_predict) + + return np.sum(ll) diff --git a/examples/megssm/mne_util.py b/examples/megssm/mne_util.py new file mode 100644 index 00000000..4bbf63b3 --- /dev/null +++ b/examples/megssm/mne_util.py @@ -0,0 +1,217 @@ +""" MNE-Python utility functions for preprocessing data and constructing + matrices necessary for MEGLDS analysis """ + +import mne +import numpy as np +import os.path as op + +from mne.io.pick import pick_types +from mne.utils import logger +from mne import label_sign_flip + +from scipy.sparse import csc_matrix, csr_matrix, diags + + +class ROIToSourceMap(object): + """ class for computing ROI-to-source space mapping matrix + + Notes + ----- + The following variables defined here correspond to various matrices + defined in :footcite:`yang_state-space_2016`: + - fwd_src_snsr : G + - fwd_roi_snsr : C + - fwd_src_roi : L + - snsr_src_cov : R + - roi_cov : Q + - roi_cov_0 : Q0 """ + + def __init__(self, fwd, labels, label_flip=False): + + src = fwd['src'] + + roiidx = list() + vertidx = list() + + n_lhverts = len(src[0]['vertno']) + n_rhverts = len(src[1]['vertno']) + n_verts = n_lhverts + n_rhverts + offsets = {'lh': 0, 'rh': n_lhverts} + + hemis = {'lh': 0, 'rh': 1} + + # index vector of which ROI a source point belongs to + which_roi = np.zeros(n_verts, dtype=np.int64) + + data = [] + for li, lab in enumerate(labels): + + this_data = np.round(label_sign_flip(lab, src)) + if not label_flip: + this_data.fill(1.) + data.append(this_data) + if isinstance(lab, mne.Label): + comp_labs = [lab] + elif isinstance(lab, mne.BiHemiLabel): + comp_labs = [lab.lh, lab.rh] + + for clab in comp_labs: + hemi = clab.hemi + hi = 0 if hemi == 'lh' else 1 + + lverts = clab.get_vertices_used(vertices=src[hi]['vertno']) + + # gets the indices in the source space vertex array, not the huge + # array. + # use `src[hi]['vertno'][lverts]` to get surface vertex indices to + # plot. + lverts = np.searchsorted(src[hi]['vertno'], lverts) + lverts += offsets[hemi] + vertidx.extend(lverts) + roiidx.extend(np.full(lverts.size, li, dtype=np.int64)) + + # add 1 b/c 0 corresponds to unassigned variance + which_roi[lverts] = li + 1 + + N = len(labels) + M = n_verts + + # construct sparse fwd_src_roi matrix + data = np.concatenate(data) + vertidx = np.array(vertidx, int) + roiidx = np.array(roiidx, int) + assert data.shape == vertidx.shape == roiidx.shape + fwd_src_roi = csc_matrix((data, (vertidx, roiidx)), shape=(M, N)) + + self.fwd = fwd + self.fwd_src_roi = fwd_src_roi + self.which_roi = which_roi + self.offsets = offsets + self.n_lhverts = n_lhverts + self.n_rhverts = n_rhverts + self.labels = labels + + return + + @property + def fwd_src_sn(self): + return self.fwd['sol']['data'] + + @property + def fwd_src_roi(self): + return self._fwd_src_roi + + @fwd_src_roi.setter + def fwd_src_roi(self, val): + self._fwd_src_roi = val + + @property + def which_roi(self): + return self._which_roi + + @which_roi.setter + def which_roi(self, val): + self._which_roi = val + + @property + def fwd_roi_sn(self): + from util import Carray + return Carray(csr_matrix.dot(self.fwd_src_roi.T, self.fwd_src_sn.T).T) + + def get_label_vinds(self, label): + li = self.labels.index(label) + if isinstance(label, mne.Label): + label_vert_idx = self.fwd_src_roi[:, li].nonzero()[0] + label_vert_idx -= self.offsets[label.hemi] + return label_vert_idx + elif isinstance(label, mne.BiHemiLabel): + # these labels store both hemispheres so subtract the rh offset + # from that part of the vertex array + lh_label_vert_idx = self.fwd_src_roi[:self.n_lhverts, li].nonzero()[0] + rh_label_vert_idx = self.fwd_src_roi[self.n_lhverts:, li].nonzero()[0] + rh_label_vert_idx[self.n_lhverts:] -= self.offsets['rh'] + return [lh_label_vert_idx, rh_label_vert_idx] + + def get_label_verts(self, label, src): + # if you're thinking of using this to plot, why not just use + # brain.add_label from pysurfer? + if isinstance(label, mne.Label): + hi = 0 if label.hemi == 'lh' else 1 + label_vert_idx = self.get_label_vinds(label) + varray = src[hi]['vertno'][label_vert_idx] + elif isinstance(label, mne.BiHemiLabel): + lh_label_vert_idx, rh_label_vert_idx = self.get_label_vinds(label) + varray = [src[0]['vertno'][lh_label_vert_idx], + src[1]['vertno'][rh_label_vert_idx]] + return varray + + def get_hemi_idx(self, label): + if isinstance(label, mne.Label): + return 0 if label.hemi == 'lh' else 1 + elif isinstance(label, mne.BiHemiLabel): + hemis = [None] * 2 + for i, lab in enumerate([label.lh, label.rh]): + hemis[i] = 0 if lab.hemi == 'lh' else 1 + return hemis + +def apply_projs(epochs, fwd, cov): + """ apply projection operators to fwd and cov """ + proj, _ = mne.io.proj.setup_proj(epochs.info, activate=False) + fwd_src_sn = fwd['sol']['data'] + fwd['sol']['data'] = np.dot(proj, fwd_src_sn) + + roi_cov = cov.data + if not np.allclose(np.dot(proj, roi_cov), roi_cov): + roi_cov = np.dot(proj, np.dot(roi_cov, proj.T)) + cov.data = roi_cov + + return fwd, cov + + +def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., + grad_scale=1.): + """ apply per-channel-type scaling to epochs, forward, and covariance """ + # from util import Carray ##skip import just pasted; util also from MEGLDS repo + Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') + Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') + Carray = Carray64 + + # get indices for each channel type + ch_names = cov['names'] # same as self.fwd['info']['ch_names'] + sel_eeg = pick_types(fwd['info'], meg=False, eeg=True, ref_meg=False) + sel_mag = pick_types(fwd['info'], meg='mag', eeg=False, ref_meg=False) + sel_grad = pick_types(fwd['info'], meg='grad', eeg=False, ref_meg=False) + idx_eeg = [ch_names.index(ch_names[c]) for c in sel_eeg] + idx_mag = [ch_names.index(ch_names[c]) for c in sel_mag] + idx_grad = [ch_names.index(ch_names[c]) for c in sel_grad] + + # retrieve forward and sensor covariance + fwd_src_snsr = fwd['sol']['data'].copy() + snsr_src_cov = cov.data.copy() + + # scale forward matrix + fwd_src_snsr[idx_eeg,:] *= eeg_scale + fwd_src_snsr[idx_mag,:] *= mag_scale + fwd_src_snsr[idx_grad,:] *= grad_scale + + # construct fwd_roi_sn matrix + fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, fwd_src_snsr.T).T) + + # scale sensor covariance + snsr_src_cov[np.ix_(idx_eeg, idx_eeg)] *= eeg_scale**2 + snsr_src_cov[np.ix_(idx_mag, idx_mag)] *= mag_scale**2 + snsr_src_cov[np.ix_(idx_grad, idx_grad)] *= grad_scale**2 + + # scale epochs + info = epochs.info.copy() + data = epochs.get_data().copy() + + data[:,idx_eeg,:] *= eeg_scale + data[:,idx_mag,:] *= mag_scale + data[:,idx_grad,:] *= grad_scale + + epochs = mne.EpochsArray(data, info) + + return fwd_src_snsr, fwd_roi_snsr, snsr_src_cov, epochs + + diff --git a/examples/megssm/models.py b/examples/megssm/models.py new file mode 100755 index 00000000..fbbaebbc --- /dev/null +++ b/examples/megssm/models.py @@ -0,0 +1,867 @@ +# from __future__ import division +# from __future__ import print_function +# from __future__ import absolute_import + +import sys + +import autograd.numpy as np +import scipy.optimize as spopt + +from autograd import grad #autograd --> jax +from autograd import value_and_grad as vgrad +from scipy.linalg import LinAlgError + +from .util import _ensure_ndim, rand_stable, rand_psd +from .util import linesearch, soft_thresh_At, block_thresh_At +from .util import relnormdiff +from .message_passing import kalman_filter, rts_smooth, rts_smooth_fast +from .message_passing import predict_step, condition +from .numpy_numthreads import numpy_num_threads + +from .mne_util import ROIToSourceMap, _scale_sensor_data + +try: + from autograd_linalg import solve_triangular +except ImportError: + raise RuntimeError("must install `autograd_linalg` package") + +# einsum2 is a parallel version of einsum that works for two arguments +try: + from einsum2 import einsum2 +except ImportError: + # rename standard numpy function if don't have einsum2 + print("=> WARNING: using standard numpy.einsum,", + "consider installing einsum2 package") + from autograd.numpy import einsum as einsum2 + +from datetime import datetime + + +# TODO: add documentation to all methods +class _MEGModel(object): + """ Base class for any model applied to MEG data that handles storing and + unpacking data from tuples. """ + + def __init__(self): + self._subjectdata = None + self._T = 0 + self._ntrials_all = 0 + self._nsubjects = 0 + + def set_data(self, subjectdata): + T_lst = [self.unpack_subject_data(e)[0].shape[1] for e in subjectdata] + assert len(list(set(T_lst))) == 1 + self._T = T_lst[0] + ntrials_lst = [self.unpack_subject_data(e)[0].shape[0] for e in \ + subjectdata] + self._ntrials_all = np.sum(ntrials_lst) + self._nsubjects = len(subjectdata) + self._subjectdata = subjectdata + + def unpack_all_subject_data(self): + if self._subjectdata is None: + raise ValueError("use set_data to add subject data") + return map(self.unpack_subject_data, self._subjectdata) + + @classmethod + def unpack_subject_data(cls, sdata): + obs, GL, G, Q_snsr, Q_J = sdata + Y = obs + w_s = 1. + if isinstance(obs, tuple): + if len(obs) == 2: + Y, w_s = obs + else: + raise ValueError("invalid format for subject data") + else: + Y = obs + w_s = 1. + + return Y, w_s, GL, G, Q_snsr, Q_J + + +# TODO: add documentation to all methods +# TODO: make some methods "private" (leading underscore) if necessary +class MEGLDS(_MEGModel): + """ State-space model for MEG data, as described in "A state-space model of + cross-region dynamic connectivity in MEG/EEG", Yang et al., NIPS 2016. + """ + + def __init__(self, D_roi, T, fwd, labels, noise_cov, A=None, Q=None, + mu0=None, Q0=None, log_sigsq_lst=None, lam0=0., lam1=0., + penalty='ridge', store_St=True): + + super().__init__() + + set_default = \ + lambda prm, val, deflt: \ + self.__setattr__(prm, val.copy() if val is not None else deflt) + + # initialize parameters + set_default("A", A, + np.stack([rand_stable(D_roi, maxew=0.7) for _ in range(T)], + axis=0)) + set_default("Q", Q, rand_psd(D_roi)) + set_default("mu0", mu0, np.zeros(D_roi)) + set_default("Q0", Q0, rand_psd(D_roi)) + set_default("log_sigsq_lst", log_sigsq_lst, + [np.log(np.random.gamma(2, 1, size=D_roi+1))]) + + self.lam0 = lam0 + self.lam1 = lam1 + + if penalty not in ('ridge', 'lasso', 'group-lasso'): + raise ValueError('penalty must be one of: ridge, lasso,' \ + + ' group-lasso') + self._penalty = penalty + + # initialize lists of smoothed estimates + self._mus_smooth_lst = None + self._sigmas_smooth_lst = None + self._sigmas_tnt_smooth_lst = None + self._loglik = None + self._store_St = bool(store_St) + + # initialize sufficient statistics + T, D, _ = self.A.shape + self._B0 = np.zeros((D, D)) + self._B1 = np.zeros((T-1, D, D)) + self._B3 = np.zeros((T-1, D, D)) + self._B2 = np.zeros((T-1, D, D)) + self._B4 = list() + + #will these pass to other functions? + self.labels = labels + self.fwd = fwd + self.noise_cov = noise_cov + + + + def set_data(self, subjectdata): + # add subject data, re-generate log_sigsq_lst if necessary + super().set_data(subjectdata) + if len(self.log_sigsq_lst) != self._nsubjects: + D_roi = self.log_sigsq_lst[0].shape[0] + self.log_sigsq_lst = [np.log(np.random.gamma(2, 1, size=D_roi)) + for _ in range(self._nsubjects)] + + # reset smoothed estimates and log-likelihood (no longer valid if + # new data was added) + self._mus_smooth_lst = None + self._sigmas_smooth_lst = None + self._sigmas_tnt_smooth_lst = None + self._loglik = None + self._B4 = [None] * self._nsubjects + + # TODO: figure out how to initialize smoothed parameters so this doesn't + # break, e.g. if em_objective is called before em for some reason + def em_objective(self): + + _, D, _ = self.A.shape + + L_Q0 = np.linalg.cholesky(self.Q0) + L_Q = np.linalg.cholesky(self.Q) + + L1 = 0. + L2 = 0. + L3 = 0. + + obj = 0. + for s, sdata in enumerate(self.unpack_all_subject_data()): + + Y, w_s, GL, G, Q_snsr, Q_J = sdata + + ntrials, T, _ = Y.shape + + sigsq_vals = np.exp(self.log_sigsq_lst[s]) + R = MEGLDS.R_(Q_snsr, G, sigsq_vals, Q_J) + L_R = np.linalg.cholesky(R) + + if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None + or self._sigmas_tnt_smooth_lst is None): + Qt = _ensure_ndim(self.Q, T, 3) + with numpy_num_threads(1): + _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ + rts_smooth_fast(Y, self.A, GL, Qt, R, self.mu0, + self.Q0, compute_lag1_cov=True) + + else: + mus_smooth = self._mus_smooth_lst[s] + sigmas_smooth = self._sigmas_smooth_lst[s] + sigmas_tnt_smooth = self._sigmas_tnt_smooth_lst[s] + + x_smooth_0_outer = einsum2('ri,rj->rij', mus_smooth[:,0,:D], + mus_smooth[:,0,:D]) + B0 = w_s*np.sum(sigmas_smooth[:,0,:D,:D] + x_smooth_0_outer, + axis=0) + + x_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,1:,:D], + mus_smooth[:,1:,:D]) + B1 = w_s*np.sum(sigmas_smooth[:,1:,:D,:D] + x_smooth_outer, axis=0) + + z_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,:-1,:], + mus_smooth[:,:-1,:]) + B3 = w_s*np.sum(sigmas_smooth[:,:-1,:,:] + z_smooth_outer, axis=0) + + mus_smooth_outer_l1 = einsum2('rti,rtj->rtij', + mus_smooth[:,1:,:D], + mus_smooth[:,:-1,:]) + B2 = w_s*np.sum(sigmas_tnt_smooth[:,:,:D,:] + mus_smooth_outer_l1, + axis=0) + + # obj += L1(Q0) + L_Q0_inv_B0 = solve_triangular(L_Q0, B0, lower=True) + L1 += (ntrials*2.*np.sum(np.log(np.diag(L_Q0))) + + np.trace(solve_triangular(L_Q0, L_Q0_inv_B0, lower=True, + trans='T'))) + + At = self.A[:-1] + AtB2T = einsum2('tik,tjk->tij', At, B2) + B2AtT = einsum2('tik,tjk->tij', B2, At) + tmp = einsum2('tik,tkl->til', At, B3) + AtB3AtT = einsum2('tik,tjk->tij', tmp, At) + + tmp = np.sum(B1 - AtB2T - B2AtT + AtB3AtT, axis=0) + + # obj += L2(Q, At) + L_Q_inv_tmp = solve_triangular(L_Q, tmp, lower=True) + L2 += (ntrials*(T-1)*2.*np.sum(np.log(np.diag(L_Q))) + + np.trace(solve_triangular(L_Q, L_Q_inv_tmp, lower=True, + trans='T'))) + + res = Y - einsum2('ik,ntk->nti', GL, mus_smooth[:,:,:D]) + CP_smooth = einsum2('ik,ntkj->ntij', GL, sigmas_smooth[:,:,:D,:D]) + + # TODO: np.sum does not parallelize over the accumulators, possible + # bottleneck. + B4 = w_s*(np.sum(einsum2('nti,ntj->ntij', res, res), axis=(0,1)) + + np.sum(einsum2('ntik,jk->ntij', CP_smooth, GL), + axis=(0,1))) + self._B4[s] = B4 + + # obj += L3(sigsq_vals) + L_R_inv_B4 = solve_triangular(L_R, B4, lower=True) + L3 += (ntrials*T*2*np.sum(np.log(np.diag(L_R))) + + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, + trans='T'))) + + obj = (L1 + L2 + L3) / self._ntrials_all + + # obj += penalty + if self.lam0 > 0.: + if self._penalty == 'ridge': + obj += self.lam0*np.sum(At**2) + elif self._penalty == 'lasso': + At_diag = np.diagonal(At, axis1=-2, axis2=-1) + sum_At_diag = np.sum(np.abs(At_diag)) + obj += self.lam0*(np.sum(np.abs(At)) - sum_At_diag) + elif self._penalty == 'group-lasso': + At_diag = np.diagonal(At, axis1=-2, axis2=-1) + norm_At_diag = np.sum(np.linalg.norm(At_diag, axis=0)) + norm_At = np.sum(np.linalg.norm(At, axis=0)) + obj += self.lam1*(norm_At - norm_At_diag) + if self.lam1 > 0.: + AtmAtm1_2 = (At[1:] - At[:-1])**2 + obj += self.lam1*np.sum(AtmAtm1_2) + + return obj + + def em(self, epochs, niter=100, tol=1e-6, A_Q_niter=100, A_Q_tol=1e-6, verbose=0, + update_A=True, update_Q=True, update_Q0=True, stationary_A=False, + diag_Q=False, update_sigsq=True, do_final_smoothing=True, + average_mus_smooth=True, Atrue=None, tau=0.1, c1=1e-4): + + fxn_start = datetime.now() + + + #compute roi to source map and scale the sensor data + fwd = self.fwd + labels = self.labels + noise_cov = self.noise_cov + roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map + scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} + fwd_src_snsr, fwd_roi_snsr, snsr_src_cov, epochs = \ + _scale_sensor_data(epochs, fwd, noise_cov, roi_to_src, **scales) + + + + + T, D, _ = self.A.shape + + # make initial A stationary if stationary_A option specified + if stationary_A: + self.A[:] = np.mean(self.A, axis=0) + + # set parameters for (A, Q) optimization + self._A_Q_niter = A_Q_niter + self._A_Q_tol = A_Q_tol + + # make initial Q, Q0 diagonal if diag_Q specified + if diag_Q: + self.Q0 = np.diag(np.diag(self.Q0)) + self.Q = np.diag(np.diag(self.Q)) + + + # keeping track of objective value and best parameters + objvals = np.zeros(niter+1) + converged = False + best_objval = np.finfo('float').max + best_params = (self.A.copy(), self.Q.copy(), self.mu0.copy(), + self.Q0.copy(), [l.copy() for l in self.log_sigsq_lst]) + + # previous parameter values (for checking convergence) + At_prev = None + Q_prev = None + Q0_prev = None + log_sigsq_lst_prev = None + + if Atrue is not None: + import matplotlib.pyplot as plt + fig_A, ax_A = plt.subplots(D, D, sharex=True, sharey=True) + plt.ion() + + # calculate initial objective value, check for updated best iterate + # have to do e-step here to initialize suff stats for m_step + if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None + or self._sigmas_tnt_smooth_lst is None): + self.e_step(verbose=verbose-1) + + objval = self.em_objective() + objvals[0] = objval + + for it in range(1, niter+1): + + iter_start = datetime.now() + + if verbose > 0: + print("em: it %d / %d" % (it, niter)) + sys.stdout.flush() + sys.stderr.flush() + + # record values from previous M-step + At_prev = self.A[:-1].copy() + Q_prev = self.Q.copy() + Q0_prev = self.Q0.copy() + log_sigsq_lst_prev = np.array(self.log_sigsq_lst).copy() + + self.m_step(update_A=update_A, update_Q=update_Q, + update_Q0=update_Q0, stationary_A=stationary_A, + diag_Q=diag_Q, update_sigsq=update_sigsq, + tau=tau, c1=c1, verbose=verbose) + + if Atrue is not None: + for i in range(D): + for j in range(D): + ax_A[i, j].cla() + ax_A[i, j].plot(Atrue[:-1, i, j], color='green') + ax_A[i, j].plot(self.A[:-1, i, j], color='red', + alpha=0.7) + fig_A.tight_layout() + fig_A.canvas.draw() + plt.pause(1. / 60.) + + self.e_step(verbose=verbose-1) + + # calculate objective value, check for updated best iterate + objval = self.em_objective() + objvals[it] = objval + + if verbose > 0: + print(" objective: %.4e" % objval) + At = self.A[:-1] + maxAt = np.max(np.abs(np.triu(At, k=1) + np.tril(At, k=-1))) + print(" max |A_t|: %.4e" % (maxAt,)) + sys.stdout.flush() + sys.stderr.flush() + + if objval < best_objval: + best_objval = objval + best_params = (self.A.copy(), self.Q.copy(), self.mu0.copy(), + self.Q0.copy(), + [l.copy() for l in self.log_sigsq_lst]) + + # check for convergence + if it >= 1: + relnormdiff_At = relnormdiff(self.A[:-1], At_prev) + relnormdiff_Q = relnormdiff(self.Q, Q_prev) + relnormdiff_Q0 = relnormdiff(self.Q0, Q0_prev) + relnormdiff_log_sigsq_lst = \ + np.array( + [relnormdiff(self.log_sigsq_lst[s], + log_sigsq_lst_prev[s]) + for s in range(len(self.log_sigsq_lst))]) + params_converged = (relnormdiff_At <= tol) and \ + (relnormdiff_Q <= tol) and \ + (relnormdiff_Q0 <= tol) and \ + np.all(relnormdiff_log_sigsq_lst <= tol) + + relobjdiff = np.abs((objval - objvals[it-1]) / objval) + + if verbose > 0: + print(" relnormdiff_At: %.3e" % relnormdiff_At) + print(" relnormdiff_Q: %.3e" % relnormdiff_Q) + print(" relnormdiff_Q0: %.3e" % relnormdiff_Q0) + print(" relnormdiff_log_sigsq_lst:", + relnormdiff_log_sigsq_lst) + print(" relobjdiff: %.3e" % relobjdiff) + + objdiff = objval - objvals[it-1] + if objdiff > 0: + print(" \033[0;31mEM objective increased\033[0m") + + sys.stdout.flush() + sys.stderr.flush() + + if params_converged or relobjdiff <= tol: + if verbose > 0: + print("EM objective converged") + sys.stdout.flush() + sys.stderr.flush() + converged = True + objvals = objvals[:it+1] + break + + # retrieve best parameters and load into instance variables. + A, Q, mu0, Q0, log_sigsq_lst = best_params + self.A = A.copy() + self.Q = Q.copy() + self.mu0 = mu0.copy() + self.Q0 = Q0.copy() + self.log_sigsq_lst = [l.copy() for l in log_sigsq_lst] + + if verbose > 0: + print() + print("elapsed, iteration:", datetime.now() - iter_start) + print("=" * 34) + print() + + # perform final smoothing + mus_smooth_lst = None + St_lst = None + if do_final_smoothing: + if verbose >= 1: + print("performing final smoothing") + + mus_smooth_lst = list() + self._loglik = 0. + if self._store_St: + St_lst = list() + for s, sdata in enumerate(self.unpack_all_subject_data()): + Y, w_s, GL, G, Q_snsr, Q_J = sdata + sigsq_vals = np.exp(self.log_sigsq_lst[s]) + R = MEGLDS.R_(Q_snsr, G, sigsq_vals, Q_J) + Qt = _ensure_ndim(self.Q, self._T, 3) + with numpy_num_threads(1): + loglik_subject, mus_smooth, _, _, St = \ + rts_smooth(Y, self.A, GL, Qt, R, self.mu0, self.Q0, + compute_lag1_cov=False, + store_St=self._store_St) + # just save the mean of the smoothed trials + if average_mus_smooth: + mus_smooth_lst.append(np.mean(mus_smooth, axis=0)) + else: + mus_smooth_lst.append(mus_smooth) + self._loglik += loglik_subject + # just save the diagonals of St b/c that's what we need for + # connectivity + if self._store_St: + St_lst.append(np.diagonal(St, axis1=-2, axis2=-1)) + + if verbose > 0: + print() + print("elapsed, function:", datetime.now() - fxn_start) + print("=" * 34) + print() + + return objvals, converged, mus_smooth_lst, self._loglik, St_lst + + def e_step(self, verbose=0): + + T, D, _ = self.A.shape + + # reset accumulation arrays + self._B0[:] = 0. + self._B1[:] = 0. + self._B3[:] = 0. + self._B2[:] = 0. + + self._mus_smooth_lst = list() + self._sigmas_smooth_lst = list() + self._sigmas_tnt_smooth_lst = list() + + if verbose > 0: + print(" e-step") + print(" subject", end="") + + for s, sdata in enumerate(self.unpack_all_subject_data()): + + if verbose > 0: + print(" %d" % (s+1,), end="") + sys.stdout.flush() + sys.stderr.flush() + + Y, w_s, GL, G, Q_snsr, Q_J = sdata + + sigsq_vals = np.exp(self.log_sigsq_lst[s]) + R = MEGLDS.R_(Q_snsr, G, sigsq_vals, Q_J) + L_R = np.linalg.cholesky(R) + Qt = _ensure_ndim(self.Q, self._T, 3) + + with numpy_num_threads(1): + _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ + rts_smooth_fast(Y, self.A, GL, Qt, R, self.mu0, + self.Q0, compute_lag1_cov=True) + + self._mus_smooth_lst.append(mus_smooth) + self._sigmas_smooth_lst.append(sigmas_smooth) + self._sigmas_tnt_smooth_lst.append(sigmas_tnt_smooth) + + x_smooth_0_outer = einsum2('ri,rj->rij', mus_smooth[:,0,:D], + mus_smooth[:,0,:D]) + self._B0 += w_s*np.sum(sigmas_smooth[:,0,:D,:D] + x_smooth_0_outer, + axis=0) + + x_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,1:,:D], + mus_smooth[:,1:,:D]) + self._B1 += w_s*np.sum(sigmas_smooth[:,1:,:D,:D] + x_smooth_outer, + axis=0) + + z_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,:-1,:], + mus_smooth[:,:-1,:]) + self._B3 += w_s*np.sum(sigmas_smooth[:,:-1,:,:] + z_smooth_outer, + axis=0) + + mus_smooth_outer_l1 = einsum2('rti,rtj->rtij', + mus_smooth[:,1:,:D], + mus_smooth[:,:-1,:]) + self._B2 += w_s*np.sum(sigmas_tnt_smooth[:,:,:D,:] + + mus_smooth_outer_l1, axis=0) + + if verbose > 0: + print("\n done") + + def m_step(self, update_A=True, update_Q=True, update_Q0=True, + stationary_A=False, diag_Q=False, update_sigsq=True, tau=0.1, c1=1e-4, + verbose=0): + self._loglik = None + if verbose > 0: + print(" m-step") + if update_Q0: + self.Q0 = (1. / self._ntrials_all) * self._B0 + if diag_Q: + self.Q0 = np.diag(np.diag(self.Q0)) + self.update_A_and_Q(update_A=update_A, update_Q=update_Q, + stationary_A=stationary_A, diag_Q=diag_Q, + tau=tau, c1=c1, verbose=verbose) + if update_sigsq: + self.update_log_sigsq_lst(verbose=verbose) + + def update_A_and_Q(self, update_A=True, update_Q=True, stationary_A=False, + diag_Q=False, tau=0.1, c1=1e-4, verbose=0): + + if verbose > 1: + print(" update A and Q") + + # gradient descent + At = self.A[:-1] + At_init = At.copy() + L_Q = np.linalg.cholesky(self.Q) + At_L_Q_obj = lambda x, y: self.L2_obj(x, y) + At_obj = lambda x: self.L2_obj(x, L_Q) + grad_At_obj = grad(At_obj) + obj_diff = np.finfo('float').max + obj = At_L_Q_obj(At, L_Q) + inner_it = 0 + + # specify proximal operator to use + if self._penalty == 'ridge': + prox_op = lambda x, y: x + elif self._penalty == 'lasso': + prox_op = soft_thresh_At + elif self._penalty == 'group-lasso': + prox_op = block_thresh_At + + while np.abs(obj_diff / obj) > self._A_Q_tol: + + if inner_it > self._A_Q_niter: + break + + obj_start = At_L_Q_obj(At, L_Q) + + # update At using gradient descent with backtracking line search + if update_A: + if stationary_A: + B2_sum = np.sum(self._B2, axis=0) + B3_sum = np.sum(self._B3, axis=0) + At[:] = np.linalg.solve(B3_sum.T, B2_sum.T).T + else: + grad_At = grad_At_obj(At) + step_size = linesearch(At_obj, grad_At_obj, At, grad_At, + prox_op=prox_op, lam=self.lam0, + tau=tau, c1=c1) + At[:] = prox_op(At - step_size * grad_At, + self.lam0 * step_size) + + # update Q using closed form + if update_Q: + AtB2T = einsum2('tik,tjk->tij', At, self._B2) + B2AtT = einsum2('tik,tjk->tij', self._B2, At) + tmp = einsum2('tik,tkl->til', At, self._B3) + AtB3AtT = einsum2('til,tjl->tij', tmp, At) + elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) + self.Q = (1. / (self._ntrials_all * self._T)) * elbo_2 + if diag_Q: + self.Q = np.diag(np.diag(self.Q)) + L_Q = np.linalg.cholesky(self.Q) + + obj = At_L_Q_obj(At, L_Q) + obj_diff = obj_start - obj + inner_it += 1 + + if verbose > 1: + if not stationary_A and update_A: + grad_norm = np.linalg.norm(grad_At) + norm_change = np.linalg.norm(At - At_init) + print(" last step size: %.3e" % step_size) + print(" last gradient norm: %.3e" % grad_norm) + print(" norm of total change: %.3e" % norm_change) + print(" number of iterations: %d" % inner_it) + print(" done") + + def update_log_sigsq_lst(self, verbose=0): + + if verbose > 1: + print(" update subject log-sigmasq") + + T, D, _ = self.A.shape + + # update log_sigsq_vals for each subject and ROI + for s, sdata in enumerate(self.unpack_all_subject_data()): + + Y, w_s, GL, G, Q_snsr, Q_J = sdata + ntrials, T, _ = Y.shape + mus_smooth = self._mus_smooth_lst[s] + sigmas_smooth = self._sigmas_smooth_lst[s] + B4 = self._B4[s] + + log_sigsq = self.log_sigsq_lst[s].copy() + log_sigsq_obj = lambda x: \ + MEGLDS.L3_obj(x, Q_snsr, G, Q_J, B4, ntrials, T) + log_sigsq_val_and_grad = vgrad(log_sigsq_obj) + + options = {'maxiter': 500} + opt_res = spopt.minimize(log_sigsq_val_and_grad, log_sigsq, + method='L-BFGS-B', jac=True, + options=options) + if verbose > 1: + print(" subject %d - %d iterations" % (s+1, opt_res.nit)) + + if not opt_res.success: + print(" log_sigsq opt") + print(" %s" % opt_res.message) + + self.log_sigsq_lst[s] = opt_res.x + + if verbose > 1: + print("\n done") + + def calculate_smoothed_estimates(self): + """ recalculate smoothed estimates with current model parameters """ + + self._mus_smooth_lst = list() + self._sigmas_smooth_lst = list() + self._sigmas_tnt_smooth_lst = list() + self._St_lst = list() + self._loglik = 0. + + for s, sdata in enumerate(self.unpack_all_subject_data()): + Y, w_s, GL, G, Q_snsr, Q_J = sdata + sigsq_vals = np.exp(self.log_sigsq_lst[s]) + R = MEGLDS.R_(Q_snsr, G, sigsq_vals, Q_J) + Qt = _ensure_ndim(self.Q, self._T, 3) + with numpy_num_threads(1): + ll, mus_smooth, sigmas_smooth, sigmas_tnt_smooth, _ = \ + rts_smooth(Y, self.A, GL, Qt, R, self.mu0, self.Q0, + compute_lag1_cov=True, store_St=False) + self._mus_smooth_lst.append(mus_smooth) + self._sigmas_smooth_lst.append(sigmas_smooth) + self._sigmas_tnt_smooth_lst.append(sigmas_tnt_smooth) + #self._St_lst.append(np.diagonal(St, axis1=-2, axis2=-1)) + self._loglik += ll + + def log_likelihood(self): + """ calculate log marginal likelihood using the Kalman filter """ + + #if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None \ + # or self._sigmas_tnt_smooth_lst is None): + # self.calculate_smoothed_estimates() + # return self._loglik + if self._loglik is not None: + return self._loglik + + self._loglik = 0. + for s, sdata in enumerate(self.unpack_all_subject_data()): + Y, w_s, GL, G, Q_snsr, Q_J = sdata + sigsq_vals = np.exp(self.log_sigsq_lst[s]) + R = MEGLDS.R_(Q_snsr, G, sigsq_vals, Q_J) + Qt = _ensure_ndim(self.Q, self._T, 3) + ll, _, _, _ = kalman_filter(Y, self.A, GL, Qt, R, self.mu0, + self.Q0, store_St=False) + self._loglik += ll + + return self._loglik + + def nparams(self): + T, p, _ = self.A.shape + + # this should equal (T-1)*p*p unless some shrinkage is used on At + nparams_At = np.sum(np.abs(self.A[:-1]) > 0) + + # nparams = nparams(At) + nparams(Q) + nparams(Q0) + # + nparams(log_sigsq_lst) + return nparams_At + p*(p+1)/2 + p*(p+1)/2 \ + + np.sum([p+1 for _ in range(len(self.log_sigsq_lst))]) + + def AIC(self): + return -2*self.log_likelihood() + 2*self.nparams() + + def BIC(self): + if self._ntrials_all == 0: + raise RuntimeError("use set_data to add subject data before" \ + + " computing BIC") + return -2*self.log_likelihood() \ + + np.log(self._ntrials_all)*self.nparams() + + def save(self, filename, **kwargs): + savedict = { 'A' : self.A, 'Q' : self.Q, 'mu0' : self.mu0, + 'Q0' : self.Q0, 'log_sigsq_lst' : self.log_sigsq_lst, + 'lam0' : self.lam0, 'lam1' : self.lam1} + savedict.update(kwargs) + np.savez_compressed(filename, **savedict) + + def load(self, filename): + loaddict = np.load(filename) + param_names = ['A', 'Q', 'mu0', 'Q0', 'log_sigsq_lst', 'lam0', 'lam1'] + for name in param_names: + if name not in loaddict.keys(): + raise RuntimeError('specified file is not a saved model:\n%s' + % (filename,)) + for name in param_names: + if name == 'log_sigsq_lst': + self.log_sigsq_lst = [l.copy() for l in loaddict[name]] + elif name in ('lam0', 'lam1'): + self.__setattr__(name, float(loaddict[name])) + else: + self.__setattr__(name, loaddict[name].copy()) + + # return remaining saved items, if there are any + others = {key : loaddict[key] for key in loaddict.keys() \ + if key not in param_names} + if len(others.keys()) > 0: + return others + + @staticmethod + def R_(Q_snsr, G, sigsq_vals, Q_J): + return Q_snsr + np.dot(G, sigsq_vals[Q_J][:,None]*G.T) + + def L2_obj(self, At, L_Q): + + # import autograd.numpy + # if isinstance(At,autograd.numpy.numpy_boxes.ArrayBox): + # At = At._value + + AtB2T = einsum2('tik,tjk->tij', At, self._B2) + B2AtT = einsum2('tik,tjk->tij', self._B2, At) + tmp = einsum2('tik,tkl->til', At, self._B3) + AtB3AtT = einsum2('til,tjl->tij', tmp, At) + elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) + + L_Q_inv_elbo_2 = solve_triangular(L_Q, elbo_2, lower=True) + obj = np.trace(solve_triangular(L_Q, L_Q_inv_elbo_2, lower=True, + trans='T')) + obj = obj / self._ntrials_all + + if self._penalty == 'ridge': + obj += self.lam0*np.sum(At**2) + AtmAtm1_2 = (At[1:] - At[:-1])**2 + obj += self.lam1*np.sum(AtmAtm1_2) + + return obj + + # TODO: convert to instance method + @staticmethod + def L3_obj(log_sigsq_vals, Q_snsr, G, Q_J, B4, ntrials, T): + R = MEGLDS.R_(Q_snsr, G, np.exp(log_sigsq_vals), Q_J) + try: + L_R = np.linalg.cholesky(R) + except LinAlgError: + return np.finfo('float').max + L_R_inv_B4 = solve_triangular(L_R, B4, lower=True) + return (ntrials*T*2.*np.sum(np.log(np.diag(L_R))) + + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, + trans='T'))) + + + @property + def A(self): + return self._A + + @A.setter + def A(self, A): + self._A = A + + @property + def Q(self): + return self._Q + + @Q.setter + def Q(self, Q): + self._Q = Q + + @property + def mu0(self): + return self._mu0 + + @mu0.setter + def mu0(self, mu0): + self._mu0 = mu0 + + @property + def Q0(self): + return self._Q0 + + @Q0.setter + def Q0(self, Q0): + self._Q0 = Q0 + + @property + def log_sigsq_lst(self): + return self._log_sigsq_lst + + @log_sigsq_lst.setter + def log_sigsq_lst(self, log_sigsq_lst): + self._log_sigsq_lst = log_sigsq_lst + + @property + def D_roi(self): + return self.A.shape[1] + + @property + def T(self): + return self._T + + @property + def lam0(self): + return self._lam0 + + @lam0.setter + def lam0(self, lam0): + self._lam0 = lam0 + + @property + def lam1(self): + return self._lam1 + + @lam1.setter + def lam1(self, lam1): + self._lam1 = lam1 diff --git a/examples/megssm/numpy_numthreads.py b/examples/megssm/numpy_numthreads.py new file mode 100755 index 00000000..550aa235 --- /dev/null +++ b/examples/megssm/numpy_numthreads.py @@ -0,0 +1,91 @@ +import contextlib +import ctypes +from ctypes.util import find_library + +# heavily based on: +# https://stackoverflow.com/questions/29559338/set-max-number-of-threads-at-runtime-on-numpy-openblas + +# Prioritize hand-compiled OpenBLAS library over version in /usr/lib/ +# from Ubuntu repos +try_paths = [find_library('openblas')] +openblas_lib = None +for libpath in try_paths: + try: + openblas_lib = ctypes.cdll.LoadLibrary(libpath) + break + except Exception: #OSError: + continue +#if openblas_lib is None: + #raise EnvironmentError('Could not locate an OpenBLAS shared library', 2) + +try: + mkl_rt_path = find_library('mkl_rt') + mkl_rt = ctypes.cdll.LoadLibrary(mkl_rt_path) + # print(mkl_rt) +except OSError: + mkl_rt = None + pass + + +def set_num_threads(n): + """Set the current number of threads used by the OpenBLAS server.""" + if mkl_rt: + pass + #mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(n))) + elif openblas_lib: + openblas_lib.openblas_set_num_threads(int(n)) + + +# At the time of writing these symbols were very new: +# https://github.com/xianyi/OpenBLAS/commit/65a847c +try: + if mkl_rt: #False: #mkl_rt: + def get_num_threads(): + return mkl_rt.mkl_get_max_threads() + elif openblas_lib: + # do this to throw exception if it doesn't exist + openblas_lib.openblas_get_num_threads() + def get_num_threads(): + """Get the current number of threads used by the OpenBLAS server.""" + return openblas_lib.openblas_get_num_threads() +except AttributeError: + def get_num_threads(): + """Dummy function (symbol not present in %s), returns -1.""" + return -1 + pass + +try: + if False: #mkl_rt: + def get_num_procs(): + # this returns number of procs + return mkl_rt.mkl_get_max_threads() + elif openblas_lib: + # do this to throw exception if it doesn't exist + openblas_lib.openblas_get_num_procs() + def get_num_procs(): + """Get the total number of physical processors""" + return openblas_lib.openblas_get_num_procs() +except AttributeError: + def get_num_procs(): + """Dummy function (symbol not present), returns -1.""" + return -1 + pass + + +@contextlib.contextmanager +def numpy_num_threads(n): + """Temporarily changes the number of OpenBLAS threads. + + Example usage: + + print("Before: {}".format(get_num_threads())) + with num_threads(n): + print("In thread context: {}".format(get_num_threads())) + print("After: {}".format(get_num_threads())) + """ + old_n = get_num_threads() + set_num_threads(n) + try: + yield + finally: + set_num_threads(old_n) diff --git a/examples/megssm/util.py b/examples/megssm/util.py new file mode 100755 index 00000000..63898be3 --- /dev/null +++ b/examples/megssm/util.py @@ -0,0 +1,117 @@ +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import autograd.numpy as np +from numpy.lib.stride_tricks import as_strided as ast + + +hs = lambda *args: np.concatenate(*args, axis=-1) + +def T_(X): + return np.swapaxes(X, -1, -2) + +def sym(X): + return 0.5*(X + T_(X)) + +def dot3(A, B, C): + return np.dot(A, np.dot(B, C)) + +def relnormdiff(A, B, min_denom=1e-9): + return np.linalg.norm(A - B) / np.maximum(np.linalg.norm(A), min_denom) + +def _ensure_ndim(X, T, ndim): + X = np.require(X, dtype=np.float64, requirements='C') + assert ndim-1 <= X.ndim <= ndim + if X.ndim == ndim: + assert X.shape[0] == T + return X + else: + return ast(X, shape=(T,) + X.shape, strides=(0,) + X.strides) + +def rand_psd(n, minew=0.1, maxew=1.): + # maxew is badly named + if n == 1: + return maxew * np.eye(n) + X = np.random.randn(n,n) + S = np.dot(T_(X), X) + S = sym(S) + ew, ev = np.linalg.eigh(S) + ew -= np.min(ew) + ew /= np.max(ew) + ew *= (maxew - minew) + ew += minew + return dot3(ev, np.diag(ew), T_(ev)) + +def rand_stable(n, maxew=0.9): + A = np.random.randn(n, n) + A *= maxew / np.max(np.abs(np.linalg.eigvals(A))) + return A + +def component_matrix(As, nlags): + """ compute component form of latent VAR process + + [A_1 A_2 ... A_p] + [ I 0 ... 0 ] + [ 0 I 0 0 ] + [ 0 ... I 0 ] + + """ + + d = As.shape[0] + res = np.zeros((d*nlags, d*nlags)) + res[:d] = As + + if nlags > 1: + res[np.arange(d,d*nlags), np.arange(d*nlags-d)] = 1 + + return res + +def linesearch(f, grad_f, xk, pk, step_size=1., tau=0.1, c1=1e-4, + prox_op=None, lam=1.): + """ find a step size via backtracking line search with armijo condition """ + obj_start = f(xk) + grad_xk = grad_f(xk) + obj_new = np.finfo('float').max + armijo_condition = 0 + + if prox_op is None: + prox_op = lambda x, y: x + + while obj_new > armijo_condition: + x_new = prox_op(xk - step_size * pk, lam*step_size) + armijo_condition = obj_start - c1*step_size*(np.sum(pk*grad_xk)) + obj_new = f(x_new) + step_size *= tau + + return step_size/tau + +def soft_thresh_At(At, lam): + At = At.copy() + diag_inds = np.diag_indices(At.shape[1]) + At_diag = np.diagonal(At, axis1=-2, axis2=-1) + + At = np.sign(At) * np.maximum(np.abs(At) - lam, 0.) + + # fill in diagonal with originally updated entries as we're not + # going to penalize them + for tt in range(At.shape[0]): + At[tt][diag_inds] = At_diag[tt] + return At + +def block_thresh_At(At, lam, min_norm=1e-16): + At = At.copy() + diag_inds = np.diag_indices(At.shape[1]) + At_diag = np.diagonal(At, axis1=-2, axis2=-1) + + norms = np.linalg.norm(At, axis=0, keepdims=True) + norms = np.maximum(norms, min_norm) + scales = np.maximum(norms - lam, 0.) + At = scales * (At / norms) + + # fill in diagonal with originally updated entries as we're not + # going to penalize them + for tt in range(At.shape[0]): + At[tt][diag_inds] = At_diag[tt] + return At + diff --git a/examples/state_space_connectivity.py b/examples/state_space_connectivity.py index da294887..5ede30a2 100644 --- a/examples/state_space_connectivity.py +++ b/examples/state_space_connectivity.py @@ -1,57 +1,79 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Created on Tue Jun 14 13:15:43 2022 +Authors: Jordan Drew -@author: jordandrew +""" -For 'mne-connectivity/examples/' to show usage of MEGLDS +''' +For 'mne-connectivity/examples/' to show usage of LDS Use MNE-sample-data for auditory/left -""" +''' ## import necessary libraries import mne import numpy as np import matplotlib.pyplot as plt -from mne_util import ROIToSourceMap, scale_sensor_data #mne_util is from MEGLDS repo + +#where should these files live within mne-connectivity repo? +from megssm.mne_util import ROIToSourceMap, _scale_sensor_data #mne_util is from MEGLDS repo +from megssm.models import MEGLDS as LDS ## define paths to sample data -data_path = '/Users/jordandrew/Documents/MEG/mne_data/MNE-sample-data' -# data_path = mne.datasets.sample.data_path() +path = None +path = '/Users/jordandrew/Documents/MEG/mne_data'#'/MNE-sample-data' +data_path = mne.datasets.sample.data_path(path=path) sample_folder = '/MEG/sample' raw_fname = data_path + sample_folder + '/sample_audvis_raw.fif' #how many subjects? subjects_dir = data_path + '/subjects' +fwd_fname = data_path + sample_folder + '/sample_audvis-meg-eeg-oct-6-fwd.fif' #EEG ONLY ## import raw data and find events raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) events = mne.find_events(raw, stim_channel='STI 014') -""" OR -raw_events_fname = data_path + sample_folder + '/sample_audvis_raw-eve.fif' -events = mne.read_events(raw_events_fname) -""" -## compute forward solution -sphere = mne.make_sphere_model('auto', 'auto', raw.info) -src = mne.setup_volume_source_space(sphere=sphere, exclude=30., pos=15.) -fwd = mne.make_forward_solution(raw.info, trans=None, src=src, bem=sphere) -fwd['src'].append( fwd['src'][0]) #fwd['src'] needs lh and rh; duplicated here -#is there a reason the sample data only has 1 hemisphere of data? +## read forward solution, remove bad channels +fwd = mne.read_forward_solution(fwd_fname,exclude=raw.info['bads']) +fwd = mne.convert_forward_solution(fwd, force_fixed=True) ## define epochs using event_dict -# event_id = 1 event_dict = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3, 'visual/right': 4, 'face': 5, 'buttonpress': 32} epochs = mne.Epochs(raw, events, tmin=-0.3, tmax=0.7, event_id=event_dict, preload=True) -# del raw ## compute covariance noise_cov = mne.compute_covariance(epochs, tmax=0) labels = mne.read_labels_from_annot('sample', subjects_dir=subjects_dir) -roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map -scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} -fwd_sr_sn, fwd_roi_sn, snsr_cov, epochs = \ - scale_sensor_data(epochs, fwd, noise_cov, roi_to_src, **scales) +#when to select specific labels/ROIs for processing? + +#make internal to LDS +# roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map +# scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} +# fwd_src_snsr, fwd_roi_snsr, snsr_src_cov, epochs = \ +# _scale_sensor_data(epochs, fwd, noise_cov, roi_to_src, **scales) + +# snsr_Q_J = mne.make_ad_hoc_cov(raw.info) #why not square matrix? + + + + +num_rois = len(labels) +timepts = len(epochs.times) +model = LDS(num_rois, timepts, fwd, labels, noise_cov) # only needs the forward, labels, and noise_cov to be initialized +# subjectdata = [(epochs, fwd_roi_snsr(C), fwd_src_snsr(G), snsr_src_cov(R,Q_snsr), roi_idx(Q_J))] +# model.set_data(subject_data) +# model.fit(epochs) # now only needs epochs to fit + + + + + + + + + + @@ -59,5 +81,3 @@ -# model = MEGLDS(fwd, labels, noise_cov) # only needs the forward, labels, and noise_cov to be initialized -# model.fit(epochs) # now only needs epochs to fit \ No newline at end of file From 97b982693632f22b9ae87977ad6f1bc831f93f05 Mon Sep 17 00:00:00 2001 From: jadrew43 Date: Mon, 27 Jun 2022 16:25:32 -0700 Subject: [PATCH 05/17] does scale_sensor_data() == whitening? --- examples/megssm/mne_util.py | 16 +++++------ examples/state_space_connectivity.py | 42 ++++++++++++++++++---------- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/examples/megssm/mne_util.py b/examples/megssm/mne_util.py index 4bbf63b3..5e53a329 100644 --- a/examples/megssm/mne_util.py +++ b/examples/megssm/mne_util.py @@ -22,7 +22,7 @@ class ROIToSourceMap(object): - fwd_src_snsr : G - fwd_roi_snsr : C - fwd_src_roi : L - - snsr_src_cov : R + - snsr_cov : Q_e - roi_cov : Q - roi_cov_0 : Q0 """ @@ -114,7 +114,7 @@ def which_roi(self, val): self._which_roi = val @property - def fwd_roi_sn(self): + def fwd_roi_snsr(self): from util import Carray return Carray(csr_matrix.dot(self.fwd_src_roi.T, self.fwd_src_sn.T).T) @@ -187,20 +187,20 @@ def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., # retrieve forward and sensor covariance fwd_src_snsr = fwd['sol']['data'].copy() - snsr_src_cov = cov.data.copy() + snsr_cov = cov.data.copy() # scale forward matrix fwd_src_snsr[idx_eeg,:] *= eeg_scale fwd_src_snsr[idx_mag,:] *= mag_scale fwd_src_snsr[idx_grad,:] *= grad_scale - # construct fwd_roi_sn matrix + # construct fwd_roi_snsr matrix fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, fwd_src_snsr.T).T) # scale sensor covariance - snsr_src_cov[np.ix_(idx_eeg, idx_eeg)] *= eeg_scale**2 - snsr_src_cov[np.ix_(idx_mag, idx_mag)] *= mag_scale**2 - snsr_src_cov[np.ix_(idx_grad, idx_grad)] *= grad_scale**2 + snsr_cov[np.ix_(idx_eeg, idx_eeg)] *= eeg_scale**2 + snsr_cov[np.ix_(idx_mag, idx_mag)] *= mag_scale**2 + snsr_cov[np.ix_(idx_grad, idx_grad)] *= grad_scale**2 # scale epochs info = epochs.info.copy() @@ -212,6 +212,6 @@ def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., epochs = mne.EpochsArray(data, info) - return fwd_src_snsr, fwd_roi_snsr, snsr_src_cov, epochs + return fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs diff --git a/examples/state_space_connectivity.py b/examples/state_space_connectivity.py index 5ede30a2..384cf51b 100644 --- a/examples/state_space_connectivity.py +++ b/examples/state_space_connectivity.py @@ -14,11 +14,17 @@ import mne import numpy as np import matplotlib.pyplot as plt +from scipy.sparse import csr_matrix #where should these files live within mne-connectivity repo? from megssm.mne_util import ROIToSourceMap, _scale_sensor_data #mne_util is from MEGLDS repo from megssm.models import MEGLDS as LDS +# from util import Carray ##skip import just pasted; util also from MEGLDS repo +Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') +Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') +Carray = Carray64 + ## define paths to sample data path = None path = '/Users/jordandrew/Documents/MEG/mne_data'#'/MNE-sample-data' @@ -47,21 +53,29 @@ labels = mne.read_labels_from_annot('sample', subjects_dir=subjects_dir) #when to select specific labels/ROIs for processing? +prepochs = epochs #make internal to LDS -# roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map -# scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} -# fwd_src_snsr, fwd_roi_snsr, snsr_src_cov, epochs = \ -# _scale_sensor_data(epochs, fwd, noise_cov, roi_to_src, **scales) - -# snsr_Q_J = mne.make_ad_hoc_cov(raw.info) #why not square matrix? - - - - -num_rois = len(labels) -timepts = len(epochs.times) -model = LDS(num_rois, timepts, fwd, labels, noise_cov) # only needs the forward, labels, and noise_cov to be initialized -# subjectdata = [(epochs, fwd_roi_snsr(C), fwd_src_snsr(G), snsr_src_cov(R,Q_snsr), roi_idx(Q_J))] +roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map +scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} +fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs = \ + _scale_sensor_data(epochs, fwd, noise_cov, roi_to_src, **scales) +#return G/fwd_src_snsr, GL/fwd_roi_snsr, Q_snsr/snsr_cov, epochs +#without scale_sensor_data()...equivalent? no +epochs_cov = mne.make_ad_hoc_cov(prepochs.info) +W, _ = mne.cov.compute_whitener(epochs_cov, prepochs.info) +fwd_data = fwd['sol']['data'].copy() +data = prepochs.get_data().copy() +#------------ +G = np.dot(W,fwd_data) #fwd_src_snsr +GL = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, fwd_data.T).T) #fwd_roi_snsr +cov = np.dot(W, np.dot(noise_cov.data.copy(), W.T)) #snsr_cov +# postpochs = np.dot(W, data) + + +# num_rois = len(labels) +# timepts = len(epochs.times) +# model = LDS(num_rois, timepts, fwd, labels, noise_cov) # only needs the forward, labels, and noise_cov to be initialized +# subjectdata = [(epochs, fwd_roi_snsr(C), fwd_src_snsr(G), snsr_cov(Q_e,Q_snsr), roi_idx(Q_J))] # model.set_data(subject_data) # model.fit(epochs) # now only needs epochs to fit From 87269d74cd6ab59332a4b9f8cd20649ff2d8f143 Mon Sep 17 00:00:00 2001 From: jadrew43 Date: Wed, 6 Jul 2022 16:21:25 -0700 Subject: [PATCH 06/17] working on suggested API` --- examples/state_space_connectivity.py | 81 ++++++++++++---------------- 1 file changed, 34 insertions(+), 47 deletions(-) diff --git a/examples/state_space_connectivity.py b/examples/state_space_connectivity.py index 384cf51b..dc4497b4 100644 --- a/examples/state_space_connectivity.py +++ b/examples/state_space_connectivity.py @@ -12,72 +12,59 @@ ## import necessary libraries import mne -import numpy as np -import matplotlib.pyplot as plt -from scipy.sparse import csr_matrix +# import numpy as np +# import matplotlib.pyplot as plt +# from scipy.sparse import csr_matrix #where should these files live within mne-connectivity repo? -from megssm.mne_util import ROIToSourceMap, _scale_sensor_data #mne_util is from MEGLDS repo +# from megssm.mne_util import ROIToSourceMap, _scale_sensor_data #mne_util is from MEGLDS repo from megssm.models import MEGLDS as LDS -# from util import Carray ##skip import just pasted; util also from MEGLDS repo -Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') -Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') -Carray = Carray64 - ## define paths to sample data path = None path = '/Users/jordandrew/Documents/MEG/mne_data'#'/MNE-sample-data' data_path = mne.datasets.sample.data_path(path=path) -sample_folder = '/MEG/sample' -raw_fname = data_path + sample_folder + '/sample_audvis_raw.fif' #how many subjects? -subjects_dir = data_path + '/subjects' -fwd_fname = data_path + sample_folder + '/sample_audvis-meg-eeg-oct-6-fwd.fif' #EEG ONLY +sample_folder = data_path / 'MEG/sample' +subjects_dir = data_path / 'subjects' ## import raw data and find events +raw_fname = sample_folder + '/sample_audvis_raw.fif' raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) events = mne.find_events(raw, stim_channel='STI 014') -## read forward solution, remove bad channels -fwd = mne.read_forward_solution(fwd_fname,exclude=raw.info['bads']) -fwd = mne.convert_forward_solution(fwd, force_fixed=True) - ## define epochs using event_dict event_dict = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3, 'visual/right': 4, 'face': 5, 'buttonpress': 32} epochs = mne.Epochs(raw, events, tmin=-0.3, tmax=0.7, event_id=event_dict, - preload=True) + preload=True).pick_types(meg=True,eeg=True,exclude='bads') -## compute covariance +## read forward solution, remove bad channels +fwd_fname = sample_folder + '/sample_audvis-meg-eeg-oct-6-fwd.fif' +fwd = mne.read_forward_solution(fwd_fname,exclude=raw.info['bads']) +fwd = mne.convert_forward_solution(fwd, force_fixed=True) + +## read in covariance OR compute noise covariance? noise_cov drops bad chs +cov_fname = sample_folder + '/sample_audvis-cov.fif' +cov = mne.read_cov(cov_fname) #has all 366 channels; drop 2? noise_cov = mne.compute_covariance(epochs, tmax=0) -labels = mne.read_labels_from_annot('sample', subjects_dir=subjects_dir) -#when to select specific labels/ROIs for processing? - -prepochs = epochs -#make internal to LDS -roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map -scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} -fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs = \ - _scale_sensor_data(epochs, fwd, noise_cov, roi_to_src, **scales) -#return G/fwd_src_snsr, GL/fwd_roi_snsr, Q_snsr/snsr_cov, epochs -#without scale_sensor_data()...equivalent? no -epochs_cov = mne.make_ad_hoc_cov(prepochs.info) -W, _ = mne.cov.compute_whitener(epochs_cov, prepochs.info) -fwd_data = fwd['sol']['data'].copy() -data = prepochs.get_data().copy() -#------------ -G = np.dot(W,fwd_data) #fwd_src_snsr -GL = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, fwd_data.T).T) #fwd_roi_snsr -cov = np.dot(W, np.dot(noise_cov.data.copy(), W.T)) #snsr_cov -# postpochs = np.dot(W, data) - - -# num_rois = len(labels) -# timepts = len(epochs.times) -# model = LDS(num_rois, timepts, fwd, labels, noise_cov) # only needs the forward, labels, and noise_cov to be initialized -# subjectdata = [(epochs, fwd_roi_snsr(C), fwd_src_snsr(G), snsr_cov(Q_e,Q_snsr), roi_idx(Q_J))] -# model.set_data(subject_data) -# model.fit(epochs) # now only needs epochs to fit + +## read labels for analysis +label_names = ['AUD-lh', 'AUD-rh', 'Vis-lh', 'Vis-rh'] +labels = [mne.read_label(sample_folder + 'labels/' + f'{label}.label') + for label in label_names] + +## initiate model +num_rois = len(labels) +timepts = len(epochs.times) +model = LDS(num_rois, timepts, lam0=0, lam1=100) # only needs the forward, labels, and noise_cov to be initialized + +model.add_subject('sample', subjects_dir, epochs, labels, fwd, noise_cov) +#when to use compute_cov vs read_cov? +#should remove epochs + + + +model.em(niter=2) From 93755acd5669fc503d1402f473895f70a5922e89 Mon Sep 17 00:00:00 2001 From: jadrew43 Date: Thu, 7 Jul 2022 14:29:21 -0700 Subject: [PATCH 07/17] API functioning properly, please test examples/state_space_connectivity.py --- examples/megssm/message_passing.py | 4 +- examples/megssm/mne_util.py | 81 +++++- examples/megssm/models.py | 387 ++++++++++++++------------- examples/megssm/plotting.py | 107 ++++++++ examples/state_space_connectivity.py | 27 +- 5 files changed, 397 insertions(+), 209 deletions(-) create mode 100644 examples/megssm/plotting.py diff --git a/examples/megssm/message_passing.py b/examples/megssm/message_passing.py index 21f08a76..e560ff49 100755 --- a/examples/megssm/message_passing.py +++ b/examples/megssm/message_passing.py @@ -304,7 +304,7 @@ def rts_smooth_fast(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False): I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) for t in range(T): - + # condition tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) @@ -361,7 +361,7 @@ def rts_smooth_fast(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False): sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) for t in range(T-2, -1, -1): - + # these names are stolen from mattjj and slinderman #temp_nn = np.dot(A[t], sigmas_smooth[n,t,:,:]) temp_nn = einsum2('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) diff --git a/examples/megssm/mne_util.py b/examples/megssm/mne_util.py index 5e53a329..7e9edd90 100644 --- a/examples/megssm/mne_util.py +++ b/examples/megssm/mne_util.py @@ -10,6 +10,12 @@ from mne import label_sign_flip from scipy.sparse import csc_matrix, csr_matrix, diags +from sklearn.decomposition import PCA + +# from util import Carray ##skip import just pasted; util also from MEGLDS repo +Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') +Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') +Carray = Carray64 class ROIToSourceMap(object): @@ -171,10 +177,10 @@ def apply_projs(epochs, fwd, cov): def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., grad_scale=1.): """ apply per-channel-type scaling to epochs, forward, and covariance """ - # from util import Carray ##skip import just pasted; util also from MEGLDS repo - Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') - Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') - Carray = Carray64 + # # from util import Carray ##skip import just pasted; util also from MEGLDS repo + # Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') + # Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') + # Carray = Carray64 # get indices for each channel type ch_names = cov['names'] # same as self.fwd['info']['ch_names'] @@ -215,3 +221,70 @@ def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., return fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs +def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', + pctvar=0.99, mean_center=False, label_flip=False): + """ apply sensor scaling, PCA dimensionality reduction with/without + whitening, and mean-centering to subject data """ + + if dim_mode not in ['rank', 'pctvar', 'whiten']: + raise ValueError("dim_mode must be in {'rank', 'pctvar', 'whiten'}") + + print("running pca for subject %s" % subject_name) + + # compute ROI-to-source map + roi_to_src = ROIToSourceMap(fwd, labels, label_flip) + scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} + if dim_mode == 'whiten': + + # scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} + G, GL, Q_snsr, epochs = \ + _scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) + dat = epochs.get_data() + dat = Carray(np.swapaxes(dat, -1, -2)) + + if mean_center: + dat -= np.mean(dat, axis=1, keepdims=True) + + dat_stacked = np.reshape(dat, (-1, dat.shape[-1])) + + W, _ = mne.cov.compute_whitener(subject.sensor_cov, + info=subject.epochs_list[0].info, + pca=True) + print("whitener for subject %s using %d principal components" % + (subject_name, W.shape[0])) + + else: + + G, GL, Q_snsr, epochs = \ + _scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) #info.scales + dat = epochs.get_data() + dat = Carray(np.swapaxes(dat, -1, -2)) + + if mean_center: + dat -= np.mean(dat, axis=1, keepdims=True) + + dat_stacked = np.reshape(dat, (-1, dat.shape[-1])) + + pca = PCA() + pca.fit(dat_stacked) + + if dim_mode == 'rank': + idx = np.linalg.matrix_rank(np.cov(dat_stacked, rowvar=False)) + else: + idx = np.where(np.cumsum(pca.explained_variance_ratio_) > pctvar)[0][0] + + idx = np.maximum(idx, len(labels)) + W = pca.components_[:idx] + print("subject %s using %d principal components" % (subject_name, idx)) + + ntrials, T, _ = dat.shape + dat_pca = np.dot(dat_stacked, W.T) + dat_pca = np.reshape(dat_pca, (ntrials, T, -1)) + + G_pca = np.dot(W, G) + GL_pca = np.dot(W, GL) + Q_snsr_pca = np.dot(W,np.dot(Q_snsr, W.T)) #dot3(W, Q_snsr, W.T) + + data = dat_pca + + return data, GL_pca, G_pca, Q_snsr_pca, roi_to_src.which_roi \ No newline at end of file diff --git a/examples/megssm/models.py b/examples/megssm/models.py index fbbaebbc..64d606bc 100755 --- a/examples/megssm/models.py +++ b/examples/megssm/models.py @@ -1,7 +1,3 @@ -# from __future__ import division -# from __future__ import print_function -# from __future__ import absolute_import - import sys import autograd.numpy as np @@ -18,7 +14,7 @@ from .message_passing import predict_step, condition from .numpy_numthreads import numpy_num_threads -from .mne_util import ROIToSourceMap, _scale_sensor_data +from .mne_util import ROIToSourceMap, _scale_sensor_data, run_pca_on_subject try: from autograd_linalg import solve_triangular @@ -44,14 +40,14 @@ class _MEGModel(object): def __init__(self): self._subjectdata = None - self._T = 0 + self._timepts = 0 self._ntrials_all = 0 self._nsubjects = 0 def set_data(self, subjectdata): - T_lst = [self.unpack_subject_data(e)[0].shape[1] for e in subjectdata] - assert len(list(set(T_lst))) == 1 - self._T = T_lst[0] + timepts_lst = [self.unpack_subject_data(e)[0].shape[1] for e in subjectdata] + assert len(list(set(timepts_lst))) == 1 + self._timepts = timepts_lst[0] ntrials_lst = [self.unpack_subject_data(e)[0].shape[0] for e in \ subjectdata] self._ntrials_all = np.sum(ntrials_lst) @@ -65,7 +61,7 @@ def unpack_all_subject_data(self): @classmethod def unpack_subject_data(cls, sdata): - obs, GL, G, Q_snsr, Q_J = sdata + obs, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata Y = obs w_s = 1. if isinstance(obs, tuple): @@ -77,7 +73,7 @@ def unpack_subject_data(cls, sdata): Y = obs w_s = 1. - return Y, w_s, GL, G, Q_snsr, Q_J + return Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi # TODO: add documentation to all methods @@ -87,9 +83,9 @@ class MEGLDS(_MEGModel): cross-region dynamic connectivity in MEG/EEG", Yang et al., NIPS 2016. """ - def __init__(self, D_roi, T, fwd, labels, noise_cov, A=None, Q=None, - mu0=None, Q0=None, log_sigsq_lst=None, lam0=0., lam1=0., - penalty='ridge', store_St=True): + def __init__(self, num_roi, timepts, A=None, roi_cov=None, mu0=None, roi_cov_0=None, + log_sigsq_lst=None, lam0=0., lam1=0., penalty='ridge', + store_St=True): super().__init__() @@ -99,13 +95,13 @@ def __init__(self, D_roi, T, fwd, labels, noise_cov, A=None, Q=None, # initialize parameters set_default("A", A, - np.stack([rand_stable(D_roi, maxew=0.7) for _ in range(T)], + np.stack([rand_stable(num_roi, maxew=0.7) for _ in range(timepts)], axis=0)) - set_default("Q", Q, rand_psd(D_roi)) - set_default("mu0", mu0, np.zeros(D_roi)) - set_default("Q0", Q0, rand_psd(D_roi)) + set_default("roi_cov", roi_cov, rand_psd(num_roi)) + set_default("mu0", mu0, np.zeros(num_roi)) + set_default("roi_cov_0", roi_cov_0, rand_psd(num_roi)) set_default("log_sigsq_lst", log_sigsq_lst, - [np.log(np.random.gamma(2, 1, size=D_roi+1))]) + [np.log(np.random.gamma(2, 1, size=num_roi+1))]) self.lam0 = lam0 self.lam1 = lam1 @@ -123,26 +119,47 @@ def __init__(self, D_roi, T, fwd, labels, noise_cov, A=None, Q=None, self._store_St = bool(store_St) # initialize sufficient statistics - T, D, _ = self.A.shape - self._B0 = np.zeros((D, D)) - self._B1 = np.zeros((T-1, D, D)) - self._B3 = np.zeros((T-1, D, D)) - self._B2 = np.zeros((T-1, D, D)) + timepts, num_roi, _ = self.A.shape + self._B0 = np.zeros((num_roi, num_roi)) + self._B1 = np.zeros((timepts-1, num_roi, num_roi)) + self._B3 = np.zeros((timepts-1, num_roi, num_roi)) + self._B2 = np.zeros((timepts-1, num_roi, num_roi)) self._B4 = list() - #will these pass to other functions? - self.labels = labels - self.fwd = fwd - self.noise_cov = noise_cov + self._subject_data = dict() + + def add_subject(self, subject, subject_dir, epochs,labels, fwd, cov): + roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map + scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} #scale=1 has no effect + fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs = \ + _scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) + + # which_roi = roi_to_src.which_roi # array of len(sources); val = ROI of source + # data = epochs._data + # data = np.swapaxes(data,-1,-2) + # subjectdata = [(data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi)] + + sdata = run_pca_on_subject(subject,epochs, fwd, cov, labels) + data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + subjectdata = [(data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi)] + self.set_data(subjectdata) + # epochs, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = subjectdata + self._subject_data[subject] = dict() + self._subject_data[subject]['epochs'] = epochs + self._subject_data[subject]['fwd_src_snsr'] = fwd_src_snsr + self._subject_data[subject]['fwd_roi_snsr'] = fwd_roi_snsr + self._subject_data[subject]['snsr_cov'] = snsr_cov + self._subject_data[subject]['labels'] = labels + self._subject_data[subject]['which_roi'] = which_roi def set_data(self, subjectdata): # add subject data, re-generate log_sigsq_lst if necessary super().set_data(subjectdata) if len(self.log_sigsq_lst) != self._nsubjects: - D_roi = self.log_sigsq_lst[0].shape[0] - self.log_sigsq_lst = [np.log(np.random.gamma(2, 1, size=D_roi)) + num_roi = self.log_sigsq_lst[0].shape[0] + self.log_sigsq_lst = [np.log(np.random.gamma(2, 1, size=num_roi)) for _ in range(self._nsubjects)] # reset smoothed estimates and log-likelihood (no longer valid if @@ -156,11 +173,11 @@ def set_data(self, subjectdata): # TODO: figure out how to initialize smoothed parameters so this doesn't # break, e.g. if em_objective is called before em for some reason def em_objective(self): + + _, num_roi, _ = self.A.shape - _, D, _ = self.A.shape - - L_Q0 = np.linalg.cholesky(self.Q0) - L_Q = np.linalg.cholesky(self.Q) + L_roi_cov_0 = np.linalg.cholesky(self.roi_cov_0) + L_roi_cov = np.linalg.cholesky(self.roi_cov) L1 = 0. L2 = 0. @@ -169,50 +186,50 @@ def em_objective(self): obj = 0. for s, sdata in enumerate(self.unpack_all_subject_data()): - Y, w_s, GL, G, Q_snsr, Q_J = sdata + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - ntrials, T, _ = Y.shape + ntrials, timepts, _ = Y.shape sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(Q_snsr, G, sigsq_vals, Q_J) + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) L_R = np.linalg.cholesky(R) if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None or self._sigmas_tnt_smooth_lst is None): - Qt = _ensure_ndim(self.Q, T, 3) + roi_cov_t = _ensure_ndim(self.roi_cov, timepts, 3) with numpy_num_threads(1): _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ - rts_smooth_fast(Y, self.A, GL, Qt, R, self.mu0, - self.Q0, compute_lag1_cov=True) + rts_smooth_fast(Y, self.A, fwd_roi_snsr, roi_cov_t, R, self.mu0, + self.roi_cov_0, compute_lag1_cov=True) else: mus_smooth = self._mus_smooth_lst[s] sigmas_smooth = self._sigmas_smooth_lst[s] sigmas_tnt_smooth = self._sigmas_tnt_smooth_lst[s] - x_smooth_0_outer = einsum2('ri,rj->rij', mus_smooth[:,0,:D], - mus_smooth[:,0,:D]) - B0 = w_s*np.sum(sigmas_smooth[:,0,:D,:D] + x_smooth_0_outer, + x_smooth_0_outer = einsum2('ri,rj->rij', mus_smooth[:,0,:num_roi], + mus_smooth[:,0,:num_roi]) + B0 = w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + x_smooth_0_outer, axis=0) - x_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,1:,:D], - mus_smooth[:,1:,:D]) - B1 = w_s*np.sum(sigmas_smooth[:,1:,:D,:D] + x_smooth_outer, axis=0) + x_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], + mus_smooth[:,1:,:num_roi]) + B1 = w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + x_smooth_outer, axis=0) z_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,:-1,:], mus_smooth[:,:-1,:]) B3 = w_s*np.sum(sigmas_smooth[:,:-1,:,:] + z_smooth_outer, axis=0) mus_smooth_outer_l1 = einsum2('rti,rtj->rtij', - mus_smooth[:,1:,:D], + mus_smooth[:,1:,:num_roi], mus_smooth[:,:-1,:]) - B2 = w_s*np.sum(sigmas_tnt_smooth[:,:,:D,:] + mus_smooth_outer_l1, + B2 = w_s*np.sum(sigmas_tnt_smooth[:,:,:num_roi,:] + mus_smooth_outer_l1, axis=0) - # obj += L1(Q0) - L_Q0_inv_B0 = solve_triangular(L_Q0, B0, lower=True) - L1 += (ntrials*2.*np.sum(np.log(np.diag(L_Q0))) - + np.trace(solve_triangular(L_Q0, L_Q0_inv_B0, lower=True, + # obj += L1(roi_cov_0) + L_roi_cov_0_inv_B0 = solve_triangular(L_roi_cov_0, B0, lower=True) + L1 += (ntrials*2.*np.sum(np.log(np.diag(L_roi_cov_0))) + + np.trace(solve_triangular(L_roi_cov_0, L_roi_cov_0_inv_B0, lower=True, trans='T'))) At = self.A[:-1] @@ -223,25 +240,25 @@ def em_objective(self): tmp = np.sum(B1 - AtB2T - B2AtT + AtB3AtT, axis=0) - # obj += L2(Q, At) - L_Q_inv_tmp = solve_triangular(L_Q, tmp, lower=True) - L2 += (ntrials*(T-1)*2.*np.sum(np.log(np.diag(L_Q))) - + np.trace(solve_triangular(L_Q, L_Q_inv_tmp, lower=True, + # obj += L2(roi_cov, At) + L_roi_cov_inv_tmp = solve_triangular(L_roi_cov, tmp, lower=True) + L2 += (ntrials*(timepts-1)*2.*np.sum(np.log(np.diag(L_roi_cov))) + + np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_tmp, lower=True, trans='T'))) - res = Y - einsum2('ik,ntk->nti', GL, mus_smooth[:,:,:D]) - CP_smooth = einsum2('ik,ntkj->ntij', GL, sigmas_smooth[:,:,:D,:D]) + res = Y - einsum2('ik,ntk->nti', fwd_roi_snsr, mus_smooth[:,:,:num_roi]) + CP_smooth = einsum2('ik,ntkj->ntij', fwd_roi_snsr, sigmas_smooth[:,:,:num_roi,:num_roi]) # TODO: np.sum does not parallelize over the accumulators, possible # bottleneck. B4 = w_s*(np.sum(einsum2('nti,ntj->ntij', res, res), axis=(0,1)) - + np.sum(einsum2('ntik,jk->ntij', CP_smooth, GL), + + np.sum(einsum2('ntik,jk->ntij', CP_smooth, fwd_roi_snsr), axis=(0,1))) self._B4[s] = B4 # obj += L3(sigsq_vals) L_R_inv_B4 = solve_triangular(L_R, B4, lower=True) - L3 += (ntrials*T*2*np.sum(np.log(np.diag(L_R))) + L3 += (ntrials*timepts*2*np.sum(np.log(np.diag(L_R))) + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, trans='T'))) @@ -266,58 +283,45 @@ def em_objective(self): return obj - def em(self, epochs, niter=100, tol=1e-6, A_Q_niter=100, A_Q_tol=1e-6, verbose=0, - update_A=True, update_Q=True, update_Q0=True, stationary_A=False, - diag_Q=False, update_sigsq=True, do_final_smoothing=True, + def fit(self, niter=100, tol=1e-6, A_roi_cov_niter=100, A_roi_cov_tol=1e-6, verbose=0, + update_A=True, update_roi_cov=True, update_roi_cov_0=True, stationary_A=False, + diag_roi_cov=False, update_sigsq=True, do_final_smoothing=True, average_mus_smooth=True, Atrue=None, tau=0.1, c1=1e-4): fxn_start = datetime.now() - - - #compute roi to source map and scale the sensor data - fwd = self.fwd - labels = self.labels - noise_cov = self.noise_cov - roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map - scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} - fwd_src_snsr, fwd_roi_snsr, snsr_src_cov, epochs = \ - _scale_sensor_data(epochs, fwd, noise_cov, roi_to_src, **scales) - - - - - T, D, _ = self.A.shape + + timepts, num_roi, _ = self.A.shape # make initial A stationary if stationary_A option specified if stationary_A: self.A[:] = np.mean(self.A, axis=0) - # set parameters for (A, Q) optimization - self._A_Q_niter = A_Q_niter - self._A_Q_tol = A_Q_tol + # set parameters for (A, roi_cov) optimization + self._A_roi_cov_niter = A_roi_cov_niter + self._A_roi_cov_tol = A_roi_cov_tol - # make initial Q, Q0 diagonal if diag_Q specified - if diag_Q: - self.Q0 = np.diag(np.diag(self.Q0)) - self.Q = np.diag(np.diag(self.Q)) + # make initial roi_cov, roi_cov_0 diagonal if diag_roi_cov specified + if diag_roi_cov: + self.roi_cov_0 = np.diag(np.diag(self.roi_cov_0)) + self.roi_cov = np.diag(np.diag(self.roi_cov)) # keeping track of objective value and best parameters objvals = np.zeros(niter+1) converged = False best_objval = np.finfo('float').max - best_params = (self.A.copy(), self.Q.copy(), self.mu0.copy(), - self.Q0.copy(), [l.copy() for l in self.log_sigsq_lst]) + best_params = (self.A.copy(), self.roi_cov.copy(), self.mu0.copy(), + self.roi_cov_0.copy(), [l.copy() for l in self.log_sigsq_lst]) # previous parameter values (for checking convergence) At_prev = None - Q_prev = None - Q0_prev = None + roi_cov_prev = None + roi_cov_0_prev = None log_sigsq_lst_prev = None if Atrue is not None: import matplotlib.pyplot as plt - fig_A, ax_A = plt.subplots(D, D, sharex=True, sharey=True) + fig_A, ax_A = plt.subplots(num_roi, num_roi, sharex=True, sharey=True) plt.ion() # calculate initial objective value, check for updated best iterate @@ -340,18 +344,18 @@ def em(self, epochs, niter=100, tol=1e-6, A_Q_niter=100, A_Q_tol=1e-6, verbose=0 # record values from previous M-step At_prev = self.A[:-1].copy() - Q_prev = self.Q.copy() - Q0_prev = self.Q0.copy() + roi_cov_prev = self.roi_cov.copy() + roi_cov_0_prev = self.roi_cov_0.copy() log_sigsq_lst_prev = np.array(self.log_sigsq_lst).copy() - self.m_step(update_A=update_A, update_Q=update_Q, - update_Q0=update_Q0, stationary_A=stationary_A, - diag_Q=diag_Q, update_sigsq=update_sigsq, + self.m_step(update_A=update_A, update_roi_cov=update_roi_cov, + update_roi_cov_0=update_roi_cov_0, stationary_A=stationary_A, + diag_roi_cov=diag_roi_cov, update_sigsq=update_sigsq, tau=tau, c1=c1, verbose=verbose) if Atrue is not None: - for i in range(D): - for j in range(D): + for i in range(num_roi): + for j in range(num_roi): ax_A[i, j].cla() ax_A[i, j].plot(Atrue[:-1, i, j], color='green') ax_A[i, j].plot(self.A[:-1, i, j], color='red', @@ -376,31 +380,31 @@ def em(self, epochs, niter=100, tol=1e-6, A_Q_niter=100, A_Q_tol=1e-6, verbose=0 if objval < best_objval: best_objval = objval - best_params = (self.A.copy(), self.Q.copy(), self.mu0.copy(), - self.Q0.copy(), + best_params = (self.A.copy(), self.roi_cov.copy(), self.mu0.copy(), + self.roi_cov_0.copy(), [l.copy() for l in self.log_sigsq_lst]) # check for convergence if it >= 1: relnormdiff_At = relnormdiff(self.A[:-1], At_prev) - relnormdiff_Q = relnormdiff(self.Q, Q_prev) - relnormdiff_Q0 = relnormdiff(self.Q0, Q0_prev) + relnormdiff_roi_cov = relnormdiff(self.roi_cov, roi_cov_prev) + relnormdiff_roi_cov_0 = relnormdiff(self.roi_cov_0, roi_cov_0_prev) relnormdiff_log_sigsq_lst = \ np.array( [relnormdiff(self.log_sigsq_lst[s], log_sigsq_lst_prev[s]) for s in range(len(self.log_sigsq_lst))]) params_converged = (relnormdiff_At <= tol) and \ - (relnormdiff_Q <= tol) and \ - (relnormdiff_Q0 <= tol) and \ + (relnormdiff_roi_cov <= tol) and \ + (relnormdiff_roi_cov_0 <= tol) and \ np.all(relnormdiff_log_sigsq_lst <= tol) relobjdiff = np.abs((objval - objvals[it-1]) / objval) if verbose > 0: print(" relnormdiff_At: %.3e" % relnormdiff_At) - print(" relnormdiff_Q: %.3e" % relnormdiff_Q) - print(" relnormdiff_Q0: %.3e" % relnormdiff_Q0) + print(" relnormdiff_roi_cov: %.3e" % relnormdiff_roi_cov) + print(" relnormdiff_roi_cov_0: %.3e" % relnormdiff_roi_cov_0) print(" relnormdiff_log_sigsq_lst:", relnormdiff_log_sigsq_lst) print(" relobjdiff: %.3e" % relobjdiff) @@ -422,11 +426,11 @@ def em(self, epochs, niter=100, tol=1e-6, A_Q_niter=100, A_Q_tol=1e-6, verbose=0 break # retrieve best parameters and load into instance variables. - A, Q, mu0, Q0, log_sigsq_lst = best_params + A, roi_cov, mu0, roi_cov_0, log_sigsq_lst = best_params self.A = A.copy() - self.Q = Q.copy() + self.roi_cov = roi_cov.copy() self.mu0 = mu0.copy() - self.Q0 = Q0.copy() + self.roi_cov_0 = roi_cov_0.copy() self.log_sigsq_lst = [l.copy() for l in log_sigsq_lst] if verbose > 0: @@ -447,13 +451,13 @@ def em(self, epochs, niter=100, tol=1e-6, A_Q_niter=100, A_Q_tol=1e-6, verbose=0 if self._store_St: St_lst = list() for s, sdata in enumerate(self.unpack_all_subject_data()): - Y, w_s, GL, G, Q_snsr, Q_J = sdata + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(Q_snsr, G, sigsq_vals, Q_J) - Qt = _ensure_ndim(self.Q, self._T, 3) + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) + roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) with numpy_num_threads(1): loglik_subject, mus_smooth, _, _, St = \ - rts_smooth(Y, self.A, GL, Qt, R, self.mu0, self.Q0, + rts_smooth(Y, self.A, fwd_roi_snsr, roi_cov_t, R, self.mu0, self.roi_cov_0, compute_lag1_cov=False, store_St=self._store_St) # just save the mean of the smoothed trials @@ -477,7 +481,7 @@ def em(self, epochs, niter=100, tol=1e-6, A_Q_niter=100, A_Q_tol=1e-6, verbose=0 def e_step(self, verbose=0): - T, D, _ = self.A.shape + timepts, num_roi, _ = self.A.shape # reset accumulation arrays self._B0[:] = 0. @@ -500,30 +504,30 @@ def e_step(self, verbose=0): sys.stdout.flush() sys.stderr.flush() - Y, w_s, GL, G, Q_snsr, Q_J = sdata + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(Q_snsr, G, sigsq_vals, Q_J) + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) L_R = np.linalg.cholesky(R) - Qt = _ensure_ndim(self.Q, self._T, 3) + roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) with numpy_num_threads(1): _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ - rts_smooth_fast(Y, self.A, GL, Qt, R, self.mu0, - self.Q0, compute_lag1_cov=True) + rts_smooth_fast(Y, self.A, fwd_roi_snsr, roi_cov_t, R, self.mu0, + self.roi_cov_0, compute_lag1_cov=True) self._mus_smooth_lst.append(mus_smooth) self._sigmas_smooth_lst.append(sigmas_smooth) self._sigmas_tnt_smooth_lst.append(sigmas_tnt_smooth) - x_smooth_0_outer = einsum2('ri,rj->rij', mus_smooth[:,0,:D], - mus_smooth[:,0,:D]) - self._B0 += w_s*np.sum(sigmas_smooth[:,0,:D,:D] + x_smooth_0_outer, + x_smooth_0_outer = einsum2('ri,rj->rij', mus_smooth[:,0,:num_roi], + mus_smooth[:,0,:num_roi]) + self._B0 += w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + x_smooth_0_outer, axis=0) - x_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,1:,:D], - mus_smooth[:,1:,:D]) - self._B1 += w_s*np.sum(sigmas_smooth[:,1:,:D,:D] + x_smooth_outer, + x_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], + mus_smooth[:,1:,:num_roi]) + self._B1 += w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + x_smooth_outer, axis=0) z_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,:-1,:], @@ -532,45 +536,45 @@ def e_step(self, verbose=0): axis=0) mus_smooth_outer_l1 = einsum2('rti,rtj->rtij', - mus_smooth[:,1:,:D], + mus_smooth[:,1:,:num_roi], mus_smooth[:,:-1,:]) - self._B2 += w_s*np.sum(sigmas_tnt_smooth[:,:,:D,:] + + self._B2 += w_s*np.sum(sigmas_tnt_smooth[:,:,:num_roi,:] + mus_smooth_outer_l1, axis=0) if verbose > 0: print("\n done") - def m_step(self, update_A=True, update_Q=True, update_Q0=True, - stationary_A=False, diag_Q=False, update_sigsq=True, tau=0.1, c1=1e-4, + def m_step(self, update_A=True, update_roi_cov=True, update_roi_cov_0=True, + stationary_A=False, diag_roi_cov=False, update_sigsq=True, tau=0.1, c1=1e-4, verbose=0): self._loglik = None if verbose > 0: print(" m-step") - if update_Q0: - self.Q0 = (1. / self._ntrials_all) * self._B0 - if diag_Q: - self.Q0 = np.diag(np.diag(self.Q0)) - self.update_A_and_Q(update_A=update_A, update_Q=update_Q, - stationary_A=stationary_A, diag_Q=diag_Q, + if update_roi_cov_0: + self.roi_cov_0 = (1. / self._ntrials_all) * self._B0 + if diag_roi_cov: + self.roi_cov_0 = np.diag(np.diag(self.roi_cov_0)) + self.update_A_and_roi_cov(update_A=update_A, update_roi_cov=update_roi_cov, + stationary_A=stationary_A, diag_roi_cov=diag_roi_cov, tau=tau, c1=c1, verbose=verbose) if update_sigsq: self.update_log_sigsq_lst(verbose=verbose) - def update_A_and_Q(self, update_A=True, update_Q=True, stationary_A=False, - diag_Q=False, tau=0.1, c1=1e-4, verbose=0): + def update_A_and_roi_cov(self, update_A=True, update_roi_cov=True, stationary_A=False, + diag_roi_cov=False, tau=0.1, c1=1e-4, verbose=0): if verbose > 1: - print(" update A and Q") + print(" update A and roi_cov") # gradient descent At = self.A[:-1] At_init = At.copy() - L_Q = np.linalg.cholesky(self.Q) - At_L_Q_obj = lambda x, y: self.L2_obj(x, y) - At_obj = lambda x: self.L2_obj(x, L_Q) + L_roi_cov = np.linalg.cholesky(self.roi_cov) + At_L_roi_cov_obj = lambda x, y: self.L2_obj(x, y) + At_obj = lambda x: self.L2_obj(x, L_roi_cov) grad_At_obj = grad(At_obj) obj_diff = np.finfo('float').max - obj = At_L_Q_obj(At, L_Q) + obj = At_L_roi_cov_obj(At, L_roi_cov) inner_it = 0 # specify proximal operator to use @@ -581,12 +585,12 @@ def update_A_and_Q(self, update_A=True, update_Q=True, stationary_A=False, elif self._penalty == 'group-lasso': prox_op = block_thresh_At - while np.abs(obj_diff / obj) > self._A_Q_tol: + while np.abs(obj_diff / obj) > self._A_roi_cov_tol: - if inner_it > self._A_Q_niter: + if inner_it > self._A_roi_cov_niter: break - obj_start = At_L_Q_obj(At, L_Q) + obj_start = At_L_roi_cov_obj(At, L_roi_cov) # update At using gradient descent with backtracking line search if update_A: @@ -602,19 +606,20 @@ def update_A_and_Q(self, update_A=True, update_Q=True, stationary_A=False, At[:] = prox_op(At - step_size * grad_At, self.lam0 * step_size) - # update Q using closed form - if update_Q: + # update roi_cov using closed form + if update_roi_cov: AtB2T = einsum2('tik,tjk->tij', At, self._B2) B2AtT = einsum2('tik,tjk->tij', self._B2, At) tmp = einsum2('tik,tkl->til', At, self._B3) AtB3AtT = einsum2('til,tjl->tij', tmp, At) elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) - self.Q = (1. / (self._ntrials_all * self._T)) * elbo_2 - if diag_Q: - self.Q = np.diag(np.diag(self.Q)) - L_Q = np.linalg.cholesky(self.Q) + self.roi_cov = (1. / (self._ntrials_all * self._timepts + )) * elbo_2 + if diag_roi_cov: + self.roi_cov = np.diag(np.diag(self.roi_cov)) + L_roi_cov = np.linalg.cholesky(self.roi_cov) - obj = At_L_Q_obj(At, L_Q) + obj = At_L_roi_cov_obj(At, L_roi_cov) obj_diff = obj_start - obj inner_it += 1 @@ -633,20 +638,20 @@ def update_log_sigsq_lst(self, verbose=0): if verbose > 1: print(" update subject log-sigmasq") - T, D, _ = self.A.shape + timepts, num_roi, _ = self.A.shape # update log_sigsq_vals for each subject and ROI for s, sdata in enumerate(self.unpack_all_subject_data()): - Y, w_s, GL, G, Q_snsr, Q_J = sdata - ntrials, T, _ = Y.shape + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + ntrials, timepts, _ = Y.shape mus_smooth = self._mus_smooth_lst[s] sigmas_smooth = self._sigmas_smooth_lst[s] B4 = self._B4[s] log_sigsq = self.log_sigsq_lst[s].copy() log_sigsq_obj = lambda x: \ - MEGLDS.L3_obj(x, Q_snsr, G, Q_J, B4, ntrials, T) + MEGLDS.L3_obj(x, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, timepts) log_sigsq_val_and_grad = vgrad(log_sigsq_obj) options = {'maxiter': 500} @@ -675,13 +680,13 @@ def calculate_smoothed_estimates(self): self._loglik = 0. for s, sdata in enumerate(self.unpack_all_subject_data()): - Y, w_s, GL, G, Q_snsr, Q_J = sdata + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(Q_snsr, G, sigsq_vals, Q_J) - Qt = _ensure_ndim(self.Q, self._T, 3) + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) + roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) with numpy_num_threads(1): ll, mus_smooth, sigmas_smooth, sigmas_tnt_smooth, _ = \ - rts_smooth(Y, self.A, GL, Qt, R, self.mu0, self.Q0, + rts_smooth(Y, self.A, fwd_roi_snsr, roi_cov_t, R, self.mu0, self.roi_cov_0, compute_lag1_cov=True, store_St=False) self._mus_smooth_lst.append(mus_smooth) self._sigmas_smooth_lst.append(sigmas_smooth) @@ -701,23 +706,23 @@ def log_likelihood(self): self._loglik = 0. for s, sdata in enumerate(self.unpack_all_subject_data()): - Y, w_s, GL, G, Q_snsr, Q_J = sdata + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(Q_snsr, G, sigsq_vals, Q_J) - Qt = _ensure_ndim(self.Q, self._T, 3) - ll, _, _, _ = kalman_filter(Y, self.A, GL, Qt, R, self.mu0, - self.Q0, store_St=False) + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) + roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) + ll, _, _, _ = kalman_filter(Y, self.A, fwd_roi_snsr, roi_cov_t, R, self.mu0, + self.roi_cov_0, store_St=False) self._loglik += ll return self._loglik def nparams(self): - T, p, _ = self.A.shape + timepts, p, _ = self.A.shape - # this should equal (T-1)*p*p unless some shrinkage is used on At + # this should equal (timepts-1)*p*p unless some shrinkage is used on At nparams_At = np.sum(np.abs(self.A[:-1]) > 0) - # nparams = nparams(At) + nparams(Q) + nparams(Q0) + # nparams = nparams(At) + nparams(roi_cov) + nparams(roi_cov_0) # + nparams(log_sigsq_lst) return nparams_At + p*(p+1)/2 + p*(p+1)/2 \ + np.sum([p+1 for _ in range(len(self.log_sigsq_lst))]) @@ -733,15 +738,15 @@ def BIC(self): + np.log(self._ntrials_all)*self.nparams() def save(self, filename, **kwargs): - savedict = { 'A' : self.A, 'Q' : self.Q, 'mu0' : self.mu0, - 'Q0' : self.Q0, 'log_sigsq_lst' : self.log_sigsq_lst, + savedict = { 'A' : self.A, 'roi_cov' : self.roi_cov, 'mu0' : self.mu0, + 'roi_cov_0' : self.roi_cov_0, 'log_sigsq_lst' : self.log_sigsq_lst, 'lam0' : self.lam0, 'lam1' : self.lam1} savedict.update(kwargs) np.savez_compressed(filename, **savedict) def load(self, filename): loaddict = np.load(filename) - param_names = ['A', 'Q', 'mu0', 'Q0', 'log_sigsq_lst', 'lam0', 'lam1'] + param_names = ['A', 'roi_cov', 'mu0', 'roi_cov_0', 'log_sigsq_lst', 'lam0', 'lam1'] for name in param_names: if name not in loaddict.keys(): raise RuntimeError('specified file is not a saved model:\n%s' @@ -761,10 +766,10 @@ def load(self, filename): return others @staticmethod - def R_(Q_snsr, G, sigsq_vals, Q_J): - return Q_snsr + np.dot(G, sigsq_vals[Q_J][:,None]*G.T) + def R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi): + return snsr_cov + np.dot(fwd_src_snsr, sigsq_vals[which_roi][:,None]*fwd_src_snsr.T) - def L2_obj(self, At, L_Q): + def L2_obj(self, At, L_roi_cov): # import autograd.numpy # if isinstance(At,autograd.numpy.numpy_boxes.ArrayBox): @@ -776,8 +781,8 @@ def L2_obj(self, At, L_Q): AtB3AtT = einsum2('til,tjl->tij', tmp, At) elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) - L_Q_inv_elbo_2 = solve_triangular(L_Q, elbo_2, lower=True) - obj = np.trace(solve_triangular(L_Q, L_Q_inv_elbo_2, lower=True, + L_roi_cov_inv_elbo_2 = solve_triangular(L_roi_cov, elbo_2, lower=True) + obj = np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_elbo_2, lower=True, trans='T')) obj = obj / self._ntrials_all @@ -790,14 +795,14 @@ def L2_obj(self, At, L_Q): # TODO: convert to instance method @staticmethod - def L3_obj(log_sigsq_vals, Q_snsr, G, Q_J, B4, ntrials, T): - R = MEGLDS.R_(Q_snsr, G, np.exp(log_sigsq_vals), Q_J) + def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, timepts): + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, np.exp(log_sigsq_vals), which_roi) try: L_R = np.linalg.cholesky(R) except LinAlgError: return np.finfo('float').max L_R_inv_B4 = solve_triangular(L_R, B4, lower=True) - return (ntrials*T*2.*np.sum(np.log(np.diag(L_R))) + return (ntrials*timepts*2.*np.sum(np.log(np.diag(L_R))) + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, trans='T'))) @@ -811,12 +816,12 @@ def A(self, A): self._A = A @property - def Q(self): - return self._Q + def roi_cov(self): + return self._roi_cov - @Q.setter - def Q(self, Q): - self._Q = Q + @roi_cov.setter + def roi_cov(self, roi_cov): + self._roi_cov = roi_cov @property def mu0(self): @@ -827,12 +832,12 @@ def mu0(self, mu0): self._mu0 = mu0 @property - def Q0(self): - return self._Q0 + def roi_cov_0(self): + return self._roi_cov_0 - @Q0.setter - def Q0(self, Q0): - self._Q0 = Q0 + @roi_cov_0.setter + def roi_cov_0(self, roi_cov_0): + self._roi_cov_0 = roi_cov_0 @property def log_sigsq_lst(self): @@ -843,12 +848,12 @@ def log_sigsq_lst(self, log_sigsq_lst): self._log_sigsq_lst = log_sigsq_lst @property - def D_roi(self): + def num_roi(self): return self.A.shape[1] @property - def T(self): - return self._T + def timepts(self): + return self._timepts @property def lam0(self): diff --git a/examples/megssm/plotting.py b/examples/megssm/plotting.py new file mode 100644 index 00000000..4e435de6 --- /dev/null +++ b/examples/megssm/plotting.py @@ -0,0 +1,107 @@ +""" plotting functions """ + +import numpy as np +import matplotlib.pyplot as plt + +def plot_At(A, ci='sd', times=None, ax=None, skipdiag=False, labels=None, + showticks=True, **kwargs): + """ plot traces of each entry of dynamics A in square grid of subplots """ + if A.ndim == 3: + T, d, _ = A.shape + elif A.ndim == 4: + _, T, d, _ = A.shape + + if times is None: + times = np.arange(T) + + if ax is None or ax.shape != (d, d): + fig, ax = plt.subplots(d, d, sharex=True, sharey=True, squeeze=False) + else: + fig = ax[0, 0].figure + + for i in range(d): + for j in range(d): + + # skip and hide subplots on diagonal + if skipdiag and i == j: + ax[i, j].set_visible(False) + continue + + # plot A entry as trace with/without error band + if A.ndim == 3: + ax[i, j].plot(times[:-1], A[:-1, i, j], **kwargs) + elif A.ndim == 4: + plot_fill(A[:, :-1, i, j], ci=ci, times=times[:-1], + ax=ax[i, j], **kwargs) + + # add labels above first row and to the left of the first column + if labels is not None: + if i == 0 or (skipdiag and (i, j) == (1, 0)): + ax[i, j].set_title(labels[j], fontsize=12) + if j == 0 or (skipdiag and (i, j) == (0, 1)): + ax[i, j].set_ylabel(labels[i], fontsize=12) + + # remove x- and y-ticks on subplot + if not showticks: + ax[i, j].set_xticks([]) + ax[i, j].set_yticks([]) + + diag_lims = [0, 1] + off_lims = [-0.25, 0.25] + for ri, row in enumerate(ax): + for ci, a in enumerate(row): + ylim = diag_lims if ri == ci else off_lims + a.set(ylim=ylim, xlim=times[[0, -1]]) + if ri == 0: + a.set_title(a.get_title(), fontsize='small') + if ci == 0: + a.set_ylabel(a.get_ylabel(), fontsize='small') + for line in a.lines: + line.set_clip_on(False) + line.set(lw=1.) + if ci != 0: + a.yaxis.set_major_formatter(plt.NullFormatter()) + if ri != len(labels) - 1: + a.xaxis.set_major_formatter(plt.NullFormatter()) + if ri == ci: + for spine in a.spines.values(): + spine.set(lw=2) + else: + a.axhline(0, color='k', ls=':', lw=1.) + + return fig, ax + +def plot_fill(X, times=None, ax=None, ci='sd', **kwargs): + """ plot mean and error band across first axis of X """ + N, T = X.shape + + if times is None: + times = np.arange(T) + if ax is None: + fig, ax = plt.subplots(1, 1) + + mu = np.mean(X, axis=0) + + # define lower and upper band limits based on ci + if ci == 'sd': # standard deviation + sigma = np.std(X, axis=0) + lower, upper = mu - sigma, mu + sigma + elif ci == 'se': # standard error + stderr = np.std(X, axis=0) / np.sqrt(X.shape[0]) + lower, upper = mu - stderr, mu + stderr + elif ci == '2sd': # 2 standard deviations + sigma = np.std(X, axis=0) + lower, upper = mu - 2 * sigma, mu + 2 * sigma + elif ci == 'max': # range (min to max) + lower, upper = np.min(X, axis=0), np.max(X, axis=0) + elif type(ci) is float and 0 < ci < 1: + # quantile-based confidence interval + a = 1 - ci + lower, upper = np.quantile(X, [a / 2, 1 - a / 2], axis=0) + else: + raise ValueError("ci must be in ('sd', 'se', '2sd', 'max') " + "or float in (0, 1)") + + lines = ax.plot(times, mu, **kwargs) + c = lines[0].get_color() + ax.fill_between(times, lower, upper, color=c, alpha=0.3, lw=0) diff --git a/examples/state_space_connectivity.py b/examples/state_space_connectivity.py index dc4497b4..79d13b5f 100644 --- a/examples/state_space_connectivity.py +++ b/examples/state_space_connectivity.py @@ -12,13 +12,11 @@ ## import necessary libraries import mne -# import numpy as np -# import matplotlib.pyplot as plt -# from scipy.sparse import csr_matrix +import matplotlib.pyplot as plt #where should these files live within mne-connectivity repo? -# from megssm.mne_util import ROIToSourceMap, _scale_sensor_data #mne_util is from MEGLDS repo from megssm.models import MEGLDS as LDS +from megssm.plotting import plot_At ## define paths to sample data path = None @@ -28,7 +26,7 @@ subjects_dir = data_path / 'subjects' ## import raw data and find events -raw_fname = sample_folder + '/sample_audvis_raw.fif' +raw_fname = sample_folder / 'sample_audvis_raw.fif' raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) events = mne.find_events(raw, stim_channel='STI 014') @@ -37,20 +35,21 @@ 'visual/right': 4, 'face': 5, 'buttonpress': 32} epochs = mne.Epochs(raw, events, tmin=-0.3, tmax=0.7, event_id=event_dict, preload=True).pick_types(meg=True,eeg=True,exclude='bads') +epochs = epochs['auditory/left'] # choose condition for analysis ## read forward solution, remove bad channels -fwd_fname = sample_folder + '/sample_audvis-meg-eeg-oct-6-fwd.fif' +fwd_fname = sample_folder / 'sample_audvis-meg-eeg-oct-6-fwd.fif' fwd = mne.read_forward_solution(fwd_fname,exclude=raw.info['bads']) fwd = mne.convert_forward_solution(fwd, force_fixed=True) ## read in covariance OR compute noise covariance? noise_cov drops bad chs -cov_fname = sample_folder + '/sample_audvis-cov.fif' +cov_fname = sample_folder / 'sample_audvis-cov.fif' cov = mne.read_cov(cov_fname) #has all 366 channels; drop 2? noise_cov = mne.compute_covariance(epochs, tmax=0) ## read labels for analysis label_names = ['AUD-lh', 'AUD-rh', 'Vis-lh', 'Vis-rh'] -labels = [mne.read_label(sample_folder + 'labels/' + f'{label}.label') +labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label') for label in label_names] ## initiate model @@ -60,12 +59,16 @@ model.add_subject('sample', subjects_dir, epochs, labels, fwd, noise_cov) #when to use compute_cov vs read_cov? -#should remove epochs +model.fit(niter=100, verbose=1) +At = model.A +assert At.shape == (timepts, num_rois, num_rois) - -model.em(niter=2) - +plt.rcParams.update( + {'xtick.labelsize': 'x-small', 'ytick.labelsize': 'x-small'}) +fig, ax = plt.subplots(num_rois, num_rois, constrained_layout=True, squeeze=False, + figsize=(12, 10)) +plot_At(At, labels=label_names, times=epochs.times, ax=ax) From aca431c773272d1f95f87465a8965d04ec0ebbeb Mon Sep 17 00:00:00 2001 From: jadrew43 Date: Thu, 21 Jul 2022 17:02:23 -0700 Subject: [PATCH 08/17] files moved to separate folder, descriptive variable names used, mne functions for scaling data --- state_space/megssm/__init__.py | 0 state_space/megssm/label_util.py | 449 ++++++++++ state_space/megssm/message_passing.py | 732 ++++++++++++++++ state_space/megssm/mne_util.py | 305 +++++++ state_space/megssm/models.py | 867 +++++++++++++++++++ state_space/megssm/numpy_numthreads.py | 91 ++ state_space/megssm/plotting.py | 107 +++ state_space/megssm/util.py | 117 +++ state_space/state_space_connectivity.py | 88 ++ state_space/state_space_connectivity_test.py | 170 ++++ 10 files changed, 2926 insertions(+) create mode 100755 state_space/megssm/__init__.py create mode 100644 state_space/megssm/label_util.py create mode 100755 state_space/megssm/message_passing.py create mode 100644 state_space/megssm/mne_util.py create mode 100755 state_space/megssm/models.py create mode 100755 state_space/megssm/numpy_numthreads.py create mode 100644 state_space/megssm/plotting.py create mode 100755 state_space/megssm/util.py create mode 100644 state_space/state_space_connectivity.py create mode 100644 state_space/state_space_connectivity_test.py diff --git a/state_space/megssm/__init__.py b/state_space/megssm/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/state_space/megssm/label_util.py b/state_space/megssm/label_util.py new file mode 100644 index 00000000..c56484ab --- /dev/null +++ b/state_space/megssm/label_util.py @@ -0,0 +1,449 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import mne +import numpy as np +import os + +from megssm.mne_util import combine_medial_labels + +subjects_dir = mne.utils.get_subjects_dir() +rtpj_modes = ('hcp', 'labsn', 'intersect') + +label_shortnames = {'Early Auditory Cortex-lh': 'AUD-lh', + 'Early Auditory Cortex-rh': 'AUD-rh', + 'Premotor Cortex-lh': 'FEF-lh', + 'Premotor Cortex-rh': 'FEF-rh', + 'lh.IPS-labsn-lh': 'IPS-lh', + 'rh.IPS-labsn-rh': 'IPS-rh', + 'lh.LIPSP-lh': 'LIPSP', + 'rh.RTPJ-rh': 'RTPJ', + 'rh.RTPJIntersect-rh-rh': 'RTPJ-intersect', + 'Primary Visual Cortex (V1)-lh + Primary Visual Cortex (V1)-rh + Early Visual Cortex-lh + Early Visual Cortex-rh': 'Vis', + 'Anterior Cingulate and Medial Prefrontal Cortex-lh + Anterior Cingulate and Medial Prefrontal Cortex-rh': 'ACC', + 'DorsoLateral Prefrontal Cortex-lh': 'DLPFC-lh', + 'DorsoLateral Prefrontal Cortex-rh': 'DLPFC-rh', + 'Temporo-Parieto-Occipital Junction-lh': 'TPOJ-lh', + 'Temporo-Parieto-Occipital Junction-rh': 'TPOJ-rh' + } + + +def _sps_meglds_base(): + hcp_mmp1_labels = mne.read_labels_from_annot('fsaverage', + parc='HCPMMP1_combined') + hcp_mmp1_labels = combine_medial_labels(hcp_mmp1_labels) + label_names = [l.name for l in hcp_mmp1_labels] + + ips_str = os.path.join(subjects_dir, "fsaverage/label/*.IPS-labsn.label") + ips_fnames = glob.glob(ips_str) + assert len(ips_fnames) == 2, ips_fnames + ips_labels = [mne.read_label(fn, subject='fsaverage') for fn in ips_fnames] + + pmc_labs = [l for l in hcp_mmp1_labels if 'Premotor Cortex' in l.name] + eac_labs = [l for l in hcp_mmp1_labels if 'Early Auditory Cortex' in l.name] + + labels = list() + labels.extend(pmc_labs) + labels.extend(eac_labs) + labels.extend(ips_labels) + + rtpj_str = os.path.join(subjects_dir, 'fsaverage/label/rh.RTPJ.label') + rtpj = mne.read_label(rtpj_str, subject='fsaverage') + labels.append(rtpj) + + lipsp_str = os.path.join(subjects_dir, 'fsaverage/label/lh.LIPSP.label') + lipsp = mne.read_label(lipsp_str, subject='fsaverage') + labels.append(lipsp) + + return sorted(labels, key=lambda x: x.name), hcp_mmp1_labels + + +def sps_meglds_base(): + return _sps_meglds_base()[0] + + +def _sps_meglds_base_vision(): + + labels, hcp_mmp1_labels = _sps_meglds_base() + + prim_visual = [l for l in hcp_mmp1_labels if 'Primary Visual Cortex' in l.name] + + # there should be only one b/c of medial merge + prim_visual = prim_visual[0] + + label_names = [l.name for l in hcp_mmp1_labels] + early_visual_lh = label_names.index("Early Visual Cortex-lh") + early_visual_rh = label_names.index("Early Visual Cortex-rh") + early_visual_lh = hcp_mmp1_labels[early_visual_lh] + early_visual_rh = hcp_mmp1_labels[early_visual_rh] + visual = prim_visual + early_visual_lh + early_visual_rh + + labels.append(visual) + + return sorted(labels, key=lambda x: x.name), hcp_mmp1_labels + + +def sps_meglds_base_vision(): + return _sps_meglds_base_vision()[0] + + +def sps_meglds_base_vision_extra(): + + labels, hcp_mmp1_labels = _sps_meglds_base_vision() + + # glasser 19 + ac_mpc_labs = [l for l in hcp_mmp1_labels if 'Anterior Cingulate and Medial Prefrontal Cortex' in l.name] + labels.extend(ac_mpc_labs) + + # glasser 22 + dpc_labs = [l for l in hcp_mmp1_labels if 'DorsoLateral Prefrontal Cortex' in l.name] + labels.extend(dpc_labs) + + return sorted(labels, key=lambda x: x.name) + + +def sps_meglds_base_extra(): + labels, hcp_mmp1_labels = _sps_meglds_base() + + # glasser 19 + ac_mpc_labs = [l for l in hcp_mmp1_labels if 'Anterior Cingulate and Medial Prefrontal Cortex' in l.name] + labels.extend(ac_mpc_labs) + + # glasser 22 + dpc_labs = [l for l in hcp_mmp1_labels if 'DorsoLateral Prefrontal Cortex' in l.name] + labels.extend(dpc_labs) + + + return sorted(labels, key=lambda x: x.name) + + + +def load_labsn_7_labels(): + label_str = os.path.join(subjects_dir, "fsaverage/label/*labsn*") + rtpj_str = os.path.join(subjects_dir, 'fsaverage/label/rh.RTPJ.label') + label_fnames = glob.glob(label_str) + assert len(label_fnames) == 6 + label_fnames.insert(0, rtpj_str) + labels = [mne.read_label(fn, subject='fsaverage') for fn in label_fnames] + labels = sorted(labels, key=lambda x: x.name) + + return labels + + +def load_hcpmmp1_combined(): + + labels = mne.read_labels_from_annot('fsaverage', parc='HCPMMP1_combined') + labels = sorted(labels, key=lambda x: x.name) + labels = combine_medial_labels(labels) + + return labels + + +def load_labsn_hcpmmp1_7_labels(include_visual=False, rtpj_mode='intersect'): + + if rtpj_mode not in rtpj_modes: + raise ValueError("rtpj must be one of", rtpj_modes) + + hcp_mmp1_labels = mne.read_labels_from_annot('fsaverage', + parc='HCPMMP1_combined') + hcp_mmp1_labels = combine_medial_labels(hcp_mmp1_labels) + label_names = [l.name for l in hcp_mmp1_labels] + + ips_str = os.path.join(subjects_dir, "fsaverage/label/*.IPS-labsn.label") + ips_fnames = glob.glob(ips_str) + assert len(ips_fnames) == 2 + ips_labels = [mne.read_label(fn, subject='fsaverage') for fn in ips_fnames] + + pmc_labs = [l for l in hcp_mmp1_labels if 'Premotor Cortex' in l.name] + eac_labs = [l for l in hcp_mmp1_labels if 'Early Auditory Cortex' in l.name] + + labels = list() + labels.extend(pmc_labs) + labels.extend(eac_labs) + labels.extend(ips_labels) + + # this is in place of original rtpj + #ipc_labs = [l for l in hcp_mmp1_labels if 'Inferior Parietal Cortex' in l.name] + if rtpj_mode == 'hcp': + rtpj = [l for l in hcp_mmp1_labels + if 'Inferior Parietal Cortex' in l.name and l.hemi == 'rh'] + rtpj = rtpj[0] + elif rtpj_mode == 'labsn': + #rtpj_str = os.path.join(subjects_dir, 'fsaverage/label/rh.RTPJAnatomical-rh.label') + rtpj_str = os.path.join(subjects_dir, 'fsaverage/label/rh.RTPJ.label') + rtpj = mne.read_label(rtpj_str, subject='fsaverage') + #tmp = [l for l in ipc_labs if l.hemi == 'lh'] + [rtpj] + #ipc_labs = tmp + elif rtpj_mode == 'intersect': + rtpj_str = os.path.join(subjects_dir, 'fsaverage/label/rh.RTPJIntersect-rh.label') + rtpj = mne.read_label(rtpj_str, subject='fsaverage') + + #tmp = [l for l in ipc_labs if l.hemi == 'lh'] + [rtpj_hcp] + #ipc_labs = tmp + + labels.append(rtpj) + + #labels.extend(ipc_labs) + + # optionally include early visual regions as controls + if include_visual: + prim_visual = [l for l in hcp_mmp1_labels if 'Primary Visual Cortex' in l.name] + + # there should be only one b/c of medial merge + prim_visual = prim_visual[0] + + early_visual_lh = label_names.index("Early Visual Cortex-lh") + early_visual_rh = label_names.index("Early Visual Cortex-rh") + early_visual_lh = hcp_mmp1_labels[early_visual_lh] + early_visual_rh = hcp_mmp1_labels[early_visual_rh] + visual = prim_visual + early_visual_lh + early_visual_rh + + labels.append(visual) + + return labels + + +def load_labsn_hcpmmp1_7_rtpj_hcp_plus_vision_labels(): + return load_labsn_hcpmmp1_7_labels(include_visual=True, rtpj_mode='hcp') + + +def load_labsn_hcpmmp1_7_rtpj_intersect_plus_vision_labels(): + return load_labsn_hcpmmp1_7_labels(include_visual=True, rtpj_mode='intersect') + + +def load_labsn_hcpmmp1_7_rtpj_sphere_plus_vision_labels(): + return load_labsn_hcpmmp1_7_labels(include_visual=True, rtpj_mode='labsn') + + +def load_labsn_hcpmmp1_av_rois_small(): + + hcp_mmp1_labels = mne.read_labels_from_annot('fsaverage', + parc='HCPMMP1_combined') + hcp_mmp1_labels = combine_medial_labels(hcp_mmp1_labels) + label_names = [l.name for l in hcp_mmp1_labels] + + #prim_visual_lh = label_names.index("Primary Visual Cortex (V1)-lh") + #prim_visual_rh = label_names.index("Primary Visual Cortex (V1)-rh") + #prim_visual_lh = hcp_mmp1_labels[prim_visual_lh] + #prim_visual_rh = hcp_mmp1_labels[prim_visual_rh] + prim_visual = [l for l in hcp_mmp1_labels if 'Primary Visual Cortex' in l.name] + + # there should be only one b/c of medial merge + prim_visual = prim_visual[0] + + early_visual_lh = label_names.index("Early Visual Cortex-lh") + early_visual_rh = label_names.index("Early Visual Cortex-rh") + early_visual_lh = hcp_mmp1_labels[early_visual_lh] + early_visual_rh = hcp_mmp1_labels[early_visual_rh] + + #visual_lh = prim_visual_lh + early_visual_lh + #visual_rh = prim_visual_rh + early_visual_rh + + visual = prim_visual + early_visual_lh + early_visual_rh + labels = [visual] + + #labels = [visual_lh, visual_rh] + + eac_labs = [l for l in hcp_mmp1_labels if 'Early Auditory Cortex' in l.name] + labels.extend(eac_labs) + + tpo_labs = [l for l in hcp_mmp1_labels if 'Temporo-Parieto-Occipital Junction' in l.name] + labels.extend(tpo_labs) + + dpc_labs = [l for l in hcp_mmp1_labels if 'DorsoLateral Prefrontal Cortex' in l.name] + labels.extend(dpc_labs) + + ## extra labels KC wanted + #pmc_labs = [l for l in hcp_mmp1_labels if 'Premotor Cortex' in l.name] + #labels.extend(pmc_labs) + + #ips_str = glob.glob(os.path.join(subjects_dir, "fsaverage/label/*IPS*labsn*")) + #ips_labs = [mne.read_label(fn, subject='fsaverage') for fn in ips_str] + #labels.extend(ips_labs) + + #rtpj_labs = [l for l in hcp_mmp1_labels if 'Inferior Parietal Cortex-rh' in l.name] + #labels.extend(rtpj_labs) + + return labels + + +def load_labsn_hcpmmp1_av_rois_large(): + + hcp_mmp1_labels = mne.read_labels_from_annot('fsaverage', + parc='HCPMMP1_combined') + hcp_mmp1_labels = combine_medial_labels(hcp_mmp1_labels) + label_names = [l.name for l in hcp_mmp1_labels] + + #prim_visual_lh = label_names.index("Primary Visual Cortex (V1)-lh") + #prim_visual_rh = label_names.index("Primary Visual Cortex (V1)-rh") + #prim_visual_lh = hcp_mmp1_labels[prim_visual_lh] + #prim_visual_rh = hcp_mmp1_labels[prim_visual_rh] + prim_visual = [l for l in hcp_mmp1_labels if 'Primary Visual Cortex' in l.name] + + # there should be only one b/c of medial merge + prim_visual = prim_visual[0] + + early_visual_lh = label_names.index("Early Visual Cortex-lh") + early_visual_rh = label_names.index("Early Visual Cortex-rh") + early_visual_lh = hcp_mmp1_labels[early_visual_lh] + early_visual_rh = hcp_mmp1_labels[early_visual_rh] + + #visual_lh = prim_visual_lh + early_visual_lh + #visual_rh = prim_visual_rh + early_visual_rh + + visual = prim_visual + early_visual_lh + early_visual_rh + labels = [visual] + + #labels = [visual_lh, visual_rh] + + eac_labs = [l for l in hcp_mmp1_labels if 'Early Auditory Cortex' in l.name] + labels.extend(eac_labs) + + tpo_labs = [l for l in hcp_mmp1_labels if 'Temporo-Parieto-Occipital Junction' in l.name] + labels.extend(tpo_labs) + + dpc_labs = [l for l in hcp_mmp1_labels if 'DorsoLateral Prefrontal Cortex' in l.name] + labels.extend(dpc_labs) + + # extra labels KC wanted + pmc_labs = [l for l in hcp_mmp1_labels if 'Premotor Cortex' in l.name] + labels.extend(pmc_labs) + + #ips_str = glob.glob(os.path.join(subjects_dir, "fsaverage/label/*IPS*labsn*")) + #ips_labs = [mne.read_label(fn, subject='fsaverage') for fn in ips_str] + #labels.extend(ips_labs) + + #rtpj_labs = [l for l in hcp_mmp1_labels if 'Inferior Parietal Cortex-rh' in l.name] + #labels.extend(rtpj_labs) + + # glasser 19 + ac_mpc_labs = [l for l in hcp_mmp1_labels if 'Anterior Cingulate and Medial Prefrontal Cortex' in l.name] + labels.extend(ac_mpc_labs) + + return labels + + +def load_labsn_hcpmmp1_av_rois_large_plus_IPS(): + + hcp_mmp1_labels = mne.read_labels_from_annot('fsaverage', + parc='HCPMMP1_combined') + hcp_mmp1_labels = combine_medial_labels(hcp_mmp1_labels) + label_names = [l.name for l in hcp_mmp1_labels] + + #prim_visual_lh = label_names.index("Primary Visual Cortex (V1)-lh") + #prim_visual_rh = label_names.index("Primary Visual Cortex (V1)-rh") + #prim_visual_lh = hcp_mmp1_labels[prim_visual_lh] + #prim_visual_rh = hcp_mmp1_labels[prim_visual_rh] + prim_visual = [l for l in hcp_mmp1_labels if 'Primary Visual Cortex' in l.name] + + # there should be only one b/c of medial merge + prim_visual = prim_visual[0] + + early_visual_lh = label_names.index("Early Visual Cortex-lh") + early_visual_rh = label_names.index("Early Visual Cortex-rh") + early_visual_lh = hcp_mmp1_labels[early_visual_lh] + early_visual_rh = hcp_mmp1_labels[early_visual_rh] + + #visual_lh = prim_visual_lh + early_visual_lh + #visual_rh = prim_visual_rh + early_visual_rh + + visual = prim_visual + early_visual_lh + early_visual_rh + labels = [visual] + + #labels = [visual_lh, visual_rh] + + eac_labs = [l for l in hcp_mmp1_labels if 'Early Auditory Cortex' in l.name] + labels.extend(eac_labs) + + tpo_labs = [l for l in hcp_mmp1_labels if 'Temporo-Parieto-Occipital Junction' in l.name] + labels.extend(tpo_labs) + + dpc_labs = [l for l in hcp_mmp1_labels if 'DorsoLateral Prefrontal Cortex' in l.name] + labels.extend(dpc_labs) + + # extra labels KC wanted + pmc_labs = [l for l in hcp_mmp1_labels if 'Premotor Cortex' in l.name] + labels.extend(pmc_labs) + + ips_str = os.path.join(subjects_dir, "fsaverage/label/*.IPS-labsn.label") + ips_fnames = glob.glob(ips_str) + ips_labels = [mne.read_label(fn, subject='fsaverage') for fn in ips_fnames] + labels.extend(ips_labels) + + #rtpj_labs = [l for l in hcp_mmp1_labels if 'Inferior Parietal Cortex-rh' in l.name] + #labels.extend(rtpj_labs) + + # glasser 19 + ac_mpc_labs = [l for l in hcp_mmp1_labels if 'Anterior Cingulate and Medial Prefrontal Cortex' in l.name] + labels.extend(ac_mpc_labs) + + return labels + + +def make_rtpj_intersect(): + labels = mne.read_labels_from_annot('fsaverage', 'HCPMMP1', 'rh', + subjects_dir=subjects_dir) + + rtpj_str = os.path.join(subjects_dir, 'fsaverage/label/rh.RTPJAnatomical-rh.label') + rtpj = mne.read_label(rtpj_str, subject='fsaverage') + src = mne.read_source_spaces(subjects_dir + '/fsaverage/bem/fsaverage-5-src.fif') + rtpj = rtpj.fill(src) + + mne.write_label(os.path.join(subjects_dir, + 'fsaverage/label/rh.RTPJ.label'), + rtpj) + + props = np.zeros((len(labels), 2)) + for li, label in enumerate(labels): + props[li] = [np.in1d(rtpj.vertices, label.vertices).mean(), + np.in1d(label.vertices, rtpj.vertices).mean()] + order = np.argsort(props[:, 0])[::-1] + for oi in order: + if props[oi, 0] > 0: + name = labels[oi].name.rstrip('-rh').lstrip('R_') + print('%4.1f%% RTPJ vertices cover %4.1f%% of %s' + % (100*props[oi,0], 100*props[oi,1], name)) + + for ii, oi in enumerate(order[:4]): + if ii == 0: + rtpj = labels[oi].copy() + else: + rtpj += labels[oi] + + mne.write_label(os.path.join(subjects_dir, + 'fsaverage/label/rh.RTPJIntersect-rh.label'), + rtpj) + + +def fixup_lipsp(): + labels = mne.read_labels_from_annot('fsaverage', 'HCPMMP1', 'rh', + subjects_dir=subjects_dir) + + lipsp_str = os.path.join(subjects_dir, 'fsaverage/label/lh.LIPSP_tf.label') + lipsp = mne.read_label(lipsp_str, subject='fsaverage') + lipsp.vertices = lipsp.vertices[lipsp.vertices < 10242] + + src = mne.read_source_spaces(subjects_dir + '/fsaverage/bem/fsaverage-5-src.fif') + lipsp = lipsp.fill(src) + + + mne.write_label(os.path.join(subjects_dir, 'fsaverage/label/lh.LIPSP.label'), + lipsp) + + return lipsp + + +#if __name__ == "__main__": +# +# from surfer import Brain +# labels = sps_meglds_base() +# +# subject_id = 'fsaverage' +# hemi = 'both' +# surf = 'inflated' +# +# brain = Brain(subject_id, hemi, surf) +# for l in labels: +# brain.add_label(l) diff --git a/state_space/megssm/message_passing.py b/state_space/megssm/message_passing.py new file mode 100755 index 00000000..e560ff49 --- /dev/null +++ b/state_space/megssm/message_passing.py @@ -0,0 +1,732 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import autograd.numpy as np +from autograd.scipy.linalg import block_diag + +from .util import T_, sym, dot3, _ensure_ndim, component_matrix, hs + +try: + from autograd_linalg import solve_triangular +except ImportError: + raise RuntimeError("must install `autograd_linalg` package") + +# einsum2 is a parallel version of einsum that works for two arguments +try: + from einsum2 import einsum2 +except ImportError: + # rename standard numpy function if don't have einsum2 + print("=> WARNING: using standard numpy.einsum,", + "consider installing einsum2 package") + from numpy import einsum as einsum2 + + +def kalman_filter(Y, A, C, Q, R, mu0, Q0, store_St=True, sum_logliks=True): + """ Kalman filter that broadcasts over the first dimension. + Handles multiple lag dependence using component form. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D*nlag, D*nlag) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N = Y.shape[0] + T, D, Dnlags = A.shape + nlags = Dnlags // D + AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) + + p = C.shape[0] + CC = hs([C, np.zeros((p, D*(nlags-1)))]) + + QQ = np.zeros((T, Dnlags, Dnlags)) + QQ[:,:D,:D] = Q + + QQ0 = block_diag(*[Q0 for _ in range(nlags)]) + + mu_predict = np.stack([np.tile(mu0, nlags) for _ in range(N)], axis=0) + sigma_predict = np.stack([QQ0 for _ in range(N)], axis=0) + + St = np.empty((N, T, p, p)) if store_St else None + + mus_filt = np.zeros((N, T, Dnlags)) + sigmas_filt = np.zeros((N, T, Dnlags, Dnlags)) + + ll = np.zeros(T) + + for t in range(T): + + # condition + # dot3(CC, sigma_predict, CC.T) + R + tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict) + sigma_pred = np.dot(tmp1, CC.T) + R + sigma_pred = sym(sigma_pred) + + if St is not None: + St[...,t,:,:] = sigma_pred + + res = Y[...,t,:] - np.dot(mu_predict, CC.T) + + L = np.linalg.cholesky(sigma_pred) + v = solve_triangular(L, res, lower=True) + + # log-likelihood over all trials + ll[t] = -0.5*(2.*np.sum(np.log(np.diagonal(L, axis1=1, axis2=2))) + + np.sum(v*v) + + N*p*np.log(2.*np.pi)) + + mus_filt[...,t,:] = mu_predict + einsum2('nki,nk->ni', tmp1, + solve_triangular(L, v, 'T', lower=True)) + + tmp2 = solve_triangular(L, tmp1, lower=True) + sigmas_filt[...,t,:,:] = sym(sigma_predict - einsum2('nki,nkj->nij', tmp2, tmp2)) + + # prediction + mu_predict = einsum2('ik,nk->ni', AA[t], mus_filt[...,t,:]) + + sigma_predict = einsum2('ik,nkl->nil', AA[t], sigmas_filt[...,t,:,:]) + sigma_predict = sym(einsum2('nil,jl->nij', sigma_predict, AA[t]) + QQ[t]) + + if sum_logliks: + ll = np.sum(ll) + return ll, mus_filt, sigmas_filt, St + + +def rts_smooth(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False, + store_St=True): + """ RTS smoother that broadcasts over the first dimension. + Handles multiple lag dependence using component form. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D*nlag, D*nlag) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N, T, _ = Y.shape + _, D, Dnlags = A.shape + nlags = Dnlags // D + AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) + + p = C.shape[0] + CC = hs([C, np.zeros((p, D*(nlags-1)))]) + + QQ = np.zeros((T, Dnlags, Dnlags)) + QQ[:,:D,:D] = Q + + QQ0 = block_diag(*[Q0 for _ in range(nlags)]) + + mu_predict = np.empty((N, T+1, Dnlags)) + sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) + + mus_smooth = np.empty((N, T, Dnlags)) + sigmas_smooth = np.empty((N, T, Dnlags, Dnlags)) + + St = np.empty((N, T, p, p)) if store_St else None + + if compute_lag1_cov: + sigmas_smooth_tnt = np.empty((N, T-1, Dnlags, Dnlags)) + else: + sigmas_smooth_tnt = None + + ll = 0. + mu_predict[:,0,:] = np.tile(mu0, nlags) + sigma_predict[:,0,:,:] = QQ0.copy() + + for t in range(T): + + # condition + # sigma_x = dot3(C, sigma_predict, C.T) + R + tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) + sigma_x = einsum2('nik,jk->nij', tmp1, CC) + R + sigma_x = sym(sigma_x) + + if St is not None: + St[...,t,:,:] = sigma_x + + L = np.linalg.cholesky(sigma_x) + # res[n] = Y[n,t,:] = np.dot(C, mu_predict[n,t,:]) + res = Y[...,t,:] - einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) + v = solve_triangular(L, res, lower=True) + + # log-likelihood over all trials + ll += -0.5*(2.*np.sum(np.log(np.diagonal(L, axis1=1, axis2=2))) + + np.sum(v*v) + + N*p*np.log(2.*np.pi)) + + mus_smooth[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', + tmp1, + solve_triangular(L, v, trans='T', lower=True)) + + # tmp2 = L^{-1}*C*sigma_predict + tmp2 = solve_triangular(L, tmp1, lower=True) + sigmas_smooth[:,t,:,:] = sym(sigma_predict[:,t,:,:] - einsum2('nki,nkj->nij', tmp2, tmp2)) + + # prediction + #mu_predict = np.dot(A[t], mus_smooth[t]) + mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_smooth[:,t,:]) + + #sigma_predict = dot3(A[t], sigmas_smooth[t], A[t].T) + Q[t] + tmp = einsum2('ik,nkl->nil', AA[t], sigmas_smooth[:,t,:,:]) + sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) + + for t in range(T-2, -1, -1): + + # these names are stolen from mattjj and slinderman + #temp_nn = np.dot(A[t], sigmas_smooth[n,t,:,:]) + temp_nn = einsum2('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) + + L = np.linalg.cholesky(sigma_predict[:,t+1,:,:]) + v = solve_triangular(L, temp_nn, lower=True) + # Look in Saarka for dfn of Gt_T + Gt_T = solve_triangular(L, v, trans='T', lower=True) + + # {mus,sigmas}_smooth[n,t] contain the filtered estimates so we're + # overwriting them on purpose + #mus_smooth[n,t,:] = mus_smooth[n,t,:] + np.dot(T_(Gt_T), mus_smooth[n,t+1,:] - mu_predict[t+1,:]) + mus_smooth[:,t,:] = mus_smooth[:,t,:] + einsum2('nki,nk->ni', Gt_T, mus_smooth[:,t+1,:] - mu_predict[:,t+1,:]) + + #sigmas_smooth[n,t,:,:] = sigmas_smooth[n,t,:,:] + dot3(T_(Gt_T), sigmas_smooth[n,t+1,:,:] - temp_nn, Gt_T) + tmp = einsum2('nki,nkj->nij', Gt_T, sigmas_smooth[:,t+1,:,:] - sigma_predict[:,t+1,:,:]) + tmp = einsum2('nik,nkj->nij', tmp, Gt_T) + sigmas_smooth[:,t,:,:] = sym(sigmas_smooth[:,t,:,:] + tmp) + + if compute_lag1_cov: + # This matrix is NOT symmetric, so don't symmetrize! + #sigmas_smooth_tnt[n,t,:,:] = np.dot(sigmas_smooth[n,t+1,:,:], Gt_T) + sigmas_smooth_tnt[:,t,:,:] = einsum2('nik,nkj->nij', sigmas_smooth[:,t+1,:,:], Gt_T) + + return ll, mus_smooth, sigmas_smooth, sigmas_smooth_tnt, St + + +def rts_smooth_fast(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False): + """ RTS smoother that broadcasts over the first dimension. + Handles multiple lag dependence using component form. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D*nlag, D*nlag) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N, T, _ = Y.shape + _, D, Dnlags = A.shape + nlags = Dnlags // D + AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) + + L_R = np.linalg.cholesky(R) + + p = C.shape[0] + CC = hs([C, np.zeros((p, D*(nlags-1)))]) + tmp = solve_triangular(L_R, CC, lower=True) + Rinv_CC = solve_triangular(L_R, tmp, trans='T', lower=True) + CCT_Rinv_CC = einsum2('ki,kj->ij', CC, Rinv_CC) + + # tile L_R across number of trials so solve_triangular + # can broadcast over trials properly + L_R = np.tile(L_R, (N, 1, 1)) + + QQ = np.zeros((T, Dnlags, Dnlags)) + QQ[:,:D,:D] = Q + + QQ0 = block_diag(*[Q0 for _ in range(nlags)]) + + mu_predict = np.empty((N, T+1, Dnlags)) + sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) + + mus_smooth = np.empty((N, T, Dnlags)) + sigmas_smooth = np.empty((N, T, Dnlags, Dnlags)) + + if compute_lag1_cov: + sigmas_smooth_tnt = np.empty((N, T-1, Dnlags, Dnlags)) + else: + sigmas_smooth_tnt = None + + ll = 0. + mu_predict[:,0,:] = np.tile(mu0, nlags) + sigma_predict[:,0,:,:] = QQ0.copy() + + I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) + + for t in range(T): + + # condition + tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) + + res = Y[...,t,:] - einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) + + # Rinv * res + tmp2 = solve_triangular(L_R, res, lower=True) + tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) + + # C^T Rinv * res + tmp3 = einsum2('ki,nk->ni', Rinv_CC, res) + + # (Pinv + C^T Rinv C)_inv * tmp3 + L_P = np.linalg.cholesky(sigma_predict[:,t,:,:]) + tmp = solve_triangular(L_P, I_tiled, lower=True) + Pinv = solve_triangular(L_P, tmp, trans='T', lower=True) + tmp4 = sym(Pinv + CCT_Rinv_CC) + L_tmp4 = np.linalg.cholesky(tmp4) + tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) + tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) + + # Rinv C * tmp3 + tmp3 = einsum2('ik,nk->ni', Rinv_CC, tmp3) + + # add the two Woodbury * res terms together + tmp = tmp2 - tmp3 + + mus_smooth[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', tmp1, tmp) + + # Rinv * tmp1 + tmp2 = solve_triangular(L_R, tmp1, lower=True) + tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) + + # C^T Rinv * tmp1 + tmp3 = einsum2('ki,nkj->nij', Rinv_CC, tmp1) + + # (Pinv + C^T Rinv C)_inv * tmp3 + tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) + tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) + + # Rinv C * tmp3 + tmp3 = einsum2('ik,nkj->nij', Rinv_CC, tmp3) + + # add the two Woodbury * tmp1 terms together, left-multiply by tmp1 + tmp = einsum2('nki,nkj->nij', tmp1, tmp2 - tmp3) + + sigmas_smooth[:,t,:,:] = sym(sigma_predict[:,t,:,:] - tmp) + + # prediction + mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_smooth[:,t,:]) + + #sigma_predict = dot3(A[t], sigmas_smooth[t], A[t].T) + Q[t] + tmp = einsum2('ik,nkl->nil', AA[t], sigmas_smooth[:,t,:,:]) + sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) + + for t in range(T-2, -1, -1): + + # these names are stolen from mattjj and slinderman + #temp_nn = np.dot(A[t], sigmas_smooth[n,t,:,:]) + temp_nn = einsum2('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) + + L = np.linalg.cholesky(sigma_predict[:,t+1,:,:]) + v = solve_triangular(L, temp_nn, lower=True) + # Look in Saarka for dfn of Gt_T + Gt_T = solve_triangular(L, v, trans='T', lower=True) + + # {mus,sigmas}_smooth[n,t] contain the filtered estimates so we're + # overwriting them on purpose + #mus_smooth[n,t,:] = mus_smooth[n,t,:] + np.dot(T_(Gt_T), mus_smooth[n,t+1,:] - mu_predict[t+1,:]) + mus_smooth[:,t,:] = mus_smooth[:,t,:] + einsum2('nki,nk->ni', Gt_T, mus_smooth[:,t+1,:] - mu_predict[:,t+1,:]) + + #sigmas_smooth[n,t,:,:] = sigmas_smooth[n,t,:,:] + dot3(T_(Gt_T), sigmas_smooth[n,t+1,:,:] - temp_nn, Gt_T) + tmp = einsum2('nki,nkj->nij', Gt_T, sigmas_smooth[:,t+1,:,:] - sigma_predict[:,t+1,:,:]) + tmp = einsum2('nik,nkj->nij', tmp, Gt_T) + sigmas_smooth[:,t,:,:] = sym(sigmas_smooth[:,t,:,:] + tmp) + + if compute_lag1_cov: + # This matrix is NOT symmetric, so don't symmetrize! + #sigmas_smooth_tnt[n,t,:,:] = np.dot(sigmas_smooth[n,t+1,:,:], Gt_T) + sigmas_smooth_tnt[:,t,:,:] = einsum2('nik,nkj->nij', sigmas_smooth[:,t+1,:,:], Gt_T) + + return ll, mus_smooth, sigmas_smooth, sigmas_smooth_tnt + + + +def predict(Y, A, C, Q, R, mu0, Q0, pred_var=False): + if pred_var: + return _predict_mean_var(Y, A, C, Q, R, mu0, Q0) + else: + return _predict_mean(Y, A, C, Q, R, mu0, Q0) + + +def _predict_mean_var(Y, A, C, Q, R, mu0, Q0): + """ Model predictions for Y given model parameters. + + Handles multiple lag dependence using component form. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D*nlag, D*nlag) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N, T, _ = Y.shape + _, D, Dnlags = A.shape + nlags = Dnlags // D + AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) + + L_R = np.linalg.cholesky(R) + + p = C.shape[0] + CC = hs([C, np.zeros((p, D*(nlags-1)))]) + tmp = solve_triangular(L_R, CC, lower=True) + Rinv_CC = solve_triangular(L_R, tmp, trans='T', lower=True) + CCT_Rinv_CC = einsum2('ki,kj->ij', CC, Rinv_CC) + + # tile L_R across number of trials so solve_triangular + # can broadcast over trials properly + L_R = np.tile(L_R, (N, 1, 1)) + + QQ = np.zeros((T, Dnlags, Dnlags)) + QQ[:,:D,:D] = Q + + QQ0 = block_diag(*[Q0 for _ in range(nlags)]) + + mu_predict = np.empty((N, T+1, Dnlags)) + sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) + + mus_filt = np.empty((N, T, Dnlags)) + sigmas_filt = np.empty((N, T, Dnlags, Dnlags)) + + mu_predict[:,0,:] = np.tile(mu0, nlags) + sigma_predict[:,0,:,:] = QQ0.copy() + + I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) + + Yhat = np.empty_like(Y) + St = np.empty((N, T, p, p)) + + for t in range(T): + + # condition + # sigma_x = dot3(C, sigma_predict, C.T) + R + tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) + sigma_x = einsum2('nik,jk->nij', tmp1, CC) + R + sigma_x = sym(sigma_x) + + St[...,t,:,:] = sigma_x + + L = np.linalg.cholesky(sigma_x) + Yhat[...,t,:] = einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) + res = Y[...,t,:] - Yhat[...,t,:] + + v = solve_triangular(L, res, lower=True) + + mus_filt[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', + tmp1, + solve_triangular(L, v, trans='T', lower=True)) + + # tmp2 = L^{-1}*C*sigma_predict + tmp2 = solve_triangular(L, tmp1, lower=True) + sigmas_filt[:,t,:,:] = sym(sigma_predict[:,t,:,:] - einsum2('nki,nkj->nij', tmp2, tmp2)) + + # prediction + #mu_predict = np.dot(A[t], mus_filt[t]) + mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_filt[:,t,:]) + + #sigma_predict = dot3(A[t], sigmas_filt[t], A[t].T) + Q[t] + tmp = einsum2('ik,nkl->nil', AA[t], sigmas_filt[:,t,:,:]) + sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) + + # just return the diagonal of the St matrices for marginal predictive + # variances + return Yhat, np.diagonal(St, axis1=-2, axis2=-1) + + +def _predict_mean(Y, A, C, Q, R, mu0, Q0): + """ Model predictions for Y given model parameters. + Handles multiple lag dependence using component form. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D*nlag, D*nlag) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N, T, _ = Y.shape + _, D, Dnlags = A.shape + nlags = Dnlags // D + AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) + + L_R = np.linalg.cholesky(R) + + p = C.shape[0] + CC = hs([C, np.zeros((p, D*(nlags-1)))]) + tmp = solve_triangular(L_R, CC, lower=True) + Rinv_CC = solve_triangular(L_R, tmp, trans='T', lower=True) + CCT_Rinv_CC = einsum2('ki,kj->ij', CC, Rinv_CC) + + # tile L_R across number of trials so solve_triangular + # can broadcast over trials properly + L_R = np.tile(L_R, (N, 1, 1)) + + QQ = np.zeros((T, Dnlags, Dnlags)) + QQ[:,:D,:D] = Q + + QQ0 = block_diag(*[Q0 for _ in range(nlags)]) + + mu_predict = np.empty((N, T+1, Dnlags)) + sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) + + mus_filt = np.empty((N, T, Dnlags)) + sigmas_filt = np.empty((N, T, Dnlags, Dnlags)) + + mu_predict[:,0,:] = np.tile(mu0, nlags) + sigma_predict[:,0,:,:] = QQ0.copy() + + I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) + + Yhat = np.empty_like(Y) + + for t in range(T): + + # condition + tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) + + Yhat[...,t,:] = einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) + res = Y[...,t,:] - Yhat[...,t,:] + + # Rinv * res + tmp2 = solve_triangular(L_R, res, lower=True) + tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) + + # C^T Rinv * res + tmp3 = einsum2('ki,nk->ni', Rinv_CC, res) + + # (Pinv + C^T Rinv C)_inv * tmp3 + L_P = np.linalg.cholesky(sigma_predict[:,t,:,:]) + tmp = solve_triangular(L_P, I_tiled, lower=True) + Pinv = solve_triangular(L_P, tmp, trans='T', lower=True) + tmp4 = sym(Pinv + CCT_Rinv_CC) + L_tmp4 = np.linalg.cholesky(tmp4) + tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) + tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) + + # Rinv C * tmp3 + tmp3 = einsum2('ik,nk->ni', Rinv_CC, tmp3) + + # add the two Woodbury * res terms together + tmp = tmp2 - tmp3 + + mus_filt[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', tmp1, tmp) + + # Rinv * tmp1 + tmp2 = solve_triangular(L_R, tmp1, lower=True) + tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) + + # C^T Rinv * tmp1 + tmp3 = einsum2('ki,nkj->nij', Rinv_CC, tmp1) + + # (Pinv + C^T Rinv C)_inv * tmp3 + tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) + tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) + + # Rinv C * tmp3 + tmp3 = einsum2('ik,nkj->nij', Rinv_CC, tmp3) + + # add the two Woodbury * tmp1 terms together, left-multiply by tmp1 + tmp = einsum2('nki,nkj->nij', tmp1, tmp2 - tmp3) + + sigmas_filt[:,t,:,:] = sym(sigma_predict[:,t,:,:] - tmp) + + # prediction + mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_filt[:,t,:]) + + #sigma_predict = dot3(A[t], sigmas_filt[t], A[t].T) + Q[t] + tmp = einsum2('ik,nkl->nil', AA[t], sigmas_filt[:,t,:,:]) + sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) + + return Yhat + + +def predict_step(mu_filt, sigma_filt, A, Q): + mu_predict = einsum2('ik,nk->ni', A, mu_filt) + tmp = einsum2('ik,nkl->nil', A, sigma_filt) + sigma_predict = sym(einsum2('nil,jl->nij', tmp, A) + Q) + + return mu_predict, sigma_predict + + +def condition(y, C, R, mu_predict, sigma_predict): + # dot3(C, sigma_predict, C.T) + R + tmp1 = einsum2('ik,nkj->nij', C, sigma_predict) + sigma_pred = einsum2('nik,jk->nij', tmp1, C) + R + sigma_pred = sym(sigma_pred) + + L = np.linalg.cholesky(sigma_pred) + # the transpose works b/c of how dot broadcasts + #y_hat = np.dot(mu_predict, C.T) + y_hat = einsum2('ik,nk->ni', C, mu_predict) + res = y - y_hat + v = solve_triangular(L, res, lower=True) + + mu_filt = mu_predict + einsum2('nki,nk->ni', tmp1, solve_triangular(L, v, trans='T', lower=True)) + + tmp2 = solve_triangular(L, tmp1, lower=True) + sigma_filt = sym(sigma_predict - einsum2('nki,nkj->nij', tmp2, tmp2)) + + return y_hat, mu_filt, sigma_filt + + +def logZ(Y, A, C, Q, R, mu0, Q0): + """ Log marginal likelihood using the Kalman filter. + + The algorithm broadcasts over the first dimension which are considered + to be independent realizations. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D, D) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N = Y.shape[0] + T, D, _ = A.shape + p = C.shape[0] + + mu_predict = np.stack([np.copy(mu0) for _ in range(N)], axis=0) + sigma_predict = np.stack([np.copy(Q0) for _ in range(N)], axis=0) + + ll = 0. + + for t in range(T): + + # condition + # sigma_x = dot3(C, sigma_predict, C.T) + R + tmp1 = einsum2('ik,nkj->nij', C, sigma_predict) + sigma_x = einsum2('nik,jk->nij', tmp1, C) + R + sigma_x = sym(sigma_x) + + # res[n] = Y[n,t,:] = np.dot(C, mu_predict[n]) + res = Y[...,t,:] - einsum2('ik,nk->ni', C, mu_predict) + + L = np.linalg.cholesky(sigma_x) + v = solve_triangular(L, res, lower=True) + + # log-likelihood over all trials + ll += -0.5*(2.*np.sum(np.log(np.diagonal(L, axis1=1, axis2=2))) + + np.sum(v*v) + + N*p*np.log(2.*np.pi)) + + mus_filt = mu_predict + einsum2('nki,nk->ni', + tmp1, + solve_triangular(L, v, trans='T', lower=True)) + + # tmp2 = L^{-1}*C*sigma_predict + tmp2 = solve_triangular(L, tmp1, lower=True) + sigmas_filt = sigma_predict - einsum2('nki,nkj->nij', tmp2, tmp2) + sigmas_filt = sym(sigmas_filt) + + # prediction + #mu_predict = np.dot(A[t], mus_filt[t]) + mu_predict = einsum2('ik,nk->ni', A[t], mus_filt) + + # originally this worked with time-varying Q, but now it's fixed + #sigma_predict = dot3(A[t], sigmas_filt[t], A[t].T) + Q[t] + sigma_predict = einsum2('ik,nkl->nil', A[t], sigmas_filt) + sigma_predict = einsum2('nil,jl->nij', sigma_predict, A[t]) + Q + sigma_predict = sym(sigma_predict) + + return np.sum(ll) diff --git a/state_space/megssm/mne_util.py b/state_space/megssm/mne_util.py new file mode 100644 index 00000000..1402b648 --- /dev/null +++ b/state_space/megssm/mne_util.py @@ -0,0 +1,305 @@ +""" MNE-Python utility functions for preprocessing data and constructing + matrices necessary for MEGLDS analysis """ + +import mne +import numpy as np +import os.path as op + +from mne.io.pick import pick_types +from mne.utils import logger +from mne import label_sign_flip + +from scipy.sparse import csc_matrix, csr_matrix, diags +from sklearn.decomposition import PCA + +# from util import Carray ##skip import just pasted; util also from MEGLDS repo +Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') +Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') +Carray = Carray64 + + +class ROIToSourceMap(object): + """ class for computing ROI-to-source space mapping matrix + + Notes + ----- + The following variables defined here correspond to various matrices + defined in :footcite:`yang_state-space_2016`: + - fwd_src_snsr : G + - fwd_roi_snsr : C + - fwd_src_roi : L + - snsr_cov : Q_e + - roi_cov : Q + - roi_cov_0 : Q0 """ + + def __init__(self, fwd, labels, label_flip=False): + + src = fwd['src'] + + roiidx = list() + vertidx = list() + + n_lhverts = len(src[0]['vertno']) + n_rhverts = len(src[1]['vertno']) + n_verts = n_lhverts + n_rhverts + offsets = {'lh': 0, 'rh': n_lhverts} + + hemis = {'lh': 0, 'rh': 1} + + # index vector of which ROI a source point belongs to + which_roi = np.zeros(n_verts, dtype=np.int64) + + data = [] + for li, lab in enumerate(labels): + + this_data = np.round(label_sign_flip(lab, src)) + if not label_flip: + this_data.fill(1.) + data.append(this_data) + if isinstance(lab, mne.Label): + comp_labs = [lab] + elif isinstance(lab, mne.BiHemiLabel): + comp_labs = [lab.lh, lab.rh] + + for clab in comp_labs: + hemi = clab.hemi + hi = 0 if hemi == 'lh' else 1 + + lverts = clab.get_vertices_used(vertices=src[hi]['vertno']) + + # gets the indices in the source space vertex array, not the huge + # array. + # use `src[hi]['vertno'][lverts]` to get surface vertex indices to + # plot. + lverts = np.searchsorted(src[hi]['vertno'], lverts) + lverts += offsets[hemi] + vertidx.extend(lverts) + roiidx.extend(np.full(lverts.size, li, dtype=np.int64)) + + # add 1 b/c 0 corresponds to unassigned variance + which_roi[lverts] = li + 1 + + N = len(labels) + M = n_verts + + # construct sparse fwd_src_roi matrix + data = np.concatenate(data) + vertidx = np.array(vertidx, int) + roiidx = np.array(roiidx, int) + assert data.shape == vertidx.shape == roiidx.shape + fwd_src_roi = csc_matrix((data, (vertidx, roiidx)), shape=(M, N)) + + self.fwd = fwd + self.fwd_src_roi = fwd_src_roi + self.which_roi = which_roi + self.offsets = offsets + self.n_lhverts = n_lhverts + self.n_rhverts = n_rhverts + self.labels = labels + + return + + # @property + # def fwd_src_sn(self): + # return self.fwd['sol']['data'] + + # @property + # def fwd_src_roi(self): + # return self._fwd_src_roi + + # @fwd_src_roi.setter + # def fwd_src_roi(self, val): + # self._fwd_src_roi = val + + # @property + # def which_roi(self): + # return self._which_roi + + # @which_roi.setter + # def which_roi(self, val): + # self._which_roi = val + + # @property + # def fwd_roi_snsr(self): + # from util import Carray + # return Carray(csr_matrix.dot(self.fwd_src_roi.T, self.fwd_src_sn.T).T) + + # def get_label_vinds(self, label): + # li = self.labels.index(label) + # if isinstance(label, mne.Label): + # label_vert_idx = self.fwd_src_roi[:, li].nonzero()[0] + # label_vert_idx -= self.offsets[label.hemi] + # return label_vert_idx + # elif isinstance(label, mne.BiHemiLabel): + # # these labels store both hemispheres so subtract the rh offset + # # from that part of the vertex array + # lh_label_vert_idx = self.fwd_src_roi[:self.n_lhverts, li].nonzero()[0] + # rh_label_vert_idx = self.fwd_src_roi[self.n_lhverts:, li].nonzero()[0] + # rh_label_vert_idx[self.n_lhverts:] -= self.offsets['rh'] + # return [lh_label_vert_idx, rh_label_vert_idx] + + # def get_label_verts(self, label, src): + # # if you're thinking of using this to plot, why not just use + # # brain.add_label from pysurfer? + # if isinstance(label, mne.Label): + # hi = 0 if label.hemi == 'lh' else 1 + # label_vert_idx = self.get_label_vinds(label) + # varray = src[hi]['vertno'][label_vert_idx] + # elif isinstance(label, mne.BiHemiLabel): + # lh_label_vert_idx, rh_label_vert_idx = self.get_label_vinds(label) + # varray = [src[0]['vertno'][lh_label_vert_idx], + # src[1]['vertno'][rh_label_vert_idx]] + # return varray + + # def get_hemi_idx(self, label): + # if isinstance(label, mne.Label): + # return 0 if label.hemi == 'lh' else 1 + # elif isinstance(label, mne.BiHemiLabel): + # hemis = [None] * 2 + # for i, lab in enumerate([label.lh, label.rh]): + # hemis[i] = 0 if lab.hemi == 'lh' else 1 + # return hemis + +def apply_projs(epochs, fwd, cov): + """ apply projection operators to fwd and cov """ + proj, _ = mne.io.proj.setup_proj(epochs.info, activate=False) + fwd_src_sn = fwd['sol']['data'] + fwd['sol']['data'] = np.dot(proj, fwd_src_sn) + + roi_cov = cov.data + if not np.allclose(np.dot(proj, roi_cov), roi_cov): + roi_cov = np.dot(proj, np.dot(roi_cov, proj.T)) + cov.data = roi_cov + + return fwd, cov + + +def _scale_sensor_data(epochs, fwd, cov, roi_to_src): + """ apply per-channel-type scaling to epochs, forward, and covariance """ + + epochs = epochs.copy().pick('data', exclude='bads') + info = epochs.info.copy() + data = epochs.get_data().copy() + snsr_cov = cov.pick_channels(epochs.ch_names, ordered=True).data + fwd = mne.convert_forward_solution(fwd, force_fixed=True) + fwd_src_snsr = fwd.pick_channels(epochs.ch_names, ordered=True)['sol']['data'] + del cov, fwd, epochs #neccessary? + + # rescale data according to covariance whitener? + # rescale_cov = mne.make_ad_hoc_cov(info) + # scaler = mne.cov.compute_whitener(rescale_cov, info) + # del rescale_cov + # fwd_src_snsr = scaler[0] @ fwd_src_snsr + # snsr_cov = scaler[0] @ snsr_cov + # data = scaler[0] @ data + + fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, fwd_src_snsr.T).T) + + epochs = mne.EpochsArray(data, info) + + return fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs + + +def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', + pctvar=0.99, mean_center=False, label_flip=False): + """ apply sensor scaling, PCA dimensionality reduction with/without + whitening, and mean-centering to subject data """ + + if dim_mode not in ['rank', 'pctvar', 'whiten']: + raise ValueError("dim_mode must be in {'rank', 'pctvar', 'whiten'}") + + print("running pca for subject %s" % subject_name) + + # compute ROI-to-source map + roi_to_src = ROIToSourceMap(fwd, labels, label_flip) + if dim_mode == 'whiten': + + fwd_src_snsr, fwd_roi_snsr, Q_snsr, epochs = \ + _scale_sensor_data(epochs, fwd, cov, roi_to_src) + dat = epochs.get_data() + dat = Carray(np.swapaxes(dat, -1, -2)) + + if mean_center: + dat -= np.mean(dat, axis=1, keepdims=True) + + dat_stacked = np.reshape(dat, (-1, dat.shape[-1])) + + W, _ = mne.cov.compute_whitener(subject.sensor_cov, + info=subject.epochs_list[0].info, + pca=True) + print("whitener for subject %s using %d principal components" % + (subject_name, W.shape[0])) + + else: + + fwd_src_snsr, fwd_roi_snsr, Q_snsr, epochs = \ + _scale_sensor_data(epochs, fwd, cov, roi_to_src) + dat = epochs.get_data() + dat = Carray(np.swapaxes(dat, -1, -2)) + + if mean_center: + dat -= np.mean(dat, axis=1, keepdims=True) + + dat_stacked = np.reshape(dat, (-1, dat.shape[-1])) + + pca = PCA() + pca.fit(dat_stacked) + + if dim_mode == 'rank': + idx = np.linalg.matrix_rank(np.cov(dat_stacked, rowvar=False)) + else: + idx = np.where(np.cumsum(pca.explained_variance_ratio_) > pctvar)[0][0] + + idx = np.maximum(idx, len(labels)) + W = pca.components_[:idx] + print("subject %s using %d principal components" % (subject_name, idx)) + + ntrials, T, _ = dat.shape + dat_pca = np.dot(dat_stacked, W.T) + dat_pca = np.reshape(dat_pca, (ntrials, T, -1)) + + fwd_src_snsr_pca = np.dot(W, fwd_src_snsr) + fwd_roi_snsr_pca = np.dot(W, fwd_roi_snsr) + Q_snsr_pca = np.dot(W,np.dot(Q_snsr, W.T)) + + data = dat_pca + + return data, fwd_roi_snsr_pca, fwd_src_snsr_pca, Q_snsr_pca, roi_to_src.which_roi + + +def combine_medial_labels(labels, subject='fsaverage', surf='white', + dist_limit=0.02): + """ combine each hemi pair of labels on medial wall into single label """ + subjects_dir = mne.get_config('SUBJECTS_DIR') + rrs = dict((hemi, mne.read_surface(op.join(subjects_dir, subject, 'surf', + '%s.%s' % (hemi, surf)))[0] / 1000.) + for hemi in ('lh', 'rh')) + use_labels = list() + used = np.zeros(len(labels), bool) + + logger.info('Matching medial regions for %s labels on %s %s, d=%0.1f mm' + % (len(labels), subject, surf, 1000 * dist_limit)) + + for li1, l1 in enumerate(labels): + if used[li1]: + continue + used[li1] = True + use_label = l1.copy() + rr1 = rrs[l1.hemi][l1.vertices] + for li2 in np.where(~used)[0]: + l2 = labels[li2] + same_name = (l2.name.replace(l2.hemi, '') == + l1.name.replace(l1.hemi, '')) + if l2.hemi != l1.hemi and same_name: + rr2 = rrs[l2.hemi][l2.vertices] + mean_min = np.mean(mne.surface._compute_nearest( + rr1, rr2, return_dists=True)[1]) + if mean_min <= dist_limit: + use_label += l2 + used[li2] = True + logger.info(' Matched: ' + l1.name) + use_labels.append(use_label) + + logger.info('Total %d labels' % (len(use_labels),)) + + return use_labels \ No newline at end of file diff --git a/state_space/megssm/models.py b/state_space/megssm/models.py new file mode 100755 index 00000000..2ec1d104 --- /dev/null +++ b/state_space/megssm/models.py @@ -0,0 +1,867 @@ +import sys + +import autograd.numpy as np +import scipy.optimize as spopt + +from autograd import grad #autograd --> jax +from autograd import value_and_grad as vgrad +from scipy.linalg import LinAlgError + +from .util import _ensure_ndim, rand_stable, rand_psd +from .util import linesearch, soft_thresh_At, block_thresh_At +from .util import relnormdiff +from .message_passing import kalman_filter, rts_smooth, rts_smooth_fast +from .message_passing import predict_step, condition +from .numpy_numthreads import numpy_num_threads + +from .mne_util import ROIToSourceMap, _scale_sensor_data, run_pca_on_subject + +try: + from autograd_linalg import solve_triangular +except ImportError: + raise RuntimeError("must install `autograd_linalg` package") + +# einsum2 is a parallel version of einsum that works for two arguments +try: + from einsum2 import einsum2 +except ImportError: + # rename standard numpy function if don't have einsum2 + print("=> WARNING: using standard numpy.einsum,", + "consider installing einsum2 package") + from autograd.numpy import einsum as einsum2 + +from datetime import datetime + + +# TODO: add documentation to all methods +class _MEGModel(object): + """ Base class for any model applied to MEG data that handles storing and + unpacking data from tuples. """ + + def __init__(self): + self._subjectdata = None + self._timepts = 0 + self._ntrials_all = 0 + self._nsubjects = 0 + + def set_data(self, subjectdata): + timepts_lst = [self.unpack_subject_data(e)[0].shape[1] for e in subjectdata] + assert len(list(set(timepts_lst))) == 1 + self._timepts = timepts_lst[0] + ntrials_lst = [self.unpack_subject_data(e)[0].shape[0] for e in \ + subjectdata] + self._ntrials_all = np.sum(ntrials_lst) + self._nsubjects = len(subjectdata) + self._subjectdata = subjectdata + + def unpack_all_subject_data(self): + if self._subjectdata is None: + raise ValueError("use set_data to add subject data") + return map(self.unpack_subject_data, self._subjectdata) + + @classmethod + def unpack_subject_data(cls, sdata): + obs, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + Y = obs + w_s = 1. + if isinstance(obs, tuple): + if len(obs) == 2: + Y, w_s = obs + else: + raise ValueError("invalid format for subject data") + else: + Y = obs + w_s = 1. + + return Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi + + +# TODO: add documentation to all methods +# TODO: make some methods "private" (leading underscore) if necessary +class MEGLDS(_MEGModel): + """ State-space model for MEG data, as described in "A state-space model of + cross-region dynamic connectivity in MEG/EEG", Yang et al., NIPS 2016. + """ + + def __init__(self, num_roi, timepts, A_t_=None, roi_cov=None, mu0=None, roi_cov_0=None, + log_sigsq_lst=None, lam0=0., lam1=0., penalty='ridge', + store_St=True): + + super().__init__() + + set_default = \ + lambda prm, val, deflt: \ + self.__setattr__(prm, val.copy() if val is not None else deflt) + + # initialize parameters + set_default("A_t_", A_t_, + np.stack([rand_stable(num_roi, maxew=0.7) for _ in range(timepts)], + axis=0)) + set_default("roi_cov", roi_cov, rand_psd(num_roi)) + set_default("mu0", mu0, np.zeros(num_roi)) + set_default("roi_cov_0", roi_cov_0, rand_psd(num_roi)) + set_default("log_sigsq_lst", log_sigsq_lst, + [np.log(np.random.gamma(2, 1, size=num_roi+1))]) + + self.lam0 = lam0 + self.lam1 = lam1 + + if penalty not in ('ridge', 'lasso', 'group-lasso'): + raise ValueError('penalty must be one of: ridge, lasso,' \ + + ' group-lasso') + self._penalty = penalty + + # initialize lists of smoothed estimates + self._mus_smooth_lst = None + self._sigmas_smooth_lst = None + self._sigmas_tnt_smooth_lst = None + self._loglik = None + self._store_St = bool(store_St) + + # initialize sufficient statistics + timepts, num_roi, _ = self.A_t_.shape + self._B0 = np.zeros((num_roi, num_roi)) + self._B1 = np.zeros((timepts-1, num_roi, num_roi)) + self._B3 = np.zeros((timepts-1, num_roi, num_roi)) + self._B2 = np.zeros((timepts-1, num_roi, num_roi)) + self._B4 = list() + + self._subject_data = dict() + + def add_subject(self, subject, subject_dir, epochs,labels, fwd, cov): + roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map + fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs = \ + _scale_sensor_data(epochs, fwd, cov, roi_to_src) + + # cov = cov.pick_channels(epochs.ch_names) + sdata = run_pca_on_subject(subject, epochs, fwd, cov, labels) #check for channel mismatch + data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + subjectdata = [(data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi)] + + self.set_data(subjectdata) + + # epochs, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = subjectdata + self._subject_data[subject] = dict() + self._subject_data[subject]['epochs'] = epochs + self._subject_data[subject]['fwd_src_snsr'] = fwd_src_snsr + self._subject_data[subject]['fwd_roi_snsr'] = fwd_roi_snsr + self._subject_data[subject]['snsr_cov'] = snsr_cov + self._subject_data[subject]['labels'] = labels + self._subject_data[subject]['which_roi'] = which_roi + + def set_data(self, subjectdata): + # add subject data, re-generate log_sigsq_lst if necessary + super().set_data(subjectdata) + if len(self.log_sigsq_lst) != self._nsubjects: + num_roi = self.log_sigsq_lst[0].shape[0] + self.log_sigsq_lst = [np.log(np.random.gamma(2, 1, size=num_roi)) + for _ in range(self._nsubjects)] + + # reset smoothed estimates and log-likelihood (no longer valid if + # new data was added) + self._mus_smooth_lst = None + self._sigmas_smooth_lst = None + self._sigmas_tnt_smooth_lst = None + self._loglik = None + self._B4 = [None] * self._nsubjects + + # TODO: figure out how to initialize smoothed parameters so this doesn't + # break, e.g. if _em_objective is called before em for some reason + def _em_objective(self): + + _, num_roi, _ = self.A_t_.shape + + L_roi_cov_0 = np.linalg.cholesky(self.roi_cov_0) + L_roi_cov = np.linalg.cholesky(self.roi_cov) + + L1 = 0. + L2 = 0. + L3 = 0. + + obj = 0. + for s, sdata in enumerate(self.unpack_all_subject_data()): + + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + + ntrials, timepts, _ = Y.shape + + sigsq_vals = np.exp(self.log_sigsq_lst[s]) + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) + L_R = np.linalg.cholesky(R) + + if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None + or self._sigmas_tnt_smooth_lst is None): + roi_cov_t = _ensure_ndim(self.roi_cov, timepts, 3) + with numpy_num_threads(1): + _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ + rts_smooth_fast(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, + self.roi_cov_0, compute_lag1_cov=True) + + else: + mus_smooth = self._mus_smooth_lst[s] + sigmas_smooth = self._sigmas_smooth_lst[s] + sigmas_tnt_smooth = self._sigmas_tnt_smooth_lst[s] + + x_smooth_0_outer = einsum2('ri,rj->rij', mus_smooth[:,0,:num_roi], + mus_smooth[:,0,:num_roi]) + B0 = w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + x_smooth_0_outer, + axis=0) + + x_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], + mus_smooth[:,1:,:num_roi]) + B1 = w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + x_smooth_outer, axis=0) + + z_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,:-1,:], + mus_smooth[:,:-1,:]) + B3 = w_s*np.sum(sigmas_smooth[:,:-1,:,:] + z_smooth_outer, axis=0) + + mus_smooth_outer_l1 = einsum2('rti,rtj->rtij', + mus_smooth[:,1:,:num_roi], + mus_smooth[:,:-1,:]) + B2 = w_s*np.sum(sigmas_tnt_smooth[:,:,:num_roi,:] + mus_smooth_outer_l1, + axis=0) + + # obj += L1(roi_cov_0) + L_roi_cov_0_inv_B0 = solve_triangular(L_roi_cov_0, B0, lower=True) + L1 += (ntrials*2.*np.sum(np.log(np.diag(L_roi_cov_0))) + + np.trace(solve_triangular(L_roi_cov_0, L_roi_cov_0_inv_B0, lower=True, + trans='T'))) + + At = self.A_t_[:-1] + AtB2T = einsum2('tik,tjk->tij', At, B2) + B2AtT = einsum2('tik,tjk->tij', B2, At) + tmp = einsum2('tik,tkl->til', At, B3) + AtB3AtT = einsum2('tik,tjk->tij', tmp, At) + + tmp = np.sum(B1 - AtB2T - B2AtT + AtB3AtT, axis=0) + + # obj += L2(roi_cov, At) + L_roi_cov_inv_tmp = solve_triangular(L_roi_cov, tmp, lower=True) + L2 += (ntrials*(timepts-1)*2.*np.sum(np.log(np.diag(L_roi_cov))) + + np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_tmp, lower=True, + trans='T'))) + + res = Y - einsum2('ik,ntk->nti', fwd_roi_snsr, mus_smooth[:,:,:num_roi]) + CP_smooth = einsum2('ik,ntkj->ntij', fwd_roi_snsr, sigmas_smooth[:,:,:num_roi,:num_roi]) + + # TODO: np.sum does not parallelize over the accumulators, possible + # bottleneck. + B4 = w_s*(np.sum(einsum2('nti,ntj->ntij', res, res), axis=(0,1)) + + np.sum(einsum2('ntik,jk->ntij', CP_smooth, fwd_roi_snsr), + axis=(0,1))) + self._B4[s] = B4 + + # obj += L3(sigsq_vals) + L_R_inv_B4 = solve_triangular(L_R, B4, lower=True) + L3 += (ntrials*timepts*2*np.sum(np.log(np.diag(L_R))) + + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, + trans='T'))) + + obj = (L1 + L2 + L3) / self._ntrials_all + + # obj += penalty + if self.lam0 > 0.: + if self._penalty == 'ridge': + obj += self.lam0*np.sum(At**2) + elif self._penalty == 'lasso': + At_diag = np.diagonal(At, axis1=-2, axis2=-1) + sum_At_diag = np.sum(np.abs(At_diag)) + obj += self.lam0*(np.sum(np.abs(At)) - sum_At_diag) + elif self._penalty == 'group-lasso': + At_diag = np.diagonal(At, axis1=-2, axis2=-1) + norm_At_diag = np.sum(np.linalg.norm(At_diag, axis=0)) + norm_At = np.sum(np.linalg.norm(At, axis=0)) + obj += self.lam1*(norm_At - norm_At_diag) + if self.lam1 > 0.: + AtmAtm1_2 = (At[1:] - At[:-1])**2 + obj += self.lam1*np.sum(AtmAtm1_2) + + return obj + + def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, A_t_roi_cov_tol=1e-6, verbose=0, + update_A_t_=True, update_roi_cov=True, update_roi_cov_0=True, stationary_A_t_=False, + diag_roi_cov=False, update_sigsq=True, do_final_smoothing=True, + average_mus_smooth=True, Atrue=None, tau=0.1, c1=1e-4): + + fxn_start = datetime.now() + + timepts, num_roi, _ = self.A_t_.shape + + # make initial A_t_ stationary if stationary_A_t_ option specified + if stationary_A_t_: + self.A_t_[:] = np.mean(self.A_t_, axis=0) + + # set parameters for (A_t_, roi_cov) optimization + self._A_t_roi_cov_niter = A_t_roi_cov_niter + self._A_t_roi_cov_tol = A_t_roi_cov_tol + + # make initial roi_cov, roi_cov_0 diagonal if diag_roi_cov specified + if diag_roi_cov: + self.roi_cov_0 = np.diag(np.diag(self.roi_cov_0)) + self.roi_cov = np.diag(np.diag(self.roi_cov)) + + + # keeping track of objective value and best parameters + objvals = np.zeros(niter+1) + converged = False + best_objval = np.finfo('float').max + best_params = (self.A_t_.copy(), self.roi_cov.copy(), self.mu0.copy(), + self.roi_cov_0.copy(), [l.copy() for l in self.log_sigsq_lst]) + + # previous parameter values (for checking convergence) + At_prev = None + roi_cov_prev = None + roi_cov_0_prev = None + log_sigsq_lst_prev = None + + if Atrue is not None: + import matplotlib.pyplot as plt + fig_A_t_, ax_A_t_ = plt.subplots(num_roi, num_roi, sharex=True, sharey=True) + plt.ion() + + # calculate initial objective value, check for updated best iterate + # have to do e-step here to initialize suff stats for _m_step + if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None + or self._sigmas_tnt_smooth_lst is None): + self._e_step(verbose=verbose-1) + + objval = self._em_objective() + objvals[0] = objval + + for it in range(1, niter+1): + + iter_start = datetime.now() + + if verbose > 0: + print("em: it %d / %d" % (it, niter)) + sys.stdout.flush() + sys.stderr.flush() + + # record values from previous M-step + At_prev = self.A_t_[:-1].copy() + roi_cov_prev = self.roi_cov.copy() + roi_cov_0_prev = self.roi_cov_0.copy() + log_sigsq_lst_prev = np.array(self.log_sigsq_lst).copy() + + self._m_step(update_A_t_=update_A_t_, update_roi_cov=update_roi_cov, + update_roi_cov_0=update_roi_cov_0, stationary_A_t_=stationary_A_t_, + diag_roi_cov=diag_roi_cov, update_sigsq=update_sigsq, + tau=tau, c1=c1, verbose=verbose) + + if Atrue is not None: + for i in range(num_roi): + for j in range(num_roi): + ax_A_t_[i, j].cla() + ax_A_t_[i, j].plot(Atrue[:-1, i, j], color='green') + ax_A_t_[i, j].plot(self.A_t_[:-1, i, j], color='red', + alpha=0.7) + fig_A_t_.tight_layout() + fig_A_t_.canvas.draw() + plt.pause(1. / 60.) + + self._e_step(verbose=verbose-1) + + # calculate objective value, check for updated best iterate + objval = self._em_objective() + objvals[it] = objval + + if verbose > 0: + print(" objective: %.4e" % objval) + At = self.A_t_[:-1] + maxAt = np.max(np.abs(np.triu(At, k=1) + np.tril(At, k=-1))) + print(" max |A_t|: %.4e" % (maxAt,)) + sys.stdout.flush() + sys.stderr.flush() + + if objval < best_objval: + best_objval = objval + best_params = (self.A_t_.copy(), self.roi_cov.copy(), self.mu0.copy(), + self.roi_cov_0.copy(), + [l.copy() for l in self.log_sigsq_lst]) + + # check for convergence + if it >= 1: + relnormdiff_At = relnormdiff(self.A_t_[:-1], At_prev) + relnormdiff_roi_cov = relnormdiff(self.roi_cov, roi_cov_prev) + relnormdiff_roi_cov_0 = relnormdiff(self.roi_cov_0, roi_cov_0_prev) + relnormdiff_log_sigsq_lst = \ + np.array( + [relnormdiff(self.log_sigsq_lst[s], + log_sigsq_lst_prev[s]) + for s in range(len(self.log_sigsq_lst))]) + params_converged = (relnormdiff_At <= tol) and \ + (relnormdiff_roi_cov <= tol) and \ + (relnormdiff_roi_cov_0 <= tol) and \ + np.all(relnormdiff_log_sigsq_lst <= tol) + + relobjdiff = np.abs((objval - objvals[it-1]) / objval) + + if verbose > 0: + print(" relnormdiff_At: %.3e" % relnormdiff_At) + print(" relnormdiff_roi_cov: %.3e" % relnormdiff_roi_cov) + print(" relnormdiff_roi_cov_0: %.3e" % relnormdiff_roi_cov_0) + print(" relnormdiff_log_sigsq_lst:", + relnormdiff_log_sigsq_lst) + print(" relobjdiff: %.3e" % relobjdiff) + + objdiff = objval - objvals[it-1] + if objdiff > 0: + print(" \033[0;31mEM objective increased\033[0m") + + sys.stdout.flush() + sys.stderr.flush() + + if params_converged or relobjdiff <= tol: + if verbose > 0: + print("EM objective converged") + sys.stdout.flush() + sys.stderr.flush() + converged = True + objvals = objvals[:it+1] + break + + # retrieve best parameters and load into instance variables. + A_t_, roi_cov, mu0, roi_cov_0, log_sigsq_lst = best_params + self.A_t_ = A_t_.copy() + self.roi_cov = roi_cov.copy() + self.mu0 = mu0.copy() + self.roi_cov_0 = roi_cov_0.copy() + self.log_sigsq_lst = [l.copy() for l in log_sigsq_lst] + + if verbose > 0: + print() + print("elapsed, iteration:", datetime.now() - iter_start) + print("=" * 34) + print() + + # perform final smoothing + mus_smooth_lst = None + St_lst = None + if do_final_smoothing: + if verbose >= 1: + print("performing final smoothing") + + mus_smooth_lst = list() + self._loglik = 0. + if self._store_St: + St_lst = list() + for s, sdata in enumerate(self.unpack_all_subject_data()): + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + sigsq_vals = np.exp(self.log_sigsq_lst[s]) + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) + roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) + with numpy_num_threads(1): + loglik_subject, mus_smooth, _, _, St = \ + rts_smooth(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, self.roi_cov_0, + compute_lag1_cov=False, + store_St=self._store_St) + # just save the mean of the smoothed trials + if average_mus_smooth: + mus_smooth_lst.append(np.mean(mus_smooth, axis=0)) + else: + mus_smooth_lst.append(mus_smooth) + self._loglik += loglik_subject + # just save the diagonals of St b/c that's what we need for + # connectivity + if self._store_St: + St_lst.append(np.diagonal(St, axis1=-2, axis2=-1)) + + if verbose > 0: + print() + print("elapsed, function:", datetime.now() - fxn_start) + print("=" * 34) + print() + + return objvals, converged, mus_smooth_lst, self._loglik, St_lst + + def _e_step(self, verbose=0): + + timepts, num_roi, _ = self.A_t_.shape + + # reset accumulation arrays + self._B0[:] = 0. + self._B1[:] = 0. + self._B3[:] = 0. + self._B2[:] = 0. + + self._mus_smooth_lst = list() + self._sigmas_smooth_lst = list() + self._sigmas_tnt_smooth_lst = list() + + if verbose > 0: + print(" e-step") + print(" subject", end="") + + for s, sdata in enumerate(self.unpack_all_subject_data()): + + if verbose > 0: + print(" %d" % (s+1,), end="") + sys.stdout.flush() + sys.stderr.flush() + + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + + sigsq_vals = np.exp(self.log_sigsq_lst[s]) + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) + L_R = np.linalg.cholesky(R) + roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) + + with numpy_num_threads(1): + _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ + rts_smooth_fast(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, + self.roi_cov_0, compute_lag1_cov=True) + + self._mus_smooth_lst.append(mus_smooth) + self._sigmas_smooth_lst.append(sigmas_smooth) + self._sigmas_tnt_smooth_lst.append(sigmas_tnt_smooth) + + x_smooth_0_outer = einsum2('ri,rj->rij', mus_smooth[:,0,:num_roi], + mus_smooth[:,0,:num_roi]) + self._B0 += w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + x_smooth_0_outer, + axis=0) + + x_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], + mus_smooth[:,1:,:num_roi]) + self._B1 += w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + x_smooth_outer, + axis=0) + + z_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,:-1,:], + mus_smooth[:,:-1,:]) + self._B3 += w_s*np.sum(sigmas_smooth[:,:-1,:,:] + z_smooth_outer, + axis=0) + + mus_smooth_outer_l1 = einsum2('rti,rtj->rtij', + mus_smooth[:,1:,:num_roi], + mus_smooth[:,:-1,:]) + self._B2 += w_s*np.sum(sigmas_tnt_smooth[:,:,:num_roi,:] + + mus_smooth_outer_l1, axis=0) + + if verbose > 0: + print("\n done") + + def _m_step(self, update_A_t_=True, update_roi_cov=True, update_roi_cov_0=True, + stationary_A_t_=False, diag_roi_cov=False, update_sigsq=True, tau=0.1, c1=1e-4, + verbose=0): + self._loglik = None + if verbose > 0: + print(" m-step") + if update_roi_cov_0: + self.roi_cov_0 = (1. / self._ntrials_all) * self._B0 + if diag_roi_cov: + self.roi_cov_0 = np.diag(np.diag(self.roi_cov_0)) + self.update_A_t_and_roi_cov(update_A_t_=update_A_t_, update_roi_cov=update_roi_cov, + stationary_A_t_=stationary_A_t_, diag_roi_cov=diag_roi_cov, + tau=tau, c1=c1, verbose=verbose) + if update_sigsq: + self.update_log_sigsq_lst(verbose=verbose) + + def update_A_t_and_roi_cov(self, update_A_t_=True, update_roi_cov=True, stationary_A_t_=False, + diag_roi_cov=False, tau=0.1, c1=1e-4, verbose=0): + + if verbose > 1: + print(" update A_t_ and roi_cov") + + # gradient descent + At = self.A_t_[:-1] + At_init = At.copy() + L_roi_cov = np.linalg.cholesky(self.roi_cov) + At_L_roi_cov_obj = lambda x, y: self.L2_obj(x, y) + At_obj = lambda x: self.L2_obj(x, L_roi_cov) + grad_At_obj = grad(At_obj) + obj_diff = np.finfo('float').max + obj = At_L_roi_cov_obj(At, L_roi_cov) + inner_it = 0 + + # specify proximal operator to use + if self._penalty == 'ridge': + prox_op = lambda x, y: x + elif self._penalty == 'lasso': + prox_op = soft_thresh_At + elif self._penalty == 'group-lasso': + prox_op = block_thresh_At + + while np.abs(obj_diff / obj) > self._A_t_roi_cov_tol: + + if inner_it > self._A_t_roi_cov_niter: + break + + obj_start = At_L_roi_cov_obj(At, L_roi_cov) + + # update At using gradient descent with backtracking line search + if update_A_t_: + if stationary_A_t_: + B2_sum = np.sum(self._B2, axis=0) + B3_sum = np.sum(self._B3, axis=0) + At[:] = np.linalg.solve(B3_sum.T, B2_sum.T).T + else: + grad_At = grad_At_obj(At) + step_size = linesearch(At_obj, grad_At_obj, At, grad_At, + prox_op=prox_op, lam=self.lam0, + tau=tau, c1=c1) + At[:] = prox_op(At - step_size * grad_At, + self.lam0 * step_size) + + # update roi_cov using closed form + if update_roi_cov: + AtB2T = einsum2('tik,tjk->tij', At, self._B2) + B2AtT = einsum2('tik,tjk->tij', self._B2, At) + tmp = einsum2('tik,tkl->til', At, self._B3) + AtB3AtT = einsum2('til,tjl->tij', tmp, At) + elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) + self.roi_cov = (1. / (self._ntrials_all * self._timepts + )) * elbo_2 + if diag_roi_cov: + self.roi_cov = np.diag(np.diag(self.roi_cov)) + L_roi_cov = np.linalg.cholesky(self.roi_cov) + + obj = At_L_roi_cov_obj(At, L_roi_cov) + obj_diff = obj_start - obj + inner_it += 1 + + if verbose > 1: + if not stationary_A_t_ and update_A_t_: + grad_norm = np.linalg.norm(grad_At) + norm_change = np.linalg.norm(At - At_init) + print(" last step size: %.3e" % step_size) + print(" last gradient norm: %.3e" % grad_norm) + print(" norm of total change: %.3e" % norm_change) + print(" number of iterations: %d" % inner_it) + print(" done") + + def update_log_sigsq_lst(self, verbose=0): + + if verbose > 1: + print(" update subject log-sigmasq") + + timepts, num_roi, _ = self.A_t_.shape + + # update log_sigsq_vals for each subject and ROI + for s, sdata in enumerate(self.unpack_all_subject_data()): + + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + ntrials, timepts, _ = Y.shape + mus_smooth = self._mus_smooth_lst[s] + sigmas_smooth = self._sigmas_smooth_lst[s] + B4 = self._B4[s] + + log_sigsq = self.log_sigsq_lst[s].copy() + log_sigsq_obj = lambda x: \ + MEGLDS.L3_obj(x, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, timepts) + log_sigsq_val_and_grad = vgrad(log_sigsq_obj) + + options = {'maxiter': 500} + opt_res = spopt.minimize(log_sigsq_val_and_grad, log_sigsq, + method='L-BFGS-B', jac=True, + options=options) + if verbose > 1: + print(" subject %d - %d iterations" % (s+1, opt_res.nit)) + + if not opt_res.success: + print(" log_sigsq opt") + print(" %s" % opt_res.message) + + self.log_sigsq_lst[s] = opt_res.x + + if verbose > 1: + print("\n done") + + def calculate_smoothed_estimates(self): + """ recalculate smoothed estimates with current model parameters """ + + self._mus_smooth_lst = list() + self._sigmas_smooth_lst = list() + self._sigmas_tnt_smooth_lst = list() + self._St_lst = list() + self._loglik = 0. + + for s, sdata in enumerate(self.unpack_all_subject_data()): + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + sigsq_vals = np.exp(self.log_sigsq_lst[s]) + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) + roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) + with numpy_num_threads(1): + ll, mus_smooth, sigmas_smooth, sigmas_tnt_smooth, _ = \ + rts_smooth(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, self.roi_cov_0, + compute_lag1_cov=True, store_St=False) + self._mus_smooth_lst.append(mus_smooth) + self._sigmas_smooth_lst.append(sigmas_smooth) + self._sigmas_tnt_smooth_lst.append(sigmas_tnt_smooth) + #self._St_lst.append(np.diagonal(St, axis1=-2, axis2=-1)) + self._loglik += ll + + def log_likelihood(self): + """ calculate log marginal likelihood using the Kalman filter """ + + #if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None \ + # or self._sigmas_tnt_smooth_lst is None): + # self.calculate_smoothed_estimates() + # return self._loglik + if self._loglik is not None: + return self._loglik + + self._loglik = 0. + for s, sdata in enumerate(self.unpack_all_subject_data()): + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + sigsq_vals = np.exp(self.log_sigsq_lst[s]) + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) + roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) + ll, _, _, _ = kalman_filter(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, + self.roi_cov_0, store_St=False) + self._loglik += ll + + return self._loglik + + def nparams(self): + timepts, p, _ = self.A_t_.shape + + # this should equal (timepts-1)*p*p unless some shrinkage is used on At + nparams_At = np.sum(np.abs(self.A_t_[:-1]) > 0) + + # nparams = nparams(At) + nparams(roi_cov) + nparams(roi_cov_0) + # + nparams(log_sigsq_lst) + return nparams_At + p*(p+1)/2 + p*(p+1)/2 \ + + np.sum([p+1 for _ in range(len(self.log_sigsq_lst))]) + + def AIC(self): + return -2*self.log_likelihood() + 2*self.nparams() + + def BIC(self): + if self._ntrials_all == 0: + raise RuntimeError("use set_data to add subject data before" \ + + " computing BIC") + return -2*self.log_likelihood() \ + + np.log(self._ntrials_all)*self.nparams() + + def save(self, filename, **kwargs): + savedict = { 'A_t_' : self.A_t_, 'roi_cov' : self.roi_cov, 'mu0' : self.mu0, + 'roi_cov_0' : self.roi_cov_0, 'log_sigsq_lst' : self.log_sigsq_lst, + 'lam0' : self.lam0, 'lam1' : self.lam1} + savedict.update(kwargs) + np.savez_compressed(filename, **savedict) + + def load(self, filename): + loaddict = np.load(filename) + param_names = ['A_t_', 'roi_cov', 'mu0', 'roi_cov_0', 'log_sigsq_lst', 'lam0', 'lam1'] + for name in param_names: + if name not in loaddict.keys(): + raise RuntimeError('specified file is not a saved model:\n%s' + % (filename,)) + for name in param_names: + if name == 'log_sigsq_lst': + self.log_sigsq_lst = [l.copy() for l in loaddict[name]] + elif name in ('lam0', 'lam1'): + self.__setattr__(name, float(loaddict[name])) + else: + self.__setattr__(name, loaddict[name].copy()) + + # return remaining saved items, if there are any + others = {key : loaddict[key] for key in loaddict.keys() \ + if key not in param_names} + if len(others.keys()) > 0: + return others + + @staticmethod + def R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi): + return snsr_cov + np.dot(fwd_src_snsr, sigsq_vals[which_roi][:,None]*fwd_src_snsr.T) + + def L2_obj(self, At, L_roi_cov): + + # import autograd.numpy + # if isinstance(At,autograd.numpy.numpy_boxes.ArrayBox): + # At = At._value + + AtB2T = einsum2('tik,tjk->tij', At, self._B2) + B2AtT = einsum2('tik,tjk->tij', self._B2, At) + tmp = einsum2('tik,tkl->til', At, self._B3) + AtB3AtT = einsum2('til,tjl->tij', tmp, At) + elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) + + L_roi_cov_inv_elbo_2 = solve_triangular(L_roi_cov, elbo_2, lower=True) + obj = np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_elbo_2, lower=True, + trans='T')) + obj = obj / self._ntrials_all + + if self._penalty == 'ridge': + obj += self.lam0*np.sum(At**2) + AtmAtm1_2 = (At[1:] - At[:-1])**2 + obj += self.lam1*np.sum(AtmAtm1_2) + + return obj + + # TODO: convert to instance method + @staticmethod + def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, timepts): + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, np.exp(log_sigsq_vals), which_roi) + try: + L_R = np.linalg.cholesky(R) + except LinAlgError: + return np.finfo('float').max + L_R_inv_B4 = solve_triangular(L_R, B4, lower=True) + return (ntrials*timepts*2.*np.sum(np.log(np.diag(L_R))) + + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, + trans='T'))) + + + # @property + # def A(self): + # return self._A + + # @A.setter + # def A(self, A): + # self._A = A + + # @property + # def roi_cov(self): + # return self._roi_cov + + # @roi_cov.setter + # def roi_cov(self, roi_cov): + # self._roi_cov = roi_cov + + # @property + # def mu0(self): + # return self._mu0 + + # @mu0.setter + # def mu0(self, mu0): + # self._mu0 = mu0 + + # @property + # def roi_cov_0(self): + # return self._roi_cov_0 + + # @roi_cov_0.setter + # def roi_cov_0(self, roi_cov_0): + # self._roi_cov_0 = roi_cov_0 + + # @property + # def log_sigsq_lst(self): + # return self._log_sigsq_lst + + # @log_sigsq_lst.setter + # def log_sigsq_lst(self, log_sigsq_lst): + # self._log_sigsq_lst = log_sigsq_lst + + # @property + # def num_roi(self): + # return self.A.shape[1] + + # @property + # def timepts(self): + # return self._timepts + + # @property + # def lam0(self): + # return self._lam0 + + # @lam0.setter + # def lam0(self, lam0): + # self._lam0 = lam0 + + # @property + # def lam1(self): + # return self._lam1 + + # @lam1.setter + # def lam1(self, lam1): + # self._lam1 = lam1 diff --git a/state_space/megssm/numpy_numthreads.py b/state_space/megssm/numpy_numthreads.py new file mode 100755 index 00000000..550aa235 --- /dev/null +++ b/state_space/megssm/numpy_numthreads.py @@ -0,0 +1,91 @@ +import contextlib +import ctypes +from ctypes.util import find_library + +# heavily based on: +# https://stackoverflow.com/questions/29559338/set-max-number-of-threads-at-runtime-on-numpy-openblas + +# Prioritize hand-compiled OpenBLAS library over version in /usr/lib/ +# from Ubuntu repos +try_paths = [find_library('openblas')] +openblas_lib = None +for libpath in try_paths: + try: + openblas_lib = ctypes.cdll.LoadLibrary(libpath) + break + except Exception: #OSError: + continue +#if openblas_lib is None: + #raise EnvironmentError('Could not locate an OpenBLAS shared library', 2) + +try: + mkl_rt_path = find_library('mkl_rt') + mkl_rt = ctypes.cdll.LoadLibrary(mkl_rt_path) + # print(mkl_rt) +except OSError: + mkl_rt = None + pass + + +def set_num_threads(n): + """Set the current number of threads used by the OpenBLAS server.""" + if mkl_rt: + pass + #mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(n))) + elif openblas_lib: + openblas_lib.openblas_set_num_threads(int(n)) + + +# At the time of writing these symbols were very new: +# https://github.com/xianyi/OpenBLAS/commit/65a847c +try: + if mkl_rt: #False: #mkl_rt: + def get_num_threads(): + return mkl_rt.mkl_get_max_threads() + elif openblas_lib: + # do this to throw exception if it doesn't exist + openblas_lib.openblas_get_num_threads() + def get_num_threads(): + """Get the current number of threads used by the OpenBLAS server.""" + return openblas_lib.openblas_get_num_threads() +except AttributeError: + def get_num_threads(): + """Dummy function (symbol not present in %s), returns -1.""" + return -1 + pass + +try: + if False: #mkl_rt: + def get_num_procs(): + # this returns number of procs + return mkl_rt.mkl_get_max_threads() + elif openblas_lib: + # do this to throw exception if it doesn't exist + openblas_lib.openblas_get_num_procs() + def get_num_procs(): + """Get the total number of physical processors""" + return openblas_lib.openblas_get_num_procs() +except AttributeError: + def get_num_procs(): + """Dummy function (symbol not present), returns -1.""" + return -1 + pass + + +@contextlib.contextmanager +def numpy_num_threads(n): + """Temporarily changes the number of OpenBLAS threads. + + Example usage: + + print("Before: {}".format(get_num_threads())) + with num_threads(n): + print("In thread context: {}".format(get_num_threads())) + print("After: {}".format(get_num_threads())) + """ + old_n = get_num_threads() + set_num_threads(n) + try: + yield + finally: + set_num_threads(old_n) diff --git a/state_space/megssm/plotting.py b/state_space/megssm/plotting.py new file mode 100644 index 00000000..8925660a --- /dev/null +++ b/state_space/megssm/plotting.py @@ -0,0 +1,107 @@ +""" plotting functions """ + +import numpy as np +import matplotlib.pyplot as plt + +def plot_A_t_(A, ci='sd', times=None, ax=None, skipdiag=False, labels=None, + showticks=True, **kwargs): + """ plot traces of each entry of dynamics A in square grid of subplots """ + if A.ndim == 3: + T, d, _ = A.shape + elif A.ndim == 4: + _, T, d, _ = A.shape + + if times is None: + times = np.arange(T) + + if ax is None or ax.shape != (d, d): + fig, ax = plt.subplots(d, d, sharex=True, sharey=True, squeeze=False) + else: + fig = ax[0, 0].figure + + for i in range(d): + for j in range(d): + + # skip and hide subplots on diagonal + if skipdiag and i == j: + ax[i, j].set_visible(False) + continue + + # plot A entry as trace with/without error band + if A.ndim == 3: + ax[i, j].plot(times[:-1], A[:-1, i, j], **kwargs) + elif A.ndim == 4: + plot_fill(A[:, :-1, i, j], ci=ci, times=times[:-1], + ax=ax[i, j], **kwargs) + + # add labels above first row and to the left of the first column + if labels is not None: + if i == 0 or (skipdiag and (i, j) == (1, 0)): + ax[i, j].set_title(labels[j], fontsize=12) + if j == 0 or (skipdiag and (i, j) == (0, 1)): + ax[i, j].set_ylabel(labels[i], fontsize=12) + + # remove x- and y-ticks on subplot + if not showticks: + ax[i, j].set_xticks([]) + ax[i, j].set_yticks([]) + + diag_lims = [0, 1] + off_lims = [-0.25, 0.25] + for ri, row in enumerate(ax): + for ci, a in enumerate(row): + ylim = diag_lims if ri == ci else off_lims + a.set(ylim=ylim, xlim=times[[0, -1]]) + if ri == 0: + a.set_title(a.get_title(), fontsize='small') + if ci == 0: + a.set_ylabel(a.get_ylabel(), fontsize='small') + for line in a.lines: + line.set_clip_on(False) + line.set(lw=1.) + if ci != 0: + a.yaxis.set_major_formatter(plt.NullFormatter()) + if ri != len(labels) - 1: + a.xaxis.set_major_formatter(plt.NullFormatter()) + if ri == ci: + for spine in a.spines.values(): + spine.set(lw=2) + else: + a.axhline(0, color='k', ls=':', lw=1.) + + return fig, ax + +def plot_fill(X, times=None, ax=None, ci='sd', **kwargs): + """ plot mean and error band across first axis of X """ + N, T = X.shape + + if times is None: + times = np.arange(T) + if ax is None: + fig, ax = plt.subplots(1, 1) + + mu = np.mean(X, axis=0) + + # define lower and upper band limits based on ci + if ci == 'sd': # standard deviation + sigma = np.std(X, axis=0) + lower, upper = mu - sigma, mu + sigma + elif ci == 'se': # standard error + stderr = np.std(X, axis=0) / np.sqrt(X.shape[0]) + lower, upper = mu - stderr, mu + stderr + elif ci == '2sd': # 2 standard deviations + sigma = np.std(X, axis=0) + lower, upper = mu - 2 * sigma, mu + 2 * sigma + elif ci == 'max': # range (min to max) + lower, upper = np.min(X, axis=0), np.max(X, axis=0) + elif type(ci) is float and 0 < ci < 1: + # quantile-based confidence interval + a = 1 - ci + lower, upper = np.quantile(X, [a / 2, 1 - a / 2], axis=0) + else: + raise ValueError("ci must be in ('sd', 'se', '2sd', 'max') " + "or float in (0, 1)") + + lines = ax.plot(times, mu, **kwargs) + c = lines[0].get_color() + ax.fill_between(times, lower, upper, color=c, alpha=0.3, lw=0) diff --git a/state_space/megssm/util.py b/state_space/megssm/util.py new file mode 100755 index 00000000..63898be3 --- /dev/null +++ b/state_space/megssm/util.py @@ -0,0 +1,117 @@ +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import autograd.numpy as np +from numpy.lib.stride_tricks import as_strided as ast + + +hs = lambda *args: np.concatenate(*args, axis=-1) + +def T_(X): + return np.swapaxes(X, -1, -2) + +def sym(X): + return 0.5*(X + T_(X)) + +def dot3(A, B, C): + return np.dot(A, np.dot(B, C)) + +def relnormdiff(A, B, min_denom=1e-9): + return np.linalg.norm(A - B) / np.maximum(np.linalg.norm(A), min_denom) + +def _ensure_ndim(X, T, ndim): + X = np.require(X, dtype=np.float64, requirements='C') + assert ndim-1 <= X.ndim <= ndim + if X.ndim == ndim: + assert X.shape[0] == T + return X + else: + return ast(X, shape=(T,) + X.shape, strides=(0,) + X.strides) + +def rand_psd(n, minew=0.1, maxew=1.): + # maxew is badly named + if n == 1: + return maxew * np.eye(n) + X = np.random.randn(n,n) + S = np.dot(T_(X), X) + S = sym(S) + ew, ev = np.linalg.eigh(S) + ew -= np.min(ew) + ew /= np.max(ew) + ew *= (maxew - minew) + ew += minew + return dot3(ev, np.diag(ew), T_(ev)) + +def rand_stable(n, maxew=0.9): + A = np.random.randn(n, n) + A *= maxew / np.max(np.abs(np.linalg.eigvals(A))) + return A + +def component_matrix(As, nlags): + """ compute component form of latent VAR process + + [A_1 A_2 ... A_p] + [ I 0 ... 0 ] + [ 0 I 0 0 ] + [ 0 ... I 0 ] + + """ + + d = As.shape[0] + res = np.zeros((d*nlags, d*nlags)) + res[:d] = As + + if nlags > 1: + res[np.arange(d,d*nlags), np.arange(d*nlags-d)] = 1 + + return res + +def linesearch(f, grad_f, xk, pk, step_size=1., tau=0.1, c1=1e-4, + prox_op=None, lam=1.): + """ find a step size via backtracking line search with armijo condition """ + obj_start = f(xk) + grad_xk = grad_f(xk) + obj_new = np.finfo('float').max + armijo_condition = 0 + + if prox_op is None: + prox_op = lambda x, y: x + + while obj_new > armijo_condition: + x_new = prox_op(xk - step_size * pk, lam*step_size) + armijo_condition = obj_start - c1*step_size*(np.sum(pk*grad_xk)) + obj_new = f(x_new) + step_size *= tau + + return step_size/tau + +def soft_thresh_At(At, lam): + At = At.copy() + diag_inds = np.diag_indices(At.shape[1]) + At_diag = np.diagonal(At, axis1=-2, axis2=-1) + + At = np.sign(At) * np.maximum(np.abs(At) - lam, 0.) + + # fill in diagonal with originally updated entries as we're not + # going to penalize them + for tt in range(At.shape[0]): + At[tt][diag_inds] = At_diag[tt] + return At + +def block_thresh_At(At, lam, min_norm=1e-16): + At = At.copy() + diag_inds = np.diag_indices(At.shape[1]) + At_diag = np.diagonal(At, axis1=-2, axis2=-1) + + norms = np.linalg.norm(At, axis=0, keepdims=True) + norms = np.maximum(norms, min_norm) + scales = np.maximum(norms - lam, 0.) + At = scales * (At / norms) + + # fill in diagonal with originally updated entries as we're not + # going to penalize them + for tt in range(At.shape[0]): + At[tt][diag_inds] = At_diag[tt] + return At + diff --git a/state_space/state_space_connectivity.py b/state_space/state_space_connectivity.py new file mode 100644 index 00000000..52db4f5c --- /dev/null +++ b/state_space/state_space_connectivity.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Authors: Jordan Drew + +""" + +''' +For 'mne-connectivity/examples/' to show usage of LDS +Use MNE-sample-data for auditory/left +''' + +## import necessary libraries +import mne +import matplotlib.pyplot as plt +import matplotlib as mpl + +#where should these files live within mne-connectivity repo? +from megssm.models import MEGLDS as LDS +from megssm.plotting import plot_A_t_ + +## define paths to sample data +path = None +path = '/Users/jordandrew/Documents/MEG/mne_data'#'/MNE-sample-data' +data_path = mne.datasets.sample.data_path(path=path) +sample_folder = data_path / 'MEG/sample' +subjects_dir = data_path / 'subjects' + +## import raw data and find events +raw_fname = sample_folder / 'sample_audvis_raw.fif' +raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) +events = mne.find_events(raw, stim_channel='STI 014') + +## define epochs using event_dict +event_dict = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3, + 'visual/right': 4, 'face': 5, 'buttonpress': 32} +epochs = mne.Epochs(raw, events, tmin=-0.3, tmax=0.7, event_id=event_dict, + preload=True).pick_types(meg=True,eeg=True) +condition = 'visual' +epochs = epochs[condition] # choose condition for analysis + +## read forward solution, remove bad channels +fwd_fname = sample_folder / 'sample_audvis-meg-eeg-oct-6-fwd.fif' +fwd = mne.read_forward_solution(fwd_fname) + +## read in covariance OR compute noise covariance? noise_cov drops bad chs +cov_fname = sample_folder / 'sample_audvis-cov.fif' +cov = mne.read_cov(cov_fname) # drop bad channels in add_subject +# noise_cov = mne.compute_covariance(epochs, tmax=0) + +## read labels for analysis +label_names = ['AUD-lh', 'AUD-rh', 'Vis-lh', 'Vis-rh'] +labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label') + for label in label_names] + +## initiate model +num_rois = len(labels) +timepts = len(epochs.times) +model = LDS(num_rois, timepts, lam0=0, lam1=100) # only needs the forward, labels, and noise_cov to be initialized + +model.add_subject('sample', subjects_dir, epochs, labels, fwd, cov) +#when to use compute_cov vs read_cov? ie cov vs noise_cov + +model.fit(niter=100, verbose=1) +A_t_ = model.A_t_ +assert A_t_.shape == (timepts, num_rois, num_rois) + +with mpl.rc_context(): + {'xtick.labelsize': 'x-small', 'ytick.labelsize': 'x-small'} + fig, ax = plt.subplots(num_rois, num_rois, constrained_layout=True, squeeze=False, + figsize=(12, 10)) + plot_A_t_(A_t_, labels=label_names, times=epochs.times, ax=ax) + fig.suptitle(condition) + + + + + + + + + + + + + + + diff --git a/state_space/state_space_connectivity_test.py b/state_space/state_space_connectivity_test.py new file mode 100644 index 00000000..848997ea --- /dev/null +++ b/state_space/state_space_connectivity_test.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Authors: Jordan Drew + +""" + +''' +For 'mne-connectivity/examples/' to show usage of LDS +Use MNE-sample-data for auditory/left +''' + +## import necessary libraries +import mne +import matplotlib.pyplot as plt +import matplotlib as mpl + +#where should these files live within mne-connectivity repo? +from megssm.models import MEGLDS as LDS +from megssm.plotting import plot_A_t_ +from megssm import label_util + +## define paths to sample data +data_path = '/Users/jordandrew/Documents/MEG/meglds-master/data/sps' +# sample_folder = data_path / 'MEG/sample' +# subjects_dir = data_path / 'subjects' + +subjects = ['eric_sps_03','eric_sps_04','eric_sps_05','eric_sps_06', + 'eric_sps_07','eric_sps_09','eric_sps_10','eric_sps_15', + 'eric_sps_17','eric_sps_18','eric_sps_19','eric_sps_21', + 'eric_sps_25','eric_sps_26','eric_sps_31','eric_sps_32'] + +label_names = ['ACC', 'DLPFC-lh', 'DLPFC-rh', 'AUD-lh', 'AUD-rh', 'FEF-lh', + 'FEF-rh', 'Vis', 'IPS-lh', 'LIPSP', 'IPS-rh', 'RTPJ'] +label_func = 'sps_meglds_base_vision_extra' +labels = getattr(label_util, label_func)() +labels = sorted(labels, key=lambda x: x.name) + +def eq_trials(epochs, kind): + """ equalize trial counts """ + import numpy as np + import mne + assert kind in ('sub', 'big') + print(' equalizing trial counts', end='') + in_names = [ + 'LL3', 'LR3', 'LU3', 'LD3', 'RL3', 'RR3', 'RU3', 'RD3', + 'UL3', 'UR3', 'UU3', 'UD3', 'DL3', 'DR3', 'DU3', 'DD3', + 'LL4', 'LR4', 'LU4', 'LD4', 'RL4', 'RR4', 'RU4', 'RD4', + 'UL4', 'UR4', 'UU4', 'UD4', 'DL4', 'DR4', 'DU4', 'DD4', + 'VS_', 'VM_', + 'Junk', + ] + out_names = ['LL', 'LR', 'LX', 'UX', 'UU', 'UD', 'VS', 'VM'] + + # strip 3/4 and combine + clean_names = np.unique([ii[:2] for ii in in_names + if not ii.startswith('V')]) + for name in clean_names: + combs = [in_name for in_name in in_names if in_name.startswith(name)] + new_id = {name: epochs.event_id[combs[-1]] + 1} + mne.epochs.combine_event_ids(epochs, combs, new_id, copy=False) + + # Now we equalize LU+LD, RU+RD, UL+UR, DL+DR, and combine those + for n1, n2, oname in zip(('LU', 'RU', 'UL', 'DL'), + ('LD', 'RD', 'UR', 'DR'), + ('LX', 'RX', 'UX', 'DX')): + if kind == 'sub': + epochs.equalize_event_counts([n1, n2]) + new_id = {oname: epochs.event_id[n1] + 1} + mne.epochs.combine_event_ids(epochs, [n1, n2], new_id, copy=False) + + # Now we equalize "sides" + cs = dict(L='R', R='L', U='D', D='U') + for n1 in ['L', 'R', 'U', 'D']: + # first equalize it with its complement in the second pos + if kind == 'sub': + epochs.equalize_event_counts([n1 + n1, n1 + cs[n1]]) + epochs.equalize_event_counts([n1 + n1, cs[n1] + n1]) + epochs.equalize_event_counts([n1 + 'X', cs[n1] + 'X']) + + # now combine cross types + for n1 in ['L', 'U']: + # LR+RL=LR, UD+DU=UD + old_ids = [n1 + cs[n1], cs[n1] + n1] + if kind == 'sub': + epochs.equalize_event_counts(old_ids) + new_id = {n1 + cs[n1]: epochs.event_id[n1 + cs[n1]] + 1} + mne.epochs.combine_event_ids(epochs, old_ids, new_id, copy=False) + # LL+RR=LL, UU+DD=UU + old_ids = [n1 + n1, cs[n1] + cs[n1]] + if kind == 'sub': + epochs.equalize_event_counts(old_ids) + new_id = {n1 + n1: epochs.event_id[n1 + n1] + 1} + mne.epochs.combine_event_ids(epochs, old_ids, new_id, copy=False) + # LC+RC=LC + old_ids = [n1 + 'X', cs[n1] + 'X'] + if kind == 'sub': + epochs.equalize_event_counts(old_ids) + new_id = {n1 + 'X': epochs.event_id[n1 + 'X'] + 1} + mne.epochs.combine_event_ids(epochs, old_ids, new_id, copy=False) + + mne.epochs.combine_event_ids(epochs, ['VM_'], dict(VM=96), copy=False) + assert 'Ju' in epochs.event_id + epochs.drop(np.where(epochs.events[:, 2] == + epochs.event_id['Ju'])[0]) + mne.epochs.combine_event_ids(epochs, ['VS_', 'Ju'], dict(VS=97), + copy=False) + + # at this point we only care about: + eq_names = ('LX', 'UX', 'LL', 'LR', 'UU', 'UD', 'VS') + assert set(eq_names + ('VM',)) == set(epochs.event_id.keys()) + assert set(eq_names + ('VM',)) == set(out_names) + orig_len = len(epochs['LL']) + epochs.equalize_event_counts(eq_names) + new_len = len(epochs['LL']) + print(' (reduced LL %s -> %s)' % (orig_len, new_len)) + for ni, out_name in enumerate(out_names): + idx = (epochs.events[:, 2] == epochs.event_id[out_name]) + epochs.event_id[out_name] = ni + 1 + epochs.events[idx, 2] = ni + 1 + return epochs + +for subject in subjects: + + subject_dir = f'{data_path}/{subject}' + + epochs_fname = f'{subject_dir}/epochs/All_55-sss_{subject}-epo.fif' + epochs = mne.read_epochs(epochs_fname) + epochs = eq_trials(epochs, kind='sub') + epochs = epochs['LL'] + + fwd_fname = f'{subject_dir}/forward/{subject}-sss-fwd.fif' + fwd = mne.read_forward_solution(fwd_fname) + + cov_fname = f'{subject_dir}/covariance/{subject}-55-sss-cov.fif' + cov = mne.read_cov(cov_fname) + + if subject == subjects[0]: + num_rois = len(labels) + timepts = len(epochs.times) + model = LDS(num_rois, timepts, lam0=0, lam1=100) + + model.add_subject(subject, subject_dir, epochs, labels, fwd, cov) #not using subject_dir + + +# model.fit(niter=100, verbose=1) +# A_t_ = model.A_t_ +# assert A_t_.shape == (timepts, num_rois, num_rois) + +# with mpl.rc_context(): +# {'xtick.labelsize': 'x-small', 'ytick.labelsize': 'x-small'} +# fig, ax = plt.subplots(num_rois, num_rois, constrained_layout=True, squeeze=False, +# figsize=(12, 10)) +# plot_A_t_(A_t_, labels=label_names, times=epochs.times, ax=ax) +# fig.suptitle(condition) + + + + + + + + + + + + + + + From 463b804f998af30ad4827b48d610910e296c2682 Mon Sep 17 00:00:00 2001 From: jadrew43 Date: Fri, 12 Aug 2022 16:15:44 -0700 Subject: [PATCH 09/17] SNR boost added via bootstrap_subject function --- state_space/megssm/mne_util.py | 65 ++++- state_space/megssm/models.py | 335 ++++++++++++++++++++---- state_space/state_space_connectivity.py | 298 +++++++++++++++++---- 3 files changed, 583 insertions(+), 115 deletions(-) diff --git a/state_space/megssm/mne_util.py b/state_space/megssm/mne_util.py index 1402b648..a8d27134 100644 --- a/state_space/megssm/mne_util.py +++ b/state_space/megssm/mne_util.py @@ -186,12 +186,12 @@ def _scale_sensor_data(epochs, fwd, cov, roi_to_src): del cov, fwd, epochs #neccessary? # rescale data according to covariance whitener? - # rescale_cov = mne.make_ad_hoc_cov(info) - # scaler = mne.cov.compute_whitener(rescale_cov, info) - # del rescale_cov - # fwd_src_snsr = scaler[0] @ fwd_src_snsr - # snsr_cov = scaler[0] @ snsr_cov - # data = scaler[0] @ data + rescale_cov = mne.make_ad_hoc_cov(info, std=1) + scaler = mne.cov.compute_whitener(rescale_cov, info) + del rescale_cov + fwd_src_snsr = scaler[0] @ fwd_src_snsr + snsr_cov = scaler[0] @ snsr_cov + data = scaler[0] @ data fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, fwd_src_snsr.T).T) @@ -199,6 +199,57 @@ def _scale_sensor_data(epochs, fwd, cov, roi_to_src): return fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs +# def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., +# grad_scale=1.): +# """ apply per-channel-type scaling to epochs, forward, and covariance """ +# # # from util import Carray ##skip import just pasted; util also from MEGLDS repo +# # Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') +# # Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') +# # Carray = Carray64 + +# epochs = epochs.copy().pick('data', exclude='bads') +# cov = cov.pick_channels(epochs.ch_names, ordered=True) +# fwd = mne.convert_forward_solution(fwd, force_fixed=True) +# fwd = fwd.pick_channels(epochs.ch_names, ordered=True) +# # +# # get indices for each channel type +# ch_names = cov['names'] # same as self.fwd['info']['ch_names'] +# sel_eeg = pick_types(fwd['info'], meg=False, eeg=True, ref_meg=False) +# sel_mag = pick_types(fwd['info'], meg='mag', eeg=False, ref_meg=False) +# sel_grad = pick_types(fwd['info'], meg='grad', eeg=False, ref_meg=False) +# idx_eeg = [ch_names.index(ch_names[c]) for c in sel_eeg] +# idx_mag = [ch_names.index(ch_names[c]) for c in sel_mag] +# idx_grad = [ch_names.index(ch_names[c]) for c in sel_grad] + +# # retrieve forward and sensor covariance +# fwd_src_snsr = fwd['sol']['data'].copy() +# snsr_cov = cov.data.copy() + +# # scale forward matrix +# fwd_src_snsr[idx_eeg,:] *= eeg_scale +# fwd_src_snsr[idx_mag,:] *= mag_scale +# fwd_src_snsr[idx_grad,:] *= grad_scale + +# # construct fwd_roi_snsr matrix +# fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, fwd_src_snsr.T).T) + +# # scale sensor covariance +# snsr_cov[np.ix_(idx_eeg, idx_eeg)] *= eeg_scale**2 +# snsr_cov[np.ix_(idx_mag, idx_mag)] *= mag_scale**2 +# snsr_cov[np.ix_(idx_grad, idx_grad)] *= grad_scale**2 + +# # scale epochs +# info = epochs.info.copy() +# data = epochs.get_data().copy() + +# data[:,idx_eeg,:] *= eeg_scale +# data[:,idx_mag,:] *= mag_scale +# data[:,idx_grad,:] *= grad_scale + +# epochs = mne.EpochsArray(data, info) + +# return fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs + def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', pctvar=0.99, mean_center=False, label_flip=False): @@ -234,7 +285,7 @@ def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', fwd_src_snsr, fwd_roi_snsr, Q_snsr, epochs = \ _scale_sensor_data(epochs, fwd, cov, roi_to_src) - dat = epochs.get_data() + dat = epochs.get_data().copy() dat = Carray(np.swapaxes(dat, -1, -2)) if mean_center: diff --git a/state_space/megssm/models.py b/state_space/megssm/models.py index 2ec1d104..8da2c276 100755 --- a/state_space/megssm/models.py +++ b/state_space/megssm/models.py @@ -1,4 +1,6 @@ import sys +import os +import mne import autograd.numpy as np import scipy.optimize as spopt @@ -40,14 +42,14 @@ class _MEGModel(object): def __init__(self): self._subjectdata = None - self._timepts = 0 + self._n_timepts = 0 self._ntrials_all = 0 self._nsubjects = 0 def set_data(self, subjectdata): - timepts_lst = [self.unpack_subject_data(e)[0].shape[1] for e in subjectdata] - assert len(list(set(timepts_lst))) == 1 - self._timepts = timepts_lst[0] + n_timepts_lst = [self.unpack_subject_data(e)[0].shape[1] for e in subjectdata] + assert len(list(set(n_timepts_lst))) == 1 + self._n_timepts = n_timepts_lst[0] ntrials_lst = [self.unpack_subject_data(e)[0].shape[0] for e in \ subjectdata] self._ntrials_all = np.sum(ntrials_lst) @@ -83,25 +85,27 @@ class MEGLDS(_MEGModel): cross-region dynamic connectivity in MEG/EEG", Yang et al., NIPS 2016. """ - def __init__(self, num_roi, timepts, A_t_=None, roi_cov=None, mu0=None, roi_cov_0=None, - log_sigsq_lst=None, lam0=0., lam1=0., penalty='ridge', - store_St=True): - + # def __init__(self, num_roi, n_timepts, A_t_=None, roi_cov=None, mu0=None, roi_cov_0=None, + # log_sigsq_lst=None, lam0=0., lam1=0., penalty='ridge', + # store_St=True): + def __init__(self, lam0=0., lam1=0., penalty='ridge', store_St=True): + super().__init__() + self._model_initalized = False - set_default = \ - lambda prm, val, deflt: \ - self.__setattr__(prm, val.copy() if val is not None else deflt) + # set_default = \ + # lambda prm, val, deflt: \ + # self.__setattr__(prm, val.copy() if val is not None else deflt) - # initialize parameters - set_default("A_t_", A_t_, - np.stack([rand_stable(num_roi, maxew=0.7) for _ in range(timepts)], - axis=0)) - set_default("roi_cov", roi_cov, rand_psd(num_roi)) - set_default("mu0", mu0, np.zeros(num_roi)) - set_default("roi_cov_0", roi_cov_0, rand_psd(num_roi)) - set_default("log_sigsq_lst", log_sigsq_lst, - [np.log(np.random.gamma(2, 1, size=num_roi+1))]) + # # initialize parameters + # set_default("A_t_", A_t_, + # np.stack([rand_stable(num_roi, maxew=0.7) for _ in range(n_timepts)], + # axis=0)) + # set_default("roi_cov", roi_cov, rand_psd(num_roi)) + # set_default("mu0", mu0, np.zeros(num_roi)) + # set_default("roi_cov_0", roi_cov_0, rand_psd(num_roi)) + # set_default("log_sigsq_lst", log_sigsq_lst, + # [np.log(np.random.gamma(2, 1, size=num_roi+1))]) self.lam0 = lam0 self.lam1 = lam1 @@ -118,37 +122,248 @@ def __init__(self, num_roi, timepts, A_t_=None, roi_cov=None, mu0=None, roi_cov_ self._loglik = None self._store_St = bool(store_St) - # initialize sufficient statistics - timepts, num_roi, _ = self.A_t_.shape - self._B0 = np.zeros((num_roi, num_roi)) - self._B1 = np.zeros((timepts-1, num_roi, num_roi)) - self._B3 = np.zeros((timepts-1, num_roi, num_roi)) - self._B2 = np.zeros((timepts-1, num_roi, num_roi)) - self._B4 = list() + # # initialize sufficient statistics + # n_timepts, num_roi, _ = self.A_t_.shape + # self._B0 = np.zeros((num_roi, num_roi)) + # self._B1 = np.zeros((n_timepts-1, num_roi, num_roi)) + # self._B3 = np.zeros((n_timepts-1, num_roi, num_roi)) + # self._B2 = np.zeros((n_timepts-1, num_roi, num_roi)) + # self._B4 = list() - self._subject_data = dict() + self._all_subject_data = list()#dict() + + #SNR boost epochs, bootstraps of 3 + def bootstrap_subject(self, subject_name, seed=8675309, sfreq=100, lower=None, + upper=None, nbootstrap=3, g_nsamples=-5, + overwrite=False, validation_set=True): + # subjects = ['sample'] + datasets = ['train', 'validation'] + use_erm = eq = False + independent = False + if g_nsamples == 0: + print('nsamples == 0, ensuring independence of samples') + independent = True + elif g_nsamples == -1: + print("using half of trials per sample") + elif g_nsamples == -2: + print("using empty room noise at half of trials per sample") + use_erm = True + elif g_nsamples == -3: + print("using independent and trial-count equalized samples") + eq = True + independent = True + elif g_nsamples == -4: + print("using independent, trial-count equailized, non-boosted samples") + assert nbootstrap == 0 # sanity check + eq = True + independent = True + datasets = ['train'] + elif g_nsamples == -5: + print("using independent, trial-count equailized, integer boosted samples") + eq = True + independent = True + datasets = ['train'] + + if lower is not None or upper is not None: + if upper is None: + print('high-pass filtering at %.2f Hz' % lower) + elif lower is None: + print('low-pass filtering at %.2f Hz' % upper) + else: + print('band-pass filtering from %.2f-%.2f Hz' % (lower, upper)) + + if sfreq is not None: + print('resampling to %.2f Hz' % sfreq) + + print(":: processing subject %s" % subject_name) + np.random.seed(seed) + + for dataset in datasets: + + print(' generating ', dataset, ' set') + datadir = './data' + + subj_dir = os.path.join(datadir, subject_name) + print("subject dir:" + subj_dir) + if not os.path.exists(subj_dir): + print(' %s not found, skipping' % subject_name) + return + + epochs_dir = os.path.join(datadir, subject_name, 'epochs') + epochs_fname = "All_55-sss_%s-epo.fif" % subject_name + epochs_bs_fname = (epochs_fname.split('-epo')[0] + + "-bootstrap_%d-nsamples_%d-seed_%d%s%s%s-" + % (nbootstrap, g_nsamples, seed, + '-lower_%.2e' % lower if lower is not None else '', + '-upper_%.2e' % upper if upper is not None else '', + '-sfreq_%.2e' % sfreq if sfreq is not None else '') + + dataset + "-epo.fif") + + if os.path.exists(os.path.join(epochs_dir, epochs_bs_fname)) and \ + not overwrite: + print(" => found existing bootstrapped epochs, skipping") + return + + epochs = mne.read_epochs(os.path.join(epochs_dir, epochs_fname), + preload=True) + condition_map = {'auditory_left':['auditory_left'],'auditory_right': ['auditory_right'], + 'visual_left': ['visual_left'], 'visual_right': ['visual_right']} + condition_eq_map = dict(auditory_left=['auditory_left'], auditory_right=['auditory_right'], + visual_left=['visual_left'], visual_right='visual_right') + if eq: + epochs.equalize_event_counts(list(condition_map)) + cond_map = condition_eq_map + + # apply band-pass filter to limit signal to desired frequency band + if lower is not None or upper is not None: + epochs = epochs.filter(lower, upper) + + # perform resampling with specified sampling frequency + if sfreq is not None: + epochs = epochs.resample(sfreq) + + data_bs_all = list() + events_bs_all = list() + for cond in sorted(cond_map.keys()): + print(" -> condition %s: bootstrapping" % cond, end='') + ep = epochs[cond_map[cond]] + dat = ep.get_data().copy() + ntrials, T, p = dat.shape + + use_bootstrap = nbootstrap + if g_nsamples == -4: + nsamples = 1 + use_bootstrap = ntrials + elif g_nsamples == -5: + nsamples = nbootstrap + use_bootstrap = ntrials // nsamples + elif independent: + nsamples = (ntrials - 1) // use_bootstrap + elif g_nsamples in (-1, -2): + nsamples = ntrials // 2 + else: + assert g_nsamples > 0 + nsamples = g_nsamples + print(" using %d samples (%d trials)" + % (nsamples, use_bootstrap)) + + # bootstrap here + if independent: # independent + if nsamples == 1 and use_bootstrap == ntrials: + inds = np.arange(ntrials) + else: + inds = np.random.choice(ntrials, nsamples * use_bootstrap) + inds.shape = (use_bootstrap, nsamples) + dat_bs = np.mean(dat[inds], axis=1) + events_bs = ep.events[inds[:, 0]] + assert dat_bs.shape[0] == events_bs.shape[0] + else: + dat_bs = np.empty((ntrials, T, p)) + events_bs = np.empty((ntrials, 3), dtype=int) + for i in range(ntrials): + + inds = list(set(range(ntrials)).difference([i])) + inds = np.random.choice(inds, size=nsamples, + replace=False) + inds = np.append(inds, i) + + dat_bs[i] = np.mean(dat[inds], axis=0) + events_bs[i] = ep.events[i] + + inds = np.random.choice(ntrials, size=use_bootstrap, + replace=False) + dat_bs = dat_bs[inds] + events_bs = events_bs[inds] + + assert dat_bs.shape == (use_bootstrap, T, p) + assert events_bs.shape == (use_bootstrap, 3) + assert (events_bs[:, 2] == events_bs[0, 2]).all() #not working for sample_info + + data_bs_all.append(dat_bs) + events_bs_all.append(events_bs) + + # write bootstrap epochs + info_dict = epochs.info.copy() + + dat_all = np.vstack(data_bs_all) + events_all = np.vstack(events_bs_all) + # replace first column with sequential list as we don't really care + # about the raw timings + events_all[:, 0] = np.arange(events_all.shape[0]) + + epochs_bs = mne.EpochsArray( + dat_all, info_dict, events=events_all, tmin=-0.2, + event_id=epochs.event_id.copy(), on_missing='ignore') + + # print(" saving bootstrapped epochs (%s)" % (epochs_bs_fname,)) + # epochs_bs.save(os.path.join(epochs_dir, epochs_bs_fname)) + return epochs_bs + def add_subject(self, subject, subject_dir, epochs,labels, fwd, cov): + + if not self._model_initalized: + n_timepts = len(epochs.times) + num_roi = len(labels) + self._init_model(n_timepts, num_roi) + self._model_initalized = True + if len(epochs.times) != self._n_times: + raise ValueError(f'Number of time points ({len(epochs.times)})' / + 'does not match original count ({self._n_times})') + roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map - fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs = \ - _scale_sensor_data(epochs, fwd, cov, roi_to_src) + # fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs = \ + # _scale_sensor_data(epochs, fwd, cov, roi_to_src) # cov = cov.pick_channels(epochs.ch_names) - sdata = run_pca_on_subject(subject, epochs, fwd, cov, labels) #check for channel mismatch + epochs_bs = self.bootstrap_subject(subject) + sdata = run_pca_on_subject(subject, epochs_bs, fwd, cov, labels, dim_mode='pctvar') #check for channel mismatch data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - subjectdata = [(data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi)] + subjectdata = (data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi) - self.set_data(subjectdata) + self._all_subject_data.append(subjectdata) + # self.set_data(subjectdata) # epochs, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = subjectdata self._subject_data[subject] = dict() - self._subject_data[subject]['epochs'] = epochs + self._subject_data[subject]['epochs'] = data self._subject_data[subject]['fwd_src_snsr'] = fwd_src_snsr self._subject_data[subject]['fwd_roi_snsr'] = fwd_roi_snsr self._subject_data[subject]['snsr_cov'] = snsr_cov self._subject_data[subject]['labels'] = labels self._subject_data[subject]['which_roi'] = which_roi - + + # dict_to_tuple = list(tuple(x.values()) for x in self._subject_data.values()) + # print(f' dict vs tuples = {np.allclose(subjectdata, dict_to_tuple)}') + + def _init_model(self, n_timepts, num_roi, A_t_=None, roi_cov=None, mu0=None, + roi_cov_0=None, log_sigsq_lst=None): + + self._n_times = n_timepts + self._subject_data = dict() + + set_default = \ + lambda prm, val, deflt: \ + self.__setattr__(prm, val.copy() if val is not None else deflt) + + # initialize parameters + set_default("A_t_", A_t_, + np.stack([rand_stable(num_roi, maxew=0.7) for _ in range(n_timepts)], + axis=0)) + set_default("roi_cov", roi_cov, rand_psd(num_roi)) + set_default("mu0", mu0, np.zeros(num_roi)) + set_default("roi_cov_0", roi_cov_0, rand_psd(num_roi)) + set_default("log_sigsq_lst", log_sigsq_lst, + [np.log(np.random.gamma(2, 1, size=num_roi+1))]) + + # initialize sufficient statistics + n_timepts, num_roi, _ = self.A_t_.shape + self._B0 = np.zeros((num_roi, num_roi)) + self._B1 = np.zeros((n_timepts-1, num_roi, num_roi)) + self._B3 = np.zeros((n_timepts-1, num_roi, num_roi)) + self._B2 = np.zeros((n_timepts-1, num_roi, num_roi)) + self._B4 = list() + def set_data(self, subjectdata): # add subject data, re-generate log_sigsq_lst if necessary super().set_data(subjectdata) @@ -183,7 +398,7 @@ def _em_objective(self): Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - ntrials, timepts, _ = Y.shape + ntrials, n_timepts, _ = Y.shape sigsq_vals = np.exp(self.log_sigsq_lst[s]) R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) @@ -191,7 +406,7 @@ def _em_objective(self): if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None or self._sigmas_tnt_smooth_lst is None): - roi_cov_t = _ensure_ndim(self.roi_cov, timepts, 3) + roi_cov_t = _ensure_ndim(self.roi_cov, n_timepts, 3) with numpy_num_threads(1): _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ rts_smooth_fast(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, @@ -237,7 +452,7 @@ def _em_objective(self): # obj += L2(roi_cov, At) L_roi_cov_inv_tmp = solve_triangular(L_roi_cov, tmp, lower=True) - L2 += (ntrials*(timepts-1)*2.*np.sum(np.log(np.diag(L_roi_cov))) + L2 += (ntrials*(n_timepts-1)*2.*np.sum(np.log(np.diag(L_roi_cov))) + np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_tmp, lower=True, trans='T'))) @@ -253,7 +468,7 @@ def _em_objective(self): # obj += L3(sigsq_vals) L_R_inv_B4 = solve_triangular(L_R, B4, lower=True) - L3 += (ntrials*timepts*2*np.sum(np.log(np.diag(L_R))) + L3 += (ntrials*n_timepts*2*np.sum(np.log(np.diag(L_R))) + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, trans='T'))) @@ -283,9 +498,11 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, A_t_roi_cov_tol=1e-6, diag_roi_cov=False, update_sigsq=True, do_final_smoothing=True, average_mus_smooth=True, Atrue=None, tau=0.1, c1=1e-4): + self.set_data(self._all_subject_data) + fxn_start = datetime.now() - timepts, num_roi, _ = self.A_t_.shape + n_timepts, num_roi, _ = self.A_t_.shape # make initial A_t_ stationary if stationary_A_t_ option specified if stationary_A_t_: @@ -449,7 +666,7 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, A_t_roi_cov_tol=1e-6, Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata sigsq_vals = np.exp(self.log_sigsq_lst[s]) R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) - roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) + roi_cov_t = _ensure_ndim(self.roi_cov, self._n_timepts, 3) with numpy_num_threads(1): loglik_subject, mus_smooth, _, _, St = \ rts_smooth(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, self.roi_cov_0, @@ -475,8 +692,7 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, A_t_roi_cov_tol=1e-6, return objvals, converged, mus_smooth_lst, self._loglik, St_lst def _e_step(self, verbose=0): - - timepts, num_roi, _ = self.A_t_.shape + n_timepts, num_roi, _ = self.A_t_.shape # reset accumulation arrays self._B0[:] = 0. @@ -504,7 +720,7 @@ def _e_step(self, verbose=0): sigsq_vals = np.exp(self.log_sigsq_lst[s]) R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) L_R = np.linalg.cholesky(R) - roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) + roi_cov_t = _ensure_ndim(self.roi_cov, self._n_timepts, 3) with numpy_num_threads(1): _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ @@ -608,7 +824,7 @@ def update_A_t_and_roi_cov(self, update_A_t_=True, update_roi_cov=True, stationa tmp = einsum2('tik,tkl->til', At, self._B3) AtB3AtT = einsum2('til,tjl->tij', tmp, At) elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) - self.roi_cov = (1. / (self._ntrials_all * self._timepts + self.roi_cov = (1. / (self._ntrials_all * self._n_timepts )) * elbo_2 if diag_roi_cov: self.roi_cov = np.diag(np.diag(self.roi_cov)) @@ -633,20 +849,20 @@ def update_log_sigsq_lst(self, verbose=0): if verbose > 1: print(" update subject log-sigmasq") - timepts, num_roi, _ = self.A_t_.shape + n_timepts, num_roi, _ = self.A_t_.shape # update log_sigsq_vals for each subject and ROI for s, sdata in enumerate(self.unpack_all_subject_data()): Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - ntrials, timepts, _ = Y.shape + ntrials, n_timepts, _ = Y.shape mus_smooth = self._mus_smooth_lst[s] sigmas_smooth = self._sigmas_smooth_lst[s] B4 = self._B4[s] log_sigsq = self.log_sigsq_lst[s].copy() log_sigsq_obj = lambda x: \ - MEGLDS.L3_obj(x, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, timepts) + MEGLDS.L3_obj(x, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, n_timepts) log_sigsq_val_and_grad = vgrad(log_sigsq_obj) options = {'maxiter': 500} @@ -678,7 +894,7 @@ def calculate_smoothed_estimates(self): Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata sigsq_vals = np.exp(self.log_sigsq_lst[s]) R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) - roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) + roi_cov_t = _ensure_ndim(self.roi_cov, self._n_timepts, 3) with numpy_num_threads(1): ll, mus_smooth, sigmas_smooth, sigmas_tnt_smooth, _ = \ rts_smooth(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, self.roi_cov_0, @@ -704,7 +920,7 @@ def log_likelihood(self): Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata sigsq_vals = np.exp(self.log_sigsq_lst[s]) R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) - roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) + roi_cov_t = _ensure_ndim(self.roi_cov, self._n_timepts, 3) ll, _, _, _ = kalman_filter(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, self.roi_cov_0, store_St=False) self._loglik += ll @@ -712,9 +928,9 @@ def log_likelihood(self): return self._loglik def nparams(self): - timepts, p, _ = self.A_t_.shape + n_timepts, p, _ = self.A_t_.shape - # this should equal (timepts-1)*p*p unless some shrinkage is used on At + # this should equal (n_timepts-1)*p*p unless some shrinkage is used on At nparams_At = np.sum(np.abs(self.A_t_[:-1]) > 0) # nparams = nparams(At) + nparams(roi_cov) + nparams(roi_cov_0) @@ -790,17 +1006,24 @@ def L2_obj(self, At, L_roi_cov): # TODO: convert to instance method @staticmethod - def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, timepts): + def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, n_timepts): R = MEGLDS.R_(snsr_cov, fwd_src_snsr, np.exp(log_sigsq_vals), which_roi) try: L_R = np.linalg.cholesky(R) except LinAlgError: return np.finfo('float').max L_R_inv_B4 = solve_triangular(L_R, B4, lower=True) - return (ntrials*timepts*2.*np.sum(np.log(np.diag(L_R))) + return (ntrials*n_timepts*2.*np.sum(np.log(np.diag(L_R))) + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, trans='T'))) + # @property + # def A_t_(self): + # '''The time-varying connectivity''' + # if self._A_t_ is not None: + # return self._A_t_.copy() + # else: + # return None # @property # def A(self): @@ -847,8 +1070,8 @@ def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, timep # return self.A.shape[1] # @property - # def timepts(self): - # return self._timepts + # def n_timepts(self): + # return self._n_timepts # @property # def lam0(self): diff --git a/state_space/state_space_connectivity.py b/state_space/state_space_connectivity.py index 52db4f5c..65ca4b0d 100644 --- a/state_space/state_space_connectivity.py +++ b/state_space/state_space_connectivity.py @@ -18,59 +18,253 @@ #where should these files live within mne-connectivity repo? from megssm.models import MEGLDS as LDS from megssm.plotting import plot_A_t_ +import pickle + +load = 0 + +if not load: +# define paths to sample data + path = None + path = '/Users/jordandrew/Documents/MEG/mne_data'#'/MNE-sample-data' + data_path = mne.datasets.sample.data_path(path=path) + sample_folder = data_path / 'MEG/sample' + subjects_dir = data_path / 'subjects' + + ## import raw data and find events + raw_fname = sample_folder / 'sample_audvis_raw.fif' + raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) + events = mne.find_events(raw, stim_channel='STI 014') + + ## define epochs using event_dict + event_dict = {'auditory_left': 1, 'auditory_right': 2, 'visual_left': 3, + 'visual_right': 4}#, 'face': 5, 'buttonpress': 32} + epochs = mne.Epochs(raw, events, tmin=-0.3, tmax=0.7, event_id=event_dict, + preload=True).pick_types(meg=True,eeg=True) + condition = 'auditory_left' + epochs = epochs[condition] # choose condition for analysis + + + + #SNR boost epochs, bootstraps of 3 + # def bootstrap_subject(subject_name, seed=8675309, sfreq=100, lower=None, + # upper=None, nbootstrap=3, g_nsamples=-5, + # overwrite=False, validation_set=True): + # import autograd.numpy as np + # import os + + # # subjects = ['sample'] + # datasets = ['train', 'validation'] + # use_erm = eq = False + # independent = False + # if g_nsamples == 0: + # print('nsamples == 0, ensuring independence of samples') + # independent = True + # elif g_nsamples == -1: + # print("using half of trials per sample") + # elif g_nsamples == -2: + # print("using empty room noise at half of trials per sample") + # use_erm = True + # elif g_nsamples == -3: + # print("using independent and trial-count equalized samples") + # eq = True + # independent = True + # elif g_nsamples == -4: + # print("using independent, trial-count equailized, non-boosted samples") + # assert nbootstrap == 0 # sanity check + # eq = True + # independent = True + # datasets = ['train'] + # elif g_nsamples == -5: + # print("using independent, trial-count equailized, integer boosted samples") + # eq = True + # independent = True + # datasets = ['train'] + + # if lower is not None or upper is not None: + # if upper is None: + # print('high-pass filtering at %.2f Hz' % lower) + # elif lower is None: + # print('low-pass filtering at %.2f Hz' % upper) + # else: + # print('band-pass filtering from %.2f-%.2f Hz' % (lower, upper)) + + # if sfreq is not None: + # print('resampling to %.2f Hz' % sfreq) + + # print(":: processing subject %s" % subject_name) + # np.random.seed(seed) + + # for dataset in datasets: + + # print(' generating ', dataset, ' set') + # datadir = './data' + + # subj_dir = os.path.join(datadir, subject_name) + # print("subject dir:" + subj_dir) + # if not os.path.exists(subj_dir): + # print(' %s not found, skipping' % subject_name) + # return + + # epochs_dir = os.path.join(datadir, subject_name, 'epochs') + # epochs_fname = "All_55-sss_%s-epo.fif" % subject_name + # epochs_bs_fname = (epochs_fname.split('-epo')[0] + + # "-bootstrap_%d-nsamples_%d-seed_%d%s%s%s-" + # % (nbootstrap, g_nsamples, seed, + # '-lower_%.2e' % lower if lower is not None else '', + # '-upper_%.2e' % upper if upper is not None else '', + # '-sfreq_%.2e' % sfreq if sfreq is not None else '') + + # dataset + "-epo.fif") + + # if os.path.exists(os.path.join(epochs_dir, epochs_bs_fname)) and \ + # not overwrite: + # print(" => found existing bootstrapped epochs, skipping") + # return + + # epochs = mne.read_epochs(os.path.join(epochs_dir, epochs_fname), + # preload=True) + # condition_map = {'auditory_left':['auditory_left'],'auditory_right': ['auditory_right'], + # 'visual_left': ['visual_left'], 'visual_right': ['visual_right']} + # condition_eq_map = dict(auditory_left=['auditory_left'], auditory_right=['auditory_right'], + # visual_left=['visual_left'], visual_right='visual_right') + # if eq: + # epochs.equalize_event_counts(list(condition_map)) + # cond_map = condition_eq_map + + # # apply band-pass filter to limit signal to desired frequency band + # if lower is not None or upper is not None: + # epochs = epochs.filter(lower, upper) + + # # perform resampling with specified sampling frequency + # if sfreq is not None: + # epochs = epochs.resample(sfreq) + + # data_bs_all = list() + # events_bs_all = list() + # for cond in sorted(cond_map.keys()): + # print(" -> condition %s: bootstrapping" % cond, end='') + # ep = epochs[cond_map[cond]] + # dat = ep.get_data().copy() + # ntrials, T, p = dat.shape + + # use_bootstrap = nbootstrap + # if g_nsamples == -4: + # nsamples = 1 + # use_bootstrap = ntrials + # elif g_nsamples == -5: + # nsamples = nbootstrap + # use_bootstrap = ntrials // nsamples + # elif independent: + # nsamples = (ntrials - 1) // use_bootstrap + # elif g_nsamples in (-1, -2): + # nsamples = ntrials // 2 + # else: + # assert g_nsamples > 0 + # nsamples = g_nsamples + # print(" using %d samples (%d trials)" + # % (nsamples, use_bootstrap)) + + # # bootstrap here + # if independent: # independent + # if nsamples == 1 and use_bootstrap == ntrials: + # inds = np.arange(ntrials) + # else: + # inds = np.random.choice(ntrials, nsamples * use_bootstrap) + # inds.shape = (use_bootstrap, nsamples) + # dat_bs = np.mean(dat[inds], axis=1) + # events_bs = ep.events[inds[:, 0]] + # assert dat_bs.shape[0] == events_bs.shape[0] + # else: + # dat_bs = np.empty((ntrials, T, p)) + # events_bs = np.empty((ntrials, 3), dtype=int) + # for i in range(ntrials): + + # inds = list(set(range(ntrials)).difference([i])) + # inds = np.random.choice(inds, size=nsamples, + # replace=False) + # inds = np.append(inds, i) + + # dat_bs[i] = np.mean(dat[inds], axis=0) + # events_bs[i] = ep.events[i] + + # inds = np.random.choice(ntrials, size=use_bootstrap, + # replace=False) + # dat_bs = dat_bs[inds] + # events_bs = events_bs[inds] + + # assert dat_bs.shape == (use_bootstrap, T, p) + # assert events_bs.shape == (use_bootstrap, 3) + # assert (events_bs[:, 2] == events_bs[0, 2]).all() #not working for sample_info + + # data_bs_all.append(dat_bs) + # events_bs_all.append(events_bs) + + # # write bootstrap epochs + # info_dict = epochs.info.copy() + + # dat_all = np.vstack(data_bs_all) + # events_all = np.vstack(events_bs_all) + # # replace first column with sequential list as we don't really care + # # about the raw timings + # events_all[:, 0] = np.arange(events_all.shape[0]) + + # epochs_bs = mne.EpochsArray( + # dat_all, info_dict, events=events_all, tmin=-0.2, + # event_id=epochs.event_id.copy(), on_missing='ignore') + + # print(" saving bootstrapped epochs (%s)" % (epochs_bs_fname,)) + # epochs_bs.save(os.path.join(epochs_dir, epochs_bs_fname)) + # bootstrap_subject('sample') + + + + + + + + ## read forward solution, remove bad channels + fwd_fname = sample_folder / 'sample_audvis-meg-eeg-oct-6-fwd.fif' + fwd = mne.read_forward_solution(fwd_fname) + + ## read in covariance OR compute noise covariance? noise_cov drops bad chs + cov_fname = sample_folder / 'sample_audvis-cov.fif' + cov = mne.read_cov(cov_fname) # drop bad channels in add_subject + # noise_cov = mne.compute_covariance(epochs, tmax=0) + + ## read labels for analysis + label_names = ['AUD-lh', 'AUD-rh', 'Vis-lh', 'Vis-rh'] + labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label', + subject='sample') for label in label_names] + + # initiate model + num_rois = len(labels) + n_timepts = len(epochs.times) + times = epochs.times + # model = LDS(num_rois, n_timepts, lam0=0, lam1=100) # only needs the forward, labels, and noise_cov to be initialized + model = LDS(lam0=0, lam1=100) + + model.add_subject('sample', subjects_dir, epochs, labels, fwd, cov) + # #when to use compute_cov vs read_cov? ie cov vs noise_cov + + # model.fit(niter=100, verbose=1) + # A_t_ = model.A_t_ + # assert A_t_.shape == (n_timepts, num_rois, num_rois) + + # file = open('sample_subj_stdNone.pkl','wb') + # pickle.dump([model, num_rois, n_timepts, times, condition, label_names],file) + # file.close + +# if load: +# with open('sample_subj_stdNone.pkl','rb') as f: +# model, num_rois, n_timepts, times, condition, label_names = pickle.load(f) + +# with mpl.rc_context(): +# {'xtick.labelsize': 'x-small', 'ytick.labelsize': 'x-small'} +# fig, ax = plt.subplots(num_rois, num_rois, constrained_layout=True, squeeze=False, +# figsize=(12, 10)) +# plot_A_t_(model.A_t_, labels=label_names, times=times, ax=ax) +# fig.suptitle(condition) + -## define paths to sample data -path = None -path = '/Users/jordandrew/Documents/MEG/mne_data'#'/MNE-sample-data' -data_path = mne.datasets.sample.data_path(path=path) -sample_folder = data_path / 'MEG/sample' -subjects_dir = data_path / 'subjects' - -## import raw data and find events -raw_fname = sample_folder / 'sample_audvis_raw.fif' -raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) -events = mne.find_events(raw, stim_channel='STI 014') - -## define epochs using event_dict -event_dict = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3, - 'visual/right': 4, 'face': 5, 'buttonpress': 32} -epochs = mne.Epochs(raw, events, tmin=-0.3, tmax=0.7, event_id=event_dict, - preload=True).pick_types(meg=True,eeg=True) -condition = 'visual' -epochs = epochs[condition] # choose condition for analysis - -## read forward solution, remove bad channels -fwd_fname = sample_folder / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fwd = mne.read_forward_solution(fwd_fname) - -## read in covariance OR compute noise covariance? noise_cov drops bad chs -cov_fname = sample_folder / 'sample_audvis-cov.fif' -cov = mne.read_cov(cov_fname) # drop bad channels in add_subject -# noise_cov = mne.compute_covariance(epochs, tmax=0) - -## read labels for analysis -label_names = ['AUD-lh', 'AUD-rh', 'Vis-lh', 'Vis-rh'] -labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label') - for label in label_names] - -## initiate model -num_rois = len(labels) -timepts = len(epochs.times) -model = LDS(num_rois, timepts, lam0=0, lam1=100) # only needs the forward, labels, and noise_cov to be initialized - -model.add_subject('sample', subjects_dir, epochs, labels, fwd, cov) -#when to use compute_cov vs read_cov? ie cov vs noise_cov - -model.fit(niter=100, verbose=1) -A_t_ = model.A_t_ -assert A_t_.shape == (timepts, num_rois, num_rois) - -with mpl.rc_context(): - {'xtick.labelsize': 'x-small', 'ytick.labelsize': 'x-small'} - fig, ax = plt.subplots(num_rois, num_rois, constrained_layout=True, squeeze=False, - figsize=(12, 10)) - plot_A_t_(A_t_, labels=label_names, times=epochs.times, ax=ax) - fig.suptitle(condition) From 5164d6c8700d34ce2252bc82cc2cf8287a3d4e45 Mon Sep 17 00:00:00 2001 From: jadrew43 Date: Fri, 26 Aug 2022 17:01:27 -0700 Subject: [PATCH 10/17] model working; _scale_sensor_data needs changed to MNE functions --- state_space/megssm/mne_util.py | 229 ++++------- state_space/megssm/models.py | 487 +++++++----------------- state_space/state_space_connectivity.py | 316 ++++----------- 3 files changed, 303 insertions(+), 729 deletions(-) diff --git a/state_space/megssm/mne_util.py b/state_space/megssm/mne_util.py index a8d27134..81f2ddab 100644 --- a/state_space/megssm/mne_util.py +++ b/state_space/megssm/mne_util.py @@ -12,9 +12,7 @@ from scipy.sparse import csc_matrix, csr_matrix, diags from sklearn.decomposition import PCA -# from util import Carray ##skip import just pasted; util also from MEGLDS repo Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') -Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') Carray = Carray64 @@ -67,10 +65,10 @@ def __init__(self, fwd, labels, label_flip=False): lverts = clab.get_vertices_used(vertices=src[hi]['vertno']) - # gets the indices in the source space vertex array, not the huge - # array. - # use `src[hi]['vertno'][lverts]` to get surface vertex indices to - # plot. + # gets the indices in the source space vertex array, not the + # huge array. + # use `src[hi]['vertno'][lverts]` to get surface vertex indices + # to plot. lverts = np.searchsorted(src[hi]['vertno'], lverts) lverts += offsets[hemi] vertidx.extend(lverts) @@ -96,70 +94,9 @@ def __init__(self, fwd, labels, label_flip=False): self.n_lhverts = n_lhverts self.n_rhverts = n_rhverts self.labels = labels - + return - # @property - # def fwd_src_sn(self): - # return self.fwd['sol']['data'] - - # @property - # def fwd_src_roi(self): - # return self._fwd_src_roi - - # @fwd_src_roi.setter - # def fwd_src_roi(self, val): - # self._fwd_src_roi = val - - # @property - # def which_roi(self): - # return self._which_roi - - # @which_roi.setter - # def which_roi(self, val): - # self._which_roi = val - - # @property - # def fwd_roi_snsr(self): - # from util import Carray - # return Carray(csr_matrix.dot(self.fwd_src_roi.T, self.fwd_src_sn.T).T) - - # def get_label_vinds(self, label): - # li = self.labels.index(label) - # if isinstance(label, mne.Label): - # label_vert_idx = self.fwd_src_roi[:, li].nonzero()[0] - # label_vert_idx -= self.offsets[label.hemi] - # return label_vert_idx - # elif isinstance(label, mne.BiHemiLabel): - # # these labels store both hemispheres so subtract the rh offset - # # from that part of the vertex array - # lh_label_vert_idx = self.fwd_src_roi[:self.n_lhverts, li].nonzero()[0] - # rh_label_vert_idx = self.fwd_src_roi[self.n_lhverts:, li].nonzero()[0] - # rh_label_vert_idx[self.n_lhverts:] -= self.offsets['rh'] - # return [lh_label_vert_idx, rh_label_vert_idx] - - # def get_label_verts(self, label, src): - # # if you're thinking of using this to plot, why not just use - # # brain.add_label from pysurfer? - # if isinstance(label, mne.Label): - # hi = 0 if label.hemi == 'lh' else 1 - # label_vert_idx = self.get_label_vinds(label) - # varray = src[hi]['vertno'][label_vert_idx] - # elif isinstance(label, mne.BiHemiLabel): - # lh_label_vert_idx, rh_label_vert_idx = self.get_label_vinds(label) - # varray = [src[0]['vertno'][lh_label_vert_idx], - # src[1]['vertno'][rh_label_vert_idx]] - # return varray - - # def get_hemi_idx(self, label): - # if isinstance(label, mne.Label): - # return 0 if label.hemi == 'lh' else 1 - # elif isinstance(label, mne.BiHemiLabel): - # hemis = [None] * 2 - # for i, lab in enumerate([label.lh, label.rh]): - # hemis[i] = 0 if lab.hemi == 'lh' else 1 - # return hemis - def apply_projs(epochs, fwd, cov): """ apply projection operators to fwd and cov """ proj, _ = mne.io.proj.setup_proj(epochs.info, activate=False) @@ -174,83 +111,78 @@ def apply_projs(epochs, fwd, cov): return fwd, cov -def _scale_sensor_data(epochs, fwd, cov, roi_to_src): - """ apply per-channel-type scaling to epochs, forward, and covariance """ - - epochs = epochs.copy().pick('data', exclude='bads') - info = epochs.info.copy() - data = epochs.get_data().copy() - snsr_cov = cov.pick_channels(epochs.ch_names, ordered=True).data - fwd = mne.convert_forward_solution(fwd, force_fixed=True) - fwd_src_snsr = fwd.pick_channels(epochs.ch_names, ordered=True)['sol']['data'] - del cov, fwd, epochs #neccessary? - - # rescale data according to covariance whitener? - rescale_cov = mne.make_ad_hoc_cov(info, std=1) - scaler = mne.cov.compute_whitener(rescale_cov, info) - del rescale_cov - fwd_src_snsr = scaler[0] @ fwd_src_snsr - snsr_cov = scaler[0] @ snsr_cov - data = scaler[0] @ data - - fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, fwd_src_snsr.T).T) - - epochs = mne.EpochsArray(data, info) - - return fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs - -# def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., -# grad_scale=1.): +# def _scale_sensor_data(epochs, fwd, cov, roi_to_src, **std): # """ apply per-channel-type scaling to epochs, forward, and covariance """ -# # # from util import Carray ##skip import just pasted; util also from MEGLDS repo -# # Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') -# # Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') -# # Carray = Carray64 - -# epochs = epochs.copy().pick('data', exclude='bads') -# cov = cov.pick_channels(epochs.ch_names, ordered=True) -# fwd = mne.convert_forward_solution(fwd, force_fixed=True) -# fwd = fwd.pick_channels(epochs.ch_names, ordered=True) -# # -# # get indices for each channel type -# ch_names = cov['names'] # same as self.fwd['info']['ch_names'] -# sel_eeg = pick_types(fwd['info'], meg=False, eeg=True, ref_meg=False) -# sel_mag = pick_types(fwd['info'], meg='mag', eeg=False, ref_meg=False) -# sel_grad = pick_types(fwd['info'], meg='grad', eeg=False, ref_meg=False) -# idx_eeg = [ch_names.index(ch_names[c]) for c in sel_eeg] -# idx_mag = [ch_names.index(ch_names[c]) for c in sel_mag] -# idx_grad = [ch_names.index(ch_names[c]) for c in sel_grad] - -# # retrieve forward and sensor covariance -# fwd_src_snsr = fwd['sol']['data'].copy() -# snsr_cov = cov.data.copy() - -# # scale forward matrix -# fwd_src_snsr[idx_eeg,:] *= eeg_scale -# fwd_src_snsr[idx_mag,:] *= mag_scale -# fwd_src_snsr[idx_grad,:] *= grad_scale -# # construct fwd_roi_snsr matrix -# fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, fwd_src_snsr.T).T) - -# # scale sensor covariance -# snsr_cov[np.ix_(idx_eeg, idx_eeg)] *= eeg_scale**2 -# snsr_cov[np.ix_(idx_mag, idx_mag)] *= mag_scale**2 -# snsr_cov[np.ix_(idx_grad, idx_grad)] *= grad_scale**2 +# for s in std: +# std[s] = 1/std[s] +# snsr_cov = cov.data.copy() +# fwd_src_snsr = fwd['sol']['data'].copy() -# # scale epochs # info = epochs.info.copy() # data = epochs.get_data().copy() -# data[:,idx_eeg,:] *= eeg_scale -# data[:,idx_mag,:] *= mag_scale -# data[:,idx_grad,:] *= grad_scale - +# rescale_cov = mne.make_ad_hoc_cov(info, std=std) +# scaler = mne.cov.compute_whitener(rescale_cov, info) +# del rescale_cov +# fwd_src_snsr = scaler[0] @ fwd_src_snsr +# snsr_cov = scaler[0] @ snsr_cov +# data = scaler[0] @ data + + # fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, + #fwd_src_snsr.T).T) + # epochs = mne.EpochsArray(data, info) # return fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs +def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., + mag_scale=1., grad_scale=1.): + """ apply per-channel-type scaling to epochs, forward, and covariance """ + Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') + Carray = Carray64 + + + # get indices for each channel type + ch_names = cov['names'] # same as self.fwd['info']['ch_names'] + sel_eeg = pick_types(fwd['info'], meg=False, eeg=True, ref_meg=False) + sel_mag = pick_types(fwd['info'], meg='mag', eeg=False, ref_meg=False) + sel_grad = pick_types(fwd['info'], meg='grad', eeg=False, ref_meg=False) + idx_eeg = [ch_names.index(ch_names[c]) for c in sel_eeg] + idx_mag = [ch_names.index(ch_names[c]) for c in sel_mag] + idx_grad = [ch_names.index(ch_names[c]) for c in sel_grad] + + # retrieve forward and sensor covariance + G = fwd['sol']['data'].copy() + Q = cov.data.copy() + + # scale forward matrix + G[idx_eeg,:] *= eeg_scale + G[idx_mag,:] *= mag_scale + G[idx_grad,:] *= grad_scale + + # construct GL matrix + GL = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, G.T).T) + + # scale sensor covariance + Q[np.ix_(idx_eeg, idx_eeg)] *= eeg_scale**2 + Q[np.ix_(idx_mag, idx_mag)] *= mag_scale**2 + Q[np.ix_(idx_grad, idx_grad)] *= grad_scale**2 + + # scale epochs + info = epochs.info.copy() + data = epochs.get_data().copy() + + data[:,idx_eeg,:] *= eeg_scale + data[:,idx_mag,:] *= mag_scale + data[:,idx_grad,:] *= grad_scale + + epochs = mne.EpochsArray(data, info) + + return G, GL, Q, epochs + + def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', pctvar=0.99, mean_center=False, label_flip=False): """ apply sensor scaling, PCA dimensionality reduction with/without @@ -260,12 +192,16 @@ def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', raise ValueError("dim_mode must be in {'rank', 'pctvar', 'whiten'}") print("running pca for subject %s" % subject_name) - + + scales = {'eeg_scale' : 1e8, 'mag_scale' : 1e16, 'grad_scale' : 1e14} + # compute ROI-to-source map - roi_to_src = ROIToSourceMap(fwd, labels, label_flip) + roi_to_src = ROIToSourceMap(fwd, labels, label_flip) + + if dim_mode == 'whiten': - fwd_src_snsr, fwd_roi_snsr, Q_snsr, epochs = \ + fwd_src_snsr, fwd_roi_snsr, cov_snsr, epochs = \ _scale_sensor_data(epochs, fwd, cov, roi_to_src) dat = epochs.get_data() dat = Carray(np.swapaxes(dat, -1, -2)) @@ -282,9 +218,11 @@ def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', (subject_name, W.shape[0])) else: - - fwd_src_snsr, fwd_roi_snsr, Q_snsr, epochs = \ - _scale_sensor_data(epochs, fwd, cov, roi_to_src) + + fwd_src_snsr, fwd_roi_snsr, cov_snsr, epochs = \ + _scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) + + dat = epochs.get_data().copy() dat = Carray(np.swapaxes(dat, -1, -2)) @@ -292,30 +230,31 @@ def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', dat -= np.mean(dat, axis=1, keepdims=True) dat_stacked = np.reshape(dat, (-1, dat.shape[-1])) - pca = PCA() pca.fit(dat_stacked) if dim_mode == 'rank': idx = np.linalg.matrix_rank(np.cov(dat_stacked, rowvar=False)) else: - idx = np.where(np.cumsum(pca.explained_variance_ratio_) > pctvar)[0][0] + idx = np.where(np.cumsum(pca.explained_variance_ratio_) > + pctvar)[0][0] idx = np.maximum(idx, len(labels)) W = pca.components_[:idx] print("subject %s using %d principal components" % (subject_name, idx)) - + ntrials, T, _ = dat.shape dat_pca = np.dot(dat_stacked, W.T) dat_pca = np.reshape(dat_pca, (ntrials, T, -1)) fwd_src_snsr_pca = np.dot(W, fwd_src_snsr) fwd_roi_snsr_pca = np.dot(W, fwd_roi_snsr) - Q_snsr_pca = np.dot(W,np.dot(Q_snsr, W.T)) + cov_snsr_pca = np.dot(W,np.dot(cov_snsr, W.T)) data = dat_pca - return data, fwd_roi_snsr_pca, fwd_src_snsr_pca, Q_snsr_pca, roi_to_src.which_roi + return (data, fwd_roi_snsr_pca, fwd_src_snsr_pca, cov_snsr_pca, + roi_to_src.which_roi) def combine_medial_labels(labels, subject='fsaverage', surf='white', diff --git a/state_space/megssm/models.py b/state_space/megssm/models.py index 8da2c276..684494b8 100755 --- a/state_space/megssm/models.py +++ b/state_space/megssm/models.py @@ -5,7 +5,7 @@ import autograd.numpy as np import scipy.optimize as spopt -from autograd import grad #autograd --> jax +from autograd import grad from autograd import value_and_grad as vgrad from scipy.linalg import LinAlgError @@ -16,26 +16,19 @@ from .message_passing import predict_step, condition from .numpy_numthreads import numpy_num_threads -from .mne_util import ROIToSourceMap, _scale_sensor_data, run_pca_on_subject +from .mne_util import (ROIToSourceMap, _scale_sensor_data, run_pca_on_subject, + apply_projs) try: from autograd_linalg import solve_triangular except ImportError: raise RuntimeError("must install `autograd_linalg` package") -# einsum2 is a parallel version of einsum that works for two arguments -try: - from einsum2 import einsum2 -except ImportError: - # rename standard numpy function if don't have einsum2 - print("=> WARNING: using standard numpy.einsum,", - "consider installing einsum2 package") - from autograd.numpy import einsum as einsum2 +from autograd.numpy import einsum from datetime import datetime -# TODO: add documentation to all methods class _MEGModel(object): """ Base class for any model applied to MEG data that handles storing and unpacking data from tuples. """ @@ -47,7 +40,8 @@ def __init__(self): self._nsubjects = 0 def set_data(self, subjectdata): - n_timepts_lst = [self.unpack_subject_data(e)[0].shape[1] for e in subjectdata] + n_timepts_lst = [self.unpack_subject_data(e)[0].shape[1] for e in + subjectdata] assert len(list(set(n_timepts_lst))) == 1 self._n_timepts = n_timepts_lst[0] ntrials_lst = [self.unpack_subject_data(e)[0].shape[0] for e in \ @@ -78,35 +72,15 @@ def unpack_subject_data(cls, sdata): return Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi -# TODO: add documentation to all methods -# TODO: make some methods "private" (leading underscore) if necessary class MEGLDS(_MEGModel): """ State-space model for MEG data, as described in "A state-space model of cross-region dynamic connectivity in MEG/EEG", Yang et al., NIPS 2016. """ - # def __init__(self, num_roi, n_timepts, A_t_=None, roi_cov=None, mu0=None, roi_cov_0=None, - # log_sigsq_lst=None, lam0=0., lam1=0., penalty='ridge', - # store_St=True): def __init__(self, lam0=0., lam1=0., penalty='ridge', store_St=True): super().__init__() self._model_initalized = False - - # set_default = \ - # lambda prm, val, deflt: \ - # self.__setattr__(prm, val.copy() if val is not None else deflt) - - # # initialize parameters - # set_default("A_t_", A_t_, - # np.stack([rand_stable(num_roi, maxew=0.7) for _ in range(n_timepts)], - # axis=0)) - # set_default("roi_cov", roi_cov, rand_psd(num_roi)) - # set_default("mu0", mu0, np.zeros(num_roi)) - # set_default("roi_cov_0", roi_cov_0, rand_psd(num_roi)) - # set_default("log_sigsq_lst", log_sigsq_lst, - # [np.log(np.random.gamma(2, 1, size=num_roi+1))]) - self.lam0 = lam0 self.lam1 = lam1 @@ -122,22 +96,13 @@ def __init__(self, lam0=0., lam1=0., penalty='ridge', store_St=True): self._loglik = None self._store_St = bool(store_St) - # # initialize sufficient statistics - # n_timepts, num_roi, _ = self.A_t_.shape - # self._B0 = np.zeros((num_roi, num_roi)) - # self._B1 = np.zeros((n_timepts-1, num_roi, num_roi)) - # self._B3 = np.zeros((n_timepts-1, num_roi, num_roi)) - # self._B2 = np.zeros((n_timepts-1, num_roi, num_roi)) - # self._B4 = list() - - self._all_subject_data = list()#dict() + self._all_subject_data = list() #SNR boost epochs, bootstraps of 3 - def bootstrap_subject(self, subject_name, seed=8675309, sfreq=100, lower=None, - upper=None, nbootstrap=3, g_nsamples=-5, + def bootstrap_subject(self, epochs, subject_name, seed=8675309, sfreq=100, + lower=None, upper=None, nbootstrap=3, g_nsamples=-5, overwrite=False, validation_set=True): - # subjects = ['sample'] datasets = ['train', 'validation'] use_erm = eq = False independent = False @@ -154,13 +119,15 @@ def bootstrap_subject(self, subject_name, seed=8675309, sfreq=100, lower=None, eq = True independent = True elif g_nsamples == -4: - print("using independent, trial-count equailized, non-boosted samples") + print("using independent, trial-count equailized, non-boosted" + "samples") assert nbootstrap == 0 # sanity check eq = True independent = True datasets = ['train'] elif g_nsamples == -5: - print("using independent, trial-count equailized, integer boosted samples") + print("using independent, trial-count equailized, integer boosted" + "samples") eq = True independent = True datasets = ['train'] @@ -184,33 +151,15 @@ def bootstrap_subject(self, subject_name, seed=8675309, sfreq=100, lower=None, print(' generating ', dataset, ' set') datadir = './data' - subj_dir = os.path.join(datadir, subject_name) - print("subject dir:" + subj_dir) - if not os.path.exists(subj_dir): - print(' %s not found, skipping' % subject_name) - return - - epochs_dir = os.path.join(datadir, subject_name, 'epochs') - epochs_fname = "All_55-sss_%s-epo.fif" % subject_name - epochs_bs_fname = (epochs_fname.split('-epo')[0] + - "-bootstrap_%d-nsamples_%d-seed_%d%s%s%s-" - % (nbootstrap, g_nsamples, seed, - '-lower_%.2e' % lower if lower is not None else '', - '-upper_%.2e' % upper if upper is not None else '', - '-sfreq_%.2e' % sfreq if sfreq is not None else '') + - dataset + "-epo.fif") - - if os.path.exists(os.path.join(epochs_dir, epochs_bs_fname)) and \ - not overwrite: - print(" => found existing bootstrapped epochs, skipping") - return - - epochs = mne.read_epochs(os.path.join(epochs_dir, epochs_fname), - preload=True) - condition_map = {'auditory_left':['auditory_left'],'auditory_right': ['auditory_right'], - 'visual_left': ['visual_left'], 'visual_right': ['visual_right']} - condition_eq_map = dict(auditory_left=['auditory_left'], auditory_right=['auditory_right'], - visual_left=['visual_left'], visual_right='visual_right') + condition_map = {'auditory_left':['auditory_left'], + 'auditory_right': ['auditory_right'], + 'visual_left': ['visual_left'], + 'visual_right': ['visual_right']} + condition_eq_map = dict(auditory_left=['auditory_left'], + auditory_right=['auditory_right'], + visual_left=['visual_left'], + visual_right='visual_right') + if eq: epochs.equalize_event_counts(list(condition_map)) cond_map = condition_eq_map @@ -253,7 +202,8 @@ def bootstrap_subject(self, subject_name, seed=8675309, sfreq=100, lower=None, if nsamples == 1 and use_bootstrap == ntrials: inds = np.arange(ntrials) else: - inds = np.random.choice(ntrials, nsamples * use_bootstrap) + inds = np.random.choice(ntrials, + nsamples * use_bootstrap) inds.shape = (use_bootstrap, nsamples) dat_bs = np.mean(dat[inds], axis=1) events_bs = ep.events[inds[:, 0]] @@ -278,7 +228,7 @@ def bootstrap_subject(self, subject_name, seed=8675309, sfreq=100, lower=None, assert dat_bs.shape == (use_bootstrap, T, p) assert events_bs.shape == (use_bootstrap, 3) - assert (events_bs[:, 2] == events_bs[0, 2]).all() #not working for sample_info + assert (events_bs[:, 2] == events_bs[0, 2]).all() data_bs_all.append(dat_bs) events_bs_all.append(events_bs) @@ -296,35 +246,44 @@ def bootstrap_subject(self, subject_name, seed=8675309, sfreq=100, lower=None, dat_all, info_dict, events=events_all, tmin=-0.2, event_id=epochs.event_id.copy(), on_missing='ignore') - # print(" saving bootstrapped epochs (%s)" % (epochs_bs_fname,)) - # epochs_bs.save(os.path.join(epochs_dir, epochs_bs_fname)) return epochs_bs - def add_subject(self, subject, subject_dir, epochs,labels, fwd, cov): + def add_subject(self, subject,condition,epochs,labels,fwd, + cov): + + epochs_bs = self.bootstrap_subject(epochs, subject) + epochs_bs = epochs_bs.crop(tmin=-0.2, tmax=0.7) + epochs_bs = epochs_bs[condition] + epochs = epochs_bs + + cov = cov.pick_channels(epochs.ch_names, ordered=True) + fwd = mne.convert_forward_solution(fwd, force_fixed=True) + fwd = fwd.pick_channels(epochs.ch_names, ordered=True) if not self._model_initalized: n_timepts = len(epochs.times) num_roi = len(labels) self._init_model(n_timepts, num_roi) self._model_initalized = True + self.n_timepts = n_timepts + self.num_roi = num_roi + self.times = epochs.times if len(epochs.times) != self._n_times: raise ValueError(f'Number of time points ({len(epochs.times)})' / 'does not match original count ({self._n_times})') - roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map - # fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs = \ - # _scale_sensor_data(epochs, fwd, cov, roi_to_src) + cov_scale = 3 # equal to number of bootstrap trials + cov['data'] /= cov_scale + fwd, cov = apply_projs(epochs_bs, fwd, cov) - # cov = cov.pick_channels(epochs.ch_names) - epochs_bs = self.bootstrap_subject(subject) - sdata = run_pca_on_subject(subject, epochs_bs, fwd, cov, labels, dim_mode='pctvar') #check for channel mismatch + sdata = run_pca_on_subject(subject, epochs_bs, fwd, cov, labels, + dim_mode='pctvar', mean_center=True) data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + subjectdata = (data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi) self._all_subject_data.append(subjectdata) - # self.set_data(subjectdata) - # epochs, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = subjectdata self._subject_data[subject] = dict() self._subject_data[subject]['epochs'] = data self._subject_data[subject]['fwd_src_snsr'] = fwd_src_snsr @@ -333,23 +292,22 @@ def add_subject(self, subject, subject_dir, epochs,labels, fwd, cov): self._subject_data[subject]['labels'] = labels self._subject_data[subject]['which_roi'] = which_roi - # dict_to_tuple = list(tuple(x.values()) for x in self._subject_data.values()) - # print(f' dict vs tuples = {np.allclose(subjectdata, dict_to_tuple)}') - def _init_model(self, n_timepts, num_roi, A_t_=None, roi_cov=None, mu0=None, - roi_cov_0=None, log_sigsq_lst=None): + def _init_model(self, n_timepts, num_roi, A_t_=None, roi_cov=None, + mu0=None, roi_cov_0=None, log_sigsq_lst=None): self._n_times = n_timepts self._subject_data = dict() set_default = \ lambda prm, val, deflt: \ - self.__setattr__(prm, val.copy() if val is not None else deflt) + self.__setattr__(prm, val.copy() if val is not None else + deflt) # initialize parameters set_default("A_t_", A_t_, - np.stack([rand_stable(num_roi, maxew=0.7) for _ in range(n_timepts)], - axis=0)) + np.stack([rand_stable(num_roi, maxew=0.7) for _ in + range(n_timepts)], axis=0)) set_default("roi_cov", roi_cov, rand_psd(num_roi)) set_default("mu0", mu0, np.zeros(num_roi)) set_default("roi_cov_0", roi_cov_0, rand_psd(num_roi)) @@ -380,8 +338,6 @@ def set_data(self, subjectdata): self._loglik = None self._B4 = [None] * self._nsubjects - # TODO: figure out how to initialize smoothed parameters so this doesn't - # break, e.g. if _em_objective is called before em for some reason def _em_objective(self): _, num_roi, _ = self.A_t_.shape @@ -409,61 +365,63 @@ def _em_objective(self): roi_cov_t = _ensure_ndim(self.roi_cov, n_timepts, 3) with numpy_num_threads(1): _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ - rts_smooth_fast(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, - self.roi_cov_0, compute_lag1_cov=True) + rts_smooth_fast(Y, self.A_t_, fwd_roi_snsr, + roi_cov_t, R, self.mu0, + self.roi_cov_0, + compute_lag1_cov=True) else: mus_smooth = self._mus_smooth_lst[s] sigmas_smooth = self._sigmas_smooth_lst[s] sigmas_tnt_smooth = self._sigmas_tnt_smooth_lst[s] - x_smooth_0_outer = einsum2('ri,rj->rij', mus_smooth[:,0,:num_roi], + x_smooth_0_outer = einsum('ri,rj->rij', mus_smooth[:,0,:num_roi], mus_smooth[:,0,:num_roi]) - B0 = w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + x_smooth_0_outer, - axis=0) - - x_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], - mus_smooth[:,1:,:num_roi]) - B1 = w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + x_smooth_outer, axis=0) - - z_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,:-1,:], + B0 = w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + + x_smooth_0_outer, axis=0) + + x_smooth_outer = einsum('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], + mus_smooth[:,1:,:num_roi]) + B1 = w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + + x_smooth_outer, axis=0) + z_smooth_outer = einsum('rti,rtj->rtij', mus_smooth[:,:-1,:], mus_smooth[:,:-1,:]) B3 = w_s*np.sum(sigmas_smooth[:,:-1,:,:] + z_smooth_outer, axis=0) - mus_smooth_outer_l1 = einsum2('rti,rtj->rtij', + mus_smooth_outer_l1 = einsum('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], mus_smooth[:,:-1,:]) - B2 = w_s*np.sum(sigmas_tnt_smooth[:,:,:num_roi,:] + mus_smooth_outer_l1, - axis=0) + B2 = w_s*np.sum(sigmas_tnt_smooth[:,:,:num_roi,:] + + mus_smooth_outer_l1, axis=0) # obj += L1(roi_cov_0) L_roi_cov_0_inv_B0 = solve_triangular(L_roi_cov_0, B0, lower=True) L1 += (ntrials*2.*np.sum(np.log(np.diag(L_roi_cov_0))) - + np.trace(solve_triangular(L_roi_cov_0, L_roi_cov_0_inv_B0, lower=True, - trans='T'))) + + np.trace(solve_triangular(L_roi_cov_0, L_roi_cov_0_inv_B0, + lower=True, trans='T'))) At = self.A_t_[:-1] - AtB2T = einsum2('tik,tjk->tij', At, B2) - B2AtT = einsum2('tik,tjk->tij', B2, At) - tmp = einsum2('tik,tkl->til', At, B3) - AtB3AtT = einsum2('tik,tjk->tij', tmp, At) + AtB2T = einsum('tik,tjk->tij', At, B2) + B2AtT = einsum('tik,tjk->tij', B2, At) + tmp = einsum('tik,tkl->til', At, B3) + AtB3AtT = einsum('tik,tjk->tij', tmp, At) tmp = np.sum(B1 - AtB2T - B2AtT + AtB3AtT, axis=0) # obj += L2(roi_cov, At) L_roi_cov_inv_tmp = solve_triangular(L_roi_cov, tmp, lower=True) L2 += (ntrials*(n_timepts-1)*2.*np.sum(np.log(np.diag(L_roi_cov))) - + np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_tmp, lower=True, - trans='T'))) + + np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_tmp, + lower=True, trans='T'))) - res = Y - einsum2('ik,ntk->nti', fwd_roi_snsr, mus_smooth[:,:,:num_roi]) - CP_smooth = einsum2('ik,ntkj->ntij', fwd_roi_snsr, sigmas_smooth[:,:,:num_roi,:num_roi]) + res = Y - einsum('ik,ntk->nti', fwd_roi_snsr, + mus_smooth[:,:,:num_roi]) + CP_smooth = einsum('ik,ntkj->ntij', fwd_roi_snsr, + sigmas_smooth[:,:,:num_roi,:num_roi]) - # TODO: np.sum does not parallelize over the accumulators, possible - # bottleneck. - B4 = w_s*(np.sum(einsum2('nti,ntj->ntij', res, res), axis=(0,1)) - + np.sum(einsum2('ntik,jk->ntij', CP_smooth, fwd_roi_snsr), - axis=(0,1))) + B4 = w_s*(np.sum(einsum('nti,ntj->ntij', res, res), axis=(0,1)) + + np.sum(einsum('ntik,jk->ntij', CP_smooth, + fwd_roi_snsr), axis=(0,1))) self._B4[s] = B4 # obj += L3(sigsq_vals) @@ -493,10 +451,11 @@ def _em_objective(self): return obj - def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, A_t_roi_cov_tol=1e-6, verbose=0, - update_A_t_=True, update_roi_cov=True, update_roi_cov_0=True, stationary_A_t_=False, - diag_roi_cov=False, update_sigsq=True, do_final_smoothing=True, - average_mus_smooth=True, Atrue=None, tau=0.1, c1=1e-4): + def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, + A_t_roi_cov_tol=1e-6, verbose=0, update_A_t_=True, + update_roi_cov=True, update_roi_cov_0=True, stationary_A_t_=False, + diag_roi_cov=False, update_sigsq=True, do_final_smoothing=True, + average_mus_smooth=True, Atrue=None, tau=0.1, c1=1e-4): self.set_data(self._all_subject_data) @@ -523,7 +482,8 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, A_t_roi_cov_tol=1e-6, converged = False best_objval = np.finfo('float').max best_params = (self.A_t_.copy(), self.roi_cov.copy(), self.mu0.copy(), - self.roi_cov_0.copy(), [l.copy() for l in self.log_sigsq_lst]) + self.roi_cov_0.copy(), [l.copy() for l in + self.log_sigsq_lst]) # previous parameter values (for checking convergence) At_prev = None @@ -533,7 +493,8 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, A_t_roi_cov_tol=1e-6, if Atrue is not None: import matplotlib.pyplot as plt - fig_A_t_, ax_A_t_ = plt.subplots(num_roi, num_roi, sharex=True, sharey=True) + fig_A_t_, ax_A_t_ = plt.subplots(num_roi, num_roi, sharex=True, + sharey=True) plt.ion() # calculate initial objective value, check for updated best iterate @@ -560,9 +521,12 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, A_t_roi_cov_tol=1e-6, roi_cov_0_prev = self.roi_cov_0.copy() log_sigsq_lst_prev = np.array(self.log_sigsq_lst).copy() - self._m_step(update_A_t_=update_A_t_, update_roi_cov=update_roi_cov, - update_roi_cov_0=update_roi_cov_0, stationary_A_t_=stationary_A_t_, - diag_roi_cov=diag_roi_cov, update_sigsq=update_sigsq, + self._m_step(update_A_t_=update_A_t_, + update_roi_cov=update_roi_cov, + update_roi_cov_0=update_roi_cov_0, + stationary_A_t_=stationary_A_t_, + diag_roi_cov=diag_roi_cov, + update_sigsq=update_sigsq, tau=tau, c1=c1, verbose=verbose) if Atrue is not None: @@ -592,15 +556,16 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, A_t_roi_cov_tol=1e-6, if objval < best_objval: best_objval = objval - best_params = (self.A_t_.copy(), self.roi_cov.copy(), self.mu0.copy(), - self.roi_cov_0.copy(), + best_params = (self.A_t_.copy(), self.roi_cov.copy(), + self.mu0.copy(), self.roi_cov_0.copy(), [l.copy() for l in self.log_sigsq_lst]) # check for convergence if it >= 1: relnormdiff_At = relnormdiff(self.A_t_[:-1], At_prev) relnormdiff_roi_cov = relnormdiff(self.roi_cov, roi_cov_prev) - relnormdiff_roi_cov_0 = relnormdiff(self.roi_cov_0, roi_cov_0_prev) + relnormdiff_roi_cov_0 = relnormdiff(self.roi_cov_0, + roi_cov_0_prev) relnormdiff_log_sigsq_lst = \ np.array( [relnormdiff(self.log_sigsq_lst[s], @@ -616,7 +581,8 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, A_t_roi_cov_tol=1e-6, if verbose > 0: print(" relnormdiff_At: %.3e" % relnormdiff_At) print(" relnormdiff_roi_cov: %.3e" % relnormdiff_roi_cov) - print(" relnormdiff_roi_cov_0: %.3e" % relnormdiff_roi_cov_0) + print(" relnormdiff_roi_cov_0: %.3e" % + relnormdiff_roi_cov_0) print(" relnormdiff_log_sigsq_lst:", relnormdiff_log_sigsq_lst) print(" relobjdiff: %.3e" % relobjdiff) @@ -669,7 +635,8 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, A_t_roi_cov_tol=1e-6, roi_cov_t = _ensure_ndim(self.roi_cov, self._n_timepts, 3) with numpy_num_threads(1): loglik_subject, mus_smooth, _, _, St = \ - rts_smooth(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, self.roi_cov_0, + rts_smooth(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, + self.mu0, self.roi_cov_0, compute_lag1_cov=False, store_St=self._store_St) # just save the mean of the smoothed trials @@ -724,29 +691,30 @@ def _e_step(self, verbose=0): with numpy_num_threads(1): _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ - rts_smooth_fast(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, - self.roi_cov_0, compute_lag1_cov=True) + rts_smooth_fast(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, + R, self.mu0, self.roi_cov_0, + compute_lag1_cov=True) self._mus_smooth_lst.append(mus_smooth) self._sigmas_smooth_lst.append(sigmas_smooth) self._sigmas_tnt_smooth_lst.append(sigmas_tnt_smooth) - x_smooth_0_outer = einsum2('ri,rj->rij', mus_smooth[:,0,:num_roi], + x_smooth_0_outer = einsum('ri,rj->rij', mus_smooth[:,0,:num_roi], mus_smooth[:,0,:num_roi]) - self._B0 += w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + x_smooth_0_outer, - axis=0) + self._B0 += w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + + x_smooth_0_outer, axis=0) - x_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], - mus_smooth[:,1:,:num_roi]) - self._B1 += w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + x_smooth_outer, - axis=0) + x_smooth_outer = einsum('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], + mus_smooth[:,1:,:num_roi]) + self._B1 += w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + + x_smooth_outer, axis=0) - z_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,:-1,:], + z_smooth_outer = einsum('rti,rtj->rtij', mus_smooth[:,:-1,:], mus_smooth[:,:-1,:]) self._B3 += w_s*np.sum(sigmas_smooth[:,:-1,:,:] + z_smooth_outer, axis=0) - mus_smooth_outer_l1 = einsum2('rti,rtj->rtij', + mus_smooth_outer_l1 = einsum('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], mus_smooth[:,:-1,:]) self._B2 += w_s*np.sum(sigmas_tnt_smooth[:,:,:num_roi,:] + @@ -755,8 +723,9 @@ def _e_step(self, verbose=0): if verbose > 0: print("\n done") - def _m_step(self, update_A_t_=True, update_roi_cov=True, update_roi_cov_0=True, - stationary_A_t_=False, diag_roi_cov=False, update_sigsq=True, tau=0.1, c1=1e-4, + def _m_step(self, update_A_t_=True, update_roi_cov=True, + update_roi_cov_0=True, stationary_A_t_=False, + diag_roi_cov=False, update_sigsq=True, tau=0.1, c1=1e-4, verbose=0): self._loglik = None if verbose > 0: @@ -765,13 +734,16 @@ def _m_step(self, update_A_t_=True, update_roi_cov=True, update_roi_cov_0=True, self.roi_cov_0 = (1. / self._ntrials_all) * self._B0 if diag_roi_cov: self.roi_cov_0 = np.diag(np.diag(self.roi_cov_0)) - self.update_A_t_and_roi_cov(update_A_t_=update_A_t_, update_roi_cov=update_roi_cov, - stationary_A_t_=stationary_A_t_, diag_roi_cov=diag_roi_cov, - tau=tau, c1=c1, verbose=verbose) + self.update_A_t_and_roi_cov(update_A_t_=update_A_t_, + update_roi_cov=update_roi_cov, + stationary_A_t_=stationary_A_t_, + diag_roi_cov=diag_roi_cov, tau=tau, + c1=c1, verbose=verbose) if update_sigsq: self.update_log_sigsq_lst(verbose=verbose) - def update_A_t_and_roi_cov(self, update_A_t_=True, update_roi_cov=True, stationary_A_t_=False, + def update_A_t_and_roi_cov(self, update_A_t_=True, update_roi_cov=True, + stationary_A_t_=False, diag_roi_cov=False, tau=0.1, c1=1e-4, verbose=0): if verbose > 1: @@ -819,10 +791,10 @@ def update_A_t_and_roi_cov(self, update_A_t_=True, update_roi_cov=True, stationa # update roi_cov using closed form if update_roi_cov: - AtB2T = einsum2('tik,tjk->tij', At, self._B2) - B2AtT = einsum2('tik,tjk->tij', self._B2, At) - tmp = einsum2('tik,tkl->til', At, self._B3) - AtB3AtT = einsum2('til,tjl->tij', tmp, At) + AtB2T = einsum('tik,tjk->tij', At, self._B2) + B2AtT = einsum('tik,tjk->tij', self._B2, At) + tmp = einsum('tik,tkl->til', At, self._B3) + AtB3AtT = einsum('til,tjl->tij', tmp, At) elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) self.roi_cov = (1. / (self._ntrials_all * self._n_timepts )) * elbo_2 @@ -862,7 +834,8 @@ def update_log_sigsq_lst(self, verbose=0): log_sigsq = self.log_sigsq_lst[s].copy() log_sigsq_obj = lambda x: \ - MEGLDS.L3_obj(x, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, n_timepts) + MEGLDS.L3_obj(x, snsr_cov, fwd_src_snsr, which_roi, B4, + ntrials, n_timepts) log_sigsq_val_and_grad = vgrad(log_sigsq_obj) options = {'maxiter': 500} @@ -881,119 +854,22 @@ def update_log_sigsq_lst(self, verbose=0): if verbose > 1: print("\n done") - def calculate_smoothed_estimates(self): - """ recalculate smoothed estimates with current model parameters """ - - self._mus_smooth_lst = list() - self._sigmas_smooth_lst = list() - self._sigmas_tnt_smooth_lst = list() - self._St_lst = list() - self._loglik = 0. - - for s, sdata in enumerate(self.unpack_all_subject_data()): - Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) - roi_cov_t = _ensure_ndim(self.roi_cov, self._n_timepts, 3) - with numpy_num_threads(1): - ll, mus_smooth, sigmas_smooth, sigmas_tnt_smooth, _ = \ - rts_smooth(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, self.roi_cov_0, - compute_lag1_cov=True, store_St=False) - self._mus_smooth_lst.append(mus_smooth) - self._sigmas_smooth_lst.append(sigmas_smooth) - self._sigmas_tnt_smooth_lst.append(sigmas_tnt_smooth) - #self._St_lst.append(np.diagonal(St, axis1=-2, axis2=-1)) - self._loglik += ll - - def log_likelihood(self): - """ calculate log marginal likelihood using the Kalman filter """ - - #if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None \ - # or self._sigmas_tnt_smooth_lst is None): - # self.calculate_smoothed_estimates() - # return self._loglik - if self._loglik is not None: - return self._loglik - - self._loglik = 0. - for s, sdata in enumerate(self.unpack_all_subject_data()): - Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) - roi_cov_t = _ensure_ndim(self.roi_cov, self._n_timepts, 3) - ll, _, _, _ = kalman_filter(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, - self.roi_cov_0, store_St=False) - self._loglik += ll - - return self._loglik - - def nparams(self): - n_timepts, p, _ = self.A_t_.shape - - # this should equal (n_timepts-1)*p*p unless some shrinkage is used on At - nparams_At = np.sum(np.abs(self.A_t_[:-1]) > 0) - - # nparams = nparams(At) + nparams(roi_cov) + nparams(roi_cov_0) - # + nparams(log_sigsq_lst) - return nparams_At + p*(p+1)/2 + p*(p+1)/2 \ - + np.sum([p+1 for _ in range(len(self.log_sigsq_lst))]) - - def AIC(self): - return -2*self.log_likelihood() + 2*self.nparams() - - def BIC(self): - if self._ntrials_all == 0: - raise RuntimeError("use set_data to add subject data before" \ - + " computing BIC") - return -2*self.log_likelihood() \ - + np.log(self._ntrials_all)*self.nparams() - - def save(self, filename, **kwargs): - savedict = { 'A_t_' : self.A_t_, 'roi_cov' : self.roi_cov, 'mu0' : self.mu0, - 'roi_cov_0' : self.roi_cov_0, 'log_sigsq_lst' : self.log_sigsq_lst, - 'lam0' : self.lam0, 'lam1' : self.lam1} - savedict.update(kwargs) - np.savez_compressed(filename, **savedict) - - def load(self, filename): - loaddict = np.load(filename) - param_names = ['A_t_', 'roi_cov', 'mu0', 'roi_cov_0', 'log_sigsq_lst', 'lam0', 'lam1'] - for name in param_names: - if name not in loaddict.keys(): - raise RuntimeError('specified file is not a saved model:\n%s' - % (filename,)) - for name in param_names: - if name == 'log_sigsq_lst': - self.log_sigsq_lst = [l.copy() for l in loaddict[name]] - elif name in ('lam0', 'lam1'): - self.__setattr__(name, float(loaddict[name])) - else: - self.__setattr__(name, loaddict[name].copy()) - - # return remaining saved items, if there are any - others = {key : loaddict[key] for key in loaddict.keys() \ - if key not in param_names} - if len(others.keys()) > 0: - return others - + @staticmethod def R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi): - return snsr_cov + np.dot(fwd_src_snsr, sigsq_vals[which_roi][:,None]*fwd_src_snsr.T) + return snsr_cov + np.dot(fwd_src_snsr, + sigsq_vals[which_roi][:,None]*fwd_src_snsr.T) def L2_obj(self, At, L_roi_cov): - - # import autograd.numpy - # if isinstance(At,autograd.numpy.numpy_boxes.ArrayBox): - # At = At._value - - AtB2T = einsum2('tik,tjk->tij', At, self._B2) - B2AtT = einsum2('tik,tjk->tij', self._B2, At) - tmp = einsum2('tik,tkl->til', At, self._B3) - AtB3AtT = einsum2('til,tjl->tij', tmp, At) + AtB2T = einsum('tik,tjk->tij', At, self._B2) + B2AtT = einsum('tik,tjk->tij', self._B2, At) + tmp = einsum('tik,tkl->til', At, self._B3) + AtB3AtT = einsum('til,tjl->tij', tmp, At) elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) L_roi_cov_inv_elbo_2 = solve_triangular(L_roi_cov, elbo_2, lower=True) - obj = np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_elbo_2, lower=True, + obj = np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_elbo_2, + lower=True, trans='T')) obj = obj / self._ntrials_all @@ -1004,10 +880,11 @@ def L2_obj(self, At, L_roi_cov): return obj - # TODO: convert to instance method @staticmethod - def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, n_timepts): - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, np.exp(log_sigsq_vals), which_roi) + def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, + n_timepts): + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, np.exp(log_sigsq_vals), + which_roi) try: L_R = np.linalg.cholesky(R) except LinAlgError: @@ -1017,74 +894,4 @@ def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, n_tim + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, trans='T'))) - # @property - # def A_t_(self): - # '''The time-varying connectivity''' - # if self._A_t_ is not None: - # return self._A_t_.copy() - # else: - # return None - - # @property - # def A(self): - # return self._A - - # @A.setter - # def A(self, A): - # self._A = A - - # @property - # def roi_cov(self): - # return self._roi_cov - - # @roi_cov.setter - # def roi_cov(self, roi_cov): - # self._roi_cov = roi_cov - - # @property - # def mu0(self): - # return self._mu0 - - # @mu0.setter - # def mu0(self, mu0): - # self._mu0 = mu0 - - # @property - # def roi_cov_0(self): - # return self._roi_cov_0 - - # @roi_cov_0.setter - # def roi_cov_0(self, roi_cov_0): - # self._roi_cov_0 = roi_cov_0 - - # @property - # def log_sigsq_lst(self): - # return self._log_sigsq_lst - - # @log_sigsq_lst.setter - # def log_sigsq_lst(self, log_sigsq_lst): - # self._log_sigsq_lst = log_sigsq_lst - - # @property - # def num_roi(self): - # return self.A.shape[1] - - # @property - # def n_timepts(self): - # return self._n_timepts - - # @property - # def lam0(self): - # return self._lam0 - - # @lam0.setter - # def lam0(self, lam0): - # self._lam0 = lam0 - - # @property - # def lam1(self): - # return self._lam1 - - # @lam1.setter - # def lam1(self, lam1): - # self._lam1 = lam1 + diff --git a/state_space/state_space_connectivity.py b/state_space/state_space_connectivity.py index 65ca4b0d..4ca68cc4 100644 --- a/state_space/state_space_connectivity.py +++ b/state_space/state_space_connectivity.py @@ -15,255 +15,83 @@ import matplotlib.pyplot as plt import matplotlib as mpl -#where should these files live within mne-connectivity repo? + from megssm.models import MEGLDS as LDS from megssm.plotting import plot_A_t_ -import pickle -load = 0 -if not load: # define paths to sample data - path = None - path = '/Users/jordandrew/Documents/MEG/mne_data'#'/MNE-sample-data' - data_path = mne.datasets.sample.data_path(path=path) - sample_folder = data_path / 'MEG/sample' - subjects_dir = data_path / 'subjects' - - ## import raw data and find events - raw_fname = sample_folder / 'sample_audvis_raw.fif' - raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) - events = mne.find_events(raw, stim_channel='STI 014') - - ## define epochs using event_dict - event_dict = {'auditory_left': 1, 'auditory_right': 2, 'visual_left': 3, - 'visual_right': 4}#, 'face': 5, 'buttonpress': 32} - epochs = mne.Epochs(raw, events, tmin=-0.3, tmax=0.7, event_id=event_dict, - preload=True).pick_types(meg=True,eeg=True) - condition = 'auditory_left' - epochs = epochs[condition] # choose condition for analysis - - - - #SNR boost epochs, bootstraps of 3 - # def bootstrap_subject(subject_name, seed=8675309, sfreq=100, lower=None, - # upper=None, nbootstrap=3, g_nsamples=-5, - # overwrite=False, validation_set=True): - # import autograd.numpy as np - # import os - - # # subjects = ['sample'] - # datasets = ['train', 'validation'] - # use_erm = eq = False - # independent = False - # if g_nsamples == 0: - # print('nsamples == 0, ensuring independence of samples') - # independent = True - # elif g_nsamples == -1: - # print("using half of trials per sample") - # elif g_nsamples == -2: - # print("using empty room noise at half of trials per sample") - # use_erm = True - # elif g_nsamples == -3: - # print("using independent and trial-count equalized samples") - # eq = True - # independent = True - # elif g_nsamples == -4: - # print("using independent, trial-count equailized, non-boosted samples") - # assert nbootstrap == 0 # sanity check - # eq = True - # independent = True - # datasets = ['train'] - # elif g_nsamples == -5: - # print("using independent, trial-count equailized, integer boosted samples") - # eq = True - # independent = True - # datasets = ['train'] - - # if lower is not None or upper is not None: - # if upper is None: - # print('high-pass filtering at %.2f Hz' % lower) - # elif lower is None: - # print('low-pass filtering at %.2f Hz' % upper) - # else: - # print('band-pass filtering from %.2f-%.2f Hz' % (lower, upper)) - - # if sfreq is not None: - # print('resampling to %.2f Hz' % sfreq) - - # print(":: processing subject %s" % subject_name) - # np.random.seed(seed) - - # for dataset in datasets: - - # print(' generating ', dataset, ' set') - # datadir = './data' - - # subj_dir = os.path.join(datadir, subject_name) - # print("subject dir:" + subj_dir) - # if not os.path.exists(subj_dir): - # print(' %s not found, skipping' % subject_name) - # return - - # epochs_dir = os.path.join(datadir, subject_name, 'epochs') - # epochs_fname = "All_55-sss_%s-epo.fif" % subject_name - # epochs_bs_fname = (epochs_fname.split('-epo')[0] + - # "-bootstrap_%d-nsamples_%d-seed_%d%s%s%s-" - # % (nbootstrap, g_nsamples, seed, - # '-lower_%.2e' % lower if lower is not None else '', - # '-upper_%.2e' % upper if upper is not None else '', - # '-sfreq_%.2e' % sfreq if sfreq is not None else '') + - # dataset + "-epo.fif") - - # if os.path.exists(os.path.join(epochs_dir, epochs_bs_fname)) and \ - # not overwrite: - # print(" => found existing bootstrapped epochs, skipping") - # return - - # epochs = mne.read_epochs(os.path.join(epochs_dir, epochs_fname), - # preload=True) - # condition_map = {'auditory_left':['auditory_left'],'auditory_right': ['auditory_right'], - # 'visual_left': ['visual_left'], 'visual_right': ['visual_right']} - # condition_eq_map = dict(auditory_left=['auditory_left'], auditory_right=['auditory_right'], - # visual_left=['visual_left'], visual_right='visual_right') - # if eq: - # epochs.equalize_event_counts(list(condition_map)) - # cond_map = condition_eq_map - - # # apply band-pass filter to limit signal to desired frequency band - # if lower is not None or upper is not None: - # epochs = epochs.filter(lower, upper) - - # # perform resampling with specified sampling frequency - # if sfreq is not None: - # epochs = epochs.resample(sfreq) - - # data_bs_all = list() - # events_bs_all = list() - # for cond in sorted(cond_map.keys()): - # print(" -> condition %s: bootstrapping" % cond, end='') - # ep = epochs[cond_map[cond]] - # dat = ep.get_data().copy() - # ntrials, T, p = dat.shape - - # use_bootstrap = nbootstrap - # if g_nsamples == -4: - # nsamples = 1 - # use_bootstrap = ntrials - # elif g_nsamples == -5: - # nsamples = nbootstrap - # use_bootstrap = ntrials // nsamples - # elif independent: - # nsamples = (ntrials - 1) // use_bootstrap - # elif g_nsamples in (-1, -2): - # nsamples = ntrials // 2 - # else: - # assert g_nsamples > 0 - # nsamples = g_nsamples - # print(" using %d samples (%d trials)" - # % (nsamples, use_bootstrap)) - - # # bootstrap here - # if independent: # independent - # if nsamples == 1 and use_bootstrap == ntrials: - # inds = np.arange(ntrials) - # else: - # inds = np.random.choice(ntrials, nsamples * use_bootstrap) - # inds.shape = (use_bootstrap, nsamples) - # dat_bs = np.mean(dat[inds], axis=1) - # events_bs = ep.events[inds[:, 0]] - # assert dat_bs.shape[0] == events_bs.shape[0] - # else: - # dat_bs = np.empty((ntrials, T, p)) - # events_bs = np.empty((ntrials, 3), dtype=int) - # for i in range(ntrials): - - # inds = list(set(range(ntrials)).difference([i])) - # inds = np.random.choice(inds, size=nsamples, - # replace=False) - # inds = np.append(inds, i) - - # dat_bs[i] = np.mean(dat[inds], axis=0) - # events_bs[i] = ep.events[i] - - # inds = np.random.choice(ntrials, size=use_bootstrap, - # replace=False) - # dat_bs = dat_bs[inds] - # events_bs = events_bs[inds] - - # assert dat_bs.shape == (use_bootstrap, T, p) - # assert events_bs.shape == (use_bootstrap, 3) - # assert (events_bs[:, 2] == events_bs[0, 2]).all() #not working for sample_info - - # data_bs_all.append(dat_bs) - # events_bs_all.append(events_bs) - - # # write bootstrap epochs - # info_dict = epochs.info.copy() - - # dat_all = np.vstack(data_bs_all) - # events_all = np.vstack(events_bs_all) - # # replace first column with sequential list as we don't really care - # # about the raw timings - # events_all[:, 0] = np.arange(events_all.shape[0]) - - # epochs_bs = mne.EpochsArray( - # dat_all, info_dict, events=events_all, tmin=-0.2, - # event_id=epochs.event_id.copy(), on_missing='ignore') - - # print(" saving bootstrapped epochs (%s)" % (epochs_bs_fname,)) - # epochs_bs.save(os.path.join(epochs_dir, epochs_bs_fname)) - # bootstrap_subject('sample') - - - - - - - - ## read forward solution, remove bad channels - fwd_fname = sample_folder / 'sample_audvis-meg-eeg-oct-6-fwd.fif' - fwd = mne.read_forward_solution(fwd_fname) - - ## read in covariance OR compute noise covariance? noise_cov drops bad chs - cov_fname = sample_folder / 'sample_audvis-cov.fif' - cov = mne.read_cov(cov_fname) # drop bad channels in add_subject - # noise_cov = mne.compute_covariance(epochs, tmax=0) - - ## read labels for analysis - label_names = ['AUD-lh', 'AUD-rh', 'Vis-lh', 'Vis-rh'] - labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label', - subject='sample') for label in label_names] - - # initiate model - num_rois = len(labels) - n_timepts = len(epochs.times) - times = epochs.times - # model = LDS(num_rois, n_timepts, lam0=0, lam1=100) # only needs the forward, labels, and noise_cov to be initialized - model = LDS(lam0=0, lam1=100) - - model.add_subject('sample', subjects_dir, epochs, labels, fwd, cov) - # #when to use compute_cov vs read_cov? ie cov vs noise_cov - - # model.fit(niter=100, verbose=1) - # A_t_ = model.A_t_ - # assert A_t_.shape == (n_timepts, num_rois, num_rois) - - # file = open('sample_subj_stdNone.pkl','wb') - # pickle.dump([model, num_rois, n_timepts, times, condition, label_names],file) - # file.close - -# if load: -# with open('sample_subj_stdNone.pkl','rb') as f: -# model, num_rois, n_timepts, times, condition, label_names = pickle.load(f) - -# with mpl.rc_context(): -# {'xtick.labelsize': 'x-small', 'ytick.labelsize': 'x-small'} -# fig, ax = plt.subplots(num_rois, num_rois, constrained_layout=True, squeeze=False, -# figsize=(12, 10)) -# plot_A_t_(model.A_t_, labels=label_names, times=times, ax=ax) -# fig.suptitle(condition) - +path = None################################ +path = '/Users/jordandrew/Documents/MEG/mne_data'#'/MNE-sample-data' +data_path = mne.datasets.sample.data_path(path=path) +sample_folder = data_path / 'MEG/sample' +subjects_dir = data_path / 'subjects' + +## import raw data and find events +raw_fname = sample_folder / 'sample_audvis_raw.fif' +raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) +events = mne.find_events(raw, stim_channel='STI 014') + +## define epochs using event_dict +event_dict = {'auditory_left': 1, 'auditory_right': 2, 'visual_left': 3, + 'visual_right': 4} +epochs = mne.Epochs(raw, events, tmin=-0.2, tmax=0.7, event_id=event_dict, + preload=True).pick_types(meg=True,eeg=True) +condition = 'auditory_left' + +## read forward solution, remove bad channels +fwd_fname = sample_folder / 'sample_audvis-meg-eeg-oct-6-fwd.fif' +fwd = mne.read_forward_solution(fwd_fname) + +## read in covariance +cov_fname = sample_folder / 'sample_audvis-cov.fif' +cov = mne.read_cov(cov_fname) + +## read labels for analysis +label_names = ['AUD-lh', 'AUD-rh', 'Vis-lh', 'Vis-rh'] +labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label', + subject='sample') for label in label_names] + +# initiate model +model = LDS(lam0=0, lam1=100) +model.add_subject('sample', condition, epochs, labels, fwd, cov) +model.fit(niter=100, verbose=2) + +#plot model output +num_roi = model.num_roi +n_timepts = model.n_timepts +times = model.times +A_t_ = model.A_t_ +assert A_t_.shape == (n_timepts, num_roi, num_roi) + +with mpl.rc_context(): + {'xtick.labelsize': 'x-small', 'ytick.labelsize': 'x-small'} + fig, ax = plt.subplots(num_roi, num_roi, constrained_layout=True, + squeeze=False, figsize=(12, 10)) + plot_A_t_(A_t_, labels=label_names, times=times, ax=ax) + fig.suptitle('testing_') + diag_lims = [0, 1] + off_lims = [-0.6, 0.6] + for ri, row in enumerate(ax): + for ci, a in enumerate(row): + ylim = diag_lims if ri == ci else off_lims + a.set(ylim=ylim, xlim=times[[0, -1]]) + if ri == 0: + a.set_title(a.get_title(), fontsize='small') + if ci == 0: + a.set_ylabel(a.get_ylabel(), fontsize='small') + for line in a.lines: + line.set_clip_on(False) + line.set(lw=1.) + if ci != 0: + a.yaxis.set_major_formatter(plt.NullFormatter()) + if ri != len(label_names) - 1: + a.xaxis.set_major_formatter(plt.NullFormatter()) + if ri == ci: + for spine in a.spines.values(): + spine.set(lw=2) + else: + a.axhline(0, color='k', ls=':', lw=1.) From 460fbbbb10e8ec91901a0b64204877930fd162ed Mon Sep 17 00:00:00 2001 From: jadrew43 Date: Mon, 29 Aug 2022 12:44:01 -0700 Subject: [PATCH 11/17] Deleted files from /examples. Find files in /state_space --- examples/megssm/__init__.py | 0 examples/megssm/message_passing.py | 732 ---------------------- examples/megssm/mne_util.py | 290 --------- examples/megssm/models.py | 872 --------------------------- examples/megssm/numpy_numthreads.py | 91 --- examples/megssm/plotting.py | 107 ---- examples/megssm/util.py | 117 ---- examples/mne_util.py | 294 --------- examples/state_space_connectivity.py | 87 --- 9 files changed, 2590 deletions(-) delete mode 100755 examples/megssm/__init__.py delete mode 100755 examples/megssm/message_passing.py delete mode 100644 examples/megssm/mne_util.py delete mode 100755 examples/megssm/models.py delete mode 100755 examples/megssm/numpy_numthreads.py delete mode 100644 examples/megssm/plotting.py delete mode 100755 examples/megssm/util.py delete mode 100644 examples/mne_util.py delete mode 100644 examples/state_space_connectivity.py diff --git a/examples/megssm/__init__.py b/examples/megssm/__init__.py deleted file mode 100755 index e69de29b..00000000 diff --git a/examples/megssm/message_passing.py b/examples/megssm/message_passing.py deleted file mode 100755 index e560ff49..00000000 --- a/examples/megssm/message_passing.py +++ /dev/null @@ -1,732 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import autograd.numpy as np -from autograd.scipy.linalg import block_diag - -from .util import T_, sym, dot3, _ensure_ndim, component_matrix, hs - -try: - from autograd_linalg import solve_triangular -except ImportError: - raise RuntimeError("must install `autograd_linalg` package") - -# einsum2 is a parallel version of einsum that works for two arguments -try: - from einsum2 import einsum2 -except ImportError: - # rename standard numpy function if don't have einsum2 - print("=> WARNING: using standard numpy.einsum,", - "consider installing einsum2 package") - from numpy import einsum as einsum2 - - -def kalman_filter(Y, A, C, Q, R, mu0, Q0, store_St=True, sum_logliks=True): - """ Kalman filter that broadcasts over the first dimension. - Handles multiple lag dependence using component form. - - Note: This function doesn't handle control inputs (yet). - - Y : ndarray, shape=(N, T, D) - Observations - - A : ndarray, shape=(T, D*nlag, D*nlag) - Time-varying dynamics matrices - - C : ndarray, shape=(p, D) - Observation matrix - - mu0: ndarray, shape=(D,) - mean of initial state variable - - Q0 : ndarray, shape=(D, D) - Covariance of initial state variable - - Q : ndarray, shape=(D, D) - Covariance of latent states - - R : ndarray, shape=(D, D) - Covariance of observations - """ - - N = Y.shape[0] - T, D, Dnlags = A.shape - nlags = Dnlags // D - AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) - - p = C.shape[0] - CC = hs([C, np.zeros((p, D*(nlags-1)))]) - - QQ = np.zeros((T, Dnlags, Dnlags)) - QQ[:,:D,:D] = Q - - QQ0 = block_diag(*[Q0 for _ in range(nlags)]) - - mu_predict = np.stack([np.tile(mu0, nlags) for _ in range(N)], axis=0) - sigma_predict = np.stack([QQ0 for _ in range(N)], axis=0) - - St = np.empty((N, T, p, p)) if store_St else None - - mus_filt = np.zeros((N, T, Dnlags)) - sigmas_filt = np.zeros((N, T, Dnlags, Dnlags)) - - ll = np.zeros(T) - - for t in range(T): - - # condition - # dot3(CC, sigma_predict, CC.T) + R - tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict) - sigma_pred = np.dot(tmp1, CC.T) + R - sigma_pred = sym(sigma_pred) - - if St is not None: - St[...,t,:,:] = sigma_pred - - res = Y[...,t,:] - np.dot(mu_predict, CC.T) - - L = np.linalg.cholesky(sigma_pred) - v = solve_triangular(L, res, lower=True) - - # log-likelihood over all trials - ll[t] = -0.5*(2.*np.sum(np.log(np.diagonal(L, axis1=1, axis2=2))) - + np.sum(v*v) - + N*p*np.log(2.*np.pi)) - - mus_filt[...,t,:] = mu_predict + einsum2('nki,nk->ni', tmp1, - solve_triangular(L, v, 'T', lower=True)) - - tmp2 = solve_triangular(L, tmp1, lower=True) - sigmas_filt[...,t,:,:] = sym(sigma_predict - einsum2('nki,nkj->nij', tmp2, tmp2)) - - # prediction - mu_predict = einsum2('ik,nk->ni', AA[t], mus_filt[...,t,:]) - - sigma_predict = einsum2('ik,nkl->nil', AA[t], sigmas_filt[...,t,:,:]) - sigma_predict = sym(einsum2('nil,jl->nij', sigma_predict, AA[t]) + QQ[t]) - - if sum_logliks: - ll = np.sum(ll) - return ll, mus_filt, sigmas_filt, St - - -def rts_smooth(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False, - store_St=True): - """ RTS smoother that broadcasts over the first dimension. - Handles multiple lag dependence using component form. - - Note: This function doesn't handle control inputs (yet). - - Y : ndarray, shape=(N, T, D) - Observations - - A : ndarray, shape=(T, D*nlag, D*nlag) - Time-varying dynamics matrices - - C : ndarray, shape=(p, D) - Observation matrix - - mu0: ndarray, shape=(D,) - mean of initial state variable - - Q0 : ndarray, shape=(D, D) - Covariance of initial state variable - - Q : ndarray, shape=(D, D) - Covariance of latent states - - R : ndarray, shape=(D, D) - Covariance of observations - """ - - N, T, _ = Y.shape - _, D, Dnlags = A.shape - nlags = Dnlags // D - AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) - - p = C.shape[0] - CC = hs([C, np.zeros((p, D*(nlags-1)))]) - - QQ = np.zeros((T, Dnlags, Dnlags)) - QQ[:,:D,:D] = Q - - QQ0 = block_diag(*[Q0 for _ in range(nlags)]) - - mu_predict = np.empty((N, T+1, Dnlags)) - sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) - - mus_smooth = np.empty((N, T, Dnlags)) - sigmas_smooth = np.empty((N, T, Dnlags, Dnlags)) - - St = np.empty((N, T, p, p)) if store_St else None - - if compute_lag1_cov: - sigmas_smooth_tnt = np.empty((N, T-1, Dnlags, Dnlags)) - else: - sigmas_smooth_tnt = None - - ll = 0. - mu_predict[:,0,:] = np.tile(mu0, nlags) - sigma_predict[:,0,:,:] = QQ0.copy() - - for t in range(T): - - # condition - # sigma_x = dot3(C, sigma_predict, C.T) + R - tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) - sigma_x = einsum2('nik,jk->nij', tmp1, CC) + R - sigma_x = sym(sigma_x) - - if St is not None: - St[...,t,:,:] = sigma_x - - L = np.linalg.cholesky(sigma_x) - # res[n] = Y[n,t,:] = np.dot(C, mu_predict[n,t,:]) - res = Y[...,t,:] - einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) - v = solve_triangular(L, res, lower=True) - - # log-likelihood over all trials - ll += -0.5*(2.*np.sum(np.log(np.diagonal(L, axis1=1, axis2=2))) - + np.sum(v*v) - + N*p*np.log(2.*np.pi)) - - mus_smooth[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', - tmp1, - solve_triangular(L, v, trans='T', lower=True)) - - # tmp2 = L^{-1}*C*sigma_predict - tmp2 = solve_triangular(L, tmp1, lower=True) - sigmas_smooth[:,t,:,:] = sym(sigma_predict[:,t,:,:] - einsum2('nki,nkj->nij', tmp2, tmp2)) - - # prediction - #mu_predict = np.dot(A[t], mus_smooth[t]) - mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_smooth[:,t,:]) - - #sigma_predict = dot3(A[t], sigmas_smooth[t], A[t].T) + Q[t] - tmp = einsum2('ik,nkl->nil', AA[t], sigmas_smooth[:,t,:,:]) - sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) - - for t in range(T-2, -1, -1): - - # these names are stolen from mattjj and slinderman - #temp_nn = np.dot(A[t], sigmas_smooth[n,t,:,:]) - temp_nn = einsum2('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) - - L = np.linalg.cholesky(sigma_predict[:,t+1,:,:]) - v = solve_triangular(L, temp_nn, lower=True) - # Look in Saarka for dfn of Gt_T - Gt_T = solve_triangular(L, v, trans='T', lower=True) - - # {mus,sigmas}_smooth[n,t] contain the filtered estimates so we're - # overwriting them on purpose - #mus_smooth[n,t,:] = mus_smooth[n,t,:] + np.dot(T_(Gt_T), mus_smooth[n,t+1,:] - mu_predict[t+1,:]) - mus_smooth[:,t,:] = mus_smooth[:,t,:] + einsum2('nki,nk->ni', Gt_T, mus_smooth[:,t+1,:] - mu_predict[:,t+1,:]) - - #sigmas_smooth[n,t,:,:] = sigmas_smooth[n,t,:,:] + dot3(T_(Gt_T), sigmas_smooth[n,t+1,:,:] - temp_nn, Gt_T) - tmp = einsum2('nki,nkj->nij', Gt_T, sigmas_smooth[:,t+1,:,:] - sigma_predict[:,t+1,:,:]) - tmp = einsum2('nik,nkj->nij', tmp, Gt_T) - sigmas_smooth[:,t,:,:] = sym(sigmas_smooth[:,t,:,:] + tmp) - - if compute_lag1_cov: - # This matrix is NOT symmetric, so don't symmetrize! - #sigmas_smooth_tnt[n,t,:,:] = np.dot(sigmas_smooth[n,t+1,:,:], Gt_T) - sigmas_smooth_tnt[:,t,:,:] = einsum2('nik,nkj->nij', sigmas_smooth[:,t+1,:,:], Gt_T) - - return ll, mus_smooth, sigmas_smooth, sigmas_smooth_tnt, St - - -def rts_smooth_fast(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False): - """ RTS smoother that broadcasts over the first dimension. - Handles multiple lag dependence using component form. - - Note: This function doesn't handle control inputs (yet). - - Y : ndarray, shape=(N, T, D) - Observations - - A : ndarray, shape=(T, D*nlag, D*nlag) - Time-varying dynamics matrices - - C : ndarray, shape=(p, D) - Observation matrix - - mu0: ndarray, shape=(D,) - mean of initial state variable - - Q0 : ndarray, shape=(D, D) - Covariance of initial state variable - - Q : ndarray, shape=(D, D) - Covariance of latent states - - R : ndarray, shape=(D, D) - Covariance of observations - """ - - N, T, _ = Y.shape - _, D, Dnlags = A.shape - nlags = Dnlags // D - AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) - - L_R = np.linalg.cholesky(R) - - p = C.shape[0] - CC = hs([C, np.zeros((p, D*(nlags-1)))]) - tmp = solve_triangular(L_R, CC, lower=True) - Rinv_CC = solve_triangular(L_R, tmp, trans='T', lower=True) - CCT_Rinv_CC = einsum2('ki,kj->ij', CC, Rinv_CC) - - # tile L_R across number of trials so solve_triangular - # can broadcast over trials properly - L_R = np.tile(L_R, (N, 1, 1)) - - QQ = np.zeros((T, Dnlags, Dnlags)) - QQ[:,:D,:D] = Q - - QQ0 = block_diag(*[Q0 for _ in range(nlags)]) - - mu_predict = np.empty((N, T+1, Dnlags)) - sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) - - mus_smooth = np.empty((N, T, Dnlags)) - sigmas_smooth = np.empty((N, T, Dnlags, Dnlags)) - - if compute_lag1_cov: - sigmas_smooth_tnt = np.empty((N, T-1, Dnlags, Dnlags)) - else: - sigmas_smooth_tnt = None - - ll = 0. - mu_predict[:,0,:] = np.tile(mu0, nlags) - sigma_predict[:,0,:,:] = QQ0.copy() - - I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) - - for t in range(T): - - # condition - tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) - - res = Y[...,t,:] - einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) - - # Rinv * res - tmp2 = solve_triangular(L_R, res, lower=True) - tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) - - # C^T Rinv * res - tmp3 = einsum2('ki,nk->ni', Rinv_CC, res) - - # (Pinv + C^T Rinv C)_inv * tmp3 - L_P = np.linalg.cholesky(sigma_predict[:,t,:,:]) - tmp = solve_triangular(L_P, I_tiled, lower=True) - Pinv = solve_triangular(L_P, tmp, trans='T', lower=True) - tmp4 = sym(Pinv + CCT_Rinv_CC) - L_tmp4 = np.linalg.cholesky(tmp4) - tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) - tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) - - # Rinv C * tmp3 - tmp3 = einsum2('ik,nk->ni', Rinv_CC, tmp3) - - # add the two Woodbury * res terms together - tmp = tmp2 - tmp3 - - mus_smooth[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', tmp1, tmp) - - # Rinv * tmp1 - tmp2 = solve_triangular(L_R, tmp1, lower=True) - tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) - - # C^T Rinv * tmp1 - tmp3 = einsum2('ki,nkj->nij', Rinv_CC, tmp1) - - # (Pinv + C^T Rinv C)_inv * tmp3 - tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) - tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) - - # Rinv C * tmp3 - tmp3 = einsum2('ik,nkj->nij', Rinv_CC, tmp3) - - # add the two Woodbury * tmp1 terms together, left-multiply by tmp1 - tmp = einsum2('nki,nkj->nij', tmp1, tmp2 - tmp3) - - sigmas_smooth[:,t,:,:] = sym(sigma_predict[:,t,:,:] - tmp) - - # prediction - mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_smooth[:,t,:]) - - #sigma_predict = dot3(A[t], sigmas_smooth[t], A[t].T) + Q[t] - tmp = einsum2('ik,nkl->nil', AA[t], sigmas_smooth[:,t,:,:]) - sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) - - for t in range(T-2, -1, -1): - - # these names are stolen from mattjj and slinderman - #temp_nn = np.dot(A[t], sigmas_smooth[n,t,:,:]) - temp_nn = einsum2('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) - - L = np.linalg.cholesky(sigma_predict[:,t+1,:,:]) - v = solve_triangular(L, temp_nn, lower=True) - # Look in Saarka for dfn of Gt_T - Gt_T = solve_triangular(L, v, trans='T', lower=True) - - # {mus,sigmas}_smooth[n,t] contain the filtered estimates so we're - # overwriting them on purpose - #mus_smooth[n,t,:] = mus_smooth[n,t,:] + np.dot(T_(Gt_T), mus_smooth[n,t+1,:] - mu_predict[t+1,:]) - mus_smooth[:,t,:] = mus_smooth[:,t,:] + einsum2('nki,nk->ni', Gt_T, mus_smooth[:,t+1,:] - mu_predict[:,t+1,:]) - - #sigmas_smooth[n,t,:,:] = sigmas_smooth[n,t,:,:] + dot3(T_(Gt_T), sigmas_smooth[n,t+1,:,:] - temp_nn, Gt_T) - tmp = einsum2('nki,nkj->nij', Gt_T, sigmas_smooth[:,t+1,:,:] - sigma_predict[:,t+1,:,:]) - tmp = einsum2('nik,nkj->nij', tmp, Gt_T) - sigmas_smooth[:,t,:,:] = sym(sigmas_smooth[:,t,:,:] + tmp) - - if compute_lag1_cov: - # This matrix is NOT symmetric, so don't symmetrize! - #sigmas_smooth_tnt[n,t,:,:] = np.dot(sigmas_smooth[n,t+1,:,:], Gt_T) - sigmas_smooth_tnt[:,t,:,:] = einsum2('nik,nkj->nij', sigmas_smooth[:,t+1,:,:], Gt_T) - - return ll, mus_smooth, sigmas_smooth, sigmas_smooth_tnt - - - -def predict(Y, A, C, Q, R, mu0, Q0, pred_var=False): - if pred_var: - return _predict_mean_var(Y, A, C, Q, R, mu0, Q0) - else: - return _predict_mean(Y, A, C, Q, R, mu0, Q0) - - -def _predict_mean_var(Y, A, C, Q, R, mu0, Q0): - """ Model predictions for Y given model parameters. - - Handles multiple lag dependence using component form. - - Note: This function doesn't handle control inputs (yet). - - Y : ndarray, shape=(N, T, D) - Observations - - A : ndarray, shape=(T, D*nlag, D*nlag) - Time-varying dynamics matrices - - C : ndarray, shape=(p, D) - Observation matrix - - mu0: ndarray, shape=(D,) - mean of initial state variable - - Q0 : ndarray, shape=(D, D) - Covariance of initial state variable - - Q : ndarray, shape=(D, D) - Covariance of latent states - - R : ndarray, shape=(D, D) - Covariance of observations - """ - - N, T, _ = Y.shape - _, D, Dnlags = A.shape - nlags = Dnlags // D - AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) - - L_R = np.linalg.cholesky(R) - - p = C.shape[0] - CC = hs([C, np.zeros((p, D*(nlags-1)))]) - tmp = solve_triangular(L_R, CC, lower=True) - Rinv_CC = solve_triangular(L_R, tmp, trans='T', lower=True) - CCT_Rinv_CC = einsum2('ki,kj->ij', CC, Rinv_CC) - - # tile L_R across number of trials so solve_triangular - # can broadcast over trials properly - L_R = np.tile(L_R, (N, 1, 1)) - - QQ = np.zeros((T, Dnlags, Dnlags)) - QQ[:,:D,:D] = Q - - QQ0 = block_diag(*[Q0 for _ in range(nlags)]) - - mu_predict = np.empty((N, T+1, Dnlags)) - sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) - - mus_filt = np.empty((N, T, Dnlags)) - sigmas_filt = np.empty((N, T, Dnlags, Dnlags)) - - mu_predict[:,0,:] = np.tile(mu0, nlags) - sigma_predict[:,0,:,:] = QQ0.copy() - - I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) - - Yhat = np.empty_like(Y) - St = np.empty((N, T, p, p)) - - for t in range(T): - - # condition - # sigma_x = dot3(C, sigma_predict, C.T) + R - tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) - sigma_x = einsum2('nik,jk->nij', tmp1, CC) + R - sigma_x = sym(sigma_x) - - St[...,t,:,:] = sigma_x - - L = np.linalg.cholesky(sigma_x) - Yhat[...,t,:] = einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) - res = Y[...,t,:] - Yhat[...,t,:] - - v = solve_triangular(L, res, lower=True) - - mus_filt[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', - tmp1, - solve_triangular(L, v, trans='T', lower=True)) - - # tmp2 = L^{-1}*C*sigma_predict - tmp2 = solve_triangular(L, tmp1, lower=True) - sigmas_filt[:,t,:,:] = sym(sigma_predict[:,t,:,:] - einsum2('nki,nkj->nij', tmp2, tmp2)) - - # prediction - #mu_predict = np.dot(A[t], mus_filt[t]) - mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_filt[:,t,:]) - - #sigma_predict = dot3(A[t], sigmas_filt[t], A[t].T) + Q[t] - tmp = einsum2('ik,nkl->nil', AA[t], sigmas_filt[:,t,:,:]) - sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) - - # just return the diagonal of the St matrices for marginal predictive - # variances - return Yhat, np.diagonal(St, axis1=-2, axis2=-1) - - -def _predict_mean(Y, A, C, Q, R, mu0, Q0): - """ Model predictions for Y given model parameters. - Handles multiple lag dependence using component form. - - Note: This function doesn't handle control inputs (yet). - - Y : ndarray, shape=(N, T, D) - Observations - - A : ndarray, shape=(T, D*nlag, D*nlag) - Time-varying dynamics matrices - - C : ndarray, shape=(p, D) - Observation matrix - - mu0: ndarray, shape=(D,) - mean of initial state variable - - Q0 : ndarray, shape=(D, D) - Covariance of initial state variable - - Q : ndarray, shape=(D, D) - Covariance of latent states - - R : ndarray, shape=(D, D) - Covariance of observations - """ - - N, T, _ = Y.shape - _, D, Dnlags = A.shape - nlags = Dnlags // D - AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) - - L_R = np.linalg.cholesky(R) - - p = C.shape[0] - CC = hs([C, np.zeros((p, D*(nlags-1)))]) - tmp = solve_triangular(L_R, CC, lower=True) - Rinv_CC = solve_triangular(L_R, tmp, trans='T', lower=True) - CCT_Rinv_CC = einsum2('ki,kj->ij', CC, Rinv_CC) - - # tile L_R across number of trials so solve_triangular - # can broadcast over trials properly - L_R = np.tile(L_R, (N, 1, 1)) - - QQ = np.zeros((T, Dnlags, Dnlags)) - QQ[:,:D,:D] = Q - - QQ0 = block_diag(*[Q0 for _ in range(nlags)]) - - mu_predict = np.empty((N, T+1, Dnlags)) - sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) - - mus_filt = np.empty((N, T, Dnlags)) - sigmas_filt = np.empty((N, T, Dnlags, Dnlags)) - - mu_predict[:,0,:] = np.tile(mu0, nlags) - sigma_predict[:,0,:,:] = QQ0.copy() - - I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) - - Yhat = np.empty_like(Y) - - for t in range(T): - - # condition - tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) - - Yhat[...,t,:] = einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) - res = Y[...,t,:] - Yhat[...,t,:] - - # Rinv * res - tmp2 = solve_triangular(L_R, res, lower=True) - tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) - - # C^T Rinv * res - tmp3 = einsum2('ki,nk->ni', Rinv_CC, res) - - # (Pinv + C^T Rinv C)_inv * tmp3 - L_P = np.linalg.cholesky(sigma_predict[:,t,:,:]) - tmp = solve_triangular(L_P, I_tiled, lower=True) - Pinv = solve_triangular(L_P, tmp, trans='T', lower=True) - tmp4 = sym(Pinv + CCT_Rinv_CC) - L_tmp4 = np.linalg.cholesky(tmp4) - tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) - tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) - - # Rinv C * tmp3 - tmp3 = einsum2('ik,nk->ni', Rinv_CC, tmp3) - - # add the two Woodbury * res terms together - tmp = tmp2 - tmp3 - - mus_filt[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', tmp1, tmp) - - # Rinv * tmp1 - tmp2 = solve_triangular(L_R, tmp1, lower=True) - tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) - - # C^T Rinv * tmp1 - tmp3 = einsum2('ki,nkj->nij', Rinv_CC, tmp1) - - # (Pinv + C^T Rinv C)_inv * tmp3 - tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) - tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) - - # Rinv C * tmp3 - tmp3 = einsum2('ik,nkj->nij', Rinv_CC, tmp3) - - # add the two Woodbury * tmp1 terms together, left-multiply by tmp1 - tmp = einsum2('nki,nkj->nij', tmp1, tmp2 - tmp3) - - sigmas_filt[:,t,:,:] = sym(sigma_predict[:,t,:,:] - tmp) - - # prediction - mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_filt[:,t,:]) - - #sigma_predict = dot3(A[t], sigmas_filt[t], A[t].T) + Q[t] - tmp = einsum2('ik,nkl->nil', AA[t], sigmas_filt[:,t,:,:]) - sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) - - return Yhat - - -def predict_step(mu_filt, sigma_filt, A, Q): - mu_predict = einsum2('ik,nk->ni', A, mu_filt) - tmp = einsum2('ik,nkl->nil', A, sigma_filt) - sigma_predict = sym(einsum2('nil,jl->nij', tmp, A) + Q) - - return mu_predict, sigma_predict - - -def condition(y, C, R, mu_predict, sigma_predict): - # dot3(C, sigma_predict, C.T) + R - tmp1 = einsum2('ik,nkj->nij', C, sigma_predict) - sigma_pred = einsum2('nik,jk->nij', tmp1, C) + R - sigma_pred = sym(sigma_pred) - - L = np.linalg.cholesky(sigma_pred) - # the transpose works b/c of how dot broadcasts - #y_hat = np.dot(mu_predict, C.T) - y_hat = einsum2('ik,nk->ni', C, mu_predict) - res = y - y_hat - v = solve_triangular(L, res, lower=True) - - mu_filt = mu_predict + einsum2('nki,nk->ni', tmp1, solve_triangular(L, v, trans='T', lower=True)) - - tmp2 = solve_triangular(L, tmp1, lower=True) - sigma_filt = sym(sigma_predict - einsum2('nki,nkj->nij', tmp2, tmp2)) - - return y_hat, mu_filt, sigma_filt - - -def logZ(Y, A, C, Q, R, mu0, Q0): - """ Log marginal likelihood using the Kalman filter. - - The algorithm broadcasts over the first dimension which are considered - to be independent realizations. - - Note: This function doesn't handle control inputs (yet). - - Y : ndarray, shape=(N, T, D) - Observations - - A : ndarray, shape=(T, D, D) - Time-varying dynamics matrices - - C : ndarray, shape=(p, D) - Observation matrix - - mu0: ndarray, shape=(D,) - mean of initial state variable - - Q0 : ndarray, shape=(D, D) - Covariance of initial state variable - - Q : ndarray, shape=(D, D) - Covariance of latent states - - R : ndarray, shape=(D, D) - Covariance of observations - """ - - N = Y.shape[0] - T, D, _ = A.shape - p = C.shape[0] - - mu_predict = np.stack([np.copy(mu0) for _ in range(N)], axis=0) - sigma_predict = np.stack([np.copy(Q0) for _ in range(N)], axis=0) - - ll = 0. - - for t in range(T): - - # condition - # sigma_x = dot3(C, sigma_predict, C.T) + R - tmp1 = einsum2('ik,nkj->nij', C, sigma_predict) - sigma_x = einsum2('nik,jk->nij', tmp1, C) + R - sigma_x = sym(sigma_x) - - # res[n] = Y[n,t,:] = np.dot(C, mu_predict[n]) - res = Y[...,t,:] - einsum2('ik,nk->ni', C, mu_predict) - - L = np.linalg.cholesky(sigma_x) - v = solve_triangular(L, res, lower=True) - - # log-likelihood over all trials - ll += -0.5*(2.*np.sum(np.log(np.diagonal(L, axis1=1, axis2=2))) - + np.sum(v*v) - + N*p*np.log(2.*np.pi)) - - mus_filt = mu_predict + einsum2('nki,nk->ni', - tmp1, - solve_triangular(L, v, trans='T', lower=True)) - - # tmp2 = L^{-1}*C*sigma_predict - tmp2 = solve_triangular(L, tmp1, lower=True) - sigmas_filt = sigma_predict - einsum2('nki,nkj->nij', tmp2, tmp2) - sigmas_filt = sym(sigmas_filt) - - # prediction - #mu_predict = np.dot(A[t], mus_filt[t]) - mu_predict = einsum2('ik,nk->ni', A[t], mus_filt) - - # originally this worked with time-varying Q, but now it's fixed - #sigma_predict = dot3(A[t], sigmas_filt[t], A[t].T) + Q[t] - sigma_predict = einsum2('ik,nkl->nil', A[t], sigmas_filt) - sigma_predict = einsum2('nil,jl->nij', sigma_predict, A[t]) + Q - sigma_predict = sym(sigma_predict) - - return np.sum(ll) diff --git a/examples/megssm/mne_util.py b/examples/megssm/mne_util.py deleted file mode 100644 index 7e9edd90..00000000 --- a/examples/megssm/mne_util.py +++ /dev/null @@ -1,290 +0,0 @@ -""" MNE-Python utility functions for preprocessing data and constructing - matrices necessary for MEGLDS analysis """ - -import mne -import numpy as np -import os.path as op - -from mne.io.pick import pick_types -from mne.utils import logger -from mne import label_sign_flip - -from scipy.sparse import csc_matrix, csr_matrix, diags -from sklearn.decomposition import PCA - -# from util import Carray ##skip import just pasted; util also from MEGLDS repo -Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') -Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') -Carray = Carray64 - - -class ROIToSourceMap(object): - """ class for computing ROI-to-source space mapping matrix - - Notes - ----- - The following variables defined here correspond to various matrices - defined in :footcite:`yang_state-space_2016`: - - fwd_src_snsr : G - - fwd_roi_snsr : C - - fwd_src_roi : L - - snsr_cov : Q_e - - roi_cov : Q - - roi_cov_0 : Q0 """ - - def __init__(self, fwd, labels, label_flip=False): - - src = fwd['src'] - - roiidx = list() - vertidx = list() - - n_lhverts = len(src[0]['vertno']) - n_rhverts = len(src[1]['vertno']) - n_verts = n_lhverts + n_rhverts - offsets = {'lh': 0, 'rh': n_lhverts} - - hemis = {'lh': 0, 'rh': 1} - - # index vector of which ROI a source point belongs to - which_roi = np.zeros(n_verts, dtype=np.int64) - - data = [] - for li, lab in enumerate(labels): - - this_data = np.round(label_sign_flip(lab, src)) - if not label_flip: - this_data.fill(1.) - data.append(this_data) - if isinstance(lab, mne.Label): - comp_labs = [lab] - elif isinstance(lab, mne.BiHemiLabel): - comp_labs = [lab.lh, lab.rh] - - for clab in comp_labs: - hemi = clab.hemi - hi = 0 if hemi == 'lh' else 1 - - lverts = clab.get_vertices_used(vertices=src[hi]['vertno']) - - # gets the indices in the source space vertex array, not the huge - # array. - # use `src[hi]['vertno'][lverts]` to get surface vertex indices to - # plot. - lverts = np.searchsorted(src[hi]['vertno'], lverts) - lverts += offsets[hemi] - vertidx.extend(lverts) - roiidx.extend(np.full(lverts.size, li, dtype=np.int64)) - - # add 1 b/c 0 corresponds to unassigned variance - which_roi[lverts] = li + 1 - - N = len(labels) - M = n_verts - - # construct sparse fwd_src_roi matrix - data = np.concatenate(data) - vertidx = np.array(vertidx, int) - roiidx = np.array(roiidx, int) - assert data.shape == vertidx.shape == roiidx.shape - fwd_src_roi = csc_matrix((data, (vertidx, roiidx)), shape=(M, N)) - - self.fwd = fwd - self.fwd_src_roi = fwd_src_roi - self.which_roi = which_roi - self.offsets = offsets - self.n_lhverts = n_lhverts - self.n_rhverts = n_rhverts - self.labels = labels - - return - - @property - def fwd_src_sn(self): - return self.fwd['sol']['data'] - - @property - def fwd_src_roi(self): - return self._fwd_src_roi - - @fwd_src_roi.setter - def fwd_src_roi(self, val): - self._fwd_src_roi = val - - @property - def which_roi(self): - return self._which_roi - - @which_roi.setter - def which_roi(self, val): - self._which_roi = val - - @property - def fwd_roi_snsr(self): - from util import Carray - return Carray(csr_matrix.dot(self.fwd_src_roi.T, self.fwd_src_sn.T).T) - - def get_label_vinds(self, label): - li = self.labels.index(label) - if isinstance(label, mne.Label): - label_vert_idx = self.fwd_src_roi[:, li].nonzero()[0] - label_vert_idx -= self.offsets[label.hemi] - return label_vert_idx - elif isinstance(label, mne.BiHemiLabel): - # these labels store both hemispheres so subtract the rh offset - # from that part of the vertex array - lh_label_vert_idx = self.fwd_src_roi[:self.n_lhverts, li].nonzero()[0] - rh_label_vert_idx = self.fwd_src_roi[self.n_lhverts:, li].nonzero()[0] - rh_label_vert_idx[self.n_lhverts:] -= self.offsets['rh'] - return [lh_label_vert_idx, rh_label_vert_idx] - - def get_label_verts(self, label, src): - # if you're thinking of using this to plot, why not just use - # brain.add_label from pysurfer? - if isinstance(label, mne.Label): - hi = 0 if label.hemi == 'lh' else 1 - label_vert_idx = self.get_label_vinds(label) - varray = src[hi]['vertno'][label_vert_idx] - elif isinstance(label, mne.BiHemiLabel): - lh_label_vert_idx, rh_label_vert_idx = self.get_label_vinds(label) - varray = [src[0]['vertno'][lh_label_vert_idx], - src[1]['vertno'][rh_label_vert_idx]] - return varray - - def get_hemi_idx(self, label): - if isinstance(label, mne.Label): - return 0 if label.hemi == 'lh' else 1 - elif isinstance(label, mne.BiHemiLabel): - hemis = [None] * 2 - for i, lab in enumerate([label.lh, label.rh]): - hemis[i] = 0 if lab.hemi == 'lh' else 1 - return hemis - -def apply_projs(epochs, fwd, cov): - """ apply projection operators to fwd and cov """ - proj, _ = mne.io.proj.setup_proj(epochs.info, activate=False) - fwd_src_sn = fwd['sol']['data'] - fwd['sol']['data'] = np.dot(proj, fwd_src_sn) - - roi_cov = cov.data - if not np.allclose(np.dot(proj, roi_cov), roi_cov): - roi_cov = np.dot(proj, np.dot(roi_cov, proj.T)) - cov.data = roi_cov - - return fwd, cov - - -def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., - grad_scale=1.): - """ apply per-channel-type scaling to epochs, forward, and covariance """ - # # from util import Carray ##skip import just pasted; util also from MEGLDS repo - # Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') - # Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') - # Carray = Carray64 - - # get indices for each channel type - ch_names = cov['names'] # same as self.fwd['info']['ch_names'] - sel_eeg = pick_types(fwd['info'], meg=False, eeg=True, ref_meg=False) - sel_mag = pick_types(fwd['info'], meg='mag', eeg=False, ref_meg=False) - sel_grad = pick_types(fwd['info'], meg='grad', eeg=False, ref_meg=False) - idx_eeg = [ch_names.index(ch_names[c]) for c in sel_eeg] - idx_mag = [ch_names.index(ch_names[c]) for c in sel_mag] - idx_grad = [ch_names.index(ch_names[c]) for c in sel_grad] - - # retrieve forward and sensor covariance - fwd_src_snsr = fwd['sol']['data'].copy() - snsr_cov = cov.data.copy() - - # scale forward matrix - fwd_src_snsr[idx_eeg,:] *= eeg_scale - fwd_src_snsr[idx_mag,:] *= mag_scale - fwd_src_snsr[idx_grad,:] *= grad_scale - - # construct fwd_roi_snsr matrix - fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, fwd_src_snsr.T).T) - - # scale sensor covariance - snsr_cov[np.ix_(idx_eeg, idx_eeg)] *= eeg_scale**2 - snsr_cov[np.ix_(idx_mag, idx_mag)] *= mag_scale**2 - snsr_cov[np.ix_(idx_grad, idx_grad)] *= grad_scale**2 - - # scale epochs - info = epochs.info.copy() - data = epochs.get_data().copy() - - data[:,idx_eeg,:] *= eeg_scale - data[:,idx_mag,:] *= mag_scale - data[:,idx_grad,:] *= grad_scale - - epochs = mne.EpochsArray(data, info) - - return fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs - - -def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', - pctvar=0.99, mean_center=False, label_flip=False): - """ apply sensor scaling, PCA dimensionality reduction with/without - whitening, and mean-centering to subject data """ - - if dim_mode not in ['rank', 'pctvar', 'whiten']: - raise ValueError("dim_mode must be in {'rank', 'pctvar', 'whiten'}") - - print("running pca for subject %s" % subject_name) - - # compute ROI-to-source map - roi_to_src = ROIToSourceMap(fwd, labels, label_flip) - scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} - if dim_mode == 'whiten': - - # scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} - G, GL, Q_snsr, epochs = \ - _scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) - dat = epochs.get_data() - dat = Carray(np.swapaxes(dat, -1, -2)) - - if mean_center: - dat -= np.mean(dat, axis=1, keepdims=True) - - dat_stacked = np.reshape(dat, (-1, dat.shape[-1])) - - W, _ = mne.cov.compute_whitener(subject.sensor_cov, - info=subject.epochs_list[0].info, - pca=True) - print("whitener for subject %s using %d principal components" % - (subject_name, W.shape[0])) - - else: - - G, GL, Q_snsr, epochs = \ - _scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) #info.scales - dat = epochs.get_data() - dat = Carray(np.swapaxes(dat, -1, -2)) - - if mean_center: - dat -= np.mean(dat, axis=1, keepdims=True) - - dat_stacked = np.reshape(dat, (-1, dat.shape[-1])) - - pca = PCA() - pca.fit(dat_stacked) - - if dim_mode == 'rank': - idx = np.linalg.matrix_rank(np.cov(dat_stacked, rowvar=False)) - else: - idx = np.where(np.cumsum(pca.explained_variance_ratio_) > pctvar)[0][0] - - idx = np.maximum(idx, len(labels)) - W = pca.components_[:idx] - print("subject %s using %d principal components" % (subject_name, idx)) - - ntrials, T, _ = dat.shape - dat_pca = np.dot(dat_stacked, W.T) - dat_pca = np.reshape(dat_pca, (ntrials, T, -1)) - - G_pca = np.dot(W, G) - GL_pca = np.dot(W, GL) - Q_snsr_pca = np.dot(W,np.dot(Q_snsr, W.T)) #dot3(W, Q_snsr, W.T) - - data = dat_pca - - return data, GL_pca, G_pca, Q_snsr_pca, roi_to_src.which_roi \ No newline at end of file diff --git a/examples/megssm/models.py b/examples/megssm/models.py deleted file mode 100755 index 64d606bc..00000000 --- a/examples/megssm/models.py +++ /dev/null @@ -1,872 +0,0 @@ -import sys - -import autograd.numpy as np -import scipy.optimize as spopt - -from autograd import grad #autograd --> jax -from autograd import value_and_grad as vgrad -from scipy.linalg import LinAlgError - -from .util import _ensure_ndim, rand_stable, rand_psd -from .util import linesearch, soft_thresh_At, block_thresh_At -from .util import relnormdiff -from .message_passing import kalman_filter, rts_smooth, rts_smooth_fast -from .message_passing import predict_step, condition -from .numpy_numthreads import numpy_num_threads - -from .mne_util import ROIToSourceMap, _scale_sensor_data, run_pca_on_subject - -try: - from autograd_linalg import solve_triangular -except ImportError: - raise RuntimeError("must install `autograd_linalg` package") - -# einsum2 is a parallel version of einsum that works for two arguments -try: - from einsum2 import einsum2 -except ImportError: - # rename standard numpy function if don't have einsum2 - print("=> WARNING: using standard numpy.einsum,", - "consider installing einsum2 package") - from autograd.numpy import einsum as einsum2 - -from datetime import datetime - - -# TODO: add documentation to all methods -class _MEGModel(object): - """ Base class for any model applied to MEG data that handles storing and - unpacking data from tuples. """ - - def __init__(self): - self._subjectdata = None - self._timepts = 0 - self._ntrials_all = 0 - self._nsubjects = 0 - - def set_data(self, subjectdata): - timepts_lst = [self.unpack_subject_data(e)[0].shape[1] for e in subjectdata] - assert len(list(set(timepts_lst))) == 1 - self._timepts = timepts_lst[0] - ntrials_lst = [self.unpack_subject_data(e)[0].shape[0] for e in \ - subjectdata] - self._ntrials_all = np.sum(ntrials_lst) - self._nsubjects = len(subjectdata) - self._subjectdata = subjectdata - - def unpack_all_subject_data(self): - if self._subjectdata is None: - raise ValueError("use set_data to add subject data") - return map(self.unpack_subject_data, self._subjectdata) - - @classmethod - def unpack_subject_data(cls, sdata): - obs, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - Y = obs - w_s = 1. - if isinstance(obs, tuple): - if len(obs) == 2: - Y, w_s = obs - else: - raise ValueError("invalid format for subject data") - else: - Y = obs - w_s = 1. - - return Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi - - -# TODO: add documentation to all methods -# TODO: make some methods "private" (leading underscore) if necessary -class MEGLDS(_MEGModel): - """ State-space model for MEG data, as described in "A state-space model of - cross-region dynamic connectivity in MEG/EEG", Yang et al., NIPS 2016. - """ - - def __init__(self, num_roi, timepts, A=None, roi_cov=None, mu0=None, roi_cov_0=None, - log_sigsq_lst=None, lam0=0., lam1=0., penalty='ridge', - store_St=True): - - super().__init__() - - set_default = \ - lambda prm, val, deflt: \ - self.__setattr__(prm, val.copy() if val is not None else deflt) - - # initialize parameters - set_default("A", A, - np.stack([rand_stable(num_roi, maxew=0.7) for _ in range(timepts)], - axis=0)) - set_default("roi_cov", roi_cov, rand_psd(num_roi)) - set_default("mu0", mu0, np.zeros(num_roi)) - set_default("roi_cov_0", roi_cov_0, rand_psd(num_roi)) - set_default("log_sigsq_lst", log_sigsq_lst, - [np.log(np.random.gamma(2, 1, size=num_roi+1))]) - - self.lam0 = lam0 - self.lam1 = lam1 - - if penalty not in ('ridge', 'lasso', 'group-lasso'): - raise ValueError('penalty must be one of: ridge, lasso,' \ - + ' group-lasso') - self._penalty = penalty - - # initialize lists of smoothed estimates - self._mus_smooth_lst = None - self._sigmas_smooth_lst = None - self._sigmas_tnt_smooth_lst = None - self._loglik = None - self._store_St = bool(store_St) - - # initialize sufficient statistics - timepts, num_roi, _ = self.A.shape - self._B0 = np.zeros((num_roi, num_roi)) - self._B1 = np.zeros((timepts-1, num_roi, num_roi)) - self._B3 = np.zeros((timepts-1, num_roi, num_roi)) - self._B2 = np.zeros((timepts-1, num_roi, num_roi)) - self._B4 = list() - - self._subject_data = dict() - - def add_subject(self, subject, subject_dir, epochs,labels, fwd, cov): - roi_to_src = ROIToSourceMap(fwd, labels) # compute ROI-to-source map - scales = {'eeg_scale' : 1, 'mag_scale' : 1, 'grad_scale' : 1} #scale=1 has no effect - fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs = \ - _scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) - - # which_roi = roi_to_src.which_roi # array of len(sources); val = ROI of source - # data = epochs._data - # data = np.swapaxes(data,-1,-2) - # subjectdata = [(data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi)] - - sdata = run_pca_on_subject(subject,epochs, fwd, cov, labels) - data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - subjectdata = [(data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi)] - - self.set_data(subjectdata) - - # epochs, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = subjectdata - self._subject_data[subject] = dict() - self._subject_data[subject]['epochs'] = epochs - self._subject_data[subject]['fwd_src_snsr'] = fwd_src_snsr - self._subject_data[subject]['fwd_roi_snsr'] = fwd_roi_snsr - self._subject_data[subject]['snsr_cov'] = snsr_cov - self._subject_data[subject]['labels'] = labels - self._subject_data[subject]['which_roi'] = which_roi - - def set_data(self, subjectdata): - # add subject data, re-generate log_sigsq_lst if necessary - super().set_data(subjectdata) - if len(self.log_sigsq_lst) != self._nsubjects: - num_roi = self.log_sigsq_lst[0].shape[0] - self.log_sigsq_lst = [np.log(np.random.gamma(2, 1, size=num_roi)) - for _ in range(self._nsubjects)] - - # reset smoothed estimates and log-likelihood (no longer valid if - # new data was added) - self._mus_smooth_lst = None - self._sigmas_smooth_lst = None - self._sigmas_tnt_smooth_lst = None - self._loglik = None - self._B4 = [None] * self._nsubjects - - # TODO: figure out how to initialize smoothed parameters so this doesn't - # break, e.g. if em_objective is called before em for some reason - def em_objective(self): - - _, num_roi, _ = self.A.shape - - L_roi_cov_0 = np.linalg.cholesky(self.roi_cov_0) - L_roi_cov = np.linalg.cholesky(self.roi_cov) - - L1 = 0. - L2 = 0. - L3 = 0. - - obj = 0. - for s, sdata in enumerate(self.unpack_all_subject_data()): - - Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - - ntrials, timepts, _ = Y.shape - - sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) - L_R = np.linalg.cholesky(R) - - if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None - or self._sigmas_tnt_smooth_lst is None): - roi_cov_t = _ensure_ndim(self.roi_cov, timepts, 3) - with numpy_num_threads(1): - _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ - rts_smooth_fast(Y, self.A, fwd_roi_snsr, roi_cov_t, R, self.mu0, - self.roi_cov_0, compute_lag1_cov=True) - - else: - mus_smooth = self._mus_smooth_lst[s] - sigmas_smooth = self._sigmas_smooth_lst[s] - sigmas_tnt_smooth = self._sigmas_tnt_smooth_lst[s] - - x_smooth_0_outer = einsum2('ri,rj->rij', mus_smooth[:,0,:num_roi], - mus_smooth[:,0,:num_roi]) - B0 = w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + x_smooth_0_outer, - axis=0) - - x_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], - mus_smooth[:,1:,:num_roi]) - B1 = w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + x_smooth_outer, axis=0) - - z_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,:-1,:], - mus_smooth[:,:-1,:]) - B3 = w_s*np.sum(sigmas_smooth[:,:-1,:,:] + z_smooth_outer, axis=0) - - mus_smooth_outer_l1 = einsum2('rti,rtj->rtij', - mus_smooth[:,1:,:num_roi], - mus_smooth[:,:-1,:]) - B2 = w_s*np.sum(sigmas_tnt_smooth[:,:,:num_roi,:] + mus_smooth_outer_l1, - axis=0) - - # obj += L1(roi_cov_0) - L_roi_cov_0_inv_B0 = solve_triangular(L_roi_cov_0, B0, lower=True) - L1 += (ntrials*2.*np.sum(np.log(np.diag(L_roi_cov_0))) - + np.trace(solve_triangular(L_roi_cov_0, L_roi_cov_0_inv_B0, lower=True, - trans='T'))) - - At = self.A[:-1] - AtB2T = einsum2('tik,tjk->tij', At, B2) - B2AtT = einsum2('tik,tjk->tij', B2, At) - tmp = einsum2('tik,tkl->til', At, B3) - AtB3AtT = einsum2('tik,tjk->tij', tmp, At) - - tmp = np.sum(B1 - AtB2T - B2AtT + AtB3AtT, axis=0) - - # obj += L2(roi_cov, At) - L_roi_cov_inv_tmp = solve_triangular(L_roi_cov, tmp, lower=True) - L2 += (ntrials*(timepts-1)*2.*np.sum(np.log(np.diag(L_roi_cov))) - + np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_tmp, lower=True, - trans='T'))) - - res = Y - einsum2('ik,ntk->nti', fwd_roi_snsr, mus_smooth[:,:,:num_roi]) - CP_smooth = einsum2('ik,ntkj->ntij', fwd_roi_snsr, sigmas_smooth[:,:,:num_roi,:num_roi]) - - # TODO: np.sum does not parallelize over the accumulators, possible - # bottleneck. - B4 = w_s*(np.sum(einsum2('nti,ntj->ntij', res, res), axis=(0,1)) - + np.sum(einsum2('ntik,jk->ntij', CP_smooth, fwd_roi_snsr), - axis=(0,1))) - self._B4[s] = B4 - - # obj += L3(sigsq_vals) - L_R_inv_B4 = solve_triangular(L_R, B4, lower=True) - L3 += (ntrials*timepts*2*np.sum(np.log(np.diag(L_R))) - + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, - trans='T'))) - - obj = (L1 + L2 + L3) / self._ntrials_all - - # obj += penalty - if self.lam0 > 0.: - if self._penalty == 'ridge': - obj += self.lam0*np.sum(At**2) - elif self._penalty == 'lasso': - At_diag = np.diagonal(At, axis1=-2, axis2=-1) - sum_At_diag = np.sum(np.abs(At_diag)) - obj += self.lam0*(np.sum(np.abs(At)) - sum_At_diag) - elif self._penalty == 'group-lasso': - At_diag = np.diagonal(At, axis1=-2, axis2=-1) - norm_At_diag = np.sum(np.linalg.norm(At_diag, axis=0)) - norm_At = np.sum(np.linalg.norm(At, axis=0)) - obj += self.lam1*(norm_At - norm_At_diag) - if self.lam1 > 0.: - AtmAtm1_2 = (At[1:] - At[:-1])**2 - obj += self.lam1*np.sum(AtmAtm1_2) - - return obj - - def fit(self, niter=100, tol=1e-6, A_roi_cov_niter=100, A_roi_cov_tol=1e-6, verbose=0, - update_A=True, update_roi_cov=True, update_roi_cov_0=True, stationary_A=False, - diag_roi_cov=False, update_sigsq=True, do_final_smoothing=True, - average_mus_smooth=True, Atrue=None, tau=0.1, c1=1e-4): - - fxn_start = datetime.now() - - timepts, num_roi, _ = self.A.shape - - # make initial A stationary if stationary_A option specified - if stationary_A: - self.A[:] = np.mean(self.A, axis=0) - - # set parameters for (A, roi_cov) optimization - self._A_roi_cov_niter = A_roi_cov_niter - self._A_roi_cov_tol = A_roi_cov_tol - - # make initial roi_cov, roi_cov_0 diagonal if diag_roi_cov specified - if diag_roi_cov: - self.roi_cov_0 = np.diag(np.diag(self.roi_cov_0)) - self.roi_cov = np.diag(np.diag(self.roi_cov)) - - - # keeping track of objective value and best parameters - objvals = np.zeros(niter+1) - converged = False - best_objval = np.finfo('float').max - best_params = (self.A.copy(), self.roi_cov.copy(), self.mu0.copy(), - self.roi_cov_0.copy(), [l.copy() for l in self.log_sigsq_lst]) - - # previous parameter values (for checking convergence) - At_prev = None - roi_cov_prev = None - roi_cov_0_prev = None - log_sigsq_lst_prev = None - - if Atrue is not None: - import matplotlib.pyplot as plt - fig_A, ax_A = plt.subplots(num_roi, num_roi, sharex=True, sharey=True) - plt.ion() - - # calculate initial objective value, check for updated best iterate - # have to do e-step here to initialize suff stats for m_step - if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None - or self._sigmas_tnt_smooth_lst is None): - self.e_step(verbose=verbose-1) - - objval = self.em_objective() - objvals[0] = objval - - for it in range(1, niter+1): - - iter_start = datetime.now() - - if verbose > 0: - print("em: it %d / %d" % (it, niter)) - sys.stdout.flush() - sys.stderr.flush() - - # record values from previous M-step - At_prev = self.A[:-1].copy() - roi_cov_prev = self.roi_cov.copy() - roi_cov_0_prev = self.roi_cov_0.copy() - log_sigsq_lst_prev = np.array(self.log_sigsq_lst).copy() - - self.m_step(update_A=update_A, update_roi_cov=update_roi_cov, - update_roi_cov_0=update_roi_cov_0, stationary_A=stationary_A, - diag_roi_cov=diag_roi_cov, update_sigsq=update_sigsq, - tau=tau, c1=c1, verbose=verbose) - - if Atrue is not None: - for i in range(num_roi): - for j in range(num_roi): - ax_A[i, j].cla() - ax_A[i, j].plot(Atrue[:-1, i, j], color='green') - ax_A[i, j].plot(self.A[:-1, i, j], color='red', - alpha=0.7) - fig_A.tight_layout() - fig_A.canvas.draw() - plt.pause(1. / 60.) - - self.e_step(verbose=verbose-1) - - # calculate objective value, check for updated best iterate - objval = self.em_objective() - objvals[it] = objval - - if verbose > 0: - print(" objective: %.4e" % objval) - At = self.A[:-1] - maxAt = np.max(np.abs(np.triu(At, k=1) + np.tril(At, k=-1))) - print(" max |A_t|: %.4e" % (maxAt,)) - sys.stdout.flush() - sys.stderr.flush() - - if objval < best_objval: - best_objval = objval - best_params = (self.A.copy(), self.roi_cov.copy(), self.mu0.copy(), - self.roi_cov_0.copy(), - [l.copy() for l in self.log_sigsq_lst]) - - # check for convergence - if it >= 1: - relnormdiff_At = relnormdiff(self.A[:-1], At_prev) - relnormdiff_roi_cov = relnormdiff(self.roi_cov, roi_cov_prev) - relnormdiff_roi_cov_0 = relnormdiff(self.roi_cov_0, roi_cov_0_prev) - relnormdiff_log_sigsq_lst = \ - np.array( - [relnormdiff(self.log_sigsq_lst[s], - log_sigsq_lst_prev[s]) - for s in range(len(self.log_sigsq_lst))]) - params_converged = (relnormdiff_At <= tol) and \ - (relnormdiff_roi_cov <= tol) and \ - (relnormdiff_roi_cov_0 <= tol) and \ - np.all(relnormdiff_log_sigsq_lst <= tol) - - relobjdiff = np.abs((objval - objvals[it-1]) / objval) - - if verbose > 0: - print(" relnormdiff_At: %.3e" % relnormdiff_At) - print(" relnormdiff_roi_cov: %.3e" % relnormdiff_roi_cov) - print(" relnormdiff_roi_cov_0: %.3e" % relnormdiff_roi_cov_0) - print(" relnormdiff_log_sigsq_lst:", - relnormdiff_log_sigsq_lst) - print(" relobjdiff: %.3e" % relobjdiff) - - objdiff = objval - objvals[it-1] - if objdiff > 0: - print(" \033[0;31mEM objective increased\033[0m") - - sys.stdout.flush() - sys.stderr.flush() - - if params_converged or relobjdiff <= tol: - if verbose > 0: - print("EM objective converged") - sys.stdout.flush() - sys.stderr.flush() - converged = True - objvals = objvals[:it+1] - break - - # retrieve best parameters and load into instance variables. - A, roi_cov, mu0, roi_cov_0, log_sigsq_lst = best_params - self.A = A.copy() - self.roi_cov = roi_cov.copy() - self.mu0 = mu0.copy() - self.roi_cov_0 = roi_cov_0.copy() - self.log_sigsq_lst = [l.copy() for l in log_sigsq_lst] - - if verbose > 0: - print() - print("elapsed, iteration:", datetime.now() - iter_start) - print("=" * 34) - print() - - # perform final smoothing - mus_smooth_lst = None - St_lst = None - if do_final_smoothing: - if verbose >= 1: - print("performing final smoothing") - - mus_smooth_lst = list() - self._loglik = 0. - if self._store_St: - St_lst = list() - for s, sdata in enumerate(self.unpack_all_subject_data()): - Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) - roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) - with numpy_num_threads(1): - loglik_subject, mus_smooth, _, _, St = \ - rts_smooth(Y, self.A, fwd_roi_snsr, roi_cov_t, R, self.mu0, self.roi_cov_0, - compute_lag1_cov=False, - store_St=self._store_St) - # just save the mean of the smoothed trials - if average_mus_smooth: - mus_smooth_lst.append(np.mean(mus_smooth, axis=0)) - else: - mus_smooth_lst.append(mus_smooth) - self._loglik += loglik_subject - # just save the diagonals of St b/c that's what we need for - # connectivity - if self._store_St: - St_lst.append(np.diagonal(St, axis1=-2, axis2=-1)) - - if verbose > 0: - print() - print("elapsed, function:", datetime.now() - fxn_start) - print("=" * 34) - print() - - return objvals, converged, mus_smooth_lst, self._loglik, St_lst - - def e_step(self, verbose=0): - - timepts, num_roi, _ = self.A.shape - - # reset accumulation arrays - self._B0[:] = 0. - self._B1[:] = 0. - self._B3[:] = 0. - self._B2[:] = 0. - - self._mus_smooth_lst = list() - self._sigmas_smooth_lst = list() - self._sigmas_tnt_smooth_lst = list() - - if verbose > 0: - print(" e-step") - print(" subject", end="") - - for s, sdata in enumerate(self.unpack_all_subject_data()): - - if verbose > 0: - print(" %d" % (s+1,), end="") - sys.stdout.flush() - sys.stderr.flush() - - Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - - sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) - L_R = np.linalg.cholesky(R) - roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) - - with numpy_num_threads(1): - _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ - rts_smooth_fast(Y, self.A, fwd_roi_snsr, roi_cov_t, R, self.mu0, - self.roi_cov_0, compute_lag1_cov=True) - - self._mus_smooth_lst.append(mus_smooth) - self._sigmas_smooth_lst.append(sigmas_smooth) - self._sigmas_tnt_smooth_lst.append(sigmas_tnt_smooth) - - x_smooth_0_outer = einsum2('ri,rj->rij', mus_smooth[:,0,:num_roi], - mus_smooth[:,0,:num_roi]) - self._B0 += w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + x_smooth_0_outer, - axis=0) - - x_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], - mus_smooth[:,1:,:num_roi]) - self._B1 += w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + x_smooth_outer, - axis=0) - - z_smooth_outer = einsum2('rti,rtj->rtij', mus_smooth[:,:-1,:], - mus_smooth[:,:-1,:]) - self._B3 += w_s*np.sum(sigmas_smooth[:,:-1,:,:] + z_smooth_outer, - axis=0) - - mus_smooth_outer_l1 = einsum2('rti,rtj->rtij', - mus_smooth[:,1:,:num_roi], - mus_smooth[:,:-1,:]) - self._B2 += w_s*np.sum(sigmas_tnt_smooth[:,:,:num_roi,:] + - mus_smooth_outer_l1, axis=0) - - if verbose > 0: - print("\n done") - - def m_step(self, update_A=True, update_roi_cov=True, update_roi_cov_0=True, - stationary_A=False, diag_roi_cov=False, update_sigsq=True, tau=0.1, c1=1e-4, - verbose=0): - self._loglik = None - if verbose > 0: - print(" m-step") - if update_roi_cov_0: - self.roi_cov_0 = (1. / self._ntrials_all) * self._B0 - if diag_roi_cov: - self.roi_cov_0 = np.diag(np.diag(self.roi_cov_0)) - self.update_A_and_roi_cov(update_A=update_A, update_roi_cov=update_roi_cov, - stationary_A=stationary_A, diag_roi_cov=diag_roi_cov, - tau=tau, c1=c1, verbose=verbose) - if update_sigsq: - self.update_log_sigsq_lst(verbose=verbose) - - def update_A_and_roi_cov(self, update_A=True, update_roi_cov=True, stationary_A=False, - diag_roi_cov=False, tau=0.1, c1=1e-4, verbose=0): - - if verbose > 1: - print(" update A and roi_cov") - - # gradient descent - At = self.A[:-1] - At_init = At.copy() - L_roi_cov = np.linalg.cholesky(self.roi_cov) - At_L_roi_cov_obj = lambda x, y: self.L2_obj(x, y) - At_obj = lambda x: self.L2_obj(x, L_roi_cov) - grad_At_obj = grad(At_obj) - obj_diff = np.finfo('float').max - obj = At_L_roi_cov_obj(At, L_roi_cov) - inner_it = 0 - - # specify proximal operator to use - if self._penalty == 'ridge': - prox_op = lambda x, y: x - elif self._penalty == 'lasso': - prox_op = soft_thresh_At - elif self._penalty == 'group-lasso': - prox_op = block_thresh_At - - while np.abs(obj_diff / obj) > self._A_roi_cov_tol: - - if inner_it > self._A_roi_cov_niter: - break - - obj_start = At_L_roi_cov_obj(At, L_roi_cov) - - # update At using gradient descent with backtracking line search - if update_A: - if stationary_A: - B2_sum = np.sum(self._B2, axis=0) - B3_sum = np.sum(self._B3, axis=0) - At[:] = np.linalg.solve(B3_sum.T, B2_sum.T).T - else: - grad_At = grad_At_obj(At) - step_size = linesearch(At_obj, grad_At_obj, At, grad_At, - prox_op=prox_op, lam=self.lam0, - tau=tau, c1=c1) - At[:] = prox_op(At - step_size * grad_At, - self.lam0 * step_size) - - # update roi_cov using closed form - if update_roi_cov: - AtB2T = einsum2('tik,tjk->tij', At, self._B2) - B2AtT = einsum2('tik,tjk->tij', self._B2, At) - tmp = einsum2('tik,tkl->til', At, self._B3) - AtB3AtT = einsum2('til,tjl->tij', tmp, At) - elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) - self.roi_cov = (1. / (self._ntrials_all * self._timepts - )) * elbo_2 - if diag_roi_cov: - self.roi_cov = np.diag(np.diag(self.roi_cov)) - L_roi_cov = np.linalg.cholesky(self.roi_cov) - - obj = At_L_roi_cov_obj(At, L_roi_cov) - obj_diff = obj_start - obj - inner_it += 1 - - if verbose > 1: - if not stationary_A and update_A: - grad_norm = np.linalg.norm(grad_At) - norm_change = np.linalg.norm(At - At_init) - print(" last step size: %.3e" % step_size) - print(" last gradient norm: %.3e" % grad_norm) - print(" norm of total change: %.3e" % norm_change) - print(" number of iterations: %d" % inner_it) - print(" done") - - def update_log_sigsq_lst(self, verbose=0): - - if verbose > 1: - print(" update subject log-sigmasq") - - timepts, num_roi, _ = self.A.shape - - # update log_sigsq_vals for each subject and ROI - for s, sdata in enumerate(self.unpack_all_subject_data()): - - Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - ntrials, timepts, _ = Y.shape - mus_smooth = self._mus_smooth_lst[s] - sigmas_smooth = self._sigmas_smooth_lst[s] - B4 = self._B4[s] - - log_sigsq = self.log_sigsq_lst[s].copy() - log_sigsq_obj = lambda x: \ - MEGLDS.L3_obj(x, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, timepts) - log_sigsq_val_and_grad = vgrad(log_sigsq_obj) - - options = {'maxiter': 500} - opt_res = spopt.minimize(log_sigsq_val_and_grad, log_sigsq, - method='L-BFGS-B', jac=True, - options=options) - if verbose > 1: - print(" subject %d - %d iterations" % (s+1, opt_res.nit)) - - if not opt_res.success: - print(" log_sigsq opt") - print(" %s" % opt_res.message) - - self.log_sigsq_lst[s] = opt_res.x - - if verbose > 1: - print("\n done") - - def calculate_smoothed_estimates(self): - """ recalculate smoothed estimates with current model parameters """ - - self._mus_smooth_lst = list() - self._sigmas_smooth_lst = list() - self._sigmas_tnt_smooth_lst = list() - self._St_lst = list() - self._loglik = 0. - - for s, sdata in enumerate(self.unpack_all_subject_data()): - Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) - roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) - with numpy_num_threads(1): - ll, mus_smooth, sigmas_smooth, sigmas_tnt_smooth, _ = \ - rts_smooth(Y, self.A, fwd_roi_snsr, roi_cov_t, R, self.mu0, self.roi_cov_0, - compute_lag1_cov=True, store_St=False) - self._mus_smooth_lst.append(mus_smooth) - self._sigmas_smooth_lst.append(sigmas_smooth) - self._sigmas_tnt_smooth_lst.append(sigmas_tnt_smooth) - #self._St_lst.append(np.diagonal(St, axis1=-2, axis2=-1)) - self._loglik += ll - - def log_likelihood(self): - """ calculate log marginal likelihood using the Kalman filter """ - - #if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None \ - # or self._sigmas_tnt_smooth_lst is None): - # self.calculate_smoothed_estimates() - # return self._loglik - if self._loglik is not None: - return self._loglik - - self._loglik = 0. - for s, sdata in enumerate(self.unpack_all_subject_data()): - Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) - roi_cov_t = _ensure_ndim(self.roi_cov, self._timepts, 3) - ll, _, _, _ = kalman_filter(Y, self.A, fwd_roi_snsr, roi_cov_t, R, self.mu0, - self.roi_cov_0, store_St=False) - self._loglik += ll - - return self._loglik - - def nparams(self): - timepts, p, _ = self.A.shape - - # this should equal (timepts-1)*p*p unless some shrinkage is used on At - nparams_At = np.sum(np.abs(self.A[:-1]) > 0) - - # nparams = nparams(At) + nparams(roi_cov) + nparams(roi_cov_0) - # + nparams(log_sigsq_lst) - return nparams_At + p*(p+1)/2 + p*(p+1)/2 \ - + np.sum([p+1 for _ in range(len(self.log_sigsq_lst))]) - - def AIC(self): - return -2*self.log_likelihood() + 2*self.nparams() - - def BIC(self): - if self._ntrials_all == 0: - raise RuntimeError("use set_data to add subject data before" \ - + " computing BIC") - return -2*self.log_likelihood() \ - + np.log(self._ntrials_all)*self.nparams() - - def save(self, filename, **kwargs): - savedict = { 'A' : self.A, 'roi_cov' : self.roi_cov, 'mu0' : self.mu0, - 'roi_cov_0' : self.roi_cov_0, 'log_sigsq_lst' : self.log_sigsq_lst, - 'lam0' : self.lam0, 'lam1' : self.lam1} - savedict.update(kwargs) - np.savez_compressed(filename, **savedict) - - def load(self, filename): - loaddict = np.load(filename) - param_names = ['A', 'roi_cov', 'mu0', 'roi_cov_0', 'log_sigsq_lst', 'lam0', 'lam1'] - for name in param_names: - if name not in loaddict.keys(): - raise RuntimeError('specified file is not a saved model:\n%s' - % (filename,)) - for name in param_names: - if name == 'log_sigsq_lst': - self.log_sigsq_lst = [l.copy() for l in loaddict[name]] - elif name in ('lam0', 'lam1'): - self.__setattr__(name, float(loaddict[name])) - else: - self.__setattr__(name, loaddict[name].copy()) - - # return remaining saved items, if there are any - others = {key : loaddict[key] for key in loaddict.keys() \ - if key not in param_names} - if len(others.keys()) > 0: - return others - - @staticmethod - def R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi): - return snsr_cov + np.dot(fwd_src_snsr, sigsq_vals[which_roi][:,None]*fwd_src_snsr.T) - - def L2_obj(self, At, L_roi_cov): - - # import autograd.numpy - # if isinstance(At,autograd.numpy.numpy_boxes.ArrayBox): - # At = At._value - - AtB2T = einsum2('tik,tjk->tij', At, self._B2) - B2AtT = einsum2('tik,tjk->tij', self._B2, At) - tmp = einsum2('tik,tkl->til', At, self._B3) - AtB3AtT = einsum2('til,tjl->tij', tmp, At) - elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) - - L_roi_cov_inv_elbo_2 = solve_triangular(L_roi_cov, elbo_2, lower=True) - obj = np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_elbo_2, lower=True, - trans='T')) - obj = obj / self._ntrials_all - - if self._penalty == 'ridge': - obj += self.lam0*np.sum(At**2) - AtmAtm1_2 = (At[1:] - At[:-1])**2 - obj += self.lam1*np.sum(AtmAtm1_2) - - return obj - - # TODO: convert to instance method - @staticmethod - def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, timepts): - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, np.exp(log_sigsq_vals), which_roi) - try: - L_R = np.linalg.cholesky(R) - except LinAlgError: - return np.finfo('float').max - L_R_inv_B4 = solve_triangular(L_R, B4, lower=True) - return (ntrials*timepts*2.*np.sum(np.log(np.diag(L_R))) - + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, - trans='T'))) - - - @property - def A(self): - return self._A - - @A.setter - def A(self, A): - self._A = A - - @property - def roi_cov(self): - return self._roi_cov - - @roi_cov.setter - def roi_cov(self, roi_cov): - self._roi_cov = roi_cov - - @property - def mu0(self): - return self._mu0 - - @mu0.setter - def mu0(self, mu0): - self._mu0 = mu0 - - @property - def roi_cov_0(self): - return self._roi_cov_0 - - @roi_cov_0.setter - def roi_cov_0(self, roi_cov_0): - self._roi_cov_0 = roi_cov_0 - - @property - def log_sigsq_lst(self): - return self._log_sigsq_lst - - @log_sigsq_lst.setter - def log_sigsq_lst(self, log_sigsq_lst): - self._log_sigsq_lst = log_sigsq_lst - - @property - def num_roi(self): - return self.A.shape[1] - - @property - def timepts(self): - return self._timepts - - @property - def lam0(self): - return self._lam0 - - @lam0.setter - def lam0(self, lam0): - self._lam0 = lam0 - - @property - def lam1(self): - return self._lam1 - - @lam1.setter - def lam1(self, lam1): - self._lam1 = lam1 diff --git a/examples/megssm/numpy_numthreads.py b/examples/megssm/numpy_numthreads.py deleted file mode 100755 index 550aa235..00000000 --- a/examples/megssm/numpy_numthreads.py +++ /dev/null @@ -1,91 +0,0 @@ -import contextlib -import ctypes -from ctypes.util import find_library - -# heavily based on: -# https://stackoverflow.com/questions/29559338/set-max-number-of-threads-at-runtime-on-numpy-openblas - -# Prioritize hand-compiled OpenBLAS library over version in /usr/lib/ -# from Ubuntu repos -try_paths = [find_library('openblas')] -openblas_lib = None -for libpath in try_paths: - try: - openblas_lib = ctypes.cdll.LoadLibrary(libpath) - break - except Exception: #OSError: - continue -#if openblas_lib is None: - #raise EnvironmentError('Could not locate an OpenBLAS shared library', 2) - -try: - mkl_rt_path = find_library('mkl_rt') - mkl_rt = ctypes.cdll.LoadLibrary(mkl_rt_path) - # print(mkl_rt) -except OSError: - mkl_rt = None - pass - - -def set_num_threads(n): - """Set the current number of threads used by the OpenBLAS server.""" - if mkl_rt: - pass - #mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(n))) - elif openblas_lib: - openblas_lib.openblas_set_num_threads(int(n)) - - -# At the time of writing these symbols were very new: -# https://github.com/xianyi/OpenBLAS/commit/65a847c -try: - if mkl_rt: #False: #mkl_rt: - def get_num_threads(): - return mkl_rt.mkl_get_max_threads() - elif openblas_lib: - # do this to throw exception if it doesn't exist - openblas_lib.openblas_get_num_threads() - def get_num_threads(): - """Get the current number of threads used by the OpenBLAS server.""" - return openblas_lib.openblas_get_num_threads() -except AttributeError: - def get_num_threads(): - """Dummy function (symbol not present in %s), returns -1.""" - return -1 - pass - -try: - if False: #mkl_rt: - def get_num_procs(): - # this returns number of procs - return mkl_rt.mkl_get_max_threads() - elif openblas_lib: - # do this to throw exception if it doesn't exist - openblas_lib.openblas_get_num_procs() - def get_num_procs(): - """Get the total number of physical processors""" - return openblas_lib.openblas_get_num_procs() -except AttributeError: - def get_num_procs(): - """Dummy function (symbol not present), returns -1.""" - return -1 - pass - - -@contextlib.contextmanager -def numpy_num_threads(n): - """Temporarily changes the number of OpenBLAS threads. - - Example usage: - - print("Before: {}".format(get_num_threads())) - with num_threads(n): - print("In thread context: {}".format(get_num_threads())) - print("After: {}".format(get_num_threads())) - """ - old_n = get_num_threads() - set_num_threads(n) - try: - yield - finally: - set_num_threads(old_n) diff --git a/examples/megssm/plotting.py b/examples/megssm/plotting.py deleted file mode 100644 index 4e435de6..00000000 --- a/examples/megssm/plotting.py +++ /dev/null @@ -1,107 +0,0 @@ -""" plotting functions """ - -import numpy as np -import matplotlib.pyplot as plt - -def plot_At(A, ci='sd', times=None, ax=None, skipdiag=False, labels=None, - showticks=True, **kwargs): - """ plot traces of each entry of dynamics A in square grid of subplots """ - if A.ndim == 3: - T, d, _ = A.shape - elif A.ndim == 4: - _, T, d, _ = A.shape - - if times is None: - times = np.arange(T) - - if ax is None or ax.shape != (d, d): - fig, ax = plt.subplots(d, d, sharex=True, sharey=True, squeeze=False) - else: - fig = ax[0, 0].figure - - for i in range(d): - for j in range(d): - - # skip and hide subplots on diagonal - if skipdiag and i == j: - ax[i, j].set_visible(False) - continue - - # plot A entry as trace with/without error band - if A.ndim == 3: - ax[i, j].plot(times[:-1], A[:-1, i, j], **kwargs) - elif A.ndim == 4: - plot_fill(A[:, :-1, i, j], ci=ci, times=times[:-1], - ax=ax[i, j], **kwargs) - - # add labels above first row and to the left of the first column - if labels is not None: - if i == 0 or (skipdiag and (i, j) == (1, 0)): - ax[i, j].set_title(labels[j], fontsize=12) - if j == 0 or (skipdiag and (i, j) == (0, 1)): - ax[i, j].set_ylabel(labels[i], fontsize=12) - - # remove x- and y-ticks on subplot - if not showticks: - ax[i, j].set_xticks([]) - ax[i, j].set_yticks([]) - - diag_lims = [0, 1] - off_lims = [-0.25, 0.25] - for ri, row in enumerate(ax): - for ci, a in enumerate(row): - ylim = diag_lims if ri == ci else off_lims - a.set(ylim=ylim, xlim=times[[0, -1]]) - if ri == 0: - a.set_title(a.get_title(), fontsize='small') - if ci == 0: - a.set_ylabel(a.get_ylabel(), fontsize='small') - for line in a.lines: - line.set_clip_on(False) - line.set(lw=1.) - if ci != 0: - a.yaxis.set_major_formatter(plt.NullFormatter()) - if ri != len(labels) - 1: - a.xaxis.set_major_formatter(plt.NullFormatter()) - if ri == ci: - for spine in a.spines.values(): - spine.set(lw=2) - else: - a.axhline(0, color='k', ls=':', lw=1.) - - return fig, ax - -def plot_fill(X, times=None, ax=None, ci='sd', **kwargs): - """ plot mean and error band across first axis of X """ - N, T = X.shape - - if times is None: - times = np.arange(T) - if ax is None: - fig, ax = plt.subplots(1, 1) - - mu = np.mean(X, axis=0) - - # define lower and upper band limits based on ci - if ci == 'sd': # standard deviation - sigma = np.std(X, axis=0) - lower, upper = mu - sigma, mu + sigma - elif ci == 'se': # standard error - stderr = np.std(X, axis=0) / np.sqrt(X.shape[0]) - lower, upper = mu - stderr, mu + stderr - elif ci == '2sd': # 2 standard deviations - sigma = np.std(X, axis=0) - lower, upper = mu - 2 * sigma, mu + 2 * sigma - elif ci == 'max': # range (min to max) - lower, upper = np.min(X, axis=0), np.max(X, axis=0) - elif type(ci) is float and 0 < ci < 1: - # quantile-based confidence interval - a = 1 - ci - lower, upper = np.quantile(X, [a / 2, 1 - a / 2], axis=0) - else: - raise ValueError("ci must be in ('sd', 'se', '2sd', 'max') " - "or float in (0, 1)") - - lines = ax.plot(times, mu, **kwargs) - c = lines[0].get_color() - ax.fill_between(times, lower, upper, color=c, alpha=0.3, lw=0) diff --git a/examples/megssm/util.py b/examples/megssm/util.py deleted file mode 100755 index 63898be3..00000000 --- a/examples/megssm/util.py +++ /dev/null @@ -1,117 +0,0 @@ -from __future__ import division -from __future__ import print_function -from __future__ import absolute_import - -import autograd.numpy as np -from numpy.lib.stride_tricks import as_strided as ast - - -hs = lambda *args: np.concatenate(*args, axis=-1) - -def T_(X): - return np.swapaxes(X, -1, -2) - -def sym(X): - return 0.5*(X + T_(X)) - -def dot3(A, B, C): - return np.dot(A, np.dot(B, C)) - -def relnormdiff(A, B, min_denom=1e-9): - return np.linalg.norm(A - B) / np.maximum(np.linalg.norm(A), min_denom) - -def _ensure_ndim(X, T, ndim): - X = np.require(X, dtype=np.float64, requirements='C') - assert ndim-1 <= X.ndim <= ndim - if X.ndim == ndim: - assert X.shape[0] == T - return X - else: - return ast(X, shape=(T,) + X.shape, strides=(0,) + X.strides) - -def rand_psd(n, minew=0.1, maxew=1.): - # maxew is badly named - if n == 1: - return maxew * np.eye(n) - X = np.random.randn(n,n) - S = np.dot(T_(X), X) - S = sym(S) - ew, ev = np.linalg.eigh(S) - ew -= np.min(ew) - ew /= np.max(ew) - ew *= (maxew - minew) - ew += minew - return dot3(ev, np.diag(ew), T_(ev)) - -def rand_stable(n, maxew=0.9): - A = np.random.randn(n, n) - A *= maxew / np.max(np.abs(np.linalg.eigvals(A))) - return A - -def component_matrix(As, nlags): - """ compute component form of latent VAR process - - [A_1 A_2 ... A_p] - [ I 0 ... 0 ] - [ 0 I 0 0 ] - [ 0 ... I 0 ] - - """ - - d = As.shape[0] - res = np.zeros((d*nlags, d*nlags)) - res[:d] = As - - if nlags > 1: - res[np.arange(d,d*nlags), np.arange(d*nlags-d)] = 1 - - return res - -def linesearch(f, grad_f, xk, pk, step_size=1., tau=0.1, c1=1e-4, - prox_op=None, lam=1.): - """ find a step size via backtracking line search with armijo condition """ - obj_start = f(xk) - grad_xk = grad_f(xk) - obj_new = np.finfo('float').max - armijo_condition = 0 - - if prox_op is None: - prox_op = lambda x, y: x - - while obj_new > armijo_condition: - x_new = prox_op(xk - step_size * pk, lam*step_size) - armijo_condition = obj_start - c1*step_size*(np.sum(pk*grad_xk)) - obj_new = f(x_new) - step_size *= tau - - return step_size/tau - -def soft_thresh_At(At, lam): - At = At.copy() - diag_inds = np.diag_indices(At.shape[1]) - At_diag = np.diagonal(At, axis1=-2, axis2=-1) - - At = np.sign(At) * np.maximum(np.abs(At) - lam, 0.) - - # fill in diagonal with originally updated entries as we're not - # going to penalize them - for tt in range(At.shape[0]): - At[tt][diag_inds] = At_diag[tt] - return At - -def block_thresh_At(At, lam, min_norm=1e-16): - At = At.copy() - diag_inds = np.diag_indices(At.shape[1]) - At_diag = np.diagonal(At, axis1=-2, axis2=-1) - - norms = np.linalg.norm(At, axis=0, keepdims=True) - norms = np.maximum(norms, min_norm) - scales = np.maximum(norms - lam, 0.) - At = scales * (At / norms) - - # fill in diagonal with originally updated entries as we're not - # going to penalize them - for tt in range(At.shape[0]): - At[tt][diag_inds] = At_diag[tt] - return At - diff --git a/examples/mne_util.py b/examples/mne_util.py deleted file mode 100644 index 04b1c796..00000000 --- a/examples/mne_util.py +++ /dev/null @@ -1,294 +0,0 @@ -""" MNE-Python utility functions for preprocessing data and constructing - matrices necessary for MEGLDS analysis """ - -from __future__ import division -from __future__ import print_function -from __future__ import absolute_import - -import mne -import numpy as np -import os.path as op - -from mne.io.pick import pick_types -from mne.utils import logger -from mne import label_sign_flip - -from scipy.sparse import csc_matrix, csr_matrix, diags - - -class ROIToSourceMap(object): - """ class for computing ROI-to-source space mapping matrix """ - - def __init__(self, fwd, labels, label_flip=False): - - src = fwd['src'] - - roiidx = list() - vertidx = list() - - n_lhverts = len(src[0]['vertno']) - n_rhverts = len(src[1]['vertno']) - n_verts = n_lhverts + n_rhverts - offsets = {'lh': 0, 'rh': n_lhverts} - - hemis = {'lh': 0, 'rh': 1} - - # index vector of which ROI a source point belongs to - Q_J = np.zeros(n_verts, dtype=np.int64) - - data = [] - for li, lab in enumerate(labels): - - this_data = np.round(label_sign_flip(lab, src)) - if not label_flip: - this_data.fill(1.) - data.append(this_data) - if isinstance(lab, mne.Label): - comp_labs = [lab] - elif isinstance(lab, mne.BiHemiLabel): - comp_labs = [lab.lh, lab.rh] - - for clab in comp_labs: - hemi = clab.hemi - hi = 0 if hemi == 'lh' else 1 - - lverts = clab.get_vertices_used(vertices=src[hi]['vertno']) - - # gets the indices in the source space vertex array, not the huge - # array. - # use `src[hi]['vertno'][lverts]` to get surface vertex indices to - # plot. - lverts = np.searchsorted(src[hi]['vertno'], lverts) - lverts += offsets[hemi] - vertidx.extend(lverts) - roiidx.extend(np.full(lverts.size, li, dtype=np.int64)) - - # add 1 b/c 0 corresponds to unassigned variance - Q_J[lverts] = li + 1 - - N = len(labels) - M = n_verts - - # construct sparse L matrix - data = np.concatenate(data) - vertidx = np.array(vertidx, int) - roiidx = np.array(roiidx, int) - assert data.shape == vertidx.shape == roiidx.shape - L = csc_matrix((data, (vertidx, roiidx)), shape=(M, N)) - - self.fwd = fwd - self.L = L - self.Q_J = Q_J - self.offsets = offsets - self.n_lhverts = n_lhverts - self.n_rhverts = n_rhverts - self.labels = labels - - return - - @property - def G(self): - return self.fwd['sol']['data'] - - @property - def L(self): - return self._L - - @L.setter - def L(self, val): - self._L = val - - @property - def Q_J(self): - return self._Q_J - - @Q_J.setter - def Q_J(self, val): - self._Q_J = val - - @property - def GL(self): - from util import Carray - return Carray(csr_matrix.dot(self.L.T, self.G.T).T) - - def get_label_vinds(self, label): - li = self.labels.index(label) - if isinstance(label, mne.Label): - label_vert_idx = self.L[:, li].nonzero()[0] - label_vert_idx -= self.offsets[label.hemi] - return label_vert_idx - elif isinstance(label, mne.BiHemiLabel): - # these labels store both hemispheres so subtract the rh offset - # from that part of the vertex array - lh_label_vert_idx = self.L[:self.n_lhverts, li].nonzero()[0] - rh_label_vert_idx = self.L[self.n_lhverts:, li].nonzero()[0] - rh_label_vert_idx[self.n_lhverts:] -= self.offsets['rh'] - return [lh_label_vert_idx, rh_label_vert_idx] - - def get_label_verts(self, label, src): - # if you're thinking of using this to plot, why not just use - # brain.add_label from pysurfer? - if isinstance(label, mne.Label): - hi = 0 if label.hemi == 'lh' else 1 - label_vert_idx = self.get_label_vinds(label) - varray = src[hi]['vertno'][label_vert_idx] - elif isinstance(label, mne.BiHemiLabel): - lh_label_vert_idx, rh_label_vert_idx = self.get_label_vinds(label) - varray = [src[0]['vertno'][lh_label_vert_idx], - src[1]['vertno'][rh_label_vert_idx]] - return varray - - def get_hemi_idx(self, label): - if isinstance(label, mne.Label): - return 0 if label.hemi == 'lh' else 1 - elif isinstance(label, mne.BiHemiLabel): - hemis = [None] * 2 - for i, lab in enumerate([label.lh, label.rh]): - hemis[i] = 0 if lab.hemi == 'lh' else 1 - return hemis - - -def morph_labels(labels, subject_to, subjects_dir=None): - """ morph labels from fsaverage to specified subject """ - - if subjects_dir is None: - subjects_dir = mne.utils.get_subjects_dir() - - if isinstance(labels, mne.Label): - labels = [labels] - - labels_morphed = list() - for lab in labels: - if isinstance(lab, mne.Label): - labels_morphed.append(lab.copy()) - elif isinstance(lab, mne.BiHemiLabel): - labels_morphed.append(lab.lh.copy() + lab.rh.copy()) - - for i, l in enumerate(labels_morphed): - if l.subject == subject_to: - continue - elif l.subject == 'unknown': - print("uknown subject for label %s" % l.name, - "assuming if is 'fsaverage' and morphing") - l.subject = 'fsaverage' - - if isinstance(l, mne.Label): - l.values.fill(1.0) - labels_morphed[i] = l.morph(subject_to=subject_to, - subjects_dir=subjects_dir) - elif isinstance(l, mne.BiHemiLabel): - l.lh.values.fill(1.0) - l.rh.values.fill(1.0) - labels_morphed[i].lh = l.lh.morph(subject_to=subject_to, - subjects_dir=subjects_dir) - labels_morphed[i].rh = l.rh.morph(subject_to=subject_to, - subjects_dir=subjects_dir) - - # make sure there are no duplicate labels - labels_morphed = sorted(list(set(labels_morphed)), key=lambda x: x.name) - - return labels_morphed - - -def apply_projs(epochs, fwd, cov): - """ apply projection operators to fwd and cov """ - proj, _ = mne.io.proj.setup_proj(epochs.info, activate=False) - G = fwd['sol']['data'] - fwd['sol']['data'] = np.dot(proj, G) - - Q = cov.data - if not np.allclose(np.dot(proj, Q), Q): - Q = np.dot(proj, np.dot(Q, proj.T)) - cov.data = Q - - return fwd, cov - - -def scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., - grad_scale=1.): - """ apply per-channel-type scaling to epochs, forward, and covariance """ - # from util import Carray ##skip import just pasted; util also from MEGLDS repo - Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') - Carray32 = lambda X: np.require(X, dtype=np.float32, requirements='C') - Carray = Carray64 - - # get indices for each channel type - ch_names = cov['names'] # same as self.fwd['info']['ch_names'] - sel_eeg = pick_types(fwd['info'], meg=False, eeg=True, ref_meg=False) - sel_mag = pick_types(fwd['info'], meg='mag', eeg=False, ref_meg=False) - sel_grad = pick_types(fwd['info'], meg='grad', eeg=False, ref_meg=False) - #2 channels are removed so idx != ch_name - #can we do idx = c for c in sel?? - #idx_eeg = [ch_names.index(ch_names[c]) for c in sel_eeg] - #idx_mag = [ch_names.index(ch_names[c]) for c in sel_mag] - #idx_grad = [ch_names.index(ch_names[c]) for c in sel_grad] - idx_eeg = [c for c in sel_eeg] - idx_mag = [c for c in sel_mag] - idx_grad = [c for c in sel_grad] - - # retrieve forward and sensor covariance - G = fwd['sol']['data'].copy() - Q = cov.data.copy() - - # scale forward matrix - G[idx_eeg,:] *= eeg_scale - G[idx_mag,:] *= mag_scale - G[idx_grad,:] *= grad_scale - - # construct GL matrix - GL = Carray(csr_matrix.dot(roi_to_src.L.T, G.T).T) - - # scale sensor covariance - Q[np.ix_(idx_eeg, idx_eeg)] *= eeg_scale**2 - Q[np.ix_(idx_mag, idx_mag)] *= mag_scale**2 - Q[np.ix_(idx_grad, idx_grad)] *= grad_scale**2 - - # scale epochs - info = epochs.info.copy() - data = epochs.get_data().copy() - - data[:,idx_eeg,:] *= eeg_scale - data[:,idx_mag,:] *= mag_scale - data[:,idx_grad,:] *= grad_scale - - epochs = mne.EpochsArray(data, info) - - return G, GL, Q, epochs - - -def combine_medial_labels(labels, subject='fsaverage', surf='white', - dist_limit=0.02): - """ combine each hemi pair of labels on medial wall into single label """ - subjects_dir = mne.get_config('SUBJECTS_DIR') - rrs = dict((hemi, mne.read_surface(op.join(subjects_dir, subject, 'surf', - '%s.%s' % (hemi, surf)))[0] / 1000.) - for hemi in ('lh', 'rh')) - use_labels = list() - used = np.zeros(len(labels), bool) - - logger.info('Matching medial regions for %s labels on %s %s, d=%0.1f mm' - % (len(labels), subject, surf, 1000 * dist_limit)) - - for li1, l1 in enumerate(labels): - if used[li1]: - continue - used[li1] = True - use_label = l1.copy() - rr1 = rrs[l1.hemi][l1.vertices] - for li2 in np.where(~used)[0]: - l2 = labels[li2] - same_name = (l2.name.replace(l2.hemi, '') == - l1.name.replace(l1.hemi, '')) - if l2.hemi != l1.hemi and same_name: - rr2 = rrs[l2.hemi][l2.vertices] - mean_min = np.mean(mne.surface._compute_nearest( - rr1, rr2, return_dists=True)[1]) - if mean_min <= dist_limit: - use_label += l2 - used[li2] = True - logger.info(' Matched: ' + l1.name) - use_labels.append(use_label) - - logger.info('Total %d labels' % (len(use_labels),)) - - return use_labels diff --git a/examples/state_space_connectivity.py b/examples/state_space_connectivity.py deleted file mode 100644 index 79d13b5f..00000000 --- a/examples/state_space_connectivity.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Authors: Jordan Drew - -""" - -''' -For 'mne-connectivity/examples/' to show usage of LDS -Use MNE-sample-data for auditory/left -''' - -## import necessary libraries -import mne -import matplotlib.pyplot as plt - -#where should these files live within mne-connectivity repo? -from megssm.models import MEGLDS as LDS -from megssm.plotting import plot_At - -## define paths to sample data -path = None -path = '/Users/jordandrew/Documents/MEG/mne_data'#'/MNE-sample-data' -data_path = mne.datasets.sample.data_path(path=path) -sample_folder = data_path / 'MEG/sample' -subjects_dir = data_path / 'subjects' - -## import raw data and find events -raw_fname = sample_folder / 'sample_audvis_raw.fif' -raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) -events = mne.find_events(raw, stim_channel='STI 014') - -## define epochs using event_dict -event_dict = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3, - 'visual/right': 4, 'face': 5, 'buttonpress': 32} -epochs = mne.Epochs(raw, events, tmin=-0.3, tmax=0.7, event_id=event_dict, - preload=True).pick_types(meg=True,eeg=True,exclude='bads') -epochs = epochs['auditory/left'] # choose condition for analysis - -## read forward solution, remove bad channels -fwd_fname = sample_folder / 'sample_audvis-meg-eeg-oct-6-fwd.fif' -fwd = mne.read_forward_solution(fwd_fname,exclude=raw.info['bads']) -fwd = mne.convert_forward_solution(fwd, force_fixed=True) - -## read in covariance OR compute noise covariance? noise_cov drops bad chs -cov_fname = sample_folder / 'sample_audvis-cov.fif' -cov = mne.read_cov(cov_fname) #has all 366 channels; drop 2? -noise_cov = mne.compute_covariance(epochs, tmax=0) - -## read labels for analysis -label_names = ['AUD-lh', 'AUD-rh', 'Vis-lh', 'Vis-rh'] -labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label') - for label in label_names] - -## initiate model -num_rois = len(labels) -timepts = len(epochs.times) -model = LDS(num_rois, timepts, lam0=0, lam1=100) # only needs the forward, labels, and noise_cov to be initialized - -model.add_subject('sample', subjects_dir, epochs, labels, fwd, noise_cov) -#when to use compute_cov vs read_cov? - -model.fit(niter=100, verbose=1) -At = model.A -assert At.shape == (timepts, num_rois, num_rois) - -plt.rcParams.update( - {'xtick.labelsize': 'x-small', 'ytick.labelsize': 'x-small'}) -fig, ax = plt.subplots(num_rois, num_rois, constrained_layout=True, squeeze=False, - figsize=(12, 10)) -plot_At(At, labels=label_names, times=epochs.times, ax=ax) - - - - - - - - - - - - - - - - From e60d17518c58238e4e5a876fe522aca6a7fbf800 Mon Sep 17 00:00:00 2001 From: jadrew43 Date: Mon, 29 Aug 2022 13:36:34 -0700 Subject: [PATCH 12/17] compare _mne_scale_sensor_data to _scale_sensor_data --- state_space/megssm/mne_util.py | 55 ++++++++++++++----------- state_space/state_space_connectivity.py | 3 +- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/state_space/megssm/mne_util.py b/state_space/megssm/mne_util.py index 81f2ddab..aa86ea7c 100644 --- a/state_space/megssm/mne_util.py +++ b/state_space/megssm/mne_util.py @@ -111,30 +111,30 @@ def apply_projs(epochs, fwd, cov): return fwd, cov -# def _scale_sensor_data(epochs, fwd, cov, roi_to_src, **std): -# """ apply per-channel-type scaling to epochs, forward, and covariance """ - -# for s in std: -# std[s] = 1/std[s] -# snsr_cov = cov.data.copy() -# fwd_src_snsr = fwd['sol']['data'].copy() - -# info = epochs.info.copy() -# data = epochs.get_data().copy() - -# rescale_cov = mne.make_ad_hoc_cov(info, std=std) -# scaler = mne.cov.compute_whitener(rescale_cov, info) -# del rescale_cov -# fwd_src_snsr = scaler[0] @ fwd_src_snsr -# snsr_cov = scaler[0] @ snsr_cov -# data = scaler[0] @ data +def _mne_scale_sensor_data(epochs, fwd, cov, roi_to_src, **std): + """ apply per-channel-type scaling to epochs, forward, and covariance """ + + for s in std: + std[s] = 1/std[s] + snsr_cov = cov.data.copy() + fwd_src_snsr = fwd['sol']['data'].copy() + + info = epochs.info.copy() + data = epochs.get_data().copy() + + rescale_cov = mne.make_ad_hoc_cov(info, std=std) + scaler = mne.cov.compute_whitener(rescale_cov, info) + del rescale_cov + fwd_src_snsr = scaler[0] @ fwd_src_snsr + snsr_cov = scaler[0] @ snsr_cov + data = scaler[0] @ data - # fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, - #fwd_src_snsr.T).T) + fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, + fwd_src_snsr.T).T) -# epochs = mne.EpochsArray(data, info) + epochs = mne.EpochsArray(data, info) -# return fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs + return fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., @@ -218,10 +218,15 @@ def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', (subject_name, W.shape[0])) else: - - fwd_src_snsr, fwd_roi_snsr, cov_snsr, epochs = \ - _scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) - + + scaled_data = _scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) + mne_scaled_data = _mne_scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) + + fwd_src_snsr, fwd_roi_snsr, cov_snsr, epochs = scaled_data#\ + # _scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) + + # for i in range(len(scaled_data)): + # np.testing.assert_allclose(scaled_data[i], mne_scaled_data[i], atol=1e-3) dat = epochs.get_data().copy() dat = Carray(np.swapaxes(dat, -1, -2)) diff --git a/state_space/state_space_connectivity.py b/state_space/state_space_connectivity.py index 4ca68cc4..0c4ef9e9 100644 --- a/state_space/state_space_connectivity.py +++ b/state_space/state_space_connectivity.py @@ -21,8 +21,7 @@ # define paths to sample data -path = None################################ -path = '/Users/jordandrew/Documents/MEG/mne_data'#'/MNE-sample-data' +path = None data_path = mne.datasets.sample.data_path(path=path) sample_folder = data_path / 'MEG/sample' subjects_dir = data_path / 'subjects' From 71c65398127194e0ba5dab6923c25e7b55e07db1 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 30 Aug 2022 13:52:43 -0400 Subject: [PATCH 13/17] FIX: More easily runnable --- state_space/megssm/message_passing.py | 25 +++- state_space/megssm/models.py | 181 ++++++++++++------------ state_space/megssm/numpy_numthreads.py | 18 +-- state_space/state_space_connectivity.py | 43 +++--- 4 files changed, 127 insertions(+), 140 deletions(-) diff --git a/state_space/megssm/message_passing.py b/state_space/megssm/message_passing.py index e560ff49..722cf8e8 100755 --- a/state_space/megssm/message_passing.py +++ b/state_space/megssm/message_passing.py @@ -7,10 +7,23 @@ from .util import T_, sym, dot3, _ensure_ndim, component_matrix, hs -try: - from autograd_linalg import solve_triangular -except ImportError: - raise RuntimeError("must install `autograd_linalg` package") +from scipy.linalg import solve_triangular as _solve_triangular + + +def solve_triangular(L, b, *, lower=True, trans=0): + if hasattr(L, '_value'): # autograd ArrayBox + L = L._value + if hasattr(b, '_value'): + b = b._value + if L.ndim == 3 and b.ndim in (2, 3): + return np.array([ + _solve_triangular(LL, bb, lower=lower, trans=trans) + for LL, bb in zip(L, b)], L.dtype) + elif L.ndim == 2 and b.ndim in (2, 1): + return _solve_triangular(L, b, lower=lower, trans=trans) + raise RuntimeError(f'Unknown shapes {L.shape} and {b.shape}') + + # einsum2 is a parallel version of einsum that works for two arguments try: @@ -304,7 +317,7 @@ def rts_smooth_fast(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False): I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) for t in range(T): - + # condition tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) @@ -361,7 +374,7 @@ def rts_smooth_fast(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False): sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) for t in range(T-2, -1, -1): - + # these names are stolen from mattjj and slinderman #temp_nn = np.dot(A[t], sigmas_smooth[n,t,:,:]) temp_nn = einsum2('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) diff --git a/state_space/megssm/models.py b/state_space/megssm/models.py index 684494b8..b7504ef2 100755 --- a/state_space/megssm/models.py +++ b/state_space/megssm/models.py @@ -5,7 +5,7 @@ import autograd.numpy as np import scipy.optimize as spopt -from autograd import grad +from autograd import grad from autograd import value_and_grad as vgrad from scipy.linalg import LinAlgError @@ -13,18 +13,13 @@ from .util import linesearch, soft_thresh_At, block_thresh_At from .util import relnormdiff from .message_passing import kalman_filter, rts_smooth, rts_smooth_fast -from .message_passing import predict_step, condition +from .message_passing import predict_step, condition, solve_triangular from .numpy_numthreads import numpy_num_threads -from .mne_util import (ROIToSourceMap, _scale_sensor_data, run_pca_on_subject, +from .mne_util import (ROIToSourceMap, _scale_sensor_data, run_pca_on_subject, apply_projs) -try: - from autograd_linalg import solve_triangular -except ImportError: - raise RuntimeError("must install `autograd_linalg` package") - -from autograd.numpy import einsum +from autograd.numpy import einsum from datetime import datetime @@ -40,7 +35,7 @@ def __init__(self): self._nsubjects = 0 def set_data(self, subjectdata): - n_timepts_lst = [self.unpack_subject_data(e)[0].shape[1] for e in + n_timepts_lst = [self.unpack_subject_data(e)[0].shape[1] for e in subjectdata] assert len(list(set(n_timepts_lst))) == 1 self._n_timepts = n_timepts_lst[0] @@ -78,7 +73,7 @@ class MEGLDS(_MEGModel): """ def __init__(self, lam0=0., lam1=0., penalty='ridge', store_St=True): - + super().__init__() self._model_initalized = False self.lam0 = lam0 @@ -97,12 +92,12 @@ def __init__(self, lam0=0., lam1=0., penalty='ridge', store_St=True): self._store_St = bool(store_St) self._all_subject_data = list() - - #SNR boost epochs, bootstraps of 3 + + #SNR boost epochs, bootstraps of 3 def bootstrap_subject(self, epochs, subject_name, seed=8675309, sfreq=100, - lower=None, upper=None, nbootstrap=3, g_nsamples=-5, + lower=None, upper=None, nbootstrap=3, g_nsamples=-5, overwrite=False, validation_set=True): - + datasets = ['train', 'validation'] use_erm = eq = False independent = False @@ -142,7 +137,7 @@ def bootstrap_subject(self, epochs, subject_name, seed=8675309, sfreq=100, if sfreq is not None: print('resampling to %.2f Hz' % sfreq) - + print(":: processing subject %s" % subject_name) np.random.seed(seed) @@ -153,17 +148,17 @@ def bootstrap_subject(self, epochs, subject_name, seed=8675309, sfreq=100, condition_map = {'auditory_left':['auditory_left'], 'auditory_right': ['auditory_right'], - 'visual_left': ['visual_left'], + 'visual_left': ['visual_left'], 'visual_right': ['visual_right']} - condition_eq_map = dict(auditory_left=['auditory_left'], + condition_eq_map = dict(auditory_left=['auditory_left'], auditory_right=['auditory_right'], - visual_left=['visual_left'], + visual_left=['visual_left'], visual_right='visual_right') - + if eq: epochs.equalize_event_counts(list(condition_map)) cond_map = condition_eq_map - + # apply band-pass filter to limit signal to desired frequency band if lower is not None or upper is not None: epochs = epochs.filter(lower, upper) @@ -175,11 +170,11 @@ def bootstrap_subject(self, epochs, subject_name, seed=8675309, sfreq=100, data_bs_all = list() events_bs_all = list() for cond in sorted(cond_map.keys()): - print(" -> condition %s: bootstrapping" % cond, end='') - ep = epochs[cond_map[cond]] + print(" -> condition %s: bootstrapping" % cond, end='') + ep = epochs[cond_map[cond]] dat = ep.get_data().copy() - ntrials, T, p = dat.shape - + ntrials, T, p = dat.shape + use_bootstrap = nbootstrap if g_nsamples == -4: nsamples = 1 @@ -202,7 +197,7 @@ def bootstrap_subject(self, epochs, subject_name, seed=8675309, sfreq=100, if nsamples == 1 and use_bootstrap == ntrials: inds = np.arange(ntrials) else: - inds = np.random.choice(ntrials, + inds = np.random.choice(ntrials, nsamples * use_bootstrap) inds.shape = (use_bootstrap, nsamples) dat_bs = np.mean(dat[inds], axis=1) @@ -225,7 +220,7 @@ def bootstrap_subject(self, epochs, subject_name, seed=8675309, sfreq=100, replace=False) dat_bs = dat_bs[inds] events_bs = events_bs[inds] - + assert dat_bs.shape == (use_bootstrap, T, p) assert events_bs.shape == (use_bootstrap, 3) assert (events_bs[:, 2] == events_bs[0, 2]).all() @@ -247,19 +242,19 @@ def bootstrap_subject(self, epochs, subject_name, seed=8675309, sfreq=100, event_id=epochs.event_id.copy(), on_missing='ignore') return epochs_bs - - def add_subject(self, subject,condition,epochs,labels,fwd, + + def add_subject(self, subject,condition,epochs,labels,fwd, cov): - + epochs_bs = self.bootstrap_subject(epochs, subject) epochs_bs = epochs_bs.crop(tmin=-0.2, tmax=0.7) epochs_bs = epochs_bs[condition] - epochs = epochs_bs - + epochs = epochs_bs + cov = cov.pick_channels(epochs.ch_names, ordered=True) fwd = mne.convert_forward_solution(fwd, force_fixed=True) fwd = fwd.pick_channels(epochs.ch_names, ordered=True) - + if not self._model_initalized: n_timepts = len(epochs.times) num_roi = len(labels) @@ -271,19 +266,19 @@ def add_subject(self, subject,condition,epochs,labels,fwd, if len(epochs.times) != self._n_times: raise ValueError(f'Number of time points ({len(epochs.times)})' / 'does not match original count ({self._n_times})') - + cov_scale = 3 # equal to number of bootstrap trials cov['data'] /= cov_scale - fwd, cov = apply_projs(epochs_bs, fwd, cov) - - sdata = run_pca_on_subject(subject, epochs_bs, fwd, cov, labels, - dim_mode='pctvar', mean_center=True) + fwd, cov = apply_projs(epochs_bs, fwd, cov) + + sdata = run_pca_on_subject(subject, epochs_bs, fwd, cov, labels, + dim_mode='pctvar', mean_center=True) data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - + subjectdata = (data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi) - + self._all_subject_data.append(subjectdata) - + self._subject_data[subject] = dict() self._subject_data[subject]['epochs'] = data self._subject_data[subject]['fwd_src_snsr'] = fwd_src_snsr @@ -291,29 +286,29 @@ def add_subject(self, subject,condition,epochs,labels,fwd, self._subject_data[subject]['snsr_cov'] = snsr_cov self._subject_data[subject]['labels'] = labels self._subject_data[subject]['which_roi'] = which_roi - - - def _init_model(self, n_timepts, num_roi, A_t_=None, roi_cov=None, + + + def _init_model(self, n_timepts, num_roi, A_t_=None, roi_cov=None, mu0=None, roi_cov_0=None, log_sigsq_lst=None): - + self._n_times = n_timepts self._subject_data = dict() - + set_default = \ lambda prm, val, deflt: \ - self.__setattr__(prm, val.copy() if val is not None else + self.__setattr__(prm, val.copy() if val is not None else deflt) - + # initialize parameters set_default("A_t_", A_t_, - np.stack([rand_stable(num_roi, maxew=0.7) for _ in + np.stack([rand_stable(num_roi, maxew=0.7) for _ in range(n_timepts)], axis=0)) set_default("roi_cov", roi_cov, rand_psd(num_roi)) set_default("mu0", mu0, np.zeros(num_roi)) set_default("roi_cov_0", roi_cov_0, rand_psd(num_roi)) set_default("log_sigsq_lst", log_sigsq_lst, [np.log(np.random.gamma(2, 1, size=num_roi+1))]) - + # initialize sufficient statistics n_timepts, num_roi, _ = self.A_t_.shape self._B0 = np.zeros((num_roi, num_roi)) @@ -321,7 +316,7 @@ def _init_model(self, n_timepts, num_roi, A_t_=None, roi_cov=None, self._B3 = np.zeros((n_timepts-1, num_roi, num_roi)) self._B2 = np.zeros((n_timepts-1, num_roi, num_roi)) self._B4 = list() - + def set_data(self, subjectdata): # add subject data, re-generate log_sigsq_lst if necessary super().set_data(subjectdata) @@ -339,7 +334,7 @@ def set_data(self, subjectdata): self._B4 = [None] * self._nsubjects def _em_objective(self): - + _, num_roi, _ = self.A_t_.shape L_roi_cov_0 = np.linalg.cholesky(self.roi_cov_0) @@ -365,9 +360,9 @@ def _em_objective(self): roi_cov_t = _ensure_ndim(self.roi_cov, n_timepts, 3) with numpy_num_threads(1): _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ - rts_smooth_fast(Y, self.A_t_, fwd_roi_snsr, - roi_cov_t, R, self.mu0, - self.roi_cov_0, + rts_smooth_fast(Y, self.A_t_, fwd_roi_snsr, + roi_cov_t, R, self.mu0, + self.roi_cov_0, compute_lag1_cov=True) else: @@ -377,12 +372,12 @@ def _em_objective(self): x_smooth_0_outer = einsum('ri,rj->rij', mus_smooth[:,0,:num_roi], mus_smooth[:,0,:num_roi]) - B0 = w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + + B0 = w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + x_smooth_0_outer, axis=0) x_smooth_outer = einsum('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], mus_smooth[:,1:,:num_roi]) - B1 = w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + + B1 = w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + x_smooth_outer, axis=0) z_smooth_outer = einsum('rti,rtj->rtij', mus_smooth[:,:-1,:], mus_smooth[:,:-1,:]) @@ -391,13 +386,13 @@ def _em_objective(self): mus_smooth_outer_l1 = einsum('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], mus_smooth[:,:-1,:]) - B2 = w_s*np.sum(sigmas_tnt_smooth[:,:,:num_roi,:] + + B2 = w_s*np.sum(sigmas_tnt_smooth[:,:,:num_roi,:] + mus_smooth_outer_l1, axis=0) # obj += L1(roi_cov_0) L_roi_cov_0_inv_B0 = solve_triangular(L_roi_cov_0, B0, lower=True) L1 += (ntrials*2.*np.sum(np.log(np.diag(L_roi_cov_0))) - + np.trace(solve_triangular(L_roi_cov_0, L_roi_cov_0_inv_B0, + + np.trace(solve_triangular(L_roi_cov_0, L_roi_cov_0_inv_B0, lower=True, trans='T'))) At = self.A_t_[:-1] @@ -411,16 +406,16 @@ def _em_objective(self): # obj += L2(roi_cov, At) L_roi_cov_inv_tmp = solve_triangular(L_roi_cov, tmp, lower=True) L2 += (ntrials*(n_timepts-1)*2.*np.sum(np.log(np.diag(L_roi_cov))) - + np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_tmp, + + np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_tmp, lower=True, trans='T'))) - res = Y - einsum('ik,ntk->nti', fwd_roi_snsr, + res = Y - einsum('ik,ntk->nti', fwd_roi_snsr, mus_smooth[:,:,:num_roi]) - CP_smooth = einsum('ik,ntkj->ntij', fwd_roi_snsr, + CP_smooth = einsum('ik,ntkj->ntij', fwd_roi_snsr, sigmas_smooth[:,:,:num_roi,:num_roi]) B4 = w_s*(np.sum(einsum('nti,ntj->ntij', res, res), axis=(0,1)) - + np.sum(einsum('ntik,jk->ntij', CP_smooth, + + np.sum(einsum('ntik,jk->ntij', CP_smooth, fwd_roi_snsr), axis=(0,1))) self._B4[s] = B4 @@ -451,16 +446,16 @@ def _em_objective(self): return obj - def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, - A_t_roi_cov_tol=1e-6, verbose=0, update_A_t_=True, + def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, + A_t_roi_cov_tol=1e-6, verbose=0, update_A_t_=True, update_roi_cov=True, update_roi_cov_0=True, stationary_A_t_=False, diag_roi_cov=False, update_sigsq=True, do_final_smoothing=True, average_mus_smooth=True, Atrue=None, tau=0.1, c1=1e-4): self.set_data(self._all_subject_data) - + fxn_start = datetime.now() - + n_timepts, num_roi, _ = self.A_t_.shape # make initial A_t_ stationary if stationary_A_t_ option specified @@ -482,7 +477,7 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, converged = False best_objval = np.finfo('float').max best_params = (self.A_t_.copy(), self.roi_cov.copy(), self.mu0.copy(), - self.roi_cov_0.copy(), [l.copy() for l in + self.roi_cov_0.copy(), [l.copy() for l in self.log_sigsq_lst]) # previous parameter values (for checking convergence) @@ -493,7 +488,7 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, if Atrue is not None: import matplotlib.pyplot as plt - fig_A_t_, ax_A_t_ = plt.subplots(num_roi, num_roi, sharex=True, + fig_A_t_, ax_A_t_ = plt.subplots(num_roi, num_roi, sharex=True, sharey=True) plt.ion() @@ -521,11 +516,11 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, roi_cov_0_prev = self.roi_cov_0.copy() log_sigsq_lst_prev = np.array(self.log_sigsq_lst).copy() - self._m_step(update_A_t_=update_A_t_, + self._m_step(update_A_t_=update_A_t_, update_roi_cov=update_roi_cov, - update_roi_cov_0=update_roi_cov_0, + update_roi_cov_0=update_roi_cov_0, stationary_A_t_=stationary_A_t_, - diag_roi_cov=diag_roi_cov, + diag_roi_cov=diag_roi_cov, update_sigsq=update_sigsq, tau=tau, c1=c1, verbose=verbose) @@ -556,7 +551,7 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, if objval < best_objval: best_objval = objval - best_params = (self.A_t_.copy(), self.roi_cov.copy(), + best_params = (self.A_t_.copy(), self.roi_cov.copy(), self.mu0.copy(), self.roi_cov_0.copy(), [l.copy() for l in self.log_sigsq_lst]) @@ -564,7 +559,7 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, if it >= 1: relnormdiff_At = relnormdiff(self.A_t_[:-1], At_prev) relnormdiff_roi_cov = relnormdiff(self.roi_cov, roi_cov_prev) - relnormdiff_roi_cov_0 = relnormdiff(self.roi_cov_0, + relnormdiff_roi_cov_0 = relnormdiff(self.roi_cov_0, roi_cov_0_prev) relnormdiff_log_sigsq_lst = \ np.array( @@ -581,7 +576,7 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, if verbose > 0: print(" relnormdiff_At: %.3e" % relnormdiff_At) print(" relnormdiff_roi_cov: %.3e" % relnormdiff_roi_cov) - print(" relnormdiff_roi_cov_0: %.3e" % + print(" relnormdiff_roi_cov_0: %.3e" % relnormdiff_roi_cov_0) print(" relnormdiff_log_sigsq_lst:", relnormdiff_log_sigsq_lst) @@ -635,7 +630,7 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, roi_cov_t = _ensure_ndim(self.roi_cov, self._n_timepts, 3) with numpy_num_threads(1): loglik_subject, mus_smooth, _, _, St = \ - rts_smooth(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, + rts_smooth(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, self.mu0, self.roi_cov_0, compute_lag1_cov=False, store_St=self._store_St) @@ -691,8 +686,8 @@ def _e_step(self, verbose=0): with numpy_num_threads(1): _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ - rts_smooth_fast(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, - R, self.mu0, self.roi_cov_0, + rts_smooth_fast(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, + R, self.mu0, self.roi_cov_0, compute_lag1_cov=True) self._mus_smooth_lst.append(mus_smooth) @@ -701,12 +696,12 @@ def _e_step(self, verbose=0): x_smooth_0_outer = einsum('ri,rj->rij', mus_smooth[:,0,:num_roi], mus_smooth[:,0,:num_roi]) - self._B0 += w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + + self._B0 += w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + x_smooth_0_outer, axis=0) x_smooth_outer = einsum('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], mus_smooth[:,1:,:num_roi]) - self._B1 += w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + + self._B1 += w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + x_smooth_outer, axis=0) z_smooth_outer = einsum('rti,rtj->rtij', mus_smooth[:,:-1,:], @@ -723,8 +718,8 @@ def _e_step(self, verbose=0): if verbose > 0: print("\n done") - def _m_step(self, update_A_t_=True, update_roi_cov=True, - update_roi_cov_0=True, stationary_A_t_=False, + def _m_step(self, update_A_t_=True, update_roi_cov=True, + update_roi_cov_0=True, stationary_A_t_=False, diag_roi_cov=False, update_sigsq=True, tau=0.1, c1=1e-4, verbose=0): self._loglik = None @@ -734,15 +729,15 @@ def _m_step(self, update_A_t_=True, update_roi_cov=True, self.roi_cov_0 = (1. / self._ntrials_all) * self._B0 if diag_roi_cov: self.roi_cov_0 = np.diag(np.diag(self.roi_cov_0)) - self.update_A_t_and_roi_cov(update_A_t_=update_A_t_, + self.update_A_t_and_roi_cov(update_A_t_=update_A_t_, update_roi_cov=update_roi_cov, - stationary_A_t_=stationary_A_t_, - diag_roi_cov=diag_roi_cov, tau=tau, + stationary_A_t_=stationary_A_t_, + diag_roi_cov=diag_roi_cov, tau=tau, c1=c1, verbose=verbose) if update_sigsq: self.update_log_sigsq_lst(verbose=verbose) - def update_A_t_and_roi_cov(self, update_A_t_=True, update_roi_cov=True, + def update_A_t_and_roi_cov(self, update_A_t_=True, update_roi_cov=True, stationary_A_t_=False, diag_roi_cov=False, tau=0.1, c1=1e-4, verbose=0): @@ -834,7 +829,7 @@ def update_log_sigsq_lst(self, verbose=0): log_sigsq = self.log_sigsq_lst[s].copy() log_sigsq_obj = lambda x: \ - MEGLDS.L3_obj(x, snsr_cov, fwd_src_snsr, which_roi, B4, + MEGLDS.L3_obj(x, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, n_timepts) log_sigsq_val_and_grad = vgrad(log_sigsq_obj) @@ -854,10 +849,10 @@ def update_log_sigsq_lst(self, verbose=0): if verbose > 1: print("\n done") - + @staticmethod def R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi): - return snsr_cov + np.dot(fwd_src_snsr, + return snsr_cov + np.dot(fwd_src_snsr, sigsq_vals[which_roi][:,None]*fwd_src_snsr.T) def L2_obj(self, At, L_roi_cov): @@ -868,7 +863,7 @@ def L2_obj(self, At, L_roi_cov): elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) L_roi_cov_inv_elbo_2 = solve_triangular(L_roi_cov, elbo_2, lower=True) - obj = np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_elbo_2, + obj = np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_elbo_2, lower=True, trans='T')) obj = obj / self._ntrials_all @@ -881,9 +876,9 @@ def L2_obj(self, At, L_roi_cov): return obj @staticmethod - def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, + def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, n_timepts): - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, np.exp(log_sigsq_vals), + R = MEGLDS.R_(snsr_cov, fwd_src_snsr, np.exp(log_sigsq_vals), which_roi) try: L_R = np.linalg.cholesky(R) @@ -893,5 +888,3 @@ def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, return (ntrials*n_timepts*2.*np.sum(np.log(np.diag(L_R))) + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, trans='T'))) - - diff --git a/state_space/megssm/numpy_numthreads.py b/state_space/megssm/numpy_numthreads.py index 550aa235..d632c21c 100755 --- a/state_space/megssm/numpy_numthreads.py +++ b/state_space/megssm/numpy_numthreads.py @@ -9,23 +9,10 @@ # from Ubuntu repos try_paths = [find_library('openblas')] openblas_lib = None -for libpath in try_paths: - try: - openblas_lib = ctypes.cdll.LoadLibrary(libpath) - break - except Exception: #OSError: - continue +mkl_rt = None #if openblas_lib is None: #raise EnvironmentError('Could not locate an OpenBLAS shared library', 2) -try: - mkl_rt_path = find_library('mkl_rt') - mkl_rt = ctypes.cdll.LoadLibrary(mkl_rt_path) - # print(mkl_rt) -except OSError: - mkl_rt = None - pass - def set_num_threads(n): """Set the current number of threads used by the OpenBLAS server.""" @@ -48,6 +35,9 @@ def get_num_threads(): def get_num_threads(): """Get the current number of threads used by the OpenBLAS server.""" return openblas_lib.openblas_get_num_threads() + else: + def get_num_threads(): + return -1 except AttributeError: def get_num_threads(): """Dummy function (symbol not present in %s), returns -1.""" diff --git a/state_space/state_space_connectivity.py b/state_space/state_space_connectivity.py index 0c4ef9e9..577049c0 100644 --- a/state_space/state_space_connectivity.py +++ b/state_space/state_space_connectivity.py @@ -6,7 +6,7 @@ """ ''' -For 'mne-connectivity/examples/' to show usage of LDS +For 'mne-connectivity/examples/' to show usage of LDS Use MNE-sample-data for auditory/left ''' @@ -26,8 +26,8 @@ sample_folder = data_path / 'MEG/sample' subjects_dir = data_path / 'subjects' -## import raw data and find events -raw_fname = sample_folder / 'sample_audvis_raw.fif' +## import raw data and find events +raw_fname = sample_folder / 'sample_audvis_raw.fif' raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) events = mne.find_events(raw, stim_channel='STI 014') @@ -39,17 +39,24 @@ condition = 'auditory_left' ## read forward solution, remove bad channels -fwd_fname = sample_folder / 'sample_audvis-meg-eeg-oct-6-fwd.fif' +fwd_fname = sample_folder / 'sample_audvis-meg-eeg-oct-6-fwd.fif' fwd = mne.read_forward_solution(fwd_fname) -## read in covariance +## read in covariance cov_fname = sample_folder / 'sample_audvis-cov.fif' -cov = mne.read_cov(cov_fname) +cov = mne.read_cov(cov_fname) ## read labels for analysis -label_names = ['AUD-lh', 'AUD-rh', 'Vis-lh', 'Vis-rh'] -labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label', - subject='sample') for label in label_names] +regexp = '^(G_temp_sup-G_T_transv.*|Pole_occipital)' +labels = mne.read_labels_from_annot( + 'sample', 'aparc.a2009s', regexp=regexp, subjects_dir=subjects_dir) +label_names = [label.name for label in labels] +assert len(label_names) == 4 +# brain = mne.viz.Brain('sample', surf='inflated', subjects_dir=subjects_dir) +# for label in labels: +# brain.add_label(label) +# raise RuntimeError + # initiate model model = LDS(lam0=0, lam1=100) @@ -65,7 +72,7 @@ with mpl.rc_context(): {'xtick.labelsize': 'x-small', 'ytick.labelsize': 'x-small'} - fig, ax = plt.subplots(num_roi, num_roi, constrained_layout=True, + fig, ax = plt.subplots(num_roi, num_roi, constrained_layout=True, squeeze=False, figsize=(12, 10)) plot_A_t_(A_t_, labels=label_names, times=times, ax=ax) fig.suptitle('testing_') @@ -91,19 +98,3 @@ spine.set(lw=2) else: a.axhline(0, color='k', ls=':', lw=1.) - - - - - - - - - - - - - - - - From b05fd713adb4e87c305a3bdee0db4a606e393779 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 30 Aug 2022 14:00:53 -0400 Subject: [PATCH 14/17] FIX: Now equiv --- state_space/megssm/mne_util.py | 92 ++++++++++++++-------------------- 1 file changed, 38 insertions(+), 54 deletions(-) diff --git a/state_space/megssm/mne_util.py b/state_space/megssm/mne_util.py index aa86ea7c..c584d676 100644 --- a/state_space/megssm/mne_util.py +++ b/state_space/megssm/mne_util.py @@ -17,8 +17,8 @@ class ROIToSourceMap(object): - """ class for computing ROI-to-source space mapping matrix - + """ class for computing ROI-to-source space mapping matrix + Notes ----- The following variables defined here correspond to various matrices @@ -65,7 +65,7 @@ def __init__(self, fwd, labels, label_flip=False): lverts = clab.get_vertices_used(vertices=src[hi]['vertno']) - # gets the indices in the source space vertex array, not the + # gets the indices in the source space vertex array, not the # huge array. # use `src[hi]['vertno'][lverts]` to get surface vertex indices # to plot. @@ -94,7 +94,7 @@ def __init__(self, fwd, labels, label_flip=False): self.n_lhverts = n_lhverts self.n_rhverts = n_rhverts self.labels = labels - + return def apply_projs(epochs, fwd, cov): @@ -111,39 +111,13 @@ def apply_projs(epochs, fwd, cov): return fwd, cov -def _mne_scale_sensor_data(epochs, fwd, cov, roi_to_src, **std): - """ apply per-channel-type scaling to epochs, forward, and covariance """ - - for s in std: - std[s] = 1/std[s] - snsr_cov = cov.data.copy() - fwd_src_snsr = fwd['sol']['data'].copy() - - info = epochs.info.copy() - data = epochs.get_data().copy() - - rescale_cov = mne.make_ad_hoc_cov(info, std=std) - scaler = mne.cov.compute_whitener(rescale_cov, info) - del rescale_cov - fwd_src_snsr = scaler[0] @ fwd_src_snsr - snsr_cov = scaler[0] @ snsr_cov - data = scaler[0] @ data - - fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, - fwd_src_snsr.T).T) - - epochs = mne.EpochsArray(data, info) - - return fwd_src_snsr, fwd_roi_snsr, snsr_cov, epochs - - -def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., +def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., grad_scale=1.): """ apply per-channel-type scaling to epochs, forward, and covariance """ Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') Carray = Carray64 - - + + # get indices for each channel type ch_names = cov['names'] # same as self.fwd['info']['ch_names'] sel_eeg = pick_types(fwd['info'], meg=False, eeg=True, ref_meg=False) @@ -178,8 +152,24 @@ def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., data[:,idx_mag,:] *= mag_scale data[:,idx_grad,:] *= grad_scale + data_mne = epochs.get_data().copy() + std = dict(grad=1. / grad_scale, mag=1. / mag_scale, eeg=1. / eeg_scale) + noproj_info = info.copy() + with noproj_info._unlock(): + noproj_info['projs'] = [] + rescale_cov = mne.make_ad_hoc_cov(noproj_info, std=std) + scaler, ch_names = mne.cov.compute_whitener(rescale_cov, noproj_info) + np.testing.assert_array_equal(np.diag(np.diag(scaler)), scaler) + assert ch_names == info['ch_names'] + data_mne = scaler @ data_mne + assert len(ch_names) == data_mne.shape[1] + for ii, ch_name in enumerate(ch_names): + np.testing.assert_allclose( + data_mne[:, ii].ravel(), data[:, ii].ravel(), + atol=1e-3, rtol=1e-5, err_msg=ch_name) + epochs = mne.EpochsArray(data, info) - + return G, GL, Q, epochs @@ -192,13 +182,13 @@ def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', raise ValueError("dim_mode must be in {'rank', 'pctvar', 'whiten'}") print("running pca for subject %s" % subject_name) - - scales = {'eeg_scale' : 1e8, 'mag_scale' : 1e16, 'grad_scale' : 1e14} - + + scales = {'eeg_scale' : 1e8, 'mag_scale' : 1e16, 'grad_scale' : 1e14} + # compute ROI-to-source map - roi_to_src = ROIToSourceMap(fwd, labels, label_flip) - - + roi_to_src = ROIToSourceMap(fwd, labels, label_flip) + + if dim_mode == 'whiten': fwd_src_snsr, fwd_roi_snsr, cov_snsr, epochs = \ @@ -218,16 +208,10 @@ def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', (subject_name, W.shape[0])) else: - - scaled_data = _scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) - mne_scaled_data = _mne_scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) - - fwd_src_snsr, fwd_roi_snsr, cov_snsr, epochs = scaled_data#\ - # _scale_sensor_data(epochs, fwd, cov, roi_to_src, **scales) - - # for i in range(len(scaled_data)): - # np.testing.assert_allclose(scaled_data[i], mne_scaled_data[i], atol=1e-3) - + + fwd_src_snsr, fwd_roi_snsr, cov_snsr, epochs = _scale_sensor_data( + epochs, fwd, cov, roi_to_src, **scales) + dat = epochs.get_data().copy() dat = Carray(np.swapaxes(dat, -1, -2)) @@ -241,24 +225,24 @@ def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', if dim_mode == 'rank': idx = np.linalg.matrix_rank(np.cov(dat_stacked, rowvar=False)) else: - idx = np.where(np.cumsum(pca.explained_variance_ratio_) > + idx = np.where(np.cumsum(pca.explained_variance_ratio_) > pctvar)[0][0] idx = np.maximum(idx, len(labels)) W = pca.components_[:idx] print("subject %s using %d principal components" % (subject_name, idx)) - + ntrials, T, _ = dat.shape dat_pca = np.dot(dat_stacked, W.T) dat_pca = np.reshape(dat_pca, (ntrials, T, -1)) fwd_src_snsr_pca = np.dot(W, fwd_src_snsr) fwd_roi_snsr_pca = np.dot(W, fwd_roi_snsr) - cov_snsr_pca = np.dot(W,np.dot(cov_snsr, W.T)) + cov_snsr_pca = np.dot(W,np.dot(cov_snsr, W.T)) data = dat_pca - return (data, fwd_roi_snsr_pca, fwd_src_snsr_pca, cov_snsr_pca, + return (data, fwd_roi_snsr_pca, fwd_src_snsr_pca, cov_snsr_pca, roi_to_src.which_roi) From 39c1d1b217e5af882f995ad33f0bf10bc82ef264 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 30 Aug 2022 14:37:03 -0400 Subject: [PATCH 15/17] FIX: Better scaler (but still wrong) --- state_space/megssm/mne_util.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/state_space/megssm/mne_util.py b/state_space/megssm/mne_util.py index c584d676..6bae2cec 100644 --- a/state_space/megssm/mne_util.py +++ b/state_space/megssm/mne_util.py @@ -117,6 +117,15 @@ def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') Carray = Carray64 + info = epochs.info.copy() + std = dict(grad=1. / grad_scale, mag=1. / mag_scale, eeg=1. / eeg_scale) + noproj_info = info.copy() + with noproj_info._unlock(): + noproj_info['projs'] = [] + rescale_cov = mne.make_ad_hoc_cov(noproj_info, std=std) + scaler, ch_names = mne.cov.compute_whitener(rescale_cov, noproj_info) + np.testing.assert_array_equal(np.diag(np.diag(scaler)), scaler) + assert ch_names == info['ch_names'] # get indices for each channel type ch_names = cov['names'] # same as self.fwd['info']['ch_names'] @@ -132,41 +141,30 @@ def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., Q = cov.data.copy() # scale forward matrix + G = scaler.T @ G G[idx_eeg,:] *= eeg_scale G[idx_mag,:] *= mag_scale G[idx_grad,:] *= grad_scale + np.testing.assert_allclose(G, G_mne) # construct GL matrix GL = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, G.T).T) # scale sensor covariance + Q_mne = scaler.T @ Q # @ scaler Q[np.ix_(idx_eeg, idx_eeg)] *= eeg_scale**2 Q[np.ix_(idx_mag, idx_mag)] *= mag_scale**2 Q[np.ix_(idx_grad, idx_grad)] *= grad_scale**2 + np.testing.assert_allclose(Q, Q_mne) # scale epochs - info = epochs.info.copy() data = epochs.get_data().copy() + data_mne = scaler @ epochs.get_data() data[:,idx_eeg,:] *= eeg_scale data[:,idx_mag,:] *= mag_scale data[:,idx_grad,:] *= grad_scale - - data_mne = epochs.get_data().copy() - std = dict(grad=1. / grad_scale, mag=1. / mag_scale, eeg=1. / eeg_scale) - noproj_info = info.copy() - with noproj_info._unlock(): - noproj_info['projs'] = [] - rescale_cov = mne.make_ad_hoc_cov(noproj_info, std=std) - scaler, ch_names = mne.cov.compute_whitener(rescale_cov, noproj_info) - np.testing.assert_array_equal(np.diag(np.diag(scaler)), scaler) - assert ch_names == info['ch_names'] - data_mne = scaler @ data_mne - assert len(ch_names) == data_mne.shape[1] - for ii, ch_name in enumerate(ch_names): - np.testing.assert_allclose( - data_mne[:, ii].ravel(), data[:, ii].ravel(), - atol=1e-3, rtol=1e-5, err_msg=ch_name) + np.testing.assert_allclose(data_mne, data) epochs = mne.EpochsArray(data, info) From e7ffafa4f6614647eaed370c5ebcff4f0401aaa5 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 30 Aug 2022 14:38:03 -0400 Subject: [PATCH 16/17] FIX: Small --- state_space/megssm/mne_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/state_space/megssm/mne_util.py b/state_space/megssm/mne_util.py index 6bae2cec..24b64dc9 100644 --- a/state_space/megssm/mne_util.py +++ b/state_space/megssm/mne_util.py @@ -141,7 +141,7 @@ def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., Q = cov.data.copy() # scale forward matrix - G = scaler.T @ G + G_mne = scaler.T @ G G[idx_eeg,:] *= eeg_scale G[idx_mag,:] *= mag_scale G[idx_grad,:] *= grad_scale From e63252ae2dd9160db516539cff29cb5c706ed3a8 Mon Sep 17 00:00:00 2001 From: jadrew43 Date: Wed, 7 Sep 2022 15:32:35 -0700 Subject: [PATCH 17/17] Final committ for GSOC - code cleaned and functioning properly --- state_space/megssm/message_passing.py | 540 ++---------------------- state_space/megssm/mne_util.py | 119 ++---- state_space/megssm/models.py | 42 +- state_space/state_space_connectivity.py | 27 +- state_space/test_state_space.py | 58 +++ 5 files changed, 155 insertions(+), 631 deletions(-) create mode 100644 state_space/test_state_space.py diff --git a/state_space/megssm/message_passing.py b/state_space/megssm/message_passing.py index 722cf8e8..6295059a 100755 --- a/state_space/megssm/message_passing.py +++ b/state_space/megssm/message_passing.py @@ -5,124 +5,14 @@ import autograd.numpy as np from autograd.scipy.linalg import block_diag -from .util import T_, sym, dot3, _ensure_ndim, component_matrix, hs +from .util import sym, component_matrix, hs -from scipy.linalg import solve_triangular as _solve_triangular - - -def solve_triangular(L, b, *, lower=True, trans=0): - if hasattr(L, '_value'): # autograd ArrayBox - L = L._value - if hasattr(b, '_value'): - b = b._value - if L.ndim == 3 and b.ndim in (2, 3): - return np.array([ - _solve_triangular(LL, bb, lower=lower, trans=trans) - for LL, bb in zip(L, b)], L.dtype) - elif L.ndim == 2 and b.ndim in (2, 1): - return _solve_triangular(L, b, lower=lower, trans=trans) - raise RuntimeError(f'Unknown shapes {L.shape} and {b.shape}') - - - -# einsum2 is a parallel version of einsum that works for two arguments try: - from einsum2 import einsum2 + from autograd_linalg import solve_triangular except ImportError: - # rename standard numpy function if don't have einsum2 - print("=> WARNING: using standard numpy.einsum,", - "consider installing einsum2 package") - from numpy import einsum as einsum2 - - -def kalman_filter(Y, A, C, Q, R, mu0, Q0, store_St=True, sum_logliks=True): - """ Kalman filter that broadcasts over the first dimension. - Handles multiple lag dependence using component form. - - Note: This function doesn't handle control inputs (yet). - - Y : ndarray, shape=(N, T, D) - Observations - - A : ndarray, shape=(T, D*nlag, D*nlag) - Time-varying dynamics matrices - - C : ndarray, shape=(p, D) - Observation matrix - - mu0: ndarray, shape=(D,) - mean of initial state variable - - Q0 : ndarray, shape=(D, D) - Covariance of initial state variable - - Q : ndarray, shape=(D, D) - Covariance of latent states - - R : ndarray, shape=(D, D) - Covariance of observations - """ - - N = Y.shape[0] - T, D, Dnlags = A.shape - nlags = Dnlags // D - AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) - - p = C.shape[0] - CC = hs([C, np.zeros((p, D*(nlags-1)))]) - - QQ = np.zeros((T, Dnlags, Dnlags)) - QQ[:,:D,:D] = Q - - QQ0 = block_diag(*[Q0 for _ in range(nlags)]) - - mu_predict = np.stack([np.tile(mu0, nlags) for _ in range(N)], axis=0) - sigma_predict = np.stack([QQ0 for _ in range(N)], axis=0) - - St = np.empty((N, T, p, p)) if store_St else None - - mus_filt = np.zeros((N, T, Dnlags)) - sigmas_filt = np.zeros((N, T, Dnlags, Dnlags)) - - ll = np.zeros(T) - - for t in range(T): - - # condition - # dot3(CC, sigma_predict, CC.T) + R - tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict) - sigma_pred = np.dot(tmp1, CC.T) + R - sigma_pred = sym(sigma_pred) - - if St is not None: - St[...,t,:,:] = sigma_pred - - res = Y[...,t,:] - np.dot(mu_predict, CC.T) - - L = np.linalg.cholesky(sigma_pred) - v = solve_triangular(L, res, lower=True) - - # log-likelihood over all trials - ll[t] = -0.5*(2.*np.sum(np.log(np.diagonal(L, axis1=1, axis2=2))) - + np.sum(v*v) - + N*p*np.log(2.*np.pi)) - - mus_filt[...,t,:] = mu_predict + einsum2('nki,nk->ni', tmp1, - solve_triangular(L, v, 'T', lower=True)) - - tmp2 = solve_triangular(L, tmp1, lower=True) - sigmas_filt[...,t,:,:] = sym(sigma_predict - einsum2('nki,nkj->nij', tmp2, tmp2)) - - # prediction - mu_predict = einsum2('ik,nk->ni', AA[t], mus_filt[...,t,:]) - - sigma_predict = einsum2('ik,nkl->nil', AA[t], sigmas_filt[...,t,:,:]) - sigma_predict = sym(einsum2('nil,jl->nij', sigma_predict, AA[t]) + QQ[t]) - - if sum_logliks: - ll = np.sum(ll) - return ll, mus_filt, sigmas_filt, St + raise RuntimeError("must install `autograd_linalg` package") +from numpy import einsum def rts_smooth(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False, store_St=True): @@ -187,8 +77,8 @@ def rts_smooth(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False, # condition # sigma_x = dot3(C, sigma_predict, C.T) + R - tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) - sigma_x = einsum2('nik,jk->nij', tmp1, CC) + R + tmp1 = einsum('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) + sigma_x = einsum('nik,jk->nij', tmp1, CC) + R sigma_x = sym(sigma_x) if St is not None: @@ -196,7 +86,7 @@ def rts_smooth(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False, L = np.linalg.cholesky(sigma_x) # res[n] = Y[n,t,:] = np.dot(C, mu_predict[n,t,:]) - res = Y[...,t,:] - einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) + res = Y[...,t,:] - einsum('ik,nk->ni', CC, mu_predict[...,t,:]) v = solve_triangular(L, res, lower=True) # log-likelihood over all trials @@ -204,27 +94,29 @@ def rts_smooth(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False, + np.sum(v*v) + N*p*np.log(2.*np.pi)) - mus_smooth[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', - tmp1, - solve_triangular(L, v, trans='T', lower=True)) + mus_smooth[:,t,:] = mu_predict[:,t,:] + \ + einsum('nki,nk->ni', tmp1, \ + solve_triangular(L, v, trans='T', lower=True)) # tmp2 = L^{-1}*C*sigma_predict tmp2 = solve_triangular(L, tmp1, lower=True) - sigmas_smooth[:,t,:,:] = sym(sigma_predict[:,t,:,:] - einsum2('nki,nkj->nij', tmp2, tmp2)) + sigmas_smooth[:,t,:,:] = sym(sigma_predict[:,t,:,:] - \ + einsum('nki,nkj->nij', tmp2, tmp2)) # prediction #mu_predict = np.dot(A[t], mus_smooth[t]) - mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_smooth[:,t,:]) + mu_predict[:,t+1,:] = einsum('ik,nk->ni', AA[t], mus_smooth[:,t,:]) #sigma_predict = dot3(A[t], sigmas_smooth[t], A[t].T) + Q[t] - tmp = einsum2('ik,nkl->nil', AA[t], sigmas_smooth[:,t,:,:]) - sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) + tmp = einsum('ik,nkl->nil', AA[t], sigmas_smooth[:,t,:,:]) + sigma_predict[:,t+1,:,:] = sym(einsum('nil,jl->nij', tmp, AA[t]) + \ + QQ[t]) for t in range(T-2, -1, -1): # these names are stolen from mattjj and slinderman #temp_nn = np.dot(A[t], sigmas_smooth[n,t,:,:]) - temp_nn = einsum2('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) + temp_nn = einsum('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) L = np.linalg.cholesky(sigma_predict[:,t+1,:,:]) v = solve_triangular(L, temp_nn, lower=True) @@ -233,18 +125,18 @@ def rts_smooth(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False, # {mus,sigmas}_smooth[n,t] contain the filtered estimates so we're # overwriting them on purpose - #mus_smooth[n,t,:] = mus_smooth[n,t,:] + np.dot(T_(Gt_T), mus_smooth[n,t+1,:] - mu_predict[t+1,:]) - mus_smooth[:,t,:] = mus_smooth[:,t,:] + einsum2('nki,nk->ni', Gt_T, mus_smooth[:,t+1,:] - mu_predict[:,t+1,:]) + mus_smooth[:,t,:] = mus_smooth[:,t,:] + \ + einsum('nki,nk->ni', Gt_T, mus_smooth[:,t+1,:] - mu_predict[:,t+1,:]) - #sigmas_smooth[n,t,:,:] = sigmas_smooth[n,t,:,:] + dot3(T_(Gt_T), sigmas_smooth[n,t+1,:,:] - temp_nn, Gt_T) - tmp = einsum2('nki,nkj->nij', Gt_T, sigmas_smooth[:,t+1,:,:] - sigma_predict[:,t+1,:,:]) - tmp = einsum2('nik,nkj->nij', tmp, Gt_T) + tmp = einsum('nki,nkj->nij', Gt_T, sigmas_smooth[:,t+1,:,:] - \ + sigma_predict[:,t+1,:,:]) + tmp = einsum('nik,nkj->nij', tmp, Gt_T) sigmas_smooth[:,t,:,:] = sym(sigmas_smooth[:,t,:,:] + tmp) if compute_lag1_cov: # This matrix is NOT symmetric, so don't symmetrize! - #sigmas_smooth_tnt[n,t,:,:] = np.dot(sigmas_smooth[n,t+1,:,:], Gt_T) - sigmas_smooth_tnt[:,t,:,:] = einsum2('nik,nkj->nij', sigmas_smooth[:,t+1,:,:], Gt_T) + sigmas_smooth_tnt[:,t,:,:] = einsum('nik,nkj->nij', \ + sigmas_smooth[:,t+1,:,:], Gt_T) return ll, mus_smooth, sigmas_smooth, sigmas_smooth_tnt, St @@ -288,7 +180,7 @@ def rts_smooth_fast(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False): CC = hs([C, np.zeros((p, D*(nlags-1)))]) tmp = solve_triangular(L_R, CC, lower=True) Rinv_CC = solve_triangular(L_R, tmp, trans='T', lower=True) - CCT_Rinv_CC = einsum2('ki,kj->ij', CC, Rinv_CC) + CCT_Rinv_CC = einsum('ki,kj->ij', CC, Rinv_CC) # tile L_R across number of trials so solve_triangular # can broadcast over trials properly @@ -319,16 +211,16 @@ def rts_smooth_fast(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False): for t in range(T): # condition - tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) + tmp1 = einsum('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) - res = Y[...,t,:] - einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) + res = Y[...,t,:] - einsum('ik,nk->ni', CC, mu_predict[...,t,:]) # Rinv * res tmp2 = solve_triangular(L_R, res, lower=True) tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) # C^T Rinv * res - tmp3 = einsum2('ki,nk->ni', Rinv_CC, res) + tmp3 = einsum('ki,nk->ni', Rinv_CC, res) # (Pinv + C^T Rinv C)_inv * tmp3 L_P = np.linalg.cholesky(sigma_predict[:,t,:,:]) @@ -340,44 +232,45 @@ def rts_smooth_fast(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False): tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) # Rinv C * tmp3 - tmp3 = einsum2('ik,nk->ni', Rinv_CC, tmp3) + tmp3 = einsum('ik,nk->ni', Rinv_CC, tmp3) # add the two Woodbury * res terms together tmp = tmp2 - tmp3 - mus_smooth[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', tmp1, tmp) + mus_smooth[:,t,:] = mu_predict[:,t,:] + einsum('nki,nk->ni', tmp1, tmp) # Rinv * tmp1 tmp2 = solve_triangular(L_R, tmp1, lower=True) tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) # C^T Rinv * tmp1 - tmp3 = einsum2('ki,nkj->nij', Rinv_CC, tmp1) + tmp3 = einsum('ki,nkj->nij', Rinv_CC, tmp1) # (Pinv + C^T Rinv C)_inv * tmp3 tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) # Rinv C * tmp3 - tmp3 = einsum2('ik,nkj->nij', Rinv_CC, tmp3) + tmp3 = einsum('ik,nkj->nij', Rinv_CC, tmp3) # add the two Woodbury * tmp1 terms together, left-multiply by tmp1 - tmp = einsum2('nki,nkj->nij', tmp1, tmp2 - tmp3) + tmp = einsum('nki,nkj->nij', tmp1, tmp2 - tmp3) sigmas_smooth[:,t,:,:] = sym(sigma_predict[:,t,:,:] - tmp) # prediction - mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_smooth[:,t,:]) + mu_predict[:,t+1,:] = einsum('ik,nk->ni', AA[t], mus_smooth[:,t,:]) #sigma_predict = dot3(A[t], sigmas_smooth[t], A[t].T) + Q[t] - tmp = einsum2('ik,nkl->nil', AA[t], sigmas_smooth[:,t,:,:]) - sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) + tmp = einsum('ik,nkl->nil', AA[t], sigmas_smooth[:,t,:,:]) + sigma_predict[:,t+1,:,:] = sym(einsum('nil,jl->nij', tmp, AA[t]) + \ + QQ[t]) for t in range(T-2, -1, -1): # these names are stolen from mattjj and slinderman #temp_nn = np.dot(A[t], sigmas_smooth[n,t,:,:]) - temp_nn = einsum2('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) + temp_nn = einsum('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) L = np.linalg.cholesky(sigma_predict[:,t+1,:,:]) v = solve_triangular(L, temp_nn, lower=True) @@ -386,360 +279,17 @@ def rts_smooth_fast(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False): # {mus,sigmas}_smooth[n,t] contain the filtered estimates so we're # overwriting them on purpose - #mus_smooth[n,t,:] = mus_smooth[n,t,:] + np.dot(T_(Gt_T), mus_smooth[n,t+1,:] - mu_predict[t+1,:]) - mus_smooth[:,t,:] = mus_smooth[:,t,:] + einsum2('nki,nk->ni', Gt_T, mus_smooth[:,t+1,:] - mu_predict[:,t+1,:]) + mus_smooth[:,t,:] = mus_smooth[:,t,:] + \ + einsum('nki,nk->ni', Gt_T, mus_smooth[:,t+1,:] - mu_predict[:,t+1,:]) - #sigmas_smooth[n,t,:,:] = sigmas_smooth[n,t,:,:] + dot3(T_(Gt_T), sigmas_smooth[n,t+1,:,:] - temp_nn, Gt_T) - tmp = einsum2('nki,nkj->nij', Gt_T, sigmas_smooth[:,t+1,:,:] - sigma_predict[:,t+1,:,:]) - tmp = einsum2('nik,nkj->nij', tmp, Gt_T) + tmp = einsum('nki,nkj->nij', Gt_T, sigmas_smooth[:,t+1,:,:] - \ + sigma_predict[:,t+1,:,:]) + tmp = einsum('nik,nkj->nij', tmp, Gt_T) sigmas_smooth[:,t,:,:] = sym(sigmas_smooth[:,t,:,:] + tmp) if compute_lag1_cov: # This matrix is NOT symmetric, so don't symmetrize! - #sigmas_smooth_tnt[n,t,:,:] = np.dot(sigmas_smooth[n,t+1,:,:], Gt_T) - sigmas_smooth_tnt[:,t,:,:] = einsum2('nik,nkj->nij', sigmas_smooth[:,t+1,:,:], Gt_T) + sigmas_smooth_tnt[:,t,:,:] = einsum('nik,nkj->nij', \ + sigmas_smooth[:,t+1,:,:], Gt_T) return ll, mus_smooth, sigmas_smooth, sigmas_smooth_tnt - - - -def predict(Y, A, C, Q, R, mu0, Q0, pred_var=False): - if pred_var: - return _predict_mean_var(Y, A, C, Q, R, mu0, Q0) - else: - return _predict_mean(Y, A, C, Q, R, mu0, Q0) - - -def _predict_mean_var(Y, A, C, Q, R, mu0, Q0): - """ Model predictions for Y given model parameters. - - Handles multiple lag dependence using component form. - - Note: This function doesn't handle control inputs (yet). - - Y : ndarray, shape=(N, T, D) - Observations - - A : ndarray, shape=(T, D*nlag, D*nlag) - Time-varying dynamics matrices - - C : ndarray, shape=(p, D) - Observation matrix - - mu0: ndarray, shape=(D,) - mean of initial state variable - - Q0 : ndarray, shape=(D, D) - Covariance of initial state variable - - Q : ndarray, shape=(D, D) - Covariance of latent states - - R : ndarray, shape=(D, D) - Covariance of observations - """ - - N, T, _ = Y.shape - _, D, Dnlags = A.shape - nlags = Dnlags // D - AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) - - L_R = np.linalg.cholesky(R) - - p = C.shape[0] - CC = hs([C, np.zeros((p, D*(nlags-1)))]) - tmp = solve_triangular(L_R, CC, lower=True) - Rinv_CC = solve_triangular(L_R, tmp, trans='T', lower=True) - CCT_Rinv_CC = einsum2('ki,kj->ij', CC, Rinv_CC) - - # tile L_R across number of trials so solve_triangular - # can broadcast over trials properly - L_R = np.tile(L_R, (N, 1, 1)) - - QQ = np.zeros((T, Dnlags, Dnlags)) - QQ[:,:D,:D] = Q - - QQ0 = block_diag(*[Q0 for _ in range(nlags)]) - - mu_predict = np.empty((N, T+1, Dnlags)) - sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) - - mus_filt = np.empty((N, T, Dnlags)) - sigmas_filt = np.empty((N, T, Dnlags, Dnlags)) - - mu_predict[:,0,:] = np.tile(mu0, nlags) - sigma_predict[:,0,:,:] = QQ0.copy() - - I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) - - Yhat = np.empty_like(Y) - St = np.empty((N, T, p, p)) - - for t in range(T): - - # condition - # sigma_x = dot3(C, sigma_predict, C.T) + R - tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) - sigma_x = einsum2('nik,jk->nij', tmp1, CC) + R - sigma_x = sym(sigma_x) - - St[...,t,:,:] = sigma_x - - L = np.linalg.cholesky(sigma_x) - Yhat[...,t,:] = einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) - res = Y[...,t,:] - Yhat[...,t,:] - - v = solve_triangular(L, res, lower=True) - - mus_filt[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', - tmp1, - solve_triangular(L, v, trans='T', lower=True)) - - # tmp2 = L^{-1}*C*sigma_predict - tmp2 = solve_triangular(L, tmp1, lower=True) - sigmas_filt[:,t,:,:] = sym(sigma_predict[:,t,:,:] - einsum2('nki,nkj->nij', tmp2, tmp2)) - - # prediction - #mu_predict = np.dot(A[t], mus_filt[t]) - mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_filt[:,t,:]) - - #sigma_predict = dot3(A[t], sigmas_filt[t], A[t].T) + Q[t] - tmp = einsum2('ik,nkl->nil', AA[t], sigmas_filt[:,t,:,:]) - sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) - - # just return the diagonal of the St matrices for marginal predictive - # variances - return Yhat, np.diagonal(St, axis1=-2, axis2=-1) - - -def _predict_mean(Y, A, C, Q, R, mu0, Q0): - """ Model predictions for Y given model parameters. - Handles multiple lag dependence using component form. - - Note: This function doesn't handle control inputs (yet). - - Y : ndarray, shape=(N, T, D) - Observations - - A : ndarray, shape=(T, D*nlag, D*nlag) - Time-varying dynamics matrices - - C : ndarray, shape=(p, D) - Observation matrix - - mu0: ndarray, shape=(D,) - mean of initial state variable - - Q0 : ndarray, shape=(D, D) - Covariance of initial state variable - - Q : ndarray, shape=(D, D) - Covariance of latent states - - R : ndarray, shape=(D, D) - Covariance of observations - """ - - N, T, _ = Y.shape - _, D, Dnlags = A.shape - nlags = Dnlags // D - AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) - - L_R = np.linalg.cholesky(R) - - p = C.shape[0] - CC = hs([C, np.zeros((p, D*(nlags-1)))]) - tmp = solve_triangular(L_R, CC, lower=True) - Rinv_CC = solve_triangular(L_R, tmp, trans='T', lower=True) - CCT_Rinv_CC = einsum2('ki,kj->ij', CC, Rinv_CC) - - # tile L_R across number of trials so solve_triangular - # can broadcast over trials properly - L_R = np.tile(L_R, (N, 1, 1)) - - QQ = np.zeros((T, Dnlags, Dnlags)) - QQ[:,:D,:D] = Q - - QQ0 = block_diag(*[Q0 for _ in range(nlags)]) - - mu_predict = np.empty((N, T+1, Dnlags)) - sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) - - mus_filt = np.empty((N, T, Dnlags)) - sigmas_filt = np.empty((N, T, Dnlags, Dnlags)) - - mu_predict[:,0,:] = np.tile(mu0, nlags) - sigma_predict[:,0,:,:] = QQ0.copy() - - I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) - - Yhat = np.empty_like(Y) - - for t in range(T): - - # condition - tmp1 = einsum2('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) - - Yhat[...,t,:] = einsum2('ik,nk->ni', CC, mu_predict[...,t,:]) - res = Y[...,t,:] - Yhat[...,t,:] - - # Rinv * res - tmp2 = solve_triangular(L_R, res, lower=True) - tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) - - # C^T Rinv * res - tmp3 = einsum2('ki,nk->ni', Rinv_CC, res) - - # (Pinv + C^T Rinv C)_inv * tmp3 - L_P = np.linalg.cholesky(sigma_predict[:,t,:,:]) - tmp = solve_triangular(L_P, I_tiled, lower=True) - Pinv = solve_triangular(L_P, tmp, trans='T', lower=True) - tmp4 = sym(Pinv + CCT_Rinv_CC) - L_tmp4 = np.linalg.cholesky(tmp4) - tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) - tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) - - # Rinv C * tmp3 - tmp3 = einsum2('ik,nk->ni', Rinv_CC, tmp3) - - # add the two Woodbury * res terms together - tmp = tmp2 - tmp3 - - mus_filt[:,t,:] = mu_predict[:,t,:] + einsum2('nki,nk->ni', tmp1, tmp) - - # Rinv * tmp1 - tmp2 = solve_triangular(L_R, tmp1, lower=True) - tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) - - # C^T Rinv * tmp1 - tmp3 = einsum2('ki,nkj->nij', Rinv_CC, tmp1) - - # (Pinv + C^T Rinv C)_inv * tmp3 - tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) - tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) - - # Rinv C * tmp3 - tmp3 = einsum2('ik,nkj->nij', Rinv_CC, tmp3) - - # add the two Woodbury * tmp1 terms together, left-multiply by tmp1 - tmp = einsum2('nki,nkj->nij', tmp1, tmp2 - tmp3) - - sigmas_filt[:,t,:,:] = sym(sigma_predict[:,t,:,:] - tmp) - - # prediction - mu_predict[:,t+1,:] = einsum2('ik,nk->ni', AA[t], mus_filt[:,t,:]) - - #sigma_predict = dot3(A[t], sigmas_filt[t], A[t].T) + Q[t] - tmp = einsum2('ik,nkl->nil', AA[t], sigmas_filt[:,t,:,:]) - sigma_predict[:,t+1,:,:] = sym(einsum2('nil,jl->nij', tmp, AA[t]) + QQ[t]) - - return Yhat - - -def predict_step(mu_filt, sigma_filt, A, Q): - mu_predict = einsum2('ik,nk->ni', A, mu_filt) - tmp = einsum2('ik,nkl->nil', A, sigma_filt) - sigma_predict = sym(einsum2('nil,jl->nij', tmp, A) + Q) - - return mu_predict, sigma_predict - - -def condition(y, C, R, mu_predict, sigma_predict): - # dot3(C, sigma_predict, C.T) + R - tmp1 = einsum2('ik,nkj->nij', C, sigma_predict) - sigma_pred = einsum2('nik,jk->nij', tmp1, C) + R - sigma_pred = sym(sigma_pred) - - L = np.linalg.cholesky(sigma_pred) - # the transpose works b/c of how dot broadcasts - #y_hat = np.dot(mu_predict, C.T) - y_hat = einsum2('ik,nk->ni', C, mu_predict) - res = y - y_hat - v = solve_triangular(L, res, lower=True) - - mu_filt = mu_predict + einsum2('nki,nk->ni', tmp1, solve_triangular(L, v, trans='T', lower=True)) - - tmp2 = solve_triangular(L, tmp1, lower=True) - sigma_filt = sym(sigma_predict - einsum2('nki,nkj->nij', tmp2, tmp2)) - - return y_hat, mu_filt, sigma_filt - - -def logZ(Y, A, C, Q, R, mu0, Q0): - """ Log marginal likelihood using the Kalman filter. - - The algorithm broadcasts over the first dimension which are considered - to be independent realizations. - - Note: This function doesn't handle control inputs (yet). - - Y : ndarray, shape=(N, T, D) - Observations - - A : ndarray, shape=(T, D, D) - Time-varying dynamics matrices - - C : ndarray, shape=(p, D) - Observation matrix - - mu0: ndarray, shape=(D,) - mean of initial state variable - - Q0 : ndarray, shape=(D, D) - Covariance of initial state variable - - Q : ndarray, shape=(D, D) - Covariance of latent states - - R : ndarray, shape=(D, D) - Covariance of observations - """ - - N = Y.shape[0] - T, D, _ = A.shape - p = C.shape[0] - - mu_predict = np.stack([np.copy(mu0) for _ in range(N)], axis=0) - sigma_predict = np.stack([np.copy(Q0) for _ in range(N)], axis=0) - - ll = 0. - - for t in range(T): - - # condition - # sigma_x = dot3(C, sigma_predict, C.T) + R - tmp1 = einsum2('ik,nkj->nij', C, sigma_predict) - sigma_x = einsum2('nik,jk->nij', tmp1, C) + R - sigma_x = sym(sigma_x) - - # res[n] = Y[n,t,:] = np.dot(C, mu_predict[n]) - res = Y[...,t,:] - einsum2('ik,nk->ni', C, mu_predict) - - L = np.linalg.cholesky(sigma_x) - v = solve_triangular(L, res, lower=True) - - # log-likelihood over all trials - ll += -0.5*(2.*np.sum(np.log(np.diagonal(L, axis1=1, axis2=2))) - + np.sum(v*v) - + N*p*np.log(2.*np.pi)) - - mus_filt = mu_predict + einsum2('nki,nk->ni', - tmp1, - solve_triangular(L, v, trans='T', lower=True)) - - # tmp2 = L^{-1}*C*sigma_predict - tmp2 = solve_triangular(L, tmp1, lower=True) - sigmas_filt = sigma_predict - einsum2('nki,nkj->nij', tmp2, tmp2) - sigmas_filt = sym(sigmas_filt) - - # prediction - #mu_predict = np.dot(A[t], mus_filt[t]) - mu_predict = einsum2('ik,nk->ni', A[t], mus_filt) - - # originally this worked with time-varying Q, but now it's fixed - #sigma_predict = dot3(A[t], sigmas_filt[t], A[t].T) + Q[t] - sigma_predict = einsum2('ik,nkl->nil', A[t], sigmas_filt) - sigma_predict = einsum2('nil,jl->nij', sigma_predict, A[t]) + Q - sigma_predict = sym(sigma_predict) - - return np.sum(ll) diff --git a/state_space/megssm/mne_util.py b/state_space/megssm/mne_util.py index c584d676..61b09165 100644 --- a/state_space/megssm/mne_util.py +++ b/state_space/megssm/mne_util.py @@ -3,18 +3,11 @@ import mne import numpy as np -import os.path as op - -from mne.io.pick import pick_types -from mne.utils import logger from mne import label_sign_flip - -from scipy.sparse import csc_matrix, csr_matrix, diags +from scipy.sparse import csc_matrix, csr_matrix from sklearn.decomposition import PCA -Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') -Carray = Carray64 - +Carray = lambda X: np.require(X, dtype=np.float64, requirements='C') class ROIToSourceMap(object): """ class for computing ROI-to-source space mapping matrix @@ -41,9 +34,7 @@ def __init__(self, fwd, labels, label_flip=False): n_rhverts = len(src[1]['vertno']) n_verts = n_lhverts + n_rhverts offsets = {'lh': 0, 'rh': n_lhverts} - - hemis = {'lh': 0, 'rh': 1} - + # index vector of which ROI a source point belongs to which_roi = np.zeros(n_verts, dtype=np.int64) @@ -114,45 +105,12 @@ def apply_projs(epochs, fwd, cov): def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., mag_scale=1., grad_scale=1.): """ apply per-channel-type scaling to epochs, forward, and covariance """ - Carray64 = lambda X: np.require(X, dtype=np.float64, requirements='C') - Carray = Carray64 - - + # get indices for each channel type - ch_names = cov['names'] # same as self.fwd['info']['ch_names'] - sel_eeg = pick_types(fwd['info'], meg=False, eeg=True, ref_meg=False) - sel_mag = pick_types(fwd['info'], meg='mag', eeg=False, ref_meg=False) - sel_grad = pick_types(fwd['info'], meg='grad', eeg=False, ref_meg=False) - idx_eeg = [ch_names.index(ch_names[c]) for c in sel_eeg] - idx_mag = [ch_names.index(ch_names[c]) for c in sel_mag] - idx_grad = [ch_names.index(ch_names[c]) for c in sel_grad] - - # retrieve forward and sensor covariance - G = fwd['sol']['data'].copy() - Q = cov.data.copy() - - # scale forward matrix - G[idx_eeg,:] *= eeg_scale - G[idx_mag,:] *= mag_scale - G[idx_grad,:] *= grad_scale - - # construct GL matrix - GL = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, G.T).T) - - # scale sensor covariance - Q[np.ix_(idx_eeg, idx_eeg)] *= eeg_scale**2 - Q[np.ix_(idx_mag, idx_mag)] *= mag_scale**2 - Q[np.ix_(idx_grad, idx_grad)] *= grad_scale**2 - - # scale epochs + ch_names = cov['names'] + + # build scaler info = epochs.info.copy() - data = epochs.get_data().copy() - - data[:,idx_eeg,:] *= eeg_scale - data[:,idx_mag,:] *= mag_scale - data[:,idx_grad,:] *= grad_scale - - data_mne = epochs.get_data().copy() std = dict(grad=1. / grad_scale, mag=1. / mag_scale, eeg=1. / eeg_scale) noproj_info = info.copy() with noproj_info._unlock(): @@ -161,16 +119,26 @@ def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., scaler, ch_names = mne.cov.compute_whitener(rescale_cov, noproj_info) np.testing.assert_array_equal(np.diag(np.diag(scaler)), scaler) assert ch_names == info['ch_names'] - data_mne = scaler @ data_mne - assert len(ch_names) == data_mne.shape[1] - for ii, ch_name in enumerate(ch_names): - np.testing.assert_allclose( - data_mne[:, ii].ravel(), data[:, ii].ravel(), - atol=1e-3, rtol=1e-5, err_msg=ch_name) + + # retrieve forward and sensor covariance + fwd_src_snsr = fwd['sol']['data'].copy() + roi_cov = cov.data.copy() + + # scale forward matrix + fwd_src_snsr = scaler @ fwd_src_snsr + + # construct fwd_roi_snsr matrix + fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, fwd_src_snsr.T).T) + # scale sensor covariance + roi_cov = scaler.T @ roi_cov @ scaler + + # scale epochs + data = epochs.get_data().copy() + data = scaler.T @ data epochs = mne.EpochsArray(data, info) - return G, GL, Q, epochs + return fwd_src_snsr, fwd_roi_snsr, roi_cov, epochs def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', @@ -188,7 +156,6 @@ def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', # compute ROI-to-source map roi_to_src = ROIToSourceMap(fwd, labels, label_flip) - if dim_mode == 'whiten': fwd_src_snsr, fwd_roi_snsr, cov_snsr, epochs = \ @@ -244,41 +211,3 @@ def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', return (data, fwd_roi_snsr_pca, fwd_src_snsr_pca, cov_snsr_pca, roi_to_src.which_roi) - - -def combine_medial_labels(labels, subject='fsaverage', surf='white', - dist_limit=0.02): - """ combine each hemi pair of labels on medial wall into single label """ - subjects_dir = mne.get_config('SUBJECTS_DIR') - rrs = dict((hemi, mne.read_surface(op.join(subjects_dir, subject, 'surf', - '%s.%s' % (hemi, surf)))[0] / 1000.) - for hemi in ('lh', 'rh')) - use_labels = list() - used = np.zeros(len(labels), bool) - - logger.info('Matching medial regions for %s labels on %s %s, d=%0.1f mm' - % (len(labels), subject, surf, 1000 * dist_limit)) - - for li1, l1 in enumerate(labels): - if used[li1]: - continue - used[li1] = True - use_label = l1.copy() - rr1 = rrs[l1.hemi][l1.vertices] - for li2 in np.where(~used)[0]: - l2 = labels[li2] - same_name = (l2.name.replace(l2.hemi, '') == - l1.name.replace(l1.hemi, '')) - if l2.hemi != l1.hemi and same_name: - rr2 = rrs[l2.hemi][l2.vertices] - mean_min = np.mean(mne.surface._compute_nearest( - rr1, rr2, return_dists=True)[1]) - if mean_min <= dist_limit: - use_label += l2 - used[li2] = True - logger.info(' Matched: ' + l1.name) - use_labels.append(use_label) - - logger.info('Total %d labels' % (len(use_labels),)) - - return use_labels \ No newline at end of file diff --git a/state_space/megssm/models.py b/state_space/megssm/models.py index b7504ef2..f6ab8038 100755 --- a/state_space/megssm/models.py +++ b/state_space/megssm/models.py @@ -1,5 +1,4 @@ import sys -import os import mne import autograd.numpy as np @@ -12,19 +11,22 @@ from .util import _ensure_ndim, rand_stable, rand_psd from .util import linesearch, soft_thresh_At, block_thresh_At from .util import relnormdiff -from .message_passing import kalman_filter, rts_smooth, rts_smooth_fast -from .message_passing import predict_step, condition, solve_triangular +from .message_passing import rts_smooth, rts_smooth_fast from .numpy_numthreads import numpy_num_threads -from .mne_util import (ROIToSourceMap, _scale_sensor_data, run_pca_on_subject, - apply_projs) +from .mne_util import run_pca_on_subject, apply_projs +try: + from autograd_linalg import solve_triangular +except ImportError: + raise RuntimeError("must install `autograd_linalg` package") + from autograd.numpy import einsum from datetime import datetime -class _MEGModel(object): +class _Model(object): """ Base class for any model applied to MEG data that handles storing and unpacking data from tuples. """ @@ -67,7 +69,7 @@ def unpack_subject_data(cls, sdata): return Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi -class MEGLDS(_MEGModel): +class LDS(_Model): """ State-space model for MEG data, as described in "A state-space model of cross-region dynamic connectivity in MEG/EEG", Yang et al., NIPS 2016. """ @@ -99,7 +101,7 @@ def bootstrap_subject(self, epochs, subject_name, seed=8675309, sfreq=100, overwrite=False, validation_set=True): datasets = ['train', 'validation'] - use_erm = eq = False + # use_erm = eq = False independent = False if g_nsamples == 0: print('nsamples == 0, ensuring independence of samples') @@ -108,7 +110,7 @@ def bootstrap_subject(self, epochs, subject_name, seed=8675309, sfreq=100, print("using half of trials per sample") elif g_nsamples == -2: print("using empty room noise at half of trials per sample") - use_erm = True + # use_erm = True elif g_nsamples == -3: print("using independent and trial-count equalized samples") eq = True @@ -144,7 +146,7 @@ def bootstrap_subject(self, epochs, subject_name, seed=8675309, sfreq=100, for dataset in datasets: print(' generating ', dataset, ' set') - datadir = './data' + # datadir = './data' condition_map = {'auditory_left':['auditory_left'], 'auditory_right': ['auditory_right'], @@ -247,10 +249,10 @@ def add_subject(self, subject,condition,epochs,labels,fwd, cov): epochs_bs = self.bootstrap_subject(epochs, subject) - epochs_bs = epochs_bs.crop(tmin=-0.2, tmax=0.7) epochs_bs = epochs_bs[condition] epochs = epochs_bs + # ensure cov and fwd use correct channels cov = cov.pick_channels(epochs.ch_names, ordered=True) fwd = mne.convert_forward_solution(fwd, force_fixed=True) fwd = fwd.pick_channels(epochs.ch_names, ordered=True) @@ -266,7 +268,8 @@ def add_subject(self, subject,condition,epochs,labels,fwd, if len(epochs.times) != self._n_times: raise ValueError(f'Number of time points ({len(epochs.times)})' / 'does not match original count ({self._n_times})') - + + # scale cov matrix according to number of bootstraps cov_scale = 3 # equal to number of bootstrap trials cov['data'] /= cov_scale fwd, cov = apply_projs(epochs_bs, fwd, cov) @@ -274,7 +277,6 @@ def add_subject(self, subject,condition,epochs,labels,fwd, sdata = run_pca_on_subject(subject, epochs_bs, fwd, cov, labels, dim_mode='pctvar', mean_center=True) data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata - subjectdata = (data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi) self._all_subject_data.append(subjectdata) @@ -352,7 +354,7 @@ def _em_objective(self): ntrials, n_timepts, _ = Y.shape sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) + R = LDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) L_R = np.linalg.cholesky(R) if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None @@ -471,7 +473,6 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, self.roi_cov_0 = np.diag(np.diag(self.roi_cov_0)) self.roi_cov = np.diag(np.diag(self.roi_cov)) - # keeping track of objective value and best parameters objvals = np.zeros(niter+1) converged = False @@ -626,7 +627,7 @@ def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, for s, sdata in enumerate(self.unpack_all_subject_data()): Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) + R = LDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) roi_cov_t = _ensure_ndim(self.roi_cov, self._n_timepts, 3) with numpy_num_threads(1): loglik_subject, mus_smooth, _, _, St = \ @@ -680,8 +681,7 @@ def _e_step(self, verbose=0): Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata sigsq_vals = np.exp(self.log_sigsq_lst[s]) - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) - L_R = np.linalg.cholesky(R) + R = LDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) roi_cov_t = _ensure_ndim(self.roi_cov, self._n_timepts, 3) with numpy_num_threads(1): @@ -823,13 +823,11 @@ def update_log_sigsq_lst(self, verbose=0): Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata ntrials, n_timepts, _ = Y.shape - mus_smooth = self._mus_smooth_lst[s] - sigmas_smooth = self._sigmas_smooth_lst[s] B4 = self._B4[s] log_sigsq = self.log_sigsq_lst[s].copy() log_sigsq_obj = lambda x: \ - MEGLDS.L3_obj(x, snsr_cov, fwd_src_snsr, which_roi, B4, + LDS.L3_obj(x, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, n_timepts) log_sigsq_val_and_grad = vgrad(log_sigsq_obj) @@ -878,7 +876,7 @@ def L2_obj(self, At, L_roi_cov): @staticmethod def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, n_timepts): - R = MEGLDS.R_(snsr_cov, fwd_src_snsr, np.exp(log_sigsq_vals), + R = LDS.R_(snsr_cov, fwd_src_snsr, np.exp(log_sigsq_vals), which_roi) try: L_R = np.linalg.cholesky(R) diff --git a/state_space/state_space_connectivity.py b/state_space/state_space_connectivity.py index 577049c0..2c464258 100644 --- a/state_space/state_space_connectivity.py +++ b/state_space/state_space_connectivity.py @@ -6,7 +6,7 @@ """ ''' -For 'mne-connectivity/examples/' to show usage of LDS +For 'mne-connectivity' examples to show usage of LDS Use MNE-sample-data for auditory/left ''' @@ -16,15 +16,12 @@ import matplotlib as mpl -from megssm.models import MEGLDS as LDS +from megssm.models import LDS from megssm.plotting import plot_A_t_ - # define paths to sample data -path = None -data_path = mne.datasets.sample.data_path(path=path) +data_path = mne.datasets.sample.data_path() sample_folder = data_path / 'MEG/sample' -subjects_dir = data_path / 'subjects' ## import raw data and find events raw_fname = sample_folder / 'sample_audvis_raw.fif' @@ -47,16 +44,9 @@ cov = mne.read_cov(cov_fname) ## read labels for analysis -regexp = '^(G_temp_sup-G_T_transv.*|Pole_occipital)' -labels = mne.read_labels_from_annot( - 'sample', 'aparc.a2009s', regexp=regexp, subjects_dir=subjects_dir) -label_names = [label.name for label in labels] -assert len(label_names) == 4 -# brain = mne.viz.Brain('sample', surf='inflated', subjects_dir=subjects_dir) -# for label in labels: -# brain.add_label(label) -# raise RuntimeError - +label_names = ['Aud-lh', 'Aud-rh', 'Vis-lh', 'Vis-rh'] +labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label', + subject='sample') for label in label_names] # initiate model model = LDS(lam0=0, lam1=100) @@ -69,13 +59,12 @@ times = model.times A_t_ = model.A_t_ assert A_t_.shape == (n_timepts, num_roi, num_roi) - with mpl.rc_context(): {'xtick.labelsize': 'x-small', 'ytick.labelsize': 'x-small'} fig, ax = plt.subplots(num_roi, num_roi, constrained_layout=True, - squeeze=False, figsize=(12, 10)) + squeeze=False, figsize=(12, 10)) plot_A_t_(A_t_, labels=label_names, times=times, ax=ax) - fig.suptitle('testing_') + fig.suptitle('API output_new Q scale_') diag_lims = [0, 1] off_lims = [-0.6, 0.6] for ri, row in enumerate(ax): diff --git a/state_space/test_state_space.py b/state_space/test_state_space.py new file mode 100644 index 00000000..ce206c26 --- /dev/null +++ b/state_space/test_state_space.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Authors: Jordan Drew + +""" + +''' +Test script to ensure LDS API is functioning properly +''' + +import pickle +import mne +from megssm.models import LDS +import numpy as np + +def test_state_space_output(): + + # define paths to sample data + data_path = mne.datasets.sample.data_path() + sample_folder = data_path / 'MEG/sample' + + ## import raw data and find events + raw_fname = sample_folder / 'sample_audvis_raw.fif' + raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) + events = mne.find_events(raw, stim_channel='STI 014') + + ## define epochs using event_dict + event_dict = {'auditory_left': 1, 'auditory_right': 2, 'visual_left': 3, + 'visual_right': 4} + epochs = mne.Epochs(raw, events, tmin=-0.2, tmax=0.7, event_id=event_dict, + preload=True).pick_types(meg=True,eeg=True) + condition = 'auditory_left' + + ## read forward solution, remove bad channels + fwd_fname = sample_folder / 'sample_audvis-meg-eeg-oct-6-fwd.fif' + fwd = mne.read_forward_solution(fwd_fname) + + ## read in covariance + cov_fname = sample_folder / 'sample_audvis-cov.fif' + cov = mne.read_cov(cov_fname) + + ## read labels for analysis + label_names = ['Aud-lh', 'Aud-rh', 'Vis-lh', 'Vis-rh'] + labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label', + subject='sample') for label in label_names] + + # initiate model + model = LDS(lam0=0, lam1=100) + model.add_subject('sample', condition, epochs, labels, fwd, cov) + model.fit(niter=50, verbose=2) + + with open('sample A_t', 'rb') as f: + A_t_ = pickle.load(f) + np.testing.assert_allclose(A_t_, model.A_t_) + print('Model is working!') + +test_state_space_output() \ No newline at end of file