Skip to content

Commit

Permalink
MAINT adapt for scikit-learn 1.6 (#1135)
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored Dec 9, 2024
1 parent 399fd93 commit 18af508
Show file tree
Hide file tree
Showing 11 changed files with 272 additions and 69 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ Bug fixes
:user:`Jérôme Dockès <jeromedockes>` and the matplotlib issue can be tracked
[here](https://github.com/matplotlib/matplotlib/issues/25041).

Maintenance
-----------
* Make `skrub` compatible with scikit-learn 1.6.
:pr:`1135` by :user:`Guillaume Lemaitre <glemaitre>`.

Release 0.4.0
=============
Expand Down
26 changes: 21 additions & 5 deletions benchmarks/bench_minhash_batch_number.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn
from joblib import Parallel, delayed, effective_n_jobs
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import gen_even_slices, murmurhash3_32
from sklearn.utils.fixes import parse_version
from utils import default_parser, find_result, monitor

from skrub._fast_hash import ngram_min_hash
Expand All @@ -32,6 +34,11 @@
# flake8: noqa: E501


sklearn_below_1_6 = parse_version(
parse_version(sklearn.__version__).base_version
) < parse_version("1.6")


class MinHashEncoder(BaseEstimator, TransformerMixin):
"""
Encode string categorical features as a numeric array, minhash method
Expand Down Expand Up @@ -126,11 +133,20 @@ def __init__(
self.batch_per_job = batch_per_job
self.n_jobs = n_jobs

def _more_tags(self):
"""
Used internally by sklearn to ease the estimator checks.
"""
return {"X_types": ["categorical"]}
if sklearn_below_1_6:

def _more_tags(self):
"""
Used internally by sklearn to ease the estimator checks.
"""
return {"X_types": ["categorical"]}

else:

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.categorical = True
return tags

def _get_murmur_hash(self, string):
"""
Expand Down
5 changes: 4 additions & 1 deletion skrub/_dataframe/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@
"total_seconds",
]

pandas_version = parse_version(parse_version(pd.__version__).base_version)

#
# Inspecting containers' type and module
# ======================================
Expand Down Expand Up @@ -330,7 +332,8 @@ def _concat_horizontal_pandas(*dataframes):
init_index = dataframes[0].index
dataframes = [df.reset_index(drop=True) for df in dataframes]
dataframes = _join_utils.make_column_names_unique(*dataframes)
result = pd.concat(dataframes, axis=1, copy=False)
kwargs = {"copy": False} if pandas_version < parse_version("3.0") else {}
result = pd.concat(dataframes, axis=1, **kwargs)
result.index = init_index
return result

Expand Down
21 changes: 21 additions & 0 deletions skrub/_datetime_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import datetime, timezone

import pandas as pd
import sklearn
from sklearn.utils.fixes import parse_version
from sklearn.utils.validation import check_is_fitted

try:
Expand All @@ -26,6 +28,11 @@
]


sklearn_below_1_6 = parse_version(
parse_version(sklearn.__version__).base_version
) < parse_version("1.6")


@dispatch
def _is_date(col):
raise NotImplementedError()
Expand Down Expand Up @@ -323,3 +330,17 @@ def _check_params(self):
raise ValueError(
f"'resolution' options are {allowed}, got {self.resolution!r}."
)

if sklearn_below_1_6:

def _more_tags(self):
return {"preserves_dtype": []}

else:

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
from sklearn.utils import TransformerTags

tags.transformer_tags = TransformerTags()
return tags
19 changes: 19 additions & 0 deletions skrub/_fixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import sklearn
from sklearn.utils.fixes import parse_version

sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)


if sklearn_version < parse_version("1.6"):
from sklearn.utils._tags import _safe_tags as get_tags # noqa
else:
from sklearn.utils import get_tags # noqa


def _check_n_features(estimator, X, *, reset):
if hasattr(estimator, "_check_n_features"):
estimator._check_n_features(X, reset=reset)
else:
from sklearn.utils.validation import _check_n_features

_check_n_features(estimator, X, reset=reset)
12 changes: 10 additions & 2 deletions skrub/_interpolation_joiner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from dataclasses import is_dataclass

import joblib
import numpy as np
Expand All @@ -7,11 +8,11 @@
HistGradientBoostingClassifier,
HistGradientBoostingRegressor,
)
from sklearn.utils._tags import _safe_tags

from . import _dataframe as sbd
from . import _join_utils, _utils
from . import _selectors as s
from ._fixes import get_tags
from ._minhash_encoder import MinHashEncoder
from ._table_vectorizer import TableVectorizer

Expand Down Expand Up @@ -403,7 +404,14 @@ def _get_assignments_for_estimator(table, estimator):


def _handles_multioutput(estimator):
return _safe_tags(estimator).get("multioutput", False)
tags = get_tags(estimator)
if isinstance(tags, dict):
# scikit-learn < 1.6
return tags.get("multioutput", False)
elif is_dataclass(tags):
# scikit-learn >= 1.6
return tags.target_tags.multi_output
return False


