包含以下内容:
- gelu_f32_kernel
- gelu_f32x4_kernel(float4向量化版本)
- gelu_f16_kernel
- gelu_f16x2_kernel(half2向量化)
- gelu_f16x8_kernel(unpack版本)
- gelu_f16x8_pack_kernel(pack版本)
- PyTorch bindings
对于半精度(half)的GELU操作,由于CUDA的半精度计算中并不包含tanh操作,因此需要使用hexp来替代对应的操作,因此会引入较大的误差。(或许可以考虑从汇编上解决这个问题);而torch是通过转化数据类型完成的。想要测试很简单,修改一下cu中f16里面的代码做一下强制类型转换即可:
y[idx] = HALF_GELU_OPS(__half2float(v)); // line 96
reg_y.x = HALF_GELU_OPS(__half2float(reg_x.x)); // line 109 , line 110
reg_y.y = HALF_GELU_OPS(__half2float(reg_x.y));
测试结果如下(由于不是所有数据都会掉误差所以取了会有误差的情况,可见修改后out_f16和out_f16x2的结果和torch相同了):
-------------------------------------------------------------------------------------
S=2048, K=4096
out_f32: [-0.08196318, -0.1613517], time:0.13425708ms
out_f32x4: [-0.08196318, -0.1613517], time:0.14128804ms
out_f32_th: [-0.08196313, -0.1613517], time:0.08195782ms
-------------------------------------------------------------------------------------
out_f16: [-0.08197021, -0.16137695], time:0.12120271ms
out_f16x2: [-0.08197021, -0.16137695], time:0.12122369ms
out_f16x8: [-0.08251953, -0.16137695], time:0.04196978ms
out_f16x8pack: [-0.08251953, -0.16137695], time:0.04215288ms
out_f16_th: [-0.08197021, -0.16137695], time:0.04287958ms
-------------------------------------------------------------------------------------
相关参考:
此外仿照torch实现了在float下tanh和none两种近似下的GELU函数,可以在gelu.cu的宏中进行修改实现不同的版本的编译。
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 gelu.py
输出(不做类型转换导致half误差):
-------------------------------------------------------------------------------------
S=1024, K=1024
out_f32: [-0.13358943, -0.06881647], time:0.01621890ms
out_f32x4: [-0.13358943, -0.06881647], time:0.01278400ms
out_f32_th: [-0.13358943, -0.06881647], time:0.00897789ms
-------------------------------------------------------------------------------------
out_f16: [-0.13378906, -0.06884766], time:0.00663781ms
out_f16x2: [-0.13378906, -0.06884766], time:0.00366306ms
out_f16x8: [-0.13378906, -0.06884766], time:0.00343323ms
out_f16x8pack: [-0.13378906, -0.06884766], time:0.00331473ms
out_f16_th: [-0.13354492, -0.06884766], time:0.00907278ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=1024, K=2048
out_f32: [1.38783729, -0.06707606], time:0.02223682ms
out_f32x4: [1.38783729, -0.06707606], time:0.02367806ms
out_f32_th: [1.38783729, -0.06707606], time:0.00959325ms
-------------------------------------------------------------------------------------
out_f16: [1.38769531, -0.06713867], time:0.00834370ms
out_f16x2: [1.38769531, -0.06713867], time:0.00784707ms
out_f16x8: [1.38769531, -0.06713867], time:0.00499964ms
out_f16x8pack: [1.38769531, -0.06713867], time:0.00461078ms
out_f16_th: [1.38769531, -0.06707764], time:0.00895357ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=1024, K=4096
out_f32: [0.47386399, 0.05760021], time:0.04273629ms
out_f32x4: [0.47386399, 0.05760021], time:0.05011940ms
out_f32_th: [0.47386405, 0.05760022], time:0.00933146ms
-------------------------------------------------------------------------------------
out_f16: [0.47387695, 0.05761719], time:0.01495123ms
out_f16x2: [0.47387695, 0.05761719], time:0.01039743ms
out_f16x8: [0.47387695, 0.05761719], time:0.00936055ms
out_f16x8pack: [0.47387695, 0.05761719], time:0.00845838ms
out_f16_th: [0.47387695, 0.05758667], time:0.00918818ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=1024
out_f32: [1.3562144, 0.40408486], time:0.03009892ms
out_f32x4: [1.3562144, 0.40408486], time:0.02289677ms
out_f32_th: [1.3562144, 0.40408486], time:0.00921512ms
-------------------------------------------------------------------------------------
out_f16: [1.35644531, 0.40405273], time:0.01173806ms
out_f16x2: [1.35644531, 0.40405273], time:0.00565076ms
out_f16x8: [1.35644531, 0.40405273], time:0.00502610ms
out_f16x8pack: [1.35644531, 0.40405273], time:0.00457048ms
out_f16_th: [1.35644531, 0.40429688], time:0.00904894ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=2048
out_f32: [-0.16498716, -0.15077244], time:0.04273534ms
out_f32x4: [-0.16498716, -0.15077244], time:0.04386163ms
out_f32_th: [-0.16498716, -0.15077244], time:0.00913596ms
-------------------------------------------------------------------------------------
out_f16: [-0.16516113, -0.15075684], time:0.01495862ms
out_f16x2: [-0.16516113, -0.15075684], time:0.01407337ms
out_f16x8: [-0.16516113, -0.15075684], time:0.00796247ms
out_f16x8pack: [-0.16516113, -0.15075684], time:0.00734925ms
out_f16_th: [-0.16503906, -0.15075684], time:0.00917435ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=4096
out_f32: [-0.03888749, 0.32139146], time:0.08363676ms
out_f32x4: [-0.03888749, 0.32139146], time:0.09505510ms
out_f32_th: [-0.03888749, 0.32139146], time:0.04022837ms
-------------------------------------------------------------------------------------
out_f16: [-0.03887939, 0.3215332], time:0.02813959ms
out_f16x2: [-0.03887939, 0.3215332], time:0.01906514ms
out_f16x8: [-0.03887939, 0.3215332], time:0.01664281ms
out_f16x8pack: [-0.03887939, 0.3215332], time:0.01474833ms
out_f16_th: [-0.03887939, 0.32128906], time:0.01357365ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=1024
out_f32: [-0.13875209, 1.08477271], time:0.05790567ms
out_f32x4: [-0.13875209, 1.08477271], time:0.04317236ms
out_f32_th: [-0.13875209, 1.08477271], time:0.00910425ms
-------------------------------------------------------------------------------------
out_f16: [-0.13903809, 1.08496094], time:0.02198315ms
out_f16x2: [-0.13903809, 1.08496094], time:0.00964355ms
out_f16x8: [-0.13903809, 1.08496094], time:0.00780869ms
out_f16x8pack: [-0.13903809, 1.08496094], time:0.00729132ms
out_f16_th: [-0.13879395, 1.08496094], time:0.00926042ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=2048
out_f32: [0.82045084, -0.0894338], time:0.08363843ms
out_f32x4: [0.82045084, -0.0894338], time:0.08431888ms
out_f32_th: [0.82045084, -0.0894338], time:0.03837347ms
-------------------------------------------------------------------------------------
out_f16: [0.8203125, -0.08947754], time:0.02813506ms
out_f16x2: [0.8203125, -0.08947754], time:0.02643061ms
out_f16x8: [0.8203125, -0.08947754], time:0.01383305ms
out_f16x8pack: [0.8203125, -0.08947754], time:0.01273918ms
out_f16_th: [0.82080078, -0.0894165], time:0.01357722ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=4096
out_f32: [-0.06997654, -0.16092129], time:0.19113564ms
out_f32x4: [-0.06997654, -0.16092129], time:0.20371628ms
out_f32_th: [-0.06997654, -0.16092129], time:0.20496607ms
-------------------------------------------------------------------------------------
out_f16: [-0.07012939, -0.16113281], time:0.05451322ms
out_f16x2: [-0.07012939, -0.16113281], time:0.03633785ms
out_f16x8: [-0.07012939, -0.16113281], time:0.03115463ms
out_f16x8pack: [-0.07012939, -0.16113281], time:0.02735877ms
out_f16_th: [-0.07000732, -0.16088867], time:0.03889561ms
-------------------------------------------------------------------------------------