diff --git a/eli5/sklearn/permutation_importance.py b/eli5/sklearn/permutation_importance.py index 5a963880..370be8be 100644 --- a/eli5/sklearn/permutation_importance.py +++ b/eli5/sklearn/permutation_importance.py @@ -214,8 +214,12 @@ def _cv_scores_importances(self, X, y, groups=None, **fit_params): cv = check_cv(self.cv, y, is_classifier(self.estimator)) feature_importances = [] # type: List base_scores = [] # type: List[float] + weights = fit_params.pop('sample_weight', None) + fold_fit_params = fit_params.copy() for train, test in cv.split(X, y, groups): - est = clone(self.estimator).fit(X[train], y[train], **fit_params) + if weights is not None: + fold_fit_params['sample_weight'] = weights[train] + est = clone(self.estimator).fit(X[train], y[train], **fold_fit_params) score_func = partial(self.scorer_, est) _base_score, _importances = self._get_score_importances( score_func, X[test], y[test]) diff --git a/tests/test_sklearn_permutation_importance.py b/tests/test_sklearn_permutation_importance.py index 4fe942fd..4ffec3ba 100644 --- a/tests/test_sklearn_permutation_importance.py +++ b/tests/test_sklearn_permutation_importance.py @@ -3,7 +3,7 @@ import numpy as np from sklearn.base import is_classifier, is_regressor from sklearn.svm import SVR, SVC -from sklearn.ensemble import RandomForestRegressor +from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier from sklearn.model_selection import train_test_split, cross_val_score from sklearn.pipeline import make_pipeline from sklearn.feature_selection import SelectFromModel @@ -165,6 +165,7 @@ def test_explain_weights(iris_train): for _expl in res: assert "petal width (cm)" in _expl + def test_pandas_xgboost_support(iris_train): xgboost = pytest.importorskip('xgboost') pd = pytest.importorskip('pandas') @@ -175,3 +176,17 @@ def test_pandas_xgboost_support(iris_train): est.fit(X, y) # we expect no exception to be raised here when using xgboost with pd.DataFrame perm = PermutationImportance(est).fit(X, y) + + +def test_cv_sample_weight(iris_train): + X, y, feature_names, target_names = iris_train + weights_ones = np.ones(len(y)) + model = RandomForestClassifier(random_state=42) + + # we expect no exception to be raised when passing weights with a CV + perm_weights = PermutationImportance(model, cv=5, random_state=42).\ + fit(X, y, sample_weight=weights_ones) + perm = PermutationImportance(model, cv=5, random_state=42).fit(X, y) + + # passing a vector of weights filled with one should be the same as passing no weights + assert (perm.feature_importances_ == perm_weights.feature_importances_).all() \ No newline at end of file