From a2934b9bebf447ce3c7c7c0ba9c3274882ddd63c Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Mon, 21 Oct 2024 20:41:47 +0800 Subject: [PATCH] [HGEMM] Add MMA 16816 swizzle, Up to 115 TFLOPS (#98) * Update hgemm_mma.cu * Update README.md * Update hgemm.py * Update hgemm.cu * Update hgemm_mma.cu * Update hgemm.cu * Update hgemm.py * Update README.md * Update hgemm_mma.cu * Update hgemm.py * Update hgemm.cu * Update hgemm_mma.cu * Update README.md * Update hgemm.py * Update README.md * Update README.md * Update hgemm_mma_stage.cu * Update hgemm.py * Update hgemm.cu * Update README.md * Update README.md * Update hgemm_mma_stage.cu * Update hgemm_mma_stage.cu --- README.md | 6 +- hgemm/README.md | 430 +++++++--------------- hgemm/hgemm.cu | 12 + hgemm/hgemm.py | 35 +- hgemm/hgemm_mma.cu | 302 ++++++++++++++- hgemm/hgemm_mma_stage.cu | 776 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 1247 insertions(+), 314 deletions(-) diff --git a/README.md b/README.md index 9c1b88ee..5363c748 100644 --- a/README.md +++ b/README.md @@ -147,6 +147,10 @@ | ✔️ [hgemm_wmma_m32n8k16....dbuf*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_wmma_m16n16k16...stages*](./hgemm/hgemm_wmma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_wmma_m16n16k16...swizzle*](./hgemm/hgemm_wmma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| +| ✔️ [hgemm_mma_m16n8k16...naive*](./hgemm/hgemm_mma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| +| ✔️ [hgemm_mma_m16n8k16...mma2x4*](./hgemm/hgemm_mma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| +| ✔️ [hgemm_mma_m16n8k16...stages*](./hgemm/hgemm_mma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| +| ✔️ [hgemm_mma_m16n8k16...swizzle*](./hgemm/hgemm_mma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| | ✔️ [sgemv_k32_f32](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️| | ✔️ [sgemv_k128_f32x4](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️| | ✔️ [sgemv_k16_f32](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️| @@ -158,7 +162,7 @@ | ✔️ [hard_nms cpp only](./nms/nms.cc)|f32|/|/|⭐️| | ✔️ [notes v1(deprecated)](./notes-v1.cu)|f32|f32|/|⭐️| -👉TIPS: * means using **Tensor Cores(MMA PTX)**, otherwise, using CUDA Cores by default. +👉TIPS: * means using **Tensor Cores(MMA/WMMA)**, otherwise, using CUDA Cores by default. ## 0x01 📖 博客目录 diff --git a/hgemm/README.md b/hgemm/README.md index 3f7b0fe6..d2f82368 100755 --- a/hgemm/README.md +++ b/hgemm/README.md @@ -23,13 +23,16 @@ - [X] hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages(WMMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle) - [X] hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages(WMMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle) - [X] hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async(WMMA, Tile MMA/Warp, Copy Async, Double Buffers, Pad) +- [X] hgemm_mma_m16n8k16_naive(MMA) +- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4(MMA, Tile MMA/Warp, pack) +- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(MMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle) - [X] PyTorch bindings ## 目前性能 - NVIDIA L20 -目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现Warp swizzle(受限于WMMA API的灵活性以及本人的能力),后续将会尝试通过MMA PTX实现warp swizzle,[点击查看性能数据](#NV-L20)。 +目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现Warp swizzle(受限于WMMA API的灵活性以及本人的能力),后续将会尝试通过MMA PTX实现warp swizzle,[点击查看性能数据](#NV-L20)。 - NVIDIA GeForce RTX 3080 Laptop @@ -232,20 +235,23 @@ nsys profile --stats=true -t cuda,osrt,nvtx -o hgemm.prof --force-overwrite true ```bash # 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... export TORCH_CUDA_ARCH_LIST=Ada -python3 hgemm.py # default, test some wmma kernels for all MNK -python3 hgemm.py --wmma # test all wmma kernels for all MNK -python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --wmma # test all wmma kernels for specific MNK -python3 hgemm.py --wmma --no-default # test all wmma kernels, but exclude the default part. +python3 hgemm.py --wmma # test defalut wmma kernels for all MNK +python3 hgemm.py --mma # test defalut mma kernels for all MNK +python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --wmma # test default wmma kernels for specific MNK +python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --mma # test default mma kernels for specific MNK +python3 hgemm.py --wmma --wmma-all # test all wmma kernels for all MNK +python3 hgemm.py --mma --mma-all # test all mma kernels for all MNK ``` ## NVIDIA L20
-Up to 113.76 TFLOPS, 113.76/119.5=95.19% TFLOPS utilization. + +### WMMA & CUDA: Up to 113.76 TFLOPS, 113.76/119.5=95.19% TFLOPS utilization. ```bash -python3 hgemm.py +python3 hgemm.py --cuda --wmma ``` 输出: ```bash @@ -338,79 +344,6 @@ python3 hgemm.py (mma4x2+warp2x4+stage3+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:2.573418ms, swizzle: 2048, TFLOPS: 106.81 (mma4x2+warp2x4+stage2+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:2.533483ms, swizzle: 2048, TFLOPS: 108.50 (cublas): ['9.84375 ', '-46.71875 '], time:2.661132ms, swizzle: NOOP, TFLOPS: 103.29 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=4096, N=8192, K=8192, Warmup=5, Iters=20, 6/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['47.53125 ', '-51.5 '], time:11.79177ms, swizzle: NOOP, TFLOPS: 46.62 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['47.53125 ', '-51.5 '], time:11.25807ms, swizzle: NOOP, TFLOPS: 48.83 (+4.74%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['47.0 ', '-52.25 '], time:12.28225ms, swizzle: NOOP, TFLOPS: 44.76 - (mma4x2+warp2x4): ['47.0 ', '-52.25 '], time:7.306694ms, swizzle: NOOP, TFLOPS: 75.24 (+54.08%) - (mma4x2+warp2x4+stage3): ['47.0 ', '-52.25 '], time:5.185413ms, swizzle: NOOP, TFLOPS: 106.02(+40.91%) - (mma4x2+warp2x4+stage2): ['47.0 ', '-52.25 '], time:5.128622ms, swizzle: NOOP, TFLOPS: 107.19(+1.11%) - (mma4x2+warp2x4+stage3+dsmem): ['47.0 ', '-52.25 '], time:5.165719ms, swizzle: NOOP, TFLOPS: 106.42 - (mma4x2+warp2x4+stage2+dsmem): ['47.0 ', '-52.25 '], time:5.137014ms, swizzle: NOOP, TFLOPS: 107.02 - (mma4x2+warp2x4+stage3+swizzle): ['47.0 ', '-52.25 '], time:5.096411ms, swizzle: 2048, TFLOPS: 107.87(+0.63%) - (mma4x2+warp2x4+stage2+swizzle): ['47.0 ', '-52.25 '], time:5.036878ms, swizzle: 2048, TFLOPS: 109.15(+1.18%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:5.087852ms, swizzle: 2048, TFLOPS: 108.05 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:5.011391ms, swizzle: 2048, TFLOPS: 109.70(+0.51%) - (cublas): ['47.0 ', '-52.25 '], time:5.063843ms, swizzle: NOOP, TFLOPS: 108.56 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=4096, N=16384, K=2048, Warmup=5, Iters=20, 7/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['14.765625 ', '-18.640625'], time:5.306124ms, swizzle: NOOP, TFLOPS: 51.80 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['14.765625 ', '-18.640625'], time:5.044364ms, swizzle: NOOP, TFLOPS: 54.49 (+5.19%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['14.28125 ', '-18.6875 '], time:5.916452ms, swizzle: NOOP, TFLOPS: 46.46 - (mma4x2+warp2x4): ['14.28125 ', '-18.6875 '], time:3.550720ms, swizzle: NOOP, TFLOPS: 77.41 (+42.07%) - (mma4x2+warp2x4+stage3): ['14.28125 ', '-18.6875 '], time:2.552175ms, swizzle: NOOP, TFLOPS: 107.70(+39.13%) - (mma4x2+warp2x4+stage2): ['14.28125 ', '-18.6875 '], time:2.537274ms, swizzle: NOOP, TFLOPS: 108.34(+0.59%) - (mma4x2+warp2x4+stage3+dsmem): ['14.28125 ', '-18.6875 '], time:2.545833ms, swizzle: NOOP, TFLOPS: 107.97 - (mma4x2+warp2x4+stage2+dsmem): ['14.28125 ', '-18.6875 '], time:2.546501ms, swizzle: NOOP, TFLOPS: 107.94 - (mma4x2+warp2x4+stage3+swizzle): ['14.28125 ', '-18.6875 '], time:2.544927ms, swizzle: 4096, TFLOPS: 108.01 - (mma4x2+warp2x4+stage2+swizzle): ['14.28125 ', '-18.6875 '], time:2.518939ms, swizzle: 4096, TFLOPS: 109.12(+0.73%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:2.547931ms, swizzle: 4096, TFLOPS: 107.88 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:2.512478ms, swizzle: 4096, TFLOPS: 109.41(+0.26%) - (cublas): ['14.28125 ', '-18.6875 '], time:2.635645ms, swizzle: NOOP, TFLOPS: 104.29 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=4096, N=16384, K=4096, Warmup=5, Iters=20, 8/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['10.296875 ', '-46.6875 '], time:11.61146ms, swizzle: NOOP, TFLOPS: 47.35 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['10.296875 ', '-46.6875 '], time:11.02995ms, swizzle: NOOP, TFLOPS: 49.84 (+5.27%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['9.84375 ', '-46.71875 '], time:15.55149ms, swizzle: NOOP, TFLOPS: 35.35 - (mma4x2+warp2x4): ['9.84375 ', '-46.71875 '], time:7.264566ms, swizzle: NOOP, TFLOPS: 75.68 (+51.83%) - (mma4x2+warp2x4+stage3): ['9.84375 ', '-46.71875 '], time:5.160856ms, swizzle: NOOP, TFLOPS: 106.52(+40.76%) - (mma4x2+warp2x4+stage2): ['9.84375 ', '-46.71875 '], time:5.038166ms, swizzle: NOOP, TFLOPS: 109.12(+2.44%) - (mma4x2+warp2x4+stage3+dsmem): ['9.84375 ', '-46.71875 '], time:5.177164ms, swizzle: NOOP, TFLOPS: 106.19 - (mma4x2+warp2x4+stage2+dsmem): ['9.84375 ', '-46.71875 '], time:5.098938ms, swizzle: NOOP, TFLOPS: 107.82 - (mma4x2+warp2x4+stage3+swizzle): ['9.84375 ', '-46.71875 '], time:5.004787ms, swizzle: 4096, TFLOPS: 109.85(+0.67%) - (mma4x2+warp2x4+stage2+swizzle): ['9.84375 ', '-46.71875 '], time:4.954004ms, swizzle: 4096, TFLOPS: 110.97(+1.03%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:5.003094ms, swizzle: 4096, TFLOPS: 109.88 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:4.933691ms, swizzle: 4096, TFLOPS: 111.43(+0.41%) - (cublas): ['9.84375 ', '-46.71875 '], time:4.990887ms, swizzle: NOOP, TFLOPS: 110.15 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=4096, N=16384, K=8192, Warmup=5, Iters=20, 9/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['47.53125 ', '-51.5 '], time:24.35543ms, swizzle: NOOP, TFLOPS: 45.14 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['47.53125 ', '-51.5 '], time:23.57738ms, swizzle: NOOP, TFLOPS: 46.63 (+3.30%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['47.0 ', '-52.25 '], time:31.01222ms, swizzle: NOOP, TFLOPS: 35.45 - (mma4x2+warp2x4): ['47.0 ', '-52.25 '], time:14.37473ms, swizzle: NOOP, TFLOPS: 76.49 (+64.02%) - (mma4x2+warp2x4+stage3): ['47.0 ', '-52.25 '], time:12.40768ms, swizzle: NOOP, TFLOPS: 88.62 (+15.85%) - (mma4x2+warp2x4+stage2): ['47.0 ', '-52.25 '], time:12.25883ms, swizzle: NOOP, TFLOPS: 89.69 (+1.21%) - (mma4x2+warp2x4+stage3+dsmem): ['47.0 ', '-52.25 '], time:12.40663ms, swizzle: NOOP, TFLOPS: 88.62 - (mma4x2+warp2x4+stage2+dsmem): ['47.0 ', '-52.25 '], time:12.26737ms, swizzle: NOOP, TFLOPS: 89.63 - (mma4x2+warp2x4+stage3+swizzle): ['47.0 ', '-52.25 '], time:9.920740ms, swizzle: 4096, TFLOPS: 110.83(+23.57%) - (mma4x2+warp2x4+stage2+swizzle): ['47.0 ', '-52.25 '], time:9.804654ms, swizzle: 4096, TFLOPS: 112.14(+1.18%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:9.917545ms, swizzle: 4096, TFLOPS: 110.87 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:9.778022ms, swizzle: 4096, TFLOPS: 112.45(+0.27%) - (cublas): ['47.0 ', '-52.25 '], time:9.679126ms, swizzle: NOOP, TFLOPS: 113.60(+1.02%) ----------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=8192, N=4096, K=2048, Warmup=5, Iters=20, 10/27 ---------------------------------------------------------------------------------------------------------------------------------- @@ -464,223 +397,6 @@ python3 hgemm.py (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:5.098199ms, swizzle: 1024, TFLOPS: 107.83 (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:5.003476ms, swizzle: 1024, TFLOPS: 109.87(+0.02%) (cublas): ['47.0 ', '-52.25 '], time:5.096864ms, swizzle: NOOP, TFLOPS: 107.86 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=8192, N=8192, K=2048, Warmup=5, Iters=20, 13/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['14.765625 ', '-18.640625'], time:5.346417ms, swizzle: NOOP, TFLOPS: 51.41 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['14.765625 ', '-18.640625'], time:4.942703ms, swizzle: NOOP, TFLOPS: 55.61 (+8.17%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['14.28125 ', '-18.6875 '], time:5.900359ms, swizzle: NOOP, TFLOPS: 46.59 - (mma4x2+warp2x4): ['14.28125 ', '-18.6875 '], time:3.572225ms, swizzle: NOOP, TFLOPS: 76.95 (+38.36%) - (mma4x2+warp2x4+stage3): ['14.28125 ', '-18.6875 '], time:2.547502ms, swizzle: NOOP, TFLOPS: 107.90(+40.22%) - (mma4x2+warp2x4+stage2): ['14.28125 ', '-18.6875 '], time:2.539443ms, swizzle: NOOP, TFLOPS: 108.24(+0.32%) - (mma4x2+warp2x4+stage3+dsmem): ['14.28125 ', '-18.6875 '], time:2.537584ms, swizzle: NOOP, TFLOPS: 108.32(+0.07%) - (mma4x2+warp2x4+stage2+dsmem): ['14.28125 ', '-18.6875 '], time:2.540159ms, swizzle: NOOP, TFLOPS: 108.21 - (mma4x2+warp2x4+stage3+swizzle): ['14.28125 ', '-18.6875 '], time:2.535915ms, swizzle: 2048, TFLOPS: 108.39(+0.07%) - (mma4x2+warp2x4+stage2+swizzle): ['14.28125 ', '-18.6875 '], time:2.510333ms, swizzle: 2048, TFLOPS: 109.50(+1.02%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:2.547550ms, swizzle: 2048, TFLOPS: 107.90 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:2.511882ms, swizzle: 2048, TFLOPS: 109.43 - (cublas): ['14.28125 ', '-18.6875 '], time:2.635979ms, swizzle: NOOP, TFLOPS: 104.28 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=8192, N=8192, K=4096, Warmup=5, Iters=20, 14/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['10.296875 ', '-46.6875 '], time:10.91315ms, swizzle: NOOP, TFLOPS: 50.38 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['10.296875 ', '-46.6875 '], time:10.32221ms, swizzle: NOOP, TFLOPS: 53.26 (+5.72%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['9.84375 ', '-46.71875 '], time:11.72120ms, swizzle: NOOP, TFLOPS: 46.90 - (mma4x2+warp2x4): ['9.84375 ', '-46.71875 '], time:6.984162ms, swizzle: NOOP, TFLOPS: 78.71 (+47.79%) - (mma4x2+warp2x4+stage3): ['9.84375 ', '-46.71875 '], time:5.013513ms, swizzle: NOOP, TFLOPS: 109.65(+39.31%) - (mma4x2+warp2x4+stage2): ['9.84375 ', '-46.71875 '], time:4.993677ms, swizzle: NOOP, TFLOPS: 110.09(+0.40%) - (mma4x2+warp2x4+stage3+dsmem): ['9.84375 ', '-46.71875 '], time:4.979777ms, swizzle: NOOP, TFLOPS: 110.40(+0.28%) - (mma4x2+warp2x4+stage2+dsmem): ['9.84375 ', '-46.71875 '], time:5.007362ms, swizzle: NOOP, TFLOPS: 109.79 - (mma4x2+warp2x4+stage3+swizzle): ['9.84375 ', '-46.71875 '], time:4.975485ms, swizzle: 2048, TFLOPS: 110.49(+0.09%) - (mma4x2+warp2x4+stage2+swizzle): ['9.84375 ', '-46.71875 '], time:4.935383ms, swizzle: 2048, TFLOPS: 111.39(+0.81%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:4.990983ms, swizzle: 2048, TFLOPS: 110.15 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:4.953265ms, swizzle: 2048, TFLOPS: 110.99 - (cublas): ['9.84375 ', '-46.71875 '], time:4.983496ms, swizzle: NOOP, TFLOPS: 110.32 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=8192, N=8192, K=8192, Warmup=5, Iters=20, 15/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['47.53125 ', '-51.5 '], time:23.78494ms, swizzle: NOOP, TFLOPS: 46.23 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['47.53125 ', '-51.5 '], time:22.85294ms, swizzle: NOOP, TFLOPS: 48.11 (+4.08%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['47.0 ', '-52.25 '], time:24.70605ms, swizzle: NOOP, TFLOPS: 44.50 - (mma4x2+warp2x4): ['47.0 ', '-52.25 '], time:14.21954ms, swizzle: NOOP, TFLOPS: 77.32 (+60.72%) - (mma4x2+warp2x4+stage3): ['47.0 ', '-52.25 '], time:10.34536ms, swizzle: NOOP, TFLOPS: 106.28(+37.45%) - (mma4x2+warp2x4+stage2): ['47.0 ', '-52.25 '], time:10.25786ms, swizzle: NOOP, TFLOPS: 107.19(+0.85%) - (mma4x2+warp2x4+stage3+dsmem): ['47.0 ', '-52.25 '], time:10.49890ms, swizzle: NOOP, TFLOPS: 104.73 - (mma4x2+warp2x4+stage2+dsmem): ['47.0 ', '-52.25 '], time:10.29896ms, swizzle: NOOP, TFLOPS: 106.76 - (mma4x2+warp2x4+stage3+swizzle): ['47.0 ', '-52.25 '], time:9.953498ms, swizzle: 2048, TFLOPS: 110.46(+3.06%) - (mma4x2+warp2x4+stage2+swizzle): ['47.0 ', '-52.25 '], time:9.775471ms, swizzle: 2048, TFLOPS: 112.48(+1.82%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:9.905838ms, swizzle: 2048, TFLOPS: 111.00 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:9.768342ms, swizzle: 2048, TFLOPS: 112.56(+0.07%) - (cublas): ['47.0 ', '-52.25 '], time:9.739327ms, swizzle: NOOP, TFLOPS: 112.89(+0.30%) ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=8192, N=16384, K=2048, Warmup=5, Iters=20, 16/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['14.765625 ', '-18.640625'], time:10.92975ms, swizzle: NOOP, TFLOPS: 50.30 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['14.765625 ', '-18.640625'], time:10.32440ms, swizzle: NOOP, TFLOPS: 53.25 (+5.86%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['14.28125 ', '-18.6875 '], time:11.78483ms, swizzle: NOOP, TFLOPS: 46.65 - (mma4x2+warp2x4): ['14.28125 ', '-18.6875 '], time:6.915855ms, swizzle: NOOP, TFLOPS: 79.49 (+49.29%) - (mma4x2+warp2x4+stage3): ['14.28125 ', '-18.6875 '], time:5.065703ms, swizzle: NOOP, TFLOPS: 108.53(+36.52%) - (mma4x2+warp2x4+stage2): ['14.28125 ', '-18.6875 '], time:5.030226ms, swizzle: NOOP, TFLOPS: 109.29(+0.71%) - (mma4x2+warp2x4+stage3+dsmem): ['14.28125 ', '-18.6875 '], time:5.033874ms, swizzle: NOOP, TFLOPS: 109.21 - (mma4x2+warp2x4+stage2+dsmem): ['14.28125 ', '-18.6875 '], time:5.028772ms, swizzle: NOOP, TFLOPS: 109.32(+0.03%) - (mma4x2+warp2x4+stage3+swizzle): ['14.28125 ', '-18.6875 '], time:5.024838ms, swizzle: 4096, TFLOPS: 109.41(+0.08%) - (mma4x2+warp2x4+stage2+swizzle): ['14.28125 ', '-18.6875 '], time:4.977488ms, swizzle: 4096, TFLOPS: 110.45(+0.95%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:5.048298ms, swizzle: 4096, TFLOPS: 108.90 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:4.978847ms, swizzle: 4096, TFLOPS: 110.42 - (cublas): ['14.28125 ', '-18.6875 '], time:5.005145ms, swizzle: NOOP, TFLOPS: 109.84 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=8192, N=16384, K=4096, Warmup=5, Iters=20, 17/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['10.296875 ', '-46.6875 '], time:23.64189ms, swizzle: NOOP, TFLOPS: 46.51 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['10.296875 ', '-46.6875 '], time:22.93310ms, swizzle: NOOP, TFLOPS: 47.94 (+3.09%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['9.84375 ', '-46.71875 '], time:31.19268ms, swizzle: NOOP, TFLOPS: 35.25 - (mma4x2+warp2x4): ['9.84375 ', '-46.71875 '], time:14.21802ms, swizzle: NOOP, TFLOPS: 77.33 (+61.30%) - (mma4x2+warp2x4+stage3): ['9.84375 ', '-46.71875 '], time:10.72919ms, swizzle: NOOP, TFLOPS: 102.48(+32.52%) - (mma4x2+warp2x4+stage2): ['9.84375 ', '-46.71875 '], time:10.53795ms, swizzle: NOOP, TFLOPS: 104.34(+1.81%) - (mma4x2+warp2x4+stage3+dsmem): ['9.84375 ', '-46.71875 '], time:10.60345ms, swizzle: NOOP, TFLOPS: 103.69 - (mma4x2+warp2x4+stage2+dsmem): ['9.84375 ', '-46.71875 '], time:10.46254ms, swizzle: NOOP, TFLOPS: 105.09(+0.72%) - (mma4x2+warp2x4+stage3+swizzle): ['9.84375 ', '-46.71875 '], time:9.963369ms, swizzle: 4096, TFLOPS: 110.36(+5.01%) - (mma4x2+warp2x4+stage2+swizzle): ['9.84375 ', '-46.71875 '], time:9.808254ms, swizzle: 4096, TFLOPS: 112.10(+1.58%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:9.931588ms, swizzle: 4096, TFLOPS: 110.71 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:9.800553ms, swizzle: 4096, TFLOPS: 112.19(+0.08%) - (cublas): ['9.84375 ', '-46.71875 '], time:9.695315ms, swizzle: NOOP, TFLOPS: 113.41(+1.09%) ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=8192, N=16384, K=8192, Warmup=5, Iters=20, 18/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['47.53125 ', '-51.5 '], time:49.22699ms, swizzle: NOOP, TFLOPS: 44.67 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['47.53125 ', '-51.5 '], time:48.20067ms, swizzle: NOOP, TFLOPS: 45.62 (+2.13%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['47.0 ', '-52.25 '], time:61.79366ms, swizzle: NOOP, TFLOPS: 35.59 - (mma4x2+warp2x4): ['47.0 ', '-52.25 '], time:28.10072ms, swizzle: NOOP, TFLOPS: 78.26 (+71.53%) - (mma4x2+warp2x4+stage3): ['47.0 ', '-52.25 '], time:24.79410ms, swizzle: NOOP, TFLOPS: 88.69 (+13.34%) - (mma4x2+warp2x4+stage2): ['47.0 ', '-52.25 '], time:24.75156ms, swizzle: NOOP, TFLOPS: 88.84 (+0.17%) - (mma4x2+warp2x4+stage3+dsmem): ['47.0 ', '-52.25 '], time:24.81336ms, swizzle: NOOP, TFLOPS: 88.62 - (mma4x2+warp2x4+stage2+dsmem): ['47.0 ', '-52.25 '], time:24.72374ms, swizzle: NOOP, TFLOPS: 88.94 (+0.11%) - (mma4x2+warp2x4+stage3+swizzle): ['47.0 ', '-52.25 '], time:19.72281ms, swizzle: 4096, TFLOPS: 111.50(+25.36%) - (mma4x2+warp2x4+stage2+swizzle): ['47.0 ', '-52.25 '], time:19.45309ms, swizzle: 4096, TFLOPS: 113.04(+1.39%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:19.70534ms, swizzle: 4096, TFLOPS: 111.60 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:19.46532ms, swizzle: 4096, TFLOPS: 112.97 - (cublas): ['47.0 ', '-52.25 '], time:19.07067ms, swizzle: NOOP, TFLOPS: 115.31(+2.01%) ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=16384, N=4096, K=2048, Warmup=5, Iters=20, 19/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['14.765625 ', '-18.640625'], time:5.374526ms, swizzle: NOOP, TFLOPS: 51.14 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['14.765625 ', '-18.640625'], time:4.960513ms, swizzle: NOOP, TFLOPS: 55.41 (+8.35%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['14.28125 ', '-18.6875 '], time:5.899548ms, swizzle: NOOP, TFLOPS: 46.59 - (mma4x2+warp2x4): ['14.28125 ', '-18.6875 '], time:3.571367ms, swizzle: NOOP, TFLOPS: 76.97 (+38.90%) - (mma4x2+warp2x4+stage3): ['14.28125 ', '-18.6875 '], time:2.558302ms, swizzle: NOOP, TFLOPS: 107.45(+39.60%) - (mma4x2+warp2x4+stage2): ['14.28125 ', '-18.6875 '], time:2.541303ms, swizzle: NOOP, TFLOPS: 108.16(+0.67%) - (mma4x2+warp2x4+stage3+dsmem): ['14.28125 ', '-18.6875 '], time:2.538442ms, swizzle: NOOP, TFLOPS: 108.29(+0.11%) - (mma4x2+warp2x4+stage2+dsmem): ['14.28125 ', '-18.6875 '], time:2.542233ms, swizzle: NOOP, TFLOPS: 108.12 - (mma4x2+warp2x4+stage3+swizzle): ['14.28125 ', '-18.6875 '], time:2.538371ms, swizzle: 1024, TFLOPS: 108.29(+0.00%) - (mma4x2+warp2x4+stage2+swizzle): ['14.28125 ', '-18.6875 '], time:2.512073ms, swizzle: 1024, TFLOPS: 109.42(+1.05%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:2.551054ms, swizzle: 1024, TFLOPS: 107.75 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:2.513241ms, swizzle: 1024, TFLOPS: 109.37 - (cublas): ['14.28125 ', '-18.6875 '], time:2.619862ms, swizzle: NOOP, TFLOPS: 104.92 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=16384, N=4096, K=4096, Warmup=5, Iters=20, 20/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['10.296875 ', '-46.6875 '], time:11.02216ms, swizzle: NOOP, TFLOPS: 49.88 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['10.296875 ', '-46.6875 '], time:10.51564ms, swizzle: NOOP, TFLOPS: 52.28 (+4.82%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['9.84375 ', '-46.71875 '], time:11.72006ms, swizzle: NOOP, TFLOPS: 46.91 - (mma4x2+warp2x4): ['9.84375 ', '-46.71875 '], time:6.976056ms, swizzle: NOOP, TFLOPS: 78.81 (+50.74%) - (mma4x2+warp2x4+stage3): ['9.84375 ', '-46.71875 '], time:5.025959ms, swizzle: NOOP, TFLOPS: 109.38(+38.80%) - (mma4x2+warp2x4+stage2): ['9.84375 ', '-46.71875 '], time:4.991459ms, swizzle: NOOP, TFLOPS: 110.14(+0.69%) - (mma4x2+warp2x4+stage3+dsmem): ['9.84375 ', '-46.71875 '], time:4.996275ms, swizzle: NOOP, TFLOPS: 110.03 - (mma4x2+warp2x4+stage2+dsmem): ['9.84375 ', '-46.71875 '], time:4.992103ms, swizzle: NOOP, TFLOPS: 110.13 - (mma4x2+warp2x4+stage3+swizzle): ['9.84375 ', '-46.71875 '], time:4.988074ms, swizzle: 1024, TFLOPS: 110.21(+0.07%) - (mma4x2+warp2x4+stage2+swizzle): ['9.84375 ', '-46.71875 '], time:4.960775ms, swizzle: 1024, TFLOPS: 110.82(+0.55%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:5.020546ms, swizzle: 1024, TFLOPS: 109.50 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:4.957389ms, swizzle: 1024, TFLOPS: 110.90(+0.07%) - (cublas): ['9.84375 ', '-46.71875 '], time:4.962539ms, swizzle: NOOP, TFLOPS: 110.78 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=16384, N=4096, K=8192, Warmup=5, Iters=20, 21/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['47.53125 ', '-51.5 '], time:22.25213ms, swizzle: NOOP, TFLOPS: 49.41 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['47.53125 ', '-51.5 '], time:21.25067ms, swizzle: NOOP, TFLOPS: 51.74 (+4.71%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['47.0 ', '-52.25 '], time:23.32034ms, swizzle: NOOP, TFLOPS: 47.15 - (mma4x2+warp2x4): ['47.0 ', '-52.25 '], time:13.78231ms, swizzle: NOOP, TFLOPS: 79.78 (+54.19%) - (mma4x2+warp2x4+stage3): ['47.0 ', '-52.25 '], time:9.944629ms, swizzle: NOOP, TFLOPS: 110.56(+38.59%) - (mma4x2+warp2x4+stage2): ['47.0 ', '-52.25 '], time:9.877133ms, swizzle: NOOP, TFLOPS: 111.32(+0.68%) - (mma4x2+warp2x4+stage3+dsmem): ['47.0 ', '-52.25 '], time:9.891724ms, swizzle: NOOP, TFLOPS: 111.15 - (mma4x2+warp2x4+stage2+dsmem): ['47.0 ', '-52.25 '], time:9.875774ms, swizzle: NOOP, TFLOPS: 111.33(+0.01%) - (mma4x2+warp2x4+stage3+swizzle): ['47.0 ', '-52.25 '], time:9.909319ms, swizzle: 1024, TFLOPS: 110.96 - (mma4x2+warp2x4+stage2+swizzle): ['47.0 ', '-52.25 '], time:9.821128ms, swizzle: 1024, TFLOPS: 111.95(+0.56%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:10.00571ms, swizzle: 1024, TFLOPS: 109.89 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:9.818959ms, swizzle: 1024, TFLOPS: 111.98(+0.02%) - (cublas): ['47.0 ', '-52.25 '], time:9.649991ms, swizzle: NOOP, TFLOPS: 113.94(+1.75%) ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=16384, N=8192, K=2048, Warmup=5, Iters=20, 22/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['14.765625 ', '-18.640625'], time:10.99567ms, swizzle: NOOP, TFLOPS: 50.00 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['14.765625 ', '-18.640625'], time:10.49816ms, swizzle: NOOP, TFLOPS: 52.37 (+4.74%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['14.28125 ', '-18.6875 '], time:11.76815ms, swizzle: NOOP, TFLOPS: 46.72 - (mma4x2+warp2x4): ['14.28125 ', '-18.6875 '], time:6.931424ms, swizzle: NOOP, TFLOPS: 79.31 (+51.46%) - (mma4x2+warp2x4+stage3): ['14.28125 ', '-18.6875 '], time:5.055880ms, swizzle: NOOP, TFLOPS: 108.74(+37.10%) - (mma4x2+warp2x4+stage2): ['14.28125 ', '-18.6875 '], time:5.022001ms, swizzle: NOOP, TFLOPS: 109.47(+0.67%) - (mma4x2+warp2x4+stage3+dsmem): ['14.28125 ', '-18.6875 '], time:5.026936ms, swizzle: NOOP, TFLOPS: 109.36 - (mma4x2+warp2x4+stage2+dsmem): ['14.28125 ', '-18.6875 '], time:5.020689ms, swizzle: NOOP, TFLOPS: 109.50(+0.03%) - (mma4x2+warp2x4+stage3+swizzle): ['14.28125 ', '-18.6875 '], time:5.018496ms, swizzle: 2048, TFLOPS: 109.55(+0.04%) - (mma4x2+warp2x4+stage2+swizzle): ['14.28125 ', '-18.6875 '], time:4.968738ms, swizzle: 2048, TFLOPS: 110.64(+1.00%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:5.040884ms, swizzle: 2048, TFLOPS: 109.06 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:4.972743ms, swizzle: 2048, TFLOPS: 110.55 - (cublas): ['14.28125 ', '-18.6875 '], time:4.969763ms, swizzle: NOOP, TFLOPS: 110.62 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=16384, N=8192, K=4096, Warmup=5, Iters=20, 23/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['10.296875 ', '-46.6875 '], time:22.06621ms, swizzle: NOOP, TFLOPS: 49.83 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['10.296875 ', '-46.6875 '], time:21.04604ms, swizzle: NOOP, TFLOPS: 52.24 (+4.85%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['9.84375 ', '-46.71875 '], time:23.35319ms, swizzle: NOOP, TFLOPS: 47.08 - (mma4x2+warp2x4): ['9.84375 ', '-46.71875 '], time:13.64452ms, swizzle: NOOP, TFLOPS: 80.58 (+54.25%) - (mma4x2+warp2x4+stage3): ['9.84375 ', '-46.71875 '], time:9.970688ms, swizzle: NOOP, TFLOPS: 110.27(+36.85%) - (mma4x2+warp2x4+stage2): ['9.84375 ', '-46.71875 '], time:9.907126ms, swizzle: NOOP, TFLOPS: 110.98(+0.64%) - (mma4x2+warp2x4+stage3+dsmem): ['9.84375 ', '-46.71875 '], time:9.919929ms, swizzle: NOOP, TFLOPS: 110.84 - (mma4x2+warp2x4+stage2+dsmem): ['9.84375 ', '-46.71875 '], time:9.905028ms, swizzle: NOOP, TFLOPS: 111.01(+0.02%) - (mma4x2+warp2x4+stage3+swizzle): ['9.84375 ', '-46.71875 '], time:9.890699ms, swizzle: 2048, TFLOPS: 111.17(+0.14%) - (mma4x2+warp2x4+stage2+swizzle): ['9.84375 ', '-46.71875 '], time:9.796810ms, swizzle: 2048, TFLOPS: 112.23(+0.96%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:9.931468ms, swizzle: 2048, TFLOPS: 110.71 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:9.799408ms, swizzle: 2048, TFLOPS: 112.20 - (cublas): ['9.84375 ', '-46.71875 '], time:9.670424ms, swizzle: NOOP, TFLOPS: 113.70(+1.31%) ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=16384, N=8192, K=8192, Warmup=5, Iters=20, 24/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['47.53125 ', '-51.5 '], time:47.87569ms, swizzle: NOOP, TFLOPS: 45.93 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['47.53125 ', '-51.5 '], time:46.51024ms, swizzle: NOOP, TFLOPS: 47.28 (+2.94%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['47.0 ', '-52.25 '], time:49.74699ms, swizzle: NOOP, TFLOPS: 44.20 - (mma4x2+warp2x4): ['47.0 ', '-52.25 '], time:28.01880ms, swizzle: NOOP, TFLOPS: 78.48 (+66.00%) - (mma4x2+warp2x4+stage3): ['47.0 ', '-52.25 '], time:21.40111ms, swizzle: NOOP, TFLOPS: 102.75(+30.92%) - (mma4x2+warp2x4+stage2): ['47.0 ', '-52.25 '], time:20.99103ms, swizzle: NOOP, TFLOPS: 104.76(+1.95%) - (mma4x2+warp2x4+stage3+dsmem): ['47.0 ', '-52.25 '], time:21.22135ms, swizzle: NOOP, TFLOPS: 103.62 - (mma4x2+warp2x4+stage2+dsmem): ['47.0 ', '-52.25 '], time:21.01814ms, swizzle: NOOP, TFLOPS: 104.62 - (mma4x2+warp2x4+stage3+swizzle): ['47.0 ', '-52.25 '], time:19.75324ms, swizzle: 2048, TFLOPS: 111.32(+6.27%) - (mma4x2+warp2x4+stage2+swizzle): ['47.0 ', '-52.25 '], time:19.45850ms, swizzle: 2048, TFLOPS: 113.01(+1.51%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:19.70596ms, swizzle: 2048, TFLOPS: 111.59 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:19.45621ms, swizzle: 2048, TFLOPS: 113.02(+0.01%) - (cublas): ['47.0 ', '-52.25 '], time:19.04292ms, swizzle: NOOP, TFLOPS: 115.48(+2.17%) ----------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=16384, N=16384, K=2048, Warmup=5, Iters=20, 25/27 ---------------------------------------------------------------------------------------------------------------------------------- @@ -737,11 +453,129 @@ python3 hgemm.py ---------------------------------------------------------------------------------------------------------------------------------- ``` +### MMA: Up to 115 TFLOPS, 115/119.5=96.23% TFLOPS utilization. + +```bash +python3 hgemm.py --mma +``` + +输出: +```bash +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=4096, K=8192, Warmup=5, Iters=20, 21/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['21.984375 ', '58.0 '], time:10.29069ms, swizzle: NOOP, TFLOPS: 106.85(+0.00%) + (mma2x4+warp4x4+stage3): ['21.984375 ', '58.0 '], time:9.866333ms, swizzle: NOOP, TFLOPS: 111.44(+4.30%) + (mma2x4+warp4x4+stage2): ['21.984375 ', '58.0 '], time:9.776329ms, swizzle: NOOP, TFLOPS: 112.47(+0.92%) + (mma2x4+warp4x4+stage3+dsmem): ['21.984375 ', '58.0 '], time:9.924983ms, swizzle: NOOP, TFLOPS: 110.78 + (mma2x4+warp4x4+stage2+dsmem): ['21.984375 ', '58.0 '], time:9.772467ms, swizzle: NOOP, TFLOPS: 112.51(+0.04%) + (mma2x4+warp4x4+stage3+swizzle): ['21.984375 ', '58.0 '], time:9.879112ms, swizzle: 1024, TFLOPS: 111.30 + (mma2x4+warp4x4+stage2+swizzle): ['21.984375 ', '58.0 '], time:9.752583ms, swizzle: 1024, TFLOPS: 112.74(+0.20%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['21.984375 ', '58.0 '], time:9.922742ms, swizzle: 1024, TFLOPS: 110.81 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['21.984375 ', '58.0 '], time:9.673309ms, swizzle: 1024, TFLOPS: 113.66(+0.82%) + (cublas): ['21.984375 ', '58.0 '], time:9.443545ms, swizzle: NOOP, TFLOPS: 116.43(+2.43%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=8192, K=2048, Warmup=5, Iters=20, 22/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['32.40625 ', '-4.0039062'], time:5.229425ms, swizzle: NOOP, TFLOPS: 105.13(+0.00%) + (mma2x4+warp4x4+stage3): ['32.40625 ', '-4.0039062'], time:5.009818ms, swizzle: NOOP, TFLOPS: 109.74(+4.38%) + (mma2x4+warp4x4+stage2): ['32.40625 ', '-4.0039062'], time:4.968261ms, swizzle: NOOP, TFLOPS: 110.65(+0.84%) + (mma2x4+warp4x4+stage3+dsmem): ['32.40625 ', '-4.0039062'], time:5.031824ms, swizzle: NOOP, TFLOPS: 109.26 + (mma2x4+warp4x4+stage2+dsmem): ['32.40625 ', '-4.0039062'], time:4.965233ms, swizzle: NOOP, TFLOPS: 110.72(+0.06%) + (mma2x4+warp4x4+stage3+swizzle): ['32.40625 ', '-4.0039062'], time:5.021595ms, swizzle: 2048, TFLOPS: 109.48 + (mma2x4+warp4x4+stage2+swizzle): ['32.40625 ', '-4.0039062'], time:4.914212ms, swizzle: 2048, TFLOPS: 111.87(+1.04%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['32.40625 ', '-4.0039062'], time:5.039000ms, swizzle: 2048, TFLOPS: 109.10 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['32.40625 ', '-4.0039062'], time:4.895591ms, swizzle: 2048, TFLOPS: 112.30(+0.38%) + (cublas): ['32.40625 ', '-4.0039062'], time:4.766654ms, swizzle: NOOP, TFLOPS: 115.33(+2.70%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=8192, K=4096, Warmup=5, Iters=20, 23/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['99.125 ', '22.8125 '], time:10.30406ms, swizzle: NOOP, TFLOPS: 106.71(+0.00%) + (mma2x4+warp4x4+stage3): ['99.125 ', '22.8125 '], time:9.895300ms, swizzle: NOOP, TFLOPS: 111.11(+4.13%) + (mma2x4+warp4x4+stage2): ['99.125 ', '22.8125 '], time:9.813237ms, swizzle: NOOP, TFLOPS: 112.04(+0.84%) + (mma2x4+warp4x4+stage3+dsmem): ['99.125 ', '22.8125 '], time:9.948658ms, swizzle: NOOP, TFLOPS: 110.52 + (mma2x4+warp4x4+stage2+dsmem): ['99.125 ', '22.8125 '], time:9.798026ms, swizzle: NOOP, TFLOPS: 112.22(+0.16%) + (mma2x4+warp4x4+stage3+swizzle): ['99.125 ', '22.8125 '], time:9.914517ms, swizzle: 2048, TFLOPS: 110.90 + (mma2x4+warp4x4+stage2+swizzle): ['99.125 ', '22.8125 '], time:9.733128ms, swizzle: 2048, TFLOPS: 112.97(+0.67%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['99.125 ', '22.8125 '], time:9.941744ms, swizzle: 2048, TFLOPS: 110.60 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['99.125 ', '22.8125 '], time:9.670472ms, swizzle: 2048, TFLOPS: 113.70(+0.65%) + (cublas): ['99.125 ', '22.8125 '], time:9.453558ms, swizzle: NOOP, TFLOPS: 116.31(+2.29%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=8192, K=8192, Warmup=5, Iters=20, 24/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['21.984375 ', '58.0 '], time:21.51823ms, swizzle: NOOP, TFLOPS: 102.19(+0.00%) + (mma2x4+warp4x4+stage3): ['21.984375 ', '58.0 '], time:20.90017ms, swizzle: NOOP, TFLOPS: 105.22(+2.96%) + (mma2x4+warp4x4+stage2): ['21.984375 ', '58.0 '], time:20.75178ms, swizzle: NOOP, TFLOPS: 105.97(+0.72%) + (mma2x4+warp4x4+stage3+dsmem): ['21.984375 ', '58.0 '], time:20.97730ms, swizzle: NOOP, TFLOPS: 104.83 + (mma2x4+warp4x4+stage2+dsmem): ['21.984375 ', '58.0 '], time:20.83809ms, swizzle: NOOP, TFLOPS: 105.53 + (mma2x4+warp4x4+stage3+swizzle): ['21.984375 ', '58.0 '], time:19.78309ms, swizzle: 2048, TFLOPS: 111.16(+4.90%) + (mma2x4+warp4x4+stage2+swizzle): ['21.984375 ', '58.0 '], time:19.33062ms, swizzle: 2048, TFLOPS: 113.76(+2.34%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['21.984375 ', '58.0 '], time:19.74017ms, swizzle: 2048, TFLOPS: 111.40 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['21.984375 ', '58.0 '], time:19.22986ms, swizzle: 2048, TFLOPS: 114.35(+0.52%) + (cublas): ['21.984375 ', '58.0 '], time:18.83535ms, swizzle: NOOP, TFLOPS: 116.75(+2.09%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=16384, K=2048, Warmup=5, Iters=20, 25/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['32.40625 ', '-4.0039062'], time:10.34352ms, swizzle: NOOP, TFLOPS: 106.30(+0.00%) + (mma2x4+warp4x4+stage3): ['32.40625 ', '-4.0039062'], time:9.953904ms, swizzle: NOOP, TFLOPS: 110.46(+3.91%) + (mma2x4+warp4x4+stage2): ['32.40625 ', '-4.0039062'], time:9.861850ms, swizzle: NOOP, TFLOPS: 111.49(+0.93%) + (mma2x4+warp4x4+stage3+dsmem): ['32.40625 ', '-4.0039062'], time:9.998512ms, swizzle: NOOP, TFLOPS: 109.97 + (mma2x4+warp4x4+stage2+dsmem): ['32.40625 ', '-4.0039062'], time:9.855365ms, swizzle: NOOP, TFLOPS: 111.56(+0.07%) + (mma2x4+warp4x4+stage3+swizzle): ['32.40625 ', '-4.0039062'], time:9.974408ms, swizzle: 4096, TFLOPS: 110.23 + (mma2x4+warp4x4+stage2+swizzle): ['32.40625 ', '-4.0039062'], time:9.743142ms, swizzle: 4096, TFLOPS: 112.85(+1.15%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['32.40625 ', '-4.0039062'], time:9.995770ms, swizzle: 4096, TFLOPS: 110.00 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['32.40625 ', '-4.0039062'], time:9.701442ms, swizzle: 4096, TFLOPS: 113.33(+0.43%) + (cublas): ['32.40625 ', '-4.0039062'], time:9.485888ms, swizzle: NOOP, TFLOPS: 115.91(+2.27%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=16384, K=4096, Warmup=5, Iters=20, 26/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['99.125 ', '22.8125 '], time:22.18379ms, swizzle: NOOP, TFLOPS: 99.13 (+0.00%) + (mma2x4+warp4x4+stage3): ['99.125 ', '22.8125 '], time:21.83485ms, swizzle: NOOP, TFLOPS: 100.71(+1.60%) + (mma2x4+warp4x4+stage2): ['99.125 ', '22.8125 '], time:21.14553ms, swizzle: NOOP, TFLOPS: 103.99(+3.26%) + (mma2x4+warp4x4+stage3+dsmem): ['99.125 ', '22.8125 '], time:21.59111ms, swizzle: NOOP, TFLOPS: 101.85 + (mma2x4+warp4x4+stage2+dsmem): ['99.125 ', '22.8125 '], time:20.96095ms, swizzle: NOOP, TFLOPS: 104.91(+0.88%) + (mma2x4+warp4x4+stage3+swizzle): ['99.125 ', '22.8125 '], time:19.78907ms, swizzle: 4096, TFLOPS: 111.12(+5.92%) + (mma2x4+warp4x4+stage2+swizzle): ['99.125 ', '22.8125 '], time:19.28851ms, swizzle: 4096, TFLOPS: 114.01(+2.60%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['99.125 ', '22.8125 '], time:19.74153ms, swizzle: 4096, TFLOPS: 111.39 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['99.125 ', '22.8125 '], time:19.19734ms, swizzle: 4096, TFLOPS: 114.55(+0.47%) + (cublas): ['99.125 ', '22.8125 '], time:18.88573ms, swizzle: NOOP, TFLOPS: 116.44(+1.65%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=16384, K=8192, Warmup=5, Iters=20, 27/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['21.984375 ', '58.0 '], time:45.41800ms, swizzle: NOOP, TFLOPS: 96.83 (+0.00%) + (mma2x4+warp4x4+stage3): ['21.984375 ', '58.0 '], time:49.64394ms, swizzle: NOOP, TFLOPS: 88.59 + (mma2x4+warp4x4+stage2): ['21.984375 ', '58.0 '], time:49.82240ms, swizzle: NOOP, TFLOPS: 88.27 + (mma2x4+warp4x4+stage3+dsmem): ['21.984375 ', '58.0 '], time:49.68290ms, swizzle: NOOP, TFLOPS: 88.52 + (mma2x4+warp4x4+stage2+dsmem): ['21.984375 ', '58.0 '], time:49.83477ms, swizzle: NOOP, TFLOPS: 88.25 + (mma2x4+warp4x4+stage3+swizzle): ['21.984375 ', '58.0 '], time:39.11197ms, swizzle: 4096, TFLOPS: 112.45(+16.12%) + (mma2x4+warp4x4+stage2+swizzle): ['21.984375 ', '58.0 '], time:38.40293ms, swizzle: 4096, TFLOPS: 114.52(+1.85%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['21.984375 ', '58.0 '], time:39.23041ms, swizzle: 4096, TFLOPS: 112.11 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['21.984375 ', '58.0 '], time:38.21511ms, swizzle: 4096, TFLOPS: 115.09(+0.49%) + (cublas): ['21.984375 ', '58.0 '], time:37.87384ms, swizzle: NOOP, TFLOPS: 116.12(+0.90%) +---------------------------------------------------------------------------------------------------------------------------------- +``` + + ## NVIDIA GeForce RTX 3080 Laptop
+- WMMA + ```bash -python3 hgemm.py --wmma --no-default +python3 hgemm.py --wmma --wmma-all ``` 输出: ```bash diff --git a/hgemm/hgemm.cu b/hgemm/hgemm.cu index b0179ca0..8845b188 100644 --- a/hgemm/hgemm.cu +++ b/hgemm/hgemm.cu @@ -1011,6 +1011,12 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages(torch::Tensor a, torch::Tensor b void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); +// from hgemm_mma.cu +void hgemm_mma_m16n8k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c); +void hgemm_mma_m16n8k16_mma2x4_warp4x4(torch::Tensor a, torch::Tensor b, torch::Tensor c); +// from hgemm_mma_stage.cu +void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); +void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // CUDA Cores FP16 @@ -1042,5 +1048,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem) TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem) TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem) + // MMA API Tensor Cores + TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_naive) + TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4) + // stage, thread block swizzle, dsmem + TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages) + TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem) } diff --git a/hgemm/hgemm.py b/hgemm/hgemm.py index 597d8220..458d0818 100644 --- a/hgemm/hgemm.py +++ b/hgemm/hgemm.py @@ -14,12 +14,15 @@ def get_args(): parser.add_argument("--K", type=int, default=None, help="Matrix K size") parser.add_argument("--warmup", "--w", type=int, default=5, help="Warmup iters") parser.add_argument("--iters", "--i", type=int, default=20, help="Benchmark iters") - parser.add_argument("--enable-mma-all", "--mma", action="store_true", help="Enable all MMA kernel tests") - parser.add_argument("--enable-wmma-all", "--wmma", action="store_true", help="Enable all WMMA kernel tests") - parser.add_argument("--enable-cuda-all", "--cuda", action="store_true", help="Enable all CUDA kernel tests") + parser.add_argument("--show-all", "--show", action="store_true", help="Show all matrix values ") + parser.add_argument("--enable-mma", "--mma", action="store_true", help="Enable MMA kernel tests") + parser.add_argument("--enable-wmma", "--wmma", action="store_true", help="Enable WMMA kernel tests") + parser.add_argument("--enable-cuda", "--cuda", action="store_true", help="Enable CUDA kernel tests") + parser.add_argument("--enable-mma-all", "--mma-all", action="store_true", help="Enable all MMA kernel tests") + parser.add_argument("--enable-wmma-all", "--wmma-all", action="store_true", help="Enable all WMMA kernel tests") + parser.add_argument("--enable-cuda-all", "--cuda-all", action="store_true", help="Enable all CUDA kernel tests") parser.add_argument("--enable-torch", "--torch", action="store_true", help="Enable torch matmul") parser.add_argument("--disable-cublas", "--no-cublas", action="store_true", help="Disable cublas hgemm") - parser.add_argument("--disable-default", "--no-default", action="store_true", help="Disable default tests") return parser.parse_args() args = get_args() @@ -29,7 +32,8 @@ def get_args(): print("Loading hgemm lib ...") lib = load(name='hgemm_lib', sources=['hgemm.cu', 'hgemm_async.cu', 'hgemm_wmma.cu', - 'hgemm_wmma_stage.cu', 'hgemm_cublas.cu'], + 'hgemm_wmma_stage.cu', 'hgemm_cublas.cu', + 'hgemm_mma.cu', 'hgemm_mma_stage.cu'], extra_cuda_cflags=[ "-O3", "-U__CUDA_NO_HALF_OPERATORS__", @@ -40,7 +44,8 @@ def get_args(): "--expt-extended-lambda", "--use_fast_math" ], - extra_cflags=['-std=c++17']) + extra_cflags=['-std=c++17'], + verbose=False) MAX_TFLOPS = -1 @@ -52,7 +57,7 @@ def run_benchmark(perf_func: callable, swizzle_stride: int = 1, warmup: int = args.warmup, iters: int = args.iters, - show_all: bool = False): + show_all: bool = args.show_all): global MAX_TFLOPS M = a.size(0) @@ -163,9 +168,10 @@ def run_benchmark(perf_func: callable, # CUDA Cores FP16 run_benchmark(lib.hgemm_naive_f16, a, b, "(naive)", c) run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "(f16x8pack+t8x8+bcf)", c) - if not args.disable_default: + if args.enable_cuda or args.enable_cuda_all: run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "(f16x8pack+t8x8+dbuf)", c) run_benchmark(lib.hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf, a, b, "(f16x8pack+t8x8+k16+dbuf)", c) + if args.enable_wmma or args.enable_wmma_all: print("-" * 68 + "WMMA" + "-" * 58) # wmma api, stages, dsmem, swizzle run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2, a, b, "(mma4x2)", c) @@ -193,8 +199,19 @@ def run_benchmark(perf_func: callable, run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(mma4x2+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(mma4x2+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) - if args.enable_mma_all: # more mma kernel tests. + if args.enable_mma or args.enable_mma_all: print("-" * 68 + "MMA" + "-" * 59) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4, a, b, "(mma2x4+warp4x4)", c) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage3)", c, stages=3) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage2)", c, stages=2) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage3+dsmem)", c, stages=3) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage2+dsmem)", c, stages=2) + # thread block swizzle + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage3+swizzle)", c, stages=3, swizzle=True) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage2+swizzle)", c, stages=2, swizzle=True) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) + if args.enable_mma_all: # more mma kernel tests. pass if not args.disable_cublas: run_benchmark(lib.hgemm_cublas_tensor_op_row_major, a, b, "(cublas)", c) diff --git a/hgemm/hgemm_mma.cu b/hgemm/hgemm_mma.cu index 678c40d2..7a3b5fe8 100644 --- a/hgemm/hgemm_mma.cu +++ b/hgemm/hgemm_mma.cu @@ -28,16 +28,306 @@ using namespace nvcuda; // ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes. #define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) #define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) -// Support A and B matrix with row-major inorder to compare with the kernels using CUDA Cores in -// hgemm.cu and hgemm_async.cu. - +#define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +#define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) HOST_DEVICE_INLINE int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } -// only 1 warp per block(32 threads), m16n16k16. A, B, C: all row_major. -template -__global__ void hgemm_mma_m16n16k16_naive_kernel(half* A, half* B, half* C, +// only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major. +template +__global__ void hgemm_mma_m16n8k16_naive_kernel(half* A, half* B, half* C, int M, int N, int K) { + const int bx = blockIdx.x; + const int by = blockIdx.y; + const int NUM_K_TILES = div_ceil(K, MMA_K); + constexpr int BM = MMA_M; // 16 + constexpr int BN = MMA_N; // 8 + constexpr int BK = MMA_K; // 16 + + __shared__ half s_a[MMA_M][MMA_K]; // 16x16 + __shared__ half s_b[MMA_K][MMA_N]; // 16x8 + __shared__ half s_c[MMA_M][MMA_N]; // 16x8 + + const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block + const int warp_id = tid / WARP_SIZE; // 0 + const int lane_id = tid % WARP_SIZE; // 0~31 + + // s_a[16][16], 每行16,每线程load 8,需要2线程,共16行,需2x16=32线程 + const int load_smem_a_m = tid / 2; // row 0~15 + const int load_smem_a_k = (tid % 2) * 8; // col 0,8 + // s_b[16][8], 每行8,每线程load 8,需要1线程,共16行,需16线程,只需一半线程加载 + const int load_smem_b_k = tid; // row 0~31, but only use 0~15 + const int load_smem_b_n = 0; // col 0 + const int load_gmem_a_m = by * BM + load_smem_a_m; // global m + const int load_gmem_b_n = bx * BN + load_smem_b_n; // global n + if (load_gmem_a_m >= M && load_gmem_b_n >= N) return; + + uint32_t RC[2] = {0, 0}; + + #pragma unroll + for (int k = 0; k < NUM_K_TILES; ++k) { + // gmem_a -> smem_a + int load_gmem_a_k = k * MMA_K + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = ( + LDST128BITS(A[load_gmem_a_addr])); + + // gmem_b -> smem_b + if (lane_id < MMA_K) { + int load_gmem_b_k = k * MMA_K + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = ( + LDST128BITS(B[load_gmem_b_addr])); + } + __syncthreads(); + + uint32_t RA[4]; + uint32_t RB[2]; + + // ldmatrix for s_a, ldmatrix.trans for s_b. + // s_a: (0,1)*8 -> 0,8 -> [(0~15),(0,8)] + uint32_t load_smem_a_ptr = __cvta_generic_to_shared( + &s_a[lane_id % 16][(lane_id / 16) * 8]); + LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], load_smem_a_ptr); + uint32_t load_smem_b_ptr = __cvta_generic_to_shared( + &s_b[lane_id % 16][0]); + LDMATRIX_X2_T(RB[0], RB[1], load_smem_b_ptr); + + HMMA16816(RC[0], RC[1], RA[0], RA[1], RA[2], RA[3], RB[0], RB[1], RC[0], RC[1]); + + __syncthreads(); + } + + // s_c[16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16] + LDST32BITS(s_c[lane_id / 4 ][(lane_id % 4) * 2]) = LDST32BITS(RC[0]); + LDST32BITS(s_c[lane_id / 4 + 8][(lane_id % 4) * 2]) = LDST32BITS(RC[1]); + + __syncthreads(); + + // store s_c[16][8] + if (lane_id < MMA_M) { + // store 128 bits per memory issue. + int store_gmem_c_m = by * BM + lane_id; + int store_gmem_c_n = bx * BN; + int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n; + LDST128BITS(C[store_gmem_c_addr]) = (LDST128BITS(s_c[lane_id][0])); + } +} + +// 128x128, mma2x4, warp4x4(64,32,16) +template +__global__ void __launch_bounds__(256) +hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel( + half* A, half* B, half* C, int M, int N, int K) { + const int bx = blockIdx.x; + const int by = blockIdx.y; + const int NUM_K_TILES = div_ceil(K, MMA_K); + constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128 + constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128 + constexpr int BK = MMA_K; // 16 + + __shared__ half s_a[BM][BK+A_PAD]; // 128*16*2=4KB + __shared__ half s_b[BK][BN+B_PAD]; // 16*128*2=4KB, 16*(128+16)*2=4.5KB + + const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int lane_id = tid % WARP_SIZE; // 0~31 + const int warp_m = warp_id % 2; // 0,1 + const int warp_n = warp_id / 2; // 0,1,2,3 + + // 先计算shared memory中的索引 + // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=16 按行读取 A行主序 + // 对于s_a每行16个数据,每个线程读取8个,需要2个线程;总共128行,需要128x2刚好256线程 + int load_smem_a_m = tid / 2; // row 0~127 + int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8 + // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B行主序 + // 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共16行,需要16x16=256个线程 + int load_smem_b_k = tid / 16; // row 0~15 + int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120 + // 再计算全局内存中的索引 + // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块 + int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c + int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c + + uint32_t RC[WARP_TILE_M][WARP_TILE_N][2]; + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + RC[i][j][0] = 0; + RC[i][j][1] = 0; + } + } + + #pragma unroll + for (int k = 0; k < NUM_K_TILES; ++k) { + // gmem -> smem + int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = ( + LDST128BITS(B[load_gmem_b_addr])); + LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = ( + LDST128BITS(A[load_gmem_a_addr])); + __syncthreads(); + + // ldmatrix for s_a, ldmatrix.trans for s_b. + uint32_t RA[WARP_TILE_M][4]; + uint32_t RB[WARP_TILE_N][2]; + + // smem -> reg + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 + int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_a_ptr = __cvta_generic_to_shared( + &s_a[lane_smem_a_m][lane_smem_a_k]); + LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + int lane_smem_b_k = lane_id % 16; // 0~15 + int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 + uint32_t lane_smem_b_ptr = __cvta_generic_to_shared( + &s_b[lane_smem_b_k][lane_smem_b_n]); + LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr); + } + + // MMA compute + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + HMMA16816(RC[i][j][0], RC[i][j][1], + RA[i][0], RA[i][1], RA[i][2], RA[i][3], + RB[j][0], RB[j][1], + RC[i][j][0], RC[i][j][1]); + } + } + __syncthreads(); + } + + // reg -> gmem, MMA_MxMMA_N=16x8 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + // mapping lane smem index -> global index. + // [16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16] + int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4; + int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n + (lane_id % 4) * 2; + int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n; + int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n; + // TODO: how to use LDST128BITS here ? reverse the loop order ? + LDST32BITS(C[store_gmem_c_addr_0]) = LDST32BITS(RC[i][j][0]); + LDST32BITS(C[store_gmem_c_addr_1]) = LDST32BITS(RC[i][j][1]); + } + } +} + + +// --------------------- PyTorch bindings for custom kernel ----------------------- +#define STRINGFY(str) #str +#define TORCH_BINDING_COMMON_EXTENSION(func) \ + m.def(STRINGFY(func), &func, STRINGFY(func)); + +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ +if(((T).options().dtype() != (th_type))) { \ + std::cout << "Tensor Info:" << (T).options() << std::endl; \ + throw std::runtime_error("values must be "#th_type); \ +} + +#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ +if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ + throw std::runtime_error("Tensor size mismatch!"); \ +} + +// only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major. +void hgemm_mma_m16n8k16_naive( + torch::Tensor a, torch::Tensor b, torch::Tensor c) { + CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) + const int M = a.size(0); + const int K = a.size(1); + const int N = b.size(1); + CHECK_TORCH_TENSOR_SHAPE(a, M, K) + CHECK_TORCH_TENSOR_SHAPE(b, K, N) + CHECK_TORCH_TENSOR_SHAPE(c, M, N) + constexpr int MMA_M = 16; + constexpr int MMA_N = 8; + constexpr int MMA_K = 16; + + dim3 block(WARP_SIZE); + dim3 grid(div_ceil(N, MMA_N), div_ceil(M, MMA_M)); + + hgemm_mma_m16n8k16_naive_kernel< + MMA_M, MMA_N, MMA_K><<>>( + reinterpret_cast(a.data_ptr()), + reinterpret_cast(b.data_ptr()), + reinterpret_cast(c.data_ptr()), + M, N, K + ); +} + +// 128x128, mma2x4, warp4x4(64,32,16) +void hgemm_mma_m16n8k16_mma2x4_warp4x4( + torch::Tensor a, torch::Tensor b, torch::Tensor c) { + CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) + const int M = a.size(0); + const int K = a.size(1); + const int N = b.size(1); + CHECK_TORCH_TENSOR_SHAPE(a, M, K) + CHECK_TORCH_TENSOR_SHAPE(b, K, N) + CHECK_TORCH_TENSOR_SHAPE(c, M, N) + constexpr int MMA_M = 16; + constexpr int MMA_N = 8; + constexpr int MMA_K = 16; + constexpr int MMA_TILE_M = 2; + constexpr int MMA_TILE_N = 4; + constexpr int WARP_TILE_M = 4; + constexpr int WARP_TILE_N = 4; + constexpr int A_PAD = 0; + constexpr int B_PAD = 16; + constexpr int NUM_THREADS= ( + MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 + + dim3 block(NUM_THREADS); + dim3 grid(div_ceil(N, MMA_N * MMA_TILE_N * WARP_TILE_N), + div_ceil(M, MMA_M * MMA_TILE_M * WARP_TILE_M)); + hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel< + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD><<>>( + reinterpret_cast(a.data_ptr()), + reinterpret_cast(b.data_ptr()), + reinterpret_cast(c.data_ptr()), + M, N, K + ); } diff --git a/hgemm/hgemm_mma_stage.cu b/hgemm/hgemm_mma_stage.cu index 8b137891..c065a9b5 100644 --- a/hgemm/hgemm_mma_stage.cu +++ b/hgemm/hgemm_mma_stage.cu @@ -1 +1,777 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace nvcuda; +#define WARP_SIZE 32 +#define DEVICE_INLINE __device__ inline +#define HOST_DEVICE_INLINE __device__ __host__ inline +#define INT4(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4(value) (reinterpret_cast(&(value))[0]) +#define HALF2(value) (reinterpret_cast(&(value))[0]) +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) +#define LDST32BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST64BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) +#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::) +#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::) +#define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n)) +// ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes. +#define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) +#define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) +#define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +#define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) + +HOST_DEVICE_INLINE +int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } + +// 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle +template +__global__ void __launch_bounds__(256) +hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_kernel( + half* A, half* B, half* C, int M, int N, int K) { + // BLOCK_SWIZZLE 0/1 control use block swizzle or not. + const int bx = ((int) BLOCK_SWIZZLE) * blockIdx.z * gridDim.x + blockIdx.x; + const int by = blockIdx.y; + const int NUM_K_TILES = div_ceil(K, MMA_K); + constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128 + constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128 + constexpr int BK = MMA_K; // 16 + + __shared__ half s_a[K_STAGE][BM][BK+A_PAD]; // 128*16*2=4KB + __shared__ half s_b[K_STAGE][BK][BN+B_PAD]; // 16*128*2=4KB, 16*(128+16)*2=4.5KB + constexpr int s_a_stage_offset = BM * (BK + A_PAD); + constexpr int s_b_stage_offset = BK * (BN + B_PAD); + + const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int lane_id = tid % WARP_SIZE; // 0~31 + const int warp_m = warp_id % 2; // 0,1 + const int warp_n = warp_id / 2; // 0,1,2,3 + + // 先计算shared memory中的索引 + // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=16 按行读取 A行主序 + // 对于s_a每行16个数据,每个线程读取8个,需要2个线程;总共128行,需要128x2刚好256线程 + int load_smem_a_m = tid / 2; // row 0~127 + int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8 + // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B行主序 + // 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共16行,需要16x16=256个线程 + int load_smem_b_k = tid / 16; // row 0~15 + int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120 + // 再计算全局内存中的索引 + // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块 + int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c + int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c + + uint32_t RC[WARP_TILE_M][WARP_TILE_N][2]; + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + RC[i][j][0] = 0; + RC[i][j][1] = 0; + } + } + + // may avoid cvta overhead ? only cvta smem base ptr once for cp.async. + uint32_t smem_a_base_ptr = __cvta_generic_to_shared(s_a); + uint32_t smem_b_base_ptr = __cvta_generic_to_shared(s_b); + + #pragma unroll + for (int k = 0; k < (K_STAGE - 1); ++k) { // 0, 1 + // k * WMMA_K, WMMA_K=16 -> (k << 4) + int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + + uint32_t load_smem_a_ptr = ( + smem_a_base_ptr + (k * s_a_stage_offset + + load_smem_a_m * (BK + A_PAD) + + load_smem_a_k) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); + + uint32_t load_smem_b_ptr = ( + smem_b_base_ptr + (k * s_b_stage_offset + + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); + + CP_ASYNC_COMMIT_GROUP(); + } + + CP_ASYNC_WAIT_GROUP(K_STAGE-2); // s2->0, s3->1, s4->2 + __syncthreads(); + + #pragma unroll + for (int k = (K_STAGE - 1); k < NUM_K_TILES; ++k) { + // gmem -> smem + // s2/4 can use bitwise ops but s3 can not, so, we use mod + // ops for all stages kernel. s2: (k + 1)&1, s4: (k + 1)&3 + // s3: (k + 1) % 3 + int smem_sel = (k + 1) % K_STAGE; // s3 k 2->0, k 3->1, k 4->2... + int smem_sel_next = k % K_STAGE; // s3 k 2->2, k 3->0, k 4->1... + + int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + + uint32_t load_smem_a_ptr = ( + smem_a_base_ptr + (smem_sel_next * s_a_stage_offset + + load_smem_a_m * (BK + A_PAD) + + load_smem_a_k) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); + + uint32_t load_smem_b_ptr = ( + smem_b_base_ptr + (smem_sel_next * s_b_stage_offset + + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); + CP_ASYNC_COMMIT_GROUP(); + + // ldmatrix for s_a, ldmatrix.trans for s_b. + uint32_t RA[WARP_TILE_M][4]; + uint32_t RB[WARP_TILE_N][2]; + + // smem -> reg + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 + int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_a_ptr = __cvta_generic_to_shared( + &s_a[smem_sel][lane_smem_a_m][lane_smem_a_k]); + LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + int lane_smem_b_k = lane_id % 16; // 0~15 + int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 + uint32_t lane_smem_b_ptr = __cvta_generic_to_shared( + &s_b[smem_sel][lane_smem_b_k][lane_smem_b_n]); + LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr); + } + + // MMA compute + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + HMMA16816(RC[i][j][0], RC[i][j][1], + RA[i][0], RA[i][1], RA[i][2], RA[i][3], + RB[j][0], RB[j][1], + RC[i][j][0], RC[i][j][1]); + } + } + + CP_ASYNC_WAIT_GROUP(K_STAGE-2); + __syncthreads(); + } + + // make sure all memory issues ready. + if ((K_STAGE - 2) > 0) { + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + + // processing last (K_STAGE-1) k iters. + { + #pragma unroll + for (int k = 0; k < (K_STAGE - 1); k++) { + int stage_sel = ((NUM_K_TILES - (K_STAGE - 1) + k) % K_STAGE); + // ldmatrix for s_a, ldmatrix.trans for s_b. + uint32_t RA[WARP_TILE_M][4]; + uint32_t RB[WARP_TILE_N][2]; + + // smem -> reg + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 + int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_a_ptr = __cvta_generic_to_shared( + &s_a[stage_sel][lane_smem_a_m][lane_smem_a_k]); + LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + int lane_smem_b_k = lane_id % 16; // 0~15 + int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 + uint32_t lane_smem_b_ptr = __cvta_generic_to_shared( + &s_b[stage_sel][lane_smem_b_k][lane_smem_b_n]); + LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr); + } + + // MMA compute + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + HMMA16816(RC[i][j][0], RC[i][j][1], + RA[i][0], RA[i][1], RA[i][2], RA[i][3], + RB[j][0], RB[j][1], + RC[i][j][0], RC[i][j][1]); + } + } + } + } + + // reg -> gmem, MMA_MxMMA_N=16x8 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + // mapping lane smem index -> global index. + // [16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16] + int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4; + int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n + (lane_id % 4) * 2; + int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n; + int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n; + // TODO: how to use LDST128BITS here ? + LDST32BITS(C[store_gmem_c_addr_0]) = LDST32BITS(RC[i][j][0]); + LDST32BITS(C[store_gmem_c_addr_1]) = LDST32BITS(RC[i][j][1]); + } + } +} + +// 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem +template +__global__ void __launch_bounds__(256) +hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_kernel( + half* A, half* B, half* C, int M, int N, int K) { + // BLOCK_SWIZZLE 0/1 control use block swizzle or not. + const int bx = ((int) BLOCK_SWIZZLE) * blockIdx.z * gridDim.x + blockIdx.x; + const int by = blockIdx.y; + const int NUM_K_TILES = div_ceil(K, MMA_K); + constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128 + constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128 + constexpr int BK = MMA_K; // 16 + + extern __shared__ half smem[]; + half* s_a = smem; + half* s_b = smem + K_STAGE * BM * (BK + A_PAD); + constexpr int s_a_stage_offset = BM * (BK + A_PAD); + constexpr int s_b_stage_offset = BK * (BN + B_PAD); + + const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int lane_id = tid % WARP_SIZE; // 0~31 + const int warp_m = warp_id % 2; // 0,1 + const int warp_n = warp_id / 2; // 0,1,2,3 + + // 先计算shared memory中的索引 + // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=16 按行读取 A行主序 + // 对于s_a每行16个数据,每个线程读取8个,需要2个线程;总共128行,需要128x2刚好256线程 + int load_smem_a_m = tid / 2; // row 0~127 + int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8 + // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B行主序 + // 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共16行,需要16x16=256个线程 + int load_smem_b_k = tid / 16; // row 0~15 + int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120 + // 再计算全局内存中的索引 + // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块 + int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c + int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c + + uint32_t RC[WARP_TILE_M][WARP_TILE_N][2]; + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + RC[i][j][0] = 0; + RC[i][j][1] = 0; + } + } + + // may avoid cvta overhead ? only cvta smem base ptr once for cp.async. + uint32_t smem_a_base_ptr = __cvta_generic_to_shared(s_a); + uint32_t smem_b_base_ptr = __cvta_generic_to_shared(s_b); + + #pragma unroll + for (int k = 0; k < (K_STAGE - 1); ++k) { // 0, 1 + // k * WMMA_K, WMMA_K=16 -> (k << 4) + int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + + uint32_t load_smem_a_ptr = ( + smem_a_base_ptr + (k * s_a_stage_offset + + load_smem_a_m * (BK + A_PAD) + + load_smem_a_k) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); + + uint32_t load_smem_b_ptr = ( + smem_b_base_ptr + (k * s_b_stage_offset + + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); + + CP_ASYNC_COMMIT_GROUP(); + } + + CP_ASYNC_WAIT_GROUP(K_STAGE-2); // s2->0, s3->1, s4->2 + __syncthreads(); + + #pragma unroll + for (int k = (K_STAGE - 1); k < NUM_K_TILES; ++k) { + // gmem -> smem + // s2/4 can use bitwise ops but s3 can not, so, we use mod + // ops for all stages kernel. s2: (k + 1)&1, s4: (k + 1)&3 + // s3: (k + 1) % 3 + int smem_sel = (k + 1) % K_STAGE; // s3 k 2->0, k 3->1, k 4->2... + int smem_sel_next = k % K_STAGE; // s3 k 2->2, k 3->0, k 4->1... + + int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + + uint32_t load_smem_a_ptr = ( + smem_a_base_ptr + (smem_sel_next * s_a_stage_offset + + load_smem_a_m * (BK + A_PAD) + + load_smem_a_k) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); + + uint32_t load_smem_b_ptr = ( + smem_b_base_ptr + (smem_sel_next * s_b_stage_offset + + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); + CP_ASYNC_COMMIT_GROUP(); + + // ldmatrix for s_a, ldmatrix.trans for s_b. + uint32_t RA[WARP_TILE_M][4]; + uint32_t RB[WARP_TILE_N][2]; + + // smem -> reg + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 + int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_a_ptr = ( + smem_a_base_ptr + (smem_sel * s_a_stage_offset + + lane_smem_a_m * (BK + A_PAD) + + lane_smem_a_k) * sizeof(half) + ); + LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + int lane_smem_b_k = lane_id % 16; // 0~15 + int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 + uint32_t lane_smem_b_ptr = ( + smem_b_base_ptr + (smem_sel * s_b_stage_offset + + lane_smem_b_k * (BN + B_PAD) + + lane_smem_b_n) * sizeof(half) + ); + LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr); + } + + // MMA compute + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + HMMA16816(RC[i][j][0], RC[i][j][1], + RA[i][0], RA[i][1], RA[i][2], RA[i][3], + RB[j][0], RB[j][1], + RC[i][j][0], RC[i][j][1]); + } + } + + CP_ASYNC_WAIT_GROUP(K_STAGE-2); + __syncthreads(); + } + + // make sure all memory issues ready. + if ((K_STAGE - 2) > 0) { + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + + // processing last (K_STAGE-1) k iters. + { + #pragma unroll + for (int k = 0; k < (K_STAGE - 1); k++) { + int stage_sel = ((NUM_K_TILES - (K_STAGE - 1) + k) % K_STAGE); + // ldmatrix for s_a, ldmatrix.trans for s_b. + uint32_t RA[WARP_TILE_M][4]; + uint32_t RB[WARP_TILE_N][2]; + + // smem -> reg + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 + int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_a_ptr = ( + smem_a_base_ptr + (stage_sel * s_a_stage_offset + + lane_smem_a_m * (BK + A_PAD) + + lane_smem_a_k) * sizeof(half) + ); + LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + int lane_smem_b_k = lane_id % 16; // 0~15 + int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 + uint32_t lane_smem_b_ptr = ( + smem_b_base_ptr + (stage_sel * s_b_stage_offset + + lane_smem_b_k * (BN + B_PAD) + + lane_smem_b_n) * sizeof(half) + ); + LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr); + } + + // MMA compute + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + HMMA16816(RC[i][j][0], RC[i][j][1], + RA[i][0], RA[i][1], RA[i][2], RA[i][3], + RB[j][0], RB[j][1], + RC[i][j][0], RC[i][j][1]); + } + } + } + } + + // reg -> gmem, MMA_MxMMA_N=16x8 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + // mapping lane smem index -> global index. + // [16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16] + int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4; + int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n + (lane_id % 4) * 2; + int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n; + int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n; + // TODO: how to use LDST128BITS here ? + LDST32BITS(C[store_gmem_c_addr_0]) = LDST32BITS(RC[i][j][0]); + LDST32BITS(C[store_gmem_c_addr_1]) = LDST32BITS(RC[i][j][1]); + } + } +} + +// TODO: Warp swizzle/permute support ? (MMA, not WMMA) + +// --------------------- PyTorch bindings for custom kernel ----------------------- +#define STRINGFY(str) #str +#define TORCH_BINDING_COMMON_EXTENSION(func) \ + m.def(STRINGFY(func), &func, STRINGFY(func)); + +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ +if(((T).options().dtype() != (th_type))) { \ + std::cout << "Tensor Info:" << (T).options() << std::endl; \ + throw std::runtime_error("values must be "#th_type); \ +} + +#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ +if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ + throw std::runtime_error("Tensor size mismatch!"); \ +} + +// 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle +#define LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_KERNEL(stages, stride) \ +{ \ + const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ + dim3 block(NUM_THREADS); \ + dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ + div_ceil(M, BM), \ + N_SWIZZLE); \ + hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_kernel< \ + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, \ + (stages), true><<>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +#define LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_KERNEL(stages) \ +{ \ + dim3 block(NUM_THREADS); \ + dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ + hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_kernel< \ + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, \ + (stages), false><<>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +// 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle +void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages( + torch::Tensor a, torch::Tensor b, torch::Tensor c, + int stages, bool swizzle, int swizzle_stride) { + CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) + const int M = a.size(0); + const int K = a.size(1); + const int N = b.size(1); + CHECK_TORCH_TENSOR_SHAPE(a, M, K) + CHECK_TORCH_TENSOR_SHAPE(b, K, N) + CHECK_TORCH_TENSOR_SHAPE(c, M, N) + constexpr int MMA_M = 16; + constexpr int MMA_N = 8; + constexpr int MMA_K = 16; + constexpr int MMA_TILE_M = 2; + constexpr int MMA_TILE_N = 4; + constexpr int WARP_TILE_M = 4; + constexpr int WARP_TILE_N = 4; + // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts. + // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts. + // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts. + // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus, + // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD. + constexpr int A_PAD = 0; // 0,8,16 + constexpr int B_PAD = 16; // 0,8,16 + constexpr int NUM_THREADS= ( + MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 + constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; + constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; + constexpr int BK = MMA_K; + // s2: 2*128*(16)*2=8KB, 2*16*(128+16)*2=9KB, ~17KB + // s3: 3*128*(16)*2=12KB, 3*16*(128+16)*2=13.5KB, ~26KB + // s4: 4*128*(16)*2=16KB, 4*16*(128+16)*2=18KB, ~34KB + // s5: 5*128*(16)*2=20KB, 5*16*(128+16)*2=22.5KB, ~43KB + if (swizzle) { + assert(swizzle_stride % 256 == 0); + switch (stages) + { + case 2: // ~17KB + LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_KERNEL(2, swizzle_stride); + break; + case 3: // ~26KB + LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_KERNEL(3, swizzle_stride); + break; + case 4: // ~34KB + LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_KERNEL(4, swizzle_stride); + break; + case 5: // ~43KB + LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_KERNEL(5, swizzle_stride); + break; + default: + LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_KERNEL(2, swizzle_stride); + break; + } + } else { + switch (stages) + { + case 2: + LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_KERNEL(2); + break; + case 3: + LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_KERNEL(3); + break; + case 4: + LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_KERNEL(4); + break; + case 5: + LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_KERNEL(5); + break; + default: + LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_KERNEL(2); + break; + } + } +} + +// 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem +#define LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_KERNEL(stages, stride) \ +{ \ + const int smem_max_size = ( \ + (stages) * BM * (BK + A_PAD) * sizeof(half) + \ + (stages) * BK * (BN + B_PAD) * sizeof(half)); \ + cudaFuncSetAttribute( \ + hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_kernel< \ + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + 98304); \ + const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ + dim3 block(NUM_THREADS); \ + dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ + div_ceil(M, BM), \ + N_SWIZZLE); \ + hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_kernel< \ + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), true><<< \ + grid, block, smem_max_size>>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +#define LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_KERNEL(stages) \ +{ \ + const int smem_max_size = ( \ + (stages) * BM * (BK + A_PAD) * sizeof(half) + \ + (stages) * BK * (BN + B_PAD) * sizeof(half)); \ + cudaFuncSetAttribute( \ + hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_kernel< \ + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), false>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + 98304); \ + dim3 block(NUM_THREADS); \ + dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ + hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_kernel< \ + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD, (stages), false><<< \ + grid, block, smem_max_size>>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +// 128x128, mma2x4, warp4x4(64,32,16), stages, block swizzle, dsmem +void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem( + torch::Tensor a, torch::Tensor b, torch::Tensor c, + int stages, bool swizzle, int swizzle_stride) { + CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) + const int M = a.size(0); + const int K = a.size(1); + const int N = b.size(1); + CHECK_TORCH_TENSOR_SHAPE(a, M, K) + CHECK_TORCH_TENSOR_SHAPE(b, K, N) + CHECK_TORCH_TENSOR_SHAPE(c, M, N) + constexpr int MMA_M = 16; + constexpr int MMA_N = 8; + constexpr int MMA_K = 16; + constexpr int MMA_TILE_M = 2; + constexpr int MMA_TILE_N = 4; + constexpr int WARP_TILE_M = 4; + constexpr int WARP_TILE_N = 4; + // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts. + // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts. + // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts. + // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus, + // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD. + constexpr int A_PAD = 0; // 0,8,16 + constexpr int B_PAD = 16; // 0,8,16 + constexpr int NUM_THREADS= ( + MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 + constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; + constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; + constexpr int BK = MMA_K; + // s2: 2*128*(16)*2=8KB, 2*16*(128+16)*2=9KB, ~17KB + // s3: 3*128*(16)*2=12KB, 3*16*(128+16)*2=13.5KB, ~26KB + // s4: 4*128*(16)*2=16KB, 4*16*(128+16)*2=18KB, ~34KB + // s5: 5*128*(16)*2=20KB, 5*16*(128+16)*2=22.5KB, ~43KB + if (swizzle) { + assert(swizzle_stride % 256 == 0); + switch (stages) + { + case 2: // ~17KB + LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_KERNEL(2, swizzle_stride); + break; + case 3: // ~26KB + LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_KERNEL(3, swizzle_stride); + break; + case 4: // ~34KB + LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_KERNEL(4, swizzle_stride); + break; + case 5: // ~43KB + LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_KERNEL(5, swizzle_stride); + break; + default: + LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4_DSMEM_KERNEL(2, swizzle_stride); + break; + } + } else { + switch (stages) + { + case 2: + LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_KERNEL(2); + break; + case 3: + LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_KERNEL(3); + break; + case 4: + LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_KERNEL(4); + break; + case 5: + LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_KERNEL(5); + break; + default: + LAUNCH_16816_STAGE_NO_SWIZZLE_MMA2x4_WARP4x4_DSMEM_KERNEL(2); + break; + } + } +}