Skip to content

Commit

Permalink
Clang format
Browse files Browse the repository at this point in the history
  • Loading branch information
ggengnv committed Nov 12, 2024
1 parent 5ce5628 commit 20f9ba0
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 86 deletions.
5 changes: 3 additions & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,9 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);

enum class MMALoadType {
SharedV3,
Registers, // may be v2 or v3
DoNotPipeline, // could be a valid shared/registers MMA operand, but skip pipelining
Registers, // may be v2 or v3
DoNotPipeline, // could be a valid shared/registers MMA operand, but skip
// pipelining
};
MMALoadType getMMALoadType(Operation *loadOp);
} // namespace mlir
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ filterPipelinedLoad(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
auto dot = dyn_cast<tt::DotOp>(use);
auto warpGroupDot = dyn_cast<ttng::WarpGroupDotOp>(use);
bool isMMAv3Shared = mmaLoadType == MMALoadType::SharedV3;
bool isMMAv3Registers = (mmaLoadType == MMALoadType::Registers)
&& warpGroupDot;
bool isMMAv3Registers =
(mmaLoadType == MMALoadType::Registers) && warpGroupDot;

if (isMMAv3Shared) {
hasSharedEncoding = true;
Expand Down
111 changes: 60 additions & 51 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace {
// Roughly, whether op is elementwise and thus threads don't need
// to exchange elements. But some ops are not currently supported even though
// they meet that criterion.
bool canHoistDotOpEncV2(Operation* op, DotOperandEncodingAttr& dotOpEnc) {
bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) {
// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(op) && !isPureUnaryInlineAsm(op) &&
Expand All @@ -34,7 +34,7 @@ bool canHoistDotOpEncV2(Operation* op, DotOperandEncodingAttr& dotOpEnc) {
// bitwidth is unable to realize that there is a mixed-precision dot
// (hence kWidth = 1) but wants to hoist through the type conversion.
if (isa<arith::ExtFOp>(op) && dotOpEnc.getKWidth() == 1)
return false;
return false;

// Currently, these instructions are not supported during lowering of
// shared -> dot_operand layout. Not all types and type conversions are
Expand All @@ -55,7 +55,7 @@ bool canHoistDotOpEncV2(Operation* op, DotOperandEncodingAttr& dotOpEnc) {

// Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A
// is in registers).
bool canHoistDotOpEncV3(Operation* op) {
bool canHoistDotOpEncV3(Operation *op) {
// Must have exactly one result and at least one operand
if (op->getNumOperands() == 0 || op->getNumResults() != 1)
return false;
Expand All @@ -64,7 +64,8 @@ bool canHoistDotOpEncV3(Operation* op) {
auto tensorTy = dyn_cast<RankedTensorType>(ty);
if (!tensorTy)
return false;
return isa<BlockedEncodingAttr, DotOperandEncodingAttr>(tensorTy.getEncoding());
return isa<BlockedEncodingAttr, DotOperandEncodingAttr>(
tensorTy.getEncoding());
};

// Operands and results must be of RankedTensorType and Blocked or DotOp
Expand Down Expand Up @@ -102,24 +103,27 @@ bool canHoistDotOpEncV3(Operation* op) {
// returning a tuple (newSlice, sliceMap), where newSlice is the cloned slice,
// and sliceMap the IRMapping that maps the ops and result values of the
// original slice to those in the cloned slice.
auto cloneSlice(PatternRewriter& rewriter, const SetVector<Operation *>& slice) {
auto cloneSlice(PatternRewriter &rewriter,
const SetVector<Operation *> &slice) {
IRMapping sliceMap;
SetVector<Operation*> newSlice;
SetVector<Operation *> newSlice;

// First pass: clone ops; the result values are cloned as well, but the operands still
// refer to the original result values
// First pass: clone ops; the result values are cloned as well, but the
// operands still refer to the original result values
for (Operation *op : slice) {
rewriter.setInsertionPoint(op);
auto newOp = rewriter.clone(*op);
newSlice.insert(newOp);
sliceMap.map(op, newOp);
for (auto [result, newResult] : llvm::zip(op->getResults(), newOp->getResults())) {
for (auto [result, newResult] :
llvm::zip(op->getResults(), newOp->getResults())) {
assert(result != newResult);
sliceMap.map(result, newResult);
}
}

// Second pass: replace operand references in cloned ops to point to cloned values
// Second pass: replace operand references in cloned ops to point to cloned
// values
for (auto [op, newOp] : sliceMap.getOperationMap())
for (auto [oprIdx, operand] : llvm::enumerate(newOp->getOperands())) {
auto defOp = operand.getDefiningOp();
Expand Down Expand Up @@ -405,21 +409,22 @@ struct MMAV3UseRegOperand
}
};

// MMAV3's analog of HoistLayoutConversion, for operand A only; will make WarpGroupDot
// accept operand A in registers instead of shmem.
// MMAV3's analog of HoistLayoutConversion, for operand A only; will make
// WarpGroupDot accept operand A in registers instead of shmem.
//
// Before: load #blocked; (elementwise #blocked)+; local_alloc; warp_group_dot
// After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+; warp_group_dot
// After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+;
// warp_group_dot
//
// Whereas (MMAV2) HoistLayoutConversion hoists thru one elementwise op at a time and
// requires multiple passes, this pattern will directly hoist the convert to the right
// place in one pass.
// Whereas (MMAV2) HoistLayoutConversion hoists thru one elementwise op at a
// time and requires multiple passes, this pattern will directly hoist the
// convert to the right place in one pass.
//
// Or, to be more precise, this pattern deletes the local_alloc op and inserts a
// convert_layout op after each load that warp_group_dot uses; so this is not simply hoisting
// a convert_layout op up as in V2, but can be considered as first changing local_alloc to
// convert_layout and then hoisting, which results in WGMMA now accepting operand A in DotOp
// layout rather than Shared.
// convert_layout op after each load that warp_group_dot uses; so this is not
// simply hoisting a convert_layout op up as in V2, but can be considered as
// first changing local_alloc to convert_layout and then hoisting, which results
// in WGMMA now accepting operand A in DotOp layout rather than Shared.
struct MMAV3HoistLayoutConversion
: public OpRewritePattern<triton::nvidia_gpu::WarpGroupDotOp> {
using OpRewritePattern::OpRewritePattern;
Expand All @@ -429,49 +434,50 @@ struct MMAV3HoistLayoutConversion
// Can only hoist operand 0
auto alloc = dotOp.getOperand(0).getDefiningOp<LocalAllocOp>();
if (!alloc || !alloc.getSrc())
return rewriter.notifyMatchFailure(dotOp,
"operand A must be produced by local_alloc");
return rewriter.notifyMatchFailure(
dotOp, "operand A must be produced by local_alloc");

auto getEncoding = [](Value v) {
return cast<TensorOrMemDesc>(v.getType()).getEncoding();
};

if (!isa<SharedEncodingAttr>(getEncoding(dotOp.getOperand(0))))
return rewriter.notifyMatchFailure(dotOp,
"requires Shared encoding for operand A");
return rewriter.notifyMatchFailure(
dotOp, "requires Shared encoding for operand A");

// Step 1: Performs checks for early stop
auto srcEnc = dyn_cast<BlockedEncodingAttr>(getEncoding(alloc.getSrc()));
if (!srcEnc)
return rewriter.notifyMatchFailure(alloc,
"requires src to have Blocked encoding");
return rewriter.notifyMatchFailure(
alloc, "requires src to have Blocked encoding");

auto dstEnc = dyn_cast<NvidiaMmaEncodingAttr>(getEncoding(dotOp.getResult()));
auto dstEnc =
dyn_cast<NvidiaMmaEncodingAttr>(getEncoding(dotOp.getResult()));
if (!dstEnc || dstEnc.getVersionMajor() != 3)
return rewriter.notifyMatchFailure(dotOp,
"requires result in NvidiaMma encoding");
return rewriter.notifyMatchFailure(
dotOp, "requires result in NvidiaMma encoding");

// Step 2: Obtain slice of ops between load/constant and local_alloc
SetVector<Operation *> slice;
BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = [&](Operation *op) {
// Stop before Load, ConstantOp, or LocalLoad
return (op->getParentRegion() == alloc->getParentRegion())
&& !isa<LoadOp, arith::ConstantOp, LocalLoadOp>(op)
&& (op->getNumOperands() != 0);
return (op->getParentRegion() == alloc->getParentRegion()) &&
!isa<LoadOp, arith::ConstantOp, LocalLoadOp>(op) &&
(op->getNumOperands() != 0);
};
getBackwardSlice(alloc.getOperation(), &slice, opt);

// Step 3: Verify slice can be hoisted through
if (slice.empty())
return rewriter.notifyMatchFailure(dotOp, "nothing to hoist through");

// We define frontierOp as an op outside this slice whose result is used by an op in
// this slice. We must eventually convert the result of all frontierOps to
// DotOperandEncoding. This is done via the insertion of ConvertLayout after each
// frontierOp.
// We currently support frontierOp to be load or constant.
// We define frontierOp as an op outside this slice whose result is used by
// an op in this slice. We must eventually convert the result of all
// frontierOps to DotOperandEncoding. This is done via the insertion of
// ConvertLayout after each frontierOp. We currently support frontierOp to
// be load or constant.
for (Operation *currOp : slice) {
if (!canHoistDotOpEncV3(currOp))
return rewriter.notifyMatchFailure(currOp, "cannot hoist through");
Expand All @@ -482,7 +488,8 @@ struct MMAV3HoistLayoutConversion
if (!slice.contains(defOp)) {
// ensure frontierOp is load or constant
if (!isa<LoadOp, arith::ConstantOp>(defOp))
return rewriter.notifyMatchFailure(defOp, "must be load or constant");
return rewriter.notifyMatchFailure(defOp,
"must be load or constant");
}
}
}
Expand All @@ -491,12 +498,14 @@ struct MMAV3HoistLayoutConversion
auto [newSlice, sliceMap] = cloneSlice(rewriter, slice);

// Step 5: Modify the cloned slice to have dotOp encoding.
// Before: load #blocked; (elementwise #blocked)+; local_alloc; warp_group_dot
// After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+; warp_group_dot
// Before: load #blocked; (elementwise #blocked)+; local_alloc;
// warp_group_dot After: load #blocked; convert_layout #dot_op;
// (elementwise #dot_op)+; warp_group_dot
//
// Specifically, this step will change all value types from #blocked to #dot_op
// encoding in the cloned slice, and for those values produced by frontierOps (i.e.,
// outside the slice), we will insert convert_layout's after the frontierOp.
// Specifically, this step will change all value types from #blocked to
// #dot_op encoding in the cloned slice, and for those values produced by
// frontierOps (i.e., outside the slice), we will insert convert_layout's
// after the frontierOp.
auto srcTy = cast<RankedTensorType>(alloc.getSrc().getType());
Type inputEltTy = srcTy.getElementType();
auto dotOperandEnc = DotOperandEncodingAttr::get(
Expand All @@ -509,8 +518,8 @@ struct MMAV3HoistLayoutConversion

auto defOp = operand.getDefiningOp();

// defOp is not frontier (i.e. it's within slice); no need to convert the
// layout of its result
// defOp is not frontier (i.e. it's within slice); no need to convert
// the layout of its result
if (newSlice.contains(defOp))
continue;

Expand All @@ -521,7 +530,8 @@ struct MMAV3HoistLayoutConversion
Type cvtTy = RankedTensorType::get(
operandTy.getShape(), operandTy.getElementType(), dotOperandEnc);
rewriter.setInsertionPoint(op);
auto cvt = rewriter.create<ConvertLayoutOp>(defOp->getLoc(), cvtTy, operand);
auto cvt =
rewriter.create<ConvertLayoutOp>(defOp->getLoc(), cvtTy, operand);

op->setOperand(oprIdx, cvt);
}
Expand All @@ -533,12 +543,11 @@ struct MMAV3HoistLayoutConversion
}

// Step 6: replace LHS operand with alloc's parent in the cloned slice
// This changes the warpGroupDot to accept a DotOp tensor as operand A instead of
// a Shared memdesc.
// This changes the warpGroupDot to accept a DotOp tensor as operand A
// instead of a Shared memdesc.
auto newDotOperand = sliceMap.lookup(alloc.getSrc());
rewriter.modifyOpInPlace(dotOp, [&]() {
dotOp.setOperand(0, newDotOperand);
});
rewriter.modifyOpInPlace(dotOp,
[&]() { dotOp.setOperand(0, newDotOperand); });

return success();
}
Expand Down
44 changes: 20 additions & 24 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,8 @@ static int createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,

auto convertBlockLayout = [&](Value val, ttg::BlockedEncodingAttr enc) {
auto ty = cast<RankedTensorType>(val.getType());
auto newTy =
RankedTensorType::get(ty.getShape(), ty.getElementType(), enc);
auto cvt =
builder.create<ttg::ConvertLayoutOp>(loc, newTy, val);
auto newTy = RankedTensorType::get(ty.getShape(), ty.getElementType(), enc);
auto cvt = builder.create<ttg::ConvertLayoutOp>(loc, newTy, val);
return cvt.getResult();
};

Expand Down Expand Up @@ -169,20 +167,16 @@ static int createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,

SmallVector<unsigned> newSizePerThread;
llvm::transform(blockEnc.getSizePerThread(),
std::back_inserter(newSizePerThread),
[&](auto size) { return std::min(size, sharedVec); });
std::back_inserter(newSizePerThread),
[&](auto size) { return std::min(size, sharedVec); });

if (newSizePerThread != blockEnc.getSizePerThread()) {
auto mod = loadOp->getParentOfType<ModuleOp>();
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod);
auto newBlockEnc = ttg::BlockedEncodingAttr::get(
loadOp.getContext(),
tensorTy.getShape(),
newSizePerThread,
blockEnc.getOrder(),
numWarps,
threadsPerWarp,
loadOp.getContext(), tensorTy.getShape(), newSizePerThread,
blockEnc.getOrder(), numWarps, threadsPerWarp,
blockEnc.getCTALayout());

src = convertBlockLayout(src, newBlockEnc);
Expand Down Expand Up @@ -528,8 +522,8 @@ assignMemoryLayouts(scf::ForOp &forOp,

loadInfo.usedByDot = true;
loadInfo.isMMAv3Shared = mmaLoadType == MMALoadType::SharedV3;
loadInfo.isMMAv3Registers = (mmaLoadType == MMALoadType::Registers)
&& warpGroupDot;
loadInfo.isMMAv3Registers =
(mmaLoadType == MMALoadType::Registers) && warpGroupDot;

if (loadInfo.isMMAv3Shared) {
loadInfo.sharedEncoding =
Expand Down Expand Up @@ -771,9 +765,9 @@ createAsyncOps(scf::ForOp &forOp,
auto &rhs) {
return lhs.distToUse < rhs.distToUse;
})->distToUse;
bool hasMMAV3 =
llvm::any_of(loadToInfo, [](auto &kv) {
return kv.second.isMMAv3Shared || kv.second.isMMAv3Registers; });
bool hasMMAV3 = llvm::any_of(loadToInfo, [](auto &kv) {
return kv.second.isMMAv3Shared || kv.second.isMMAv3Registers;
});
if (hasMMAV3) {
// For MMAv3, we need an extra buffer as this is assumed in the wgmma
// pipelining post-processing.
Expand Down Expand Up @@ -1182,14 +1176,15 @@ static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait,
//
// 1. All operands that touch shared memory are multi-buffered, i.e. can't read
// an incomplete value while it's being written asynchronously by a load.
// 1a. If operand A is in registers, these registers cannot be updated inside
// 1a. If operand A is in registers, these registers cannot be updated
// inside
// the loop.
// **Exception** if the operand is produced by a preceding WGMMA,
// then this op can be properly async. Either the f16 shortcut is possible
// and the WGMMA's can run back-to-back (see rule 3 below), or elementwise
// truncate is needed, in which case the preceding WGMMA is not async and
// a WarpGroupDotWait is inserted right after, which guarantees exclusive
// access to the operand registers.
// then this op can be properly async. Either the f16 shortcut is
// possible and the WGMMA's can run back-to-back (see rule 3 below), or
// elementwise truncate is needed, in which case the preceding WGMMA is
// not async and a WarpGroupDotWait is inserted right after, which
// guarantees exclusive access to the operand registers.
//
// 2. If the dot is used by any op in the loop, it must be used under an `if`,
// and will be synced with a `wait 0` at the beginning of the `if` block.
Expand Down Expand Up @@ -1228,7 +1223,8 @@ static std::optional<int> dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp,
// Rule 1a: Register operands must not be modified within the loop.
// First, check for chained WGMMA as an exception.
if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(operand.getDefiningOp())) {
return isa<ttg::NvidiaMmaEncodingAttr>(cvt.getSrc().getType().getEncoding());
return isa<ttg::NvidiaMmaEncodingAttr>(
cvt.getSrc().getType().getEncoding());
}
// And then, do a stricter-than-necessary check for now, that the operand
// is defined outside the loop.
Expand Down
12 changes: 7 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -978,10 +978,11 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {

MMALoadType getMMALoadType(Operation *loadOp) {
if (!loadOp->hasOneUse())
return MMALoadType::DoNotPipeline;
return MMALoadType::DoNotPipeline;

if (auto alloc = dyn_cast<ttg::LocalAllocOp>(*loadOp->getUsers().begin())) {
auto sharedEnc = cast<ttg::SharedEncodingAttr>(alloc.getType().getEncoding());
auto sharedEnc =
cast<ttg::SharedEncodingAttr>(alloc.getType().getEncoding());

if (!sharedEnc.getHasLeadingOffset())
return MMALoadType::DoNotPipeline;
Expand All @@ -995,8 +996,10 @@ MMALoadType getMMALoadType(Operation *loadOp) {
// be changed after FuseTranspositions Pass. So we only pipeline the
// load if the order of the loaded BlockedEncoding is the same as the
// order of the SharedEncoding it is converted to.
return oldOrder == newOrder ? MMALoadType::SharedV3 : MMALoadType::DoNotPipeline;
} else if (auto cvt = dyn_cast<ttg::ConvertLayoutOp>(*loadOp->getUsers().begin())) {
return oldOrder == newOrder ? MMALoadType::SharedV3
: MMALoadType::DoNotPipeline;
} else if (auto cvt =
dyn_cast<ttg::ConvertLayoutOp>(*loadOp->getUsers().begin())) {
auto resTy = dyn_cast<RankedTensorType>(cvt->getResultTypes()[0]);
if (!resTy) {
return MMALoadType::DoNotPipeline;
Expand All @@ -1012,7 +1015,6 @@ MMALoadType getMMALoadType(Operation *loadOp) {
}
}


namespace {

/// Detect dead arguments in scf.for op by assuming all the values are dead and
Expand Down
1 change: 0 additions & 1 deletion test/TritonGPU/loop-pipeline-hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -994,4 +994,3 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
tt.return %17#0 : tensor<128x16xf32, #mma>
}
}

Loading

0 comments on commit 20f9ba0

Please sign in to comment.