-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FA fwd D=128] Reduce LDS usage in epilogue #340
Conversation
Benchmarks
|
d7b46e3
to
a2e55a7
Compare
@zhanglx13 Thanks for running benchmarks! d=128 case uses more registers. I think that in BLOCK_M = 128, BLOCK_N = 64 case, which we use for d=64 it uses a bit more than half of the registers available on 1 SIMD. So basically, it allows only 1 wave per SIMD, but it doesn't utilize all the registers available. That's why I wanted to increase BLOCK sizes, since we already don't have great occupancy, at least to have less grids and utilize all 512 registers available. So it turned out that BLOCK_M = 256, BLOCK_N = 32 used around 512 registers, so still we don't any spills, we still can only run one wave per SIMD, but are doing larger computations. To run this config, I needed to solve some LDS issues that I encountered (detailed explanation in the code). Another way going about this problem would be an opposite of this. Instead of trying to increase register pressure to have better utilization of SIMD, we can try to decrease it as much as we can, so wave needs < 256 registers. That way we could naturally fit 2 waves per SIMD, thus increasing occupancy. TLDR: This PR improves performance by making the best of already low CU occupancy by increasing register pressure as much as possible without having any spills. |
I remembered the default setting of |
|
||
assert(minIdx >= 0 && minIdx < factorizedNumWarps.size()); | ||
auto warpsPerCTAPair = factorizedNumWarps[minIdx]; | ||
std::tie(tmpCvt, newEpliogueCvt) = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: epilogue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, thanks!
What a coincidence, few days ago I watched your talk https://www.youtube.com/watch?v=VbFqA9rvxPs. Very interesting presentation! :)
// clang-format off | ||
// | ||
// LDS usage of this op is roughly calculated as: | ||
// LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layoput)[1] * sizeof(data_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit pick
// LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layoput)[1] * sizeof(data_type) | |
// LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layout)[1] * sizeof(data_type) |
// LDS_USAGE = warpsPerCTA(mfma_layout)[0] * warpsPerCta(blocked_layout)[1] * C, | ||
// where C = 32 * sizePerWarp(blocked_layout)[1] * threadsPerWarp(blocked_layout)[1] * sizeof(data_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these lines redundant?
int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt); | ||
int newCvtLDS = getCvtOpLDSUsage(newEpliogueCvt); | ||
if (tmpCvtLDS <= LDSSize && newCvtLDS <= LDSSize) { | ||
int LDSUsage = tmpCvtLDS + newCvtLDS; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does lifetimes of scratch buffers of tmpCvt
and newEpilogueCvt
overlap?
If not I think it is better to choose maximum from these values instead of sum them.
a2e55a7
to
94d843a
Compare
69b8563
to
461d72e
Compare
94d843a
to
408a555
Compare
461d72e
to
0ae9ffd
Compare
408a555
to
5ebbd48
Compare
0ae9ffd
to
c58d6be
Compare
a8e659c
to
a6db42d
Compare
a6db42d
to
8348b79
Compare
This PR reduces LDS usage in the epilogue by breaking the
convert_layout #mfma --> #blocked
to multipleconvert_layout
s, each of which uses less LDS than the original one. The issue with the original convert_layout is that padding is used in LDS to avoid bank conflicts. It's simpler than swizzling but extra LDS space is required.Note: