包含以下内容:
- rms_norm_f32_kernel
- rms_norm_f32x4_kernel
- rms_norm_f16_f16_kernel
- rms_norm_f16x2_f16_kernel
- rms_norm_f16x8_f16_kernel
- rms_norm_f16x8_f32_kernel
- rms_norm_f16x8_pack_f16_kernel
- rms_norm_f16x8_pack_f32_kernel
- rms_norm_f16_f32_kernel
- PyTorch bindings
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 rms_norm.py
输出:
-------------------------------------------------------------------------------------
N=4096, K=512
out_f32: ['0.04078517 ', '0.74503314 ', '0.87149841 '], time:0.01198173ms
out_f32x4: ['0.04078517 ', '0.74503314 ', '0.87149841 '], time:0.00517488ms
out_f32_th: ['0.04078539 ', '0.74503714 ', '0.87150306 '], time:0.04351616ms
-------------------------------------------------------------------------------------
out_f16f16: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.01200986ms
out_f16f32: ['0.040802 ', '0.74511719 ', '0.87109375 '], time:0.01180410ms
out_f16x2f16: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.00670171ms
out_f16x8f16: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.00411820ms
out_f16x8f32: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.00411677ms
out_f16x8packf16: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.00411630ms
out_f16x8packf32: ['0.040802 ', '0.74511719 ', '0.87109375 '], time:0.00399137ms
out_f16_th: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.04383564ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=1024
out_f32: ['-0.76329279 ', '-0.62111992 ', '-1.45531178 '], time:0.03398657ms
out_f32x4: ['-0.76329279 ', '-0.62111992 ', '-1.45531178 '], time:0.00862885ms
out_f32_th: ['-0.76329684 ', '-0.62112319 ', '-1.4553194 '], time:0.04355550ms
-------------------------------------------------------------------------------------
out_f16f16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.03526235ms
out_f16f32: ['-0.76318359 ', '-0.62109375 ', '-1.45605469 '], time:0.03302288ms
out_f16x2f16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.01215649ms
out_f16x8f16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.00632071ms
out_f16x8f32: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.00631690ms
out_f16x8packf16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.00528240ms
out_f16x8packf32: ['-0.76318359 ', '-0.62109375 ', '-1.45605469 '], time:0.00519514ms
out_f16_th: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.04399920ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=2048
out_f32x4: ['-0.17984088 ', '-1.76387513 ', '-0.32782754 '], time:0.01650691ms
out_f32_th: ['-0.17984176 ', '-1.76388371 ', '-0.32782915 '], time:0.09451318ms
-------------------------------------------------------------------------------------
out_f16x2f16: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.03497124ms
out_f16x8f16: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.01254177ms
out_f16x8f32: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.01253581ms
out_f16x8packf16: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.00903535ms
out_f16x8packf32: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.00894380ms
out_f16_th: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.04889655ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=4096
out_f32x4: ['-1.14100003 ', '-0.71529448 ', '2.26544118 '], time:0.18783689ms
out_f32_th: ['-1.14100587 ', '-0.71529812 ', '2.26545286 '], time:0.52556086ms
-------------------------------------------------------------------------------------
out_f16x8f16: ['-1.140625 ', '-0.71484375 ', '2.26367188 '], time:0.03605795ms
out_f16x8f32: ['-1.140625 ', '-0.71484375 ', '2.26367188 '], time:0.03605533ms
out_f16x8packf16: ['-1.140625 ', '-0.71484375 ', '2.26367188 '], time:0.01718473ms
out_f16x8packf32: ['-1.140625 ', '-0.71533203 ', '2.26367188 '], time:0.01735568ms
out_f16_th: ['-1.140625 ', '-0.71484375 ', '2.26367188 '], time:0.11150384ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=4096, K=8192
out_f16x8f16: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.19292974ms
out_f16x8f32: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.19298863ms
out_f16x8packf16: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.18497562ms
out_f16x8packf32: ['-0.40844727 ', '-0.14294434 ', '-0.93310547 '], time:0.18479729ms
out_f16_th: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.59557104ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
N=8192, K=8192
out_f16x8f16: ['-0.35253906 ', '-1.04101562 ', '0.17358398 '], time:0.38169765ms
out_f16x8f32: ['-0.35253906 ', '-1.04101562 ', '0.17358398 '], time:0.38264203ms
out_f16x8packf16: ['-0.35253906 ', '-1.04101562 ', '0.17358398 '], time:0.40794849ms
out_f16x8packf32: ['-0.35229492 ', '-1.04003906 ', '0.17346191 '], time:0.40747380ms
out_f16_th: ['-0.35229492 ', '-1.04003906 ', '0.17346191 '], time:1.35807014ms
-------------------------------------------------------------------------------------