Skip to content

Commit

Permalink
Merge pull request #2 from eli5-org/new_sklearn
Browse files Browse the repository at this point in the history
New sklearn enabling changes
  • Loading branch information
lopuhin authored Jan 20, 2021
2 parents c944b20 + 459027b commit 1eb116e
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ sphinx_rtd_theme
ipython
scipy
numpy > 1.9.0
scikit-learn >= 0.18
scikit-learn >= 0.20
typing
11 changes: 8 additions & 3 deletions eli5/base_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import inspect
try:
from inspect import getfullargspec
except ImportError:
# python 2
from inspect import getargspec as getfullargspec # type: ignore

import attr

Expand All @@ -8,6 +12,7 @@
from singledispatch import singledispatch # type: ignore



def attrs(class_):
""" Like attr.s with slots=True,
but with attributes extracted from __init__ method signature.
Expand All @@ -25,12 +30,12 @@ def attrs(class_):
if method in class_.__dict__:
# Allow to redefine a special method (or else attr.s will do it)
attrs_kwargs[kw_name] = False
init_args = inspect.getargspec(class_.__init__)
init_args = getfullargspec(class_.__init__)
defaults_shift = len(init_args.args) - len(init_args.defaults or []) - 1
these = {}
for idx, arg in enumerate(init_args.args[1:]):
attrib_kwargs = {}
if idx >= defaults_shift:
if idx >= defaults_shift and init_args.defaults:
attrib_kwargs['default'] = init_args.defaults[idx - defaults_shift]
these[arg] = attr.ib(**attrib_kwargs)
return attr.s(class_, these=these, init=False, slots=True, **attrs_kwargs) # type: ignore
2 changes: 1 addition & 1 deletion eli5/sklearn/permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
clone,
is_classifier
)
from sklearn.metrics.scorer import check_scoring
from sklearn.metrics import check_scoring

from eli5.permutation_importance import get_score_importances
from eli5.sklearn.utils import pandas_available
Expand Down
6 changes: 4 additions & 2 deletions eli5/sklearn/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import numpy as np
from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.feature_selection.base import SelectorMixin

try:
from sklearn.feature_selection import SelectorMixin
except ImportError:
from sklearn.feature_selection.base import SelectorMixin
from sklearn.preprocessing import (
MinMaxScaler,
StandardScaler,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
numpy >= 1.9.0
scipy
singledispatch >= 3.4.0.3
scikit-learn >= 0.18
scikit-learn >= 0.20
attrs > 16.0.0
jinja2
pip >= 8.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_long_description():
'numpy >= 1.9.0',
'scipy',
'six',
'scikit-learn >= 0.18',
'scikit-learn >= 0.20',
'graphviz',
'tabulate>=0.7.7',
],
Expand Down

0 comments on commit 1eb116e

Please sign in to comment.