diff --git a/python/perf-kernels/03-matrix-multiplication-stream-k.py b/python/perf-kernels/03-matrix-multiplication-stream-k.py index 77d459f624a9..7493e682b4a1 100755 --- a/python/perf-kernels/03-matrix-multiplication-stream-k.py +++ b/python/perf-kernels/03-matrix-multiplication-stream-k.py @@ -153,7 +153,7 @@ def set_debug(debug: bool): matmul._debug = debug @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int): + def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): device = a.device assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" @@ -167,7 +167,7 @@ def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: total_blocks_M = triton.cdiv(M, BLK_M) total_blocks_N = triton.cdiv(N, BLK_N) iters_per_tile = triton.cdiv(K, BLK_K) - GROUP_M = 0 # 0 to disable swizzling + GROUP_M = 4 # 0 to disable swizzling total_tiles = total_blocks_M * total_blocks_N if total_programs_streamk > 0: # Stream-K @@ -232,6 +232,8 @@ def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: num_stages=num_stages, num_warps=num_warps, waves_per_eu = waves_per_eu, + matrix_instr_nonkdim = mfmaInstrSize, + kpack = kpack, ) if matmul._debug: print(f"{kk.n_regs} registers used, {kk.n_spills} spills") @@ -241,8 +243,8 @@ def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: return c @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, num_stages=3, num_warps=4, waves_per_eu = 2): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu = waves_per_eu) + def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, num_stages=3, num_warps=4, waves_per_eu = 2, mfmaInstrSize = 16, kpack = 1): + return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu = waves_per_eu, mfmaInstrSize = mfmaInstrSize, kpack = kpack) # --------------------------------------------------------------------------- # Example and Benchmark @@ -250,20 +252,24 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=1 perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) -m, n, k = 4864, 4096, 8256 # some problem size to test +#m, n, k = 4864, 4096, 8256 # some problem size to test +#m, n, k = 4096, 4096, 8192 # some problem size to test #m, n, k = 8192, 8192, 8192 # some problem size to test +m, n, k = 6912, 768, 256 # some problem size to test A = torch.randn(m, k, device="cuda", dtype=torch.float16) B = torch.randn(k, n, device="cuda", dtype=torch.float16) -BLK_M = 256 -BLK_N = 128 -BLK_K = 32 +BLK_M = 64 +BLK_N = 64 +BLK_K = 64 two_tiles = 'True' num_stages = 0 -num_warps = 8 +num_warps = 4 waves_per_eu = 0 +mfmaInstrSize = 16 +kpack = 2 matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu) +C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack) #exit(0) matmul.set_debug(False) expected = A @ B @@ -277,13 +283,13 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=1 triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu)) +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu)) +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu)) +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") exit(0) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py index 9ab14247b9d9..0491e71bcde8 100644 --- a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py +++ b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py @@ -122,36 +122,40 @@ def get_full_tuning_space(): if not tuning_full_space: return configs - block_mn_range = [128, 256] + block_mn_range = [64, 128, 256] block_k_range = [16, 32, 64] num_warps_range = [1, 2, 4, 8] # group_m_range = [0, 1, 2, 4, 8] - group_m_range = [4, 8] + group_m_range = [0, 4, 8] # For now we see better perf with num_stages=0 for all gemm configs we care # But keep this explicit so that we do not forget we may need to set it to # other values in the future num_stage_range = [0] waves_per_eu_range =[0] + matrix_instr_nonkdim_range = [16, 32] + kpack_range = [1, 2] for block_m in block_mn_range: for block_n in block_mn_range: for block_k in block_k_range: for num_warps in num_warps_range: for group_m in group_m_range: - for num_waves_per_eu in waves_per_eu_range: - for num_stages in num_stage_range: - configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'GROUP_M': group_m, 'waves_per_eu': num_waves_per_eu}, num_stages=num_stages, num_warps=num_warps)) + for num_stages in num_stage_range: + for num_waves_per_eu in waves_per_eu_range: + for matrix_instr_nonkdim in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'GROUP_M': group_m, 'waves_per_eu': num_waves_per_eu, 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack}, num_stages=num_stages, num_warps=num_warps,)) return configs - +#To do: we need update the default autotune configuration once we go through the whole performance test sets. @triton.autotune( configs= get_full_tuning_space() if tuning_full_space else [ - triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 0}, num_warps=4, num_stages=0), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 2}, num_warps=4, num_stages=0), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 0}, num_warps=4, num_stages=0), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 2}, num_warps=4, num_stages=0), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 16, 'waves_per_eu': 0}, num_warps=4, num_stages=0), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 16, 'GROUP_M': 0, 'waves_per_eu': 0}, num_warps=4, num_stages=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 1}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 2, 'matrix_instr_nonkdim': 16, 'kpack': 1}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 1}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 2, 'matrix_instr_nonkdim': 16, 'kpack': 1}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 16, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 1}, num_warps=4, num_stages=0), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 16, 'GROUP_M': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 1}, num_warps=4, num_stages=4), ], key=['M', 'N', 'K'], # prune_configs_by={ @@ -264,7 +268,7 @@ def set_debug(debug: bool): matmul._debug = debug @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLOCK_M: int, BLOCK_N: int, BLOCK_K: int, two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int): + def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLOCK_M: int, BLOCK_N: int, BLOCK_K: int, two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): def compute_total_blocking_tiles(M, N, BLOCK_M, BLOCK_N, two_tiles, total_programs_streamk): total_blocks_M = triton.cdiv(M, BLOCK_M) total_blocks_N = triton.cdiv(N, BLOCK_N) @@ -370,8 +374,8 @@ def compute_total_blocking_tiles(M, N, BLOCK_M, BLOCK_N, two_tiles, total_progra return c @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, two_tiles=True, num_stages=3, num_warps=4, waves_per_eu = 2): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu = waves_per_eu) + def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, two_tiles=True, num_stages=3, num_warps=4, waves_per_eu = 2, mfmaInstrSize = 16, kpack = 1): + return matmul._call(a=a, b=b, total_programs_streamk=grid, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu = waves_per_eu, mfmaInstrSize = mfmaInstrSize, kpack = kpack) # --------------------------------------------------------------------------- # Example and Benchmark @@ -380,21 +384,24 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) #m, n, k = 1792, 7424, 4864 # some problem size to test -m, n, k = 8192, 8192, 8192 # some problem size to test -A = torch.randn(m, k, device="cuda", dtype=torch.float32) -B = torch.randn(k, n, device="cuda", dtype=torch.float32) +#m, n, k = 8192, 8192, 8192 # some problem size to test +m, n, k = 4096, 4096, 8192 # some problem size to test +A = torch.randn(m, k, device="cuda", dtype=torch.float16) +B = torch.randn(k, n, device="cuda", dtype=torch.float16) #A = torch.ones((m, k), device="cuda", dtype=torch.float16) #B = torch.ones((k, n), device="cuda", dtype=torch.float16) BLOCK_M = 256 -BLOCK_N = 128 -BLOCK_K = 16 +BLOCK_N = 256 +BLOCK_K = 64 two_tiles = True num_stages = 0 -num_warps = 4 -waves_per_eu = 2 +num_warps = 8 +waves_per_eu = 0 +mfmaInstrSize = 16 +kpack = 1 matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu) +C = matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack) matmul.set_debug(False) expected = A @ B @@ -407,15 +414,15 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu)) +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu)) +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu)) +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') @@ -425,6 +432,7 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK # --------------------------------------------------------------------------- # tried to reproduce the tests described in the paper +perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) num_samples = 1000 # 32768 step = 256 values = ((torch.logspace(torch.tensor(step).log2(), torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py index 3db9e368dd83..71e83c78469f 100644 --- a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py +++ b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py @@ -161,7 +161,7 @@ def set_debug(debug: bool): matmul._debug = debug @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int): + def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): device = a.device assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" @@ -175,7 +175,7 @@ def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: total_blocks_M = triton.cdiv(M, BLK_M) total_blocks_N = triton.cdiv(N, BLK_N) iters_per_tile = triton.cdiv(K, BLK_K) - GROUP_M = 0 # 0 to disable swizzling + GROUP_M = 4 # 0 to disable swizzling total_tiles = total_blocks_M * total_blocks_N if total_programs_streamk > 0: # Stream-K @@ -235,6 +235,8 @@ def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: num_stages=num_stages, num_warps=num_warps, waves_per_eu = waves_per_eu, + matrix_instr_nonkdim = mfmaInstrSize, + kpack = kpack ) if matmul._debug: print(f"{k1.n_regs} registers used, {k1.n_spills} spills") @@ -260,6 +262,8 @@ def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: num_stages=num_stages, num_warps=num_warps, waves_per_eu = waves_per_eu, + matrix_instr_nonkdim = mfmaInstrSize, + kpack = kpack ) if matmul._debug: print(f"{k2.n_regs} registers used, {k2.n_spills} spills") @@ -267,8 +271,8 @@ def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: return c @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, num_stages=3, num_warps=4, waves_per_eu = 2): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu = waves_per_eu) + def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, num_stages=3, num_warps=4, waves_per_eu = 2, mfmaInstrSize = 16, kpack = 1): + return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu = waves_per_eu, mfmaInstrSize = mfmaInstrSize, kpack = kpack) # --------------------------------------------------------------------------- @@ -277,23 +281,25 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=1 perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) -m, n, k = 4864, 4096, 8256 # some problem size to test -#m, n, k = 6912, 768, 256 # some problem size to test +#m, n, k = 4864, 4096, 8256 # some problem size to test +m, n, k = 6912, 768, 256 # some problem size to test #m, n, k = 8192, 8192, 8192 # some problem size to test A = torch.randn(m, k, device="cuda", dtype=torch.float16) B = torch.randn(k, n, device="cuda", dtype=torch.float16) #A = torch.ones((m, k), device="cuda", dtype=torch.float16) #B = torch.ones((k, n), device="cuda", dtype=torch.float16) -BLK_M = 128 -BLK_N = 256 -BLK_K = 32 +BLK_M = 64 +BLK_N = 64 +BLK_K = 64 two_tiles='True' num_stages=0 num_warps=4 waves_per_eu = 0 +mfmaInstrSize = 16 +kpack = 2 matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu) +C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack) #exit(0) matmul.set_debug(False) expected = A @ B @@ -305,13 +311,13 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=1 triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu)) +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu)) +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu)) +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") exit(0) @@ -355,7 +361,7 @@ def wrapper_matmul(*args, **kwargs): nb_sm.append(total_tile) nb_sm += random.sample(range(2, total_sm * 2, 2), 10) for sm in nb_sm: - triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu)) + triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) max_disc = (output - expected).abs().max().item() # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}"