Skip to content

Commit

Permalink
[BE] Accumulator init optimization (#4680)
Browse files Browse the repository at this point in the history
Adding a transformation pass that skips filling the accumulator with
zero value if the HW supports accumulator scale or init flag. In such
case flag value is created and maintained, and passed to the MMA op
indicating if accumulator should be taken into an account when
calculating the dot product.
The pass is intended to be generic enough to be reusable between
different HW platforms, therefore it is not placed in the Nvidia
specific folder, even though it is supporting only Hopper MMA for the
moment.
  • Loading branch information
pawelszczerbuk committed Sep 10, 2024
1 parent 58eccfc commit a0c1bc9
Show file tree
Hide file tree
Showing 14 changed files with 609 additions and 22 deletions.
4 changes: 2 additions & 2 deletions include/triton/Dialect/Triton/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ template <class ConcreteType>
class DotLike : public TraitBase<ConcreteType, DotLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
if (op->getNumOperands() != 3)
return op->emitOpError("expected 3 operands");
if (op->getNumOperands() < 3)
return op->emitOpError("expected at least 3 operands");
auto aTy = cast<TensorOrMemDesc>(op->getOperand(0).getType());
auto bTy = cast<TensorOrMemDesc>(op->getOperand(1).getType());
auto cTy = cast<TensorOrMemDesc>(op->getOperand(2).getType());
Expand Down
10 changes: 10 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,14 @@ def TritonGPUCombineTensorSelectAndIf: Pass<"tritongpu-combine-tensor-select-and
"mlir::triton::TritonDialect"];
}

def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init", "mlir::ModuleOp"> {
let summary = "Replace accumulater zero-initialization with the flag indicating first use of the accumulator";

let description = "For the dot operations that support accumulator-use flag this pass replaces the zero-initialization "
"of the accumulator with the flag indicating the first use of the accumulator.";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods<I
let arguments = (ins TT_TensorOrMemDesc:$a,
TT_TensorOrMemDesc:$b,
TT_FpIntTensor:$c,
Optional<I1>:$useC,
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
DefaultValuedAttr<I32Attr, "0">:$maxNumImpreciseAcc,
DefaultValuedAttr<BoolAttr, "false">:$isAsync);

let results = (outs TT_FpIntTensor:$d);

let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
let assemblyFormat = "$a`,` $b`,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)";
}

def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose);
b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose);
newDot = rewriter.create<triton::nvidia_gpu::WarpGroupDotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getInputPrecision(),
dotOp.getMaxNumImpreciseAcc(), false);
dotOp.getLoc(), newRetType, a, b, newAcc, nullptr,
dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc(), false);
} else {
// convert operands
int minBitwidth =
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_triton_library(TritonGPUTransforms
F32DotTC.cpp
CombineTensorSelectAndIf.cpp
ReduceDataDuplication.cpp
OptimizeAccumulatorInit.cpp
OptimizeDotOperands.cpp
OptimizeThreadLocality.cpp
Pipeliner/MatmulLoopPipeline.cpp
Expand Down
205 changes: 205 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

namespace mlir {
namespace triton {
namespace gpu {

#define GEN_PASS_DEF_TRITONGPUOPTIMIZEACCUMULATORINIT
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

namespace {
bool dotSupportsAccInitFlag(Operation *op) {
assert(op->hasTrait<OpTrait::DotLike>() && "Expected a dot-like operation");
return isa<triton::nvidia_gpu::WarpGroupDotOp>(op);
}

std::pair<Value, Operation *> getAccumulatorUseAndDef(Operation *op) {
assert(op->hasTrait<OpTrait::DotLike>() && "Expected a dot-like operation");
if (auto wgDotOp = dyn_cast<triton::nvidia_gpu::WarpGroupDotOp>(op)) {
return std::make_pair(wgDotOp.getC(), wgDotOp);
}
return std::make_pair(nullptr, nullptr);
}

void setUseAccFlag(Operation *op, Value useAcc) {
assert(op->hasTrait<OpTrait::DotLike>() && "Expected a dot-like operation");
if (auto wgDotOp = dyn_cast<triton::nvidia_gpu::WarpGroupDotOp>(op)) {
wgDotOp.getUseCMutable().assign(useAcc);
}
}

bool isConstantZeroTensor(Value v) {
auto constOp = v.getDefiningOp<arith::ConstantOp>();
if (!constOp)
return false;
auto splat = mlir::dyn_cast<SplatElementsAttr>(constOp.getValue());
if (!splat)
return false;
return splat.getSplatValue<FloatAttr>().getValue().convertToFloat() == 0.0f;
}

std::optional<std::pair<Operation *, int>> findZeroInitOp(Value accUse,
Operation *accDef,
scf::ForOp forOp,
bool &loopArgIsZero) {
Value v = accUse;
if (auto arg = dyn_cast<BlockArgument>(v)) {
assert(arg.getOwner() == forOp.getBody());
if (isConstantZeroTensor(forOp.getInitArgs()[arg.getArgNumber() - 1])) {
loopArgIsZero = true;
}
v = forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
}

auto defOp = v.getDefiningOp();
if (!defOp) {
return std::nullopt;
}
if (auto selOp = dyn_cast<arith::SelectOp>(defOp)) {
if (isConstantZeroTensor(selOp.getTrueValue()) ||
isConstantZeroTensor(selOp.getFalseValue())) {
return std::make_pair(selOp, 0);
}
}
if (auto ifOp = dyn_cast<scf::IfOp>(defOp)) {
unsigned resultIndex = 0;
for (; resultIndex < ifOp.getNumResults(); ++resultIndex) {
if (ifOp.getResult(resultIndex) == v)
break;
}
Value thenVal = ifOp.thenYield()->getOperand(resultIndex);
Value elseVal = ifOp.elseYield()->getOperand(resultIndex);
if (isConstantZeroTensor(thenVal) || isConstantZeroTensor(elseVal)) {
// Make sure that the other value is not defined in the if itself, but
// passed from outside
if (thenVal.getParentBlock()->getParentOp() == ifOp ||
elseVal.getParentBlock()->getParentOp() == ifOp) {
return std::nullopt;
}
return std::make_pair(ifOp, resultIndex);
}
}
return std::nullopt;
}

} // namespace

class OptimizeAccumulatorInitPass
: public impl::TritonGPUOptimizeAccumulatorInitBase<
OptimizeAccumulatorInitPass> {
public:
void runOnOperation() override {
ModuleOp m = getOperation();
SmallVector<Operation *> mmaOps;
m.walk([&](Operation *op) {
if (op->hasTrait<OpTrait::DotLike>() && dotSupportsAccInitFlag(op)) {
mmaOps.push_back(op);
}
});

// for each mma op, find where the accumulator is initialized with zero
// It can be:
// 1. A constant zero
// 2. Initialized with zero as the loop argument
// 3. Initialized with zero in the if op or with a select op in current
// or any of the previous loop iterations
for (Operation *mmaOp : mmaOps) {
Location loc = mmaOp->getLoc();

scf::ForOp forOp = dyn_cast<scf::ForOp>(mmaOp->getParentOp());
if (!forOp) {
continue;
}

IRRewriter rewriter(forOp);
rewriter.setInsertionPoint(forOp);

Value vTrue =
rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(true));
Value vFalse =
rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(false));

// Find the accumulator
auto [accUse, accDef] = getAccumulatorUseAndDef(mmaOp);
if (!accUse || !accDef) {
continue;
}
if (isConstantZeroTensor(accUse)) {
setUseAccFlag(mmaOp, vFalse);
continue;
}

bool loopArgIsZero = false;
std::optional<std::pair<Operation *, int>> zeroInitOp =
findZeroInitOp(accUse, accDef, forOp, loopArgIsZero);
if (!zeroInitOp) {
continue;
}

Value loopArgFlagValue = loopArgIsZero ? vFalse : vTrue;
scf::ForOp newForOp =
replaceForOpWithNewSignature(rewriter, forOp, {loopArgFlagValue});
forOp.erase();
forOp = newForOp;
loopArgFlagValue =
forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 1);

Value condition = nullptr;
Value oldValue = nullptr;
Value zeroValue = nullptr;
bool thenInitsToZero = false;
if (auto selOp = dyn_cast<arith::SelectOp>(zeroInitOp->first)) {
condition = selOp.getCondition();
oldValue = isConstantZeroTensor(selOp.getTrueValue())
? selOp.getFalseValue()
: selOp.getTrueValue();
zeroValue = isConstantZeroTensor(selOp.getTrueValue())
? selOp.getTrueValue()
: selOp.getFalseValue();
thenInitsToZero = isConstantZeroTensor(selOp.getTrueValue());
} else {
assert(isa<scf::IfOp>(*zeroInitOp->first) && "Expected an if op");
auto ifOp = cast<scf::IfOp>(zeroInitOp->first);
unsigned resultIndex = zeroInitOp->second;
condition = ifOp.getCondition();
Value thenVal = ifOp.thenYield()->getOperand(resultIndex);
Value elseVal = ifOp.elseYield()->getOperand(resultIndex);
oldValue = isConstantZeroTensor(thenVal) ? elseVal : thenVal;
zeroValue = isConstantZeroTensor(thenVal) ? thenVal : elseVal;
thenInitsToZero = isConstantZeroTensor(thenVal);
}

// Create a select op that updates the flag
rewriter.setInsertionPoint(zeroInitOp->first);
bool zeroingBeforeMMA = zeroInitOp->first->isBeforeInBlock(mmaOp);
Value prevFlagValue = zeroingBeforeMMA ? loopArgFlagValue : vTrue;
auto selectFlagOp = rewriter.create<arith::SelectOp>(
loc, condition, thenInitsToZero ? vFalse : prevFlagValue,
thenInitsToZero ? prevFlagValue : vFalse);
setUseAccFlag(mmaOp, zeroingBeforeMMA ? selectFlagOp : loopArgFlagValue);
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
forYield->insertOperands(forYield->getNumOperands(),
{zeroingBeforeMMA ? vTrue : selectFlagOp});

// Stop clearing out the accumulator with zero
if (auto selOp = dyn_cast<arith::SelectOp>(zeroInitOp->first)) {
rewriter.setInsertionPoint(selOp);
rewriter.replaceOp(selOp, oldValue);
} else {
auto ifOp = cast<scf::IfOp>(zeroInitOp->first);
int resultIndex = zeroInitOp->second;
auto zeroingYield =
thenInitsToZero ? ifOp.thenYield() : ifOp.elseYield();
zeroingYield.setOperand(resultIndex, oldValue);
}
}
}
};

} // namespace gpu
} // namespace triton
} // namespace mlir
2 changes: 2 additions & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ void init_triton_passes_ttgpuir(py::module &&m) {
createAllocateSharedMemoryPass);
ADD_PASS_WRAPPER_0("add_combine_tensor_select_and_if",
createTritonGPUCombineTensorSelectAndIf);
ADD_PASS_WRAPPER_0("add_optimize_accumulator_init",
createTritonGPUOptimizeAccumulatorInit);
}

