Skip to content

Commit

Permalink
Merge branch 'main_perf' into atomics_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohuguo2023 authored Aug 19, 2024
2 parents 49b0a46 + 15cb3a8 commit 3956d23
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 31 deletions.
20 changes: 18 additions & 2 deletions python/perf-kernels/tune_gemm/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# GEMM tuning script (current v3.3)
# GEMM tuning script (current v3.4)

## matmul kernel

The matmul kernel implementation can be found as [matmul_kernel.py](https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tune_gemm/matmul_kernel.py), which includes the following features:
- XCD-based pid remapping
- grouping order of workgroup id, which is controlled by `GROUP_SIZE_M`, that
implements L2 cache optimization introduced in the [tutorial](https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations).
- split-k algorithm, which is controlled by `SPLIT_K`.
Expand Down Expand Up @@ -144,7 +145,7 @@ The default value is 1000.

The general idea of the tuning script can be summarized as
- Compile all the kernels in the tuning space in parallel.
- Divide the tuning space into tasks and invoke `rocprofv2` once per
- Divide the tuning space into tasks and invoke `rocprof` once per
task. This will save invocation overhead of the profiler.
- Profile tasks in parallel on multiple GPUs.

Expand Down Expand Up @@ -309,6 +310,21 @@ places:
- Statically set `device` and `stream` in the [jit.py](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/python/triton/runtime/jit.py#L588-L589)


# GEMM Tuning Script v3.4

## API changes

No API changes

## Implementation changes

- Now the matmul_kernel supports XCD-based pid remapping. Details with experiments
will be added later.
- Switched back to rocprofv1. Check [ticket#228](https://github.com/ROCm/triton-internal/issues/228) for more details.
- Improved the post-procesing logic to filter out the "spikes" in the profiling results.
- Reduced the number of iterations in both tuning and benchmark mode (120 and 200).


# One config running script

`one_config.py` is a script that runs one given matmul config.
Expand Down
16 changes: 14 additions & 2 deletions python/perf-kernels/tune_gemm/matmul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,22 @@
def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm,
stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BIAS: tl.constexpr,
EVEN_K: tl.constexpr):
EVEN_K: tl.constexpr, GRID_MN: tl.constexpr, NUM_XCDS: tl.constexpr):
pid = tl.program_id(axis=0)
pid_z = tl.program_id(1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

if NUM_XCDS != 1:
## pid remapping on xcds
# Number of pids per XCD in the new arrangement
pids_per_xcd = GRID_MN // NUM_XCDS
# Compute current XCD and local pid within the XCD
xcd = pid % NUM_XCDS
local_pid = pid // NUM_XCDS
# Calculate new pid based on the new grouping
pid = xcd * pids_per_xcd + local_pid

if GROUP_SIZE_M == 1:
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
Expand All @@ -19,8 +30,9 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak,
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

if SPLIT_K == 1:
offs_k = tl.arange(0, BLOCK_SIZE_K)
else:
Expand Down
39 changes: 16 additions & 23 deletions python/perf-kernels/tune_gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_full_tuning_space():
block_k_range = [16, 32, 64, 128, 256]
split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24]
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
group_m_range = [1, 2, 4, 8, 16, 32]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
Expand Down Expand Up @@ -157,7 +157,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
if num_warps < 4:
continue
# check if tiling is integer multiple of GEMM size because we have no boundary check
if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0 or K % BLOCK_SIZE_K != 0:
if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0:
continue

pruned_configs.append(config)
Expand All @@ -169,20 +169,15 @@ def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


def extract_kernel_time(M, N, K, config, df, bias_size):
# Correct the header by removing 'sig' and 'obj' to reduce number from 21 to 19
# once the bug(https://github.com/ROCm/rocprofiler/issues/144) fixed, we should
# not need below two lines
cols = [
'Index', 'KernelName', 'gpu-id', 'queue-id', 'queue-index', 'pid', 'tid', 'grd', 'wgr', 'lds', 'scr',
'arch_vgpr', 'accum_vgpr', 'sgpr', 'wave_size', 'DispatchNs', 'BeginNs', 'EndNs', 'CompleteNs'
]
df.columns = cols
def extract_kernel_time(M, N, K, config, df):
configStr = gen_configStr(config)
filtered_df = df[df['KernelName'].str.contains(configStr, na=False)].copy()
filtered_df['DurationNs'] = filtered_df['EndNs'] - filtered_df['BeginNs']
meanTime = filtered_df['DurationNs'].tail(100).mean()
return config, meanTime
df = df[df['KernelName'].str.contains(configStr)]

