Skip to content

vulkan: optimization proposals for coopmat1 mul_mm #12260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

remyoudompheng
Copy link
Contributor

This PR proposes several changes to speed up coopmat1 matrix multiplication.

It is tested only on RDNA3 currently. Feedback on other coopmat architectures is welcome.

  • da686c7 : load all A coopmats in registers to avoid loading them multiple times (possibly increases register pressure, but tiles are quite small anyway)
  • 8223a84 : don't compile unused f16 shaders (no impact on performance, reduces binary size)
  • da0d698 : use f16 for shared buffers if possible (actually no impact on performance, but can reduce shared memory consumption)
  • 28c458a : increase LOAD_VEC_A (similar to vulkan: matmul dequantization improvements #12015) for IQ2 and IQ3

The overall effect on performance seems to be around 25% on IQ2/IQ3 and 10% on QK

Performance on RDNA3 iGPU (780M)

                                                                                                      master          da686c7b        da0d6989        PR
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):            -   1.55 TFLOPS -   1.58 TFLOPS -   1.61 TFLOPS -   1.59 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):            -   2.53 TFLOPS -   2.19 TFLOPS -   2.22 TFLOPS -   2.19 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):           -   3.86 TFLOPS -   4.69 TFLOPS -   4.68 TFLOPS -   4.68 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):           -   3.78 TFLOPS -   4.68 TFLOPS -   4.73 TFLOPS -   4.75 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):           -   3.44 TFLOPS -   4.15 TFLOPS -   4.19 TFLOPS -   4.25 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):           -   3.50 TFLOPS -   3.55 TFLOPS -   3.82 TFLOPS -   4.13 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):           -   3.56 TFLOPS -   3.91 TFLOPS -   3.93 TFLOPS -   3.91 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):           -   3.34 TFLOPS -   4.03 TFLOPS -   3.95 TFLOPS -   4.06 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):           -   2.98 TFLOPS -   3.45 TFLOPS -   3.40 TFLOPS -   3.44 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):           -   2.88 TFLOPS -   3.62 TFLOPS -   3.60 TFLOPS -   3.63 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):           -   2.85 TFLOPS -   3.23 TFLOPS -   3.23 TFLOPS -   3.22 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):           -   2.81 TFLOPS -   3.36 TFLOPS -   3.17 TFLOPS -   3.27 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):          -   2.99 TFLOPS -   4.17 TFLOPS -   4.19 TFLOPS -   4.06 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):          -   3.38 TFLOPS -   3.72 TFLOPS -   3.74 TFLOPS -   4.19 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):        -   3.21 TFLOPS -   3.78 TFLOPS -   3.75 TFLOPS -   3.96 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):         -   3.36 TFLOPS -   3.70 TFLOPS -   3.72 TFLOPS -   3.93 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):          -   2.99 TFLOPS -   3.24 TFLOPS -   3.25 TFLOPS -   3.70 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):        -   3.11 TFLOPS -   3.65 TFLOPS -   3.61 TFLOPS -   3.89 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):          -   3.12 TFLOPS -   3.67 TFLOPS -   3.69 TFLOPS -   3.86 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):         -   3.22 TFLOPS -   3.80 TFLOPS -   3.80 TFLOPS -   3.84 TFLOPS                  
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):         -   3.59 TFLOPS -   4.38 TFLOPS -   4.37 TFLOPS -   4.36 TFLOPS

Performance on prompt processing

model size params backend ngl test master t/s da686c7 t/s da0d698 t/s PR t/s
qwen2 3B IQ2_M - 2.7 bpw 1.06 GiB 3.09 B Vulkan 99 pp256 480.76 ± 1.32 519.60 ± 0.73 519.22 ± 1.89 604.06 ± 1.18
qwen2 3B IQ2_M - 2.7 bpw 1.06 GiB 3.09 B Vulkan 99 pp512 464.59 ± 1.91 506.15 ± 1.64 508.66 ± 0.38 583.36 ± 0.47
qwen2 3B IQ2_M - 2.7 bpw 1.06 GiB 3.09 B Vulkan 99 pp1024 456.02 ± 1.75 493.56 ± 0.48 498.58 ± 0.68 571.40 ± 1.12
qwen2 3B IQ3_S mix - 3.66 bpw 1.38 GiB 3.09 B Vulkan 99 pp256 515.49 ± 1.17 571.79 ± 2.63 572.92 ± 0.26 638.46 ± 0.65
qwen2 3B IQ3_S mix - 3.66 bpw 1.38 GiB 3.09 B Vulkan 99 pp512 495.93 ± 1.63 551.79 ± 1.50 549.11 ± 2.14 614.63 ± 0.21
qwen2 3B IQ3_S mix - 3.66 bpw 1.38 GiB 3.09 B Vulkan 99 pp1024 489.35 ± 0.55 543.93 ± 1.30 533.95 ± 0.35 598.97 ± 1.68
qwen2 3B IQ4_XS - 4.25 bpw 1.61 GiB 3.09 B Vulkan 99 pp256 519.93 ± 2.43 591.49 ± 3.12 584.19 ± 7.43 590.80 ± 3.40
qwen2 3B IQ4_XS - 4.25 bpw 1.61 GiB 3.09 B Vulkan 99 pp512 500.12 ± 3.63 575.09 ± 1.38 566.04 ± 4.75 573.39 ± 4.06
qwen2 3B IQ4_XS - 4.25 bpw 1.61 GiB 3.09 B Vulkan 99 pp1024 488.37 ± 2.09 557.57 ± 1.72 538.20 ± 17.70 552.63 ± 0.76
qwen2 3B Q4_K - Medium 1.79 GiB 3.09 B Vulkan 99 pp256 497.34 ± 3.58 547.36 ± 1.54 537.15 ± 7.23 545.71 ± 2.45
qwen2 3B Q4_K - Medium 1.79 GiB 3.09 B Vulkan 99 pp512 477.47 ± 1.46 530.55 ± 0.66 528.13 ± 1.56 528.60 ± 2.68
qwen2 3B Q4_K - Medium 1.79 GiB 3.09 B Vulkan 99 pp1024 466.98 ± 3.19 519.90 ± 0.49 509.74 ± 0.99 512.27 ± 0.65
qwen2 32B IQ3_XS - 3.3 bpw 12.76 GiB 32.76 B Vulkan 99 pp256 45.32 ± 0.36 48.92 ± 0.13 48.58 ± 0.21 55.61 ± 0.05
qwen2 32B IQ3_XS - 3.3 bpw 12.76 GiB 32.76 B Vulkan 99 pp512 38.51 ± 0.07 41.56 ± 0.08 42.70 ± 0.28 48.35 ± 0.15
qwen2 32B IQ3_XS - 3.3 bpw 12.76 GiB 32.76 B Vulkan 99 pp1024 37.77 ± 0.03 42.83 ± 0.02 42.27 ± 0.16 47.66 ± 0.05

