-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BACKEND][NVIDIA] Add DotOp Hoisting Pass for WGMMA and Add Lowering for SMEM-to-MMAv3 DotOp Copy #5003
Open
ggengnv
wants to merge
16
commits into
triton-lang:main
Choose a base branch
from
ggengnv:oai-lhs-reg-hoist
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+548
−59
Open
[BACKEND][NVIDIA] Add DotOp Hoisting Pass for WGMMA and Add Lowering for SMEM-to-MMAv3 DotOp Copy #5003
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
f0fe49d
Add preliminary logic to hoist elt-wise ops for MMAv3
ggengnv d2fff26
Lower shared > v3 dotOp & improve hoisting logic
ggengnv 7308447
Fix test regressions
ggengnv 8aef99b
Rewrite OptimizeDotOperands logic and add tests
ggengnv 32651e9
Improve comments
ggengnv 25fc6be
Improve documentation and refactor
ggengnv 27b2333
Rename SharedToDotOperandMMAv2 -> ...v2OrV3
ggengnv d5932b2
Remove debug flags in test_core.py
ggengnv 26e7407
Address even more comments
ggengnv a40e519
Add pipelining and tests
ggengnv 73363cf
Fix bug in MMAv2OrV3 lowering for transposed case
ggengnv b3dc4f0
Tighten pipeline properlyAsync check for DotOp in registers
ggengnv c1272f1
Fix test regressions related to properlyAsync logic
ggengnv 5ce5628
Fix rebase: use getMMALoadType in LoopScheduling as well
ggengnv 20f9ba0
Clang format
ggengnv 882aefc
Fix Hopper path in composeValues...
ggengnv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
#include "mlir/IR/IRMapping.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include "mlir/Pass/PassManager.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "mlir/Transforms/Passes.h" | ||
#include "triton/Analysis/Utility.h" | ||
#include "triton/Dialect/TritonGPU/IR/Attributes.h" | ||
#include "triton/Dialect/TritonGPU/IR/Dialect.h" | ||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h" | ||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h" | ||
|
@@ -15,6 +17,125 @@ namespace gpu { | |
|
||
namespace { | ||
|
||
// Helpers | ||
|
||
// Returns whether we can hoist DotOp Encoding through `op`. | ||
// Roughly, whether op is elementwise and thus threads don't need | ||
// to exchange elements. But some ops are not currently supported even though | ||
// they meet that criterion. | ||
bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) { | ||
// Only consider custom conversions or arith ops. | ||
// TODO(jlebar): Is this too restrictive? | ||
if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm(op) && | ||
!isa<arith::ArithDialect>(op->getDialect())) | ||
return false; | ||
|
||
// Quick handling to fix loading issues when computing the original | ||
// bitwidth is unable to realize that there is a mixed-precision dot | ||
// (hence kWidth = 1) but wants to hoist through the type conversion. | ||
if (isa<arith::ExtFOp>(op) && dotOpEnc.getKWidth() == 1) | ||
return false; | ||
|
||
// Currently, these instructions are not supported during lowering of | ||
// shared -> dot_operand layout. Not all types and type conversions are | ||
// supported. | ||
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(op)) | ||
return false; | ||
|
||
// Don't hoist through u1 -> fp casts as they aren't supported in | ||
// ElementwiseOpToLLVM::reorderValues(). | ||
if (isa<arith::UIToFPOp>(op)) { | ||
Type opType = getElementTypeOrSelf(op->getOperand(0)); | ||
if (opType.isInteger(1)) | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A | ||
// is in registers). | ||
bool canHoistDotOpEncV3(Operation *op) { | ||
// Must have exactly one result and at least one operand | ||
if (op->getNumOperands() == 0 || op->getNumResults() != 1) | ||
return false; | ||
|
||
auto isBlockedOrDotOpRankedTensor = [](Type ty) { | ||
auto tensorTy = dyn_cast<RankedTensorType>(ty); | ||
if (!tensorTy) | ||
return false; | ||
return isa<BlockedEncodingAttr, DotOperandEncodingAttr>( | ||
tensorTy.getEncoding()); | ||
}; | ||
|
||
// Operands and results must be of RankedTensorType and Blocked or DotOp | ||
if (!(all_of(op->getOperandTypes(), isBlockedOrDotOpRankedTensor) && | ||
all_of(op->getResultTypes(), isBlockedOrDotOpRankedTensor))) | ||
return false; | ||
|
||
// Only consider custom conversions or arith ops. | ||
if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm(op) && | ||
!isa<arith::ArithDialect>(op->getDialect())) | ||
return false; | ||
|
||
// Currently, these instructions are not supported during lowering of | ||
// shared -> dot_operand layout. Not all types and type conversions are | ||
// supported. | ||
if (isa<arith::SelectOp>(op)) | ||
return false; | ||
|
||
// Downcasting not currently supported; it will likely require minor | ||
// adjustments in sharedToDotOperandMMv2 | ||
auto oprType = getElementTypeOrSelf(op->getOperand(0)); | ||
auto resType = getElementTypeOrSelf(op->getResult(0)); | ||
if (oprType.getIntOrFloatBitWidth() > resType.getIntOrFloatBitWidth()) | ||
return false; | ||
|
||
// Don't hoist through u1 -> fp casts as they aren't supported in | ||
// ElementwiseOpToLLVM::reorderValues(). | ||
if (isa<arith::UIToFPOp>(op) && oprType.isInteger(1)) | ||
return false; | ||
|
||
return true; | ||
} | ||
|
||
// Helper to perform a "deep" clone of the given slice (i.e., set of ops), | ||
// returning a tuple (newSlice, sliceMap), where newSlice is the cloned slice, | ||
// and sliceMap the IRMapping that maps the ops and result values of the | ||
// original slice to those in the cloned slice. | ||
auto cloneSlice(PatternRewriter &rewriter, | ||
const SetVector<Operation *> &slice) { | ||
IRMapping sliceMap; | ||
SetVector<Operation *> newSlice; | ||
|
||
// First pass: clone ops; the result values are cloned as well, but the | ||
// operands still refer to the original result values | ||
for (Operation *op : slice) { | ||
rewriter.setInsertionPoint(op); | ||
auto newOp = rewriter.clone(*op); | ||
newSlice.insert(newOp); | ||
sliceMap.map(op, newOp); | ||
for (auto [result, newResult] : | ||
llvm::zip(op->getResults(), newOp->getResults())) { | ||
assert(result != newResult); | ||
sliceMap.map(result, newResult); | ||
} | ||
} | ||
|
||
// Second pass: replace operand references in cloned ops to point to cloned | ||
// values | ||
for (auto [op, newOp] : sliceMap.getOperationMap()) | ||
for (auto [oprIdx, operand] : llvm::enumerate(newOp->getOperands())) { | ||
auto defOp = operand.getDefiningOp(); | ||
if (!slice.contains(defOp)) | ||
continue; | ||
|
||
newOp->setOperand(oprIdx, sliceMap.lookup(operand)); | ||
} | ||
|
||
return std::make_tuple(newSlice, sliceMap); | ||
} | ||
|
||
// Given | ||
// convert(trans(src)) #dot_operand -> | ||
// convert(local_load(trans(alloc(src)))) | ||
|
@@ -111,7 +232,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> { | |
PatternRewriter &rewriter) const override { | ||
// Only consider conversions to dot operand. | ||
auto cvtTy = cast<RankedTensorType>(cvt.getType()); | ||
if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding())) | ||
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding()); | ||
if (!dotOpEnc) | ||
return failure(); | ||
|
||
auto src = cvt.getSrc().getDefiningOp(); | ||
|
@@ -126,16 +248,7 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> { | |
[](Type ty) { return isa<RankedTensorType>(ty); })) | ||
return failure(); | ||
|
||
// Only consider custom conversions or arith ops. | ||
// TODO(jlebar): Is this too restrictive? | ||
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) && | ||
src->getDialect()->getTypeID() != TypeID::get<arith::ArithDialect>()) | ||
return failure(); | ||
|
||
// Currently, these instructions are not supported during lowering of | ||
// shared -> dot_operand layout. Not all types and type conversions are | ||
// supported. | ||
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src)) | ||
if (!canHoistDotOpEncV2(src, dotOpEnc)) | ||
return failure(); | ||
|
||
// Check that the conversion is transitively dependent on a load, and all | ||
|
@@ -165,12 +278,7 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> { | |
if (isa<LoadOp>(currOp)) { | ||
foundLoad = true; | ||
} else if (foundLoad) { | ||
// Bail out if there exists an op after Load that is not FpToFp, | ||
// Bitcast, or Arith. | ||
if (!isa<FpToFpOp, BitcastOp>(currOp) && | ||
!isPureUnaryInlineAsm(currOp) && | ||
currOp->getDialect()->getTypeID() != | ||
TypeID::get<arith::ArithDialect>()) | ||
if (!canHoistDotOpEncV2(currOp, dotOpEnc)) | ||
return failure(); | ||
} | ||
} | ||
|
@@ -301,6 +409,150 @@ struct MMAV3UseRegOperand | |
} | ||
}; | ||
|
||
// MMAV3's analog of HoistLayoutConversion, for operand A only; will make | ||
// WarpGroupDot accept operand A in registers instead of shmem. | ||
// | ||
// Before: load #blocked; (elementwise #blocked)+; local_alloc; warp_group_dot | ||
// After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+; | ||
// warp_group_dot | ||
// | ||
// Whereas (MMAV2) HoistLayoutConversion hoists thru one elementwise op at a | ||
// time and requires multiple passes, this pattern will directly hoist the | ||
// convert to the right place in one pass. | ||
// | ||
// Or, to be more precise, this pattern deletes the local_alloc op and inserts a | ||
// convert_layout op after each load that warp_group_dot uses; so this is not | ||
// simply hoisting a convert_layout op up as in V2, but can be considered as | ||
// first changing local_alloc to convert_layout and then hoisting, which results | ||
// in WGMMA now accepting operand A in DotOp layout rather than Shared. | ||
struct MMAV3HoistLayoutConversion | ||
: public OpRewritePattern<triton::nvidia_gpu::WarpGroupDotOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp dotOp, | ||
PatternRewriter &rewriter) const override { | ||
// Can only hoist operand 0 | ||
auto alloc = dotOp.getOperand(0).getDefiningOp<LocalAllocOp>(); | ||
if (!alloc || !alloc.getSrc()) | ||
return rewriter.notifyMatchFailure( | ||
dotOp, "operand A must be produced by local_alloc"); | ||
|
||
auto getEncoding = [](Value v) { | ||
return cast<TensorOrMemDesc>(v.getType()).getEncoding(); | ||
}; | ||
|
||
if (!isa<SharedEncodingAttr>(getEncoding(dotOp.getOperand(0)))) | ||
return rewriter.notifyMatchFailure( | ||
dotOp, "requires Shared encoding for operand A"); | ||
|
||
// Step 1: Performs checks for early stop | ||
auto srcEnc = dyn_cast<BlockedEncodingAttr>(getEncoding(alloc.getSrc())); | ||
if (!srcEnc) | ||
return rewriter.notifyMatchFailure( | ||
alloc, "requires src to have Blocked encoding"); | ||
|
||
auto dstEnc = | ||
dyn_cast<NvidiaMmaEncodingAttr>(getEncoding(dotOp.getResult())); | ||
if (!dstEnc || dstEnc.getVersionMajor() != 3) | ||
return rewriter.notifyMatchFailure( | ||
dotOp, "requires result in NvidiaMma encoding"); | ||
|
||
// Step 2: Obtain slice of ops between load/constant and local_alloc | ||
SetVector<Operation *> slice; | ||
BackwardSliceOptions opt; | ||
opt.omitBlockArguments = true; | ||
opt.filter = [&](Operation *op) { | ||
// Stop before Load, ConstantOp, or LocalLoad | ||
return (op->getParentRegion() == alloc->getParentRegion()) && | ||
!isa<LoadOp, arith::ConstantOp, LocalLoadOp>(op) && | ||
(op->getNumOperands() != 0); | ||
}; | ||
getBackwardSlice(alloc.getOperation(), &slice, opt); | ||
|
||
// Step 3: Verify slice can be hoisted through | ||
if (slice.empty()) | ||
return rewriter.notifyMatchFailure(dotOp, "nothing to hoist through"); | ||
|
||
// We define frontierOp as an op outside this slice whose result is used by | ||
// an op in this slice. We must eventually convert the result of all | ||
// frontierOps to DotOperandEncoding. This is done via the insertion of | ||
// ConvertLayout after each frontierOp. We currently support frontierOp to | ||
// be load or constant. | ||
for (Operation *currOp : slice) { | ||
if (!canHoistDotOpEncV3(currOp)) | ||
return rewriter.notifyMatchFailure(currOp, "cannot hoist through"); | ||
|
||
// We previously ensured that all ops in slice have at least one operand | ||
for (auto operand : currOp->getOperands()) { | ||
auto defOp = operand.getDefiningOp(); | ||
if (!slice.contains(defOp)) { | ||
// ensure frontierOp is load or constant | ||
if (!isa<LoadOp, arith::ConstantOp>(defOp)) | ||
return rewriter.notifyMatchFailure(defOp, | ||
"must be load or constant"); | ||
} | ||
} | ||
} | ||
|
||
// Step 4: Clone slice | ||
auto [newSlice, sliceMap] = cloneSlice(rewriter, slice); | ||
|
||
// Step 5: Modify the cloned slice to have dotOp encoding. | ||
// Before: load #blocked; (elementwise #blocked)+; local_alloc; | ||
// warp_group_dot After: load #blocked; convert_layout #dot_op; | ||
// (elementwise #dot_op)+; warp_group_dot | ||
// | ||
// Specifically, this step will change all value types from #blocked to | ||
// #dot_op encoding in the cloned slice, and for those values produced by | ||
// frontierOps (i.e., outside the slice), we will insert convert_layout's | ||
// after the frontierOp. | ||
auto srcTy = cast<RankedTensorType>(alloc.getSrc().getType()); | ||
Type inputEltTy = srcTy.getElementType(); | ||
auto dotOperandEnc = DotOperandEncodingAttr::get( | ||
dotOp.getContext(), /*opIdx=*/0, dstEnc, inputEltTy); | ||
|
||
for (auto op : newSlice) { | ||
// Step 5a: If any operand is defined by a frontierOp, we must insert a | ||
// convert_layout(#dot_op) after the frontierOp and before currOp | ||
for (auto [oprIdx, operand] : llvm::enumerate(op->getOperands())) { | ||
|
||
auto defOp = operand.getDefiningOp(); | ||
|
||
// defOp is not frontier (i.e. it's within slice); no need to convert | ||
// the layout of its result | ||
if (newSlice.contains(defOp)) | ||
continue; | ||
|
||
// We checked earlier that all operands are ranked tensors | ||
auto operandTy = cast<RankedTensorType>(operand.getType()); | ||
auto operandEltTy = operandTy.getElementType(); | ||
|
||
Type cvtTy = RankedTensorType::get( | ||
operandTy.getShape(), operandTy.getElementType(), dotOperandEnc); | ||
rewriter.setInsertionPoint(op); | ||
auto cvt = | ||
rewriter.create<ConvertLayoutOp>(defOp->getLoc(), cvtTy, operand); | ||
|
||
op->setOperand(oprIdx, cvt); | ||
} | ||
|
||
// Step 5b: Change the result to have DotOp rather than Blocked encoding | ||
auto resTy = cast<RankedTensorType>(op->getResult(0).getType()); | ||
op->getResult(0).setType(RankedTensorType::get( | ||
resTy.getShape(), resTy.getElementType(), dotOperandEnc)); | ||
} | ||
|
||
// Step 6: replace LHS operand with alloc's parent in the cloned slice | ||
// This changes the warpGroupDot to accept a DotOp tensor as operand A | ||
// instead of a Shared memdesc. | ||
auto newDotOperand = sliceMap.lookup(alloc.getSrc()); | ||
rewriter.modifyOpInPlace(dotOp, | ||
[&]() { dotOp.setOperand(0, newDotOperand); }); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
#define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS | ||
|
@@ -322,9 +574,11 @@ class TritonGPUOptimizeDotOperandsPass | |
auto ret = pm.run(m); | ||
|
||
mlir::RewritePatternSet patterns(context); | ||
patterns.add<MMAV3HoistLayoutConversion>(context); | ||
patterns.add<SwizzleShmemConvert>(context); | ||
if (this->hoistLayoutConversion.getValue()) | ||
if (this->hoistLayoutConversion.getValue()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: following MLIR style we usually don't have braces here |
||
patterns.add<HoistLayoutConversion>(context); | ||
} | ||
patterns.add<FuseTransHopper>(context); | ||
patterns.add<MMAV3UseRegOperand>(context); | ||
ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should the clone op be inserted right before the old op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean that I should use
setInsertionPoint
before theclone
s?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah? To avoid pulling all the operations down to the dot