Skip to content

Commit

Permalink
[MFMA] Switch between MFMA types
Browse files Browse the repository at this point in the history
This PR introduces matrix_instr_nonkdim flag to switch
between MFMA 16 and MFMA 32.
  • Loading branch information
binarman committed Oct 12, 2023
1 parent 821e75a commit f43b54e
Show file tree
Hide file tree
Showing 14 changed files with 350 additions and 248 deletions.
2 changes: 1 addition & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ bool maybeSharedAllocationOp(Operation *op);
bool maybeAliasOp(Operation *op);

#ifdef USE_ROCM
bool supportMFMA(triton::DotOp op, int64_t nonKDim);
bool supportMFMA(triton::DotOp op);
#endif

bool supportMMA(triton::DotOp op, int version);
Expand Down
4 changes: 4 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ std::unique_ptr<Pass> createTritonGPUStreamPipelinePass();
std::unique_ptr<Pass>
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);

std::unique_ptr<Pass>
createTritonAMDGPUAccelerateMatmulPass(int matrixCoreVersion = 0,
int matrixInstructionSize = 0);

std::unique_ptr<Pass> createTritonGPUPrefetchPass();

std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
Expand Down
23 changes: 23 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,29 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul
];
}

def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir::ModuleOp"> {
let summary = "accelerate matmul";

let description = [{
Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators
(e.g., AMD matrix cores)
}];

let constructor = "mlir::createTritonAMDGPUAccelerateMatmulPass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];

let options = [
Option<"matrixCoreVersion", "matrix-core-version",
"int32_t", /*default*/"0",
"device matrix core version">,
Option<"matrixInstructionSize", "matrix-instructio-size",
"int32_t", /*default*/"0",
"enforce matrix intrucion MN size">
];
}

def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> {
let summary = "fuse transpositions";

Expand Down
23 changes: 13 additions & 10 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,18 +378,21 @@ bool supportMMA(triton::DotOp op, int version) {
}

#ifdef USE_ROCM
static bool supportMFMAGranularity(int m, int n, int k, int64_t nonKDim) {
static bool supportMFMAGranularity(int m, int n, int k) {
// these limitations are dtype dependent, in future we may relax them
const int granularityMN = nonKDim;
const int granularityK = nonKDim == 32 ? 8 : 16;
if (m % granularityMN != 0 || n % granularityMN != 0)
return false;
if (k % granularityK != 0)
return false;
return true;
const static std::pair<int, int> mfmaTypes[2] = {{32, 8}, {16, 16}};
for (const auto &mfmaType : mfmaTypes) {
auto [granularityMN, granularityK] = mfmaType;
if (m % granularityMN != 0 || n % granularityMN != 0)
continue;
if (k % granularityK != 0)
continue;
return true;
}
return false;
}

bool supportMFMA(triton::DotOp op, int64_t nonKDim) {
bool supportMFMA(triton::DotOp op) {
auto aTy = op.getA().getType().cast<RankedTensorType>();
auto bTy = op.getB().getType().cast<RankedTensorType>();

Expand All @@ -403,7 +406,7 @@ bool supportMFMA(triton::DotOp op, int64_t nonKDim) {
auto bShape = bTy.getShape();

assert(aShape[1] == bShape[0]);
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1], nonKDim))
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1]))
return false;

return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() ||
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<MfmaEncodingAttr>();
if (!isOuter && mfmaLayout && supportMFMA(op, mfmaLayout.getNonKDim())) {
if (!isOuter && mfmaLayout && supportMFMA(op)) {
return convertMFMA(op, adaptor, getTypeConverter(), rewriter);
}
#endif
Expand Down
9 changes: 5 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,11 @@ struct ReduceOpConversion
inputTy.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
if (inMfma && inMfma.getIsTransposed()) {
assert(numLaneToReduce == 2 || numLaneToReduce == 4);
// for mfma 32x32 adjecant threads in y dimension in transposed MFMA layout are 32
// apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
// for mfma 16x16 adjecant threads in y dimension in transposed MFMA layout are 16
// apart: [[0 0 0 0 16 16 16 16 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
// for mfma 32x32 adjacent threads in y dimension in transposed MFMA
// layout are 32 apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33
// ...] ...]. for mfma 16x16 adjacent threads in y dimension in
// transposed MFMA layout are 16 apart: [[0 0 0 0 16 16 16 16 32 32 32
// 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
const int warpSize = 64;
shuffleIdx = warpSize / N / 2;
}
Expand Down
251 changes: 251 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/Support/Debug.h"
#include <memory>

using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace {
using tt::DotOp;
using ttg::BlockedEncodingAttr;
using ttg::ConvertLayoutOp;
using ttg::DotOperandEncodingAttr;
using ttg::MfmaEncodingAttr;
using ttg::SliceEncodingAttr;

SmallVector<unsigned, 2>
warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
// TODO: needs to be updated with appropriate shapePerWarp etc.
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, filter);
for (Operation *op : slices)
if (isa<tt::DotOp>(op) && (op != dotOp))
return {(unsigned)numWarps, 1};

SmallVector<int64_t, 2> tensorShape = {shape[0], shape[1]};
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {32, 32};
bool changed = false;

do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (tensorShape[0] / (shapePerWarp[0] * 2) / ret[0] >=
tensorShape[1] / shapePerWarp[1] / ret[1]) {
if (ret[0] < tensorShape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);

if (ret[1] * shapePerWarp[1] > tensorShape[1]) {
return {ret[1], ret[0]};
}

return ret;
}

class BlockedToMFMA : public mlir::RewritePattern {
int mfmaVersion;
int enforcedNonKDim;

public:
BlockedToMFMA(mlir::MLIRContext *context, int mfmaVersion, int nonKDim)
: mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context),
mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim) {}

bool isChainDot(tt::DotOp &dotOp) const {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, filter);
for (Operation *op : slices) {
if (isa<tt::DotOp>(op) && (op != dotOp))
return true;
}
return false;
}

/// @brief Choose MFMA instruction parameters
/// @param dot target dot operation
/// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments
std::pair<int64_t, int64_t> chooseMfmaDimensions(tt::DotOp dot) const {
// number of matrix elements along k dim per one MFMA intruction
int64_t kDim = -1;
auto opType = dot.getA().getType().cast<RankedTensorType>();
auto elemType = opType.getElementType();

auto resType = dot.getD().getType().cast<RankedTensorType>();
auto resShape = resType.getShape();

int64_t nonKDim = -1;
if (enforcedNonKDim != 0) {
nonKDim = enforcedNonKDim;
} else {
nonKDim = (resShape[0] < 32 || resShape[1] < 32) ? 16 : 32;
}
if (nonKDim == 32) {
if (elemType.isF32())
kDim = 2;
if (elemType.isF16())
kDim = 8;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 4;
if (mfmaVersion == 2)
kDim = 8;
}
if (elemType.isInteger(8))
kDim = 8;
} else {
if (elemType.isF32())
kDim = 4;
if (elemType.isF16())
kDim = 16;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 8;
if (mfmaVersion == 2)
kDim = 16;
}
if (elemType.isInteger(8))
kDim = 16;
}
assert(kDim != -1);
assert(nonKDim != -1);
assert(resShape[0] % nonKDim == 0 && resShape[1] % nonKDim == 0);
assert(opType.getShape()[1] % kDim == 0);
return {nonKDim, kDim};
}

mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dotOp = cast<tt::DotOp>(op);

auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (!oldRetType.getEncoding() ||
!oldRetType.getEncoding().isa<ttg::BlockedEncodingAttr>())
return failure();

if (!supportMFMA(dotOp))
return failure();

auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding());

// get MFMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);

// operands
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto ctx = oldAType.getContext();

ttg::MfmaEncodingAttr mfmaEnc;

auto [nonKDim, kDim] = chooseMfmaDimensions(dotOp);

auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps);

bool isTransposed = isChainDot(dotOp);
mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim,
warpsPerTile, isTransposed, CTALayout);

auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc);

// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<ttg::ConvertLayoutOp>(oldAcc.getLoc(),
newRetType, oldAcc);
auto oldAOrder = oldAType.getEncoding()
.cast<ttg::DotOperandEncodingAttr>()
.getParent()
.cast<ttg::BlockedEncodingAttr>()
.getOrder();
auto oldBOrder = oldBType.getEncoding()
.cast<ttg::DotOperandEncodingAttr>()
.getParent()
.cast<ttg::BlockedEncodingAttr>()
.getOrder();

// kWidth is a number of consecutive elements per one instruction per one
// thread
auto kWidth = kDim;
// in mfma 32x32 case argument matrix groups elements in 2 groups
// in mfma 16x16 case argument matrix groups elements in 4 groups
if (nonKDim == 32) {
kWidth /= 2;
} else {
assert(nonKDim == 16);
kWidth /= 4;
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth));
a = rewriter.create<ttg::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<ttg::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<tt::DotOp>(dotOp.getLoc(), newRetType, a, b,
newAcc, dotOp.getAllowTF32());

rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType,
newDot.getResult());
return success();
}
};

} // namespace

#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

class TritonAMDGPUAccelerateMatmulPass
: public TritonAMDGPUAccelerateMatmulBase<
TritonAMDGPUAccelerateMatmulPass> {
public:
TritonAMDGPUAccelerateMatmulPass() = default;
TritonAMDGPUAccelerateMatmulPass(int matrixCoreVersion,
int matrixInstructionSize) {
this->matrixCoreVersion = matrixCoreVersion;
this->matrixInstructionSize = matrixInstructionSize;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();

mlir::RewritePatternSet patterns(context);
if (matrixCoreVersion == 1 || matrixCoreVersion == 2)
patterns.add<::BlockedToMFMA>(context, matrixCoreVersion,
matrixInstructionSize);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
}
};

std::unique_ptr<Pass>
mlir::createTritonAMDGPUAccelerateMatmulPass(int matrixCoreVersion,
int matrixInstructionSize) {
return std::make_unique<TritonAMDGPUAccelerateMatmulPass>(
matrixCoreVersion, matrixInstructionSize);
}
Loading

0 comments on commit f43b54e

Please sign in to comment.