diff --git a/python/perf-kernels/tools/tune_gemm/benchmarking_results_ref@709e3fdbf.csv b/python/perf-kernels/tools/tune_gemm/benchmarking_results_ref@709e3fdbf.csv new file mode 100644 index 000000000000..dec37e4de94f --- /dev/null +++ b/python/perf-kernels/tools/tune_gemm/benchmarking_results_ref@709e3fdbf.csv @@ -0,0 +1,14 @@ +trans,M,N,K,TFLOPS,us +TN,4864,4096,4096,467.39,349.19 +TN,4864,4096,4160,567.17,292.26 +TN,4864,4096,4224,557.49,301.90 +TN,4864,4096,4288,569.55,299.99 +TN,4864,4096,4097,501.58,325.47 +TN,4864,4096,4098,491.96,331.92 +TN,4864,4096,4100,503.51,324.46 +TN,4864,4096,4104,515.70,317.10 +TN,4864,4096,4112,525.66,311.70 +TN,4864,8192,4096,519.95,627.79 +TN,4864,8192,4160,579.14,572.43 +TN,4864,8192,8192,543.30,1201.6 +TN,4864,8192,8256,563.43,1167.7 diff --git a/python/perf-kernels/tools/tune_gemm/database.yaml b/python/perf-kernels/tools/tune_gemm/database.yaml new file mode 100644 index 000000000000..33c892f3cdde --- /dev/null +++ b/python/perf-kernels/tools/tune_gemm/database.yaml @@ -0,0 +1,18 @@ +# M // BLOCK_M * N // BLOCK_N % 304 == 0 +## 1 workgroup / CU +- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2} +- {'M': 4864, 'N': 4096, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2} +- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2} +- {'M': 4864, 'N': 4096, 'K': 4288, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2} +## 1 workgroup / CU masked loadK +- {'M': 4864, 'N': 4096, 'K': 4097, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2} +- {'M': 4864, 'N': 4096, 'K': 4098, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2} +- {'M': 4864, 'N': 4096, 'K': 4100, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2} +- {'M': 4864, 'N': 4096, 'K': 4104, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2} +- {'M': 4864, 'N': 4096, 'K': 4112, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2} + +## 2 workgroups / CU +- {'M': 4864, 'N': 8192, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2} +- {'M': 4864, 'N': 8192, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2} +- {'M': 4864, 'N': 8192, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2} +- {'M': 4864, 'N': 8192, 'K': 8256, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2} diff --git a/python/perf-kernels/tools/tune_gemm/matmul_kernel.py b/python/perf-kernels/tools/tune_gemm/matmul_kernel.py index 1d9902bc2de6..6491340d10d9 100644 --- a/python/perf-kernels/tools/tune_gemm/matmul_kernel.py +++ b/python/perf-kernels/tools/tune_gemm/matmul_kernel.py @@ -46,16 +46,26 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, bias = tl.load(bias_ptrs, mask=offs_am < M, other=0.0) acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): - if EVEN_K: - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) - else: - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + max_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) - 1 + if EVEN_K: + max_k += 1 + for k in range(0, max_k): + a = tl.load(tl.multiple_of(a_ptrs, (1, 16))) + b = tl.load(tl.multiple_of(b_ptrs, (16, 1))) accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk + + if not EVEN_K: + k = max_k + offs_k = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrsX = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrsX = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + a = tl.load(a_ptrsX, mask=offs_k[None, :] < K, other=0.0) + b = tl.load(b_ptrsX, mask=offs_k[:, None] < K, other=0.0) + accumulator += tl.dot(a, b) + c = accumulator.to(c_ptr.type.element_ty) if BIAS: c += bias[:, None] diff --git a/python/perf-kernels/tools/tune_gemm/tune_gemm.py b/python/perf-kernels/tools/tune_gemm/tune_gemm.py index 291096b3d7af..b3d0d0bccc12 100755 --- a/python/perf-kernels/tools/tune_gemm/tune_gemm.py +++ b/python/perf-kernels/tools/tune_gemm/tune_gemm.py @@ -108,6 +108,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): num_warps = config.get("num_warps") num_stages = config.get("num_stages") matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + EVEN_K = (K % BLOCK_SIZE_K == 0) if matrix_instr_nonkdim > mfma: continue if mfma == 4 and BLOCK_SIZE_K < 64: @@ -149,10 +150,11 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): continue # Skip small block sizes and num_warps for large gemm # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + # We only want to use a small BLOCK_SIZE_K if not EVEN_K if large_gemm: if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: continue - if BLOCK_SIZE_K < 64: + if BLOCK_SIZE_K < 64 and EVEN_K: continue if num_warps < 4: continue @@ -657,14 +659,14 @@ def main(): # write best config to tuning_results.yaml if run_bench: - print(f"{formatted_tflops} {minTime}") + print(f"{formatted_tflops} {minTime} {bestConfig_compact_str}") f_results.write(f"{formatted_tflops},{minTime}\n") sizeDict = {'M': M, 'N': N, 'K': K, 'rowMajorA': row_a_str, 'rowMajorB': row_b_str} sizeDict.update(bestConfig) if not run_bench: f_results.write("- " + str(sizeDict) + " ") - f_results.write(f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n') + f_results.write(f'# {bestConfig_compact_str}\n') # remove generated files if asked to if not keepTmp: