diff --git a/benchmarks/benchmark_gemm.py b/benchmarks/benchmark_gemm.py index 6a7dc7bd7..df0d56b8f 100644 --- a/benchmarks/benchmark_gemm.py +++ b/benchmarks/benchmark_gemm.py @@ -4,6 +4,10 @@ from triton.testing import do_bench +if torch.version.cuda: + backendBLAS = "cuBLAS" +elif torch.version.hip: + backendBLAS = "hipBLAS" def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs): """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" @@ -34,10 +38,10 @@ def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2) nFLOPS_matmul = 2 * m * n * k time.sleep(2) # to reduce power throttling - timing = benchmark_forward(torch.matmul, a, b, desc='cuBLAS', verbose=verbose, repeats=repeats)[1] + timing = benchmark_forward(torch.matmul, a, b, desc=backendBLAS, verbose=verbose, repeats=repeats)[1] tflops_matmul[k] = nFLOPS_matmul / timing.mean * 1e-12 - print(f'[torch.utils.benchmark] cuBLAS, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul[k]:.1f} TFLOPS') + print(f'[torch.utils.benchmark] {backendBLAS}, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul[k]:.1f} TFLOPS') time.sleep(2) # to reduce power throttling ms = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=repeats) tflops_matmul1[k] = nFLOPS_matmul / ms * 1e-9 - print(f'[triton.test.do_bench] cuBLAS, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1[k]:.1f} TFLOPS') + print(f'[triton.test.do_bench] {backendBLAS}, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1[k]:.1f} TFLOPS')