Skip to content

Latest commit

 

History

History

layer-norm

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

LayerNorm

0x00 说明

包含以下内容:

  • layer_norm_f32_kernel
  • layer_norm_f32x4_kernel
  • layer_norm_f16_f16_kernel
  • layer_norm_f16x2_f16_kernel
  • layer_norm_f16x8_f16_kernel
  • layer_norm_f16x8_pack_f16_kernel
  • layer_norm_f16x8_pack_f32_kernel
  • layer_norm_f16_f32_kernel
  • PyTorch bindings

测试

# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada 
python3 layer_norm.py

输出:

-------------------------------------------------------------------------------------
                                        N=4096, K=512
-------------------------------------------------------------------------------------
          out_f32: ['-0.95119929 ', '0.65728813  ', '-0.27701864 '], time:0.01898599ms
        out_f32x4: ['-0.95119929 ', '0.65728813  ', '-0.27701864 '], time:0.00600958ms
       out_f32_th: ['-0.95026982 ', '0.65664589  ', '-0.27674797 '], time:0.07345414ms
-------------------------------------------------------------------------------------
       out_f16f16: ['-0.95068359 ', '0.65722656  ', '-0.27709961 '], time:0.01866651ms
       out_f16f32: ['-0.95117188 ', '0.65722656  ', '-0.27709961 '], time:0.01897073ms
     out_f16x2f16: ['-0.95068359 ', '0.65722656  ', '-0.27709961 '], time:0.00952697ms
     out_f16x8f16: ['-0.95068359 ', '0.65722656  ', '-0.27709961 '], time:0.00470805ms
 out_f16x8packf16: ['-0.95117188 ', '0.65673828  ', '-0.27709961 '], time:0.00427437ms
 out_f16x8packf32: ['-0.95117188 ', '0.65722656  ', '-0.27709961 '], time:0.00418639ms
       out_f16_th: ['-0.94970703 ', '0.65673828  ', '-0.27685547 '], time:0.07291913ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        N=4096, K=1024
-------------------------------------------------------------------------------------
          out_f32: ['0.81839228  ', '0.36616057  ', '-1.71588480 '], time:0.05122757ms
        out_f32x4: ['0.81839228  ', '0.36616057  ', '-1.71588480 '], time:0.01071095ms
       out_f32_th: ['0.81799269  ', '0.36598179  ', '-1.71504688 '], time:0.07267237ms
-------------------------------------------------------------------------------------
       out_f16f16: ['0.81835938  ', '0.36596680  ', '-1.71484375 '], time:0.05317926ms
       out_f16f32: ['0.81835938  ', '0.36621094  ', '-1.71582031 '], time:0.05062103ms
     out_f16x2f16: ['0.81884766  ', '0.36621094  ', '-1.71679688 '], time:0.01855445ms
     out_f16x8f16: ['0.81884766  ', '0.36621094  ', '-1.71679688 '], time:0.00742888ms
 out_f16x8packf16: ['0.81884766  ', '0.36621094  ', '-1.71679688 '], time:0.00645399ms
 out_f16x8packf32: ['0.81835938  ', '0.36621094  ', '-1.71582031 '], time:0.00634456ms
       out_f16_th: ['0.81835938  ', '0.36596680  ', '-1.71582031 '], time:0.07386255ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        N=4096, K=2048
-------------------------------------------------------------------------------------
        out_f32x4: ['-0.65341073 ', '0.10270299  ', '-0.06597849 '], time:0.02200651ms
       out_f32_th: ['-0.65325129 ', '0.10267793  ', '-0.06596238 '], time:0.12027287ms
-------------------------------------------------------------------------------------
     out_f16x2f16: ['-0.65332031 ', '0.10266113  ', '-0.06591797 '], time:0.05352354ms
     out_f16x8f16: ['-0.65380859 ', '0.10272217  ', '-0.06597900 '], time:0.01377678ms
 out_f16x8packf16: ['-0.65332031 ', '0.10266113  ', '-0.06591797 '], time:0.01154637ms
 out_f16x8packf32: ['-0.65332031 ', '0.10272217  ', '-0.06597900 '], time:0.01166582ms
       out_f16_th: ['-0.65380859 ', '0.10272217  ', '-0.06597900 '], time:0.08442783ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        N=4096, K=4096
-------------------------------------------------------------------------------------
        out_f32x4: ['2.38733387  ', '-0.03023042 ', '0.66022825  '], time:0.18884635ms
       out_f32_th: ['2.38704205  ', '-0.03022672 ', '0.66014749  '], time:0.77852798ms
-------------------------------------------------------------------------------------
     out_f16x8f16: ['2.38671875  ', '-0.03024292 ', '0.66015625  '], time:0.03325391ms
 out_f16x8packf16: ['2.38671875  ', '-0.03024292 ', '0.66015625  '], time:0.02401376ms
 out_f16x8packf32: ['2.38671875  ', '-0.03021240 ', '0.66064453  '], time:0.02381730ms
       out_f16_th: ['2.38671875  ', '-0.03021240 ', '0.66015625  '], time:0.17546010ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        N=4096, K=8192
-------------------------------------------------------------------------------------
     out_f16x8f16: ['0.15905762  ', '1.06542969  ', '-0.19396973 '], time:0.19306803ms
 out_f16x8packf16: ['0.15905762  ', '1.06542969  ', '-0.19396973 '], time:0.18665886ms
 out_f16x8packf32: ['0.15905762  ', '1.06542969  ', '-0.19396973 '], time:0.18657684ms
       out_f16_th: ['0.15905762  ', '1.06542969  ', '-0.19396973 '], time:0.84462571ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
                                        N=8192, K=8192
-------------------------------------------------------------------------------------
     out_f16x8f16: ['-0.53662109 ', '2.359375    ', '0.78027344  '], time:0.38366604ms
 out_f16x8packf16: ['-0.53662109 ', '2.359375    ', '0.78027344  '], time:0.40789628ms
 out_f16x8packf32: ['-0.53613281 ', '2.359375    ', '0.78027344  '], time:0.40818143ms
       out_f16_th: ['-0.53662109 ', '2.359375    ', '0.78027344  '], time:1.99523735ms
-------------------------------------------------------------------------------------