Skip to content

Commit cae96b4

Browse files
committed
[tritonbench] Report tflops by default for gemm; fix exception handling
Summary: TFLOPS is the core metric for gemm. Along the way I hit some bugs and weirdness: - You couldn't Ctrl-C out of tritonbench, because the `finally` clause contained a return, which [suppresses the exception](https://docs.python.org/3/tutorial/errors.html#defining-clean-up-actions) - In generally I don't think the framework should catch RuntimeErrors, it makes it really hard to debug stuff because the desired result just ends up missing - In fact we had a typo (`metric` instead of `metrics` in the framework code that was never caught because it was caught and suppressed Test Plan: ``` python run_benchmark.py triton --op gemm --splitk ```
1 parent 4c7ec3a commit cae96b4

File tree

2 files changed

+5
-14
lines changed

2 files changed

+5
-14
lines changed

torchbenchmark/operators/gemm/operator.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989

9090

9191
class Operator(BenchmarkOperator):
92-
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]
92+
DEFAULT_METRICS = ["latency", "speedup", "accuracy", "tflops"]
9393
DEFAULT_PRECISION = "fp16"
9494

9595
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
@@ -202,13 +202,7 @@ def get_input_iter(self) -> Generator:
202202
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
203203
output = fn()
204204
baseline_output = baseline_fn()
205-
accuracy = True
206-
try:
207-
torch.testing.assert_close(output, baseline_output)
208-
except Exception:
209-
accuracy = False
210-
finally:
211-
return accuracy
205+
return torch.allclose(output, baseline_output)
212206

213207
def plot(self):
214208
@triton.testing.perf_report(

torchbenchmark/util/triton_op.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,9 @@ def _init_extra_metrics() -> Dict[str, Any]:
702702
if "hw_roofline" in self.required_metrics:
703703
metrics.hw_roofline = self.hw_roofline()
704704
if "tflops" in self.required_metrics:
705-
metrics.tflops = self.tflops(fn_name, self.example_inputs, metric)
705+
metrics.tflops = self.tflops(fn_name, self.example_inputs, metrics)
706706
if "compile_time" in self.required_metrics:
707-
metrics.compile_time = self.compile_time(input_id, fn_name, metric)
707+
metrics.compile_time = self.compile_time(input_id, fn_name, metrics)
708708
if "ncu_trace" in self.required_metrics:
709709
metrics.ncu_trace = self.ncu_trace(input_id, fn_name)
710710
if "kineto_trace" in self.required_metrics:
@@ -749,10 +749,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
749749
metrics.extra_metrics[metric_name] = func(fn, self.example_inputs, metrics)
750750
except torch.cuda.OutOfMemoryError:
751751
metrics.error_msg = "CUDA OOM"
752-
except RuntimeError as e:
753-
metrics.error_msg = str(e)
754-
finally:
755-
return metrics
752+
return metrics
756753

757754
def get_peak_mem(
758755
self, fn: Callable

0 commit comments

Comments
 (0)