Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Introduce an OptimizeLDSUsage pass #3730

Merged
merged 22 commits into from
Jul 20, 2024

Conversation

binarman
Copy link
Contributor

@binarman binarman commented Apr 22, 2024

This PR inroduces OptimizeLDSUsage pass which generalizes LDS optimization,
which was part of DecomposeUnsupportedLayouts pass.

Overall it tries to reduce LDS usage of convert op by adding intermediate layout
in conversion.

int tmpCvtLDS = getCvtOpLDSUsage(tmpCvt);
int newCvtLDS = getCvtOpLDSUsage(newEpilogueCvt);
if (tmpCvtLDS <= LDSSize && newCvtLDS <= LDSSize) {
int LDSUsage = std::max(tmpCvtLDS, newCvtLDS);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@binarman
Copy link
Contributor Author

To clarify, what this PR is doing:

At the moment we have an optimization in DecomposeUnsupportedLayouts pass, which is looking for convert_layout operations that requires more shared memory, than we have. Optimization tries to decompose such convert_layouts in two converts with some intermediate layout, In many cases this helps to reduce LDS usage.

Current approach can not optimize convert_layout in hopper flash attention test, so LDS overflows.
This PR introduces two things:

  1. adding more intermediate layouts variants
  2. doing global analysis, to catch convert_layout operation which do not overflow LDS on its own, but overflows memory if there are some shared tensors.

First item is needed, because old set of intermediate layouts was not able to optimize conversions found int hopper FA.

Second item is needed to generalize optimization. For example, take a look at this example:

 %1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared>
 %2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma>
 %3 = triton_gpu.local_load %1 : !tt.memdesc<128x128xf16, #shared> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>

%1 consumes 16 KB of LDS, %2 requires ~64KB of lds for a scratch buffer.
If there are no padding, %2 can be exactly 64KB, which fits into LDS, but %1 and %2 together do not.

P.s. I had some concerns that new optimization can affect existing benchmarks. I had an offline conversation with author of original optimization (@oplavsic) and we decided that best to leave old optimization functionally same, but move some functions in common place and make them parameterizable.

* ->
* %1 = cvtOp %0 (srcLayout -> dstLayout)
* %2 = cvtOp %0 (srcLayout -> tmpLayout)
* %3 = cvtOp %1 (tmpLayout -> dstLayout)
Copy link
Collaborator

@zhanglx13 zhanglx13 Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be %3 = cvtOp %2?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function creates two cvtOps based on a given cvtOps. Could you be more specific about which cvtOp is the new one and which is the old one in the comment?

// LDS reduction is possible by changing the shape of WarpsPerCta attribute in
// mfma layout. The implicit LDS usage of cvt(mfma->blocked) op depends on the
// number of warps per CTA that mfma layout uses along x dimension and block
// layout uses across y dimension.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little confusing whether x refers to the row or column. We can use dim 0 and dim 1 instead.

// LDS usage of this op is roughly calculated as:
// LDS_USAGE = getShapePerCTA(mfma_layout)[0] * getShapePerCTA(blocked_layout)[1] * sizeof(data_type)
// LDS_USAGE = warpsPerCTA(mfma_layout)[0] * warpsPerCta(blocked_layout)[1] * C,
// where C = 32 * sizePerWarp(blocked_layout)[1] * threadsPerWarp(blocked_layout)[1] * sizeof(data_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is 32 hardcoded? Is it assuming mfma32 is used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest, I did not look deep into this comment, just copied it from original algorithm.
It was implemented a log ago, we probably had only mfma32 at the time.

I'll take a closer look and adjust.

for (int i = 0; i < tmpLayouts.size(); i++) {
auto tmpLayout = tmpLayouts[i];
std::tie(tmpCvt, newEpilogueCvt) =
createNewConvertOps(builder, cvtOp, tmpLayout);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this loop, we only want to know the index of the tmpLayout that gives us the min LDS usage. Do we really need to create the cvtOps and erase them at the end of each iteration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creation/deletion is needed because algorithm use getScratchConfigForCvtLayout(ConvertLayoutOp unsigned&, unsigned&) function from Allocation.cpp to estimate LDS usage.

I can introduce new interface, so we can avoid these redundant stuff.

* @return mapping from operation to list of live LDS buffers
*/
std::map<mlir::Operation *, SmallVector<Allocation::BufferId>>
analyzeBufferLiveness(FunctionOpInterface func, const Allocation *allocations) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not AMD specific. Maybe we should put it in Analysis/Allocation.cpp?

}

SmallVector<triton::gpu::ConvertLayoutOp>
findLDSBottleneck(ModuleAllocation &allocAnalysis, FunctionOpInterface func) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also put this to the common part since it can benefit NV path. But after realizing NV GPUs have pretty large shared memory ....

@zhanglx13
Copy link
Collaborator

@binarman I have a question regarding tryMinimizeLDS.

%1 consumes 16 KB of LDS, %2 requires ~64KB of lds for a scratch buffer.
If there are no padding, %2 can be exactly 64KB, which fits into LDS, but %1 and %2 together do not.

In this example, %2 will be a candidate from findLDSBottleneck, and tryMinimizeLDS is called on it. However, tryMinimizeLDS will early return since currLDSUsage <= LDSSize. I think the problem is that tryMinimizeLDS should not take LDSSize as target, instead it should take LDSSize - offset as target, where offset can be kept when we look for candidates in findLDSBottleneck.


namespace {

constexpr int LDSSize = 65536;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we not hardcode it but pass it from the front end?

@binarman
Copy link
Contributor Author

@zhanglx13 about tryMinimizeLDS

Condition is filters out cases which will definitely overflow LDS and there are no early exit.
We can actually remove this condition at all, because we are looking for the smallest LDS usage anyway.

@zhanglx13
Copy link
Collaborator

yes, at least the early return condition needs to be removed
And when you find the minLDSUsage, it could still be larger than LDSSize - offset, so tryMinimizeLDS should also return nothing in this case.

@binarman
Copy link
Contributor Author

the early return condition needs to be removed

Now I see, I've missed this early return, thank you!
At first I thought you were talking about early exit from loop.

module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @alloc_convert_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>) attributes {noinline = false} {
%1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared>
%2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I forgot to mention that I think this cvtOp is decomposed just because it uses more than 64 KB of LDS since padding is used. Therefore, this test does not test the functionality that a cvtOp could still be decomposed even it uses less than 64 KB LDS.

Copy link
Contributor Author

@binarman binarman Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added new test: it uses fp16 instead of fp32, so cvt scratch buffer is x2 smaller

@@ -147,6 +147,8 @@ def make_llir(src, metadata, options):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm)
lds_size = 65536
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, where to place code choosing LDS size, so it is plain constant at this point.
Let's introduce some interface in later PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be convenient to rebase onto Lei's PR #3808

@antiagainst antiagainst marked this pull request as draft April 30, 2024 22:15
@antiagainst
Copy link
Collaborator

(coverting to draft as we chatted--need to first get all issues addressed from AMD side before making it as open)

@binarman
Copy link
Contributor Author

binarman commented May 1, 2024

@antiagainst @zhanglx13
This PR is ready for review, PTAL 🙂

namespace triton {
namespace AMD {

constexpr int kPtrBitWidth = 64;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to hardcode the pointer bitwidth? Can we just use inline constant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is copied from Allocation.cpp (it is not part of public interface).
Maybe I can actually take this part in some public interface, for example in Analysis/Utility module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I was talking about: binarman#6

res.LDS = std::numeric_limits<typeof(res.LDS)>::max();

triton::gpu::ConvertLayoutOp tmpCvt;
triton::gpu::ConvertLayoutOp newEpilogueCvt;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above three lines are not used.

threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1];
auto order = triton::gpu::getOrder(srcEnc);
auto layoutCTA = triton::gpu::getCTALayout(srcEnc);
auto fallbackLayout = triton::gpu::BlockedEncodingAttr::get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. For this fallbackLayout, all the components, except warpsPerCTA, are loop invariants. Maybe we can create a base BlockLayout out of the loop and use createTmpLayout(blockEnc, warpsPerCTA) inside the loop to update the warpsPerCTA only?
  2. Why is 8 chosen in warpSize / 8?
  3. In general, why we need this fallbackLayout? Is it covered by either srcEnc or dstEnc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Why is 8 chosen in warpSize / 8

For wave64 it will be [8, 8], for wave32 it will be [4, 8]. This is done to make layout tile "square", so no dim size of minimal tile is dominating.

  1. In general, why we need this fallbackLayout? Is it covered by either srcEnc or dstEnc?

In some cases different warpsPerCTA of src or dst layout is not enough to reduce LDS usage, but some other layouts can be appropriate. These fallback layouts are designed to have as compact tile as possible, i.e. elementsPerThread = [1, ... 1], and threadsPerWarp are as "square" as possible.

I believe, that in most cases fallback layout will be chosen as a temporary layout. This could be non optimal in terms of performance, but it is fine, because without this transformation kernel will not compile at all.

return;
}

triton::gpu::ConvertLayoutOp tmpCvt;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we using this tmpCvt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, will rewrite this part as done in DecomposeUnsupportedConversions pass.

if (offset + size > LDSSize) {
auto maxScratchBufferSize = computeMaxScratchBufferSize(
cvtOp, funcAnalysis, liveBuffers[cvtOp]);
candidates.push_back({cvtOp, maxScratchBufferSize});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is very confusing to me.

  1. Why do we need opBuffer? Just to check it's valid?
  2. Does liveBuffers[cvtOp] include opBuffer? To put it another way, does one of the bufId's for the scratch buffer allocated for this cvtOp?
  3. It seems to me that this function assumes that there is at most one extra buffer that can overlap with the buffer for this cvtOp? If there are more live buffers that overlap with this cvtOp, we should still only push cvtOp into candidates once, but compute maxScratchBufferSize based on all overlapped live buffers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need opBuffer? Just to check it's valid?

Sorry, this is reminder after refactoring, I used to pass it to computeMaxScratchBufferSize, but then start compute it inside function.

Does liveBuffers[cvtOp] include opBuffer? To put it another way, does one of the bufId's for the scratch buffer allocated for this cvtOp?

Yes, scratch buffer is the same as "long-living" buffers, the only difference, that it's live time is limited to one operation.

It seems to me that this function assumes that there is at most one extra buffer that can overlap with the buffer for this cvtOp? If there are more live buffers that overlap with this cvtOp, we should still only push cvtOp into candidates once, but compute maxScratchBufferSize based on all overlapped live buffers.

No, there could be any number of buffers with live-time overlapping with scratch buffer.

let me remove loop from this function, it should make algorithm clearer.

int64_t scratchBufferSize = allocation->getAllocatedSize(scratchBufferId);
size_t totalLDSConsumption = 0;
for (auto buf : liveBuffers)
totalLDSConsumption = std::max(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If all liveBuffers are live at this cvtOp, should we use sum instead of max here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Max is more conservative metric in this sense. Let's consider that we have "holes" in memory:

Screenshot from 2024-05-08 16-12-09

let's consider that green buffer is scratch buffer that we want to optimize, viollet and blue are long-living buffers in shared layout.

Hole is created, because pink tensor is allocated on tick 1 and reallocated on tick 2, but previously allocated violet tensor continue live.

Summarizing buffer sizes will tell that we have 20 KB(3 * 8 KB) for scratch buffer, but in reality we probably wan to make it smaller.

* space available for scratch buffer.
*/
int64_t
computeMaxScratchBufferSize(triton::gpu::ConvertLayoutOp op,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe computeTargetBufferSize? I feel like "target" or "desired" is more accurate about what we want to do here.

@@ -169,6 +169,9 @@ def make_llir(src, metadata, options):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm, options.arch)
# experimental parameter, specifies custom LDS usage limit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on this parameter? What does it mean when it's set to a non-zero and zero?
Especially when it's set to non-zero value, does it mean the total LDS usage is guaranteed to be lower than that? Or is it just a hint?

Copy link
Collaborator

@zhanglx13 zhanglx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved.
Could you please add more comments for custom_lds_size? See below.

@zhanglx13 zhanglx13 marked this pull request as ready for review July 19, 2024 02:59
@antiagainst antiagainst force-pushed the reduce_lds_usage branch 3 times, most recently from a0fa2f8 to b825705 Compare July 20, 2024 06:05
@antiagainst antiagainst changed the title [AMD] OptimizeLDSUsage pass [AMD] Introduce an OptimizeLDSUsage pass Jul 20, 2024
@antiagainst antiagainst enabled auto-merge (squash) July 20, 2024 06:31
@antiagainst antiagainst merged commit aa3ac0a into triton-lang:main Jul 20, 2024
6 checks passed

} // namespace

namespace mlir::triton::AMD {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@antiagainst
You changed this part from multiple line to one, is it preferred code style now?

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 30, 2024
Imported from GitHub PR openxla/xla#15477

This build break is introduced by openxla/xla#15257

and ROcm has a new optimized LDS pass on openai/triton triton-lang/triton#3730

@xla-rotation
Copybara import of the project:

--
6f86fdbd090a4fc3fa2346ba6969d7ddeae773e3 by Chao Chen <[email protected]>:

updated rocm triton OptimizeLDSUsage pass due to triton-lang/triton#3730

Merging this change closes #15477

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15477 from ROCm:ci_hotfix_20240730 6f86fdbd090a4fc3fa2346ba6969d7ddeae773e3
PiperOrigin-RevId: 657498811
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Jul 30, 2024
Imported from GitHub PR #15477

This build break is introduced by #15257

and ROcm has a new optimized LDS pass on openai/triton triton-lang/triton#3730

@xla-rotation
Copybara import of the project:

--
6f86fdb by Chao Chen <[email protected]>:

updated rocm triton OptimizeLDSUsage pass due to triton-lang/triton#3730

Merging this change closes #15477

COPYBARA_INTEGRATE_REVIEW=#15477 from ROCm:ci_hotfix_20240730 6f86fdb
PiperOrigin-RevId: 657634867
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 30, 2024
Imported from GitHub PR openxla/xla#15477

This build break is introduced by openxla/xla#15257

and ROcm has a new optimized LDS pass on openai/triton triton-lang/triton#3730

@xla-rotation
Copybara import of the project:

--
6f86fdbd090a4fc3fa2346ba6969d7ddeae773e3 by Chao Chen <[email protected]>:

updated rocm triton OptimizeLDSUsage pass due to triton-lang/triton#3730

Merging this change closes #15477

PiperOrigin-RevId: 657634867
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants