Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove kpack from the decision of how many elements to copy per thread #1714

Merged
merged 2 commits into from
Jan 28, 2025

Conversation

dhernandez0
Copy link
Contributor

@dhernandez0 dhernandez0 commented Jan 9, 2025

kpack is used to decide how many K elements to copy per thread. Currently we do:

maxVlen = 128 / elementType.getIntOrFloatBitWidth()
copyPerThread = (kPerBlock * dPerBlock) / blockSize;
kpack -> param

copyKPerThread = gcd(maxVlen, gcd(kpack, copyPerThread))

For example, if kpack=8 and copyPerThread=16, and maxVlen=16, copyKPerThread would be limited to 8. I think kpack should not be used to limit copyKPerThread. This PR changes the code to:

copyKPerThread = gcd(maxVlen, copyPerThread)

I've run some experiments (gfx942) for int8 (which would be affected by this because maxVlen = 16 in this case) and we can show it improves performance, note these are the unfused int8 kernels from resnet50 (int8 input and int8 output and batch size 32):

file develop this PR speed up
mlir_convolution_1x1024x14x14s200704x1x14336x1024_256x1024x1x1.mxr.py 0.010477 0.010467 1.00
mlir_convolution_1x1024x14x14s200704x1x14336x1024_512x1024x1x1.mxr.py 0.017557 0.017301 1.01
mlir_convolution_1x128x28x28s100352x1x3584x128_128x128x3x3s1152x1x384x128.mxr.py 0.01607 0.015092 1.06
mlir_convolution_1x128x28x28s100352x1x3584x128_512x128x1x1.mxr.py 0.01111 0.010649 1.04
mlir_convolution_1x128x56x56s401408x1x7168x128_128x128x3x3s1152x1x384x128.mxr.py 0.018084 0.016068 1.13
mlir_convolution_1x128x56x56s401408x1x7168x128_256x128x1x1.mxr.py 0.02151 0.020434 1.05
mlir_convolution_1x1536x7x7s75264x1x10752x1536_2048x1536x1x1.mxr.py 0.023993 0.024301 0.99
mlir_convolution_1x2048x7x7s100352x1x14336x2048_512x2048x1x1.mxr.py 0.011723 0.011719 1.00
mlir_convolution_1x256x14x14s50176x1x3584x256_1024x256x1x1.mxr.py 0.009614 0.009476 1.01
mlir_convolution_1x256x14x14s50176x1x3584x256_256x256x3x3s2304x1x768x256.mxr.py 0.020837 0.015433 1.35
mlir_convolution_1x256x28x28s200704x1x7168x256_256x256x3x3s2304x1x768x256.mxr.py 0.021589 0.019241 1.12
mlir_convolution_1x256x56x56s802816x1x14336x256_128x256x1x1.mxr.py 0.017802 0.017515 1.02
mlir_convolution_1x256x56x56s802816x1x14336x256_64x256x1x1.mxr.py 0.010705 0.010758 1.00
mlir_convolution_1x384x28x28s301056x1x10752x384_512x384x1x1.mxr.py 0.02302 0.021914 1.05
mlir_convolution_1x3x224x224s150528x1x672x3_64x3x7x7s147x1x21x3.mxr.py 0.06569 0.066277 0.99
mlir_convolution_1x512x14x14s100352x1x7168x512_512x512x3x3s4608x1x1536x512.mxr.py 0.024977 0.025044 1.00
mlir_convolution_1x512x28x28s401408x1x14336x512_128x512x1x1.mxr.py 0.010838 0.01001 1.08
mlir_convolution_1x512x28x28s401408x1x14336x512_256x512x1x1.mxr.py 0.017405 0.016605 1.05
mlir_convolution_1x512x7x7s25088x1x3584x512_2048x512x1x1.mxr.py 0.008887 0.008363 1.06
mlir_convolution_1x512x7x7s25088x1x3584x512_512x512x3x3s4608x1x1536x512.mxr.py 0.026264 0.022071 1.19
mlir_convolution_1x64x56x56s200704x1x3584x64_256x64x1x1.mxr.py 0.014953 0.014183 1.05
mlir_convolution_1x64x56x56s200704x1x3584x64_64x64x1x1.mxr.py 0.005723 0.005713 1.00
mlir_convolution_1x64x56x56s200704x1x3584x64_64x64x3x3s576x1x192x64.mxr.py 0.018053 0.016827 1.07
mlir_convolution_1x768x14x14s150528x1x10752x768_1024x768x1x1.mxr.py 0.021956 0.021533 1.02

I've tested "sdxl-gemm-configs", "attention-configs" and "sdxl-conv-configs". All configs but one have the same performance (-5% to +5%), the following sdxl-conv gets a 1.25x speed up:

convfp16 -F 1 -f NHWC -I NHWC -O NHWC -n 2 -c 1920 -H 64 -W 64 -k 640 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1

@dhernandez0 dhernandez0 requested a review from umangyadav January 9, 2025 16:12
@dhernandez0 dhernandez0 self-assigned this Jan 9, 2025
@dhernandez0 dhernandez0 requested a review from causten as a code owner January 9, 2025 16:12
Copy link

codecov bot commented Jan 10, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 78.60%. Comparing base (a4e8230) to head (1f64791).
Report is 26 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #1714      +/-   ##
===========================================
- Coverage    78.88%   78.60%   -0.28%     
===========================================
  Files          100      100              
  Lines        28346    28347       +1     
  Branches      4130     4130              
===========================================
- Hits         22361    22283      -78     
- Misses        4368     4402      +34     
- Partials      1617     1662      +45     
Flag Coverage Δ
mfma 78.60% <100.00%> (-0.28%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -127,8 +127,7 @@ computeCopyPerThread(Type elementType, int64_t copyPerThread, int64_t kPerBlock,
copyDPerThread = math_util::gcd(maxVlen, copyPerThread);
copyKPerThread = copyPerThread / copyDPerThread;
} else {
copyKPerThread =
math_util::gcd(maxVlen, math_util::gcd(kpack, copyPerThread));
copyKPerThread = math_util::gcd(maxVlen, copyPerThread);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copyPerThread = (KPerBlock * MPerBlock) / blockSize

where KPerBlock = kPack * kPacksPerBlock

and blockSize = ((MPerBlock * NPerBlock) / (MPerWave * NPerWave)) * waveSize

therefore
copyPerThread = ((KPerBlock * (MPerWave * NPerWave) * WaveSize) / NPerBlock)

So it seem GCD should always be kPack
gcd(kPack, copyPerThread) = kPack

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So previous logic was indeed limiting copyKPerThread to kPack

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should run some more experiments on larger bitwidth dtypes. I am not sure how it affects the register pressures.

You can change rocMLIR SHA similar to this PR on MIGraphX side ROCm/AMDMIGraphX#3743

and MIGraphX CI would run more models with this change.

@causten
Copy link
Collaborator

causten commented Jan 28, 2025

Ran some tests with BERT Large and Resnet50. I didn't see any performance drop. Performance improvements were very small. Would be good to merge though since migx needs to support int8 as a return type to take full advantage of this PR.

@dhernandez0 dhernandez0 merged commit 0e25440 into develop Jan 28, 2025
15 of 24 checks passed
@dhernandez0 dhernandez0 deleted the 1667-improve-int8-kernel-performance branch January 28, 2025 16:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants