Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: Ridge support for Array API compliant inputs #27800

Merged
merged 72 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
e0429db
update r2 score to use the array API, and write initial tests
elindgren Aug 18, 2023
b9c1720
Merge remote-tracking branch 'upstream/main'
elindgren Aug 18, 2023
5666ce5
Merge branch 'main' into ENH/r2_score_array_api
ogrisel Aug 19, 2023
4580d1c
Merge branch 'main' into ENH/r2_score_array_api
ogrisel Sep 7, 2023
a4dd594
Fix some review comments and move stuff to CPU
elindgren Sep 8, 2023
adc7680
Add regression tests to the test_common framework
elindgren Sep 28, 2023
85469a9
Update sklearn/metrics/tests/test_regression.py
elindgren Oct 5, 2023
b7efaa5
Update sklearn/metrics/tests/test_regression.py
elindgren Oct 5, 2023
ac533c2
Remove hardcoded device choice in _weighted_sum
betatim Aug 30, 2023
35be22e
Factor out max float precision determination
betatim Sep 7, 2023
7c53e19
Use convenience function to find highest accuracy float in r2_score
elindgren Oct 5, 2023
230ae46
add tests for _average for Array API
elindgren Oct 5, 2023
e4672d1
MNT Ignore ruff errors (#27094)
lesteve Aug 18, 2023
8ba9485
DOC fix docstring for `sklearn.datasets.get_data_home` (#27073)
kachayev Aug 18, 2023
490e0b4
TST Extend tests for `scipy.sparse.*array` in `sklearn/cluster/tests/…
jjerphan Aug 18, 2023
a8a820c
MNT Remove DeprecationWarning for scipy.sparse.linalg.cg tol vs rtol …
lesteve Aug 18, 2023
552e421
Merge branch 'main' into ENH/r2_score_array_api
elindgren Oct 5, 2023
ff52710
Merge remote-tracking branch 'upstream/main' into ENH/r2_score_array_api
elindgren Oct 5, 2023
fe9cc1c
remove temporary file
elindgren Oct 5, 2023
93257ba
WIP: solving dtype and device maze
fcharras Dec 5, 2023
45bbe4e
Fix changelog conflict
fcharras Dec 5, 2023
2145a6b
Tests fixups
fcharras Dec 6, 2023
bd4b224
Tests fixups
fcharras Dec 6, 2023
34aceb1
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras Dec 6, 2023
56d5308
Fix dtype parameterization in common metric tests
fcharras Dec 6, 2023
75cb3f3
Tests fixups
fcharras Dec 6, 2023
d9fff24
Tests fixups
fcharras Dec 6, 2023
d72137c
Adds lru_cache on device inspection function + user _convert_to_numpy…
fcharras Dec 11, 2023
16ab95f
Adequatly define hash of _ArrayAPIWrapper to avoid wrong equality
fcharras Dec 11, 2023
9862a85
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras Dec 19, 2023
143ce54
Remove _weighted_sum and only use _average
fcharras Dec 19, 2023
4e9401b
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras Dec 19, 2023
2b095c4
Linting on unrelated diff, pre-commit broken ? + fixes
fcharras Dec 19, 2023
42f5d8d
Merge branch 'main' into ENH/r2_score_array_api
fcharras Dec 26, 2023
ff0b860
re add faster, simpler code branch for _weighted_sum in _classificati…
fcharras Dec 27, 2023
efe36f3
re add faster, simpler code branch for _weighted_sum in _classificati…
fcharras Dec 27, 2023
abb9ee9
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras Dec 27, 2023
08f5433
fix
fcharras Dec 28, 2023
618579e
Ridge support for Array API compliant inputs
fcharras Dec 28, 2023
7ae648b
Fix
fcharras Dec 29, 2023
7a299a7
General fixes + fix tests with torch+cuda
fcharras Dec 29, 2023
0cc52a1
fix tests with torch+cuda
fcharras Dec 29, 2023
9a6b425
fix tests with torch+cuda
fcharras Dec 29, 2023
1d58e83
Merge branch 'main' into enh/ridge_support_for_array_api
fcharras Jan 2, 2024
219a7f5
fixup alpha check when multitarget
fcharras Jan 16, 2024
046cf00
wip: merge main after merge of #27904
fcharras Mar 15, 2024
19591f5
Factor repeated patterns
fcharras Mar 15, 2024
ea95282
Fixup
fcharras Mar 15, 2024
2cc32b0
Fixup
fcharras Mar 15, 2024
020c5e5
Fixup
fcharras Mar 15, 2024
aef4d00
Adress comments - cleaning and fixes
fcharras Mar 18, 2024
d72a1de
Fixup - remove unnecessary _asarray_with_order call
fcharras Mar 18, 2024
55ce84c
Fixup
fcharras Mar 18, 2024
2116646
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras Mar 19, 2024
408e205
Fixup
fcharras Mar 19, 2024
d8c0d3c
Wip: Remove _item / keep auto behavior / test error and warning behavior
fcharras Mar 23, 2024
47afb9d
Fixup doctest
fcharras Mar 23, 2024
5d05527
Fixup
fcharras Mar 23, 2024
894204d
fix output type for intercept in _ridge_regression / add yield_namesp…
fcharras Mar 23, 2024
ef732d9
Fixup
fcharras Mar 23, 2024
93f2ec1
Test setup fixup
fcharras Mar 23, 2024
6fbaeb7
Fix coverage + register solver_ attribute in Ridge and RidgeClassifier
fcharras Mar 23, 2024
9745941
Adress review suggestions
fcharras Mar 26, 2024
741a86c
Properly interpolate strings.
fcharras Mar 27, 2024
4edebac
Also fix typo and use helper and properly interpolate strings in uni…
fcharras Mar 27, 2024
5600a2a
Update credits
fcharras Mar 28, 2024
fc6521b
Merge branch 'main' into enh/ridge_support_for_array_api
ogrisel Mar 28, 2024
25da280
Merge branch 'main' into enh/ridge_support_for_array_api
ogrisel Apr 12, 2024
1a29866
Typo
ogrisel Apr 24, 2024
e51d6e4
Rename solver resolution helper functions
ogrisel Apr 24, 2024
6a84850
Remove useless xp.asarray(intercept) in the saga branch
ogrisel Apr 24, 2024
a02e7a4
Improved warning message based on Tim's suggestion
ogrisel Apr 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Estimators

- :class:`decomposition.PCA` (with `svd_solver="full"`,
`svd_solver="randomized"` and `power_iteration_normalizer="QR"`)
- :class:`linear_model.Ridge` (with `solver="svd"`)
- :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`)
- :class:`preprocessing.KernelCenterer`
- :class:`preprocessing.MaxAbsScaler`
Expand Down
7 changes: 6 additions & 1 deletion doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ See :ref:`array_api` for more details.

**Classes:**

- :class:`linear_model.Ridge` now supports the Array API for the `svd` solver.
See :ref:`array_api` for more details.
:pr:`27800` by :user:`Franck Charras <fcharras>`, :user:`TODO <TODO>` and
:user:`TODO <TODO>`.

Support for building with Meson
-------------------------------

Expand Down Expand Up @@ -293,7 +298,7 @@ Changelog
:func:`preprocessing.quantile_transform` now supports disabling
subsampling explicitly.
:pr:`27636` by :user:`Ralph Urlus <rurlus>`.

:mod:`sklearn.tree`
...................

Expand Down
77 changes: 54 additions & 23 deletions sklearn/linear_model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@
_fit_context,
)
from ..utils import check_array, check_random_state
from ..utils._array_api import get_namespace, indexing_dtype
from ..utils._array_api import (
_asarray_with_order,
_average,
get_namespace,
get_namespace_and_device,
indexing_dtype,
supported_float_dtypes,
)
from ..utils._seq_dataset import (
ArrayDataset32,
ArrayDataset64,
Expand All @@ -43,7 +50,7 @@
from ..utils.extmath import safe_sparse_dot
from ..utils.parallel import Parallel, delayed
from ..utils.sparsefuncs import mean_variance_axis
from ..utils.validation import FLOAT_DTYPES, _check_sample_weight, check_is_fitted
from ..utils.validation import _check_sample_weight, check_is_fitted

# TODO: bayesian_ridge_regression and bayesian_regression_ard
# should be squashed into its respective objects.
Expand Down Expand Up @@ -155,43 +162,51 @@ def _preprocess_data(
Always an array of ones. TODO: refactor the code base to make it
possible to remove this unused variable.
"""
xp, _, device_ = get_namespace_and_device(X, y, sample_weight)
n_samples, n_features = X.shape
X_is_sparse = sp.issparse(X)

if isinstance(sample_weight, numbers.Number):
sample_weight = None
if sample_weight is not None:
sample_weight = np.asarray(sample_weight)
sample_weight = xp.asarray(sample_weight)

if check_input:
X = check_array(X, copy=copy, accept_sparse=["csr", "csc"], dtype=FLOAT_DTYPES)
X = check_array(
X, copy=copy, accept_sparse=["csr", "csc"], dtype=supported_float_dtypes(xp)
)
y = check_array(y, dtype=X.dtype, copy=copy_y, ensure_2d=False)
else:
y = y.astype(X.dtype, copy=copy_y)
y = xp.astype(y, X.dtype, copy=copy_y)
if copy:
if sp.issparse(X):
if X_is_sparse:
X = X.copy()
else:
X = X.copy(order="K")
X = _asarray_with_order(X, order="K", copy=True, xp=xp)

dtype_ = X.dtype

if fit_intercept:
if sp.issparse(X):
if X_is_sparse:
X_offset, X_var = mean_variance_axis(X, axis=0, weights=sample_weight)
else:
X_offset = np.average(X, axis=0, weights=sample_weight)
X_offset = _average(X, axis=0, weights=sample_weight, xp=xp)

X_offset = X_offset.astype(X.dtype, copy=False)
X_offset = xp.astype(X_offset, X.dtype, copy=False)
X -= X_offset

y_offset = np.average(y, axis=0, weights=sample_weight)
y_offset = _average(y, axis=0, weights=sample_weight, xp=xp)
y -= y_offset
else:
X_offset = np.zeros(X.shape[1], dtype=X.dtype)
X_offset = xp.zeros(n_features, dtype=X.dtype, device=device_)
if y.ndim == 1:
y_offset = X.dtype.type(0)
y_offset = xp.asarray(0.0, dtype=dtype_, device=device_)
else:
y_offset = np.zeros(y.shape[1], dtype=X.dtype)
y_offset = xp.zeros(y.shape[1], dtype=dtype_, device=device_)

# XXX: X_scale is no longer needed. It is an historic artifact from the
# time where linear model exposed the normalize parameter.
X_scale = np.ones(X.shape[1], dtype=X.dtype)
X_scale = xp.ones(n_features, dtype=X.dtype, device=device_)
return X, y, X_offset, y_offset, X_scale


Expand Down Expand Up @@ -224,8 +239,9 @@ def _rescale_data(X, y, sample_weight, inplace=False):
"""
# Assume that _validate_data and _check_sample_weight have been called by
# the caller.
xp, _ = get_namespace(X, y, sample_weight)
n_samples = X.shape[0]
sample_weight_sqrt = np.sqrt(sample_weight)
sample_weight_sqrt = xp.sqrt(sample_weight)

if sp.issparse(X) or sp.issparse(y):
sw_matrix = sparse.dia_matrix(
Expand All @@ -236,9 +252,9 @@ def _rescale_data(X, y, sample_weight, inplace=False):
X = safe_sparse_dot(sw_matrix, X)
else:
if inplace:
X *= sample_weight_sqrt[:, np.newaxis]
X *= sample_weight_sqrt[:, None]
else:
X = X * sample_weight_sqrt[:, np.newaxis]
X = X * sample_weight_sqrt[:, None]

if sp.issparse(y):
y = safe_sparse_dot(sw_matrix, y)
Expand All @@ -247,12 +263,12 @@ def _rescale_data(X, y, sample_weight, inplace=False):
if y.ndim == 1:
y *= sample_weight_sqrt
else:
y *= sample_weight_sqrt[:, np.newaxis]
y *= sample_weight_sqrt[:, None]
else:
if y.ndim == 1:
y = y * sample_weight_sqrt
else:
y = y * sample_weight_sqrt[:, np.newaxis]
y = y * sample_weight_sqrt[:, None]
return X, y, sample_weight_sqrt


Expand All @@ -267,7 +283,11 @@ def _decision_function(self, X):
check_is_fitted(self)

X = self._validate_data(X, accept_sparse=["csr", "csc", "coo"], reset=False)
return safe_sparse_dot(X, self.coef_.T, dense_output=True) + self.intercept_
coef_ = self.coef_
if coef_.ndim == 1:
return X @ coef_ + self.intercept_
else:
return X @ coef_.T + self.intercept_

def predict(self, X):
"""
Expand All @@ -287,11 +307,22 @@ def predict(self, X):

def _set_intercept(self, X_offset, y_offset, X_scale):
"""Set the intercept_"""

xp, _ = get_namespace(X_offset, y_offset, X_scale)

if self.fit_intercept:
# We always want coef_.dtype=X.dtype. For instance, X.dtype can differ from
# coef_.dtype if warm_start=True.
fcharras marked this conversation as resolved.
Show resolved Hide resolved
self.coef_ = np.divide(self.coef_, X_scale, dtype=X_scale.dtype)
self.intercept_ = y_offset - np.dot(X_offset, self.coef_.T)
coef_ = xp.astype(self.coef_, X_scale.dtype, copy=False)
coef_ = self.coef_ = xp.divide(coef_, X_scale)

if coef_.ndim == 1:
intercept_ = y_offset - X_offset @ coef_
else:
intercept_ = y_offset - X_offset @ coef_.T

self.intercept_ = intercept_
fcharras marked this conversation as resolved.
Show resolved Hide resolved

else:
self.intercept_ = 0.0

Expand Down
Loading
Loading