Skip to content

Commit

Permalink
update the code comments
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Jan 17, 2025
1 parent f188afe commit 180c435
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
24 changes: 16 additions & 8 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,9 +881,13 @@ def calculate_speedup(
Speedup is defined as the ratio of the candidate's runtime to the average baseline time
for the corresponding device as:
speedup = candidate_runtime / avg_baseline_time
speedup = candidate_runtime / avg_baseline_time (or fallback_baseline)
If no valid baseline times are available, the candidate'sruntime is used directly as:
If no valid baseline times are available for a specific device, the fallback baseline is used.
The fallback baseline is calculated as the average of all valid baseline times across devices.
If no valid baseline times are available across all devices, the candidate's runtime is
used directly as:
speedup = candidate_runtime
Expand All @@ -909,13 +913,15 @@ def calculate_speedup(

speedup_by_candidate = {}
for candidate in candidate_results:
baseline_avg = self.get_average_result_ms(candidate.device_id)
if baseline_avg is None or not math.isfinite(baseline_avg):
baseline_avg = fallback_baseline
speedup_by_candidate[candidate.candidate_id] = candidate.time / baseline_avg
baseline_avg_ms = self.get_average_result_ms(candidate.device_id)
if baseline_avg_ms is None:
baseline_avg_ms = fallback_baseline
speedup_by_candidate[candidate.candidate_id] = (
candidate.time / baseline_avg_ms
)
return speedup_by_candidate

def get_top_candidates(
def sort_candidates_with_speedup(
self,
speedup_by_candidate: dict[int, float],
) -> list[tuple[int, float]]:
Expand Down Expand Up @@ -1054,7 +1060,9 @@ def benchmark(
logging.warning("Baseline run failed.")

speedup_result = baseline_handler.calculate_speedup(candidate_results)
all_candidates_with_speedup = baseline_handler.get_top_candidates(speedup_result)
all_candidates_with_speedup = baseline_handler.sort_candidates_with_speedup(
speedup_result
)
top_candidates_with_speedup = all_candidates_with_speedup[:num_candidates]

if baseline_handler.is_valid():
Expand Down
7 changes: 3 additions & 4 deletions tuner/tuner/libtuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def test_baseline_result_handler_speedup():
4: 0.2 / 0.875,
}

all_candidates_with_speedup = handler.get_top_candidates(speedup)
all_candidates_with_speedup = handler.sort_candidates_with_speedup(speedup)
assert all_candidates_with_speedup == [
(4, 0.2 / 0.875),
(1, 0.4 / 0.9),
Expand All @@ -290,8 +290,7 @@ def test_baseline_result_handler_speedup():
7: 0.8 / 1.2,
}

all_candidates_with_speedup = handler.get_top_candidates(speedup)
print(all_candidates_with_speedup)
all_candidates_with_speedup = handler.sort_candidates_with_speedup(speedup)
assert all_candidates_with_speedup == [
(5, 0.6 / 0.9),
(7, 0.8 / 1.2),
Expand All @@ -307,7 +306,7 @@ def test_baseline_result_handler_speedup():
6: 0.4,
7: 0.8,
}
all_candidates_with_speedup = handler.get_top_candidates(speedup)
all_candidates_with_speedup = handler.sort_candidates_with_speedup(speedup)
assert all_candidates_with_speedup == [
(6, 0.4),
(5, 0.6),
Expand Down

0 comments on commit 180c435

Please sign in to comment.