Skip to content

Commit

Permalink
[OPTIMIZER] Separate out kWidth layout optimization from pipelining p…
Browse files Browse the repository at this point in the history
…ass (#1823)

Since the kWidth optimization was happening during software pipelining
it was skipped in case pipelining wasn't applied.
This also improve separation of concerns.
  • Loading branch information
ThomasRaoux committed Jun 23, 2023
1 parent 2eacaff commit 2eb7bc4
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 43 deletions.
16 changes: 12 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"Type":$eltTy), [{
"unsigned":$typeWidthInBit), [{
auto mmaEnc = dotOpEnc.getParent().dyn_cast<MmaEncodingAttr>();

if(!mmaEnc)
Expand All @@ -87,7 +87,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
int opIdx = dotOpEnc.getOpIdx();

// number of rows per phase
int perPhase = 128 / (shape[order[0]] * (eltTy.getIntOrFloatBitWidth() / 8));
int perPhase = 128 / (shape[order[0]] * (typeWidthInBit / 8));
perPhase = std::max<int>(perPhase, 1);

// index of the inner dimension in `order`
Expand All @@ -109,9 +109,9 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
std::vector<size_t> matShape = {8, 8,
2 * 64 / eltTy.getIntOrFloatBitWidth()};
2 * 64 / typeWidthInBit};
// for now, disable swizzle when using transposed int8 tensor cores
if (eltTy.isInteger(8) && order[0] == inner)
if (typeWidthInBit == 8 && order[0] == inner)
return $_get(context, 1, 1, 1, order);

// --- handle A operand ---
Expand All @@ -135,6 +135,14 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /

// ---- not implemented ----
llvm_unreachable("unsupported swizzling for provided MMA version");
}]>,

AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
"ArrayRef<int64_t>":$shape,
"ArrayRef<unsigned>":$order,
"Type":$eltTy), [{
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
return get(context, dotOpEnc, shape, order, bitwidth);
}]>
];

Expand Down
118 changes: 118 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using triton::DotOp;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
using triton::gpu::SharedEncodingAttr;
using triton::gpu::SliceEncodingAttr;

// convert(trans(convert(arg)))
Expand Down Expand Up @@ -200,6 +201,119 @@ class MoveOpAfterLayoutConversion : public mlir::RewritePattern {

} // namespace

static bool isConvertToDotEncoding(Operation *op) {
auto convertLayout = llvm::dyn_cast<ConvertLayoutOp>(op);
if (!convertLayout)
return false;
auto tensorType =
convertLayout.getResult().getType().cast<RankedTensorType>();
return tensorType.getEncoding().isa<DotOperandEncodingAttr>();
}

static ConvertLayoutOp updateConvert(OpBuilder &builder, ConvertLayoutOp cvt,
IRMapping &mapping, Type smallestType) {
auto cvtDstTy = cvt.getResult().getType().cast<RankedTensorType>();
auto cvtDstEnc = cvtDstTy.getEncoding().cast<DotOperandEncodingAttr>();
Value operand = cvt.getOperand();
if (mapping.contains(operand))
operand = mapping.lookup(operand);
auto newDstTy = RankedTensorType::get(
cvtDstTy.getShape(), cvtDstTy.getElementType(),
DotOperandEncodingAttr::get(cvtDstEnc.getContext(), cvtDstEnc.getOpIdx(),
cvtDstEnc.getParent(), smallestType));
auto newCvt =
builder.create<ConvertLayoutOp>(cvt.getLoc(), newDstTy, operand);
mapping.map(cvt.getResult(), newCvt.getResult());
return newCvt;
}

