diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index dceb79b66..7377d6701 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -876,14 +876,14 @@ def is_valid_for_device(self, device_id: str) -> bool: def get_candidates_ordered_by_speedup( self, candidate_results: list[BenchmarkResult] - ) -> Union[list[BenchmarkResult], list[tuple[BenchmarkResult, float]]]: + ) -> list[tuple[BenchmarkResult, float]]: """ - returns: - - `list[BenchmarkResult]` sorted by runtime if no valid baselines exist. - - `list[tuple[BenchmarkResult, float]]` sorted by speedup when baselines are available. + Returns a list of tuples (BenchmarkResult, speedup) sorted in ascending order based on speedup + or raw runtime. - If no valid baseline times are available across all devices, candidates are sorted - and returned based on their raw runtime in ascending order. + If no valid baseline times are available across all devices, candidates are sorted based on + their raw runtime in ascending order. A placeholder speedup value of 1.0 is assigned to each + candidate. If valid baseline times exist, speedup is defined as the ratio of the candidate's runtime to the average baseline time for the corresponding device as: @@ -896,7 +896,10 @@ def get_candidates_ordered_by_speedup( if not self.is_valid(): logging.warning("No valid baseline times available.") # Use the candidate time directly when no baselines are available. - return sorted(candidate_results, key=lambda candidate: candidate.time) + return sorted( + [(candidate, 1.0) for candidate in candidate_results], + key=lambda x: x[0].time, + ) # Calculate the fallback baseline as the average of all valid times across devices. valid_baseline_times = [ @@ -1065,11 +1068,7 @@ def benchmark( f"({percentage_of_baseline:.1f}% of baseline)" ) else: - top_candiates = [ - result if isinstance(result, BenchmarkResult) else result[0] - for result in top_candidates_with_speedup - ] - for candidate in top_candiates: + for candidate, _ in top_candidates_with_speedup: time_ms = candidate.time candidate_id = candidate.candidate_id top_candidate_ids.append(candidate_id) diff --git a/tuner/tuner/libtuner_test.py b/tuner/tuner/libtuner_test.py index e0e3f6ba0..bbca2d370 100644 --- a/tuner/tuner/libtuner_test.py +++ b/tuner/tuner/libtuner_test.py @@ -290,11 +290,14 @@ def test_baseline_result_handler_speedup(): ] handler = libtuner.BaselineResultHandler() - all_candidates = handler.get_candidates_ordered_by_speedup(candidates) - assert all_candidates == [ - candidates[1], - candidates[0], - candidates[2], + all_candidates_with_speedup = handler.get_candidates_ordered_by_speedup(candidates) + assert all_candidates_with_speedup == [ + (candidates[1], 1.0), + (candidates[0], 1.0), + (candidates[2], 1.0), + ] + top_candidates_with_speedup = all_candidates_with_speedup[:2] + assert [candidate.candidate_id for candidate, _ in top_candidates_with_speedup] == [ + 6, + 5, ] - top_candidates = all_candidates[:2] - assert [candidate.candidate_id for candidate in top_candidates] == [6, 5]