Skip to content

Commit

Permalink
patched function
Browse files Browse the repository at this point in the history
  • Loading branch information
Apolline Mellot committed Aug 9, 2023
1 parent c95d96f commit de35d1f
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 13 deletions.
8 changes: 4 additions & 4 deletions coffeine/tests/test_transfer_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def test_rescale():
X_test_str = str.transform(X_test)
# Test if dispersion = 1
M_train = mean_covariance(X_train_str, metric='riemann')
disp_train = np.sum(
disp_train = np.mean(
distance(X_train_str, M_train, metric='riemann')**2
) / X_train_str.shape[0]
)
assert np.isclose(disp_train, 1.0)
M_test = mean_covariance(X_test_str, metric='riemann')
disp_test = np.sum(
disp_test = np.mean(
distance(X_test_str, M_test, metric='riemann')**2
) / X_test_str.shape[0]
)
assert np.isclose(disp_test, 1.0)
86 changes: 77 additions & 9 deletions coffeine/transfer_learning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import numpy as np
import pyriemann
from pyriemann.transfer._tools import decode_domains
from pyriemann.utils.mean import mean_riemann
from pyriemann.utils.distance import distance
from pyriemann.transfer import TLCenter, TLStretch, encode_domains
from sklearn.base import BaseEstimator, TransformerMixin

Expand All @@ -21,6 +25,54 @@ def _check_data(X):
return out


class TLStretch_patch(TLStretch):
"""Patched function of TLStretch.
To use in ReScale when pyRiemann version is lower than 0.6"""

def __init__(self, target_domain, final_dispersion=1.0,
centered_data=False, metric='riemann'):
super().__init__(target_domain, final_dispersion,
centered_data, metric)

def fit(self, X, y_enc):
"""Fit TLStretch_patch.
Calculate the dispersion around the mean for each domain.
Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
y_enc : ndarray, shape (n_matrices,)
Extended labels for each matrix.
Returns
-------
self : TLStretch_patch instance
The TLStretch_patch instance.
"""

_, _, domains = decode_domains(X, y_enc)
n_dim = X[0].shape[1]
self._means = {}
self.dispersions_ = {}
for d in np.unique(domains):
if self.centered_data:
self._means[d] = np.eye(n_dim)
else:
self._means[d] = mean_riemann(X[domains == d])
disp_domain = distance(
X[domains == d],
self._means[d],
metric=self.metric,
squared=True,
).mean()
self.dispersions_[d] = disp_domain

return self


class ReCenter(BaseEstimator, TransformerMixin):
"""Re-center each dataset seperately for transfer learning.
Expand Down Expand Up @@ -145,9 +197,14 @@ def fit(self, X, y):
"""
X = _check_data(X)
_, y_enc = encode_domains(X, y, self.domains)
self.re_scale_ = TLStretch('target_domain',
centered_data=False,
metric=self.metric)
if pyriemann.__version__ != '0.6':
self.re_scale_ = TLStretch_patch(
'target_domain', centered_data=False, metric=self.metric
)
else:
self.re_scale_ = TLStretch(
'target_domain', centered_data=False, metric=self.metric
)
self.dispersions_ = self.re_scale_.fit(X, y_enc).dispersions_
return self

Expand All @@ -172,9 +229,14 @@ def transform(self, X, y=None):
X = _check_data(X)
n_sample = X.shape[0]
_, y_enc = encode_domains(X, [0]*n_sample, ['target_domain']*n_sample)
self.re_scale_ = TLStretch('target_domain',
centered_data=False,
metric=self.metric)
if pyriemann.__version__ != '0.6':
self.re_scale_ = TLStretch_patch(
'target_domain', centered_data=False, metric=self.metric
)
else:
self.re_scale_ = TLStretch(
'target_domain', centered_data=False, metric=self.metric
)
self.re_scale_.fit(X, y_enc)
X_str = self.re_scale_.transform(X)
return X_str
Expand All @@ -198,8 +260,14 @@ def fit_transform(self, X, y):
"""
X = _check_data(X)
_, y_enc = encode_domains(X, y, self.domains)
self.re_scale_ = TLStretch('target_domain',
centered_data=False,
metric=self.metric)
print(pyriemann.__version__)
if pyriemann.__version__ != '0.6':
self.re_scale_ = TLStretch_patch(
'target_domain', centered_data=False, metric=self.metric
)
else:
self.re_scale_ = TLStretch(
'target_domain', centered_data=False, metric=self.metric
)
X_str = self.re_scale_.fit_transform(X, y_enc)
return X_str

0 comments on commit de35d1f

Please sign in to comment.