From be46ef1340ed752e27675bd8216747573b4c429d Mon Sep 17 00:00:00 2001 From: Pawel Szczerbuk Date: Mon, 26 Aug 2024 12:54:21 -0700 Subject: [PATCH 1/7] Base for perf data collection --- python/tutorials/03-matrix-multiplication.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 91f751207b8e..75e619d4c974 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -377,7 +377,7 @@ def matmul(a, b, activation=""): b = torch.randn((512, 512), device="cuda", dtype=torch.float16) a = a.to(torch.float8_e5m2) # pre-transpose b for efficiency. - b = b.T + b = b.T.contiguous().T b = b.to(torch.float8_e5m2) triton_output = matmul(a, b) torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16)) @@ -401,13 +401,13 @@ def matmul(a, b, activation=""): ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS' configs = [] -for fp8_inputs in [False, True]: +for fp8_inputs in [True]: if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()): continue configs.append( triton.testing.Benchmark( - x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot - x_vals=[128 * i for i in range(2, 33)], # Different possible values for `x_name` + x_names=["K"], # Argument names to use as an x-axis for the plot + x_vals=[512 * i for i in range(1, 17)], # Different possible values for `x_name` line_arg="provider", # Argument name whose value corresponds to a different line in the plot # Possible values for `line_arg` # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. @@ -422,12 +422,13 @@ def matmul(a, b, activation=""): @triton.testing.perf_report(configs) -def benchmark(M, N, K, provider, fp8_inputs): +def benchmark(K, provider, fp8_inputs): + M, N = 8192, 8192 a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) if TORCH_HAS_FP8 and fp8_inputs: a = a.to(torch.float8_e5m2) - b = b.T + b = b.T.contiguous().T b = b.to(torch.float8_e5m2) quantiles = [0.5, 0.2, 0.8] if provider == ref_lib.lower(): From 668bf993ff69f10244737e530e4054859ed2589d Mon Sep 17 00:00:00 2001 From: Pawel Szczerbuk Date: Wed, 28 Aug 2024 16:23:30 -0700 Subject: [PATCH 2/7] First hack to get AccInit optimization to work --- include/triton/Dialect/Triton/IR/Traits.h | 4 ++-- .../TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td | 3 ++- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 4 ++-- .../include/Dialect/NVGPU/IR/NVGPUOps.td | 4 ++-- .../lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 8 +++++-- .../DotOpToLLVM/WGMMA.cpp | 24 ++++++++++++------- 6 files changed, 29 insertions(+), 18 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/Traits.h b/include/triton/Dialect/Triton/IR/Traits.h index 6e554aed2358..93e4a0e9d2ff 100644 --- a/include/triton/Dialect/Triton/IR/Traits.h +++ b/include/triton/Dialect/Triton/IR/Traits.h @@ -67,8 +67,8 @@ template class DotLike : public TraitBase { public: static LogicalResult verifyTrait(Operation *op) { - if (op->getNumOperands() != 3) - return op->emitOpError("expected 3 operands"); + // if (op->getNumOperands() != 3) + // return op->emitOpError("expected 3 operands"); auto aTy = cast(op->getOperand(0).getType()); auto bTy = cast(op->getOperand(1).getType()); auto cTy = cast(op->getOperand(2).getType()); diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index c96799780368..d5e531b71992 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -83,13 +83,14 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods:$useC, DefaultValuedAttr:$inputPrecision, DefaultValuedAttr:$maxNumImpreciseAcc, DefaultValuedAttr:$isAsync); let results = (outs TT_FpIntTensor:$d); - let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)"; + let assemblyFormat = "$a`,` $b`,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)"; } def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods, diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 39c043695bc6..69f8343ac4fc 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -306,8 +306,8 @@ class BlockedToMMA : public mlir::OpRewritePattern { a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose); b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); newDot = rewriter.create( - dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getInputPrecision(), - dotOp.getMaxNumImpreciseAcc(), false); + dotOp.getLoc(), newRetType, a, b, newAcc, nullptr, + dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc(), false); } else { // convert operands int minBitwidth = diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index 7affd8840612..6f138089cde6 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -79,12 +79,12 @@ def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType", def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">; def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { - let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, Optional:$opC, + let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, I1:$useC, Optional:$opC, I32Attr:$m, I32Attr:$n, I32Attr:$k, WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB, WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB); let results = (outs LLVM_AnyStruct:$res); - let assemblyFormat = "$opA `,` $opB (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; + let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; } def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> { diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index b075ca31a407..b6fa7a1d65b7 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -348,6 +348,7 @@ class WGMMAOpPattern : public OpRewritePattern { auto opA = op.getOpA(); auto opB = op.getOpB(); auto opC = op.getOpC(); + auto opScaleD = op.getUseC(); auto typeA = opA.getType(); auto structTypeA = dyn_cast(typeA); @@ -364,6 +365,9 @@ class WGMMAOpPattern : public OpRewritePattern { // Operand B (must be `desc`) operandsAndConstraints.push_back({opB, "l"}); + + // `scale-d` + operandsAndConstraints.push_back({opScaleD, "b"}); return operandsAndConstraints; } @@ -460,8 +464,8 @@ class WGMMAOpPattern : public OpRewritePattern { // Operand B (must be `desc`) args += "$" + std::to_string(asmOpIdx++) + ", "; - // `scale-d` is 1 if we have a C operand. - args += op.getOpC() ? "1" : "0"; + // `scale-d` + args += "$" + std::to_string(asmOpIdx++); // `imm-scale-a`, and `imm-scale-b` are 1 by default only for float-based // WGMMA diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index baed96a29704..9d7aefcec041 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -357,9 +357,9 @@ static SmallVector emitWait(ConversionPatternRewriter &rewriter, LogicalResult convertDot(const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, Operation *op, Value a, Value b, Value c, Value d, - Value loadedA, Value loadedB, Value loadedC, - bool allowTF32, uint32_t maxNumImpreciseAcc, bool sync, - Value thread) { + Value useCOperand, Value loadedA, Value loadedB, + Value loadedC, bool allowTF32, + uint32_t maxNumImpreciseAcc, bool sync, Value thread) { auto aTensorTy = cast(a.getType()); auto bTensorTy = cast(b.getType()); auto dTensorTy = cast(d.getType()); @@ -436,8 +436,13 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, auto accTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); Value d; - if (!zeroAcc) + Value useC = i1_val(0); + if (!zeroAcc) { d = packLLElements(loc, typeConverter, mmaOut, rewriter, accTy); + useC = i1_val(true); + } + if (useCOperand) + useC = and_(useC, useCOperand); uint32_t numLowPrecisionAcc = 0; Value partialAcc; for (int k = 0; k < numRepK; ++k) { @@ -463,8 +468,9 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, (numLowPrecisionAcc >= maxNumImpreciseAcc || k == numRepK - 1); Value mmaAcc = needsPartialAccumulator ? partialAcc : d; mmaAcc = rewriter.create( - loc, accTy, a, b, mmaAcc, M, N, K, eltTypeC, eltTypeA, eltTypeB, - layoutA, layoutB); + loc, accTy, a, b, useC, mmaAcc, M, N, K, eltTypeC, eltTypeA, + eltTypeB, layoutA, layoutB); + useC = i1_val(1); if (needsPartialAccumulator) partialAcc = mmaAcc; else @@ -510,9 +516,9 @@ LogicalResult convertWGMMA(triton::nvidia_gpu::WarpGroupDotOp op, mlir::isa(AEnc)); assert(mlir::isa(BEnc) && "Operand B should use Shared layout."); - return convertDot(typeConverter, rewriter, op.getLoc(), op.getOperation(), // - op.getA(), op.getB(), op.getC(), op.getD(), // - adaptor.getA(), adaptor.getB(), adaptor.getC(), + return convertDot(typeConverter, rewriter, op.getLoc(), op.getOperation(), // + op.getA(), op.getB(), op.getC(), op.getD(), op.getUseC(), // + adaptor.getA(), adaptor.getB(), adaptor.getC(), // op.getInputPrecision() == InputPrecision::TF32, op.getMaxNumImpreciseAcc(), !op.getIsAsync(), thread); } From 4c5d950ad17e9873c0b7df49db1aa7f19ae1de93 Mon Sep 17 00:00:00 2001 From: Pawel Szczerbuk Date: Fri, 6 Sep 2024 16:24:50 -0700 Subject: [PATCH 3/7] New pass ready, working with Hopper pipelining --- include/triton/Dialect/Triton/IR/Traits.h | 4 +- .../Dialect/TritonGPU/Transforms/Passes.td | 10 + .../TritonGPU/Transforms/CMakeLists.txt | 1 + .../Transforms/OptimizeAccumulatorInit.cpp | 205 ++++++++++ python/src/passes.cc | 2 + test/TritonGPU/accumulator-init.mlir | 351 ++++++++++++++++++ third_party/nvidia/backend/compiler.py | 1 + 7 files changed, 572 insertions(+), 2 deletions(-) create mode 100644 lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp create mode 100644 test/TritonGPU/accumulator-init.mlir diff --git a/include/triton/Dialect/Triton/IR/Traits.h b/include/triton/Dialect/Triton/IR/Traits.h index 93e4a0e9d2ff..e47c023a29b0 100644 --- a/include/triton/Dialect/Triton/IR/Traits.h +++ b/include/triton/Dialect/Triton/IR/Traits.h @@ -67,8 +67,8 @@ template class DotLike : public TraitBase { public: static LogicalResult verifyTrait(Operation *op) { - // if (op->getNumOperands() != 3) - // return op->emitOpError("expected 3 operands"); + if (op->getNumOperands() < 3) + return op->emitOpError("expected at least 3 operands"); auto aTy = cast(op->getOperand(0).getType()); auto bTy = cast(op->getOperand(1).getType()); auto cTy = cast(op->getOperand(2).getType()); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 7469e39f837b..f2b79d222a91 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -169,4 +169,14 @@ def TritonGPUCombineTensorSelectAndIf: Pass<"tritongpu-combine-tensor-select-and "mlir::triton::TritonDialect"]; } +def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init", "mlir::ModuleOp"> { + let summary = "Replace accumulater zero-initialization with the flag indicating first use of the accumulator"; + + let description = "For the dot operations that support accumulator-use flag this pass replaces the zero-initialization " + "of the accumulator with the flag indicating the first use of the accumulator."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + #endif diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 9767effa5a74..99e2ac3c9660 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_triton_library(TritonGPUTransforms F32DotTC.cpp CombineTensorSelectAndIf.cpp ReduceDataDuplication.cpp + OptimizeAccumulatorInit.cpp OptimizeDotOperands.cpp OptimizeThreadLocality.cpp Pipeliner/MatmulLoopPipeline.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp new file mode 100644 index 000000000000..b5ed64cca3ac --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp @@ -0,0 +1,205 @@ +#include "mlir/Transforms/Passes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZEACCUMULATORINIT +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { +bool dotSupportsAccInitFlag(Operation *op) { + assert(op->hasTrait() && "Expected a dot-like operation"); + return isa(op); +} + +std::pair getAccumulatorUseAndDef(Operation *op) { + assert(op->hasTrait() && "Expected a dot-like operation"); + if (auto wgDotOp = dyn_cast(op)) { + return std::make_pair(wgDotOp.getC(), wgDotOp); + } + return std::make_pair(nullptr, nullptr); +} + +void setUseAccFlag(Operation *op, Value useAcc) { + assert(op->hasTrait() && "Expected a dot-like operation"); + if (auto wgDotOp = dyn_cast(op)) { + wgDotOp.getUseCMutable().assign(useAcc); + } +} + +bool isConstantZeroTensor(Value v) { + auto constOp = v.getDefiningOp(); + if (!constOp) + return false; + auto splat = mlir::dyn_cast(constOp.getValue()); + if (!splat) + return false; + return splat.getSplatValue().getValue().convertToFloat() == 0.0f; +} + +std::optional> findZeroInitOp(Value accUse, + Operation *accDef, + scf::ForOp forOp, + bool &loopArgIsZero) { + Value v = accUse; + if (auto arg = dyn_cast(v)) { + assert(arg.getOwner() == forOp.getBody()); + if (isConstantZeroTensor(forOp.getInitArgs()[arg.getArgNumber() - 1])) { + loopArgIsZero = true; + } + v = forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + + auto defOp = v.getDefiningOp(); + if (!defOp) { + return std::nullopt; + } + if (auto selOp = dyn_cast(defOp)) { + if (isConstantZeroTensor(selOp.getTrueValue()) || + isConstantZeroTensor(selOp.getFalseValue())) { + return std::make_pair(selOp, 0); + } + } + if (auto ifOp = dyn_cast(defOp)) { + unsigned resultIndex = 0; + for (; resultIndex < ifOp.getNumResults(); ++resultIndex) { + if (ifOp.getResult(resultIndex) == v) + break; + } + Value thenVal = ifOp.thenYield()->getOperand(resultIndex); + Value elseVal = ifOp.elseYield()->getOperand(resultIndex); + if (isConstantZeroTensor(thenVal) || isConstantZeroTensor(elseVal)) { + // Make sure that the other value is not defined in the if itself, but + // passed from outside + if (thenVal.getParentBlock()->getParentOp() == ifOp || + elseVal.getParentBlock()->getParentOp() == ifOp) { + return std::nullopt; + } + return std::make_pair(ifOp, resultIndex); + } + } + return std::nullopt; +} + +} // namespace + +class OptimizeAccumulatorInitPass + : public impl::TritonGPUOptimizeAccumulatorInitBase< + OptimizeAccumulatorInitPass> { +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + SmallVector mmaOps; + m.walk([&](Operation *op) { + if (op->hasTrait() && dotSupportsAccInitFlag(op)) { + mmaOps.push_back(op); + } + }); + + // for each mma op, find where the accumulator is initialized with zero + // It can be: + // 1. A constant zero + // 2. Initialized with zero as the loop argument + // 3. Initialized with zero in the if op or with a select op in current + // or any of the previous loop iterations + for (Operation *mmaOp : mmaOps) { + Location loc = mmaOp->getLoc(); + + scf::ForOp forOp = dyn_cast(mmaOp->getParentOp()); + if (!forOp) { + continue; + } + + IRRewriter rewriter(forOp); + rewriter.setInsertionPoint(forOp); + + Value vTrue = + rewriter.create(loc, rewriter.getBoolAttr(true)); + Value vFalse = + rewriter.create(loc, rewriter.getBoolAttr(false)); + + // Find the accumulator + auto [accUse, accDef] = getAccumulatorUseAndDef(mmaOp); + if (!accUse || !accDef) { + continue; + } + if (isConstantZeroTensor(accUse)) { + setUseAccFlag(mmaOp, vFalse); + continue; + } + + bool loopArgIsZero = false; + std::optional> zeroInitOp = + findZeroInitOp(accUse, accDef, forOp, loopArgIsZero); + if (!zeroInitOp) { + continue; + } + + Value loopArgFlagValue = loopArgIsZero ? vFalse : vTrue; + scf::ForOp newForOp = + replaceForOpWithNewSignature(rewriter, forOp, {loopArgFlagValue}); + forOp.erase(); + forOp = newForOp; + loopArgFlagValue = + forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 1); + + Value condition = nullptr; + Value oldValue = nullptr; + Value zeroValue = nullptr; + bool thenInitsToZero = false; + if (auto selOp = dyn_cast(zeroInitOp->first)) { + condition = selOp.getCondition(); + oldValue = isConstantZeroTensor(selOp.getTrueValue()) + ? selOp.getFalseValue() + : selOp.getTrueValue(); + zeroValue = isConstantZeroTensor(selOp.getTrueValue()) + ? selOp.getTrueValue() + : selOp.getFalseValue(); + thenInitsToZero = isConstantZeroTensor(selOp.getTrueValue()); + } else { + assert(isa(*zeroInitOp->first) && "Expected an if op"); + auto ifOp = cast(zeroInitOp->first); + unsigned resultIndex = zeroInitOp->second; + condition = ifOp.getCondition(); + Value thenVal = ifOp.thenYield()->getOperand(resultIndex); + Value elseVal = ifOp.elseYield()->getOperand(resultIndex); + oldValue = isConstantZeroTensor(thenVal) ? elseVal : thenVal; + zeroValue = isConstantZeroTensor(thenVal) ? thenVal : elseVal; + thenInitsToZero = isConstantZeroTensor(thenVal); + } + + // Create a select op that updates the flag + rewriter.setInsertionPoint(zeroInitOp->first); + bool zeroingBeforeMMA = zeroInitOp->first->isBeforeInBlock(mmaOp); + Value prevFlagValue = zeroingBeforeMMA ? loopArgFlagValue : vTrue; + auto selectFlagOp = rewriter.create( + loc, condition, thenInitsToZero ? vFalse : prevFlagValue, + thenInitsToZero ? prevFlagValue : vFalse); + setUseAccFlag(mmaOp, zeroingBeforeMMA ? selectFlagOp : loopArgFlagValue); + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield->insertOperands(forYield->getNumOperands(), + {zeroingBeforeMMA ? vTrue : selectFlagOp}); + + // Stop clearing out the accumulator with zero + if (auto selOp = dyn_cast(zeroInitOp->first)) { + rewriter.setInsertionPoint(selOp); + rewriter.replaceOp(selOp, oldValue); + } else { + auto ifOp = cast(zeroInitOp->first); + int resultIndex = zeroInitOp->second; + auto zeroingYield = + thenInitsToZero ? ifOp.thenYield() : ifOp.elseYield(); + zeroingYield.setOperand(resultIndex, oldValue); + } + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/python/src/passes.cc b/python/src/passes.cc index 513e811d28ad..50cc402365b0 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -65,6 +65,8 @@ void init_triton_passes_ttgpuir(py::module &&m) { createAllocateSharedMemoryPass); ADD_PASS_WRAPPER_0("add_combine_tensor_select_and_if", createTritonGPUCombineTensorSelectAndIf); + ADD_PASS_WRAPPER_0("add_optimize_accumulator_init", + createTritonGPUOptimizeAccumulatorInit); } void init_triton_passes_convert(py::module &&m) { diff --git a/test/TritonGPU/accumulator-init.mlir b/test/TritonGPU/accumulator-init.mlir new file mode 100644 index 000000000000..72ef11dcafd7 --- /dev/null +++ b/test/TritonGPU/accumulator-init.mlir @@ -0,0 +1,351 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-optimize-accumulator-init | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + +// CHECK-LABEL: @constant_init +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] + tt.func @constant_init(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %cst_2 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @if_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_after_mma_invert +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[TRUE]], %[[FALSE]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @if_after_mma_invert(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %acc : tensor<128x16xf32, #mma1> + } else { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_before_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[ACC]] +// CHECK: else +// CHECK: scf.yield %[[ACC]] +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @if_before_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %arg4 : tensor<128x16xf32, #mma1> + } + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_before_mma_invert +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[USE_ACC]], %[[FALSE]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[ACC]] +// CHECK: else +// CHECK: scf.yield %[[ACC]] +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @if_before_mma_invert(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %arg4 : tensor<128x16xf32, #mma1> + } else { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @sel_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @sel_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @sel_before_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @sel_before_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xf32, #mma1> + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + + +// Check that we look only at the zeroing directly preceding the mma + +// CHECK-LABEL: @if_before_and_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[ACC]] +// CHECK: else +// CHECK: scf.yield %[[ACC]] +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[C0_TENSOR]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @if_before_and_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_0 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %arg4 : tensor<128x16xf32, #mma1> + } + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_0 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_1 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + scf.yield %acc_1: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @two_ifs_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[C0_TENSOR]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[ACC_CND]] +// CHECK: else +// CHECK: scf.yield %[[ACC_CND]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @two_ifs_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_0 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + %acc_1 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc_0 : tensor<128x16xf32, #mma1> + } + scf.yield %acc_1: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// Check that we bail out in unsupported cases + +// CHECK-LABEL: @non_zero_init +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @non_zero_init(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @zero_init_dist_2 +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @zero_init_dist_2(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %cst_2) -> (tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg5 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> + scf.yield %acc_, %arg4: tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_defines_alternative +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @if_defines_alternative(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + %acc_alt = arith.addf %acc, %cst_3 : tensor<128x16xf32, #mma1> + scf.yield %acc_alt : tensor<128x16xf32, #mma1> + } + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @non_cond_override +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @non_cond_override(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = arith.addf %acc, %cst_3 : tensor<128x16xf32, #mma1> + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } +} diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 5dd75e530fec..66c94d88ce3f 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -202,6 +202,7 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.common.add_cse(pm) if capability // 10 >= 8: + passes.ttgpuir.add_optimize_accumulator_init(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.ttgpuir.add_pipeline(pm, opt.num_stages) passes.ttgpuir.add_prefetch(pm) From b0d7fe1ab8d84631965a9cd137beace2200c876a Mon Sep 17 00:00:00 2001 From: Pawel Szczerbuk Date: Mon, 9 Sep 2024 09:40:19 -0700 Subject: [PATCH 4/7] Updating lowering lit tests --- test/Conversion/nvgpu_to_llvm.mlir | 5 +++-- test/Conversion/tritongpu_to_llvm_hopper.mlir | 3 ++- third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test/Conversion/nvgpu_to_llvm.mlir b/test/Conversion/nvgpu_to_llvm.mlir index cef074bb6935..beaa4c952d4e 100644 --- a/test/Conversion/nvgpu_to_llvm.mlir +++ b/test/Conversion/nvgpu_to_llvm.mlir @@ -71,7 +71,8 @@ llvm.func @st_matrix(%i: i32, %ptr: !llvm.ptr<3>) { // CHECK-LABEL: @wgmma llvm.func @wgmma(%desc: i64, %in: !struct_64xf32) { // CHECK: wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 -%acc0 = nvgpu.wgmma %desc, %desc { +%false = llvm.mlir.constant(false) : i1 +%acc0 = nvgpu.wgmma %desc, %desc, %false { eltTypeA = 3 : i32, eltTypeB = 3 : i32, eltTypeC = 7 : i32, @@ -80,7 +81,7 @@ llvm.func @wgmma(%desc: i64, %in: !struct_64xf32) { m = 64 : i32, n = 256 : i32, k = 32 : i32 -} : (i64, i64) -> !struct_128xf32 +} : (i64, i64, i1) -> !struct_128xf32 // CHECK: // wait for regs: $0,$1,$2,{{.*}},$127 // CHECK: wgmma.wait_group.sync.aligned 0; diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 7ecee2eba11b..cf4c19361650 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -97,8 +97,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %false = arith.constant false : i1 %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> - %m = triton_nvidia_gpu.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: + %m = triton_nvidia_gpu.warp_group_dot %opA, %b, %false, %cst { inputPrecision = 0 : i32 }: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return } diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index 6f138089cde6..31b2646db81d 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -79,7 +79,7 @@ def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType", def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">; def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { - let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, I1:$useC, Optional:$opC, + let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, I1:$useC, Optional:$opC, I32Attr:$m, I32Attr:$n, I32Attr:$k, WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB, WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB); From 445a16e921a40fae820a0edec26ff5171abe52a5 Mon Sep 17 00:00:00 2001 From: Pawel Szczerbuk Date: Mon, 9 Sep 2024 09:47:20 -0700 Subject: [PATCH 5/7] Reverting tutorial change --- python/tutorials/03-matrix-multiplication.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 75e619d4c974..91f751207b8e 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -377,7 +377,7 @@ def matmul(a, b, activation=""): b = torch.randn((512, 512), device="cuda", dtype=torch.float16) a = a.to(torch.float8_e5m2) # pre-transpose b for efficiency. - b = b.T.contiguous().T + b = b.T b = b.to(torch.float8_e5m2) triton_output = matmul(a, b) torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16)) @@ -401,13 +401,13 @@ def matmul(a, b, activation=""): ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS' configs = [] -for fp8_inputs in [True]: +for fp8_inputs in [False, True]: if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()): continue configs.append( triton.testing.Benchmark( - x_names=["K"], # Argument names to use as an x-axis for the plot - x_vals=[512 * i for i in range(1, 17)], # Different possible values for `x_name` + x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot + x_vals=[128 * i for i in range(2, 33)], # Different possible values for `x_name` line_arg="provider", # Argument name whose value corresponds to a different line in the plot # Possible values for `line_arg` # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. @@ -422,13 +422,12 @@ def matmul(a, b, activation=""): @triton.testing.perf_report(configs) -def benchmark(K, provider, fp8_inputs): - M, N = 8192, 8192 +def benchmark(M, N, K, provider, fp8_inputs): a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) if TORCH_HAS_FP8 and fp8_inputs: a = a.to(torch.float8_e5m2) - b = b.T.contiguous().T + b = b.T b = b.to(torch.float8_e5m2) quantiles = [0.5, 0.2, 0.8] if provider == ref_lib.lower(): From e774fe9e485c90554b52a1af876a0559021b4571 Mon Sep 17 00:00:00 2001 From: Pawel Szczerbuk Date: Mon, 9 Sep 2024 14:21:36 -0700 Subject: [PATCH 6/7] Updating lit tests --- test/Conversion/tritongpu_to_llvm_hopper.mlir | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index cf4c19361650..c7c63301fac2 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -93,13 +93,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_reg_operand_A // Generate a wgmma where the first operand is a struct. - // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> - %false = arith.constant false : i1 %opA = triton_gpu.convert_layout %a : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> - %m = triton_nvidia_gpu.warp_group_dot %opA, %b, %false, %cst { inputPrecision = 0 : i32 }: + %m = triton_nvidia_gpu.warp_group_dot %opA, %b, %cst { inputPrecision = 0 : i32 }: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<64x64xf16, #shared> -> tensor<128x64xf32, #mma> tt.return } @@ -113,7 +112,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_reg_operand_A_fp8 // Generate a wgmma where the first operand is a struct. - // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1> From 4153c5e43f1cacc7434287d1c63f305b727f91f3 Mon Sep 17 00:00:00 2001 From: Pawel Szczerbuk Date: Tue, 10 Sep 2024 14:18:19 -0700 Subject: [PATCH 7/7] Handle the case where the WGMMAOp has no C operand set --- third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index b6fa7a1d65b7..8de0efefca84 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -367,7 +367,9 @@ class WGMMAOpPattern : public OpRewritePattern { operandsAndConstraints.push_back({opB, "l"}); // `scale-d` - operandsAndConstraints.push_back({opScaleD, "b"}); + if (op.getOpC()) + operandsAndConstraints.push_back({opScaleD, "b"}); + return operandsAndConstraints; } @@ -465,7 +467,10 @@ class WGMMAOpPattern : public OpRewritePattern { args += "$" + std::to_string(asmOpIdx++) + ", "; // `scale-d` - args += "$" + std::to_string(asmOpIdx++); + if (op.getOpC()) + args += "$" + std::to_string(asmOpIdx++); + else + args += "0"; // `imm-scale-a`, and `imm-scale-b` are 1 by default only for float-based // WGMMA