// Update kWidth based on the smallestType found in the given convert ops and
// propagate the type change.
static void
updateDotEncodingLayout(SmallVector<ConvertLayoutOp> &convertsToDotEncoding,
Type smallestType) {
IRMapping mapping;
OpBuilder builder(smallestType.getContext());
SetVector<Operation *> slices(convertsToDotEncoding.begin(),
convertsToDotEncoding.end());
// Collect all the operations where the type needs to be propagated.
for (auto cvt : convertsToDotEncoding) {
auto filter = [&](Operation *op) {
for (Value operand : op->getOperands()) {
auto tensorType = operand.getType().dyn_cast<RankedTensorType>();
if (tensorType &&
tensorType.getEncoding().isa<DotOperandEncodingAttr>())
return true;
}
return false;
};
mlir::getForwardSlice(cvt.getResult(), &slices, {filter});
}
// Apply the type change by walking ops in topological order.
slices = mlir::topologicalSort(slices);
for (Operation *op : slices) {
builder.setInsertionPoint(op);
if (isConvertToDotEncoding(op)) {
auto cvt = cast<ConvertLayoutOp>(op);
ConvertLayoutOp newCvt =
updateConvert(builder, cvt, mapping, smallestType);
continue;
}
auto *newOp = cloneWithInferType(builder, op, mapping);
for (auto [result, newResult] :
llvm::zip(op->getResults(), newOp->getResults())) {
result.replaceUsesWithIf(newResult, [&](OpOperand &operand) {
return slices.count(operand.getOwner()) == 0;
});
}
}
for (Operation *op : llvm::reverse(slices))
op->erase();
}

// Change the layout of dotOperand layout to use the kWidth from the smallest
// loaded type. This allows better code generation for mixed-mode matmul.
static void optimizeKWidth(triton::FuncOp func) {
SmallVector<ConvertLayoutOp> convertsToDotEncoding;
Type smallestType;
func->walk([&](triton::LoadOp loadOp) {
if (!loadOp.getResult().hasOneUse())
return;
Operation *use = *loadOp.getResult().getUsers().begin();

// Advance to the first conversion as long as the use resides in shared
// memory and it has a single use itself
while (use) {
if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse())
break;
auto tensorType =
use->getResult(0).getType().dyn_cast<RankedTensorType>();
if (!tensorType || !tensorType.getEncoding().isa<SharedEncodingAttr>())
break;
use = *use->getResult(0).getUsers().begin();
}

auto convertLayout = llvm::dyn_cast<ConvertLayoutOp>(use);
if (!convertLayout)
return;
auto tensorType =
convertLayout.getResult().getType().cast<RankedTensorType>();
if (!tensorType.getEncoding().isa<DotOperandEncodingAttr>())
return;
convertsToDotEncoding.push_back(convertLayout);

// Update the smallest type.
auto ty = loadOp.getType().cast<RankedTensorType>();
Type eltTy = ty.getElementType();
if (!smallestType ||
(eltTy.getIntOrFloatBitWidth() < smallestType.getIntOrFloatBitWidth()))
smallestType = eltTy;
});
if (!smallestType)
return;
updateDotEncodingLayout(convertsToDotEncoding, smallestType);
}

#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

Expand All @@ -224,6 +338,10 @@ class TritonGPUOptimizeDotOperandsPass
signalPassFailure();
if (fixupLoops(m).failed())
signalPassFailure();

// Change the layout of dotOperand layout to use the kWidth from the
// smallest loaded type.
m->walk([](triton::FuncOp func) { optimizeKWidth(func); });
}
};

