Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND][NVIDIA] Add DotOp Hoisting Pass for WGMMA and Add Lowering for SMEM-to-MMAv3 DotOp Copy #5003

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

ggengnv
Copy link
Contributor

@ggengnv ggengnv commented Oct 28, 2024

Hopper has two kinds of WGMMAs, "SS" (both operands in shmem) and "RS" (LHS operand A in registers).
In cases where we apply elementwise operations on A before WGMMA, Triton previously will copy A from global memory (GMEM) into registers (RF), perform the elementwise ops, and then copy to shared memory (SMEM) to perform SS WGMMA.

This PR adds an optimization for the case above to use RS GEMM. This requires the following changes:

  • In TritonGPU OptimizeDotOperands pass, add optimizations to change SS GEMM into RS GEMM.
  • Add TritonGPU -> LLVM lowering for copying from SMEM to RF in MMA v3 dotOperand layout.

NOTE: This may not see perf gain, and may even see perf loss, for certain shapes (e.g. small-K), and additional optimizations are in a separate PR (still more optimizations are WIP). Please advise on the merging strategy.

@ggengnv ggengnv requested a review from ptillet as a code owner October 28, 2024 22:22
@ggengnv
Copy link
Contributor Author

ggengnv commented Oct 28, 2024

@lezcano I've transferred the previous PR here per your request. Also @Jokeren

I've addressed some of your comments in the latest commit. The remaining comments I've responded to, asking for clarification or explaining why I think changes won't be needed. We can discuss further, and I'll add changes to this PR from now on.

One thing that needs to be resolved before merge is that this PR is only the first part of my changes. The second PR in the XLA fork includes optimizations that may be necessary to see perf gains.

In addition, even with the second PR's changes, we don't see perf gain for all shapes. It looks like that for some smaller shapes we currently see perf loss, and so further optimizations (or heuristics to enable/disable hoisting) may be necessary. I'm not sure what the merge strategy is for these kinds of larger changes, so please advise on this, thanks :)

@Jokeren
Copy link
Contributor

Jokeren commented Oct 28, 2024

This may not see perf gain, and may even see perf loss, for certain shapes

Perf loss is concerning to me

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work.
One thing I wonder is how the wgmma pipelining will work in this case.
PTX spec says:

Accessing the accumulator register or the input register containing the fragments of matrix A of a wgmma.mma_async instruction without first performing a wgmma.wait_group instruction that waits on a wgmma-group including that wgmma.mma_async instruction is undefined behavior.

Therefore when we do wgmma pipelining with operand coming from register we are going to break this rule as we would keep re-using the same register for A operand.
I believe ptxas will see that and fallback but that is likely to cause significant performance problems.
Is this something we need to handle? Do you know how libraries handle it?


// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A
// is in registers).
bool canHoistDotOpEncV3(Operation* op) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this any different than canHoistDotOpEncV2? I would expect it to be the same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're almost the same; the only difference is that MMAv3 hoisting doesn't support downcasting yet, because the lowering of shared-to-MMAv3-dotop-copy logic that I added doesn't yet support it.

For v3, I also added stricter checks like

  // Must have exactly one result and at least one operand
  if (op->getNumOperands() == 0 || op->getNumResults() != 1)
    return false;

  // Operands and results must be of RankedTensorType and Blocked or DotOp
  if (!(all_of(op->getOperandTypes(), isBlockedOrDotOpRankedTensor) &&
        all_of(op->getResultTypes(), isBlockedOrDotOpRankedTensor)))
    return false;

but left the v2 one intact in case something breaks.

// First pass: clone ops; the result values are cloned as well, but the operands still
// refer to the original result values
for (Operation *op : slice) {
auto newOp = rewriter.clone(*op);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the clone op be inserted right before the old op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean that I should use setInsertionPoint before the clones?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah? To avoid pulling all the operations down to the dot

}

