diff --git a/include/triton/Dialect/Triton/IR/Traits.h b/include/triton/Dialect/Triton/IR/Traits.h index 6e554aed2358..e47c023a29b0 100644 --- a/include/triton/Dialect/Triton/IR/Traits.h +++ b/include/triton/Dialect/Triton/IR/Traits.h @@ -67,8 +67,8 @@ template class DotLike : public TraitBase { public: static LogicalResult verifyTrait(Operation *op) { - if (op->getNumOperands() != 3) - return op->emitOpError("expected 3 operands"); + if (op->getNumOperands() < 3) + return op->emitOpError("expected at least 3 operands"); auto aTy = cast(op->getOperand(0).getType()); auto bTy = cast(op->getOperand(1).getType()); auto cTy = cast(op->getOperand(2).getType()); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 7469e39f837b..f2b79d222a91 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -169,4 +169,14 @@ def TritonGPUCombineTensorSelectAndIf: Pass<"tritongpu-combine-tensor-select-and "mlir::triton::TritonDialect"]; } +def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init", "mlir::ModuleOp"> { + let summary = "Replace accumulater zero-initialization with the flag indicating first use of the accumulator"; + + let description = "For the dot operations that support accumulator-use flag this pass replaces the zero-initialization " + "of the accumulator with the flag indicating the first use of the accumulator."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + #endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index d2446c4a6e51..ac7bf96f7fdd 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -83,13 +83,14 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods:$useC, DefaultValuedAttr:$inputPrecision, DefaultValuedAttr:$maxNumImpreciseAcc, DefaultValuedAttr:$isAsync); let results = (outs TT_FpIntTensor:$d); - let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)"; + let assemblyFormat = "$a`,` $b`,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)"; } def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods, diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 6c15ce06979f..d9bbd51bd9a1 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -306,8 +306,8 @@ class BlockedToMMA : public mlir::OpRewritePattern { a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose); b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); newDot = rewriter.create( - dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getInputPrecision(), - dotOp.getMaxNumImpreciseAcc(), false); + dotOp.getLoc(), newRetType, a, b, newAcc, nullptr, + dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc(), false); } else { // convert operands int minBitwidth = diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 9767effa5a74..99e2ac3c9660 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_triton_library(TritonGPUTransforms F32DotTC.cpp CombineTensorSelectAndIf.cpp ReduceDataDuplication.cpp + OptimizeAccumulatorInit.cpp OptimizeDotOperands.cpp OptimizeThreadLocality.cpp Pipeliner/MatmulLoopPipeline.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp new file mode 100644 index 000000000000..b5ed64cca3ac --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp @@ -0,0 +1,205 @@ +#include "mlir/Transforms/Passes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUOPTIMIZEACCUMULATORINIT +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { +bool dotSupportsAccInitFlag(Operation *op) { + assert(op->hasTrait() && "Expected a dot-like operation"); + return isa(op); +} + +std::pair getAccumulatorUseAndDef(Operation *op) { + assert(op->hasTrait() && "Expected a dot-like operation"); + if (auto wgDotOp = dyn_cast(op)) { + return std::make_pair(wgDotOp.getC(), wgDotOp); + } + return std::make_pair(nullptr, nullptr); +} + +void setUseAccFlag(Operation *op, Value useAcc) { + assert(op->hasTrait() && "Expected a dot-like operation"); + if (auto wgDotOp = dyn_cast(op)) { + wgDotOp.getUseCMutable().assign(useAcc); + } +} + +bool isConstantZeroTensor(Value v) { + auto constOp = v.getDefiningOp(); + if (!constOp) + return false; + auto splat = mlir::dyn_cast(constOp.getValue()); + if (!splat) + return false; + return splat.getSplatValue().getValue().convertToFloat() == 0.0f; +} + +std::optional> findZeroInitOp(Value accUse, + Operation *accDef, + scf::ForOp forOp, + bool &loopArgIsZero) { + Value v = accUse; + if (auto arg = dyn_cast(v)) { + assert(arg.getOwner() == forOp.getBody()); + if (isConstantZeroTensor(forOp.getInitArgs()[arg.getArgNumber() - 1])) { + loopArgIsZero = true; + } + v = forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1); + } + + auto defOp = v.getDefiningOp(); + if (!defOp) { + return std::nullopt; + } + if (auto selOp = dyn_cast(defOp)) { + if (isConstantZeroTensor(selOp.getTrueValue()) || + isConstantZeroTensor(selOp.getFalseValue())) { + return std::make_pair(selOp, 0); + } + } + if (auto ifOp = dyn_cast(defOp)) { + unsigned resultIndex = 0; + for (; resultIndex < ifOp.getNumResults(); ++resultIndex) { + if (ifOp.getResult(resultIndex) == v) + break; + } + Value thenVal = ifOp.thenYield()->getOperand(resultIndex); + Value elseVal = ifOp.elseYield()->getOperand(resultIndex); + if (isConstantZeroTensor(thenVal) || isConstantZeroTensor(elseVal)) { + // Make sure that the other value is not defined in the if itself, but + // passed from outside + if (thenVal.getParentBlock()->getParentOp() == ifOp || + elseVal.getParentBlock()->getParentOp() == ifOp) { + return std::nullopt; + } + return std::make_pair(ifOp, resultIndex); + } + } + return std::nullopt; +} + +} // namespace + +class OptimizeAccumulatorInitPass + : public impl::TritonGPUOptimizeAccumulatorInitBase< + OptimizeAccumulatorInitPass> { +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + SmallVector mmaOps; + m.walk([&](Operation *op) { + if (op->hasTrait() && dotSupportsAccInitFlag(op)) { + mmaOps.push_back(op); + } + }); + + // for each mma op, find where the accumulator is initialized with zero + // It can be: + // 1. A constant zero + // 2. Initialized with zero as the loop argument + // 3. Initialized with zero in the if op or with a select op in current + // or any of the previous loop iterations + for (Operation *mmaOp : mmaOps) { + Location loc = mmaOp->getLoc(); + + scf::ForOp forOp = dyn_cast(mmaOp->getParentOp()); + if (!forOp) { + continue; + } + + IRRewriter rewriter(forOp); + rewriter.setInsertionPoint(forOp); + + Value vTrue = + rewriter.create(loc, rewriter.getBoolAttr(true)); + Value vFalse = + rewriter.create(loc, rewriter.getBoolAttr(false)); + + // Find the accumulator + auto [accUse, accDef] = getAccumulatorUseAndDef(mmaOp); + if (!accUse || !accDef) { + continue; + } + if (isConstantZeroTensor(accUse)) { + setUseAccFlag(mmaOp, vFalse); + continue; + } + + bool loopArgIsZero = false; + std::optional> zeroInitOp = + findZeroInitOp(accUse, accDef, forOp, loopArgIsZero); + if (!zeroInitOp) { + continue; + } + + Value loopArgFlagValue = loopArgIsZero ? vFalse : vTrue; + scf::ForOp newForOp = + replaceForOpWithNewSignature(rewriter, forOp, {loopArgFlagValue}); + forOp.erase(); + forOp = newForOp; + loopArgFlagValue = + forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 1); + + Value condition = nullptr; + Value oldValue = nullptr; + Value zeroValue = nullptr; + bool thenInitsToZero = false; + if (auto selOp = dyn_cast(zeroInitOp->first)) { + condition = selOp.getCondition(); + oldValue = isConstantZeroTensor(selOp.getTrueValue()) + ? selOp.getFalseValue() + : selOp.getTrueValue(); + zeroValue = isConstantZeroTensor(selOp.getTrueValue()) + ? selOp.getTrueValue() + : selOp.getFalseValue(); + thenInitsToZero = isConstantZeroTensor(selOp.getTrueValue()); + } else { + assert(isa(*zeroInitOp->first) && "Expected an if op"); + auto ifOp = cast(zeroInitOp->first); + unsigned resultIndex = zeroInitOp->second; + condition = ifOp.getCondition(); + Value thenVal = ifOp.thenYield()->getOperand(resultIndex); + Value elseVal = ifOp.elseYield()->getOperand(resultIndex); + oldValue = isConstantZeroTensor(thenVal) ? elseVal : thenVal; + zeroValue = isConstantZeroTensor(thenVal) ? thenVal : elseVal; + thenInitsToZero = isConstantZeroTensor(thenVal); + } + + // Create a select op that updates the flag + rewriter.setInsertionPoint(zeroInitOp->first); + bool zeroingBeforeMMA = zeroInitOp->first->isBeforeInBlock(mmaOp); + Value prevFlagValue = zeroingBeforeMMA ? loopArgFlagValue : vTrue; + auto selectFlagOp = rewriter.create( + loc, condition, thenInitsToZero ? vFalse : prevFlagValue, + thenInitsToZero ? prevFlagValue : vFalse); + setUseAccFlag(mmaOp, zeroingBeforeMMA ? selectFlagOp : loopArgFlagValue); + auto forYield = cast(forOp.getBody()->getTerminator()); + forYield->insertOperands(forYield->getNumOperands(), + {zeroingBeforeMMA ? vTrue : selectFlagOp}); + + // Stop clearing out the accumulator with zero + if (auto selOp = dyn_cast(zeroInitOp->first)) { + rewriter.setInsertionPoint(selOp); + rewriter.replaceOp(selOp, oldValue); + } else { + auto ifOp = cast(zeroInitOp->first); + int resultIndex = zeroInitOp->second; + auto zeroingYield = + thenInitsToZero ? ifOp.thenYield() : ifOp.elseYield(); + zeroingYield.setOperand(resultIndex, oldValue); + } + } + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/python/src/passes.cc b/python/src/passes.cc index 8cd0f5c5c1e9..98d8369d40aa 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -66,6 +66,8 @@ void init_triton_passes_ttgpuir(py::module &&m) { createAllocateSharedMemoryPass); ADD_PASS_WRAPPER_0("add_combine_tensor_select_and_if", createTritonGPUCombineTensorSelectAndIf); + ADD_PASS_WRAPPER_0("add_optimize_accumulator_init", + createTritonGPUOptimizeAccumulatorInit); } void init_triton_passes_convert(py::module &&m) { diff --git a/test/Conversion/nvgpu_to_llvm.mlir b/test/Conversion/nvgpu_to_llvm.mlir index cef074bb6935..beaa4c952d4e 100644 --- a/test/Conversion/nvgpu_to_llvm.mlir +++ b/test/Conversion/nvgpu_to_llvm.mlir @@ -71,7 +71,8 @@ llvm.func @st_matrix(%i: i32, %ptr: !llvm.ptr<3>) { // CHECK-LABEL: @wgmma llvm.func @wgmma(%desc: i64, %in: !struct_64xf32) { // CHECK: wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 -%acc0 = nvgpu.wgmma %desc, %desc { +%false = llvm.mlir.constant(false) : i1 +%acc0 = nvgpu.wgmma %desc, %desc, %false { eltTypeA = 3 : i32, eltTypeB = 3 : i32, eltTypeC = 7 : i32, @@ -80,7 +81,7 @@ llvm.func @wgmma(%desc: i64, %in: !struct_64xf32) { m = 64 : i32, n = 256 : i32, k = 32 : i32 -} : (i64, i64) -> !struct_128xf32 +} : (i64, i64, i1) -> !struct_128xf32 // CHECK: // wait for regs: $0,$1,$2,{{.*}},$127 // CHECK: wgmma.wait_group.sync.aligned 0; diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 3c16ea0260ef..d44529966274 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -93,7 +93,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_reg_operand_A // Generate a wgmma where the first operand is a struct. - // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> @@ -112,7 +112,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: @dot_reg_operand_A_fp8 // Generate a wgmma where the first operand is a struct. - // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1> diff --git a/test/TritonGPU/accumulator-init.mlir b/test/TritonGPU/accumulator-init.mlir new file mode 100644 index 000000000000..72ef11dcafd7 --- /dev/null +++ b/test/TritonGPU/accumulator-init.mlir @@ -0,0 +1,351 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-optimize-accumulator-init | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}> +#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + +// CHECK-LABEL: @constant_init +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, %[[FALSE]] + tt.func @constant_init(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %cst_2 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @if_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_after_mma_invert +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[TRUE]], %[[FALSE]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @if_after_mma_invert(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %acc : tensor<128x16xf32, #mma1> + } else { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_before_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[ACC]] +// CHECK: else +// CHECK: scf.yield %[[ACC]] +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @if_before_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %arg4 : tensor<128x16xf32, #mma1> + } + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_before_mma_invert +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[USE_ACC]], %[[FALSE]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[ACC]] +// CHECK: else +// CHECK: scf.yield %[[ACC]] +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @if_before_mma_invert(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %arg4 : tensor<128x16xf32, #mma1> + } else { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @sel_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @sel_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @sel_before_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @sel_before_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xf32, #mma1> + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + scf.yield %acc: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + + +// Check that we look only at the zeroing directly preceding the mma + +// CHECK-LABEL: @if_before_and_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[TRUE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[USE_ACC]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[ACC]] +// CHECK: else +// CHECK: scf.yield %[[ACC]] +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC_CND]], %[[USE_ACC_NEXT]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[C0_TENSOR]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: scf.yield {{.*}}, %[[TRUE]] + tt.func @if_before_and_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc_0 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %arg4 : tensor<128x16xf32, #mma1> + } + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_0 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_1 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + scf.yield %acc_1: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @two_ifs_after_mma +// CHECK-DAG: %[[C0_TENSOR:.+]] = arith.constant dense<0.000000e+00> +// CHECK-DAG: %[[TRUE:.+]] = arith.constant true +// CHECK-DAG: %[[FALSE:.+]] = arith.constant false +// CHECK: scf.for {{.*}} iter_args(%[[ACC:.+]] = %[[C0_TENSOR]], %[[USE_ACC:.+]] = %[[FALSE]]) +// CHECK: %[[CND:.+]] = arith.cmpi +// CHECK: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, %[[ACC]], %[[USE_ACC]] +// CHECK: %[[ACC_CND:.+]] = scf.if %[[CND]] +// CHECK: scf.yield %[[C0_TENSOR]] +// CHECK: else +// CHECK: scf.yield %[[ACC_NEXT]] +// CHECK: %[[USE_ACC_NEXT:.*]] = arith.select %[[CND]], %[[FALSE]], %[[TRUE]] +// CHECK: scf.if %[[CND]] +// CHECK: scf.yield %[[ACC_CND]] +// CHECK: else +// CHECK: scf.yield %[[ACC_CND]] +// CHECK: scf.yield {{.*}}, %[[USE_ACC_NEXT]] + tt.func @two_ifs_after_mma(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_0 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc : tensor<128x16xf32, #mma1> + } + %acc_1 = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + scf.yield %acc_0 : tensor<128x16xf32, #mma1> + } + scf.yield %acc_1: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// Check that we bail out in unsupported cases + +// CHECK-LABEL: @non_zero_init +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @non_zero_init(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @zero_init_dist_2 +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @zero_init_dist_2(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2, %arg5 = %cst_2) -> (tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg5 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = arith.select %cnd, %cst_2, %acc : tensor<128x16xf32, #mma1> + scf.yield %acc_, %arg4: tensor<128x16xf32, #mma1>, tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @if_defines_alternative +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @if_defines_alternative(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %ext: i32, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %cnd = arith.cmpi slt, %arg3, %ext : i32 + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = scf.if %cnd -> (tensor<128x16xf32, #mma1>) { + scf.yield %cst_2 : tensor<128x16xf32, #mma1> + } else { + %acc_alt = arith.addf %acc, %cst_3 : tensor<128x16xf32, #mma1> + scf.yield %acc_alt : tensor<128x16xf32, #mma1> + } + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } + +// CHECK-LABEL: @non_cond_override +// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc + tt.func @non_cond_override(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %inc: tensor<64x16xi32, #blocked> {tt.divisibility = 16 : i32}) -> tensor<128x16xf32, #mma1> { + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1> + %cst_3 = arith.constant dense<1.000000e+00> : tensor<128x16xf32, #mma1> + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 { + %acc = triton_nvidia_gpu.warp_group_dot %A, %B, %arg4 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> + %acc_ = arith.addf %acc, %cst_3 : tensor<128x16xf32, #mma1> + scf.yield %acc_: tensor<128x16xf32, #mma1> + } + tt.return %17 : tensor<128x16xf32, #mma1> + } +} diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index c2185a4b8eac..adfde57b01b0 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -219,6 +219,7 @@ def make_ttgir(mod, metadata, opt, capability): passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.common.add_cse(pm) if capability // 10 >= 8: + passes.ttgpuir.add_optimize_accumulator_init(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) passes.ttgpuir.add_pipeline(pm, opt.num_stages) passes.ttgpuir.add_prefetch(pm) diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index 7affd8840612..31b2646db81d 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -79,12 +79,12 @@ def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType", def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">; def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { - let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, Optional:$opC, + let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, I1:$useC, Optional:$opC, I32Attr:$m, I32Attr:$n, I32Attr:$k, WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB, WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB); let results = (outs LLVM_AnyStruct:$res); - let assemblyFormat = "$opA `,` $opB (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; + let assemblyFormat = "$opA `,` $opB `,` $useC (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; } def NVGPU_FenceAsyncSharedOp : NVGPU_Op<"fence_async_shared", []> { diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index b075ca31a407..8de0efefca84 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -348,6 +348,7 @@ class WGMMAOpPattern : public OpRewritePattern { auto opA = op.getOpA(); auto opB = op.getOpB(); auto opC = op.getOpC(); + auto opScaleD = op.getUseC(); auto typeA = opA.getType(); auto structTypeA = dyn_cast(typeA); @@ -364,6 +365,11 @@ class WGMMAOpPattern : public OpRewritePattern { // Operand B (must be `desc`) operandsAndConstraints.push_back({opB, "l"}); + + // `scale-d` + if (op.getOpC()) + operandsAndConstraints.push_back({opScaleD, "b"}); + return operandsAndConstraints; } @@ -460,8 +466,11 @@ class WGMMAOpPattern : public OpRewritePattern { // Operand B (must be `desc`) args += "$" + std::to_string(asmOpIdx++) + ", "; - // `scale-d` is 1 if we have a C operand. - args += op.getOpC() ? "1" : "0"; + // `scale-d` + if (op.getOpC()) + args += "$" + std::to_string(asmOpIdx++); + else + args += "0"; // `imm-scale-a`, and `imm-scale-b` are 1 by default only for float-based // WGMMA diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp index 41e36503f593..c10a6e777987 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -357,9 +357,9 @@ static SmallVector emitWait(ConversionPatternRewriter &rewriter, LogicalResult convertDot(const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, Operation *op, Value a, Value b, Value c, Value d, - Value loadedA, Value loadedB, Value loadedC, - bool allowTF32, uint32_t maxNumImpreciseAcc, bool sync, - Value thread) { + Value useCOperand, Value loadedA, Value loadedB, + Value loadedC, bool allowTF32, + uint32_t maxNumImpreciseAcc, bool sync, Value thread) { auto aTensorTy = cast(a.getType()); auto bTensorTy = cast(b.getType()); auto dTensorTy = cast(d.getType()); @@ -436,8 +436,13 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, auto accTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); Value d; - if (!zeroAcc) + Value useC = i1_val(0); + if (!zeroAcc) { d = packLLElements(loc, typeConverter, mmaOut, rewriter, accTy); + useC = i1_val(true); + } + if (useCOperand) + useC = and_(useC, useCOperand); uint32_t numLowPrecisionAcc = 0; Value partialAcc; for (int k = 0; k < numRepK; ++k) { @@ -463,8 +468,9 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, (numLowPrecisionAcc >= maxNumImpreciseAcc || k == numRepK - 1); Value mmaAcc = needsPartialAccumulator ? partialAcc : d; mmaAcc = rewriter.create( - loc, accTy, a, b, mmaAcc, M, N, K, eltTypeC, eltTypeA, eltTypeB, - layoutA, layoutB); + loc, accTy, a, b, useC, mmaAcc, M, N, K, eltTypeC, eltTypeA, + eltTypeB, layoutA, layoutB); + useC = i1_val(1); if (needsPartialAccumulator) partialAcc = mmaAcc; else @@ -510,9 +516,9 @@ LogicalResult convertWGMMA(triton::nvidia_gpu::WarpGroupDotOp op, mlir::isa(AEnc)); assert(mlir::isa(BEnc) && "Operand B should use Shared layout."); - return convertDot(typeConverter, rewriter, op.getLoc(), op.getOperation(), // - op.getA(), op.getB(), op.getC(), op.getD(), // - adaptor.getA(), adaptor.getB(), adaptor.getC(), + return convertDot(typeConverter, rewriter, op.getLoc(), op.getOperation(), // + op.getA(), op.getB(), op.getC(), op.getD(), op.getUseC(), // + adaptor.getA(), adaptor.getB(), adaptor.getC(), // op.getInputPrecision() == InputPrecision::TF32, op.getMaxNumImpreciseAcc(), !op.getIsAsync(), thread); }