forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MFMA] Switch between MFMA types #352
Merged
alefimov-amd
merged 3 commits into
ROCm:triton-mlir
from
binarman:mfma16_support_kernel_parameter
Oct 18, 2023
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
251 changes: 251 additions & 0 deletions
251
lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
#include "mlir/IR/TypeUtilities.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've Separated this code from common AccelerateMatmul pass, so I can add an additional option to it. Do you think it is ok to do this in this PR or is it better to separate it? |
||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "triton/Analysis/Utility.h" | ||
#include "triton/Dialect/TritonGPU/IR/Dialect.h" | ||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h" | ||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h" | ||
#include "triton/Tools/Sys/GetEnv.hpp" | ||
#include "llvm/Support/Debug.h" | ||
#include <memory> | ||
|
||
using namespace mlir; | ||
namespace tt = mlir::triton; | ||
namespace ttg = mlir::triton::gpu; | ||
namespace { | ||
using tt::DotOp; | ||
using ttg::BlockedEncodingAttr; | ||
using ttg::ConvertLayoutOp; | ||
using ttg::DotOperandEncodingAttr; | ||
using ttg::MfmaEncodingAttr; | ||
using ttg::SliceEncodingAttr; | ||
|
||
SmallVector<unsigned, 2> | ||
warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) { | ||
// TODO: needs to be updated with appropriate shapePerWarp etc. | ||
auto filter = [&dotOp](Operation *op) { | ||
return op->getParentRegion() == dotOp->getParentRegion(); | ||
}; | ||
auto slices = mlir::getSlice(dotOp, filter); | ||
for (Operation *op : slices) | ||
if (isa<tt::DotOp>(op) && (op != dotOp)) | ||
return {(unsigned)numWarps, 1}; | ||
|
||
SmallVector<int64_t, 2> tensorShape = {shape[0], shape[1]}; | ||
SmallVector<unsigned, 2> ret = {1, 1}; | ||
SmallVector<int64_t, 2> shapePerWarp = {32, 32}; | ||
bool changed = false; | ||
|
||
do { | ||
changed = false; | ||
if (ret[0] * ret[1] >= numWarps) | ||
break; | ||
if (tensorShape[0] / (shapePerWarp[0] * 2) / ret[0] >= | ||
tensorShape[1] / shapePerWarp[1] / ret[1]) { | ||
if (ret[0] < tensorShape[0] / shapePerWarp[0]) { | ||
ret[0] *= 2; | ||
} else | ||
ret[1] *= 2; | ||
} else { | ||
ret[1] *= 2; | ||
} | ||
} while (true); | ||
|
||
if (ret[1] * shapePerWarp[1] > tensorShape[1]) { | ||
return {ret[1], ret[0]}; | ||
} | ||
|
||
return ret; | ||
} | ||
|
||
class BlockedToMFMA : public mlir::RewritePattern { | ||
int mfmaVersion; | ||
int enforcedNonKDim; | ||
|
||
public: | ||
BlockedToMFMA(mlir::MLIRContext *context, int mfmaVersion, int nonKDim) | ||
: mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context), | ||
mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim) {} | ||
|
||
bool isChainDot(tt::DotOp &dotOp) const { | ||
auto filter = [&dotOp](Operation *op) { | ||
return op->getParentRegion() == dotOp->getParentRegion(); | ||
}; | ||
auto slices = mlir::getSlice(dotOp, filter); | ||
for (Operation *op : slices) { | ||
if (isa<tt::DotOp>(op) && (op != dotOp)) | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
/// @brief Choose MFMA instruction parameters | ||
/// @param dot target dot operation | ||
/// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments | ||
std::pair<int64_t, int64_t> chooseMfmaDimensions(tt::DotOp dot) const { | ||
// number of matrix elements along k dim per one MFMA intruction | ||
int64_t kDim = -1; | ||
auto opType = dot.getA().getType().cast<RankedTensorType>(); | ||
auto elemType = opType.getElementType(); | ||
|
||
auto resType = dot.getD().getType().cast<RankedTensorType>(); | ||
auto resShape = resType.getShape(); | ||
|
||
int64_t nonKDim = -1; | ||
if (enforcedNonKDim != 0) { | ||
nonKDim = enforcedNonKDim; | ||
} else { | ||
nonKDim = (resShape[0] < 32 || resShape[1] < 32) ? 16 : 32; | ||
} | ||
if (nonKDim == 32) { | ||
if (elemType.isF32()) | ||
kDim = 2; | ||
if (elemType.isF16()) | ||
kDim = 8; | ||
if (elemType.isBF16()) { | ||
if (mfmaVersion == 1) | ||
kDim = 4; | ||
if (mfmaVersion == 2) | ||
kDim = 8; | ||
} | ||
if (elemType.isInteger(8)) | ||
kDim = 8; | ||
} else { | ||
if (elemType.isF32()) | ||
kDim = 4; | ||
if (elemType.isF16()) | ||
kDim = 16; | ||
if (elemType.isBF16()) { | ||
if (mfmaVersion == 1) | ||
kDim = 8; | ||
if (mfmaVersion == 2) | ||
kDim = 16; | ||
} | ||
if (elemType.isInteger(8)) | ||
kDim = 16; | ||
} | ||
assert(kDim != -1); | ||
assert(nonKDim != -1); | ||
assert(resShape[0] % nonKDim == 0 && resShape[1] % nonKDim == 0); | ||
assert(opType.getShape()[1] % kDim == 0); | ||
return {nonKDim, kDim}; | ||
} | ||
|
||
mlir::LogicalResult | ||
matchAndRewrite(mlir::Operation *op, | ||
mlir::PatternRewriter &rewriter) const override { | ||
auto dotOp = cast<tt::DotOp>(op); | ||
|
||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>(); | ||
if (!oldRetType.getEncoding() || | ||
!oldRetType.getEncoding().isa<ttg::BlockedEncodingAttr>()) | ||
return failure(); | ||
|
||
if (!supportMFMA(dotOp)) | ||
return failure(); | ||
|
||
auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding()); | ||
|
||
// get MFMA encoding for the given number of warps | ||
auto retShape = oldRetType.getShape(); | ||
auto mod = op->getParentOfType<mlir::ModuleOp>(); | ||
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); | ||
|
||
// operands | ||
Value a = dotOp.getA(); | ||
Value b = dotOp.getB(); | ||
auto oldAType = a.getType().cast<RankedTensorType>(); | ||
auto oldBType = b.getType().cast<RankedTensorType>(); | ||
auto ctx = oldAType.getContext(); | ||
|
||
ttg::MfmaEncodingAttr mfmaEnc; | ||
|
||
auto [nonKDim, kDim] = chooseMfmaDimensions(dotOp); | ||
|
||
auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps); | ||
|
||
bool isTransposed = isChainDot(dotOp); | ||
mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim, | ||
warpsPerTile, isTransposed, CTALayout); | ||
|
||
auto newRetType = | ||
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc); | ||
|
||
// convert accumulator | ||
auto oldAcc = dotOp.getOperand(2); | ||
auto newAcc = rewriter.create<ttg::ConvertLayoutOp>(oldAcc.getLoc(), | ||
newRetType, oldAcc); | ||
auto oldAOrder = oldAType.getEncoding() | ||
.cast<ttg::DotOperandEncodingAttr>() | ||
.getParent() | ||
.cast<ttg::BlockedEncodingAttr>() | ||
.getOrder(); | ||
auto oldBOrder = oldBType.getEncoding() | ||
.cast<ttg::DotOperandEncodingAttr>() | ||
.getParent() | ||
.cast<ttg::BlockedEncodingAttr>() | ||
.getOrder(); | ||
|
||
// kWidth is a number of consecutive elements per one instruction per one | ||
// thread | ||
auto kWidth = kDim; | ||
// in mfma 32x32 case argument matrix groups elements in 2 groups | ||
// in mfma 16x16 case argument matrix groups elements in 4 groups | ||
if (nonKDim == 32) { | ||
kWidth /= 2; | ||
} else { | ||
assert(nonKDim == 16); | ||
kWidth /= 4; | ||
} | ||
auto newAType = RankedTensorType::get( | ||
oldAType.getShape(), oldAType.getElementType(), | ||
ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth)); | ||
auto newBType = RankedTensorType::get( | ||
oldBType.getShape(), oldBType.getElementType(), | ||
ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth)); | ||
a = rewriter.create<ttg::ConvertLayoutOp>(a.getLoc(), newAType, a); | ||
b = rewriter.create<ttg::ConvertLayoutOp>(b.getLoc(), newBType, b); | ||
auto newDot = rewriter.create<tt::DotOp>(dotOp.getLoc(), newRetType, a, b, | ||
newAcc, dotOp.getAllowTF32()); | ||
|
||
rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType, | ||
newDot.getResult()); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
#define GEN_PASS_CLASSES | ||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" | ||
|
||
class TritonAMDGPUAccelerateMatmulPass | ||
: public TritonAMDGPUAccelerateMatmulBase< | ||
TritonAMDGPUAccelerateMatmulPass> { | ||
public: | ||
TritonAMDGPUAccelerateMatmulPass() = default; | ||
TritonAMDGPUAccelerateMatmulPass(int matrixCoreVersion, | ||
int matrixInstructionSize) { | ||
this->matrixCoreVersion = matrixCoreVersion; | ||
this->matrixInstructionSize = matrixInstructionSize; | ||
} | ||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
ModuleOp m = getOperation(); | ||
|
||
mlir::RewritePatternSet patterns(context); | ||
if (matrixCoreVersion == 1 || matrixCoreVersion == 2) | ||
patterns.add<::BlockedToMFMA>(context, matrixCoreVersion, | ||
matrixInstructionSize); | ||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { | ||
signalPassFailure(); | ||
} | ||
} | ||
}; | ||
|
||
std::unique_ptr<Pass> | ||
mlir::createTritonAMDGPUAccelerateMatmulPass(int matrixCoreVersion, | ||
int matrixInstructionSize) { | ||
return std::make_unique<TritonAMDGPUAccelerateMatmulPass>( | ||
matrixCoreVersion, matrixInstructionSize); | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to get the waveSize from the gpu dialect or mfma layout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately no...
However! MFMA layout appears in IR only if target is CDNA architecture, which has only 64 waves mode.
I think it should be safe to use constant here.
In my opinion we should report MFMA layout on non CDNA GPU as an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you really want, it is possible to infer waveSize from mfmaLayout by computing a product of mfmaLayout.threadsPerWarp. But that is a little "ugly" in my opinion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And you can also get it from gpu dialect like here: https://github.com/ROCmSoftwarePlatform/triton/blob/4d539d7dae055bb6b8dbb1b2b380118333250f15/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp#L589