Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA fwd D=128] Reduce LDS usage in epilogue #340

Merged
merged 9 commits into from
Oct 25, 2023
Merged

Conversation

oplavsic
Copy link
Collaborator

@oplavsic oplavsic commented Sep 26, 2023

This PR reduces LDS usage in the epilogue by breaking the convert_layout #mfma --> #blocked to multiple convert_layouts, each of which uses less LDS than the original one. The issue with the original convert_layout is that padding is used in LDS to avoid bank conflicts. It's simpler than swizzling but extra LDS space is required.

Note:

  • This PR only reduces LDS usage if the convert_layout in the epilogue is using more than 64 KB of LDS. And this is the case for FA fwd D=128
  • For smaller block size, as in the case of FA fwd D=64, the convert_layout in the epilogue does not use more than 64 KB LDS but still uses more LDS compared to the block size, which can harm the occupancy. We'll fix this issue for D=64 in future PR.

@zhanglx13
Copy link

zhanglx13 commented Sep 26, 2023

Benchmarks

nheads bs seqlen d64-False d64-True d128-False d128-True d64-bwd
48 4 1024 101 58 92 48 18
48 4 2048 109 79 102 55 23
48 4 4096 108 92 107 78 27
48 4 8192 108 98 110 96 29
48 4 16384 109 101 112 107 30

@oplavsic
Copy link
Collaborator Author

Benchmarks

fwd causal=False

This PR

N_CTX d64-waves2 d64-waves3 d128-waves2 d128-waves3
1024 82.841357 99.843707 80.969017 38.079945
2048 88.786731 108.1741 88.17704 40.056298
4096 91.643274 110.840493 92.150436 40.94952
8192 92.731951 108.502042 95.371288 41.567236
16384 93.154725 109.13215 96.29653 41.749272
triton-mlir

N_CTX d64-waves2 d64-waves3 d128-waves2 d128-waves3
1024 82.62418 99.66007 70.760866 13.822745
2048 88.981469 108.095599 74.312649 14.039361
4096 91.777661 110.77687 76.916985 14.165306
8192 92.766691 108.49514 78.215578 14.202539
16384 93.005973 109.051868 78.136207 14.230571

fwd causal=True

This PR

N_CTX d64-waves2 d64-waves3 d128-waves2 d128-waves3
1024 27.837838 6.637805 33.545301 3.285597
2048 39.584447 7.950494 37.155564 3.715085
4096 48.823643 9.170838 52.601148 4.312236
8192 54.126445 9.959126 63.314607 4.54692
16384 56.297693 10.381401 68.738105 4.680015
triton-mlir

N_CTX d64-waves2 d64-waves3 d128-waves2 d128-waves3
1024 27.821601 6.644341 34.41814 5.357317
2048 39.642029 7.943067 48.589815 6.658589
4096 48.824999 9.121136 60.285348 7.383126
8192 54.124562 9.959886 66.399662 7.724
16384 56.184433 10.382775 68.18188 7.868713

bwd kernel

I used the tutorial and this PR and triton-mlir have same perf numbers. But for bwd kernel there is another benchmark repo. We can try that later.

Conclusions

  • This PR improves the performance of fwd kernel with causal=False and d=128 from 78 to 96 tflops
  • To get the best perf for d=128, just use the default value of waves-per-eu
  • To get the best perf for d=64, we still need to set waves-per-eu=3 explicitly

@oplavsic @alefimov-amd @scxiao @jayfurmanek @sunway513

@zhanglx13 Thanks for running benchmarks!
A bit of context that would explain some of these results:

d=128 case uses more registers. I think that in BLOCK_M = 128, BLOCK_N = 64 case, which we use for d=64 it uses a bit more than half of the registers available on 1 SIMD. So basically, it allows only 1 wave per SIMD, but it doesn't utilize all the registers available. That's why I wanted to increase BLOCK sizes, since we already don't have great occupancy, at least to have less grids and utilize all 512 registers available. So it turned out that BLOCK_M = 256, BLOCK_N = 32 used around 512 registers, so still we don't any spills, we still can only run one wave per SIMD, but are doing larger computations. To run this config, I needed to solve some LDS issues that I encountered (detailed explanation in the code).
That is why you can't gain any performance by setting waves_per_eu to anything larger than 1. Since wave needs about 512 registers, trying to fit more than 1 wave to SIMD will result in a lot of spills.

Another way going about this problem would be an opposite of this. Instead of trying to increase register pressure to have better utilization of SIMD, we can try to decrease it as much as we can, so wave needs < 256 registers. That way we could naturally fit 2 waves per SIMD, thus increasing occupancy.

TLDR: This PR improves performance by making the best of already low CU occupancy by increasing register pressure as much as possible without having any spills.

@zhanglx13
Copy link

I remembered the default setting of waves-per-eu leads us to have 2 waves per SIMD for d=64 with BLOCK_M=128 and BLOCK_N=64. Have you checked with thread trace that we only have 1 wave per SIMD?


assert(minIdx >= 0 && minIdx < factorizedNumWarps.size());
auto warpsPerCTAPair = factorizedNumWarps[minIdx];
std::tie(tmpCvt, newEpliogueCvt) =
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: epilogue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks!
What a coincidence, few days ago I watched your talk https://www.youtube.com/watch?v=VbFqA9rvxPs. Very interesting presentation! :)

// clang-format off
//
// LDS usage of this op is roughly calculated as:
// LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layoput)[1] * sizeof(data_type)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit pick

Suggested change
// LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layoput)[1] * sizeof(data_type)
// LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layout)[1] * sizeof(data_type)

Comment on lines +609 to +788
// LDS_USAGE = warpsPerCTA(mfma_layout)[0] * warpsPerCta(blocked_layout)[1] * C,
// where C = 32 * sizePerWarp(blocked_layout)[1] * threadsPerWarp(blocked_layout)[1] * sizeof(data_type)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these lines redundant?

int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt);
int newCvtLDS = getCvtOpLDSUsage(newEpliogueCvt);
if (tmpCvtLDS <= LDSSize && newCvtLDS <= LDSSize) {
int LDSUsage = tmpCvtLDS + newCvtLDS;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does lifetimes of scratch buffers of tmpCvt and newEpilogueCvt overlap?
If not I think it is better to choose maximum from these values instead of sum them.

@zhanglx13 zhanglx13 changed the base branch from triton-mlir to improve_fa_fwd October 11, 2023 04:25
@zhanglx13 zhanglx13 changed the title Reduce implicit LDS usage of convert ops [DO NOT MERGE] Reduce implicit LDS usage of convert ops Oct 11, 2023
@zhanglx13 zhanglx13 changed the base branch from improve_fa_fwd to triton-mlir October 12, 2023 17:42
@zhanglx13 zhanglx13 changed the title [DO NOT MERGE] Reduce implicit LDS usage of convert ops Reduce LDS usage in epilogue Oct 24, 2023
@zhanglx13 zhanglx13 changed the title Reduce LDS usage in epilogue [FA fwd D=128] Reduce LDS usage in epilogue Oct 24, 2023
@zhanglx13 zhanglx13 merged commit 715a589 into triton-mlir Oct 25, 2023
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants