Skip to content

Commit

Permalink
address reviewer comments and format code
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Jan 17, 2025
1 parent 51c0054 commit 44b7e92
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 39 deletions.
61 changes: 30 additions & 31 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ class BenchmarkResult:
time: float
device_id: str

def is_valid(self) -> bool:
return math.isfinite(self.time)


def unit_to_microseconds(real_time: float, time_unit: str) -> float:
unit_conversions = {
Expand Down Expand Up @@ -750,7 +753,9 @@ def collision_handler(index_hash_list: list[tuple[int, str]]) -> tuple[bool, lis
return collision_detected, unique_indexes


def benchmark_candidates(candidate_indices, devices, tuning_client, candidate_trackers):
def benchmark_candidates(
candidate_indices, devices, tuning_client, candidate_trackers
) -> list[BenchmarkResult]:
"""
Runs the benchmarking for a given list of candidate indices.
"""
Expand Down Expand Up @@ -807,22 +812,24 @@ def __init__(self) -> None:
)

def add_run(self, results: list[BenchmarkResult]) -> None:
if not BaselineResultHandler.are_baseline_devices_unique(results):
logging.warning(
"Duplicate device IDs detected in the first baseline results."
)
for result in results:
self.device_baseline_results[result.device_id].append(result)

def are_baseline_devices_unique(self, results: list[BenchmarkResult]) -> bool:
@staticmethod
def are_baseline_devices_unique(results: list[BenchmarkResult]) -> bool:
return len(results) == len(set(result.device_id for result in results))

def get_valid_time_ms(self, device_id: str) -> list[float]:
return [
result.time
for result in self.device_baseline_results.get(device_id, [])
if math.isfinite(result.time)
if result.is_valid()
]

def num_successful_runs(self, device_id: str) -> int:
return len(self.get_valid_time_ms(device_id))

def get_average_result_ms(self, device_id: str) -> Optional[float]:
valid_times = self.get_valid_time_ms(device_id)
if valid_times:
Expand All @@ -834,38 +841,32 @@ def detect_regressions(
baseline_results: list[BenchmarkResult],
threshold: float = 1.03,
) -> list[str]:
"""
Return a list of device IDs where regressions were detected.
"""
regressions = []
for result in baseline_results:
if not math.isfinite(result.time):
if not result.is_valid():
continue

baseline_avg = self.get_average_result_ms(result.device_id)
if baseline_avg is not None and result.time > baseline_avg * threshold:
regressions.append(result.device_id)
logging.warning(
f"Performance regression detected on device {result.device_id}: "
f"Stored average baseline time = {baseline_avg:.2f} ms, "
f"New baseline time = {result.time:.2f} ms, "
f"Slower by {((result.time - baseline_avg) / baseline_avg) * 100:.2f}%"
)

return regressions

def is_valid(self) -> bool:
"""
Check if there are any valid finite baseline time recorded.
Return True if at least a valid (finite) baseline time recorded,
otherwise False.
This method determines whether the baseline data is available for computations
such as calculating speedup.
Return True iff at least a valid (finite) baseline time recorded.
"""
return any(
self.get_valid_time_ms(device_id)
for device_id in self.device_baseline_results
)

def is_valid_for_device(self, device_id: str) -> bool:
return bool(self.get_valid_time_ms(device_id))
return len(self.get_valid_time_ms(device_id)) != 0

def calculate_speedup(
self, candidate_results: list[BenchmarkResult]
Expand All @@ -876,18 +877,18 @@ def calculate_speedup(
"""
if not self.is_valid():
logging.warning("No valid baseline times available.")
# Use the candidate time directly when no baselines are available
# Use the candidate time directly when no baselines are available.
return {
candidate.candidate_id: candidate.time
for candidate in candidate_results
}

# Calculate the fallback baseline as the average of all valid times across devices
# Calculate the fallback baseline as the average of all valid times across devices.
valid_baseline_times = [
result.time
for device_id in self.device_baseline_results
for result in self.device_baseline_results[device_id]
if math.isfinite(result.time)
if result.is_valid()
]

fallback_baseline = sum(valid_baseline_times) / len(valid_baseline_times)
Expand Down Expand Up @@ -1017,10 +1018,7 @@ def benchmark(
baseline_handler = BaselineResultHandler()
baseline_handler.add_run(first_baseline_result)
if not baseline_handler.is_valid():
logging.warning("Baseline result is not valid after first run")

if not baseline_handler.are_baseline_devices_unique(first_baseline_result):
logging.warning("Duplicate device IDs detected in the first baseline results.")
logging.warning("Baseline run failed.")

candidate_indices = [i for i in compiled_candidates if i != 0]
candidate_results = benchmark_candidates(
Expand All @@ -1039,26 +1037,27 @@ def benchmark(
regression_devices = baseline_handler.detect_regressions(second_baseline_result)
if regression_devices:
logging.warning(
f"Performance regressions detected for the following devices: {', '.join(regression_devices)}"
f"Performance regressions detected for the following devices: {', '.join(regression_devices)}."
)
baseline_handler.add_run(second_baseline_result)

if not baseline_handler.is_valid():
logging.warning("Baseline result is not valid after second run")

if not baseline_handler.are_baseline_devices_unique(second_baseline_result):
logging.warning("Duplicate device IDs detected in the second baseline results.")
logging.warning("Baseline run failed.")

speedup_result = baseline_handler.calculate_speedup(candidate_results)
# If the baseline is valid (`baseline_handler.is_valid()`), `speedup_result` represents the speedup values.
# Otherwise, `speedup_result` contains the raw time values.
top_candidates = baseline_handler.get_top_candidates(speedup_result, num_candidates)
if baseline_handler.is_valid():
candidate_time_map = {
result.candidate_id: result.time for result in candidate_results
}
for candidate_id in top_candidates:
speedup_value = speedup_result[candidate_id]
actual_time = candidate_time_map[candidate_id]
percentage_of_baseline = speedup_value * 100
logging.info(
f"Candidate {candidate_id} time: {speedup_value:.2f} ms "
f"Candidate {candidate_id} time: {actual_time:.2f} ms "
f"({percentage_of_baseline:.1f}% of baseline)"
)
else:
Expand Down
18 changes: 10 additions & 8 deletions tuner/tuner/libtuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ def test_baseline_result_handler_valid():
libtuner.BenchmarkResult(0, math.inf, "hip://1"),
libtuner.BenchmarkResult(0, 0.7, "hip://0"),
]
assert handler.are_baseline_devices_unique([])
assert not handler.are_baseline_devices_unique(baseline)
assert libtuner.BaselineResultHandler.are_baseline_devices_unique([])
assert not libtuner.BaselineResultHandler.are_baseline_devices_unique(baseline)
handler.add_run(baseline)
assert handler.is_valid()
assert handler.is_valid_for_device("hip://0")
Expand All @@ -214,20 +214,22 @@ def test_baseline_result_handler_valid():
libtuner.BenchmarkResult(0, math.inf, "hip://1"),
]

assert handler.num_successful_runs("hip://0") == 2
assert handler.num_successful_runs("hip://1") == 0
assert handler.num_successful_runs("hip://2") == 0
assert handler.get_valid_time_ms("hip://0") == [0.5, 0.7]
assert handler.get_valid_time_ms("hip://1") == []
assert handler.get_valid_time_ms("hip://2") == []

additional_baseline = [
libtuner.BenchmarkResult(0, math.inf, "hip://1"),
libtuner.BenchmarkResult(0, math.nan, "hip://1"),
libtuner.BenchmarkResult(0, 1.2, "hip://1"),
libtuner.BenchmarkResult(0, 0.8, "hip://1"),
]
assert not handler.are_baseline_devices_unique(additional_baseline)
assert not libtuner.BaselineResultHandler.are_baseline_devices_unique(
additional_baseline
)
handler.add_run(additional_baseline)
assert handler.num_successful_runs("hip://0") == 2
assert handler.num_successful_runs("hip://0") == 2
assert handler.get_valid_time_ms("hip://0") == [0.5, 0.7]
assert handler.get_valid_time_ms("hip://1") == [1.2, 0.8]
assert handler.is_valid_for_device("hip://1")

assert handler.get_average_result_ms("hip://0") == 0.6
Expand Down

0 comments on commit 44b7e92

Please sign in to comment.