包含以下内容:
- dot_prod_f32_f32_kernel
- dot_prod_f32x4_f32_kernel(float4向量化版本)
- dot_prod_f16_f32_kernel(fp16版本,使用fp32 acc)
- dot_prod_f16x2_f32_kernel(fp16向量化版本,使用fp32 acc)
- dot_prod_f16x8_pack_f32_kernel(fp16向量化版本,使用fp32 acc, pack)
- PyTorch bindings
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 dot_product.py
输出:
--------------------------------------------------------------------------------
S=1024, K=1024
out_f32f32: -332.80715942 , time:0.01124835ms
out_f32x4f32: -332.80645752 , time:0.01134133ms
out_f32f32_th: -332.80691528 , time:0.01127815ms
--------------------------------------------------------------------------------
out_f16f32: -333.19879150 , time:0.01110196ms
out_f16x2f32: -333.44345093 , time:0.01122665ms
out_f16x8packf32: -333.64193726 , time:0.01099825ms
out_f16f16_th: -332.75000000 , time:0.01118803ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=1024, K=2048
out_f32f32: -142.86260986 , time:0.01630998ms
out_f32x4f32: -142.86064148 , time:0.01116729ms
out_f32f32_th: -142.86035156 , time:0.01143432ms
--------------------------------------------------------------------------------
out_f16f32: -143.31562805 , time:0.01554394ms
out_f16x2f32: -142.84217834 , time:0.01099968ms
out_f16x8packf32: -143.60864258 , time:0.01112890ms
out_f16f16_th: -143.00000000 , time:0.01136470ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=1024, K=4096
out_f32f32: -3116.77270508 , time:0.02791572ms
out_f32x4f32: -3116.77929688 , time:0.01236105ms
out_f32f32_th: -3116.77709961 , time:0.01418424ms
--------------------------------------------------------------------------------
out_f16f32: -3118.24951172 , time:0.02777576ms
out_f16x2f32: -3118.13208008 , time:0.01556611ms
out_f16x8packf32: -3118.15527344 , time:0.01114249ms
out_f16f16_th: -3118.00000000 , time:0.01161337ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=2048, K=1024
out_f32f32: -1549.67492676 , time:0.01551032ms
out_f32x4f32: -1549.67419434 , time:0.01115298ms
out_f32f32_th: -1549.67382812 , time:0.01146293ms
--------------------------------------------------------------------------------
out_f16f32: -1549.45434570 , time:0.01545978ms
out_f16x2f32: -1549.04064941 , time:0.01100898ms
out_f16x8packf32: -1549.04748535 , time:0.01111746ms
out_f16f16_th: -1550.00000000 , time:0.01136041ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=2048, K=2048
out_f32f32: -4219.10205078 , time:0.02766943ms
out_f32x4f32: -4219.10009766 , time:0.01223850ms
out_f32f32_th: -4219.10693359 , time:0.01404524ms
--------------------------------------------------------------------------------
out_f16f32: -4218.69335938 , time:0.02764416ms
out_f16x2f32: -4219.42822266 , time:0.01547956ms
out_f16x8packf32: -4219.27929688 , time:0.01113629ms
out_f16f16_th: -4220.00000000 , time:0.01157045ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=2048, K=4096
out_f32f32: -2869.79296875 , time:0.05231595ms
out_f32x4f32: -2869.78149414 , time:0.02043509ms
out_f32f32_th: -2869.78759766 , time:0.02305937ms
--------------------------------------------------------------------------------
out_f16f32: -2870.39965820 , time:0.05218816ms
out_f16x2f32: -2871.60571289 , time:0.02775407ms
out_f16x8packf32: -2870.28857422 , time:0.01228762ms
out_f16f16_th: -2870.00000000 , time:0.01509762ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=4096, K=1024
out_f32f32: -1801.87890625 , time:0.02767515ms
out_f32x4f32: -1801.88061523 , time:0.01203156ms
out_f32f32_th: -1801.88317871 , time:0.01396847ms
--------------------------------------------------------------------------------
out_f16f32: -1801.71777344 , time:0.02766609ms
out_f16x2f32: -1801.05224609 , time:0.01547670ms
out_f16x8packf32: -1799.91137695 , time:0.01112270ms
out_f16f16_th: -1801.00000000 , time:0.01154137ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=4096, K=2048
out_f32f32: 643.72991943 , time:0.05231857ms
out_f32x4f32: 643.72863770 , time:0.02044320ms
out_f32f32_th: 643.73022461 , time:0.02305865ms
--------------------------------------------------------------------------------
out_f16f32: 644.73352051 , time:0.05214262ms
out_f16x2f32: 644.69067383 , time:0.02766657ms
out_f16x8packf32: 644.65740967 , time:0.01228309ms
out_f16f16_th: 644.00000000 , time:0.01508307ms
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
S=4096, K=4096
out_f32f32: 7372.59375000 , time:0.17362595ms
out_f32x4f32: 7372.59960938 , time:0.18044138ms
out_f32f32_th: 7372.58251953 , time:0.18282819ms
--------------------------------------------------------------------------------
out_f16f32: 7371.09033203 , time:0.10100150ms
out_f16x2f32: 7371.48632812 , time:0.05214143ms
out_f16x8packf32: 7369.69873047 , time:0.02043009ms
out_f16f16_th: 7372.00000000 , time:0.02451396ms
--------------------------------------------------------------------------------