// Step 5b: Change the result to have DotOp rather than Blocked encoding
auto resTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use cast when you know that the cast should succeeded

Comment on lines 387 to 393
// In case LHS is in registers, don't pipeline for now TODO(ggengnv) is this necessary?
auto op = *alloc->getUsers().begin();
if (auto localLoad = dyn_cast<ttg::LocalLoadOp>(op)) {
auto resTy = cast<RankedTensorType>(localLoad->getResultTypes()[0]);
if (!resTy || isa<ttg::DotOperandEncodingAttr>(resTy.getEncoding()))
return false;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sounds like it will be inefficient? Why do we need that?

Copy link
Contributor Author

@ggengnv ggengnv Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this isn't necessary and I forgot to delete this TODO - I originally had pipelining logic enabled in this PR but there were concerns that this PR was getting too large, so I separated it into another PR.

If you think it's more natural, I can add back the pipelining logic into this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I just saw your other comment. I can keep the changes separate for now and combine them before merging.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding splitting the PR, I think the PR could be split between the changes to kWidth, which are simple and benevolent, and the hoisting + pipelining logic, which seem a bit trickier. @ggengnv how does that sound?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me.

by changes to kWidth, you mean the lowering of shared to register copies for MMAv3, correct?

the changes should be able to be split cleanly. only thing is that it might be hard to test by itself?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

everything but the hoisting pass, yes.
When it comes to testing, you can add a couple lit tests that exercise kWidth != 4 / elemSize

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but yeah, since these changes are orthogonal splitting it in two shouldn't be terribly difficult. Tag me as a reviewer once you put it up and I'll approve it. That should make this PR much more lean, and it should be then fine to merge this PR and the pipeline PR into a single manageable PR

Comment on lines 285 to 287
// This ordering is decided when a tensor in DotOpEnc is lowered into llvm.
// For WGMMA this happens in both SharedToDotOperand and MMAToDotOperand.
// Thus, both lowerings must obey this above ordering for the below code to be correct.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decision should be based on the layout definition rather than a convention between different lowering. This comment is a bit misleading and maybe we should more explicitly describe the layout instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's currently nothing in the dot operand layout attributes that would indicate the ordering of matM and matK though, so I assumed it was just implicit logic. I could move this comment to the definition of DotOpEncoding or perhaps remove it altogether to avoid confusion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes the layout is not well documented and/or defined but this is how it should work :) I think moving it to DotOpEncoding is good, this is still valuable in my opinion

@ThomasRaoux
Copy link
Collaborator

This may not see perf gain, and may even see perf loss, for certain shapes

Perf loss is concerning to me

What workload did you use to measure this? If we need both PRs to avoid perf loss we should merge them together, they can be reviewed separately but we should avoid creating perf regressions

@ggengnv
Copy link
Contributor Author

ggengnv commented Oct 29, 2024

Therefore when we do wgmma pipelining with operand coming from register we are going to break this rule as we would keep re-using the same register for A operand.
I believe ptxas will see that and fallback but that is likely to cause significant performance problems.
Is this something we need to handle? Do you know how libraries handle it?

You're right -- I believe the pipelining logic that I currently have in the other PR suffers from this issue. I believe CUTLASS uses more than one RF "buffer" to ping-pong the loads, so that's something I should look into.

This may not see perf gain, and may even see perf loss, for certain shapes

Perf loss is concerning to me

What workload did you use to measure this? If we need both PRs to avoid perf loss we should merge them together, they can be reviewed separately but we should avoid creating perf regressions

I did simple mixed-precision GEMMs on various shapes with one or more dimensions being small (<4k). But now that you bring up the issue with pipelining, I believe I should fix the pipelining logic first before benchmarking again. The perf loss might disappear then.

elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1}));
}
for (int k = 0; k < n1; ++k)
if (isHopper) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it's built on top of bad base

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, made a mistake there. Just reverted -- please ignore :)

@ThomasRaoux
Copy link
Collaborator

