From c2f2479527f5b3010092a4eb5be2bc3664a1eca1 Mon Sep 17 00:00:00 2001 From: Franck Charras <29153872+fcharras@users.noreply.github.com> Date: Mon, 15 Jan 2024 20:20:27 +0100 Subject: [PATCH] Faster warmup --- benchmarks/ridge/solvers/cuml.py | 5 ++++- benchmarks/ridge/solvers/scikit_learn.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/benchmarks/ridge/solvers/cuml.py b/benchmarks/ridge/solvers/cuml.py index 20482a1..c64cd41 100644 --- a/benchmarks/ridge/solvers/cuml.py +++ b/benchmarks/ridge/solvers/cuml.py @@ -69,11 +69,14 @@ def set_objective( self.random_state = random_state def warm_up(self): + sample_weight = self.sample_weight + if sample_weight is not None: + sample_weight = sample_weight[:2] cuml.Ridge( alpha=self.alpha, fit_intercept=self.fit_intercept, solver=self.solver, - ).fit(self.X, self.y, sample_weight=self.sample_weight) + ).fit(self.X[:2], self.y[:2], sample_weight=sample_weight) def run(self, _): estimator = cuml.Ridge( diff --git a/benchmarks/ridge/solvers/scikit_learn.py b/benchmarks/ridge/solvers/scikit_learn.py index 0fbbe8b..97db1c6 100644 --- a/benchmarks/ridge/solvers/scikit_learn.py +++ b/benchmarks/ridge/solvers/scikit_learn.py @@ -48,6 +48,9 @@ def skip(self, **objective_dict): return False, None def warm_up(self): + sample_weight = self.sample_weight + if sample_weight is not None: + sample_weight = sample_weight[:2] Ridge( alpha=self.alpha, fit_intercept=self.fit_intercept, @@ -57,7 +60,7 @@ def warm_up(self): solver=self.solver, positive=True if (self.solver == "lbfgs") else False, random_state=self.random_state, - ).fit(self.X, self.y, self.sample_weight) + ).fit(self.X[:2], self.y[:2], sample_weight) def run(self, _): estimator = Ridge(