diff --git a/benchmarks/ridge/consolidate_result_csv.py b/benchmarks/ridge/consolidate_result_csv.py index 25b2cf2..2933f92 100644 --- a/benchmarks/ridge/consolidate_result_csv.py +++ b/benchmarks/ridge/consolidate_result_csv.py @@ -94,7 +94,9 @@ BACKEND_PROVIDER: str, COMPUTE_DEVICE: str, COMPUTE_RUNTIME: str, - RESULT_NB_ITERATIONS: np.int64, + # NB: following should be int but str is more practical because it enables + # use of missing values for solver for which it doesn't apply. + RESULT_NB_ITERATIONS: str, OBJECTIVE_FUNCTION_VALUE: np.float64, SOLVER: str, PLATFORM: str, diff --git a/benchmarks/ridge/objective.py b/benchmarks/ridge/objective.py index 94c241c..970fb21 100644 --- a/benchmarks/ridge/objective.py +++ b/benchmarks/ridge/objective.py @@ -1,3 +1,4 @@ +import numbers from datetime import datetime from benchopt import BaseObjective, safe_import_context @@ -87,6 +88,23 @@ def evaluate_result(self, weights, intercept, n_iter, **solver_parameters): all_parameters.update( {("solver_param_" + key): value for key, value in solver_parameters.items()} ) + + if not (isinstance(n_iter, numbers.Number) or (n_iter is None)): + n_iter = set(n_iter) + if len(n_iter) > 1: + raise ValueError( + "In multitarget mode, the same number of iterations is expected " + "for all targets, to keep reports comparable." + ) + n_iter = n_iter.pop() + + # NB: str for n_iter is a more practical type because it enables + # using missing values for solvers for which it doesn't apply + if n_iter is None: + n_iter = "" + else: + n_iter = str(n_iter) + return dict( value=value, n_iter=n_iter, diff --git a/benchmarks/ridge/solvers/scikit_learn.py b/benchmarks/ridge/solvers/scikit_learn.py index 202e720..0d49ac0 100644 --- a/benchmarks/ridge/solvers/scikit_learn.py +++ b/benchmarks/ridge/solvers/scikit_learn.py @@ -64,20 +64,10 @@ def run(self, _): self.n_iter_ = estimator.n_iter_ def get_result(self): - n_iter = self.n_iter_ - if isinstance(n_iter, list): - n_iter = set(n_iter) - if len(n_iter) > 1: - raise ValueError( - "In multitarget mode, the same number of iterations is expected " - "for all targets, to keep reports comparable." - ) - n_iter = n_iter.pop() - return dict( weights=self.weights, intercept=self.intercept, - n_iter=n_iter, + n_iter=self.n_iter_, version_info=f"scikit-learn {version('scikit-learn')}", __name=self.name, **self._parameters,