Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Allow fit_resample to receive metadata routed parameters #1111

Open
ShimantoRahman opened this issue Dec 16, 2024 · 5 comments · Fixed by #1115
Open

[ENH] Allow fit_resample to receive metadata routed parameters #1111

ShimantoRahman opened this issue Dec 16, 2024 · 5 comments · Fixed by #1115

Comments

@ShimantoRahman
Copy link

Is your feature request related to a problem? Please describe

In cost-sensitive learning, resampling techniques are used to address the asymmetrical importance of data points. These techniques require the amount of resampling to be dependent on instance-specific parameters, such as cost weights associated with individual data points. These cost weights are usually in a cost matrix for each data point $i$:

Actual Positive ($y_i = 1$) Actual Negative ($y_i = 0$)
Predicted Positive ($\hat y_i=1$) $C_{TP_i} $ $C_{FP_i}$
Predicted Negative ($\hat y_i=0$) $C_{FN_i}$ $C_{TN_i}$

Since these cost weights are dependent on the data point, they cannot be predetermined during initialization __init__ but instead must adapt dynamically based on the input data during the fit_resample process.

The current implementation imbalanced-learn Pipeline object does not natively support passing metadata through its fit_resample method. Metadata routing, which would enable instance-dependent parameters to flow seamlessly through the pipeline, is critical for implementing cost-sensitive learning workflows.

Desired workflow (DOES NOT CURRENTLY WORK)

import numpy as np
from imblearn.pipeline import Pipeline
from sklearn import set_config
from sklearn.utils._metadata_requests import MetadataRequest, RequestMethod
from sklearn.base import BaseEstimator
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression


set_config(enable_metadata_routing=True)

class CostSensitiveSampler(BaseEstimator):

    _estimator_type = "sampler"
    __metadata_request__fit_resample = {'cost_matrix': True}

    def __init__(self, random_state=None):
        self.random_state = random_state

    def fit_resample(self, X, y, cost_matrix=None):
        # resample based on cost_matrix
        # ...
        return X, y

    def _get_metadata_request(self):
        routing = MetadataRequest(owner=self.__class__.__name__)
        routing.fit_resample.add_request(param='cost_matrix', alias=True)
        return routing

    set_fit_resample_request = RequestMethod('fit_resample', ['cost_matrix'])

X, y = make_classification()
cost_matrix = np.random.rand(X.shape[0], 2, 2)
pipeline = Pipeline([
    ('sampler', CostSensitiveSampler().set_fit_resample_request(cost_matrix=True)),
    ('model', LogisticRegression())
])
pipeline.fit(X, y, cost_matrix=cost_matrix)

Describe the solution you'd like

From what I understand from the metadata routing implementation of the Pipeline object only a couple of changes have to be made:

  1. the SIMPLE_METHODS constant found here needs to include "fit_resample":
SIMPLE_METHODS = [
            "fit",
            "partial_fit",
            "fit_resample",  # add line here
            "predict",
            "predict_proba",
            "predict_log_proba",
            "decision_function",
            "score",
            "split",
            "transform",
            "inverse_transform",
        ]

Note that this does require imbalanced-learn to redefine the classes and functions which use the SIMPLE_METHODS constant internally. These are now imported from scikit-learn if scikit-learn version 1.4 or higher is installed. These include: MetadataRequest and _MetadataRequester.
2. A method mapping from caller "fit" to callee "fit_resample" has to be added in the get_meta_data_routing(self) method found here and the filter_resample parameter of self._iter method needs be set to False:

