Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] Accumulator init optimization #4680

Merged
merged 8 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/triton/Dialect/Triton/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ template <class ConcreteType>
class DotLike : public TraitBase<ConcreteType, DotLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
if (op->getNumOperands() != 3)
return op->emitOpError("expected 3 operands");
if (op->getNumOperands() < 3)
ThomasRaoux marked this conversation as resolved.
Show resolved Hide resolved
return op->emitOpError("expected at least 3 operands");
auto aTy = cast<TensorOrMemDesc>(op->getOperand(0).getType());
auto bTy = cast<TensorOrMemDesc>(op->getOperand(1).getType());
auto cTy = cast<TensorOrMemDesc>(op->getOperand(2).getType());
Expand Down
10 changes: 10 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,14 @@ def TritonGPUCombineTensorSelectAndIf: Pass<"tritongpu-combine-tensor-select-and
"mlir::triton::TritonDialect"];
}

def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init", "mlir::ModuleOp"> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this go in triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td?

let summary = "Replace accumulater zero-initialization with the flag indicating first use of the accumulator";

let description = "For the dot operations that support accumulator-use flag this pass replaces the zero-initialization "
"of the accumulator with the flag indicating the first use of the accumulator.";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods<I
let arguments = (ins TT_TensorOrMemDesc:$a,
TT_TensorOrMemDesc:$b,
TT_FpIntTensor:$c,
Optional<I1>:$useC,
DefaultValuedAttr<TT_InputPrecisionAttr, "::mlir::triton::InputPrecision::IEEE">:$inputPrecision,
DefaultValuedAttr<I32Attr, "0">:$maxNumImpreciseAcc,
DefaultValuedAttr<BoolAttr, "false">:$isAsync);

let results = (outs TT_FpIntTensor:$d);

let assemblyFormat = "$a`,` $b`,` $c attr-dict `:` type($a) `*` type($b) `->` type($d)";
let assemblyFormat = "$a`,` $b`,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)";
}

def TTNG_WarpGroupDotWaitOp : TTNG_Op<"warp_group_dot_wait", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose);
b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose);
newDot = rewriter.create<triton::nvidia_gpu::WarpGroupDotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getInputPrecision(),
dotOp.getMaxNumImpreciseAcc(), false);
dotOp.getLoc(), newRetType, a, b, newAcc, nullptr,
dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc(), false);
} else {
// convert operands
int minBitwidth =
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_triton_library(TritonGPUTransforms
F32DotTC.cpp
CombineTensorSelectAndIf.cpp
ReduceDataDuplication.cpp
OptimizeAccumulatorInit.cpp
OptimizeDotOperands.cpp
OptimizeThreadLocality.cpp
Pipeliner/MatmulLoopPipeline.cpp
Expand Down
205 changes: 205 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"

namespace mlir {
namespace triton {
namespace gpu {

#define GEN_PASS_DEF_TRITONGPUOPTIMIZEACCUMULATORINIT
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

namespace {
bool dotSupportsAccInitFlag(Operation *op) {
assert(op->hasTrait<OpTrait::DotLike>() && "Expected a dot-like operation");
return isa<triton::nvidia_gpu::WarpGroupDotOp>(op);
}

std::pair<Value, Operation *> getAccumulatorUseAndDef(Operation *op) {
assert(op->hasTrait<OpTrait::DotLike>() && "Expected a dot-like operation");
if (auto wgDotOp = dyn_cast<triton::nvidia_gpu::WarpGroupDotOp>(op)) {
return std::make_pair(wgDotOp.getC(), wgDotOp);
}
return std::make_pair(nullptr, nullptr);
}

void setUseAccFlag(Operation *op, Value useAcc) {
assert(op->hasTrait<OpTrait::DotLike>() && "Expected a dot-like operation");
if (auto wgDotOp = dyn_cast<triton::nvidia_gpu::WarpGroupDotOp>(op)) {
wgDotOp.getUseCMutable().assign(useAcc);
}
}

bool isConstantZeroTensor(Value v) {
auto constOp = v.getDefiningOp<arith::ConstantOp>();
if (!constOp)
return false;
auto splat = mlir::dyn_cast<SplatElementsAttr>(constOp.getValue());
if (!splat)
return false;
return splat.getSplatValue<FloatAttr>().getValue().convertToFloat() == 0.0f;
}

std::optional<std::pair<Operation *, int>> findZeroInitOp(Value accUse,
Operation *accDef,
scf::ForOp forOp,
bool &loopArgIsZero) {
Value v = accUse;
if (auto arg = dyn_cast<BlockArgument>(v)) {
assert(arg.getOwner() == forOp.getBody());
if (isConstantZeroTensor(forOp.getInitArgs()[arg.getArgNumber() - 1])) {
loopArgIsZero = true;
}
v = forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
}

auto defOp = v.getDefiningOp();
if (!defOp) {
return std::nullopt;
}
if (auto selOp = dyn_cast<arith::SelectOp>(defOp)) {
if (isConstantZeroTensor(selOp.getTrueValue()) ||
isConstantZeroTensor(selOp.getFalseValue())) {
return std::make_pair(selOp, 0);
}
}
if (auto ifOp = dyn_cast<scf::IfOp>(defOp)) {
unsigned resultIndex = 0;
for (; resultIndex < ifOp.getNumResults(); ++resultIndex) {
if (ifOp.getResult(resultIndex) == v)
break;
}
Value thenVal = ifOp.thenYield()->getOperand(resultIndex);
Value elseVal = ifOp.elseYield()->getOperand(resultIndex);
if (isConstantZeroTensor(thenVal) || isConstantZeroTensor(elseVal)) {
// Make sure that the other value is not defined in the if itself, but
// passed from outside
if (thenVal.getParentBlock()->getParentOp() == ifOp ||
elseVal.getParentBlock()->getParentOp() == ifOp) {
return std::nullopt;
}
return std::make_pair(ifOp, resultIndex);
}
}
return std::nullopt;
}

} // namespace

class OptimizeAccumulatorInitPass
: public impl::TritonGPUOptimizeAccumulatorInitBase<
OptimizeAccumulatorInitPass> {
public:
void runOnOperation() override {
ModuleOp m = getOperation();
SmallVector<Operation *> mmaOps;
m.walk([&](Operation *op) {
if (op->hasTrait<OpTrait::DotLike>() && dotSupportsAccInitFlag(op)) {
mmaOps.push_back(op);
}
});

// for each mma op, find where the accumulator is initialized with zero
// It can be:
// 1. A constant zero
// 2. Initialized with zero as the loop argument
// 3. Initialized with zero in the if op or with a select op in current
// or any of the previous loop iterations
for (Operation *mmaOp : mmaOps) {
Location loc = mmaOp->getLoc();

scf::ForOp forOp = dyn_cast<scf::ForOp>(mmaOp->getParentOp());
if (!forOp) {
continue;
}

IRRewriter rewriter(forOp);
rewriter.setInsertionPoint(forOp);

Value vTrue =
rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(true));
Value vFalse =
rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(false));

// Find the accumulator
auto [accUse, accDef] = getAccumulatorUseAndDef(mmaOp);
if (!accUse || !accDef) {
continue;
}
if (isConstantZeroTensor(accUse)) {
setUseAccFlag(mmaOp, vFalse);
continue;
}

bool loopArgIsZero = false;
std::optional<std::pair<Operation *, int>> zeroInitOp =
findZeroInitOp(accUse, accDef, forOp, loopArgIsZero);
if (!zeroInitOp) {
continue;
}

Value loopArgFlagValue = loopArgIsZero ? vFalse : vTrue;
scf::ForOp newForOp =
replaceForOpWithNewSignature(rewriter, forOp, {loopArgFlagValue});
forOp.erase();
forOp = newForOp;
loopArgFlagValue =
forOp.getRegionIterArg(forOp.getNumRegionIterArgs() - 1);

Value condition = nullptr;
Value oldValue = nullptr;
Value zeroValue = nullptr;
bool thenInitsToZero = false;
if (auto selOp = dyn_cast<arith::SelectOp>(zeroInitOp->first)) {
condition = selOp.getCondition();
oldValue = isConstantZeroTensor(selOp.getTrueValue())
? selOp.getFalseValue()
: selOp.getTrueValue();
zeroValue = isConstantZeroTensor(selOp.getTrueValue())
? selOp.getTrueValue()
: selOp.getFalseValue();
thenInitsToZero = isConstantZeroTensor(selOp.getTrueValue());
} else {
assert(isa<scf::IfOp>(*zeroInitOp->first) && "Expected an if op");
auto ifOp = cast<scf::IfOp>(zeroInitOp->first);
unsigned resultIndex = zeroInitOp->second;
condition = ifOp.getCondition();
Value thenVal = ifOp.thenYield()->getOperand(resultIndex);
Value elseVal = ifOp.elseYield()->getOperand(resultIndex);
oldValue = isConstantZeroTensor(thenVal) ? elseVal : thenVal;
zeroValue = isConstantZeroTensor(thenVal) ? thenVal : elseVal;
thenInitsToZero = isConstantZeroTensor(thenVal);
}

// Create a select op that updates the flag
rewriter.setInsertionPoint(zeroInitOp->first);
bool zeroingBeforeMMA = zeroInitOp->first->isBeforeInBlock(mmaOp);
Value prevFlagValue = zeroingBeforeMMA ? loopArgFlagValue : vTrue;
auto selectFlagOp = rewriter.create<arith::SelectOp>(
loc, condition, thenInitsToZero ? vFalse : prevFlagValue,
thenInitsToZero ? prevFlagValue : vFalse);
setUseAccFlag(mmaOp, zeroingBeforeMMA ? selectFlagOp : loopArgFlagValue);
auto forYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
forYield->insertOperands(forYield->getNumOperands(),
{zeroingBeforeMMA ? vTrue : selectFlagOp});

// Stop clearing out the accumulator with zero
if (auto selOp = dyn_cast<arith::SelectOp>(zeroInitOp->first)) {
rewriter.setInsertionPoint(selOp);
rewriter.replaceOp(selOp, oldValue);
} else {
auto ifOp = cast<scf::IfOp>(zeroInitOp->first);
int resultIndex = zeroInitOp->second;
auto zeroingYield =
thenInitsToZero ? ifOp.thenYield() : ifOp.elseYield();
zeroingYield.setOperand(resultIndex, oldValue);
}
}
}
};

} // namespace gpu
} // namespace triton
} // namespace mlir
2 changes: 2 additions & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ void init_triton_passes_ttgpuir(py::module &&m) {
createAllocateSharedMemoryPass);
ADD_PASS_WRAPPER_0("add_combine_tensor_select_and_if",
createTritonGPUCombineTensorSelectAndIf);
ADD_PASS_WRAPPER_0("add_optimize_accumulator_init",
createTritonGPUOptimizeAccumulatorInit);
}

