diff --git a/python/perf-kernels/tune_gemm/README.md b/python/perf-kernels/tune_gemm/README.md index da45dcda5c3c..c22382143544 100644 --- a/python/perf-kernels/tune_gemm/README.md +++ b/python/perf-kernels/tune_gemm/README.md @@ -1,8 +1,9 @@ -# GEMM tuning script (current v3.3) +# GEMM tuning script (current v3.4) ## matmul kernel The matmul kernel implementation can be found as [matmul_kernel.py](https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tune_gemm/matmul_kernel.py), which includes the following features: +- XCD-based pid remapping - grouping order of workgroup id, which is controlled by `GROUP_SIZE_M`, that implements L2 cache optimization introduced in the [tutorial](https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations). - split-k algorithm, which is controlled by `SPLIT_K`. @@ -144,7 +145,7 @@ The default value is 1000. The general idea of the tuning script can be summarized as - Compile all the kernels in the tuning space in parallel. -- Divide the tuning space into tasks and invoke `rocprofv2` once per +- Divide the tuning space into tasks and invoke `rocprof` once per task. This will save invocation overhead of the profiler. - Profile tasks in parallel on multiple GPUs. @@ -309,6 +310,21 @@ places: - Statically set `device` and `stream` in the [jit.py](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/python/triton/runtime/jit.py#L588-L589) +# GEMM Tuning Script v3.4 + +## API changes + +No API changes + +## Implementation changes + +- Now the matmul_kernel supports XCD-based pid remapping. Details with experiments +will be added later. +- Switched back to rocprofv1. Check [ticket#228](https://github.com/ROCm/triton-internal/issues/228) for more details. +- Improved the post-procesing logic to filter out the "spikes" in the profiling results. +- Reduced the number of iterations in both tuning and benchmark mode (120 and 200). + + # One config running script `one_config.py` is a script that runs one given matmul config. diff --git a/python/perf-kernels/tune_gemm/matmul_kernel.py b/python/perf-kernels/tune_gemm/matmul_kernel.py index 336a643dca50..1d9902bc2de6 100644 --- a/python/perf-kernels/tune_gemm/matmul_kernel.py +++ b/python/perf-kernels/tune_gemm/matmul_kernel.py @@ -6,11 +6,22 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BIAS: tl.constexpr, - EVEN_K: tl.constexpr): + EVEN_K: tl.constexpr, GRID_MN: tl.constexpr, NUM_XCDS: tl.constexpr): pid = tl.program_id(axis=0) pid_z = tl.program_id(1) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + if NUM_XCDS != 1: + ## pid remapping on xcds + # Number of pids per XCD in the new arrangement + pids_per_xcd = GRID_MN // NUM_XCDS + # Compute current XCD and local pid within the XCD + xcd = pid % NUM_XCDS + local_pid = pid // NUM_XCDS + # Calculate new pid based on the new grouping + pid = xcd * pids_per_xcd + local_pid + if GROUP_SIZE_M == 1: pid_m = pid // num_pid_n pid_n = pid % num_pid_n @@ -19,8 +30,9 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m + if SPLIT_K == 1: offs_k = tl.arange(0, BLOCK_SIZE_K) else: diff --git a/python/perf-kernels/tune_gemm/tune_gemm.py b/python/perf-kernels/tune_gemm/tune_gemm.py index d49823306e3a..291096b3d7af 100755 --- a/python/perf-kernels/tune_gemm/tune_gemm.py +++ b/python/perf-kernels/tune_gemm/tune_gemm.py @@ -54,7 +54,7 @@ def get_full_tuning_space(): block_k_range = [16, 32, 64, 128, 256] split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] num_warps_range = [1, 2, 4, 8] - group_m_range = [1, 4, 8, 16, 32] + group_m_range = [1, 2, 4, 8, 16, 32] # For now we see better perf with num_stages=0 for all gemm configs we care # But keep this explicit so that we do not forget we may need to set it to # other values in the future @@ -157,7 +157,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): if num_warps < 4: continue # check if tiling is integer multiple of GEMM size because we have no boundary check - if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0 or K % BLOCK_SIZE_K != 0: + if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0: continue pruned_configs.append(config) @@ -169,20 +169,15 @@ def need_split_k(SIZE_M, SIZE_N, SIZE_K): return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 -def extract_kernel_time(M, N, K, config, df, bias_size): - # Correct the header by removing 'sig' and 'obj' to reduce number from 21 to 19 - # once the bug(https://github.com/ROCm/rocprofiler/issues/144) fixed, we should - # not need below two lines - cols = [ - 'Index', 'KernelName', 'gpu-id', 'queue-id', 'queue-index', 'pid', 'tid', 'grd', 'wgr', 'lds', 'scr', - 'arch_vgpr', 'accum_vgpr', 'sgpr', 'wave_size', 'DispatchNs', 'BeginNs', 'EndNs', 'CompleteNs' - ] - df.columns = cols +def extract_kernel_time(M, N, K, config, df): configStr = gen_configStr(config) - filtered_df = df[df['KernelName'].str.contains(configStr, na=False)].copy() - filtered_df['DurationNs'] = filtered_df['EndNs'] - filtered_df['BeginNs'] - meanTime = filtered_df['DurationNs'].tail(100).mean() - return config, meanTime + df = df[df['KernelName'].str.contains(configStr)] + + first_value = df['DurationNs'].iloc[0] + filtered_data = df['DurationNs'][df['DurationNs'] <= first_value] + new_meanTime = filtered_data.tail(100).mean() + + return config, new_meanTime def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose): @@ -197,7 +192,7 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose): if verbose: print(f"profiling {kernel_name} on GPU {gpuid}") run_bash_command_wrapper( - f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {get_filename_profile_driver(M, N, K, jobId)}", + f"rocprof --stats -o results_{jobId}.csv python {get_filename_profile_driver(M, N, K, jobId)}", capture=(verbose < 2)) jobId += ngpus @@ -244,13 +239,10 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type thread_pool = multiprocessing.Pool(processes=num_threads) tasks = [] idx = 0 - df_prof = [ - pd.read_csv(f"results_{i}.csv", skiprows=1, header=None, delimiter=',', quotechar='"', escapechar='\\') - for i in range(jobs) - ] + df_prof = [pd.read_csv(f"results_{i}.csv") for i in range(jobs)] for config in configs: file_idx = idx % jobs - tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx], bias_size))] + tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx]))] idx += 1 thread_pool.close() thread_pool.join() @@ -366,11 +358,12 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps grid = triton.cdiv(M, block_m) * triton.cdiv(N, block_n), split_k stride_bias = bias.stride(0) if use_bias else 0 EVEN_K = K % block_k == 0 + num_xcds = 1 if split_k > 1 else 8 matmul_kernel[grid](a, b, c, bias, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), stride_bias=stride_bias, BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, SPLIT_K=split_k, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, - kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K) + kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K, GRID_MN=grid[0], NUM_XCDS=num_xcds) return c @@ -441,7 +434,7 @@ def parse_args(): parser.add_argument("--num_threads", type=int, default=32, help="number of threads to use for kernel compilation and post processing") parser.add_argument("--jobs", type=int, default=1, help="number of tasks during the profiling process") - parser.add_argument("--iters", type=int, default=1000, help="number of iterations used in --benchmark mode") + parser.add_argument("--iters", type=int, default=200, help="number of iterations used in --benchmark mode") parser.add_argument("--init_type", type=str, default='randn', choices=['randn', 'hpl', 'trig_float', 'zeros'], help="Input tensor initialization (default normal distribution)") parser.add_argument( diff --git a/python/perf-kernels/tune_gemm/utils/file_generator.py b/python/perf-kernels/tune_gemm/utils/file_generator.py index 1011bc9df805..d92079dab9a0 100644 --- a/python/perf-kernels/tune_gemm/utils/file_generator.py +++ b/python/perf-kernels/tune_gemm/utils/file_generator.py @@ -76,6 +76,10 @@ def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype use_bias = bias_size > 0 + ## Let's enable xcd-based pid remapping only when split-K is NOT used + ## Also #xcd is fixed to 8. If we are tuning for MI308, please change it to 4 + num_xcds = 1 if split_k > 1 else 8 + if warmup: torch_dtype_a = 'fp16' torch_dtype_b = 'fp16' @@ -89,6 +93,7 @@ def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype matmul_def_str = f""" def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): + grid_mn = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n}) matmul_kernel_{configStr}.warmup( {torch_dtype_a}, {torch_dtype_b}, {torch_dtype_c}, {torch_dtype_c}, M, N, K, @@ -103,8 +108,10 @@ def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn): waves_per_eu = {waves_per_eu}, matrix_instr_nonkdim = {mfmaInstrSize}, kpack = {kpack}, - BIAS={use_bias}, - EVEN_K={EVEN_K}, + BIAS = {use_bias}, + EVEN_K = {EVEN_K}, + GRID_MN = grid_mn, + NUM_XCDS = {num_xcds}, grid=(1,), ) return None @@ -136,7 +143,9 @@ def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn): matrix_instr_nonkdim = {mfmaInstrSize}, kpack = {kpack}, BIAS = {use_bias}, - EVEN_K = {EVEN_K} + EVEN_K = {EVEN_K}, + GRID_MN = grid[0], + NUM_XCDS = {num_xcds} ) return c """ @@ -310,7 +319,7 @@ def generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, ini # call all matmul_xxx functions idx = 0 - runs = iters if run_bench else 200 + runs = iters if run_bench else 120 for config in configs: configStr = gen_configStr(config) matmul_call_str = f"""