first_value = df['DurationNs'].iloc[0]
filtered_data = df['DurationNs'][df['DurationNs'] <= first_value]
new_meanTime = filtered_data.tail(100).mean()

return config, new_meanTime


def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
Expand All @@ -197,7 +192,7 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
if verbose:
print(f"profiling {kernel_name} on GPU {gpuid}")
run_bash_command_wrapper(
f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {get_filename_profile_driver(M, N, K, jobId)}",
f"rocprof --stats -o results_{jobId}.csv python {get_filename_profile_driver(M, N, K, jobId)}",
capture=(verbose < 2))
jobId += ngpus

Expand Down Expand Up @@ -244,13 +239,10 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type
thread_pool = multiprocessing.Pool(processes=num_threads)
tasks = []
idx = 0
df_prof = [
pd.read_csv(f"results_{i}.csv", skiprows=1, header=None, delimiter=',', quotechar='"', escapechar='\\')
for i in range(jobs)
]
df_prof = [pd.read_csv(f"results_{i}.csv") for i in range(jobs)]
for config in configs:
file_idx = idx % jobs
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx], bias_size))]
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx]))]
idx += 1
thread_pool.close()
thread_pool.join()
Expand Down Expand Up @@ -366,11 +358,12 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps
grid = triton.cdiv(M, block_m) * triton.cdiv(N, block_n), split_k
stride_bias = bias.stride(0) if use_bias else 0
EVEN_K = K % block_k == 0
num_xcds = 1 if split_k > 1 else 8
matmul_kernel[grid](a, b, c, bias, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0),
c.stride(1), stride_bias=stride_bias, BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, SPLIT_K=split_k, num_warps=num_warps,
num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize,
kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K)
kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K, GRID_MN=grid[0], NUM_XCDS=num_xcds)
return c


Expand Down Expand Up @@ -441,7 +434,7 @@ def parse_args():
parser.add_argument("--num_threads", type=int, default=32,
help="number of threads to use for kernel compilation and post processing")
parser.add_argument("--jobs", type=int, default=1, help="number of tasks during the profiling process")
parser.add_argument("--iters", type=int, default=1000, help="number of iterations used in --benchmark mode")
parser.add_argument("--iters", type=int, default=200, help="number of iterations used in --benchmark mode")
parser.add_argument("--init_type", type=str, default='randn', choices=['randn', 'hpl', 'trig_float', 'zeros'],
help="Input tensor initialization (default normal distribution)")
parser.add_argument(
Expand Down
17 changes: 13 additions & 4 deletions python/perf-kernels/tune_gemm/utils/file_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype

use_bias = bias_size > 0

## Let's enable xcd-based pid remapping only when split-K is NOT used
## Also #xcd is fixed to 8. If we are tuning for MI308, please change it to 4
num_xcds = 1 if split_k > 1 else 8

if warmup:
torch_dtype_a = 'fp16'
torch_dtype_b = 'fp16'
Expand All @@ -89,6 +93,7 @@ def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype

matmul_def_str = f"""
def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn):
grid_mn = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n})
matmul_kernel_{configStr}.warmup(
{torch_dtype_a}, {torch_dtype_b}, {torch_dtype_c}, {torch_dtype_c},
M, N, K,
Expand All @@ -103,8 +108,10 @@ def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn):
waves_per_eu = {waves_per_eu},
matrix_instr_nonkdim = {mfmaInstrSize},
kpack = {kpack},
BIAS={use_bias},
EVEN_K={EVEN_K},
BIAS = {use_bias},
EVEN_K = {EVEN_K},
GRID_MN = grid_mn,
NUM_XCDS = {num_xcds},
grid=(1,),
)
return None
Expand Down Expand Up @@ -136,7 +143,9 @@ def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn):
matrix_instr_nonkdim = {mfmaInstrSize},
kpack = {kpack},
BIAS = {use_bias},
EVEN_K = {EVEN_K}
EVEN_K = {EVEN_K},
GRID_MN = grid[0],
NUM_XCDS = {num_xcds}
)
return c
"""
Expand Down Expand Up @@ -310,7 +319,7 @@ def generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, ini

# call all matmul_xxx functions
idx = 0
runs = iters if run_bench else 200
runs = iters if run_bench else 120
for config in configs:
configStr = gen_configStr(config)
matmul_call_str = f"""
Expand Down

0 comments on commit 3956d23

Please sign in to comment.