-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BACKEND][NVIDIA] Add DotOp Hoisting Pass for WGMMA and Add Lowering for SMEM-to-MMAv3 DotOp Copy #5003
base: main
Are you sure you want to change the base?
Conversation
@lezcano I've transferred the previous PR here per your request. Also @Jokeren I've addressed some of your comments in the latest commit. The remaining comments I've responded to, asking for clarification or explaining why I think changes won't be needed. We can discuss further, and I'll add changes to this PR from now on. One thing that needs to be resolved before merge is that this PR is only the first part of my changes. The second PR in the XLA fork includes optimizations that may be necessary to see perf gains. In addition, even with the second PR's changes, we don't see perf gain for all shapes. It looks like that for some smaller shapes we currently see perf loss, and so further optimizations (or heuristics to enable/disable hoisting) may be necessary. I'm not sure what the merge strategy is for these kinds of larger changes, so please advise on this, thanks :) |
f47f5d6
to
a08b09b
Compare
Perf loss is concerning to me |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work.
One thing I wonder is how the wgmma pipelining will work in this case.
PTX spec says:
Accessing the accumulator register or the input register containing the fragments of matrix A of a wgmma.mma_async instruction without first performing a wgmma.wait_group instruction that waits on a wgmma-group including that wgmma.mma_async instruction is undefined behavior.
Therefore when we do wgmma pipelining with operand coming from register we are going to break this rule as we would keep re-using the same register for A operand.
I believe ptxas will see that and fallback but that is likely to cause significant performance problems.
Is this something we need to handle? Do you know how libraries handle it?
|
||
// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A | ||
// is in registers). | ||
bool canHoistDotOpEncV3(Operation* op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this any different than canHoistDotOpEncV2
? I would expect it to be the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're almost the same; the only difference is that MMAv3 hoisting doesn't support downcasting yet, because the lowering of shared-to-MMAv3-dotop-copy logic that I added doesn't yet support it.
For v3, I also added stricter checks like
// Must have exactly one result and at least one operand
if (op->getNumOperands() == 0 || op->getNumResults() != 1)
return false;
// Operands and results must be of RankedTensorType and Blocked or DotOp
if (!(all_of(op->getOperandTypes(), isBlockedOrDotOpRankedTensor) &&
all_of(op->getResultTypes(), isBlockedOrDotOpRankedTensor)))
return false;
but left the v2 one intact in case something breaks.
// First pass: clone ops; the result values are cloned as well, but the operands still | ||
// refer to the original result values | ||
for (Operation *op : slice) { | ||
auto newOp = rewriter.clone(*op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should the clone op be inserted right before the old op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean that I should use setInsertionPoint
before the clone
s?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah? To avoid pulling all the operations down to the dot
} | ||
|
||
// Step 5b: Change the result to have DotOp rather than Blocked encoding | ||
auto resTy = dyn_cast<RankedTensorType>(op->getResult(0).getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use cast
when you know that the cast should succeeded
// In case LHS is in registers, don't pipeline for now TODO(ggengnv) is this necessary? | ||
auto op = *alloc->getUsers().begin(); | ||
if (auto localLoad = dyn_cast<ttg::LocalLoadOp>(op)) { | ||
auto resTy = cast<RankedTensorType>(localLoad->getResultTypes()[0]); | ||
if (!resTy || isa<ttg::DotOperandEncodingAttr>(resTy.getEncoding())) | ||
return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this sounds like it will be inefficient? Why do we need that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this isn't necessary and I forgot to delete this TODO - I originally had pipelining logic enabled in this PR but there were concerns that this PR was getting too large, so I separated it into another PR.
If you think it's more natural, I can add back the pipelining logic into this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I just saw your other comment. I can keep the changes separate for now and combine them before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding splitting the PR, I think the PR could be split between the changes to kWidth
, which are simple and benevolent, and the hoisting + pipelining logic, which seem a bit trickier. @ggengnv how does that sound?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good to me.
by changes to kWidth
, you mean the lowering of shared to register copies for MMAv3, correct?
the changes should be able to be split cleanly. only thing is that it might be hard to test by itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
everything but the hoisting pass, yes.
When it comes to testing, you can add a couple lit tests that exercise kWidth != 4 / elemSize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but yeah, since these changes are orthogonal splitting it in two shouldn't be terribly difficult. Tag me as a reviewer once you put it up and I'll approve it. That should make this PR much more lean, and it should be then fine to merge this PR and the pipeline PR into a single manageable PR
// This ordering is decided when a tensor in DotOpEnc is lowered into llvm. | ||
// For WGMMA this happens in both SharedToDotOperand and MMAToDotOperand. | ||
// Thus, both lowerings must obey this above ordering for the below code to be correct. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decision should be based on the layout definition rather than a convention between different lowering. This comment is a bit misleading and maybe we should more explicitly describe the layout instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's currently nothing in the dot operand layout attributes that would indicate the ordering of matM and matK though, so I assumed it was just implicit logic. I could move this comment to the definition of DotOpEncoding or perhaps remove it altogether to avoid confusion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes the layout is not well documented and/or defined but this is how it should work :) I think moving it to DotOpEncoding is good, this is still valuable in my opinion
What workload did you use to measure this? If we need both PRs to avoid perf loss we should merge them together, they can be reviewed separately but we should avoid creating perf regressions |
You're right -- I believe the pipelining logic that I currently have in the other PR suffers from this issue. I believe CUTLASS uses more than one RF "buffer" to ping-pong the loads, so that's something I should look into.
I did simple mixed-precision GEMMs on various shapes with one or more dimensions being small (<4k). But now that you bring up the issue with pipelining, I believe I should fix the pipelining logic first before benchmarking again. The perf loss might disappear then. |
elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); | ||
} | ||
for (int k = 0; k < n1; ++k) | ||
if (isHopper) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like it's built on top of bad base
Line 514 in 86a2ac7
for (int k = 0; k < n1; ++k) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, made a mistake there. Just reverted -- please ignore :)
interesting, I wonder how that's possible unless we unroll the loop? |
afaik cutlass subdivides each block ("ktiles" in cutlass terminology) into subblocks ("kblocks"). So while TMA operates on the granularity of blocks, each WGMMA instruction will handle only a subblock. The inner loop of iterating over the subblocks of each block is unrolled, and the shared-to-register copies and WGMMAs are interleaved. I'm not fully certain if it's a ping-pong buffer or something else -- in the process of confirming it |
interesting, thanks for the information. It would probably be tricky to implement it that way in triton. I think there are other strategy we can apply to respect ptx spec. We could have the wgmma_wait at the end of the loop or we could place it before the a operand is set and add some IR changes to make sure nobody re-orders things incorrectly. If this is the main blocker for this PR to land I would suggest having this pass turned off by default so that the infra can be pushed. Then we can work on wgmma pipelining changes to make this right and ideally performant. |
Per @lezcano's suggestion I've first split out the dotOp lowering changes (which is a substantial part of this PR): #5009 After that PR's merged, I think there's still an edge case I need to resolve. But sg -- I can turn off this pass by default after everything else's resolved. |
Allows for upcasting in DotOp encoding in RF. This lowering path is not currently in use; pending #5003
Two bugfixes following #5009. - When `BLOCK_M=64` and `num_warps > 4`, the order of warps for DotOpEncoded tensor should be M-major instead of N-major, since WGMMA expects the 4 warps in each warp group to be stacked along the M dimension. - Should use `mmaBitwidth` instead of `bitwidth` when calculating `numRep` in `SharedToDotOperandMMAv2OrV3`. This was missed in a bad rebase. @lezcano I encountered these bugs when attempting to locally test the [DotOp hoisting PR](#5003) after rebasing (they normally would be caught by `test_core.py` but that path was not yet enabled in the last PR). With these fixes added, I was able to successfully validate against pytorch.
dceb453
to
e9217d1
Compare
Update: cherrypicked pipelining changes into this PR.
|
// Hopper may not contain 32b along kWidth; Ampere always does | ||
int kBits = 8 * elemBytes * kWidth; | ||
assert(kBits == 32 || isHopper); | ||
int vecSize = kBits / canonBits; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for non-transposed case, vecSize = kBits / canonBits
was always 1, so the previous logic (isHopper ? 1 : ...)
worked.
for transposed case, vecSize
can be more than 1, since we're loading each value separately (i.e. canonBits = 8 * elemBytes
).
31b233e
to
b5c407f
Compare
@ThomasRaoux I have obtained some benchmark results. cc @lezcano All GEMMs are bf16 x int8, both operands k-major. The numbers are in us
You can notice that for the last three rows as well as some other small shapes (i.e. all shapes with small M), it's faster to just do AxB, i.e. not swap operands. These cases just use Ampere MMAv2 and the LHS optimization wouldn't apply. |
you meant A and B are swapped right? |
oh yes, that was a typo. |
why are the result different for those 3 rows between BxA and BxA RS? I thought when using MMAv2 the PR would have no effect? |
BxA and BxA RS are still with the operands swapped, so the PR would apply. AxB is the one without swapped; hence there's only one column for AxB. AxB is faster but I included BxA and BxA RS for reference anyway. |
But it wouldn't use wgmma even with the operands swapped? I thought the code in the PR would not change this case |
What I was concerned about earlier was that this PR's optimization might not make sense when pipelining is disabled, since I expected RS WGMMA (operand A in registers) to be only useful in conjunction with pipelining. But I just tested on a shape with small K, where we might realistically disable pipelining. The shape profiled was (32, 4096, 128) with block size (64, 16, 128), num_warps = 4, and num_stages = 1. A and B were swapped. So I think I might've been wrong, and we won't need to worry about this case at all. |
interesting, I expect that like for mmav2 this should still be better even without pipelining but I haven't looked at IR for this case in a while. Anyway if it improves runtime I wouldn't worry much about it and this is a case we can optimize in future PRs |
We do have cases that run faster with num_stages=1 so I would say it would still be useful if we can avoid regressing this case. |
I think I can simply disable the pass if the encapsulating for loop has num_stages == 1. |
@Moerafaat can you check what regime of |
If this is not too much work then I would say it's preferable to have that option. |
I'm checking internally for this information and will get back to you. |
SGTM. That being said, we should wait for @Moerafaat who is looking into their kernels. In the reasonably likely case that they use it for small K, it may be just alright to always keep this on and avoid branching the logic. |
to be honest I'm a bit concerned with those kind of ad hoc heuristics. It is easy to have those creep up and make it even harder to address the problem later on. If there are proven regressions on Google's side then we will need to get to the bottom of it, maybe the pass can be turned off on Google's side for some time but don't think adding heuristics based on paramters that should be independent is a good solution. |
also my understanding is that XLA has its own pass manager so it is easy to not enable the optimizations until more performance problems are resolved if needed |
I haven't seen all cases where we have num_stages = 1, but already we see a few examples where K is between 4k and 8k and the best configuration includes num_stages=1. A close second usually has num_stages>1 but that still means that we can get best tilings without pipelining. That being said, we also don't mind the change if we can enable/disable an entire pass (though that could mean that we would need to address regressions later on to enable it from XLA). |
717941d
to
20f9ba0
Compare
Hopper has two kinds of WGMMAs, "SS" (both operands in shmem) and "RS" (LHS operand A in registers).
In cases where we apply elementwise operations on A before WGMMA, Triton previously will copy A from global memory (GMEM) into registers (RF), perform the elementwise ops, and then copy to shared memory (SMEM) to perform SS WGMMA.
This PR adds an optimization for the case above to use RS GEMM. This requires the following changes:
NOTE: This may not see perf gain, and may even see perf loss, for certain shapes (e.g. small-K), and additional optimizations are in a separate PR (still more optimizations are WIP). Please advise on the merging strategy.