@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Mar 7, 2025
@jeffbolznv
Copy link
Collaborator

I did a quick run on RTX 4070 with coopmat2 disabled:

before:
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   326 runs -  3083.09 us/run -  60.13 GFLOP/run -  19.50 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   318 runs -  3148.90 us/run -  60.13 GFLOP/run -  19.10 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  744 runs -  1344.71 us/run -  60.13 GFLOP/run -  44.72 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  656 runs -  1528.82 us/run -  60.13 GFLOP/run -  39.33 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  712 runs -  1406.18 us/run -  60.13 GFLOP/run -  42.76 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  708 runs -  1415.71 us/run -  60.13 GFLOP/run -  42.47 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  644 runs -  1556.51 us/run -  60.13 GFLOP/run -  38.63 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  650 runs -  1541.96 us/run -  60.13 GFLOP/run -  39.00 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  528 runs -  1896.95 us/run -  60.13 GFLOP/run -  31.70 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  564 runs -  1775.80 us/run -  60.13 GFLOP/run -  33.86 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  472 runs -  2118.84 us/run -  60.13 GFLOP/run -  28.38 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  554 runs -  1807.40 us/run -  60.13 GFLOP/run -  33.27 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):               468 runs -  2137.58 us/run -  60.13 GFLOP/run -  28.13 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                508 runs -  1972.00 us/run -  60.13 GFLOP/run -  30.49 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                 472 runs -  2121.83 us/run -  60.13 GFLOP/run -  28.34 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):               438 runs -  2288.97 us/run -  60.13 GFLOP/run -  26.27 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                 526 runs -  1904.99 us/run -  60.13 GFLOP/run -  31.56 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                 414 runs -  2415.89 us/run -  60.13 GFLOP/run -  24.89 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                744 runs -  1345.36 us/run -  60.13 GFLOP/run -  44.69 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                 576 runs -  1737.48 us/run -  60.13 GFLOP/run -  34.61 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                624 runs -  1604.25 us/run -  60.13 GFLOP/run -  37.48 TFLOPS
  
after
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   324 runs -  3104.48 us/run -  60.13 GFLOP/run -  19.37 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   330 runs -  3033.47 us/run -  60.13 GFLOP/run -  19.82 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  732 runs -  1367.42 us/run -  60.13 GFLOP/run -  43.97 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  674 runs -  1487.76 us/run -  60.13 GFLOP/run -  40.42 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  714 runs -  1404.04 us/run -  60.13 GFLOP/run -  42.83 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  706 runs -  1418.11 us/run -  60.13 GFLOP/run -  42.40 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  658 runs -  1523.43 us/run -  60.13 GFLOP/run -  39.47 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  650 runs -  1539.78 us/run -  60.13 GFLOP/run -  39.05 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  546 runs -  1838.28 us/run -  60.13 GFLOP/run -  32.71 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  556 runs -  1801.43 us/run -  60.13 GFLOP/run -  33.38 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  474 runs -  2115.05 us/run -  60.13 GFLOP/run -  28.43 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                  556 runs -  1801.09 us/run -  60.13 GFLOP/run -  33.39 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):               616 runs -  1626.04 us/run -  60.13 GFLOP/run -  36.98 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                534 runs -  1872.87 us/run -  60.13 GFLOP/run -  32.11 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                 540 runs -  1852.56 us/run -  60.13 GFLOP/run -  32.46 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):               632 runs -  1585.13 us/run -  60.13 GFLOP/run -  37.93 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                 650 runs -  1540.16 us/run -  60.13 GFLOP/run -  39.04 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                 556 runs -  1803.66 us/run -  60.13 GFLOP/run -  33.34 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                744 runs -  1345.90 us/run -  60.13 GFLOP/run -  44.68 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                 662 runs -  1514.20 us/run -  60.13 GFLOP/run -  39.71 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                612 runs -  1634.11 us/run -  60.13 GFLOP/run -  36.80 TFLOPS  

Only the IQ types show a real improvement, presumably from the LOAD_VEC_A changes.

@0cc4m
Copy link
Collaborator

0cc4m commented Mar 11, 2025

da0d698 : use f16 for shared buffers if possible (actually no impact on performance, but can reduce shared memory consumption)

This is already the case, coopmat requires fp16 and that means the buffers will always be 16-bit floats. I think the commit doesn't change anything.

@0cc4m
Copy link
Collaborator

0cc4m commented May 10, 2025

@remyoudompheng Are you still around? Optimizing coopmat1 is a good idea, let's get this PR ready.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants