From 20f9ba08c8e6db58dcffdfbc8b21df7e6bf82b4f Mon Sep 17 00:00:00 2001 From: Gary Geng Date: Tue, 12 Nov 2024 22:24:10 +0000 Subject: [PATCH] Clang format --- .../Dialect/TritonGPU/Transforms/Utility.h | 5 +- .../TritonGPU/Transforms/LoopScheduling.cpp | 4 +- .../Transforms/OptimizeDotOperands.cpp | 111 ++++++++++-------- .../Pipeliner/MatmulLoopPipeline.cpp | 44 ++++--- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 12 +- test/TritonGPU/loop-pipeline-hopper.mlir | 1 - .../DotOpToLLVM/WGMMA.cpp | 3 +- 7 files changed, 94 insertions(+), 86 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 4baee61ab2f5..f1e361f64d66 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -197,8 +197,9 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible); enum class MMALoadType { SharedV3, - Registers, // may be v2 or v3 - DoNotPipeline, // could be a valid shared/registers MMA operand, but skip pipelining + Registers, // may be v2 or v3 + DoNotPipeline, // could be a valid shared/registers MMA operand, but skip + // pipelining }; MMALoadType getMMALoadType(Operation *loadOp); } // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp b/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp index 63e9823cfcb8..9d6d903f4d2c 100644 --- a/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp +++ b/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp @@ -149,8 +149,8 @@ filterPipelinedLoad(llvm::SmallVector> auto dot = dyn_cast(use); auto warpGroupDot = dyn_cast(use); bool isMMAv3Shared = mmaLoadType == MMALoadType::SharedV3; - bool isMMAv3Registers = (mmaLoadType == MMALoadType::Registers) - && warpGroupDot; + bool isMMAv3Registers = + (mmaLoadType == MMALoadType::Registers) && warpGroupDot; if (isMMAv3Shared) { hasSharedEncoding = true; diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 437ed0c40ba2..f5b9ba880a0a 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -23,7 +23,7 @@ namespace { // Roughly, whether op is elementwise and thus threads don't need // to exchange elements. But some ops are not currently supported even though // they meet that criterion. -bool canHoistDotOpEncV2(Operation* op, DotOperandEncodingAttr& dotOpEnc) { +bool canHoistDotOpEncV2(Operation *op, DotOperandEncodingAttr &dotOpEnc) { // Only consider custom conversions or arith ops. // TODO(jlebar): Is this too restrictive? if (!isa(op) && !isPureUnaryInlineAsm(op) && @@ -34,7 +34,7 @@ bool canHoistDotOpEncV2(Operation* op, DotOperandEncodingAttr& dotOpEnc) { // bitwidth is unable to realize that there is a mixed-precision dot // (hence kWidth = 1) but wants to hoist through the type conversion. if (isa(op) && dotOpEnc.getKWidth() == 1) - return false; + return false; // Currently, these instructions are not supported during lowering of // shared -> dot_operand layout. Not all types and type conversions are @@ -55,7 +55,7 @@ bool canHoistDotOpEncV2(Operation* op, DotOperandEncodingAttr& dotOpEnc) { // Analog of canHoistDotOpEncV2, but for MMAv3 (WGMMA where operand A // is in registers). -bool canHoistDotOpEncV3(Operation* op) { +bool canHoistDotOpEncV3(Operation *op) { // Must have exactly one result and at least one operand if (op->getNumOperands() == 0 || op->getNumResults() != 1) return false; @@ -64,7 +64,8 @@ bool canHoistDotOpEncV3(Operation* op) { auto tensorTy = dyn_cast(ty); if (!tensorTy) return false; - return isa(tensorTy.getEncoding()); + return isa( + tensorTy.getEncoding()); }; // Operands and results must be of RankedTensorType and Blocked or DotOp @@ -102,24 +103,27 @@ bool canHoistDotOpEncV3(Operation* op) { // returning a tuple (newSlice, sliceMap), where newSlice is the cloned slice, // and sliceMap the IRMapping that maps the ops and result values of the // original slice to those in the cloned slice. -auto cloneSlice(PatternRewriter& rewriter, const SetVector& slice) { +auto cloneSlice(PatternRewriter &rewriter, + const SetVector &slice) { IRMapping sliceMap; - SetVector newSlice; + SetVector newSlice; - // First pass: clone ops; the result values are cloned as well, but the operands still - // refer to the original result values + // First pass: clone ops; the result values are cloned as well, but the + // operands still refer to the original result values for (Operation *op : slice) { rewriter.setInsertionPoint(op); auto newOp = rewriter.clone(*op); newSlice.insert(newOp); sliceMap.map(op, newOp); - for (auto [result, newResult] : llvm::zip(op->getResults(), newOp->getResults())) { + for (auto [result, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) { assert(result != newResult); sliceMap.map(result, newResult); } } - // Second pass: replace operand references in cloned ops to point to cloned values + // Second pass: replace operand references in cloned ops to point to cloned + // values for (auto [op, newOp] : sliceMap.getOperationMap()) for (auto [oprIdx, operand] : llvm::enumerate(newOp->getOperands())) { auto defOp = operand.getDefiningOp(); @@ -405,21 +409,22 @@ struct MMAV3UseRegOperand } }; -// MMAV3's analog of HoistLayoutConversion, for operand A only; will make WarpGroupDot -// accept operand A in registers instead of shmem. +// MMAV3's analog of HoistLayoutConversion, for operand A only; will make +// WarpGroupDot accept operand A in registers instead of shmem. // // Before: load #blocked; (elementwise #blocked)+; local_alloc; warp_group_dot -// After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+; warp_group_dot +// After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+; +// warp_group_dot // -// Whereas (MMAV2) HoistLayoutConversion hoists thru one elementwise op at a time and -// requires multiple passes, this pattern will directly hoist the convert to the right -// place in one pass. +// Whereas (MMAV2) HoistLayoutConversion hoists thru one elementwise op at a +// time and requires multiple passes, this pattern will directly hoist the +// convert to the right place in one pass. // // Or, to be more precise, this pattern deletes the local_alloc op and inserts a -// convert_layout op after each load that warp_group_dot uses; so this is not simply hoisting -// a convert_layout op up as in V2, but can be considered as first changing local_alloc to -// convert_layout and then hoisting, which results in WGMMA now accepting operand A in DotOp -// layout rather than Shared. +// convert_layout op after each load that warp_group_dot uses; so this is not +// simply hoisting a convert_layout op up as in V2, but can be considered as +// first changing local_alloc to convert_layout and then hoisting, which results +// in WGMMA now accepting operand A in DotOp layout rather than Shared. struct MMAV3HoistLayoutConversion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -429,27 +434,28 @@ struct MMAV3HoistLayoutConversion // Can only hoist operand 0 auto alloc = dotOp.getOperand(0).getDefiningOp(); if (!alloc || !alloc.getSrc()) - return rewriter.notifyMatchFailure(dotOp, - "operand A must be produced by local_alloc"); + return rewriter.notifyMatchFailure( + dotOp, "operand A must be produced by local_alloc"); auto getEncoding = [](Value v) { return cast(v.getType()).getEncoding(); }; if (!isa(getEncoding(dotOp.getOperand(0)))) - return rewriter.notifyMatchFailure(dotOp, - "requires Shared encoding for operand A"); + return rewriter.notifyMatchFailure( + dotOp, "requires Shared encoding for operand A"); // Step 1: Performs checks for early stop auto srcEnc = dyn_cast(getEncoding(alloc.getSrc())); if (!srcEnc) - return rewriter.notifyMatchFailure(alloc, - "requires src to have Blocked encoding"); + return rewriter.notifyMatchFailure( + alloc, "requires src to have Blocked encoding"); - auto dstEnc = dyn_cast(getEncoding(dotOp.getResult())); + auto dstEnc = + dyn_cast(getEncoding(dotOp.getResult())); if (!dstEnc || dstEnc.getVersionMajor() != 3) - return rewriter.notifyMatchFailure(dotOp, - "requires result in NvidiaMma encoding"); + return rewriter.notifyMatchFailure( + dotOp, "requires result in NvidiaMma encoding"); // Step 2: Obtain slice of ops between load/constant and local_alloc SetVector slice; @@ -457,9 +463,9 @@ struct MMAV3HoistLayoutConversion opt.omitBlockArguments = true; opt.filter = [&](Operation *op) { // Stop before Load, ConstantOp, or LocalLoad - return (op->getParentRegion() == alloc->getParentRegion()) - && !isa(op) - && (op->getNumOperands() != 0); + return (op->getParentRegion() == alloc->getParentRegion()) && + !isa(op) && + (op->getNumOperands() != 0); }; getBackwardSlice(alloc.getOperation(), &slice, opt); @@ -467,11 +473,11 @@ struct MMAV3HoistLayoutConversion if (slice.empty()) return rewriter.notifyMatchFailure(dotOp, "nothing to hoist through"); - // We define frontierOp as an op outside this slice whose result is used by an op in - // this slice. We must eventually convert the result of all frontierOps to - // DotOperandEncoding. This is done via the insertion of ConvertLayout after each - // frontierOp. - // We currently support frontierOp to be load or constant. + // We define frontierOp as an op outside this slice whose result is used by + // an op in this slice. We must eventually convert the result of all + // frontierOps to DotOperandEncoding. This is done via the insertion of + // ConvertLayout after each frontierOp. We currently support frontierOp to + // be load or constant. for (Operation *currOp : slice) { if (!canHoistDotOpEncV3(currOp)) return rewriter.notifyMatchFailure(currOp, "cannot hoist through"); @@ -482,7 +488,8 @@ struct MMAV3HoistLayoutConversion if (!slice.contains(defOp)) { // ensure frontierOp is load or constant if (!isa(defOp)) - return rewriter.notifyMatchFailure(defOp, "must be load or constant"); + return rewriter.notifyMatchFailure(defOp, + "must be load or constant"); } } } @@ -491,12 +498,14 @@ struct MMAV3HoistLayoutConversion auto [newSlice, sliceMap] = cloneSlice(rewriter, slice); // Step 5: Modify the cloned slice to have dotOp encoding. - // Before: load #blocked; (elementwise #blocked)+; local_alloc; warp_group_dot - // After: load #blocked; convert_layout #dot_op; (elementwise #dot_op)+; warp_group_dot + // Before: load #blocked; (elementwise #blocked)+; local_alloc; + // warp_group_dot After: load #blocked; convert_layout #dot_op; + // (elementwise #dot_op)+; warp_group_dot // - // Specifically, this step will change all value types from #blocked to #dot_op - // encoding in the cloned slice, and for those values produced by frontierOps (i.e., - // outside the slice), we will insert convert_layout's after the frontierOp. + // Specifically, this step will change all value types from #blocked to + // #dot_op encoding in the cloned slice, and for those values produced by + // frontierOps (i.e., outside the slice), we will insert convert_layout's + // after the frontierOp. auto srcTy = cast(alloc.getSrc().getType()); Type inputEltTy = srcTy.getElementType(); auto dotOperandEnc = DotOperandEncodingAttr::get( @@ -509,8 +518,8 @@ struct MMAV3HoistLayoutConversion auto defOp = operand.getDefiningOp(); - // defOp is not frontier (i.e. it's within slice); no need to convert the - // layout of its result + // defOp is not frontier (i.e. it's within slice); no need to convert + // the layout of its result if (newSlice.contains(defOp)) continue; @@ -521,7 +530,8 @@ struct MMAV3HoistLayoutConversion Type cvtTy = RankedTensorType::get( operandTy.getShape(), operandTy.getElementType(), dotOperandEnc); rewriter.setInsertionPoint(op); - auto cvt = rewriter.create(defOp->getLoc(), cvtTy, operand); + auto cvt = + rewriter.create(defOp->getLoc(), cvtTy, operand); op->setOperand(oprIdx, cvt); } @@ -533,12 +543,11 @@ struct MMAV3HoistLayoutConversion } // Step 6: replace LHS operand with alloc's parent in the cloned slice - // This changes the warpGroupDot to accept a DotOp tensor as operand A instead of - // a Shared memdesc. + // This changes the warpGroupDot to accept a DotOp tensor as operand A + // instead of a Shared memdesc. auto newDotOperand = sliceMap.lookup(alloc.getSrc()); - rewriter.modifyOpInPlace(dotOp, [&]() { - dotOp.setOperand(0, newDotOperand); - }); + rewriter.modifyOpInPlace(dotOp, + [&]() { dotOp.setOperand(0, newDotOperand); }); return success(); } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 6dcd21a307b3..2044f87245ab 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -125,10 +125,8 @@ static int createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, auto convertBlockLayout = [&](Value val, ttg::BlockedEncodingAttr enc) { auto ty = cast(val.getType()); - auto newTy = - RankedTensorType::get(ty.getShape(), ty.getElementType(), enc); - auto cvt = - builder.create(loc, newTy, val); + auto newTy = RankedTensorType::get(ty.getShape(), ty.getElementType(), enc); + auto cvt = builder.create(loc, newTy, val); return cvt.getResult(); }; @@ -169,20 +167,16 @@ static int createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, SmallVector newSizePerThread; llvm::transform(blockEnc.getSizePerThread(), - std::back_inserter(newSizePerThread), - [&](auto size) { return std::min(size, sharedVec); }); + std::back_inserter(newSizePerThread), + [&](auto size) { return std::min(size, sharedVec); }); if (newSizePerThread != blockEnc.getSizePerThread()) { auto mod = loadOp->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod); auto newBlockEnc = ttg::BlockedEncodingAttr::get( - loadOp.getContext(), - tensorTy.getShape(), - newSizePerThread, - blockEnc.getOrder(), - numWarps, - threadsPerWarp, + loadOp.getContext(), tensorTy.getShape(), newSizePerThread, + blockEnc.getOrder(), numWarps, threadsPerWarp, blockEnc.getCTALayout()); src = convertBlockLayout(src, newBlockEnc); @@ -528,8 +522,8 @@ assignMemoryLayouts(scf::ForOp &forOp, loadInfo.usedByDot = true; loadInfo.isMMAv3Shared = mmaLoadType == MMALoadType::SharedV3; - loadInfo.isMMAv3Registers = (mmaLoadType == MMALoadType::Registers) - && warpGroupDot; + loadInfo.isMMAv3Registers = + (mmaLoadType == MMALoadType::Registers) && warpGroupDot; if (loadInfo.isMMAv3Shared) { loadInfo.sharedEncoding = @@ -771,9 +765,9 @@ createAsyncOps(scf::ForOp &forOp, auto &rhs) { return lhs.distToUse < rhs.distToUse; })->distToUse; - bool hasMMAV3 = - llvm::any_of(loadToInfo, [](auto &kv) { - return kv.second.isMMAv3Shared || kv.second.isMMAv3Registers; }); + bool hasMMAV3 = llvm::any_of(loadToInfo, [](auto &kv) { + return kv.second.isMMAv3Shared || kv.second.isMMAv3Registers; + }); if (hasMMAV3) { // For MMAv3, we need an extra buffer as this is assumed in the wgmma // pipelining post-processing. @@ -1182,14 +1176,15 @@ static void threadValuesThroughWait(ttng::WarpGroupDotWaitOp wait, // // 1. All operands that touch shared memory are multi-buffered, i.e. can't read // an incomplete value while it's being written asynchronously by a load. -// 1a. If operand A is in registers, these registers cannot be updated inside +// 1a. If operand A is in registers, these registers cannot be updated +// inside // the loop. // **Exception** if the operand is produced by a preceding WGMMA, -// then this op can be properly async. Either the f16 shortcut is possible -// and the WGMMA's can run back-to-back (see rule 3 below), or elementwise -// truncate is needed, in which case the preceding WGMMA is not async and -// a WarpGroupDotWait is inserted right after, which guarantees exclusive -// access to the operand registers. +// then this op can be properly async. Either the f16 shortcut is +// possible and the WGMMA's can run back-to-back (see rule 3 below), or +// elementwise truncate is needed, in which case the preceding WGMMA is +// not async and a WarpGroupDotWait is inserted right after, which +// guarantees exclusive access to the operand registers. // // 2. If the dot is used by any op in the loop, it must be used under an `if`, // and will be synced with a `wait 0` at the beginning of the `if` block. @@ -1228,7 +1223,8 @@ static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, // Rule 1a: Register operands must not be modified within the loop. // First, check for chained WGMMA as an exception. if (auto cvt = dyn_cast(operand.getDefiningOp())) { - return isa(cvt.getSrc().getType().getEncoding()); + return isa( + cvt.getSrc().getType().getEncoding()); } // And then, do a stricter-than-necessary check for now, that the operand // is defined outside the loop. diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 27a9a1499fde..bb60c1821ad7 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -978,10 +978,11 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { MMALoadType getMMALoadType(Operation *loadOp) { if (!loadOp->hasOneUse()) - return MMALoadType::DoNotPipeline; + return MMALoadType::DoNotPipeline; if (auto alloc = dyn_cast(*loadOp->getUsers().begin())) { - auto sharedEnc = cast(alloc.getType().getEncoding()); + auto sharedEnc = + cast(alloc.getType().getEncoding()); if (!sharedEnc.getHasLeadingOffset()) return MMALoadType::DoNotPipeline; @@ -995,8 +996,10 @@ MMALoadType getMMALoadType(Operation *loadOp) { // be changed after FuseTranspositions Pass. So we only pipeline the // load if the order of the loaded BlockedEncoding is the same as the // order of the SharedEncoding it is converted to. - return oldOrder == newOrder ? MMALoadType::SharedV3 : MMALoadType::DoNotPipeline; - } else if (auto cvt = dyn_cast(*loadOp->getUsers().begin())) { + return oldOrder == newOrder ? MMALoadType::SharedV3 + : MMALoadType::DoNotPipeline; + } else if (auto cvt = + dyn_cast(*loadOp->getUsers().begin())) { auto resTy = dyn_cast(cvt->getResultTypes()[0]); if (!resTy) { return MMALoadType::DoNotPipeline; @@ -1012,7 +1015,6 @@ MMALoadType getMMALoadType(Operation *loadOp) { } } - namespace { /// Detect dead arguments in scf.for op by assuming all the values are dead and diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index d2ebd0a1357b..2c1d6abb5d5e 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -994,4 +994,3 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : tt.return %17#0 : tensor<128x16xf32, #mma> } } - diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 07baaa3f20ad..9b1667db7083 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -284,7 +284,8 @@ DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter, // // This ordering is decided when a tensor in DotOpEnc is lowered into llvm. // For WGMMA this happens in both SharedToDotOperand and MMAToDotOperand. -// Thus, both lowerings must obey this above ordering for the below code to be correct. +// Thus, both lowerings must obey this above ordering for the below code to be +// correct. llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, Location loc, const SmallVector &elements,