def get_metadata_routing(self):
        """Get metadata routing of this object.

        Please check :ref:`User Guide <metadata_routing>` on how the routing
        mechanism works.

        Returns
        -------
        routing : MetadataRouter
            A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
            routing information.
        """
        router = MetadataRouter(owner=self.__class__.__name__)

        # first we add all steps except the last one
        for _, name, trans in self._iter(with_final=False, filter_passthrough=True, filter_resample=False):  # change filter_resample to False
            method_mapping = MethodMapping()
            # fit, fit_predict, and fit_transform call fit_transform if it
            # exists, or else fit and transform
            if hasattr(trans, "fit_transform"):
                (
                    method_mapping.add(caller="fit", callee="fit_transform")
                    .add(caller="fit_transform", callee="fit_transform")
                    .add(caller="fit_predict", callee="fit_transform")
                    .add(caller="fit_resample", callee="fit_transform")
                )
            else:
                (
                    method_mapping.add(caller="fit", callee="fit")
                    .add(caller="fit", callee="transform")
                    .add(caller="fit_transform", callee="fit")
                    .add(caller="fit_transform", callee="transform")
                    .add(caller="fit_predict", callee="fit")
                    .add(caller="fit_predict", callee="transform")
                    .add(caller="fit_resample", callee="fit")
                    .add(caller="fit_resample", callee="transform")
                )

            (
                method_mapping.add(caller="predict", callee="transform")
                .add(caller="predict", callee="transform")
                .add(caller="predict_proba", callee="transform")
                .add(caller="decision_function", callee="transform")
                .add(caller="predict_log_proba", callee="transform")
                .add(caller="transform", callee="transform")
                .add(caller="inverse_transform", callee="inverse_transform")
                .add(caller="score", callee="transform")
                .add(caller="fit_resample", callee="transform")
                .add(caller="fit", callee="fit_resample")  # add this line
            )

            
            router.add(method_mapping=method_mapping, **{name: trans})
        # add final estimator method mapping
        ...

Additional context

I am a PhD Researcher and used these methods for my paper and the author of a python package Empulse which has implemented samplers which require cost parameters to be passed to the fit_resample method like in the dummy example (see Empulse/Samplers). I find the whole metadata routing implementation incredibly confusing, so apologies if I made some mistakes in my reasoning.

@glemaitre
Copy link
Member

On the principle, I think it would be nice to accept metadata indeed.

For your specific use case, I'm not sure that resampling is actually the best. While working on the scikit-learn project, we found that resampling is breaking the calibration of the classifier and usually what users try actually to solved can be done as a post-tuning of the threshold of the classifier.

We recently added the TunedThresholdClassifier in scikit-learn and we show an example of cost-sensitive learning in the documentation: https://scikit-learn.org/1.5/auto_examples/model_selection/plot_cost_sensitive_learning.html#sphx-glr-auto-examples-model-selection-plot-cost-sensitive-learning-py

We also worked on the following tutorial to show some internal that could be interested to you: https://probabl-ai.github.io/calibration-cost-sensitive-learning/intro.html

@glemaitre
Copy link
Member

So I think that we had an underlying bug in get_metadata_routing.

I got your example and made a minimal reproducer:

https://github.com/scikit-learn-contrib/imbalanced-learn/pull/1115/files#diff-82b96c4de3880642afa90f01a32ca3b1dbac2918d037990a7826ba4dc206a939R1501-R1519

So it means that it should work out of the box.

@ShimantoRahman
Copy link
Author

For your specific use case, I'm not sure that resampling is actually the best. While working on the scikit-learn project, we found that resampling is breaking the calibration of the classifier and usually what users try actually to solved can be done as a post-tuning of the threshold of the classifier.

Thank you for your recommendation. A couple of days ago I had watched your podcast together with Vincent Warmerdam on the Probabl YouTube channel. It was quite insightful and it prompted me to read the scikit-learn documentation you have linked above. It was very insightful and definitely changed my perspective to the problem. I was planning to do some benchmarking of my own once I was finished implementing some of the techniques I found in literature, and I will definitely explore calibration further.

I just tested out version 0.13.0 and it works like a charm! Thank you for the quick implementation and my best wishes this holiday period <3

@ShimantoRahman
Copy link
Author

One small suggestion in relation to type checking. As of now type checkers will not recognize the set_fit_resample_request method as it is dynamically constructed at runtime. Perhaps adding this to the SamplerMixin could be useful:

class SamplerMixin(metaclass=ABCMeta):
    """Mixin class for samplers with abstract method.

    Warning: This class should not be used directly. Use the derive classes
    instead.
    """

    _estimator_type = "sampler"

    if TYPE_CHECKING:
        def set_fit_resample_request(self, **kwargs): pass

    ...

@glemaitre
Copy link
Member

Let me reopen to not forget about this last issue. Thanks for reporting.

@glemaitre glemaitre reopened this Dec 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants