Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND][NVIDIA] Add DotOp Hoisting Pass for WGMMA and Add Lowering for SMEM-to-MMAv3 DotOp Copy #5003

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,13 @@ int getNVIDIAComputeCapability(Operation *module);
std::optional<mlir::triton::gpu::SharedEncodingAttr>
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);

bool loadIsMMAv3(Operation *loadOp);
enum class MMALoadType {
SharedV3,
Registers, // may be v2 or v3
DoNotPipeline, // could be a valid shared/registers MMA operand, but skip
// pipelining
};
MMALoadType getMMALoadType(Operation *loadOp);
} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
11 changes: 9 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,18 @@ filterPipelinedLoad(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>

bool hasSharedEncoding = false;
if (use->hasTrait<OpTrait::DotLike>()) {
if (loadIsMMAv3(op)) {
auto mmaLoadType = getMMALoadType(op);
auto dot = dyn_cast<tt::DotOp>(use);
auto warpGroupDot = dyn_cast<ttng::WarpGroupDotOp>(use);
bool isMMAv3Shared = mmaLoadType == MMALoadType::SharedV3;
bool isMMAv3Registers =
(mmaLoadType == MMALoadType::Registers) && warpGroupDot;

if (isMMAv3Shared) {
hasSharedEncoding = true;
} else if (isa<tt::ExperimentalDescriptorLoadOp>(op)) {
hasSharedEncoding = true;
} else if (auto dot = dyn_cast<tt::DotOp>(use)) {
} else if (isMMAv3Registers || dot) {
// FIXME: if we have a better solution in handling incompatible shared
// encoding, we can simplify the logic here by checking if all users are
// dot encoding. Fow now, getSharedEncIfAllUsersAreDotEnc will be used
Expand Down
290 changes: 272 additions & 18 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
Expand All @@ -15,6 +17,125 @@ namespace gpu {

namespace {

// Helpers

// Returns whether we can hoist DotOp Encoding through `op`.
// Roughly, whether op is elementwise and thus threads don't need
// to exchange elements. But some ops are not currently supported even though
// they meet that criterion.
bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) {
// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm(op) &&
!isa<arith::ArithDialect>(op->getDialect()))
return false;

// Quick handling to fix loading issues when computing the original
// bitwidth is unable to realize that there is a mixed-precision dot
// (hence kWidth = 1) but wants to hoist through the type conversion.
if (isa<arith::ExtFOp>(op) && dotOpEnc.getKWidth() == 1)
return false;

// Currently, these instructions are not supported during lowering of
// shared -> dot_operand layout. Not all types and type conversions are
// supported.
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(op))
return false;

// Don't hoist through u1 -> fp casts as they aren't supported in
// ElementwiseOpToLLVM::reorderValues().
if (isa<arith::UIToFPOp>(op)) {
Type opType = getElementTypeOrSelf(op->getOperand(0));
if (opType.isInteger(1))
return false;
}

return true;
}

// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A
// is in registers).
bool canHoistDotOpEncV3(Operation *op) {
// Must have exactly one result and at least one operand
if (op->getNumOperands() == 0 || op->getNumResults() != 1)
return false;

auto isBlockedOrDotOpRankedTensor = [](Type ty) {
auto tensorTy = dyn_cast<RankedTensorType>(ty);
if (!tensorTy)
return false;
return isa<BlockedEncodingAttr, DotOperandEncodingAttr>(
tensorTy.getEncoding());
};

// Operands and results must be of RankedTensorType and Blocked or DotOp
if (!(all_of(op->getOperandTypes(), isBlockedOrDotOpRankedTensor) &&
all_of(op->getResultTypes(), isBlockedOrDotOpRankedTensor)))
return false;

// Only consider custom conversions or arith ops.
if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm(op) &&
!isa<arith::ArithDialect>(op->getDialect()))
return false;

// Currently, these instructions are not supported during lowering of
// shared -> dot_operand layout. Not all types and type conversions are
// supported.
if (isa<arith::SelectOp>(op))
return false;

// Downcasting not currently supported; it will likely require minor
// adjustments in sharedToDotOperandMMv2
auto oprType = getElementTypeOrSelf(op->getOperand(0));
auto resType = getElementTypeOrSelf(op->getResult(0));
if (oprType.getIntOrFloatBitWidth() > resType.getIntOrFloatBitWidth())
return false;

// Don't hoist through u1 -> fp casts as they aren't supported in
// ElementwiseOpToLLVM::reorderValues().
if (isa<arith::UIToFPOp>(op) && oprType.isInteger(1))
return false;

return true;
}

// Helper to perform a "deep" clone of the given slice (i.e., set of ops),
// returning a tuple (newSlice, sliceMap), where newSlice is the cloned slice,
// and sliceMap the IRMapping that maps the ops and result values of the
// original slice to those in the cloned slice.
auto cloneSlice(PatternRewriter &rewriter,
const SetVector<Operation *> &slice) {
IRMapping sliceMap;
SetVector<Operation *> newSlice;

// First pass: clone ops; the result values are cloned as well, but the
// operands still refer to the original result values
for (Operation *op : slice) {
rewriter.setInsertionPoint(op);
auto newOp = rewriter.clone(*op);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the clone op be inserted right before the old op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean that I should use setInsertionPoint before the clones?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah? To avoid pulling all the operations down to the dot

newSlice.insert(newOp);
sliceMap.map(op, newOp);
for (auto [result, newResult] :
llvm::zip(op->getResults(), newOp->getResults())) {
assert(result != newResult);
sliceMap.map(result, newResult);
}
}

// Second pass: replace operand references in cloned ops to point to cloned
// values
for (auto [op, newOp] : sliceMap.getOperationMap())
for (auto [oprIdx, operand] : llvm::enumerate(newOp->getOperands())) {
auto defOp = operand.getDefiningOp();
if (!slice.contains(defOp))
continue;

newOp->setOperand(oprIdx, sliceMap.lookup(operand));
}

return std::make_tuple(newSlice, sliceMap);
}

// Given
// convert(trans(src)) #dot_operand ->
// convert(local_load(trans(alloc(src))))
Expand Down Expand Up @@ -111,7 +232,8 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
PatternRewriter &rewriter) const override {
// Only consider conversions to dot operand.
auto cvtTy = cast<RankedTensorType>(cvt.getType());
if (!isa<DotOperandEncodingAttr>(cvtTy.getEncoding()))
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(cvtTy.getEncoding());
if (!dotOpEnc)
return failure();

auto src = cvt.getSrc().getDefiningOp();
Expand All @@ -126,16 +248,7 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
[](Type ty) { return isa<RankedTensorType>(ty); }))
return failure();

// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
src->getDialect()->getTypeID() != TypeID::get<arith::ArithDialect>())
return failure();

// Currently, these instructions are not supported during lowering of
// shared -> dot_operand layout. Not all types and type conversions are
// supported.
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
if (!canHoistDotOpEncV2(src, dotOpEnc))
return failure();

// Check that the conversion is transitively dependent on a load, and all
Expand Down Expand Up @@ -165,12 +278,7 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
if (isa<LoadOp>(currOp)) {
foundLoad = true;
} else if (foundLoad) {
// Bail out if there exists an op after Load that is not FpToFp,
// Bitcast, or Arith.
if (!isa<FpToFpOp, BitcastOp>(currOp) &&
!isPureUnaryInlineAsm(currOp) &&
currOp->getDialect()->getTypeID() !=
TypeID::get<arith::ArithDialect>())
if (!canHoistDotOpEncV2(currOp, dotOpEnc))
return failure();
}
}
Expand Down Expand Up @@ -301,6 +409,150 @@ struct MMAV3UseRegOperand
}
};

// MMAV3's analog of HoistLayoutConversion, for operand A only; will make
// WarpGroupDot accept operand A in registers instead of shmem.
//
// Before: load #blocked; (elementwise #blocked)+; local_alloc; warp_group_dot
// After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+;
// warp_group_dot
//
// Whereas (MMAV2) HoistLayoutConversion hoists thru one elementwise op at a
// time and requires multiple passes, this pattern will directly hoist the
// convert to the right place in one pass.
//
// Or, to be more precise, this pattern deletes the local_alloc op and inserts a
// convert_layout op after each load that warp_group_dot uses; so this is not
// simply hoisting a convert_layout op up as in V2, but can be considered as
// first changing local_alloc to convert_layout and then hoisting, which results
// in WGMMA now accepting operand A in DotOp layout rather than Shared.
struct MMAV3HoistLayoutConversion
: public OpRewritePattern<triton::nvidia_gpu::WarpGroupDotOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(triton::nvidia_gpu::WarpGroupDotOp dotOp,
PatternRewriter &rewriter) const override {
// Can only hoist operand 0
auto alloc = dotOp.getOperand(0).getDefiningOp<LocalAllocOp>();
if (!alloc || !alloc.getSrc())
return rewriter.notifyMatchFailure(
dotOp, "operand A must be produced by local_alloc");

auto getEncoding = [](Value v) {
return cast<TensorOrMemDesc>(v.getType()).getEncoding();
};

if (!isa<SharedEncodingAttr>(getEncoding(dotOp.getOperand(0))))
return rewriter.notifyMatchFailure(
dotOp, "requires Shared encoding for operand A");

// Step 1: Performs checks for early stop
auto srcEnc = dyn_cast<BlockedEncodingAttr>(getEncoding(alloc.getSrc()));
if (!srcEnc)
return rewriter.notifyMatchFailure(
alloc, "requires src to have Blocked encoding");

auto dstEnc =
dyn_cast<NvidiaMmaEncodingAttr>(getEncoding(dotOp.getResult()));
if (!dstEnc || dstEnc.getVersionMajor() != 3)
return rewriter.notifyMatchFailure(
dotOp, "requires result in NvidiaMma encoding");

// Step 2: Obtain slice of ops between load/constant and local_alloc
SetVector<Operation *> slice;
BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = [&](Operation *op) {
// Stop before Load, ConstantOp, or LocalLoad
return (op->getParentRegion() == alloc->getParentRegion()) &&
!isa<LoadOp, arith::ConstantOp, LocalLoadOp>(op) &&
(op->getNumOperands() != 0);
};
getBackwardSlice(alloc.getOperation(), &slice, opt);

// Step 3: Verify slice can be hoisted through
if (slice.empty())
return rewriter.notifyMatchFailure(dotOp, "nothing to hoist through");

// We define frontierOp as an op outside this slice whose result is used by
// an op in this slice. We must eventually convert the result of all
// frontierOps to DotOperandEncoding. This is done via the insertion of
// ConvertLayout after each frontierOp. We currently support frontierOp to
// be load or constant.
for (Operation *currOp : slice) {
if (!canHoistDotOpEncV3(currOp))
return rewriter.notifyMatchFailure(currOp, "cannot hoist through");

// We previously ensured that all ops in slice have at least one operand
for (auto operand : currOp->getOperands()) {
auto defOp = operand.getDefiningOp();
if (!slice.contains(defOp)) {
// ensure frontierOp is load or constant
if (!isa<LoadOp, arith::ConstantOp>(defOp))
return rewriter.notifyMatchFailure(defOp,
"must be load or constant");
}
}
}

// Step 4: Clone slice
auto [newSlice, sliceMap] = cloneSlice(rewriter, slice);

// Step 5: Modify the cloned slice to have dotOp encoding.
// Before: load #blocked; (elementwise #blocked)+; local_alloc;
// warp_group_dot After: load #blocked; convert_layout #dot_op;
// (elementwise #dot_op)+; warp_group_dot
//
// Specifically, this step will change all value types from #blocked to
// #dot_op encoding in the cloned slice, and for those values produced by
// frontierOps (i.e., outside the slice), we will insert convert_layout's
// after the frontierOp.
auto srcTy = cast<RankedTensorType>(alloc.getSrc().getType());
Type inputEltTy = srcTy.getElementType();
auto dotOperandEnc = DotOperandEncodingAttr::get(
dotOp.getContext(), /*opIdx=*/0, dstEnc, inputEltTy);

for (auto op : newSlice) {
// Step 5a: If any operand is defined by a frontierOp, we must insert a
// convert_layout(#dot_op) after the frontierOp and before currOp
for (auto [oprIdx, operand] : llvm::enumerate(op->getOperands())) {

auto defOp = operand.getDefiningOp();

// defOp is not frontier (i.e. it's within slice); no need to convert
// the layout of its result
if (newSlice.contains(defOp))
continue;

// We checked earlier that all operands are ranked tensors
auto operandTy = cast<RankedTensorType>(operand.getType());
auto operandEltTy = operandTy.getElementType();

Type cvtTy = RankedTensorType::get(
operandTy.getShape(), operandTy.getElementType(), dotOperandEnc);
rewriter.setInsertionPoint(op);
auto cvt =
rewriter.create<ConvertLayoutOp>(defOp->getLoc(), cvtTy, operand);

op->setOperand(oprIdx, cvt);
}

// Step 5b: Change the result to have DotOp rather than Blocked encoding
auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
op->getResult(0).setType(RankedTensorType::get(
resTy.getShape(), resTy.getElementType(), dotOperandEnc));
}

// Step 6: replace LHS operand with alloc's parent in the cloned slice
// This changes the warpGroupDot to accept a DotOp tensor as operand A
// instead of a Shared memdesc.
auto newDotOperand = sliceMap.lookup(alloc.getSrc());
rewriter.modifyOpInPlace(dotOp,
[&]() { dotOp.setOperand(0, newDotOperand); });

return success();
}
};

} // namespace

#define GEN_PASS_DEF_TRITONGPUOPTIMIZEDOTOPERANDS
Expand All @@ -322,9 +574,11 @@ class TritonGPUOptimizeDotOperandsPass
auto ret = pm.run(m);

mlir::RewritePatternSet patterns(context);
patterns.add<MMAV3HoistLayoutConversion>(context);
patterns.add<SwizzleShmemConvert>(context);
if (this->hoistLayoutConversion.getValue())
if (this->hoistLayoutConversion.getValue()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: following MLIR style we usually don't have braces here

patterns.add<HoistLayoutConversion>(context);
}
patterns.add<FuseTransHopper>(context);
patterns.add<MMAV3UseRegOperand>(context);
ConvertLayoutOp::getCanonicalizationPatterns(patterns, context);
Expand Down
Loading
Loading