Skip to content

Commit 8a8c1fc

Browse files
bertmaherfacebook-github-bot
authored andcommitted
Report tflops by default for gemm; fix exception handling (#2259)
Summary: TFLOPS is the core metric for gemm. Along the way I hit some bugs and weirdness: - You couldn't Ctrl-C out of tritonbench, because the `finally` clause contained a return, which [suppresses the exception](https://docs.python.org/3/tutorial/errors.html#defining-clean-up-actions) - In generally I don't think the framework should catch RuntimeErrors, it makes it really hard to debug stuff because the desired result just ends up missing - In fact we had a typo (`metric` instead of `metrics` in the framework code that was never caught because it was caught and suppressed Pull Request resolved: #2259 Test Plan: ``` python run_benchmark.py triton --op gemm --splitk ``` Reviewed By: xuzhao9 Differential Revision: D57171806 Pulled By: bertmaher fbshipit-source-id: 74568625ad10907d9def8916abfc2f6292cdc6d6
1 parent c1f2dc8 commit 8a8c1fc

File tree

2 files changed

+2
-10
lines changed

2 files changed

+2
-10
lines changed

torchbenchmark/operators/gemm/operator.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@
8989

9090

9191
class Operator(BenchmarkOperator):
92-
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]
92+
DEFAULT_METRICS = ["latency", "speedup", "accuracy", "tflops"]
9393
DEFAULT_PRECISION = "fp16"
9494

9595
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
@@ -202,13 +202,7 @@ def get_input_iter(self) -> Generator:
202202
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
203203
output = fn()
204204
baseline_output = baseline_fn()
205-
accuracy = True
206-
try:
207-
torch.testing.assert_close(output, baseline_output)
208-
except Exception:
209-
accuracy = False
210-
finally:
211-
return accuracy
205+
return torch.allclose(output, baseline_output)
212206

213207
def plot(self):
214208
@triton.testing.perf_report(

torchbenchmark/util/triton_op.py

-2
Original file line numberDiff line numberDiff line change
@@ -752,8 +752,6 @@ def _init_extra_metrics() -> Dict[str, Any]:
752752
metrics.extra_metrics[metric_name] = func(fn, self.example_inputs, metrics)
753753
except torch.cuda.OutOfMemoryError:
754754
metrics.error_msg = "CUDA OOM"
755-
except RuntimeError as e:
756-
metrics.error_msg = str(e)
757755
return metrics
758756

759757
def get_peak_mem(

0 commit comments

Comments
 (0)