Skip to content

Commit

Permalink
changes as PR suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
rg2410 committed Jan 21, 2020
1 parent f587bfa commit 729e557
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions eli5/sklearn/permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,12 @@ def _cv_scores_importances(self, X, y, groups=None, **fit_params):
cv = check_cv(self.cv, y, is_classifier(self.estimator))
feature_importances = [] # type: List
base_scores = [] # type: List[float]
weights = fit_params.get('sample_weight', None)
fit_params.pop('sample_weight', None)
weights = fit_params.pop('sample_weight', None)
fold_fit_params = fit_params.copy()
for train, test in cv.split(X, y, groups):
if weights is None:
est = clone(self.estimator).fit(X[train], y[train], **fit_params)
else:
est = clone(self.estimator).fit(X[train], y[train], sample_weight=weights[train], **fit_params)
if weights is not None:
fold_fit_params['sample_weight'] = weights[train]
est = clone(self.estimator).fit(X[train], y[train], **fold_fit_params)
score_func = partial(self.scorer_, est)
_base_score, _importances = self._get_score_importances(
score_func, X[test], y[test])
Expand Down

0 comments on commit 729e557

Please sign in to comment.