Skip to content

Commit

Permalink
rclean the code agin due to using handler class
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Jan 17, 2025
1 parent 6b4b292 commit 32aecfe
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 78 deletions.
48 changes: 5 additions & 43 deletions tuner/tuner/libtuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,47 +222,6 @@ def validate_devices(user_devices: list[str]) -> None:
)


def get_valid_benchmark_results(
benchmark_results: list[BenchmarkResult],
) -> list[BenchmarkResult]:
"""
Filter the benchmark_results list to return list with finite `time` values.
"""
filtered_benchmark_results = [r for r in benchmark_results if math.isfinite(r.time)]
if len(filtered_benchmark_results) == 0:
logging.error("No successful candidate benchmarks.")

return filtered_benchmark_results


def are_baseline_devices_unique(
baseline_results: list[BenchmarkResult],
) -> bool:
return len(baseline_results) == len(
set(map(lambda r: r.device_id, baseline_results))
)


def map_baseline_by_device(
baseline_result: list[BenchmarkResult],
) -> dict[str, float]:
if not are_baseline_devices_unique(baseline_result):
logging.warning("Duplicate device IDs detected in the baseline results.")
baseline_device_times = defaultdict(list)

for r in baseline_result:
if math.isfinite(r.time):
baseline_device_times[r.device_id].append(r.time)

average_device_times = {
device_id: sum(times) / len(times)
for device_id, times in baseline_device_times.items()
if times
}

return average_device_times


class ExecutionPhases(str, Enum):
dont_stop = ""
generate_candidates = "generate-candidates"
Expand Down Expand Up @@ -849,6 +808,9 @@ def add_run(self, results: list[BenchmarkResult]) -> None:
for result in results:
self.device_baseline_times[result.device_id].append(result.time)

def are_baseline_devices_unique(self, results: list[BenchmarkResult]) -> bool:
return len(results) == len(set(map(lambda r: r.device_id, results)))

def get_valid_time_ms(self, device_id: str) -> list[float]:
return [
time
Expand Down Expand Up @@ -1054,7 +1016,7 @@ def benchmark(
if not baseline_handler.is_valid():
logging.warning("Baseline result is not valid after first run")

if not are_baseline_devices_unique(first_baseline_result):
if not baseline_handler.are_baseline_devices_unique(first_baseline_result):
logging.warning("Duplicate device IDs detected in the first baseline results.")

candidate_indices = [i for i in compiled_candidates if i != 0]
Expand All @@ -1081,7 +1043,7 @@ def benchmark(
if not baseline_handler.is_valid():
logging.warning("Baseline result is not valid after second run")

if not are_baseline_devices_unique(second_baseline_result):
if not baseline_handler.are_baseline_devices_unique(second_baseline_result):
logging.warning("Duplicate device IDs detected in the second baseline results.")

speedup_result = baseline_handler.calculate_speedup(candidate_results)
Expand Down
38 changes: 3 additions & 35 deletions tuner/tuner/libtuner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,41 +192,6 @@ def test_enum_collision():
from iree.compiler.dialects import linalg, vector, iree_gpu, iree_codegen, iree_input # type: ignore


def test_validate_benchmark_results():
benchmark_results = [
libtuner.BenchmarkResult(0, math.inf, "hip://0"),
]

result = libtuner.get_valid_benchmark_results(benchmark_results)
assert result == []

benchmark_results = [
libtuner.BenchmarkResult(0, math.inf, "hip://0"),
libtuner.BenchmarkResult(0, 0.1, "hip://1"),
]
result = libtuner.get_valid_benchmark_results(benchmark_results)
assert len(result) == 1
assert result[0].candidate_id == 0
assert result[0].time == 0.1
assert result[0].device_id == "hip://1"


def test_check_baseline_devices_uniqueness():
baseline_results = [
libtuner.BenchmarkResult(0, 1000.0, "hip://0"),
libtuner.BenchmarkResult(0, 2000.0, "hip://1"),
libtuner.BenchmarkResult(0, 3000.0, "hip://2"),
]
assert libtuner.are_baseline_devices_unique(baseline_results)

baseline_results = [
libtuner.BenchmarkResult(0, 1000.0, "hip://0"),
libtuner.BenchmarkResult(0, 2000.0, "hip://0"),
libtuner.BenchmarkResult(0, 3000.0, "hip://2"),
]
assert not libtuner.are_baseline_devices_unique(baseline_results)


def test_baseline_result_handler_valid():
handler = libtuner.BaselineResultHandler()
assert not handler.is_valid()
Expand All @@ -235,6 +200,8 @@ def test_baseline_result_handler_valid():
libtuner.BenchmarkResult(0, math.inf, "hip://1"),
libtuner.BenchmarkResult(0, 0.7, "hip://0"),
]
assert handler.are_baseline_devices_unique([])
assert not handler.are_baseline_devices_unique(baseline)
handler.add_run(baseline)
assert handler.is_valid()
assert handler.is_valid_for_device("hip://0")
Expand All @@ -253,6 +220,7 @@ def test_baseline_result_handler_valid():
libtuner.BenchmarkResult(0, 1.2, "hip://1"),
libtuner.BenchmarkResult(0, 0.8, "hip://1"),
]
assert not handler.are_baseline_devices_unique(additional_baseline)
handler.add_run(additional_baseline)
assert handler.num_successful_runs("hip://0") == 2
assert handler.num_successful_runs("hip://0") == 2
Expand Down

0 comments on commit 32aecfe

Please sign in to comment.