From a2934b9bebf447ce3c7c7c0ba9c3274882ddd63c Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Mon, 21 Oct 2024 20:41:47 +0800 Subject: [PATCH] [HGEMM] Add MMA 16816 swizzle, Up to 115 TFLOPS (#98) * Update hgemm_mma.cu * Update README.md * Update hgemm.py * Update hgemm.cu * Update hgemm_mma.cu * Update hgemm.cu * Update hgemm.py * Update README.md * Update hgemm_mma.cu * Update hgemm.py * Update hgemm.cu * Update hgemm_mma.cu * Update README.md * Update hgemm.py * Update README.md * Update README.md * Update hgemm_mma_stage.cu * Update hgemm.py * Update hgemm.cu * Update README.md * Update README.md * Update hgemm_mma_stage.cu * Update hgemm_mma_stage.cu --- README.md | 6 +- hgemm/README.md | 430 +++++++--------------- hgemm/hgemm.cu | 12 + hgemm/hgemm.py | 35 +- hgemm/hgemm_mma.cu | 302 ++++++++++++++- hgemm/hgemm_mma_stage.cu | 776 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 1247 insertions(+), 314 deletions(-) diff --git a/README.md b/README.md index 9c1b88ee..5363c748 100644 --- a/README.md +++ b/README.md @@ -147,6 +147,10 @@ | ✔️ [hgemm_wmma_m32n8k16....dbuf*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_wmma_m16n16k16...stages*](./hgemm/hgemm_wmma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_wmma_m16n16k16...swizzle*](./hgemm/hgemm_wmma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| +| ✔️ [hgemm_mma_m16n8k16...naive*](./hgemm/hgemm_mma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| +| ✔️ [hgemm_mma_m16n8k16...mma2x4*](./hgemm/hgemm_mma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| +| ✔️ [hgemm_mma_m16n8k16...stages*](./hgemm/hgemm_mma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| +| ✔️ [hgemm_mma_m16n8k16...swizzle*](./hgemm/hgemm_mma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️| | ✔️ [sgemv_k32_f32](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️| | ✔️ [sgemv_k128_f32x4](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️| | ✔️ [sgemv_k16_f32](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️| @@ -158,7 +162,7 @@ | ✔️ [hard_nms cpp only](./nms/nms.cc)|f32|/|/|⭐️| | ✔️ [notes v1(deprecated)](./notes-v1.cu)|f32|f32|/|⭐️| -👉TIPS: * means using **Tensor Cores(MMA PTX)**, otherwise, using CUDA Cores by default. +👉TIPS: * means using **Tensor Cores(MMA/WMMA)**, otherwise, using CUDA Cores by default. ## 0x01 📖 博客目录 diff --git a/hgemm/README.md b/hgemm/README.md index 3f7b0fe6..d2f82368 100755 --- a/hgemm/README.md +++ b/hgemm/README.md @@ -23,13 +23,16 @@ - [X] hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages(WMMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle) - [X] hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages(WMMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle) - [X] hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async(WMMA, Tile MMA/Warp, Copy Async, Double Buffers, Pad) +- [X] hgemm_mma_m16n8k16_naive(MMA) +- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4(MMA, Tile MMA/Warp, pack) +- [X] hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(MMA, Tile MMA/Warp, Copy Async, Stages, Pad, Block swizzle) - [X] PyTorch bindings ## 目前性能 - NVIDIA L20 -目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现Warp swizzle(受限于WMMA API的灵活性以及本人的能力),后续将会尝试通过MMA PTX实现warp swizzle,[点击查看性能数据](#NV-L20)。 +目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现Warp swizzle(受限于WMMA API的灵活性以及本人的能力),后续将会尝试通过MMA PTX实现warp swizzle,[点击查看性能数据](#NV-L20)。 - NVIDIA GeForce RTX 3080 Laptop @@ -232,20 +235,23 @@ nsys profile --stats=true -t cuda,osrt,nvtx -o hgemm.prof --force-overwrite true ```bash # 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... export TORCH_CUDA_ARCH_LIST=Ada -python3 hgemm.py # default, test some wmma kernels for all MNK -python3 hgemm.py --wmma # test all wmma kernels for all MNK -python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --wmma # test all wmma kernels for specific MNK -python3 hgemm.py --wmma --no-default # test all wmma kernels, but exclude the default part. +python3 hgemm.py --wmma # test defalut wmma kernels for all MNK +python3 hgemm.py --mma # test defalut mma kernels for all MNK +python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --wmma # test default wmma kernels for specific MNK +python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --mma # test default mma kernels for specific MNK +python3 hgemm.py --wmma --wmma-all # test all wmma kernels for all MNK +python3 hgemm.py --mma --mma-all # test all mma kernels for all MNK ``` ## NVIDIA L20
-Up to 113.76 TFLOPS, 113.76/119.5=95.19% TFLOPS utilization. + +### WMMA & CUDA: Up to 113.76 TFLOPS, 113.76/119.5=95.19% TFLOPS utilization. ```bash -python3 hgemm.py +python3 hgemm.py --cuda --wmma ``` 输出: ```bash @@ -338,79 +344,6 @@ python3 hgemm.py (mma4x2+warp2x4+stage3+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:2.573418ms, swizzle: 2048, TFLOPS: 106.81 (mma4x2+warp2x4+stage2+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:2.533483ms, swizzle: 2048, TFLOPS: 108.50 (cublas): ['9.84375 ', '-46.71875 '], time:2.661132ms, swizzle: NOOP, TFLOPS: 103.29 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=4096, N=8192, K=8192, Warmup=5, Iters=20, 6/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['47.53125 ', '-51.5 '], time:11.79177ms, swizzle: NOOP, TFLOPS: 46.62 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['47.53125 ', '-51.5 '], time:11.25807ms, swizzle: NOOP, TFLOPS: 48.83 (+4.74%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['47.0 ', '-52.25 '], time:12.28225ms, swizzle: NOOP, TFLOPS: 44.76 - (mma4x2+warp2x4): ['47.0 ', '-52.25 '], time:7.306694ms, swizzle: NOOP, TFLOPS: 75.24 (+54.08%) - (mma4x2+warp2x4+stage3): ['47.0 ', '-52.25 '], time:5.185413ms, swizzle: NOOP, TFLOPS: 106.02(+40.91%) - (mma4x2+warp2x4+stage2): ['47.0 ', '-52.25 '], time:5.128622ms, swizzle: NOOP, TFLOPS: 107.19(+1.11%) - (mma4x2+warp2x4+stage3+dsmem): ['47.0 ', '-52.25 '], time:5.165719ms, swizzle: NOOP, TFLOPS: 106.42 - (mma4x2+warp2x4+stage2+dsmem): ['47.0 ', '-52.25 '], time:5.137014ms, swizzle: NOOP, TFLOPS: 107.02 - (mma4x2+warp2x4+stage3+swizzle): ['47.0 ', '-52.25 '], time:5.096411ms, swizzle: 2048, TFLOPS: 107.87(+0.63%) - (mma4x2+warp2x4+stage2+swizzle): ['47.0 ', '-52.25 '], time:5.036878ms, swizzle: 2048, TFLOPS: 109.15(+1.18%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:5.087852ms, swizzle: 2048, TFLOPS: 108.05 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:5.011391ms, swizzle: 2048, TFLOPS: 109.70(+0.51%) - (cublas): ['47.0 ', '-52.25 '], time:5.063843ms, swizzle: NOOP, TFLOPS: 108.56 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=4096, N=16384, K=2048, Warmup=5, Iters=20, 7/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['14.765625 ', '-18.640625'], time:5.306124ms, swizzle: NOOP, TFLOPS: 51.80 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['14.765625 ', '-18.640625'], time:5.044364ms, swizzle: NOOP, TFLOPS: 54.49 (+5.19%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['14.28125 ', '-18.6875 '], time:5.916452ms, swizzle: NOOP, TFLOPS: 46.46 - (mma4x2+warp2x4): ['14.28125 ', '-18.6875 '], time:3.550720ms, swizzle: NOOP, TFLOPS: 77.41 (+42.07%) - (mma4x2+warp2x4+stage3): ['14.28125 ', '-18.6875 '], time:2.552175ms, swizzle: NOOP, TFLOPS: 107.70(+39.13%) - (mma4x2+warp2x4+stage2): ['14.28125 ', '-18.6875 '], time:2.537274ms, swizzle: NOOP, TFLOPS: 108.34(+0.59%) - (mma4x2+warp2x4+stage3+dsmem): ['14.28125 ', '-18.6875 '], time:2.545833ms, swizzle: NOOP, TFLOPS: 107.97 - (mma4x2+warp2x4+stage2+dsmem): ['14.28125 ', '-18.6875 '], time:2.546501ms, swizzle: NOOP, TFLOPS: 107.94 - (mma4x2+warp2x4+stage3+swizzle): ['14.28125 ', '-18.6875 '], time:2.544927ms, swizzle: 4096, TFLOPS: 108.01 - (mma4x2+warp2x4+stage2+swizzle): ['14.28125 ', '-18.6875 '], time:2.518939ms, swizzle: 4096, TFLOPS: 109.12(+0.73%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:2.547931ms, swizzle: 4096, TFLOPS: 107.88 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:2.512478ms, swizzle: 4096, TFLOPS: 109.41(+0.26%) - (cublas): ['14.28125 ', '-18.6875 '], time:2.635645ms, swizzle: NOOP, TFLOPS: 104.29 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=4096, N=16384, K=4096, Warmup=5, Iters=20, 8/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['10.296875 ', '-46.6875 '], time:11.61146ms, swizzle: NOOP, TFLOPS: 47.35 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['10.296875 ', '-46.6875 '], time:11.02995ms, swizzle: NOOP, TFLOPS: 49.84 (+5.27%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['9.84375 ', '-46.71875 '], time:15.55149ms, swizzle: NOOP, TFLOPS: 35.35 - (mma4x2+warp2x4): ['9.84375 ', '-46.71875 '], time:7.264566ms, swizzle: NOOP, TFLOPS: 75.68 (+51.83%) - (mma4x2+warp2x4+stage3): ['9.84375 ', '-46.71875 '], time:5.160856ms, swizzle: NOOP, TFLOPS: 106.52(+40.76%) - (mma4x2+warp2x4+stage2): ['9.84375 ', '-46.71875 '], time:5.038166ms, swizzle: NOOP, TFLOPS: 109.12(+2.44%) - (mma4x2+warp2x4+stage3+dsmem): ['9.84375 ', '-46.71875 '], time:5.177164ms, swizzle: NOOP, TFLOPS: 106.19 - (mma4x2+warp2x4+stage2+dsmem): ['9.84375 ', '-46.71875 '], time:5.098938ms, swizzle: NOOP, TFLOPS: 107.82 - (mma4x2+warp2x4+stage3+swizzle): ['9.84375 ', '-46.71875 '], time:5.004787ms, swizzle: 4096, TFLOPS: 109.85(+0.67%) - (mma4x2+warp2x4+stage2+swizzle): ['9.84375 ', '-46.71875 '], time:4.954004ms, swizzle: 4096, TFLOPS: 110.97(+1.03%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:5.003094ms, swizzle: 4096, TFLOPS: 109.88 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:4.933691ms, swizzle: 4096, TFLOPS: 111.43(+0.41%) - (cublas): ['9.84375 ', '-46.71875 '], time:4.990887ms, swizzle: NOOP, TFLOPS: 110.15 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=4096, N=16384, K=8192, Warmup=5, Iters=20, 9/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['47.53125 ', '-51.5 '], time:24.35543ms, swizzle: NOOP, TFLOPS: 45.14 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['47.53125 ', '-51.5 '], time:23.57738ms, swizzle: NOOP, TFLOPS: 46.63 (+3.30%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['47.0 ', '-52.25 '], time:31.01222ms, swizzle: NOOP, TFLOPS: 35.45 - (mma4x2+warp2x4): ['47.0 ', '-52.25 '], time:14.37473ms, swizzle: NOOP, TFLOPS: 76.49 (+64.02%) - (mma4x2+warp2x4+stage3): ['47.0 ', '-52.25 '], time:12.40768ms, swizzle: NOOP, TFLOPS: 88.62 (+15.85%) - (mma4x2+warp2x4+stage2): ['47.0 ', '-52.25 '], time:12.25883ms, swizzle: NOOP, TFLOPS: 89.69 (+1.21%) - (mma4x2+warp2x4+stage3+dsmem): ['47.0 ', '-52.25 '], time:12.40663ms, swizzle: NOOP, TFLOPS: 88.62 - (mma4x2+warp2x4+stage2+dsmem): ['47.0 ', '-52.25 '], time:12.26737ms, swizzle: NOOP, TFLOPS: 89.63 - (mma4x2+warp2x4+stage3+swizzle): ['47.0 ', '-52.25 '], time:9.920740ms, swizzle: 4096, TFLOPS: 110.83(+23.57%) - (mma4x2+warp2x4+stage2+swizzle): ['47.0 ', '-52.25 '], time:9.804654ms, swizzle: 4096, TFLOPS: 112.14(+1.18%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:9.917545ms, swizzle: 4096, TFLOPS: 110.87 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:9.778022ms, swizzle: 4096, TFLOPS: 112.45(+0.27%) - (cublas): ['47.0 ', '-52.25 '], time:9.679126ms, swizzle: NOOP, TFLOPS: 113.60(+1.02%) ----------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=8192, N=4096, K=2048, Warmup=5, Iters=20, 10/27 ---------------------------------------------------------------------------------------------------------------------------------- @@ -464,223 +397,6 @@ python3 hgemm.py (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:5.098199ms, swizzle: 1024, TFLOPS: 107.83 (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:5.003476ms, swizzle: 1024, TFLOPS: 109.87(+0.02%) (cublas): ['47.0 ', '-52.25 '], time:5.096864ms, swizzle: NOOP, TFLOPS: 107.86 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=8192, N=8192, K=2048, Warmup=5, Iters=20, 13/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['14.765625 ', '-18.640625'], time:5.346417ms, swizzle: NOOP, TFLOPS: 51.41 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['14.765625 ', '-18.640625'], time:4.942703ms, swizzle: NOOP, TFLOPS: 55.61 (+8.17%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['14.28125 ', '-18.6875 '], time:5.900359ms, swizzle: NOOP, TFLOPS: 46.59 - (mma4x2+warp2x4): ['14.28125 ', '-18.6875 '], time:3.572225ms, swizzle: NOOP, TFLOPS: 76.95 (+38.36%) - (mma4x2+warp2x4+stage3): ['14.28125 ', '-18.6875 '], time:2.547502ms, swizzle: NOOP, TFLOPS: 107.90(+40.22%) - (mma4x2+warp2x4+stage2): ['14.28125 ', '-18.6875 '], time:2.539443ms, swizzle: NOOP, TFLOPS: 108.24(+0.32%) - (mma4x2+warp2x4+stage3+dsmem): ['14.28125 ', '-18.6875 '], time:2.537584ms, swizzle: NOOP, TFLOPS: 108.32(+0.07%) - (mma4x2+warp2x4+stage2+dsmem): ['14.28125 ', '-18.6875 '], time:2.540159ms, swizzle: NOOP, TFLOPS: 108.21 - (mma4x2+warp2x4+stage3+swizzle): ['14.28125 ', '-18.6875 '], time:2.535915ms, swizzle: 2048, TFLOPS: 108.39(+0.07%) - (mma4x2+warp2x4+stage2+swizzle): ['14.28125 ', '-18.6875 '], time:2.510333ms, swizzle: 2048, TFLOPS: 109.50(+1.02%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:2.547550ms, swizzle: 2048, TFLOPS: 107.90 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:2.511882ms, swizzle: 2048, TFLOPS: 109.43 - (cublas): ['14.28125 ', '-18.6875 '], time:2.635979ms, swizzle: NOOP, TFLOPS: 104.28 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=8192, N=8192, K=4096, Warmup=5, Iters=20, 14/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['10.296875 ', '-46.6875 '], time:10.91315ms, swizzle: NOOP, TFLOPS: 50.38 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['10.296875 ', '-46.6875 '], time:10.32221ms, swizzle: NOOP, TFLOPS: 53.26 (+5.72%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['9.84375 ', '-46.71875 '], time:11.72120ms, swizzle: NOOP, TFLOPS: 46.90 - (mma4x2+warp2x4): ['9.84375 ', '-46.71875 '], time:6.984162ms, swizzle: NOOP, TFLOPS: 78.71 (+47.79%) - (mma4x2+warp2x4+stage3): ['9.84375 ', '-46.71875 '], time:5.013513ms, swizzle: NOOP, TFLOPS: 109.65(+39.31%) - (mma4x2+warp2x4+stage2): ['9.84375 ', '-46.71875 '], time:4.993677ms, swizzle: NOOP, TFLOPS: 110.09(+0.40%) - (mma4x2+warp2x4+stage3+dsmem): ['9.84375 ', '-46.71875 '], time:4.979777ms, swizzle: NOOP, TFLOPS: 110.40(+0.28%) - (mma4x2+warp2x4+stage2+dsmem): ['9.84375 ', '-46.71875 '], time:5.007362ms, swizzle: NOOP, TFLOPS: 109.79 - (mma4x2+warp2x4+stage3+swizzle): ['9.84375 ', '-46.71875 '], time:4.975485ms, swizzle: 2048, TFLOPS: 110.49(+0.09%) - (mma4x2+warp2x4+stage2+swizzle): ['9.84375 ', '-46.71875 '], time:4.935383ms, swizzle: 2048, TFLOPS: 111.39(+0.81%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:4.990983ms, swizzle: 2048, TFLOPS: 110.15 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:4.953265ms, swizzle: 2048, TFLOPS: 110.99 - (cublas): ['9.84375 ', '-46.71875 '], time:4.983496ms, swizzle: NOOP, TFLOPS: 110.32 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=8192, N=8192, K=8192, Warmup=5, Iters=20, 15/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['47.53125 ', '-51.5 '], time:23.78494ms, swizzle: NOOP, TFLOPS: 46.23 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['47.53125 ', '-51.5 '], time:22.85294ms, swizzle: NOOP, TFLOPS: 48.11 (+4.08%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['47.0 ', '-52.25 '], time:24.70605ms, swizzle: NOOP, TFLOPS: 44.50 - (mma4x2+warp2x4): ['47.0 ', '-52.25 '], time:14.21954ms, swizzle: NOOP, TFLOPS: 77.32 (+60.72%) - (mma4x2+warp2x4+stage3): ['47.0 ', '-52.25 '], time:10.34536ms, swizzle: NOOP, TFLOPS: 106.28(+37.45%) - (mma4x2+warp2x4+stage2): ['47.0 ', '-52.25 '], time:10.25786ms, swizzle: NOOP, TFLOPS: 107.19(+0.85%) - (mma4x2+warp2x4+stage3+dsmem): ['47.0 ', '-52.25 '], time:10.49890ms, swizzle: NOOP, TFLOPS: 104.73 - (mma4x2+warp2x4+stage2+dsmem): ['47.0 ', '-52.25 '], time:10.29896ms, swizzle: NOOP, TFLOPS: 106.76 - (mma4x2+warp2x4+stage3+swizzle): ['47.0 ', '-52.25 '], time:9.953498ms, swizzle: 2048, TFLOPS: 110.46(+3.06%) - (mma4x2+warp2x4+stage2+swizzle): ['47.0 ', '-52.25 '], time:9.775471ms, swizzle: 2048, TFLOPS: 112.48(+1.82%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:9.905838ms, swizzle: 2048, TFLOPS: 111.00 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:9.768342ms, swizzle: 2048, TFLOPS: 112.56(+0.07%) - (cublas): ['47.0 ', '-52.25 '], time:9.739327ms, swizzle: NOOP, TFLOPS: 112.89(+0.30%) ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=8192, N=16384, K=2048, Warmup=5, Iters=20, 16/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['14.765625 ', '-18.640625'], time:10.92975ms, swizzle: NOOP, TFLOPS: 50.30 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['14.765625 ', '-18.640625'], time:10.32440ms, swizzle: NOOP, TFLOPS: 53.25 (+5.86%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['14.28125 ', '-18.6875 '], time:11.78483ms, swizzle: NOOP, TFLOPS: 46.65 - (mma4x2+warp2x4): ['14.28125 ', '-18.6875 '], time:6.915855ms, swizzle: NOOP, TFLOPS: 79.49 (+49.29%) - (mma4x2+warp2x4+stage3): ['14.28125 ', '-18.6875 '], time:5.065703ms, swizzle: NOOP, TFLOPS: 108.53(+36.52%) - (mma4x2+warp2x4+stage2): ['14.28125 ', '-18.6875 '], time:5.030226ms, swizzle: NOOP, TFLOPS: 109.29(+0.71%) - (mma4x2+warp2x4+stage3+dsmem): ['14.28125 ', '-18.6875 '], time:5.033874ms, swizzle: NOOP, TFLOPS: 109.21 - (mma4x2+warp2x4+stage2+dsmem): ['14.28125 ', '-18.6875 '], time:5.028772ms, swizzle: NOOP, TFLOPS: 109.32(+0.03%) - (mma4x2+warp2x4+stage3+swizzle): ['14.28125 ', '-18.6875 '], time:5.024838ms, swizzle: 4096, TFLOPS: 109.41(+0.08%) - (mma4x2+warp2x4+stage2+swizzle): ['14.28125 ', '-18.6875 '], time:4.977488ms, swizzle: 4096, TFLOPS: 110.45(+0.95%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:5.048298ms, swizzle: 4096, TFLOPS: 108.90 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:4.978847ms, swizzle: 4096, TFLOPS: 110.42 - (cublas): ['14.28125 ', '-18.6875 '], time:5.005145ms, swizzle: NOOP, TFLOPS: 109.84 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=8192, N=16384, K=4096, Warmup=5, Iters=20, 17/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['10.296875 ', '-46.6875 '], time:23.64189ms, swizzle: NOOP, TFLOPS: 46.51 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['10.296875 ', '-46.6875 '], time:22.93310ms, swizzle: NOOP, TFLOPS: 47.94 (+3.09%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['9.84375 ', '-46.71875 '], time:31.19268ms, swizzle: NOOP, TFLOPS: 35.25 - (mma4x2+warp2x4): ['9.84375 ', '-46.71875 '], time:14.21802ms, swizzle: NOOP, TFLOPS: 77.33 (+61.30%) - (mma4x2+warp2x4+stage3): ['9.84375 ', '-46.71875 '], time:10.72919ms, swizzle: NOOP, TFLOPS: 102.48(+32.52%) - (mma4x2+warp2x4+stage2): ['9.84375 ', '-46.71875 '], time:10.53795ms, swizzle: NOOP, TFLOPS: 104.34(+1.81%) - (mma4x2+warp2x4+stage3+dsmem): ['9.84375 ', '-46.71875 '], time:10.60345ms, swizzle: NOOP, TFLOPS: 103.69 - (mma4x2+warp2x4+stage2+dsmem): ['9.84375 ', '-46.71875 '], time:10.46254ms, swizzle: NOOP, TFLOPS: 105.09(+0.72%) - (mma4x2+warp2x4+stage3+swizzle): ['9.84375 ', '-46.71875 '], time:9.963369ms, swizzle: 4096, TFLOPS: 110.36(+5.01%) - (mma4x2+warp2x4+stage2+swizzle): ['9.84375 ', '-46.71875 '], time:9.808254ms, swizzle: 4096, TFLOPS: 112.10(+1.58%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:9.931588ms, swizzle: 4096, TFLOPS: 110.71 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:9.800553ms, swizzle: 4096, TFLOPS: 112.19(+0.08%) - (cublas): ['9.84375 ', '-46.71875 '], time:9.695315ms, swizzle: NOOP, TFLOPS: 113.41(+1.09%) ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=8192, N=16384, K=8192, Warmup=5, Iters=20, 18/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['47.53125 ', '-51.5 '], time:49.22699ms, swizzle: NOOP, TFLOPS: 44.67 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['47.53125 ', '-51.5 '], time:48.20067ms, swizzle: NOOP, TFLOPS: 45.62 (+2.13%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['47.0 ', '-52.25 '], time:61.79366ms, swizzle: NOOP, TFLOPS: 35.59 - (mma4x2+warp2x4): ['47.0 ', '-52.25 '], time:28.10072ms, swizzle: NOOP, TFLOPS: 78.26 (+71.53%) - (mma4x2+warp2x4+stage3): ['47.0 ', '-52.25 '], time:24.79410ms, swizzle: NOOP, TFLOPS: 88.69 (+13.34%) - (mma4x2+warp2x4+stage2): ['47.0 ', '-52.25 '], time:24.75156ms, swizzle: NOOP, TFLOPS: 88.84 (+0.17%) - (mma4x2+warp2x4+stage3+dsmem): ['47.0 ', '-52.25 '], time:24.81336ms, swizzle: NOOP, TFLOPS: 88.62 - (mma4x2+warp2x4+stage2+dsmem): ['47.0 ', '-52.25 '], time:24.72374ms, swizzle: NOOP, TFLOPS: 88.94 (+0.11%) - (mma4x2+warp2x4+stage3+swizzle): ['47.0 ', '-52.25 '], time:19.72281ms, swizzle: 4096, TFLOPS: 111.50(+25.36%) - (mma4x2+warp2x4+stage2+swizzle): ['47.0 ', '-52.25 '], time:19.45309ms, swizzle: 4096, TFLOPS: 113.04(+1.39%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:19.70534ms, swizzle: 4096, TFLOPS: 111.60 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:19.46532ms, swizzle: 4096, TFLOPS: 112.97 - (cublas): ['47.0 ', '-52.25 '], time:19.07067ms, swizzle: NOOP, TFLOPS: 115.31(+2.01%) ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=16384, N=4096, K=2048, Warmup=5, Iters=20, 19/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['14.765625 ', '-18.640625'], time:5.374526ms, swizzle: NOOP, TFLOPS: 51.14 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['14.765625 ', '-18.640625'], time:4.960513ms, swizzle: NOOP, TFLOPS: 55.41 (+8.35%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['14.28125 ', '-18.6875 '], time:5.899548ms, swizzle: NOOP, TFLOPS: 46.59 - (mma4x2+warp2x4): ['14.28125 ', '-18.6875 '], time:3.571367ms, swizzle: NOOP, TFLOPS: 76.97 (+38.90%) - (mma4x2+warp2x4+stage3): ['14.28125 ', '-18.6875 '], time:2.558302ms, swizzle: NOOP, TFLOPS: 107.45(+39.60%) - (mma4x2+warp2x4+stage2): ['14.28125 ', '-18.6875 '], time:2.541303ms, swizzle: NOOP, TFLOPS: 108.16(+0.67%) - (mma4x2+warp2x4+stage3+dsmem): ['14.28125 ', '-18.6875 '], time:2.538442ms, swizzle: NOOP, TFLOPS: 108.29(+0.11%) - (mma4x2+warp2x4+stage2+dsmem): ['14.28125 ', '-18.6875 '], time:2.542233ms, swizzle: NOOP, TFLOPS: 108.12 - (mma4x2+warp2x4+stage3+swizzle): ['14.28125 ', '-18.6875 '], time:2.538371ms, swizzle: 1024, TFLOPS: 108.29(+0.00%) - (mma4x2+warp2x4+stage2+swizzle): ['14.28125 ', '-18.6875 '], time:2.512073ms, swizzle: 1024, TFLOPS: 109.42(+1.05%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:2.551054ms, swizzle: 1024, TFLOPS: 107.75 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:2.513241ms, swizzle: 1024, TFLOPS: 109.37 - (cublas): ['14.28125 ', '-18.6875 '], time:2.619862ms, swizzle: NOOP, TFLOPS: 104.92 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=16384, N=4096, K=4096, Warmup=5, Iters=20, 20/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['10.296875 ', '-46.6875 '], time:11.02216ms, swizzle: NOOP, TFLOPS: 49.88 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['10.296875 ', '-46.6875 '], time:10.51564ms, swizzle: NOOP, TFLOPS: 52.28 (+4.82%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['9.84375 ', '-46.71875 '], time:11.72006ms, swizzle: NOOP, TFLOPS: 46.91 - (mma4x2+warp2x4): ['9.84375 ', '-46.71875 '], time:6.976056ms, swizzle: NOOP, TFLOPS: 78.81 (+50.74%) - (mma4x2+warp2x4+stage3): ['9.84375 ', '-46.71875 '], time:5.025959ms, swizzle: NOOP, TFLOPS: 109.38(+38.80%) - (mma4x2+warp2x4+stage2): ['9.84375 ', '-46.71875 '], time:4.991459ms, swizzle: NOOP, TFLOPS: 110.14(+0.69%) - (mma4x2+warp2x4+stage3+dsmem): ['9.84375 ', '-46.71875 '], time:4.996275ms, swizzle: NOOP, TFLOPS: 110.03 - (mma4x2+warp2x4+stage2+dsmem): ['9.84375 ', '-46.71875 '], time:4.992103ms, swizzle: NOOP, TFLOPS: 110.13 - (mma4x2+warp2x4+stage3+swizzle): ['9.84375 ', '-46.71875 '], time:4.988074ms, swizzle: 1024, TFLOPS: 110.21(+0.07%) - (mma4x2+warp2x4+stage2+swizzle): ['9.84375 ', '-46.71875 '], time:4.960775ms, swizzle: 1024, TFLOPS: 110.82(+0.55%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:5.020546ms, swizzle: 1024, TFLOPS: 109.50 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:4.957389ms, swizzle: 1024, TFLOPS: 110.90(+0.07%) - (cublas): ['9.84375 ', '-46.71875 '], time:4.962539ms, swizzle: NOOP, TFLOPS: 110.78 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=16384, N=4096, K=8192, Warmup=5, Iters=20, 21/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['47.53125 ', '-51.5 '], time:22.25213ms, swizzle: NOOP, TFLOPS: 49.41 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['47.53125 ', '-51.5 '], time:21.25067ms, swizzle: NOOP, TFLOPS: 51.74 (+4.71%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['47.0 ', '-52.25 '], time:23.32034ms, swizzle: NOOP, TFLOPS: 47.15 - (mma4x2+warp2x4): ['47.0 ', '-52.25 '], time:13.78231ms, swizzle: NOOP, TFLOPS: 79.78 (+54.19%) - (mma4x2+warp2x4+stage3): ['47.0 ', '-52.25 '], time:9.944629ms, swizzle: NOOP, TFLOPS: 110.56(+38.59%) - (mma4x2+warp2x4+stage2): ['47.0 ', '-52.25 '], time:9.877133ms, swizzle: NOOP, TFLOPS: 111.32(+0.68%) - (mma4x2+warp2x4+stage3+dsmem): ['47.0 ', '-52.25 '], time:9.891724ms, swizzle: NOOP, TFLOPS: 111.15 - (mma4x2+warp2x4+stage2+dsmem): ['47.0 ', '-52.25 '], time:9.875774ms, swizzle: NOOP, TFLOPS: 111.33(+0.01%) - (mma4x2+warp2x4+stage3+swizzle): ['47.0 ', '-52.25 '], time:9.909319ms, swizzle: 1024, TFLOPS: 110.96 - (mma4x2+warp2x4+stage2+swizzle): ['47.0 ', '-52.25 '], time:9.821128ms, swizzle: 1024, TFLOPS: 111.95(+0.56%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:10.00571ms, swizzle: 1024, TFLOPS: 109.89 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:9.818959ms, swizzle: 1024, TFLOPS: 111.98(+0.02%) - (cublas): ['47.0 ', '-52.25 '], time:9.649991ms, swizzle: NOOP, TFLOPS: 113.94(+1.75%) ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=16384, N=8192, K=2048, Warmup=5, Iters=20, 22/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['14.765625 ', '-18.640625'], time:10.99567ms, swizzle: NOOP, TFLOPS: 50.00 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['14.765625 ', '-18.640625'], time:10.49816ms, swizzle: NOOP, TFLOPS: 52.37 (+4.74%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['14.28125 ', '-18.6875 '], time:11.76815ms, swizzle: NOOP, TFLOPS: 46.72 - (mma4x2+warp2x4): ['14.28125 ', '-18.6875 '], time:6.931424ms, swizzle: NOOP, TFLOPS: 79.31 (+51.46%) - (mma4x2+warp2x4+stage3): ['14.28125 ', '-18.6875 '], time:5.055880ms, swizzle: NOOP, TFLOPS: 108.74(+37.10%) - (mma4x2+warp2x4+stage2): ['14.28125 ', '-18.6875 '], time:5.022001ms, swizzle: NOOP, TFLOPS: 109.47(+0.67%) - (mma4x2+warp2x4+stage3+dsmem): ['14.28125 ', '-18.6875 '], time:5.026936ms, swizzle: NOOP, TFLOPS: 109.36 - (mma4x2+warp2x4+stage2+dsmem): ['14.28125 ', '-18.6875 '], time:5.020689ms, swizzle: NOOP, TFLOPS: 109.50(+0.03%) - (mma4x2+warp2x4+stage3+swizzle): ['14.28125 ', '-18.6875 '], time:5.018496ms, swizzle: 2048, TFLOPS: 109.55(+0.04%) - (mma4x2+warp2x4+stage2+swizzle): ['14.28125 ', '-18.6875 '], time:4.968738ms, swizzle: 2048, TFLOPS: 110.64(+1.00%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:5.040884ms, swizzle: 2048, TFLOPS: 109.06 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['14.28125 ', '-18.6875 '], time:4.972743ms, swizzle: 2048, TFLOPS: 110.55 - (cublas): ['14.28125 ', '-18.6875 '], time:4.969763ms, swizzle: NOOP, TFLOPS: 110.62 ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=16384, N=8192, K=4096, Warmup=5, Iters=20, 23/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['10.296875 ', '-46.6875 '], time:22.06621ms, swizzle: NOOP, TFLOPS: 49.83 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['10.296875 ', '-46.6875 '], time:21.04604ms, swizzle: NOOP, TFLOPS: 52.24 (+4.85%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['9.84375 ', '-46.71875 '], time:23.35319ms, swizzle: NOOP, TFLOPS: 47.08 - (mma4x2+warp2x4): ['9.84375 ', '-46.71875 '], time:13.64452ms, swizzle: NOOP, TFLOPS: 80.58 (+54.25%) - (mma4x2+warp2x4+stage3): ['9.84375 ', '-46.71875 '], time:9.970688ms, swizzle: NOOP, TFLOPS: 110.27(+36.85%) - (mma4x2+warp2x4+stage2): ['9.84375 ', '-46.71875 '], time:9.907126ms, swizzle: NOOP, TFLOPS: 110.98(+0.64%) - (mma4x2+warp2x4+stage3+dsmem): ['9.84375 ', '-46.71875 '], time:9.919929ms, swizzle: NOOP, TFLOPS: 110.84 - (mma4x2+warp2x4+stage2+dsmem): ['9.84375 ', '-46.71875 '], time:9.905028ms, swizzle: NOOP, TFLOPS: 111.01(+0.02%) - (mma4x2+warp2x4+stage3+swizzle): ['9.84375 ', '-46.71875 '], time:9.890699ms, swizzle: 2048, TFLOPS: 111.17(+0.14%) - (mma4x2+warp2x4+stage2+swizzle): ['9.84375 ', '-46.71875 '], time:9.796810ms, swizzle: 2048, TFLOPS: 112.23(+0.96%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:9.931468ms, swizzle: 2048, TFLOPS: 110.71 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['9.84375 ', '-46.71875 '], time:9.799408ms, swizzle: 2048, TFLOPS: 112.20 - (cublas): ['9.84375 ', '-46.71875 '], time:9.670424ms, swizzle: NOOP, TFLOPS: 113.70(+1.31%) ----------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------------------------------------------- - M=16384, N=8192, K=8192, Warmup=5, Iters=20, 24/27 ----------------------------------------------------------------------------------------------------------------------------------- - (f16x8pack+t8x8+dbuf): ['47.53125 ', '-51.5 '], time:47.87569ms, swizzle: NOOP, TFLOPS: 45.93 (+0.00%) - (f16x8pack+t8x8+k16+dbuf): ['47.53125 ', '-51.5 '], time:46.51024ms, swizzle: NOOP, TFLOPS: 47.28 (+2.94%) ---------------------------------------------------------------------WMMA---------------------------------------------------------- - (mma4x2): ['47.0 ', '-52.25 '], time:49.74699ms, swizzle: NOOP, TFLOPS: 44.20 - (mma4x2+warp2x4): ['47.0 ', '-52.25 '], time:28.01880ms, swizzle: NOOP, TFLOPS: 78.48 (+66.00%) - (mma4x2+warp2x4+stage3): ['47.0 ', '-52.25 '], time:21.40111ms, swizzle: NOOP, TFLOPS: 102.75(+30.92%) - (mma4x2+warp2x4+stage2): ['47.0 ', '-52.25 '], time:20.99103ms, swizzle: NOOP, TFLOPS: 104.76(+1.95%) - (mma4x2+warp2x4+stage3+dsmem): ['47.0 ', '-52.25 '], time:21.22135ms, swizzle: NOOP, TFLOPS: 103.62 - (mma4x2+warp2x4+stage2+dsmem): ['47.0 ', '-52.25 '], time:21.01814ms, swizzle: NOOP, TFLOPS: 104.62 - (mma4x2+warp2x4+stage3+swizzle): ['47.0 ', '-52.25 '], time:19.75324ms, swizzle: 2048, TFLOPS: 111.32(+6.27%) - (mma4x2+warp2x4+stage2+swizzle): ['47.0 ', '-52.25 '], time:19.45850ms, swizzle: 2048, TFLOPS: 113.01(+1.51%) - (mma4x2+warp2x4+stage3+dsmem+swizzle): ['47.0 ', '-52.25 '], time:19.70596ms, swizzle: 2048, TFLOPS: 111.59 - (mma4x2+warp2x4+stage2+dsmem+swizzle): ['47.0 ', '-52.25 '], time:19.45621ms, swizzle: 2048, TFLOPS: 113.02(+0.01%) - (cublas): ['47.0 ', '-52.25 '], time:19.04292ms, swizzle: NOOP, TFLOPS: 115.48(+2.17%) ----------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------------------------------------- M=16384, N=16384, K=2048, Warmup=5, Iters=20, 25/27 ---------------------------------------------------------------------------------------------------------------------------------- @@ -737,11 +453,129 @@ python3 hgemm.py ---------------------------------------------------------------------------------------------------------------------------------- ``` +### MMA: Up to 115 TFLOPS, 115/119.5=96.23% TFLOPS utilization. + +```bash +python3 hgemm.py --mma +``` + +输出: +```bash +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=4096, K=8192, Warmup=5, Iters=20, 21/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['21.984375 ', '58.0 '], time:10.29069ms, swizzle: NOOP, TFLOPS: 106.85(+0.00%) + (mma2x4+warp4x4+stage3): ['21.984375 ', '58.0 '], time:9.866333ms, swizzle: NOOP, TFLOPS: 111.44(+4.30%) + (mma2x4+warp4x4+stage2): ['21.984375 ', '58.0 '], time:9.776329ms, swizzle: NOOP, TFLOPS: 112.47(+0.92%) + (mma2x4+warp4x4+stage3+dsmem): ['21.984375 ', '58.0 '], time:9.924983ms, swizzle: NOOP, TFLOPS: 110.78 + (mma2x4+warp4x4+stage2+dsmem): ['21.984375 ', '58.0 '], time:9.772467ms, swizzle: NOOP, TFLOPS: 112.51(+0.04%) + (mma2x4+warp4x4+stage3+swizzle): ['21.984375 ', '58.0 '], time:9.879112ms, swizzle: 1024, TFLOPS: 111.30 + (mma2x4+warp4x4+stage2+swizzle): ['21.984375 ', '58.0 '], time:9.752583ms, swizzle: 1024, TFLOPS: 112.74(+0.20%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['21.984375 ', '58.0 '], time:9.922742ms, swizzle: 1024, TFLOPS: 110.81 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['21.984375 ', '58.0 '], time:9.673309ms, swizzle: 1024, TFLOPS: 113.66(+0.82%) + (cublas): ['21.984375 ', '58.0 '], time:9.443545ms, swizzle: NOOP, TFLOPS: 116.43(+2.43%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=8192, K=2048, Warmup=5, Iters=20, 22/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['32.40625 ', '-4.0039062'], time:5.229425ms, swizzle: NOOP, TFLOPS: 105.13(+0.00%) + (mma2x4+warp4x4+stage3): ['32.40625 ', '-4.0039062'], time:5.009818ms, swizzle: NOOP, TFLOPS: 109.74(+4.38%) + (mma2x4+warp4x4+stage2): ['32.40625 ', '-4.0039062'], time:4.968261ms, swizzle: NOOP, TFLOPS: 110.65(+0.84%) + (mma2x4+warp4x4+stage3+dsmem): ['32.40625 ', '-4.0039062'], time:5.031824ms, swizzle: NOOP, TFLOPS: 109.26 + (mma2x4+warp4x4+stage2+dsmem): ['32.40625 ', '-4.0039062'], time:4.965233ms, swizzle: NOOP, TFLOPS: 110.72(+0.06%) + (mma2x4+warp4x4+stage3+swizzle): ['32.40625 ', '-4.0039062'], time:5.021595ms, swizzle: 2048, TFLOPS: 109.48 + (mma2x4+warp4x4+stage2+swizzle): ['32.40625 ', '-4.0039062'], time:4.914212ms, swizzle: 2048, TFLOPS: 111.87(+1.04%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['32.40625 ', '-4.0039062'], time:5.039000ms, swizzle: 2048, TFLOPS: 109.10 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['32.40625 ', '-4.0039062'], time:4.895591ms, swizzle: 2048, TFLOPS: 112.30(+0.38%) + (cublas): ['32.40625 ', '-4.0039062'], time:4.766654ms, swizzle: NOOP, TFLOPS: 115.33(+2.70%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=8192, K=4096, Warmup=5, Iters=20, 23/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['99.125 ', '22.8125 '], time:10.30406ms, swizzle: NOOP, TFLOPS: 106.71(+0.00%) + (mma2x4+warp4x4+stage3): ['99.125 ', '22.8125 '], time:9.895300ms, swizzle: NOOP, TFLOPS: 111.11(+4.13%) + (mma2x4+warp4x4+stage2): ['99.125 ', '22.8125 '], time:9.813237ms, swizzle: NOOP, TFLOPS: 112.04(+0.84%) + (mma2x4+warp4x4+stage3+dsmem): ['99.125 ', '22.8125 '], time:9.948658ms, swizzle: NOOP, TFLOPS: 110.52 + (mma2x4+warp4x4+stage2+dsmem): ['99.125 ', '22.8125 '], time:9.798026ms, swizzle: NOOP, TFLOPS: 112.22(+0.16%) + (mma2x4+warp4x4+stage3+swizzle): ['99.125 ', '22.8125 '], time:9.914517ms, swizzle: 2048, TFLOPS: 110.90 + (mma2x4+warp4x4+stage2+swizzle): ['99.125 ', '22.8125 '], time:9.733128ms, swizzle: 2048, TFLOPS: 112.97(+0.67%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['99.125 ', '22.8125 '], time:9.941744ms, swizzle: 2048, TFLOPS: 110.60 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['99.125 ', '22.8125 '], time:9.670472ms, swizzle: 2048, TFLOPS: 113.70(+0.65%) + (cublas): ['99.125 ', '22.8125 '], time:9.453558ms, swizzle: NOOP, TFLOPS: 116.31(+2.29%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=8192, K=8192, Warmup=5, Iters=20, 24/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['21.984375 ', '58.0 '], time:21.51823ms, swizzle: NOOP, TFLOPS: 102.19(+0.00%) + (mma2x4+warp4x4+stage3): ['21.984375 ', '58.0 '], time:20.90017ms, swizzle: NOOP, TFLOPS: 105.22(+2.96%) + (mma2x4+warp4x4+stage2): ['21.984375 ', '58.0 '], time:20.75178ms, swizzle: NOOP, TFLOPS: 105.97(+0.72%) + (mma2x4+warp4x4+stage3+dsmem): ['21.984375 ', '58.0 '], time:20.97730ms, swizzle: NOOP, TFLOPS: 104.83 + (mma2x4+warp4x4+stage2+dsmem): ['21.984375 ', '58.0 '], time:20.83809ms, swizzle: NOOP, TFLOPS: 105.53 + (mma2x4+warp4x4+stage3+swizzle): ['21.984375 ', '58.0 '], time:19.78309ms, swizzle: 2048, TFLOPS: 111.16(+4.90%) + (mma2x4+warp4x4+stage2+swizzle): ['21.984375 ', '58.0 '], time:19.33062ms, swizzle: 2048, TFLOPS: 113.76(+2.34%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['21.984375 ', '58.0 '], time:19.74017ms, swizzle: 2048, TFLOPS: 111.40 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['21.984375 ', '58.0 '], time:19.22986ms, swizzle: 2048, TFLOPS: 114.35(+0.52%) + (cublas): ['21.984375 ', '58.0 '], time:18.83535ms, swizzle: NOOP, TFLOPS: 116.75(+2.09%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=16384, K=2048, Warmup=5, Iters=20, 25/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['32.40625 ', '-4.0039062'], time:10.34352ms, swizzle: NOOP, TFLOPS: 106.30(+0.00%) + (mma2x4+warp4x4+stage3): ['32.40625 ', '-4.0039062'], time:9.953904ms, swizzle: NOOP, TFLOPS: 110.46(+3.91%) + (mma2x4+warp4x4+stage2): ['32.40625 ', '-4.0039062'], time:9.861850ms, swizzle: NOOP, TFLOPS: 111.49(+0.93%) + (mma2x4+warp4x4+stage3+dsmem): ['32.40625 ', '-4.0039062'], time:9.998512ms, swizzle: NOOP, TFLOPS: 109.97 + (mma2x4+warp4x4+stage2+dsmem): ['32.40625 ', '-4.0039062'], time:9.855365ms, swizzle: NOOP, TFLOPS: 111.56(+0.07%) + (mma2x4+warp4x4+stage3+swizzle): ['32.40625 ', '-4.0039062'], time:9.974408ms, swizzle: 4096, TFLOPS: 110.23 + (mma2x4+warp4x4+stage2+swizzle): ['32.40625 ', '-4.0039062'], time:9.743142ms, swizzle: 4096, TFLOPS: 112.85(+1.15%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['32.40625 ', '-4.0039062'], time:9.995770ms, swizzle: 4096, TFLOPS: 110.00 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['32.40625 ', '-4.0039062'], time:9.701442ms, swizzle: 4096, TFLOPS: 113.33(+0.43%) + (cublas): ['32.40625 ', '-4.0039062'], time:9.485888ms, swizzle: NOOP, TFLOPS: 115.91(+2.27%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=16384, K=4096, Warmup=5, Iters=20, 26/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['99.125 ', '22.8125 '], time:22.18379ms, swizzle: NOOP, TFLOPS: 99.13 (+0.00%) + (mma2x4+warp4x4+stage3): ['99.125 ', '22.8125 '], time:21.83485ms, swizzle: NOOP, TFLOPS: 100.71(+1.60%) + (mma2x4+warp4x4+stage2): ['99.125 ', '22.8125 '], time:21.14553ms, swizzle: NOOP, TFLOPS: 103.99(+3.26%) + (mma2x4+warp4x4+stage3+dsmem): ['99.125 ', '22.8125 '], time:21.59111ms, swizzle: NOOP, TFLOPS: 101.85 + (mma2x4+warp4x4+stage2+dsmem): ['99.125 ', '22.8125 '], time:20.96095ms, swizzle: NOOP, TFLOPS: 104.91(+0.88%) + (mma2x4+warp4x4+stage3+swizzle): ['99.125 ', '22.8125 '], time:19.78907ms, swizzle: 4096, TFLOPS: 111.12(+5.92%) + (mma2x4+warp4x4+stage2+swizzle): ['99.125 ', '22.8125 '], time:19.28851ms, swizzle: 4096, TFLOPS: 114.01(+2.60%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['99.125 ', '22.8125 '], time:19.74153ms, swizzle: 4096, TFLOPS: 111.39 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['99.125 ', '22.8125 '], time:19.19734ms, swizzle: 4096, TFLOPS: 114.55(+0.47%) + (cublas): ['99.125 ', '22.8125 '], time:18.88573ms, swizzle: NOOP, TFLOPS: 116.44(+1.65%) +---------------------------------------------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------------------------------------- + M=16384, N=16384, K=8192, Warmup=5, Iters=20, 27/27 +---------------------------------------------------------------------------------------------------------------------------------- +--------------------------------------------------------------------MMA----------------------------------------------------------- + (mma2x4+warp4x4): ['21.984375 ', '58.0 '], time:45.41800ms, swizzle: NOOP, TFLOPS: 96.83 (+0.00%) + (mma2x4+warp4x4+stage3): ['21.984375 ', '58.0 '], time:49.64394ms, swizzle: NOOP, TFLOPS: 88.59 + (mma2x4+warp4x4+stage2): ['21.984375 ', '58.0 '], time:49.82240ms, swizzle: NOOP, TFLOPS: 88.27 + (mma2x4+warp4x4+stage3+dsmem): ['21.984375 ', '58.0 '], time:49.68290ms, swizzle: NOOP, TFLOPS: 88.52 + (mma2x4+warp4x4+stage2+dsmem): ['21.984375 ', '58.0 '], time:49.83477ms, swizzle: NOOP, TFLOPS: 88.25 + (mma2x4+warp4x4+stage3+swizzle): ['21.984375 ', '58.0 '], time:39.11197ms, swizzle: 4096, TFLOPS: 112.45(+16.12%) + (mma2x4+warp4x4+stage2+swizzle): ['21.984375 ', '58.0 '], time:38.40293ms, swizzle: 4096, TFLOPS: 114.52(+1.85%) + (mma2x4+warp4x4+stage3+dsmem+swizzle): ['21.984375 ', '58.0 '], time:39.23041ms, swizzle: 4096, TFLOPS: 112.11 + (mma2x4+warp4x4+stage2+dsmem+swizzle): ['21.984375 ', '58.0 '], time:38.21511ms, swizzle: 4096, TFLOPS: 115.09(+0.49%) + (cublas): ['21.984375 ', '58.0 '], time:37.87384ms, swizzle: NOOP, TFLOPS: 116.12(+0.90%) +---------------------------------------------------------------------------------------------------------------------------------- +``` + + ## NVIDIA GeForce RTX 3080 Laptop +- WMMA + ```bash -python3 hgemm.py --wmma --no-default +python3 hgemm.py --wmma --wmma-all ``` 输出: ```bash diff --git a/hgemm/hgemm.cu b/hgemm/hgemm.cu index b0179ca0..8845b188 100644 --- a/hgemm/hgemm.cu +++ b/hgemm/hgemm.cu @@ -1011,6 +1011,12 @@ void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages(torch::Tensor a, torch::Tensor b void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); +// from hgemm_mma.cu +void hgemm_mma_m16n8k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c); +void hgemm_mma_m16n8k16_mma2x4_warp4x4(torch::Tensor a, torch::Tensor b, torch::Tensor c); +// from hgemm_mma_stage.cu +void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); +void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // CUDA Cores FP16 @@ -1042,5 +1048,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem) TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem) TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem) + // MMA API Tensor Cores + TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_naive) + TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4) + // stage, thread block swizzle, dsmem + TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages) + TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem) } diff --git a/hgemm/hgemm.py b/hgemm/hgemm.py index 597d8220..458d0818 100644 --- a/hgemm/hgemm.py +++ b/hgemm/hgemm.py @@ -14,12 +14,15 @@ def get_args(): parser.add_argument("--K", type=int, default=None, help="Matrix K size") parser.add_argument("--warmup", "--w", type=int, default=5, help="Warmup iters") parser.add_argument("--iters", "--i", type=int, default=20, help="Benchmark iters") - parser.add_argument("--enable-mma-all", "--mma", action="store_true", help="Enable all MMA kernel tests") - parser.add_argument("--enable-wmma-all", "--wmma", action="store_true", help="Enable all WMMA kernel tests") - parser.add_argument("--enable-cuda-all", "--cuda", action="store_true", help="Enable all CUDA kernel tests") + parser.add_argument("--show-all", "--show", action="store_true", help="Show all matrix values ") + parser.add_argument("--enable-mma", "--mma", action="store_true", help="Enable MMA kernel tests") + parser.add_argument("--enable-wmma", "--wmma", action="store_true", help="Enable WMMA kernel tests") + parser.add_argument("--enable-cuda", "--cuda", action="store_true", help="Enable CUDA kernel tests") + parser.add_argument("--enable-mma-all", "--mma-all", action="store_true", help="Enable all MMA kernel tests") + parser.add_argument("--enable-wmma-all", "--wmma-all", action="store_true", help="Enable all WMMA kernel tests") + parser.add_argument("--enable-cuda-all", "--cuda-all", action="store_true", help="Enable all CUDA kernel tests") parser.add_argument("--enable-torch", "--torch", action="store_true", help="Enable torch matmul") parser.add_argument("--disable-cublas", "--no-cublas", action="store_true", help="Disable cublas hgemm") - parser.add_argument("--disable-default", "--no-default", action="store_true", help="Disable default tests") return parser.parse_args() args = get_args() @@ -29,7 +32,8 @@ def get_args(): print("Loading hgemm lib ...") lib = load(name='hgemm_lib', sources=['hgemm.cu', 'hgemm_async.cu', 'hgemm_wmma.cu', - 'hgemm_wmma_stage.cu', 'hgemm_cublas.cu'], + 'hgemm_wmma_stage.cu', 'hgemm_cublas.cu', + 'hgemm_mma.cu', 'hgemm_mma_stage.cu'], extra_cuda_cflags=[ "-O3", "-U__CUDA_NO_HALF_OPERATORS__", @@ -40,7 +44,8 @@ def get_args(): "--expt-extended-lambda", "--use_fast_math" ], - extra_cflags=['-std=c++17']) + extra_cflags=['-std=c++17'], + verbose=False) MAX_TFLOPS = -1 @@ -52,7 +57,7 @@ def run_benchmark(perf_func: callable, swizzle_stride: int = 1, warmup: int = args.warmup, iters: int = args.iters, - show_all: bool = False): + show_all: bool = args.show_all): global MAX_TFLOPS M = a.size(0) @@ -163,9 +168,10 @@ def run_benchmark(perf_func: callable, # CUDA Cores FP16 run_benchmark(lib.hgemm_naive_f16, a, b, "(naive)", c) run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf, a, b, "(f16x8pack+t8x8+bcf)", c) - if not args.disable_default: + if args.enable_cuda or args.enable_cuda_all: run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf, a, b, "(f16x8pack+t8x8+dbuf)", c) run_benchmark(lib.hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf, a, b, "(f16x8pack+t8x8+k16+dbuf)", c) + if args.enable_wmma or args.enable_wmma_all: print("-" * 68 + "WMMA" + "-" * 58) # wmma api, stages, dsmem, swizzle run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2, a, b, "(mma4x2)", c) @@ -193,8 +199,19 @@ def run_benchmark(perf_func: callable, run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem, a, b, "(mma4x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(mma4x2+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem, a, b, "(mma4x2+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) - if args.enable_mma_all: # more mma kernel tests. + if args.enable_mma or args.enable_mma_all: print("-" * 68 + "MMA" + "-" * 59) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4, a, b, "(mma2x4+warp4x4)", c) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage3)", c, stages=3) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage2)", c, stages=2) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage3+dsmem)", c, stages=3) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage2+dsmem)", c, stages=2) + # thread block swizzle + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage3+swizzle)", c, stages=3, swizzle=True) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage2+swizzle)", c, stages=2, swizzle=True) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) + run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem, a, b, "(mma2x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) + if args.enable_mma_all: # more mma kernel tests. pass if not args.disable_cublas: run_benchmark(lib.hgemm_cublas_tensor_op_row_major, a, b, "(cublas)", c) diff --git a/hgemm/hgemm_mma.cu b/hgemm/hgemm_mma.cu index 678c40d2..7a3b5fe8 100644 --- a/hgemm/hgemm_mma.cu +++ b/hgemm/hgemm_mma.cu @@ -28,16 +28,306 @@ using namespace nvcuda; // ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes. #define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) #define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) -// Support A and B matrix with row-major inorder to compare with the kernels using CUDA Cores in -// hgemm.cu and hgemm_async.cu. - +#define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +#define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) HOST_DEVICE_INLINE int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } -// only 1 warp per block(32 threads), m16n16k16. A, B, C: all row_major. -template