You're right -- I believe the pipelining logic that I currently have in the other PR suffers from this issue. I believe CUTLASS uses more than one RF "buffer" to ping-pong the loads, so that's something I should look into.

interesting, I wonder how that's possible unless we unroll the loop?

@ggengnv
Copy link
Contributor Author

ggengnv commented Oct 29, 2024

You're right -- I believe the pipelining logic that I currently have in the other PR suffers from this issue. I believe CUTLASS uses more than one RF "buffer" to ping-pong the loads, so that's something I should look into.

interesting, I wonder how that's possible unless we unroll the loop?

afaik cutlass subdivides each block ("ktiles" in cutlass terminology) into subblocks ("kblocks"). So while TMA operates on the granularity of blocks, each WGMMA instruction will handle only a subblock. The inner loop of iterating over the subblocks of each block is unrolled, and the shared-to-register copies and WGMMAs are interleaved.

I'm not fully certain if it's a ping-pong buffer or something else -- in the process of confirming it

@ThomasRaoux
Copy link
Collaborator

You're right -- I believe the pipelining logic that I currently have in the other PR suffers from this issue. I believe CUTLASS uses more than one RF "buffer" to ping-pong the loads, so that's something I should look into.

interesting, I wonder how that's possible unless we unroll the loop?

afaik cutlass subdivides each block ("ktiles" in cutlass terminology) into subblocks ("kblocks"). So while TMA operates on the granularity of blocks, each WGMMA instruction will handle only a subblock. The inner loop of iterating over the subblocks of each block is unrolled, and the shared-to-register copies and WGMMAs are interleaved.

I'm not fully certain if it's a ping-pong buffer or something else -- in the process of confirming it

interesting, thanks for the information. It would probably be tricky to implement it that way in triton. I think there are other strategy we can apply to respect ptx spec. We could have the wgmma_wait at the end of the loop or we could place it before the a operand is set and add some IR changes to make sure nobody re-orders things incorrectly.

If this is the main blocker for this PR to land I would suggest having this pass turned off by default so that the infra can be pushed. Then we can work on wgmma pipelining changes to make this right and ideally performant.

@ggengnv
Copy link
Contributor Author

ggengnv commented Oct 29, 2024

interesting, thanks for the information. It would probably be tricky to implement it that way in triton. I think there are other strategy we can apply to respect ptx spec. We could have the wgmma_wait at the end of the loop or we could place it before the a operand is set and add some IR changes to make sure nobody re-orders things incorrectly.

If this is the main blocker for this PR to land I would suggest having this pass turned off by default so that the infra can be pushed. Then we can work on wgmma pipelining changes to make this right and ideally performant.

Per @lezcano's suggestion I've first split out the dotOp lowering changes (which is a substantial part of this PR): #5009

After that PR's merged, I think there's still an edge case I need to resolve. But sg -- I can turn off this pass by default after everything else's resolved.

lezcano pushed a commit that referenced this pull request Oct 30, 2024
Allows for upcasting in DotOp encoding in RF.
This lowering path is not currently in use; pending
#5003
@ggengnv ggengnv marked this pull request as draft November 1, 2024 01:30
lezcano pushed a commit that referenced this pull request Nov 4, 2024
Two bugfixes following #5009.

- When `BLOCK_M=64` and `num_warps > 4`, the order of warps for
DotOpEncoded tensor should be M-major instead of N-major, since WGMMA
expects the 4 warps in each warp group to be stacked along the M
dimension.
- Should use `mmaBitwidth` instead of `bitwidth` when calculating
`numRep` in `SharedToDotOperandMMAv2OrV3`. This was missed in a bad
rebase.

