diff --git "a/docs/project/\351\203\250\347\275\262\344\274\230\345\214\226/\346\267\261\345\272\246\345\255\246\344\271\240\347\274\226\350\257\221\345\231\250/OpenAI Triton MLIR \347\254\254\345\233\233\347\253\240: ROCm-triton\351\205\215\347\275\256.md" "b/docs/project/\351\203\250\347\275\262\344\274\230\345\214\226/\346\267\261\345\272\246\345\255\246\344\271\240\347\274\226\350\257\221\345\231\250/OpenAI Triton MLIR \347\254\254\345\233\233\347\253\240: ROCm-triton\351\205\215\347\275\256.md" new file mode 100644 index 0000000..6c65bcc --- /dev/null +++ "b/docs/project/\351\203\250\347\275\262\344\274\230\345\214\226/\346\267\261\345\272\246\345\255\246\344\271\240\347\274\226\350\257\221\345\231\250/OpenAI Triton MLIR \347\254\254\345\233\233\347\253\240: ROCm-triton\351\205\215\347\275\256.md" @@ -0,0 +1,627 @@ +# OpenAI/Triton MLIR 第四章: ROCm-triton配置 + +## 本文首发于GiantPandaCV,未经作者允许不得转载 + +最近在整理python-based的benchmark代码,反过来在NV的GPU上又把Triton装了一遍,发现Triton的github repo已经给出了对应的llvm的commit id以及对应的编译细节,然后跟着走了一遍,也顺利的安装成功,只需要按照如下方式即可完成NV GPU上的安装, + +``` +1. git clone https://github.com/openai/triton.git; +2. cd triton; +3. cd $HOME/llvm-project # your clone of LLVM. +4. git checkout 49af6502 +5. mkdir build +6. cd build +7. cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm" +8. ninja -j8 + +export LLVM_BUILD_DIR=$HOME/llvm-project/build + +cd +LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \ +LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \ +LLVM_SYSPATH=$LLVM_BUILD_DIR \ +pip install -e python +``` + +![img](https://picx.zhimg.com/80/v2-427448f490d6e325cf39e2456db60933_1440w.png?source=d16d100b) + + + + + +添加图片注释,不超过 140 字(可选) + +出现3.0.0说明triton已经安装成功了,装完triton后一定要安装Torch,为个人使用的是CUDA 12.1版本,按照下面的命令无脑安装即可。 + +``` +pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121 +``` + +NV GPU上triton的安装和使用其实已经轻车熟路了,接下来,让我们来探索一下AMD GPU上如何安装和配置triton。 + +### 0x00 软件安装 + +关于triton amd的backend,虽然triton的官方将其作为third-party来进行支持,但是我还是推荐大家使用AMD专门维护的一套triton版本,因为在最开始的官方triton的main分支下,开启 TRITON_CODEGEN_AMD_HIP_BACKEND=1 没有正确完成编译。所以找到了 + +按照对应的安装流程进行安装即可,我推荐使用如下命令进行安装,亲测有效 + +``` +1. git clone https://github.com/ROCmSoftwarePlatform/triton.git +2. cd triton +3. git checkout triton-mlir +``` + +这里已经准备好了需要编译的triton,但是triton后端是基于LLVM的,所以要想借助triton去生成可以跑在对应设备上的代码,我们还需要对LLVM进行编译,本教程中将会手动编译LLVM,当然如果你选择直接编译好的LLVM也是没有问题的。关于LLVM,由于triton是基于b1115f8c这个commit id进行开发的,那么我们只需要将LLVM clone下来后,checkout到对应的commit id,然后按照如下完整命令进行编译即可。 + +``` +1. git clone https://github.com/llvm/llvm-project +2. git checkout b1115f8c +3. cd llvm-project +4. mkdir build +5. cd build +6. cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm" +7. ninja -j8 +``` + +等LLVM全部装好后,就可以去将当前这个LLVM的路径写入到你的bashrc下 + +``` +export PATH=/home/llvm-project/build/bin:$PATH +``` + +然后进入到一开始clone下来的triton目录下进行如下命令 + +``` +1. cd triton +2. vim CMakeLists.txt (option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" ON)) +3. mkdir build +4. cd build +5. cmake .. +6. make -j8 +``` + +在编译完全正确后,就会在当前的 build 目录下产生一个 [libtriton.so](http://libtriton.so/) 文件。那么接下来只要将 + +[libtriton.so](http://libtriton.so/) 文件移动到 triton/python/triton/_C 目录下,将 triton 的 python 路径下入 bashrc + +``` +export TRITON_HOME=/home/Documents/compiler/triton +export PYTHONPATH=$TRITON_HOME/python:${PYTHONPATH} +``` + +如果在编译的过程中出现 goolge test 找不到的情况,按照如下命令进行安装: + +``` +1. git clone https://github.com/google/googletest +2. cd googletest +3. cmake CMakeLists.txt +4. make -j8 +5. cp ./lib/libgtest*.a /usr/lib +6. cd googletest +7. cp –a include/gtest /usr/include +``` + +如果在编译的过程中出现 pybind11 找不到的情况,按照如下命令进行按照: + +``` +1. pip install pytest +2. git clone https://github.com/pybind/pybind11.git +3. cd pybind11 +4. mkdir build +5. cd build +6. cmake .. +7. make check -j 8 +8. sudo make instal +``` + +关于 在AMD GPU上的pytorch 一定要去安装适配 ROCM 版本的 pytorch,由于我的机器使用的是5.6版本的ROCm,所以我的安装的命令如下,仅供参考: + +``` +pip3 install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url +https://download.pytorch.org/whl/rocm5.6 +``` + +关于 ROCM 版本可以通过如下命令进行查询: + +``` +dpkg -l | grep rocm +``` + +这里要记住,pytorch在AMD GPU上的使用和在NV GPU上的使用非常相似,也是用.cuda()来指定变量所在位置。 + +### 0x01 GEMM代码示例 + +全部编译好后,就可以通过执行下面的代码得到对应的 GEMM 在 AMD 显卡上针对 Triton和 rocBLAS 的 benchmark 了。 + +``` +import torch + +import triton +import triton.language as tl +import sys +import argparse +import pytest + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + ] if torch.version.hip is None else [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, + num_warps=8, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=8, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, + num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, + num_warps=4, num_stages=0), + ], + key=['M', 'N', 'K'], +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, +}) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + EVEN_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetics` section for details + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. +@triton.jit +def leaky_relu(x): + x = x + 1 + return tl.where(x >= 0, x, 0.01 * x) + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ACTIVATION=activation # + ) + return c + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). +@pytest.mark.parametrize("M, N, K, in_dtype, out_dtype", +[ (*shape, in_dtype, out_dtype) + for shape in [(128, 256, 32), (128, 16, 32), (32, 128, 64), + (128, 128, 64), (64, 128, 128), (32, 128, 64), + (64, 64, 32), (32, 32, 128), (128, 128, 64), + (64, 128, 128), (512, 512, 512), (1024, 1024, 1024)] + for in_dtype, out_dtype in [('int8', 'int8'), + ('float16', 'float16'), + ('bfloat16', 'bfloat16'), + ('float16', 'float32'), + ('float32', 'float32')]] +) +def test_correctness(M, N, K, in_dtype, out_dtype): + torch.manual_seed(0) + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + triton_output = matmul(a, b) + torch_output = torch.matmul(a, b) + print(f"triton_output={triton_output}") + print(f"torch_output={torch_output}") + rtol = 0 if torch.version.hip is None else 1e-2 + if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print("✅ Triton and Torch match") + else: + print("❌ Triton and Torch differ") + assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol) + + +# %% +# Benchmark +# --------- +# +# Square Matrix Performance +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, +# but feel free to arrange this script as you wish to benchmark any other matrix shape. + +global verbose +verbose = False + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot + x_vals=[ + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + (8192, 8192, 8192), + (9728, 8192, 65536) + ], # Different possible values for `x_name` + line_arg='provider', # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + line_vals=['rocblas', 'triton'], + # Label name for the lines + line_names=["rocBLAS", "Triton"], + # Line styles + styles=[('green', '-'), ('blue', '-')], + ylabel="TFLOPS", # Label name for the y-axis + plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. + args={}, + )) +def benchmark(M, N, K, provider): + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + quantiles = [0.5, 0.2, 0.8] + if provider == 'rocblas': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + global verbose + if verbose: + print(f'SIZE: {M},{N},{K} Best tuning config: ({matmul_kernel.get_best_config()})') + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="GEMM tutorial example", + allow_abbrev=False, + ) + + parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config") + args = parser.parse_args() + + return args + + +def main(): + # assign to a global verbose var to indicate whether print + # best tuning config + global verbose + args = parse_args() + verbose=args.v + benchmark.run(show_plots=True, print_data=True) + +if __name__ == '__main__': + sys.exit(main()) +``` + +### 0x10 GEMM代码详细解读 + +首先是对于搜索空间的定义,这里 + +``` +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + ] if torch.version.hip is None else [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, + num_warps=8, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=8, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, + num_warps=4, num_stages=0), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, + num_warps=4, num_stages=0), + ], + key=['M', 'N', 'K'], +) +``` + +其中的torch.version.hip走的就是AMD GPU所对应的搜索空间,我们看到其对应的可以tuning的knob,有最常规的BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M外,还有了一个新的wave_per_eu,我一开始看到这个概念的时候也很陌生,随后和AMD的技术人员请教了下,总结下来就是: + +AMD GPU由计算单元(CU)组成,这相当于NVIDIA GPU上的流处理器(SM)。在每个CU中,有4个SIMD单元(也称执行引擎或EU)。你可以把SIMD单元看成是一个矢量执行单元,它具有执行计算所需的一定数量的寄存器和ALUs。当你发起一个计算网格时,工作组(相当于NVIDIA GPU上的线程块)会安排在CU上运行。在CU中,波前(相当于NVIDIA GPU上的波纹)会安排在SIMD单元上运行。这里提出了occupancy的概念,它表示每个SIMD单元上可同时运行的波前数。这取决于每个波前需要的资源量和每个SIMD单元的资源量。waves_per_eu参数重点关注寄存器使用情况。例如,每个SIMD(EU)有512个寄存器。如果每个波前需要256个寄存器,那么occupancy为2。但如果我们设置waves_per_eu=3,编译器会试图将每个波前的寄存器使用量减少到170,这样occupancy就可以是3了。但是提高waves_per_eu存在寄存器溢出的风险和性能下降。所以增加waves_per_eu可能会增加occupancy,但不一定能提高性能。 + +然后是具体的kernel定义,这部分的定义其实和NV GPU上的写法没有本质区别 + +``` +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + EVEN_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetics` section for details + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) +``` + +接下来是单元测试,用来说明triton的输出结果和torch的输出结果必须是相同的 + +``` +def test_correctness(M, N, K, in_dtype, out_dtype): + torch.manual_seed(0) + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + triton_output = matmul(a, b) + torch_output = torch.matmul(a, b) + print(f"triton_output={triton_output}") + print(f"torch_output={torch_output}") + rtol = 0 if torch.version.hip is None else 1e-2 + if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print("✅ Triton and Torch match") + else: + print("❌ Triton and Torch differ") + assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol) +``` + +接下来你只需要指定好对应的GEMM的尺寸,我们的默认输入顺序还是以M,N,K为主,剩下都是中规中局的操作了。 + +``` +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot + x_vals=[ + (1024, 1024, 1024), + (2048, 2048, 2048), + (4096, 4096, 4096), + (8192, 8192, 8192), + (9728, 8192, 65536) + ], # Different possible values for `x_name` + line_arg='provider', # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + line_vals=['rocblas', 'triton'], + # Label name for the lines + line_names=["rocBLAS", "Triton"], + # Line styles + styles=[('green', '-'), ('blue', '-')], + ylabel="TFLOPS", # Label name for the y-axis + plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. + args={}, + )) +def benchmark(M, N, K, provider): + a = torch.randn((M, K), device='cuda', dtype=torch.float16) + b = torch.randn((K, N), device='cuda', dtype=torch.float16) + quantiles = [0.5, 0.2, 0.8] + if provider == 'rocblas': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + global verbose + if verbose: + print(f'SIZE: {M},{N},{K} Best tuning config: ({matmul_kernel.get_best_config()})') + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="GEMM tutorial example", + allow_abbrev=False, + ) + + parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config") + args = parser.parse_args() + + return args + + +def main(): + # assign to a global verbose var to indicate whether print + # best tuning config + global verbose + args = parse_args() + verbose=args.v + benchmark.run(show_plots=True, print_data=True) + +if __name__ == '__main__': + sys.exit(main()) +``` + +关于在AMD GPU上更加自动化的GEMM benchmark调优脚本,我们将在后面的章节中来为大家进行解读。