From b5467f29a09023197142dede304288ca6cb42266 Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian@gmail.com>
Date: Mon, 20 Jan 2025 09:58:17 -0600
Subject: [PATCH] avoid using union

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
---
 tuner/tuner/libtuner.py      | 22 ++++++++++------------
 tuner/tuner/libtuner_test.py | 17 ++++++++++-------
 2 files changed, 20 insertions(+), 19 deletions(-)

diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py
index dceb79b66..c257ea910 100644
--- a/tuner/tuner/libtuner.py
+++ b/tuner/tuner/libtuner.py
@@ -876,14 +876,13 @@ def is_valid_for_device(self, device_id: str) -> bool:
 
     def get_candidates_ordered_by_speedup(
         self, candidate_results: list[BenchmarkResult]
-    ) -> Union[list[BenchmarkResult], list[tuple[BenchmarkResult, float]]]:
+    ) -> list[tuple[BenchmarkResult, float]]:
         """
-        returns:
-        - `list[BenchmarkResult]` sorted by runtime if no valid baselines exist.
-        - `list[tuple[BenchmarkResult, float]]` sorted by speedup when baselines are available.
+        Returns a list of tuples (BenchmarkResult, speedup) sorted in ascending order based on speedup
+        or raw runtime.
 
-        If no valid baseline times are available across all devices, candidates are sorted
-        and returned based on their raw runtime in ascending order.
+        If no valid baseline times are available across all devices, candidates are sorted based on
+        their raw runtime. A placeholder speedup value of 1.0 is assigned to each candidate.
 
         If valid baseline times exist, speedup is defined as the ratio of the candidate's runtime to
         the average baseline time for the corresponding device as:
@@ -896,7 +895,10 @@ def get_candidates_ordered_by_speedup(
         if not self.is_valid():
             logging.warning("No valid baseline times available.")
             # Use the candidate time directly when no baselines are available.
-            return sorted(candidate_results, key=lambda candidate: candidate.time)
+            return sorted(
+                [(candidate, 1.0) for candidate in candidate_results],
+                key=lambda x: x[0].time,
+            )
 
         # Calculate the fallback baseline as the average of all valid times across devices.
         valid_baseline_times = [
@@ -1065,11 +1067,7 @@ def benchmark(
                 f"({percentage_of_baseline:.1f}% of baseline)"
             )
     else:
-        top_candiates = [
-            result if isinstance(result, BenchmarkResult) else result[0]
-            for result in top_candidates_with_speedup
-        ]
-        for candidate in top_candiates:
+        for candidate, _ in top_candidates_with_speedup:
             time_ms = candidate.time
             candidate_id = candidate.candidate_id
             top_candidate_ids.append(candidate_id)
diff --git a/tuner/tuner/libtuner_test.py b/tuner/tuner/libtuner_test.py
index e0e3f6ba0..bbca2d370 100644
--- a/tuner/tuner/libtuner_test.py
+++ b/tuner/tuner/libtuner_test.py
@@ -290,11 +290,14 @@ def test_baseline_result_handler_speedup():
     ]
 
     handler = libtuner.BaselineResultHandler()
-    all_candidates = handler.get_candidates_ordered_by_speedup(candidates)
-    assert all_candidates == [
-        candidates[1],
-        candidates[0],
-        candidates[2],
+    all_candidates_with_speedup = handler.get_candidates_ordered_by_speedup(candidates)
+    assert all_candidates_with_speedup == [
+        (candidates[1], 1.0),
+        (candidates[0], 1.0),
+        (candidates[2], 1.0),
+    ]
+    top_candidates_with_speedup = all_candidates_with_speedup[:2]
+    assert [candidate.candidate_id for candidate, _ in top_candidates_with_speedup] == [
+        6,
+        5,
     ]
-    top_candidates = all_candidates[:2]
-    assert [candidate.candidate_id for candidate in top_candidates] == [6, 5]