Skip to content

Commit

Permalink
deal domains with metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
apmellot committed Mar 5, 2024
1 parent 1cea3fb commit ef0830c
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 61 deletions.
31 changes: 15 additions & 16 deletions coffeine/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
ReScale
)

import sklearn
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.compose import make_column_transformer
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import RidgeCV, LogisticRegression

sklearn.set_config(enable_metadata_routing=True)


class GaussianKernel(BaseEstimator, TransformerMixin):
"""Gaussian (squared exponential) Kernel.
Expand Down Expand Up @@ -148,7 +151,6 @@ def make_filter_bank_transformer(
names: list[str],
method: str = 'riemann',
alignment: Union[list[str], None] = None,
domains: Union[list[str], None] = None,
projection_params: Union[dict, None] = None,
vectorization_params: Union[dict, None] = None,
kernel: Union[str, Pipeline, None] = None,
Expand Down Expand Up @@ -194,8 +196,6 @@ def make_filter_bank_transformer(
alignment : list of str | None
Alignment steps to include in the pipeline. Can be ``'re-center'``,
``'re-scale'``.
domains : list of str | None
Domains for each matrix.
projection_params : dict | None
The parameters for the projection step.
vectorization_params : dict | None
Expand Down Expand Up @@ -261,8 +261,6 @@ def make_filter_bank_transformer(
if vectorization_params is not None:
vectorization_params_.update(**vectorization_params)

alignment_params_ = dict(domains=domains)

def _get_projector_vectorizer(projection, vectorization,
recenter, rescale,
kernel=None):
Expand All @@ -271,11 +269,18 @@ def _get_projector_vectorizer(projection, vectorization,
steps = [projection(**projection_params_)]

if recenter is not None:
steps.append(recenter(**alignment_params_))
steps.append(
recenter().set_fit_request(
domains=True
).set_transform_request(domains=True)
)

if rescale is not None:
steps.append(rescale(**alignment_params_))

steps.append(
rescale().set_fit_request(
domains=True
).set_transform_request(domains=True)
)
steps.append(vectorization(**vectorization_params_))

if kernel is not None:
Expand Down Expand Up @@ -348,7 +353,6 @@ def make_filter_bank_regressor(
names: list[str],
method: str = 'riemann',
alignment: Union[list[str], None] = None,
domains: Union[list[str], None] = None,
projection_params: Union[dict, None] = None,
vectorization_params: Union[dict, None] = None,
categorical_interaction: Union[bool, None] = None,
Expand Down Expand Up @@ -394,8 +398,6 @@ def make_filter_bank_regressor(
alignment : list of str | None
Alignment steps to include in the pipeline. Can be ``'re-center'``,
``'re-scale'``.
domains : list of str | None
Domains for each matrix.
projection_params : dict | None
The parameters for the projection step.
vectorization_params : dict | None
Expand All @@ -419,7 +421,7 @@ def make_filter_bank_regressor(
https://doi.org/10.1016/j.neuroimage.2020.116893
"""
filter_bank_transformer = make_filter_bank_transformer(
names=names, method=method, alignment=alignment, domains=domains,
names=names, method=method, alignment=alignment,
projection_params=projection_params,
vectorization_params=vectorization_params,
categorical_interaction=categorical_interaction
Expand All @@ -446,7 +448,6 @@ def make_filter_bank_classifier(
names: list[str],
method: str = 'riemann',
alignment: Union[list[str], None] = None,
domains: Union[list[str], None] = None,
projection_params: Union[dict, None] = None,
vectorization_params: Union[dict, None] = None,
categorical_interaction: Union[bool, None] = None,
Expand Down Expand Up @@ -492,8 +493,6 @@ def make_filter_bank_classifier(
alignment : list of str | None
Alignment steps to include in the pipeline. Can be ``'re-center'``,
``'re-scale'``.
domains : list of str | None
Domains for each matrix.
projection_params : dict | None
The parameters for the projection step.
vectorization_params : dict | None
Expand All @@ -518,7 +517,7 @@ def make_filter_bank_classifier(
"""
filter_bank_transformer = make_filter_bank_transformer(
names=names, method=method, alignment=alignment, domains=domains,
names=names, method=method, alignment=alignment,
projection_params=projection_params,
vectorization_params=vectorization_params,
categorical_interaction=categorical_interaction
Expand Down
85 changes: 49 additions & 36 deletions coffeine/transfer_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,23 @@ def _check_data(X):
return out


def _check_domains(domains, n_sample, target=False):
out = None
if domains is None:
if not target:
out = ['source_domain']*n_sample
else:
out = ['target_domain']*n_sample
else:
if target and 'target_domain' not in domains:
raise ValueError(
"The target domains should include 'target_domain'"
)
else:
out = domains
return out


class TLStretch_patch(TLStretch):
"""Patched function of TLStretch.
Expand Down Expand Up @@ -86,11 +103,10 @@ class ReCenter(BaseEstimator, TransformerMixin):
metric : str, default='riemann'
The metric to compute the mean.
"""
def __init__(self, domains, metric='riemann'):
self.domains = domains
def __init__(self, metric='riemann'):
self.metric = metric

def fit(self, X, y):
def fit(self, X, y, domains=None):
"""Fit ReCenter.
Mean of each domain are calculated with TLCenter from
Expand All @@ -107,12 +123,14 @@ def fit(self, X, y):
self : ReCenter instance
"""
X = _check_data(X)
_, y_enc = encode_domains(X, y, self.domains)
domains = _check_domains(domains, X.shape[0], target=False)
self._domains_source = domains
_, y_enc = encode_domains(X, y, domains)
self.re_center_ = TLCenter('target_domain', metric=self.metric)
self.means_ = self.re_center_.fit(X, y_enc).recenter_
return self

def transform(self, X, y=None):
def transform(self, X, y=None, domains=None):
"""Re-center the test data.
Calculate the mean and then transform the data.
Expand All @@ -132,12 +150,16 @@ def transform(self, X, y=None):
"""
X = _check_data(X)
n_sample = X.shape[0]
_, y_enc = encode_domains(X, [0]*n_sample, ['target_domain']*n_sample)
self.re_center_ = TLCenter('target_domain', metric=self.metric)
X_rct = self.re_center_.fit_transform(X, y_enc)
domains = _check_domains(domains, n_sample, target=True)
_, y_enc = encode_domains(X, [0]*n_sample, domains)
# self.re_center_ = TLCenter('target_domain', metric=self.metric)
if 'target_domain' in self._domains_source:
X_rct = self.re_center_.transform(X, y_enc)
else:
X_rct = self.re_center_.fit_transform(X, y_enc)
return X_rct

def fit_transform(self, X, y):
def fit_transform(self, X, y, domains=None):
"""Fit ReCenter and transform the data.
Calculate the mean of each domain with TLCenter from pyRiemann and
Expand All @@ -154,9 +176,10 @@ def fit_transform(self, X, y):
X_rct : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices with mean at the Identity.
"""
self.fit(X, y, domains)
X = _check_data(X)
_, y_enc = encode_domains(X, y, self.domains)
self.re_center_ = TLCenter('target_domain', metric=self.metric)
domains = _check_domains(domains, X.shape[0], target=False)
_, y_enc = encode_domains(X, y, domains)
X_rct = self.re_center_.fit_transform(X, y_enc)
return X_rct

Expand All @@ -175,11 +198,10 @@ class ReScale(BaseEstimator, TransformerMixin):
metric : str, default='riemann'
The metric to compute the dispersion.
"""
def __init__(self, domains, metric='riemann'):
self.domains = domains
def __init__(self, metric='riemann'):
self.metric = metric

def fit(self, X, y):
def fit(self, X, y, domains):
"""Fit ReScale.
Dispersions around the mean of each domain are calculated with
Expand All @@ -196,7 +218,9 @@ def fit(self, X, y):
self : ReScale instance
"""
X = _check_data(X)
_, y_enc = encode_domains(X, y, self.domains)
domains = _check_domains(domains, X.shape[0], target=False)
self._domains_source = domains
_, y_enc = encode_domains(X, y, domains)
if pyriemann.__version__ != '0.6':
self.re_scale_ = TLStretch_patch(
'target_domain', centered_data=False, metric=self.metric
Expand All @@ -208,7 +232,7 @@ def fit(self, X, y):
self.dispersions_ = self.re_scale_.fit(X, y_enc).dispersions_
return self

def transform(self, X, y=None):
def transform(self, X, y=None, domains=None):
"""Re-scale the test data.
Calculate the dispersion around the mean iand then transform the data.
Expand All @@ -228,20 +252,15 @@ def transform(self, X, y=None):
"""
X = _check_data(X)
n_sample = X.shape[0]
_, y_enc = encode_domains(X, [0]*n_sample, ['target_domain']*n_sample)
if pyriemann.__version__ != '0.6':
self.re_scale_ = TLStretch_patch(
'target_domain', centered_data=False, metric=self.metric
)
domains = _check_domains(domains, n_sample, target=True)
_, y_enc = encode_domains(X, [0]*n_sample, domains)
if 'target_domain' in self._domains_source:
X_str = self.re_scale_.transform(X)
else:
self.re_scale_ = TLStretch(
'target_domain', centered_data=False, metric=self.metric
)
self.re_scale_.fit(X, y_enc)
X_str = self.re_scale_.transform(X)
X_str = self.re_scale_.fit_transform(X, y_enc)
return X_str

def fit_transform(self, X, y):
def fit_transform(self, X, y, domains):
"""Fit ReScale and transform the data.
Calculate the dispersions around the mean of each domain with
Expand All @@ -258,15 +277,9 @@ def fit_transform(self, X, y):
X_str : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices with a dispersion equal to 1.
"""
self.fit(X, y, domains)
X = _check_data(X)
_, y_enc = encode_domains(X, y, self.domains)
if pyriemann.__version__ != '0.6':
self.re_scale_ = TLStretch_patch(
'target_domain', centered_data=False, metric=self.metric
)
else:
self.re_scale_ = TLStretch(
'target_domain', centered_data=False, metric=self.metric
)
domains = _check_domains(domains, X.shape[0], target=False)
_, y_enc = encode_domains(X, y, domains)
X_str = self.re_scale_.fit_transform(X, y_enc)
return X_str
Loading

0 comments on commit ef0830c

Please sign in to comment.