diff --git a/coffeine/pipelines.py b/coffeine/pipelines.py index 8657bde..21aa292 100644 --- a/coffeine/pipelines.py +++ b/coffeine/pipelines.py @@ -21,12 +21,15 @@ ReScale ) +import sklearn from sklearn.base import BaseEstimator, TransformerMixin from sklearn.compose import make_column_transformer from sklearn.pipeline import make_pipeline, Pipeline from sklearn.preprocessing import StandardScaler from sklearn.linear_model import RidgeCV, LogisticRegression +sklearn.set_config(enable_metadata_routing=True) + class GaussianKernel(BaseEstimator, TransformerMixin): """Gaussian (squared exponential) Kernel. @@ -148,7 +151,6 @@ def make_filter_bank_transformer( names: list[str], method: str = 'riemann', alignment: Union[list[str], None] = None, - domains: Union[list[str], None] = None, projection_params: Union[dict, None] = None, vectorization_params: Union[dict, None] = None, kernel: Union[str, Pipeline, None] = None, @@ -194,8 +196,6 @@ def make_filter_bank_transformer( alignment : list of str | None Alignment steps to include in the pipeline. Can be ``'re-center'``, ``'re-scale'``. - domains : list of str | None - Domains for each matrix. projection_params : dict | None The parameters for the projection step. vectorization_params : dict | None @@ -261,8 +261,6 @@ def make_filter_bank_transformer( if vectorization_params is not None: vectorization_params_.update(**vectorization_params) - alignment_params_ = dict(domains=domains) - def _get_projector_vectorizer(projection, vectorization, recenter, rescale, kernel=None): @@ -271,11 +269,18 @@ def _get_projector_vectorizer(projection, vectorization, steps = [projection(**projection_params_)] if recenter is not None: - steps.append(recenter(**alignment_params_)) + steps.append( + recenter().set_fit_request( + domains=True + ).set_transform_request(domains=True) + ) if rescale is not None: - steps.append(rescale(**alignment_params_)) - + steps.append( + rescale().set_fit_request( + domains=True + ).set_transform_request(domains=True) + ) steps.append(vectorization(**vectorization_params_)) if kernel is not None: @@ -348,7 +353,6 @@ def make_filter_bank_regressor( names: list[str], method: str = 'riemann', alignment: Union[list[str], None] = None, - domains: Union[list[str], None] = None, projection_params: Union[dict, None] = None, vectorization_params: Union[dict, None] = None, categorical_interaction: Union[bool, None] = None, @@ -394,8 +398,6 @@ def make_filter_bank_regressor( alignment : list of str | None Alignment steps to include in the pipeline. Can be ``'re-center'``, ``'re-scale'``. - domains : list of str | None - Domains for each matrix. projection_params : dict | None The parameters for the projection step. vectorization_params : dict | None @@ -419,7 +421,7 @@ def make_filter_bank_regressor( https://doi.org/10.1016/j.neuroimage.2020.116893 """ filter_bank_transformer = make_filter_bank_transformer( - names=names, method=method, alignment=alignment, domains=domains, + names=names, method=method, alignment=alignment, projection_params=projection_params, vectorization_params=vectorization_params, categorical_interaction=categorical_interaction @@ -446,7 +448,6 @@ def make_filter_bank_classifier( names: list[str], method: str = 'riemann', alignment: Union[list[str], None] = None, - domains: Union[list[str], None] = None, projection_params: Union[dict, None] = None, vectorization_params: Union[dict, None] = None, categorical_interaction: Union[bool, None] = None, @@ -492,8 +493,6 @@ def make_filter_bank_classifier( alignment : list of str | None Alignment steps to include in the pipeline. Can be ``'re-center'``, ``'re-scale'``. - domains : list of str | None - Domains for each matrix. projection_params : dict | None The parameters for the projection step. vectorization_params : dict | None @@ -518,7 +517,7 @@ def make_filter_bank_classifier( """ filter_bank_transformer = make_filter_bank_transformer( - names=names, method=method, alignment=alignment, domains=domains, + names=names, method=method, alignment=alignment, projection_params=projection_params, vectorization_params=vectorization_params, categorical_interaction=categorical_interaction diff --git a/coffeine/transfer_learning.py b/coffeine/transfer_learning.py index aadeac7..2ffada2 100644 --- a/coffeine/transfer_learning.py +++ b/coffeine/transfer_learning.py @@ -25,6 +25,23 @@ def _check_data(X): return out +def _check_domains(domains, n_sample, target=False): + out = None + if domains is None: + if not target: + out = ['source_domain']*n_sample + else: + out = ['target_domain']*n_sample + else: + if target and 'target_domain' not in domains: + raise ValueError( + "The target domains should include 'target_domain'" + ) + else: + out = domains + return out + + class TLStretch_patch(TLStretch): """Patched function of TLStretch. @@ -86,11 +103,10 @@ class ReCenter(BaseEstimator, TransformerMixin): metric : str, default='riemann' The metric to compute the mean. """ - def __init__(self, domains, metric='riemann'): - self.domains = domains + def __init__(self, metric='riemann'): self.metric = metric - def fit(self, X, y): + def fit(self, X, y, domains=None): """Fit ReCenter. Mean of each domain are calculated with TLCenter from @@ -107,12 +123,14 @@ def fit(self, X, y): self : ReCenter instance """ X = _check_data(X) - _, y_enc = encode_domains(X, y, self.domains) + domains = _check_domains(domains, X.shape[0], target=False) + self._domains_source = domains + _, y_enc = encode_domains(X, y, domains) self.re_center_ = TLCenter('target_domain', metric=self.metric) self.means_ = self.re_center_.fit(X, y_enc).recenter_ return self - def transform(self, X, y=None): + def transform(self, X, y=None, domains=None): """Re-center the test data. Calculate the mean and then transform the data. @@ -132,12 +150,16 @@ def transform(self, X, y=None): """ X = _check_data(X) n_sample = X.shape[0] - _, y_enc = encode_domains(X, [0]*n_sample, ['target_domain']*n_sample) - self.re_center_ = TLCenter('target_domain', metric=self.metric) - X_rct = self.re_center_.fit_transform(X, y_enc) + domains = _check_domains(domains, n_sample, target=True) + _, y_enc = encode_domains(X, [0]*n_sample, domains) + # self.re_center_ = TLCenter('target_domain', metric=self.metric) + if 'target_domain' in self._domains_source: + X_rct = self.re_center_.transform(X, y_enc) + else: + X_rct = self.re_center_.fit_transform(X, y_enc) return X_rct - def fit_transform(self, X, y): + def fit_transform(self, X, y, domains=None): """Fit ReCenter and transform the data. Calculate the mean of each domain with TLCenter from pyRiemann and @@ -154,9 +176,10 @@ def fit_transform(self, X, y): X_rct : ndarray, shape (n_matrices, n_channels, n_channels) Set of SPD matrices with mean at the Identity. """ + self.fit(X, y, domains) X = _check_data(X) - _, y_enc = encode_domains(X, y, self.domains) - self.re_center_ = TLCenter('target_domain', metric=self.metric) + domains = _check_domains(domains, X.shape[0], target=False) + _, y_enc = encode_domains(X, y, domains) X_rct = self.re_center_.fit_transform(X, y_enc) return X_rct @@ -175,11 +198,10 @@ class ReScale(BaseEstimator, TransformerMixin): metric : str, default='riemann' The metric to compute the dispersion. """ - def __init__(self, domains, metric='riemann'): - self.domains = domains + def __init__(self, metric='riemann'): self.metric = metric - def fit(self, X, y): + def fit(self, X, y, domains): """Fit ReScale. Dispersions around the mean of each domain are calculated with @@ -196,7 +218,9 @@ def fit(self, X, y): self : ReScale instance """ X = _check_data(X) - _, y_enc = encode_domains(X, y, self.domains) + domains = _check_domains(domains, X.shape[0], target=False) + self._domains_source = domains + _, y_enc = encode_domains(X, y, domains) if pyriemann.__version__ != '0.6': self.re_scale_ = TLStretch_patch( 'target_domain', centered_data=False, metric=self.metric @@ -208,7 +232,7 @@ def fit(self, X, y): self.dispersions_ = self.re_scale_.fit(X, y_enc).dispersions_ return self - def transform(self, X, y=None): + def transform(self, X, y=None, domains=None): """Re-scale the test data. Calculate the dispersion around the mean iand then transform the data. @@ -228,20 +252,15 @@ def transform(self, X, y=None): """ X = _check_data(X) n_sample = X.shape[0] - _, y_enc = encode_domains(X, [0]*n_sample, ['target_domain']*n_sample) - if pyriemann.__version__ != '0.6': - self.re_scale_ = TLStretch_patch( - 'target_domain', centered_data=False, metric=self.metric - ) + domains = _check_domains(domains, n_sample, target=True) + _, y_enc = encode_domains(X, [0]*n_sample, domains) + if 'target_domain' in self._domains_source: + X_str = self.re_scale_.transform(X) else: - self.re_scale_ = TLStretch( - 'target_domain', centered_data=False, metric=self.metric - ) - self.re_scale_.fit(X, y_enc) - X_str = self.re_scale_.transform(X) + X_str = self.re_scale_.fit_transform(X, y_enc) return X_str - def fit_transform(self, X, y): + def fit_transform(self, X, y, domains): """Fit ReScale and transform the data. Calculate the dispersions around the mean of each domain with @@ -258,15 +277,9 @@ def fit_transform(self, X, y): X_str : ndarray, shape (n_matrices, n_channels, n_channels) Set of SPD matrices with a dispersion equal to 1. """ + self.fit(X, y, domains) X = _check_data(X) - _, y_enc = encode_domains(X, y, self.domains) - if pyriemann.__version__ != '0.6': - self.re_scale_ = TLStretch_patch( - 'target_domain', centered_data=False, metric=self.metric - ) - else: - self.re_scale_ = TLStretch( - 'target_domain', centered_data=False, metric=self.metric - ) + domains = _check_domains(domains, X.shape[0], target=False) + _, y_enc = encode_domains(X, y, domains) X_str = self.re_scale_.fit_transform(X, y_enc) return X_str diff --git a/doc/tutorials/filterbank_cross_subject_classification_bci.ipynb b/doc/tutorials/filterbank_cross_subject_classification_bci.ipynb index f1ed641..31ea930 100644 --- a/doc/tutorials/filterbank_cross_subject_classification_bci.ipynb +++ b/doc/tutorials/filterbank_cross_subject_classification_bci.ipynb @@ -172,7 +172,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -188,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -197,7 +197,7 @@ "0.5333333333333333" ] }, - "execution_count": 18, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -208,7 +208,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -216,17 +216,17 @@ " names=list(X_df_source.columns),\n", " method='riemann',\n", " alignment=['re-center', 're-scale'],\n", - " domains=['source']*X_df_source.shape[0],\n", + " # domains=['source']*X_df_source.shape[0],\n", " projection_params=dict(scale=1, n_compo=60, reg=0),\n", " estimator=LogisticRegression(solver='liblinear', C=1e7)\n", ")\n", - "fb_model.fit(X_df_source, labels_source)\n", - "score = fb_model.score(X_df_target, labels_target)" + "fb_model.fit(X_df_source, labels_source, domains=['source']*X_df_source.shape[0])\n", + "score = fb_model.score(X_df_target, labels_target, domains=['target_domain']*X_df_target.shape[0])" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -235,7 +235,7 @@ "0.6222222222222222" ] }, - "execution_count": 20, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -243,6 +243,73 @@ "source": [ "score" ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "cv = ShuffleSplit(10, test_size=0.8, random_state=42)\n", + "scores = []\n", + "for train_index, test_index in cv.split(X_df_target):\n", + " X_df_target_train = X_df_target.iloc[train_index]\n", + " labels_target_train = labels_target[train_index]\n", + " X_df_target_test = X_df_target.iloc[test_index]\n", + " labels_target_test = labels_target[test_index]\n", + " X_df_train = pd.concat([X_df_source, X_df_target_train])\n", + " y_train = np.concatenate([labels_source, labels_target_train])\n", + " domains = ['source']*X_df_source.shape[0] + ['target_domain']*X_df_target_train.shape[0]\n", + " fb_model.fit(X_df_train, y_train, domains=domains)\n", + " scores.append(fb_model.score(X_df_target_test, labels_target_test,\n", + " domains=['target_domain']*X_df_target_test.shape[0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean classification accuracy: 0.61\n" + ] + } + ], + "source": [ + "print(f'Mean classification accuracy: {np.mean(scores):0.2f}')" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.75,\n", + " 0.5833333333333334,\n", + " 0.6388888888888888,\n", + " 0.6666666666666666,\n", + " 0.4722222222222222,\n", + " 0.5833333333333334,\n", + " 0.6388888888888888,\n", + " 0.6111111111111112,\n", + " 0.5,\n", + " 0.6944444444444444]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scores" + ] } ], "metadata": {