From 72f896763f641dabb39573c1f0ec152b04d85315 Mon Sep 17 00:00:00 2001 From: Apolline Mellot Date: Wed, 9 Aug 2023 16:22:12 +0200 Subject: [PATCH] alignment and domains params in all functions --- coffeine/pipelines.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/coffeine/pipelines.py b/coffeine/pipelines.py index 13897df..e71202c 100644 --- a/coffeine/pipelines.py +++ b/coffeine/pipelines.py @@ -194,6 +194,8 @@ def make_filter_bank_transformer( alignment : list of str | None Alignment steps to include in the pipeline. Can be ``'re-center'``, ``'re-scale'``. + domains : list of str | None + Domains for each matrix. projection_params : dict | None The parameters for the projection step. vectorization_params : dict | None @@ -349,6 +351,8 @@ def _get_projector_vectorizer(projection, vectorization, def make_filter_bank_regressor( names: list[str], method: str = 'riemann', + alignment: Union[list[str], None] = None, + domains: Union[list[str], None] = None, projection_params: Union[dict, None] = None, vectorization_params: Union[dict, None] = None, categorical_interaction: Union[bool, None] = None, @@ -391,6 +395,11 @@ def make_filter_bank_regressor( to ``'riemann'``. Can be ``'riemann'``, ``'lw_riemann'``, ``'diag'``, ``'log_diag'``, ``'random'``, ``'naive'``, ``'spoc'``, ``'riemann_wasserstein'``. + alignment : list of str | None + Alignment steps to include in the pipeline. Can be ``'re-center'``, + ``'re-scale'``. + domains : list of str | None + Domains for each matrix. projection_params : dict | None The parameters for the projection step. vectorization_params : dict | None @@ -414,7 +423,8 @@ def make_filter_bank_regressor( https://doi.org/10.1016/j.neuroimage.2020.116893 """ filter_bank_transformer = make_filter_bank_transformer( - names=names, method=method, projection_params=projection_params, + names=names, method=method, alignment=alignment, domains=domains, + projection_params=projection_params, vectorization_params=vectorization_params, categorical_interaction=categorical_interaction ) @@ -439,6 +449,8 @@ def make_filter_bank_regressor( def make_filter_bank_classifier( names: list[str], method: str = 'riemann', + alignment: Union[list[str], None] = None, + domains: Union[list[str], None] = None, projection_params: Union[dict, None] = None, vectorization_params: Union[dict, None] = None, categorical_interaction: Union[bool, None] = None, @@ -481,6 +493,11 @@ def make_filter_bank_classifier( to ``'riemann'``. Can be ``'riemann'``, ``'lw_riemann'``, ``'diag'``, ``'log_diag'``, ``'random'``, ``'naive'``, ``'spoc'``, ``'riemann_wasserstein'``. + alignment : list of str | None + Alignment steps to include in the pipeline. Can be ``'re-center'``, + ``'re-scale'``. + domains : list of str | None + Domains for each matrix. projection_params : dict | None The parameters for the projection step. vectorization_params : dict | None @@ -505,7 +522,8 @@ def make_filter_bank_classifier( """ filter_bank_transformer = make_filter_bank_transformer( - names=names, method=method, projection_params=projection_params, + names=names, method=method, alignment=alignment, domains=domains, + projection_params=projection_params, vectorization_params=vectorization_params, categorical_interaction=categorical_interaction )