Skip to content

Commit

Permalink
fix n_iter report
Browse files Browse the repository at this point in the history
  • Loading branch information
fcharras committed Jan 15, 2024
1 parent 2f9d369 commit 1cf6569
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
4 changes: 3 additions & 1 deletion benchmarks/ridge/consolidate_result_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@
BACKEND_PROVIDER: str,
COMPUTE_DEVICE: str,
COMPUTE_RUNTIME: str,
RESULT_NB_ITERATIONS: np.int64,
# NB: following should be int but str is more practical because it enables
# use of missing values for solver for which it doesn't apply.
RESULT_NB_ITERATIONS: str,
OBJECTIVE_FUNCTION_VALUE: np.float64,
SOLVER: str,
PLATFORM: str,
Expand Down
18 changes: 18 additions & 0 deletions benchmarks/ridge/objective.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numbers
from datetime import datetime

from benchopt import BaseObjective, safe_import_context
Expand Down Expand Up @@ -87,6 +88,23 @@ def evaluate_result(self, weights, intercept, n_iter, **solver_parameters):
all_parameters.update(
{("solver_param_" + key): value for key, value in solver_parameters.items()}
)

if not (isinstance(n_iter, numbers.Number) or (n_iter is None)):
n_iter = set(n_iter)
if len(n_iter) > 1:
raise ValueError(
"In multitarget mode, the same number of iterations is expected "
"for all targets, to keep reports comparable."
)
n_iter = n_iter.pop()

# NB: str for n_iter is a more practical type because it enables
# using missing values for solvers for which it doesn't apply
if n_iter is None:
n_iter = ""
else:
n_iter = str(n_iter)

return dict(
value=value,
n_iter=n_iter,
Expand Down
12 changes: 1 addition & 11 deletions benchmarks/ridge/solvers/scikit_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,10 @@ def run(self, _):
self.n_iter_ = estimator.n_iter_

def get_result(self):
n_iter = self.n_iter_
if isinstance(n_iter, list):
n_iter = set(n_iter)
if len(n_iter) > 1:
raise ValueError(
"In multitarget mode, the same number of iterations is expected "
"for all targets, to keep reports comparable."
)
n_iter = n_iter.pop()

return dict(
weights=self.weights,
intercept=self.intercept,
n_iter=n_iter,
n_iter=self.n_iter_,
version_info=f"scikit-learn {version('scikit-learn')}",
__name=self.name,
**self._parameters,
Expand Down

0 comments on commit 1cf6569

Please sign in to comment.