def _fit(key_values, target_table, estimator, propagate_exceptions):
Expand Down
46 changes: 31 additions & 15 deletions skrub/_similarity_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
which encodes similarity instead of equality of values.
"""


import numpy as np
import pandas as pd
import sklearn
Expand All @@ -14,12 +13,18 @@
from sklearn.utils.fixes import parse_version
from sklearn.utils.validation import check_is_fitted

from ._fixes import _check_n_features
from ._string_distances import get_ngram_count, preprocess

# Ignore lines too long, first docstring lines can't be cut
# flake8: noqa: E501


sklearn_below_1_6 = parse_version(
parse_version(sklearn.__version__).base_version
) < parse_version("1.6")


def _ngram_similarity_one_sample_inplace(
x_count_vector,
vocabulary_count_matrix,
Expand Down Expand Up @@ -334,7 +339,7 @@ def fit(self, X, y=None):
X[mask] = self.handle_missing

Xlist, n_samples, n_features = self._check_X(X)
self._check_n_features(X, reset=True)
_check_n_features(self, X, reset=True)

if self.handle_unknown not in ["error", "ignore"]:
raise ValueError(
Expand Down Expand Up @@ -453,7 +458,7 @@ def transform(self, X, fast=True):
X[mask] = self.handle_missing

Xlist, n_samples, n_features = self._check_X(X)
self._check_n_features(X, reset=False)
_check_n_features(self, X, reset=False)

for i in range(n_features):
Xi = Xlist[i]
Expand Down Expand Up @@ -550,15 +555,26 @@ def _ngram_similarity_fast(

return np.nan_to_num(out, copy=False)

def _more_tags(self):
return {
"X_types": ["2darray", "categorical", "string"],
"preserves_dtype": [],
"allow_nan": True,
"_xfail_checks": {
"check_estimator_sparse_data": (
"Cannot create sparse matrix with strings."
),
"check_estimators_dtypes": "We only support string dtypes.",
},
}
if sklearn_below_1_6:

def _more_tags(self):
return {
"X_types": ["2darray", "categorical", "string"],
"preserves_dtype": [],
"allow_nan": True,
"_xfail_checks": {
"check_estimator_sparse_data": (
"Cannot create sparse matrix with strings."
),
"check_estimators_dtypes": "We only support string dtypes.",
},
}

else:

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.categorical = True
tags.input_tags.string = True
tags.transformer_tags.preserves_dtype = []
return tags
39 changes: 28 additions & 11 deletions skrub/_table_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from typing import Iterable

import numpy as np
import sklearn
from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils._estimator_html_repr import _VisualBlock
from sklearn.utils.fixes import parse_version
from sklearn.utils.validation import check_is_fitted

from . import _dataframe as sbd
Expand All @@ -28,6 +30,11 @@
__all__ = ["TableVectorizer"]


sklearn_below_1_6 = parse_version(
parse_version(sklearn.__version__).base_version
) < parse_version("1.6")


class PassThrough(SingleColumnTransformer):
def fit_transform(self, column, y=None):
return column
Expand Down Expand Up @@ -658,17 +665,27 @@ def _sk_visual_block_(self):

# scikit-learn compatibility

def _more_tags(self):
"""
Used internally by sklearn to ease the estimator checks.
"""
return {
"X_types": ["2darray", "string"],
"allow_nan": [True],
"_xfail_checks": {
"check_complex_data": "Passthrough complex columns as-is.",
},
}
if sklearn_below_1_6:

def _more_tags(self):
"""
Used internally by sklearn to ease the estimator checks.
"""
return {
"X_types": ["2darray", "string"],
"allow_nan": [True],
"_xfail_checks": {
"check_complex_data": "Passthrough complex columns as-is.",
},
}

else:

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.string = True
tags.input_tags.allow_nan = True
return tags

def get_feature_names_out(self):
"""Return the column names of the output of ``transform`` as a list of strings.
Expand Down
15 changes: 12 additions & 3 deletions skrub/_tabular_learner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from dataclasses import is_dataclass

import sklearn
from sklearn import ensemble
from sklearn.base import BaseEstimator
Expand All @@ -6,6 +8,7 @@
from sklearn.preprocessing import OrdinalEncoder, StandardScaler
from sklearn.utils.fixes import parse_version

from ._fixes import get_tags
from ._minhash_encoder import MinHashEncoder
from ._table_vectorizer import TableVectorizer
from ._to_categorical import ToCategorical
Expand Down Expand Up @@ -270,9 +273,15 @@ def tabular_learner(estimator, *, n_jobs=None):
high_cardinality=MinHashEncoder(),
)
steps = [vectorizer]
if not hasattr(estimator, "_get_tags") or not estimator._get_tags().get(
"allow_nan", False
):
try:
tags = get_tags(estimator)
if is_dataclass(tags):
allow_nan = tags.input_tags.allow_nan
else:
allow_nan = tags.get("allow_nan", False)
except TypeError:
allow_nan = False
if not allow_nan:
steps.append(SimpleImputer(add_indicator=True))
if not isinstance(estimator, _TREE_ENSEMBLE_CLASSES):
steps.append(StandardScaler())
Expand Down
2 changes: 2 additions & 0 deletions skrub/_to_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def _get_time_zone_pandas(col):
return None
if hasattr(tz, "zone"):
return tz.zone
if hasattr(tz, "key"):
return tz.key
return tz.tzname(None)


Expand Down
Loading

0 comments on commit 18af508

Please sign in to comment.