@lezcano I encountered these bugs when attempting to locally test the
[DotOp hoisting PR](#5003)
after rebasing (they normally would be caught by `test_core.py` but that
path was not yet enabled in the last PR). With these fixes added, I was
able to successfully validate against pytorch.
@ggengnv
Copy link
Contributor Author

ggengnv commented Nov 4, 2024

Update: cherrypicked pipelining changes into this PR.
Current issues:

  • Suboptimal codegen when num_stages = 1 (should be relatively easy to fix)
  • Perf regression for small MNKs. Currently getting benchmark numbers for these shapes. Will require interleaved pipelining of WGMMAs and scales.

// Hopper may not contain 32b along kWidth; Ampere always does
int kBits = 8 * elemBytes * kWidth;
assert(kBits == 32 || isHopper);
int vecSize = kBits / canonBits;
Copy link
Contributor Author

@ggengnv ggengnv Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for non-transposed case, vecSize = kBits / canonBits was always 1, so the previous logic (isHopper ? 1 : ...) worked.

for transposed case, vecSize can be more than 1, since we're loading each value separately (i.e. canonBits = 8 * elemBytes).

@ggengnv
Copy link
Contributor Author

ggengnv commented Nov 6, 2024

@ThomasRaoux I have obtained some benchmark results. cc @lezcano

All GEMMs are bf16 x int8, both operands k-major. The numbers are in us

  • "BxA Base" means A and B are swapped, running on main branch triton build
  • "BxA RS" means A and B are not swapped, running on my PR's build
  • "AxB" means A and B are not swapped, so the hoisting optimization does not apply
  • Relative perf is the runtime of Base over RS

image (3)

You can notice that for the last three rows as well as some other small shapes (i.e. all shapes with small M), it's faster to just do AxB, i.e. not swap operands. These cases just use Ampere MMAv2 and the LHS optimization wouldn't apply.

@ThomasRaoux
Copy link
Collaborator

  • "BxA RS" means A and B are not swapped, running on my PR's build

you meant A and B are swapped right?

@ggengnv
Copy link
Contributor Author

ggengnv commented Nov 6, 2024

  • "BxA RS" means A and B are not swapped, running on my PR's build

you meant A and B are swapped right?

oh yes, that was a typo.

@ThomasRaoux
Copy link
Collaborator

You can notice that for the last three rows as well as some other small shapes (i.e. all shapes with small M), it's faster to just do AxB, i.e. not swap operands. These cases just use Ampere MMAv2 and the LHS optimization wouldn't apply.

why are the result different for those 3 rows between BxA and BxA RS? I thought when using MMAv2 the PR would have no effect?

@ggengnv
Copy link
Contributor Author

ggengnv commented Nov 6, 2024

You can notice that for the last three rows as well as some other small shapes (i.e. all shapes with small M), it's faster to just do AxB, i.e. not swap operands. These cases just use Ampere MMAv2 and the LHS optimization wouldn't apply.

why are the result different for those 3 rows between BxA and BxA RS? I thought when using MMAv2 the PR would have no effect?

BxA and BxA RS are still with the operands swapped, so the PR would apply. AxB is the one without swapped; hence there's only one column for AxB.

AxB is faster but I included BxA and BxA RS for reference anyway.

@ThomasRaoux
Copy link
Collaborator

BxA and BxA RS are still with the operands swapped, so the PR would apply. AxB is the one without swapped; hence there's only one column for AxB.

But it wouldn't use wgmma even with the operands swapped? I thought the code in the PR would not change this case

@ggengnv
Copy link
Contributor Author

ggengnv commented Nov 6, 2024

What I was concerned about earlier was that this PR's optimization might not make sense when pipelining is disabled, since I expected RS WGMMA (operand A in registers) to be only useful in conjunction with pipelining.

But I just tested on a shape with small K, where we might realistically disable pipelining. The shape profiled was (32, 4096, 128) with block size (64, 16, 128), num_warps = 4, and num_stages = 1. A and B were swapped.
This PR's runtime was 3.25us; baseline runtime 3.31us.

So I think I might've been wrong, and we won't need to worry about this case at all.

@ThomasRaoux
Copy link
Collaborator

What I was concerned about earlier was that this PR's optimization might not make sense when pipelining is disabled, since I expected RS WGMMA (operand A in registers) to be only useful in conjunction with pipelining.

But I just tested on a shape with small K, where we might realistically disable pipelining. The shape profiled was (32, 4096, 128) with block size (64, 16, 128), num_warps = 4, and num_stages = 1. A and B were swapped. This PR's runtime was 3.25us; baseline runtime 3.31us.

So I think I might've been wrong, and we won't need to worry about this case at all.

interesting, I expect that like for mmav2 this should still be better even without pipelining but I haven't looked at IR for this case in a while. Anyway if it improves runtime I wouldn't worry much about it and this is a case we can optimize in future PRs

@Moerafaat
Copy link
Contributor

if (getNumStagesOrDefault(forOp) > 1)

No worries about num_stages=1

We do have cases that run faster with num_stages=1 so I would say it would still be useful if we can avoid regressing this case.

@ggengnv
Copy link
Contributor Author

ggengnv commented Nov 7, 2024

We do have cases that run faster with num_stages=1 so I would say it would still be useful if we can avoid regressing this case.

I think I can simply disable the pass if the encapsulating for loop has num_stages == 1.

@lezcano
Copy link
Contributor

lezcano commented Nov 7, 2024

@Moerafaat can you check what regime of K's you are using num_stages=1 with?

@Moerafaat
Copy link
Contributor

We do have cases that run faster with num_stages=1 so I would say it would still be useful if we can avoid regressing this case.

I think I can simply disable the pass if the encapsulating for loop has num_stages == 1.

If this is not too much work then I would say it's preferable to have that option.

@Moerafaat
Copy link
Contributor

@Moerafaat can you check what regime of K's you are using num_stages=1 with?

I'm checking internally for this information and will get back to you.

@ggengnv
Copy link
Contributor Author

ggengnv commented Nov 8, 2024

If this is not too much work then I would say it's preferable to have that option.

If I want to conditionally disable hoisting based on the value of num_stages, I can pass num_stages to OptimizeDotOperands as a parameter. Is this an acceptable approach?
@lezcano @Jokeren

@lezcano
Copy link
Contributor

lezcano commented Nov 10, 2024

Is this an acceptable approach?

SGTM. That being said, we should wait for @Moerafaat who is looking into their kernels. In the reasonably likely case that they use it for small K, it may be just alright to always keep this on and avoid branching the logic.

@ThomasRaoux
Copy link
Collaborator

If this is not too much work then I would say it's preferable to have that option.

If I want to conditionally disable hoisting based on the value of num_stages, I can pass num_stages to OptimizeDotOperands as a parameter. Is this an acceptable approach? @lezcano @Jokeren

to be honest I'm a bit concerned with those kind of ad hoc heuristics. It is easy to have those creep up and make it even harder to address the problem later on. If there are proven regressions on Google's side then we will need to get to the bottom of it, maybe the pass can be turned off on Google's side for some time but don't think adding heuristics based on paramters that should be independent is a good solution.

@ThomasRaoux
Copy link
Collaborator

also my understanding is that XLA has its own pass manager so it is easy to not enable the optimizations until more performance problems are resolved if needed

@Moerafaat
Copy link
Contributor

Is this an acceptable approach?

SGTM. That being said, we should wait for @Moerafaat who is looking into their kernels. In the reasonably likely case that they use it for small K, it may be just alright to always keep this on and avoid branching the logic.

I haven't seen all cases where we have num_stages = 1, but already we see a few examples where K is between 4k and 8k and the best configuration includes num_stages=1. A close second usually has num_stages>1 but that still means that we can get best tilings without pipelining. That being said, we also don't mind the change if we can enable/disable an entire pass (though that could mean that we would need to address regressions later on to enable it from XLA).

@ggengnv ggengnv marked this pull request as ready for review November 13, 2024 02:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants