From de35d1fa4c84c9e38e2238898505f6a5dc3fed45 Mon Sep 17 00:00:00 2001 From: Apolline Mellot Date: Wed, 9 Aug 2023 12:10:18 +0200 Subject: [PATCH] patched function --- coffeine/tests/test_transfer_learning.py | 8 +-- coffeine/transfer_learning.py | 86 +++++++++++++++++++++--- 2 files changed, 81 insertions(+), 13 deletions(-) diff --git a/coffeine/tests/test_transfer_learning.py b/coffeine/tests/test_transfer_learning.py index d2b1ecc..2afc008 100644 --- a/coffeine/tests/test_transfer_learning.py +++ b/coffeine/tests/test_transfer_learning.py @@ -51,12 +51,12 @@ def test_rescale(): X_test_str = str.transform(X_test) # Test if dispersion = 1 M_train = mean_covariance(X_train_str, metric='riemann') - disp_train = np.sum( + disp_train = np.mean( distance(X_train_str, M_train, metric='riemann')**2 - ) / X_train_str.shape[0] + ) assert np.isclose(disp_train, 1.0) M_test = mean_covariance(X_test_str, metric='riemann') - disp_test = np.sum( + disp_test = np.mean( distance(X_test_str, M_test, metric='riemann')**2 - ) / X_test_str.shape[0] + ) assert np.isclose(disp_test, 1.0) diff --git a/coffeine/transfer_learning.py b/coffeine/transfer_learning.py index 89f3f62..5a7bdea 100644 --- a/coffeine/transfer_learning.py +++ b/coffeine/transfer_learning.py @@ -1,4 +1,8 @@ import numpy as np +import pyriemann +from pyriemann.transfer._tools import decode_domains +from pyriemann.utils.mean import mean_riemann +from pyriemann.utils.distance import distance from pyriemann.transfer import TLCenter, TLStretch, encode_domains from sklearn.base import BaseEstimator, TransformerMixin @@ -21,6 +25,54 @@ def _check_data(X): return out +class TLStretch_patch(TLStretch): + """Patched function of TLStretch. + + To use in ReScale when pyRiemann version is lower than 0.6""" + + def __init__(self, target_domain, final_dispersion=1.0, + centered_data=False, metric='riemann'): + super().__init__(target_domain, final_dispersion, + centered_data, metric) + + def fit(self, X, y_enc): + """Fit TLStretch_patch. + + Calculate the dispersion around the mean for each domain. + + Parameters + ---------- + X : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices. + y_enc : ndarray, shape (n_matrices,) + Extended labels for each matrix. + + Returns + ------- + self : TLStretch_patch instance + The TLStretch_patch instance. + """ + + _, _, domains = decode_domains(X, y_enc) + n_dim = X[0].shape[1] + self._means = {} + self.dispersions_ = {} + for d in np.unique(domains): + if self.centered_data: + self._means[d] = np.eye(n_dim) + else: + self._means[d] = mean_riemann(X[domains == d]) + disp_domain = distance( + X[domains == d], + self._means[d], + metric=self.metric, + squared=True, + ).mean() + self.dispersions_[d] = disp_domain + + return self + + class ReCenter(BaseEstimator, TransformerMixin): """Re-center each dataset seperately for transfer learning. @@ -145,9 +197,14 @@ def fit(self, X, y): """ X = _check_data(X) _, y_enc = encode_domains(X, y, self.domains) - self.re_scale_ = TLStretch('target_domain', - centered_data=False, - metric=self.metric) + if pyriemann.__version__ != '0.6': + self.re_scale_ = TLStretch_patch( + 'target_domain', centered_data=False, metric=self.metric + ) + else: + self.re_scale_ = TLStretch( + 'target_domain', centered_data=False, metric=self.metric + ) self.dispersions_ = self.re_scale_.fit(X, y_enc).dispersions_ return self @@ -172,9 +229,14 @@ def transform(self, X, y=None): X = _check_data(X) n_sample = X.shape[0] _, y_enc = encode_domains(X, [0]*n_sample, ['target_domain']*n_sample) - self.re_scale_ = TLStretch('target_domain', - centered_data=False, - metric=self.metric) + if pyriemann.__version__ != '0.6': + self.re_scale_ = TLStretch_patch( + 'target_domain', centered_data=False, metric=self.metric + ) + else: + self.re_scale_ = TLStretch( + 'target_domain', centered_data=False, metric=self.metric + ) self.re_scale_.fit(X, y_enc) X_str = self.re_scale_.transform(X) return X_str @@ -198,8 +260,14 @@ def fit_transform(self, X, y): """ X = _check_data(X) _, y_enc = encode_domains(X, y, self.domains) - self.re_scale_ = TLStretch('target_domain', - centered_data=False, - metric=self.metric) + print(pyriemann.__version__) + if pyriemann.__version__ != '0.6': + self.re_scale_ = TLStretch_patch( + 'target_domain', centered_data=False, metric=self.metric + ) + else: + self.re_scale_ = TLStretch( + 'target_domain', centered_data=False, metric=self.metric + ) X_str = self.re_scale_.fit_transform(X, y_enc) return X_str