Skip to content

Commit

Permalink
Faster warmup
Browse files Browse the repository at this point in the history
  • Loading branch information
fcharras committed Jan 15, 2024
1 parent f3e9765 commit c2f2479
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion benchmarks/ridge/solvers/cuml.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@ def set_objective(
self.random_state = random_state

def warm_up(self):
sample_weight = self.sample_weight
if sample_weight is not None:
sample_weight = sample_weight[:2]
cuml.Ridge(
alpha=self.alpha,
fit_intercept=self.fit_intercept,
solver=self.solver,
).fit(self.X, self.y, sample_weight=self.sample_weight)
).fit(self.X[:2], self.y[:2], sample_weight=sample_weight)

def run(self, _):
estimator = cuml.Ridge(
Expand Down
5 changes: 4 additions & 1 deletion benchmarks/ridge/solvers/scikit_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def skip(self, **objective_dict):
return False, None

def warm_up(self):
sample_weight = self.sample_weight
if sample_weight is not None:
sample_weight = sample_weight[:2]
Ridge(
alpha=self.alpha,
fit_intercept=self.fit_intercept,
Expand All @@ -57,7 +60,7 @@ def warm_up(self):
solver=self.solver,
positive=True if (self.solver == "lbfgs") else False,
random_state=self.random_state,
).fit(self.X, self.y, self.sample_weight)
).fit(self.X[:2], self.y[:2], sample_weight)

def run(self, _):
estimator = Ridge(
Expand Down

0 comments on commit c2f2479

Please sign in to comment.