Skip to content

Commit

Permalink
alignment and domains params in all functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Apolline Mellot committed Aug 9, 2023
1 parent f2d5809 commit 72f8967
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions coffeine/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def make_filter_bank_transformer(
alignment : list of str | None
Alignment steps to include in the pipeline. Can be ``'re-center'``,
``'re-scale'``.
domains : list of str | None
Domains for each matrix.
projection_params : dict | None
The parameters for the projection step.
vectorization_params : dict | None
Expand Down Expand Up @@ -349,6 +351,8 @@ def _get_projector_vectorizer(projection, vectorization,
def make_filter_bank_regressor(
names: list[str],
method: str = 'riemann',
alignment: Union[list[str], None] = None,
domains: Union[list[str], None] = None,
projection_params: Union[dict, None] = None,
vectorization_params: Union[dict, None] = None,
categorical_interaction: Union[bool, None] = None,
Expand Down Expand Up @@ -391,6 +395,11 @@ def make_filter_bank_regressor(
to ``'riemann'``. Can be ``'riemann'``, ``'lw_riemann'``, ``'diag'``,
``'log_diag'``, ``'random'``, ``'naive'``, ``'spoc'``,
``'riemann_wasserstein'``.
alignment : list of str | None
Alignment steps to include in the pipeline. Can be ``'re-center'``,
``'re-scale'``.
domains : list of str | None
Domains for each matrix.
projection_params : dict | None
The parameters for the projection step.
vectorization_params : dict | None
Expand All @@ -414,7 +423,8 @@ def make_filter_bank_regressor(
https://doi.org/10.1016/j.neuroimage.2020.116893
"""
filter_bank_transformer = make_filter_bank_transformer(
names=names, method=method, projection_params=projection_params,
names=names, method=method, alignment=alignment, domains=domains,
projection_params=projection_params,
vectorization_params=vectorization_params,
categorical_interaction=categorical_interaction
)
Expand All @@ -439,6 +449,8 @@ def make_filter_bank_regressor(
def make_filter_bank_classifier(
names: list[str],
method: str = 'riemann',
alignment: Union[list[str], None] = None,
domains: Union[list[str], None] = None,
projection_params: Union[dict, None] = None,
vectorization_params: Union[dict, None] = None,
categorical_interaction: Union[bool, None] = None,
Expand Down Expand Up @@ -481,6 +493,11 @@ def make_filter_bank_classifier(
to ``'riemann'``. Can be ``'riemann'``, ``'lw_riemann'``, ``'diag'``,
``'log_diag'``, ``'random'``, ``'naive'``, ``'spoc'``,
``'riemann_wasserstein'``.
alignment : list of str | None
Alignment steps to include in the pipeline. Can be ``'re-center'``,
``'re-scale'``.
domains : list of str | None
Domains for each matrix.
projection_params : dict | None
The parameters for the projection step.
vectorization_params : dict | None
Expand All @@ -505,7 +522,8 @@ def make_filter_bank_classifier(
"""
filter_bank_transformer = make_filter_bank_transformer(
names=names, method=method, projection_params=projection_params,
names=names, method=method, alignment=alignment, domains=domains,
projection_params=projection_params,
vectorization_params=vectorization_params,
categorical_interaction=categorical_interaction
)
Expand Down

0 comments on commit 72f8967

Please sign in to comment.