Expand Down
38 changes: 8 additions & 30 deletions lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,6 @@ class LoopPipeliner {

/// Loads to be pipelined
SetVector<Value> validLoads;
/// Smallest data-type for each load (used to optimize swizzle and
/// (create DotOpEncoding layout)
DenseMap<Value, Type> loadsSmallestType;
/// The value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// load => buffer
Expand Down Expand Up @@ -485,21 +482,6 @@ Value LoopPipeliner::lookupOrDefault(Value origin, int stage) {
}

void LoopPipeliner::createBufferTypes() {
// We need to find the smallest common dtype since this determines the layout
// of `mma.sync` operands in mixed-precision mode
Type smallestType;
for (auto loadCvt : loadsMapping) {
auto loadOp = loadCvt.first;
auto ty = loadOp.getType().cast<RankedTensorType>();
Type eltTy = ty.getElementType();
if (!smallestType ||
(eltTy.getIntOrFloatBitWidth() < smallestType.getIntOrFloatBitWidth()))
smallestType = eltTy;
}

for (auto loadCvt : loadsMapping)
loadsSmallestType[loadCvt.first] = smallestType;

for (auto loadCvt : loadsMapping) {
auto loadOp = loadCvt.first;
Value cvt = loadCvt.second;
Expand All @@ -511,9 +493,12 @@ void LoopPipeliner::createBufferTypes() {
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
ttg::getOrder(ty.getEncoding()), loadsSmallestType[loadOp]);
unsigned bitWidth = dotOpEnc.getMMAv2kWidth()
? 32 / dotOpEnc.getMMAv2kWidth()
: ty.getElementType().getIntOrFloatBitWidth();
auto sharedEnc =
ttg::SharedEncodingAttr::get(ty.getContext(), dotOpEnc, ty.getShape(),
ttg::getOrder(ty.getEncoding()), bitWidth);
loadsBufferType[loadOp] =
RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc);
}
Expand Down Expand Up @@ -789,19 +774,12 @@ scf::ForOp LoopPipeliner::cloneForOp(ArrayRef<Value> newLoopArgs,
// we replace the use new load use with a convert layout
size_t i = std::distance(validLoads.begin(), it);
auto cvtDstTy = op.getResult(0).getType().cast<RankedTensorType>();
auto cvtDstEnc =
cvtDstTy.getEncoding().dyn_cast<ttg::DotOperandEncodingAttr>();
if (!cvtDstEnc) {
if (!cvtDstTy.getEncoding().isa<ttg::DotOperandEncodingAttr>()) {
builder.clone(op, mapping);
continue;
}
auto newDstTy = RankedTensorType::get(
cvtDstTy.getShape(), cvtDstTy.getElementType(),
ttg::DotOperandEncodingAttr::get(
cvtDstEnc.getContext(), cvtDstEnc.getOpIdx(), cvtDstEnc.getParent(),
loadsSmallestType[op.getOperand(0)]));
auto cvt = builder.create<ttg::ConvertLayoutOp>(
op.getResult(0).getLoc(), newDstTy,
op.getResult(0).getLoc(), cvtDstTy,
newForOp.getRegionIterArgs()[loadIdx + i]);
mapping.map(op.getResult(0), cvt.getResult());
}
Expand Down
20 changes: 11 additions & 9 deletions test/TritonGPU/dot-operands.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {

// CHECK: tt.func @push_elementwise1
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]]
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]]
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
// CHECK: %[[C:.*]] = tt.dot %[[AF16]]
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] {{.*}} #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] {{.*}} #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] {{.*}} #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %{{.*}} : {{.*}} tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
// CHECK: %[[C:.*]] = tt.dot %[[AF16]], %[[BCVT]]
// CHECK-SAME: tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x16xf32, #mma>
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
tt.func @push_elementwise1(
%pa: tensor<16x16x!tt.ptr<i8>, #ALR> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
Expand Down Expand Up @@ -161,12 +163,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {

// CHECK: tt.func @push_convert_both_operands
// CHECK: %[[ALOAD:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #[[BA]]>
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] : (tensor<16x16xf16, #[[BA]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 1}>>
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] : (tensor<16x16xf16, #[[BA]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>>
// CHECK: %[[BLOAD:.*]] = tt.load %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #[[BB]]>
// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 1}>>
// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 1}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 1}>>
// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 1}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 1}>>
// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 1}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 1}>> -> tensor<16x16xf32, #mma>
// CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
// CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>>
// CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>>
// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma>
tt.func @push_convert_both_operands(
%pa: tensor<16x16x!tt.ptr<f16>, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pb: tensor<16x16x!tt.ptr<f16>, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
Expand Down

0 comments on commit 2eb7bc4

Please sign in to comment.