void init_triton_passes_convert(py::module &&m) {
Expand Down
5 changes: 3 additions & 2 deletions test/Conversion/nvgpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ llvm.func @st_matrix(%i: i32, %ptr: !llvm.ptr<3>) {
// CHECK-LABEL: @wgmma
llvm.func @wgmma(%desc: i64, %in: !struct_64xf32) {
// CHECK: wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2
%acc0 = nvgpu.wgmma %desc, %desc {
%false = llvm.mlir.constant(false) : i1
%acc0 = nvgpu.wgmma %desc, %desc, %false {
eltTypeA = 3 : i32,
eltTypeB = 3 : i32,
eltTypeC = 7 : i32,
Expand All @@ -80,7 +81,7 @@ llvm.func @wgmma(%desc: i64, %in: !struct_64xf32) {
m = 64 : i32,
n = 256 : i32,
k = 32 : i32
} : (i64, i64) -> !struct_128xf32
} : (i64, i64, i1) -> !struct_128xf32

// CHECK: // wait for regs: $0,$1,$2,{{.*}},$127
// CHECK: wgmma.wait_group.sync.aligned 0;
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: @dot_reg_operand_A
// Generate a wgmma where the first operand is a struct.
// CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: !tt.memdesc<64x64xf16, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
Expand All @@ -112,7 +112,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: @dot_reg_operand_A_fp8
// Generate a wgmma where the first operand is a struct.
// CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK: nvgpu.wgmma_wait_group %{{.*}} {pendings = 0 : i32}
tt.func @dot_reg_operand_A_fp8(%a: tensor<128x128xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>, %b: !tt.memdesc<128x256xf8E5M2, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma1>
Expand Down
Loading
Loading