-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] Introduce an OptimizeLDSUsage pass #3730
Conversation
int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt); | ||
int newCvtLDS = getCvtOpLDSUsage(newEpilogueCvt); | ||
if (tmpCvtLDS <= LDSSize && newCvtLDS <= LDSSize) { | ||
int LDSUsage = std::max(tmpCvtLDS, newCvtLDS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@oplavsic
I've changed this part of the algorithm: https://github.com/openai/triton/pull/3730/files#diff-0d63e5cd9cf58151489fd9a5206b43a0902939004e58f3a7ec5258fa7d473267L227
Was it crucial?
To clarify, what this PR is doing: At the moment we have an optimization in Current approach can not optimize convert_layout in hopper flash attention test, so LDS overflows.
First item is needed, because old set of intermediate layouts was not able to optimize conversions found int hopper FA. Second item is needed to generalize optimization. For example, take a look at this example:
%1 consumes 16 KB of LDS, %2 requires ~64KB of lds for a scratch buffer. P.s. I had some concerns that new optimization can affect existing benchmarks. I had an offline conversation with author of original optimization (@oplavsic) and we decided that best to leave old optimization functionally same, but move some functions in common place and make them parameterizable. |
* -> | ||
* %1 = cvtOp %0 (srcLayout -> dstLayout) | ||
* %2 = cvtOp %0 (srcLayout -> tmpLayout) | ||
* %3 = cvtOp %1 (tmpLayout -> dstLayout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be %3 = cvtOp %2
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function creates two cvtOps based on a given cvtOps. Could you be more specific about which cvtOp is the new one and which is the old one in the comment?
// LDS reduction is possible by changing the shape of WarpsPerCta attribute in | ||
// mfma layout. The implicit LDS usage of cvt(mfma->blocked) op depends on the | ||
// number of warps per CTA that mfma layout uses along x dimension and block | ||
// layout uses across y dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little confusing whether x refers to the row or column. We can use dim 0 and dim 1 instead.
// LDS usage of this op is roughly calculated as: | ||
// LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layout)[1] * sizeof(data_type) | ||
// LDS_USAGE = warpsPerCTA(mfma_layout)[0] * warpsPerCta(blocked_layout)[1] * C, | ||
// where C = 32 * sizePerWarp(blocked_layout)[1] * threadsPerWarp(blocked_layout)[1] * sizeof(data_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is 32 hardcoded? Is it assuming mfma32 is used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest, I did not look deep into this comment, just copied it from original algorithm.
It was implemented a log ago, we probably had only mfma32 at the time.
I'll take a closer look and adjust.
for (int i = 0; i < tmpLayouts.size(); i++) { | ||
auto tmpLayout = tmpLayouts[i]; | ||
std::tie(tmpCvt, newEpilogueCvt) = | ||
createNewConvertOps(builder, cvtOp, tmpLayout); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this loop, we only want to know the index of the tmpLayout that gives us the min LDS usage. Do we really need to create the cvtOps and erase them at the end of each iteration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This creation/deletion is needed because algorithm use getScratchConfigForCvtLayout(ConvertLayoutOp unsigned&, unsigned&)
function from Allocation.cpp
to estimate LDS usage.
I can introduce new interface, so we can avoid these redundant stuff.
* @return mapping from operation to list of live LDS buffers | ||
*/ | ||
std::map<mlir::Operation *, SmallVector<Allocation::BufferId>> | ||
analyzeBufferLiveness(FunctionOpInterface func, const Allocation *allocations) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not AMD specific. Maybe we should put it in Analysis/Allocation.cpp?
} | ||
|
||
SmallVector<triton::gpu::ConvertLayoutOp> | ||
findLDSBottleneck(ModuleAllocation &allocAnalysis, FunctionOpInterface func) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also put this to the common part since it can benefit NV path. But after realizing NV GPUs have pretty large shared memory ....
@binarman I have a question regarding
In this example, |
|
||
namespace { | ||
|
||
constexpr int LDSSize = 65536; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we not hardcode it but pass it from the front end?
@zhanglx13 about Condition is filters out cases which will definitely overflow LDS and there are no early exit. |
yes, at least the early return condition needs to be removed |
Now I see, I've missed this early return, thank you! |
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { | ||
tt.func public @alloc_convert_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>) attributes {noinline = false} { | ||
%1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared> | ||
%2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I forgot to mention that I think this cvtOp is decomposed just because it uses more than 64 KB of LDS since padding is used. Therefore, this test does not test the functionality that a cvtOp could still be decomposed even it uses less than 64 KB LDS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added new test: it uses fp16 instead of fp32, so cvt scratch buffer is x2 smaller
74d3bad
to
ada48d1
Compare
third_party/amd/backend/compiler.py
Outdated
@@ -147,6 +147,8 @@ def make_llir(src, metadata, options): | |||
pm = ir.pass_manager(mod.context) | |||
pm.enable_debug() | |||
amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm) | |||
lds_size = 65536 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure, where to place code choosing LDS size, so it is plain constant at this point.
Let's introduce some interface in later PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be convenient to rebase onto Lei's PR #3808
(coverting to draft as we chatted--need to first get all issues addressed from AMD side before making it as open) |
@antiagainst @zhanglx13 |
namespace triton { | ||
namespace AMD { | ||
|
||
constexpr int kPtrBitWidth = 64; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need to hardcode the pointer bitwidth? Can we just use inline constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is copied from Allocation.cpp
(it is not part of public interface).
Maybe I can actually take this part in some public interface, for example in Analysis/Utility
module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what I was talking about: binarman#6
res.LDS = std::numeric_limits<typeof(res.LDS)>::max(); | ||
|
||
triton::gpu::ConvertLayoutOp tmpCvt; | ||
triton::gpu::ConvertLayoutOp newEpilogueCvt; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The above three lines are not used.
threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1]; | ||
auto order = triton::gpu::getOrder(srcEnc); | ||
auto layoutCTA = triton::gpu::getCTALayout(srcEnc); | ||
auto fallbackLayout = triton::gpu::BlockedEncodingAttr::get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- For this fallbackLayout, all the components, except warpsPerCTA, are loop invariants. Maybe we can create a base BlockLayout out of the loop and use
createTmpLayout(blockEnc, warpsPerCTA)
inside the loop to update the warpsPerCTA only? - Why is 8 chosen in
warpSize / 8
? - In general, why we need this fallbackLayout? Is it covered by either srcEnc or dstEnc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Why is 8 chosen in warpSize / 8
For wave64 it will be [8, 8], for wave32 it will be [4, 8]. This is done to make layout tile "square", so no dim size of minimal tile is dominating.
- In general, why we need this fallbackLayout? Is it covered by either srcEnc or dstEnc?
In some cases different warpsPerCTA of src or dst layout is not enough to reduce LDS usage, but some other layouts can be appropriate. These fallback layouts are designed to have as compact tile as possible, i.e. elementsPerThread = [1, ... 1]
, and threadsPerWarp
are as "square" as possible.
I believe, that in most cases fallback layout will be chosen as a temporary layout. This could be non optimal in terms of performance, but it is fine, because without this transformation kernel will not compile at all.
return; | ||
} | ||
|
||
triton::gpu::ConvertLayoutOp tmpCvt; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we using this tmpCvt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, will rewrite this part as done in DecomposeUnsupportedConversions
pass.
if (offset + size > LDSSize) { | ||
auto maxScratchBufferSize = computeMaxScratchBufferSize( | ||
cvtOp, funcAnalysis, liveBuffers[cvtOp]); | ||
candidates.push_back({cvtOp, maxScratchBufferSize}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is very confusing to me.
- Why do we need
opBuffer
? Just to check it's valid? - Does
liveBuffers[cvtOp]
includeopBuffer
? To put it another way, does one of thebufId
's for the scratch buffer allocated for this cvtOp? - It seems to me that this function assumes that there is at most one extra buffer that can overlap with the buffer for this cvtOp? If there are more live buffers that overlap with this cvtOp, we should still only push
cvtOp
intocandidates
once, but computemaxScratchBufferSize
based on all overlapped live buffers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need opBuffer? Just to check it's valid?
Sorry, this is reminder after refactoring, I used to pass it to computeMaxScratchBufferSize
, but then start compute it inside function.
Does liveBuffers[cvtOp] include opBuffer? To put it another way, does one of the bufId's for the scratch buffer allocated for this cvtOp?
Yes, scratch buffer is the same as "long-living" buffers, the only difference, that it's live time is limited to one operation.
It seems to me that this function assumes that there is at most one extra buffer that can overlap with the buffer for this cvtOp? If there are more live buffers that overlap with this cvtOp, we should still only push cvtOp into candidates once, but compute maxScratchBufferSize based on all overlapped live buffers.
No, there could be any number of buffers with live-time overlapping with scratch buffer.
let me remove loop from this function, it should make algorithm clearer.
int64_t scratchBufferSize = allocation->getAllocatedSize(scratchBufferId); | ||
size_t totalLDSConsumption = 0; | ||
for (auto buf : liveBuffers) | ||
totalLDSConsumption = std::max( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If all liveBuffers are live at this cvtOp, should we use sum instead of max here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Max is more conservative metric in this sense. Let's consider that we have "holes" in memory:
let's consider that green buffer is scratch buffer that we want to optimize, viollet and blue are long-living buffers in shared layout.
Hole is created, because pink tensor is allocated on tick 1 and reallocated on tick 2, but previously allocated violet tensor continue live.
Summarizing buffer sizes will tell that we have 20 KB(3 * 8 KB) for scratch buffer, but in reality we probably wan to make it smaller.
* space available for scratch buffer. | ||
*/ | ||
int64_t | ||
computeMaxScratchBufferSize(triton::gpu::ConvertLayoutOp op, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe computeTargetBufferSize
? I feel like "target" or "desired" is more accurate about what we want to do here.
67eaac0
to
28620cd
Compare
third_party/amd/backend/compiler.py
Outdated
@@ -169,6 +169,9 @@ def make_llir(src, metadata, options): | |||
pm = ir.pass_manager(mod.context) | |||
pm.enable_debug() | |||
amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm, options.arch) | |||
# experimental parameter, specifies custom LDS usage limit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on this parameter? What does it mean when it's set to a non-zero and zero?
Especially when it's set to non-zero value, does it mean the total LDS usage is guaranteed to be lower than that? Or is it just a hint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved.
Could you please add more comments for custom_lds_size
? See below.
a0fa2f8
to
b825705
Compare
b825705
to
01ab78f
Compare
|
||
} // namespace | ||
|
||
namespace mlir::triton::AMD { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@antiagainst
You changed this part from multiple line to one, is it preferred code style now?
Imported from GitHub PR openxla/xla#15477 This build break is introduced by openxla/xla#15257 and ROcm has a new optimized LDS pass on openai/triton triton-lang/triton#3730 @xla-rotation Copybara import of the project: -- 6f86fdbd090a4fc3fa2346ba6969d7ddeae773e3 by Chao Chen <[email protected]>: updated rocm triton OptimizeLDSUsage pass due to triton-lang/triton#3730 Merging this change closes #15477 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15477 from ROCm:ci_hotfix_20240730 6f86fdbd090a4fc3fa2346ba6969d7ddeae773e3 PiperOrigin-RevId: 657498811
Imported from GitHub PR #15477 This build break is introduced by #15257 and ROcm has a new optimized LDS pass on openai/triton triton-lang/triton#3730 @xla-rotation Copybara import of the project: -- 6f86fdb by Chao Chen <[email protected]>: updated rocm triton OptimizeLDSUsage pass due to triton-lang/triton#3730 Merging this change closes #15477 COPYBARA_INTEGRATE_REVIEW=#15477 from ROCm:ci_hotfix_20240730 6f86fdb PiperOrigin-RevId: 657634867
Imported from GitHub PR openxla/xla#15477 This build break is introduced by openxla/xla#15257 and ROcm has a new optimized LDS pass on openai/triton triton-lang/triton#3730 @xla-rotation Copybara import of the project: -- 6f86fdbd090a4fc3fa2346ba6969d7ddeae773e3 by Chao Chen <[email protected]>: updated rocm triton OptimizeLDSUsage pass due to triton-lang/triton#3730 Merging this change closes #15477 PiperOrigin-RevId: 657634867
This PR inroduces OptimizeLDSUsage pass which generalizes LDS optimization,
which was part of DecomposeUnsupportedLayouts pass.
Overall it tries to reduce LDS usage of convert op by adding intermediate layout
in conversion.