包含以下内容:
- layer_norm_f32_kernel
- layer_norm_f32x4_kernel
- layer_norm_f16_f16_kernel
- layer_norm_f16x2_f16_kernel
- layer_norm_f16x8_f16_kernel
- layer_norm_f16x8_pack_f16_kernel
- layer_norm_f16x8_pack_f32_kernel
- layer_norm_f16_f32_kernel
- PyTorch bindings
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 layer_norm.py
输出:
-------------------------------------------------------------------------------------
N=4096, K=512
-------------------------------------------------------------------------------------
out_f32: ['-0.95119929 ', '0.65728813 ', '-0.27701864 '], time:0.01898599ms
out_f32x4: ['-0.95119929 ', '0.65728813 ', '-0.27701864 '], time:0.00600958ms
out_f32_th: ['-0.95026982 ', '0.65664589 ', '-0.27674797 '], time:0.07345414ms
-------------------------------------------------------------------------------------
out_f16f16: ['-0.95068359 ', '0.65722656 ', '-0.27709961 '], time:0.01866651ms
out_f16f32: ['-0.95117188 ', '0.65722656 ', '-0.27709961 '], time:0.01897073ms
out_f16x2f16: ['-0.95068359 ', '0.65722656 ', '-0.27709961 '], time:0.00952697ms
out_f16x8f16: ['-0.95068359 ', '0.65722656 ', '-0.27709961 '], time:0.00470805ms
out_f16x8packf16: ['-0.95117188 ', '0.65673828 ', '-0.27709961 '], time:0.00427437ms
out_f16x8packf32: ['-0.95117188 ', '0.65722656 ', '-0.27709961 '], time:0.00418639ms
out_f16_th: ['-0.94970703 ', '0.65673828 ', '-0.27685547 '], time:0.07291913ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=1024
-------------------------------------------------------------------------------------
out_f32: ['0.81839228 ', '0.36616057 ', '-1.71588480 '], time:0.05122757ms
out_f32x4: ['0.81839228 ', '0.36616057 ', '-1.71588480 '], time:0.01071095ms
out_f32_th: ['0.81799269 ', '0.36598179 ', '-1.71504688 '], time:0.07267237ms
-------------------------------------------------------------------------------------
out_f16f16: ['0.81835938 ', '0.36596680 ', '-1.71484375 '], time:0.05317926ms
out_f16f32: ['0.81835938 ', '0.36621094 ', '-1.71582031 '], time:0.05062103ms
out_f16x2f16: ['0.81884766 ', '0.36621094 ', '-1.71679688 '], time:0.01855445ms
out_f16x8f16: ['0.81884766 ', '0.36621094 ', '-1.71679688 '], time:0.00742888ms
out_f16x8packf16: ['0.81884766 ', '0.36621094 ', '-1.71679688 '], time:0.00645399ms
out_f16x8packf32: ['0.81835938 ', '0.36621094 ', '-1.71582031 '], time:0.00634456ms
out_f16_th: ['0.81835938 ', '0.36596680 ', '-1.71582031 '], time:0.07386255ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=2048
-------------------------------------------------------------------------------------
out_f32x4: ['-0.65341073 ', '0.10270299 ', '-0.06597849 '], time:0.02200651ms
out_f32_th: ['-0.65325129 ', '0.10267793 ', '-0.06596238 '], time:0.12027287ms
-------------------------------------------------------------------------------------
out_f16x2f16: ['-0.65332031 ', '0.10266113 ', '-0.06591797 '], time:0.05352354ms
out_f16x8f16: ['-0.65380859 ', '0.10272217 ', '-0.06597900 '], time:0.01377678ms
out_f16x8packf16: ['-0.65332031 ', '0.10266113 ', '-0.06591797 '], time:0.01154637ms
out_f16x8packf32: ['-0.65332031 ', '0.10272217 ', '-0.06597900 '], time:0.01166582ms
out_f16_th: ['-0.65380859 ', '0.10272217 ', '-0.06597900 '], time:0.08442783ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=4096
-------------------------------------------------------------------------------------
out_f32x4: ['2.38733387 ', '-0.03023042 ', '0.66022825 '], time:0.18884635ms
out_f32_th: ['2.38704205 ', '-0.03022672 ', '0.66014749 '], time:0.77852798ms
-------------------------------------------------------------------------------------
out_f16x8f16: ['2.38671875 ', '-0.03024292 ', '0.66015625 '], time:0.03325391ms
out_f16x8packf16: ['2.38671875 ', '-0.03024292 ', '0.66015625 '], time:0.02401376ms
out_f16x8packf32: ['2.38671875 ', '-0.03021240 ', '0.66064453 '], time:0.02381730ms
out_f16_th: ['2.38671875 ', '-0.03021240 ', '0.66015625 '], time:0.17546010ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=8192
-------------------------------------------------------------------------------------
out_f16x8f16: ['0.15905762 ', '1.06542969 ', '-0.19396973 '], time:0.19306803ms
out_f16x8packf16: ['0.15905762 ', '1.06542969 ', '-0.19396973 '], time:0.18665886ms
out_f16x8packf32: ['0.15905762 ', '1.06542969 ', '-0.19396973 '], time:0.18657684ms
out_f16_th: ['0.15905762 ', '1.06542969 ', '-0.19396973 '], time:0.84462571ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=8192, K=8192
-------------------------------------------------------------------------------------
out_f16x8f16: ['-0.53662109 ', '2.359375 ', '0.78027344 '], time:0.38366604ms
out_f16x8packf16: ['-0.53662109 ', '2.359375 ', '0.78027344 '], time:0.40789628ms
out_f16x8packf32: ['-0.53613281 ', '2.359375 ', '0.78027344 '], time:0.40818143ms
out_f16_th: ['-0.53662109 ', '2.359375 ', '0.78027344 '], time:1.99523735ms
-------------------------------------------------------------------------------------