void init_triton_passes_convert(py::module &&m) {
Expand Down
5 changes: 3 additions & 2 deletions test/Conversion/nvgpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ llvm.func @st_matrix(%i: i32, %ptr: !llvm.ptr<3>) {
// CHECK-LABEL: @wgmma
llvm.func @wgmma(%desc: i64, %in: !struct_64xf32) {
// CHECK: wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2
%acc0 = nvgpu.wgmma %desc, %desc {
%false = llvm.mlir.constant(false) : i1
%acc0 = nvgpu.wgmma %desc, %desc, %false {
eltTypeA = 3 : i32,
eltTypeB = 3 : i32,
eltTypeC = 7 : i32,
Expand All @@ -80,7 +81,7 @@ llvm.func @wgmma(%desc: i64, %in: !struct_64xf32) {
m = 64 : i32,
n = 256 : i32,
k = 32 : i32
} : (i64, i64) -> !struct_128xf32
} : (i64, i64, i1) -> !struct_128xf32

// CHECK: // wait for regs: $0,$1,$2,{{.*}},$127
// CHECK: wgmma.wait_group.sync.aligned 0;
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: @dot_reg_operand_A
// Generate a wgmma where the first operand is a struct.
// CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
Expand All @@ -112,7 +112,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: @dot_reg_operand_A_fp8
// Generate a wgmma where the first operand is a struct.
// CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32}
tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1>
Expand Down
Loading

0 comments on commit a0c1bc9

Please sign in to comment.