Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MFMA] Switch between MFMA types #352

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ bool maybeSharedAllocationOp(Operation *op);
bool maybeAliasOp(Operation *op);

#ifdef USE_ROCM
bool supportMFMA(triton::DotOp op, int64_t nonKDim);
bool supportMFMA(triton::DotOp op);
#endif

bool supportMMA(triton::DotOp op, int version);
Expand Down
4 changes: 4 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ std::unique_ptr<Pass> createTritonGPUStreamPipelinePass();
std::unique_ptr<Pass>
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);

std::unique_ptr<Pass>
createTritonAMDGPUAccelerateMatmulPass(int matrixCoreVersion = 0,
int matrixInstructionSize = 0);

std::unique_ptr<Pass> createTritonGPUPrefetchPass();

std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
Expand Down
23 changes: 23 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,29 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul
];
}

def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir::ModuleOp"> {
let summary = "accelerate matmul";

let description = [{
Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators
(e.g., AMD matrix cores)
}];

let constructor = "mlir::createTritonAMDGPUAccelerateMatmulPass()";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];

let options = [
Option<"matrixCoreVersion", "matrix-core-version",
"int32_t", /*default*/"0",
"device matrix core version">,
Option<"matrixInstructionSize", "matrix-instruction-size",
"int32_t", /*default*/"0",
"enforce matrix instruction MN size">
];
}

def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> {
let summary = "fuse transpositions";

Expand Down
23 changes: 13 additions & 10 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,18 +378,21 @@ bool supportMMA(triton::DotOp op, int version) {
}

#ifdef USE_ROCM
static bool supportMFMAGranularity(int m, int n, int k, int64_t nonKDim) {
static bool supportMFMAGranularity(int m, int n, int k) {
// these limitations are dtype dependent, in future we may relax them
const int granularityMN = nonKDim;
const int granularityK = nonKDim == 32 ? 8 : 16;
if (m % granularityMN != 0 || n % granularityMN != 0)
return false;
if (k % granularityK != 0)
return false;
return true;
const static std::pair<int, int> mfmaTypes[2] = {{32, 8}, {16, 16}};
for (const auto &mfmaType : mfmaTypes) {
auto [granularityMN, granularityK] = mfmaType;
if (m % granularityMN != 0 || n % granularityMN != 0)
continue;
if (k % granularityK != 0)
continue;
return true;
}
return false;
}

bool supportMFMA(triton::DotOp op, int64_t nonKDim) {
bool supportMFMA(triton::DotOp op) {
auto aTy = op.getA().getType().cast<RankedTensorType>();
auto bTy = op.getB().getType().cast<RankedTensorType>();

Expand All @@ -403,7 +406,7 @@ bool supportMFMA(triton::DotOp op, int64_t nonKDim) {
auto bShape = bTy.getShape();

assert(aShape[1] == bShape[0]);
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1], nonKDim))
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1]))
return false;

return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() ||
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<MfmaEncodingAttr>();
if (!isOuter && mfmaLayout && supportMFMA(op, mfmaLayout.getNonKDim())) {
if (!isOuter && mfmaLayout && supportMFMA(op)) {
return convertMFMA(op, adaptor, getTypeConverter(), rewriter);
}
#endif
Expand Down
9 changes: 5 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,11 @@ struct ReduceOpConversion
inputTy.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
if (inMfma && inMfma.getIsTransposed()) {
assert(numLaneToReduce == 2 || numLaneToReduce == 4);
// for mfma 32x32 adjecant threads in y dimension in transposed MFMA layout are 32
// apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
// for mfma 16x16 adjecant threads in y dimension in transposed MFMA layout are 16
// apart: [[0 0 0 0 16 16 16 16 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
// for mfma 32x32 adjacent threads in y dimension in transposed MFMA
// layout are 32 apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33
// ...] ...]. for mfma 16x16 adjacent threads in y dimension in
// transposed MFMA layout are 16 apart: [[0 0 0 0 16 16 16 16 32 32 32
// 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to get the waveSize from the gpu dialect or mfma layout?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately no...
However! MFMA layout appears in IR only if target is CDNA architecture, which has only 64 waves mode.

I think it should be safe to use constant here.
In my opinion we should report MFMA layout on non CDNA GPU as an error.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you really want, it is possible to infer waveSize from mfmaLayout by computing a product of mfmaLayout.threadsPerWarp. But that is a little "ugly" in my opinion.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const int warpSize = 64;
shuffleIdx = warpSize / N / 2;
}
Expand Down
251 changes: 251 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
#include "mlir/IR/TypeUtilities.h"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhanglx13

I've Separated this code from common AccelerateMatmul pass, so I can add an additional option to it.

Do you think it is ok to do this in this PR or is it better to separate it?

#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/Support/Debug.h"
#include <memory>

using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace {
using tt::DotOp;
using ttg::BlockedEncodingAttr;
using ttg::ConvertLayoutOp;
using ttg::DotOperandEncodingAttr;
using ttg::MfmaEncodingAttr;
using ttg::SliceEncodingAttr;

SmallVector<unsigned, 2>
warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
// TODO: needs to be updated with appropriate shapePerWarp etc.
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, filter);
for (Operation *op : slices)
if (isa<tt::DotOp>(op) && (op != dotOp))
return {(unsigned)numWarps, 1};

SmallVector<int64_t, 2> tensorShape = {shape[0], shape[1]};
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {32, 32};
bool changed = false;

do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (tensorShape[0] / (shapePerWarp[0] * 2) / ret[0] >=
tensorShape[1] / shapePerWarp[1] / ret[1]) {
if (ret[0] < tensorShape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);

if (ret[1] * shapePerWarp[1] > tensorShape[1]) {
return {ret[1], ret[0]};
}

return ret;
}

class BlockedToMFMA : public mlir::RewritePattern {
int mfmaVersion;
int enforcedNonKDim;

public:
BlockedToMFMA(mlir::MLIRContext *context, int mfmaVersion, int nonKDim)
: mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context),
mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim) {}

bool isChainDot(tt::DotOp &dotOp) const {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, filter);
for (Operation *op : slices) {
if (isa<tt::DotOp>(op) && (op != dotOp))
return true;
}
return false;
}

/// @brief Choose MFMA instruction parameters
/// @param dot target dot operation
/// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments
std::pair<int64_t, int64_t> chooseMfmaDimensions(tt::DotOp dot) const {
// number of matrix elements along k dim per one MFMA intruction
int64_t kDim = -1;
auto opType = dot.getA().getType().cast<RankedTensorType>();
auto elemType = opType.getElementType();

auto resType = dot.getD().getType().cast<RankedTensorType>();
auto resShape = resType.getShape();

int64_t nonKDim = -1;
if (enforcedNonKDim != 0) {
nonKDim = enforcedNonKDim;
} else {
nonKDim = (resShape[0] < 32 || resShape[1] < 32) ? 16 : 32;
}
if (nonKDim == 32) {
if (elemType.isF32())
kDim = 2;
if (elemType.isF16())
kDim = 8;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 4;
if (mfmaVersion == 2)
kDim = 8;
}
if (elemType.isInteger(8))
kDim = 8;
} else {
if (elemType.isF32())
kDim = 4;
if (elemType.isF16())
kDim = 16;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 8;
if (mfmaVersion == 2)
kDim = 16;
}
if (elemType.isInteger(8))
kDim = 16;
}
assert(kDim != -1);
assert(nonKDim != -1);
assert(resShape[0] % nonKDim == 0 && resShape[1] % nonKDim == 0);
assert(opType.getShape()[1] % kDim == 0);
return {nonKDim, kDim};
}

mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto dotOp = cast<tt::DotOp>(op);

auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (!oldRetType.getEncoding() ||
!oldRetType.getEncoding().isa<ttg::BlockedEncodingAttr>())
return failure();

if (!supportMFMA(dotOp))
return failure();

auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding());

// get MFMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);

// operands
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto ctx = oldAType.getContext();

ttg::MfmaEncodingAttr mfmaEnc;

auto [nonKDim, kDim] = chooseMfmaDimensions(dotOp);

auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps);

bool isTransposed = isChainDot(dotOp);
mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim,
warpsPerTile, isTransposed, CTALayout);

auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc);

// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<ttg::ConvertLayoutOp>(oldAcc.getLoc(),
newRetType, oldAcc);
auto oldAOrder = oldAType.getEncoding()
.cast<ttg::DotOperandEncodingAttr>()
.getParent()
.cast<ttg::BlockedEncodingAttr>()
.getOrder();
auto oldBOrder = oldBType.getEncoding()
.cast<ttg::DotOperandEncodingAttr>()
.getParent()
.cast<ttg::BlockedEncodingAttr>()
.getOrder();

// kWidth is a number of consecutive elements per one instruction per one
// thread
auto kWidth = kDim;
// in mfma 32x32 case argument matrix groups elements in 2 groups
// in mfma 16x16 case argument matrix groups elements in 4 groups
if (nonKDim == 32) {
kWidth /= 2;
} else {
assert(nonKDim == 16);
kWidth /= 4;
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth));
a = rewriter.create<ttg::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<ttg::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<tt::DotOp>(dotOp.getLoc(), newRetType, a, b,
newAcc, dotOp.getAllowTF32());

rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType,
newDot.getResult());
return success();
}
};

} // namespace

#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

class TritonAMDGPUAccelerateMatmulPass
: public TritonAMDGPUAccelerateMatmulBase<
TritonAMDGPUAccelerateMatmulPass> {
public:
TritonAMDGPUAccelerateMatmulPass() = default;
TritonAMDGPUAccelerateMatmulPass(int matrixCoreVersion,
int matrixInstructionSize) {
this->matrixCoreVersion = matrixCoreVersion;
this->matrixInstructionSize = matrixInstructionSize;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();

mlir::RewritePatternSet patterns(context);
if (matrixCoreVersion == 1 || matrixCoreVersion == 2)
patterns.add<::BlockedToMFMA>(context, matrixCoreVersion,
matrixInstructionSize);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
}
};

std::unique_ptr<Pass>
mlir::createTritonAMDGPUAccelerateMatmulPass(int matrixCoreVersion,
int matrixInstructionSize) {
return std::make_unique<TritonAMDGPUAccelerateMatmulPass>(
matrixCoreVersion, matrixInstructionSize);
}
Loading