From 30c35963aeea3c86df63da8c09649c693271c256 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Mon, 26 Jun 2023 16:23:02 +0000 Subject: [PATCH 1/3] [MFMA] Switch between MFMA types This PR introduces matrix_instr_nonkdim flag to switch between MFMA 16 and MFMA 32. --- include/triton/Analysis/Utility.h | 2 +- .../Dialect/TritonGPU/Transforms/Passes.h | 4 + .../Dialect/TritonGPU/Transforms/Passes.td | 23 ++ lib/Analysis/Utility.cpp | 23 +- .../TritonGPUToLLVM/DotOpToLLVM.cpp | 2 +- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 9 +- .../Transforms/AccelerateAMDMatmul.cpp | 251 ++++++++++++++++++ .../TritonGPU/Transforms/AccelerateMatmul.cpp | 202 -------------- .../TritonGPU/Transforms/CMakeLists.txt | 1 + python/src/triton.cc | 5 + python/test/unit/language/test_core_amd.py | 30 ++- python/triton/compiler/compiler.py | 12 +- python/triton/language/semantic.py | 22 +- python/triton/runtime/jit.py | 12 +- 14 files changed, 350 insertions(+), 248 deletions(-) create mode 100644 lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index af0f0961bc9c..ef167f5612b1 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -115,7 +115,7 @@ bool maybeSharedAllocationOp(Operation *op); bool maybeAliasOp(Operation *op); #ifdef USE_ROCM -bool supportMFMA(triton::DotOp op, int64_t nonKDim); +bool supportMFMA(triton::DotOp op); #endif bool supportMMA(triton::DotOp op, int version); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index 89b3d818c072..abab74741a83 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -16,6 +16,10 @@ std::unique_ptr createTritonGPUStreamPipelinePass(); std::unique_ptr createTritonGPUAccelerateMatmulPass(int computeCapability = 80); +std::unique_ptr +createTritonAMDGPUAccelerateMatmulPass(int matrixCoreVersion = 0, + int matrixInstructionSize = 0); + std::unique_ptr createTritonGPUPrefetchPass(); std::unique_ptr createTritonGPUCanonicalizeLoopsPass(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index 81d1b1de7ba9..b30e897207e3 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -85,6 +85,29 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul ]; } +def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir::ModuleOp"> { + let summary = "accelerate matmul"; + + let description = [{ + Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators + (e.g., AMD matrix cores) + }]; + + let constructor = "mlir::createTritonAMDGPUAccelerateMatmulPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + + let options = [ + Option<"matrixCoreVersion", "matrix-core-version", + "int32_t", /*default*/"0", + "device matrix core version">, + Option<"matrixInstructionSize", "matrix-instructio-size", + "int32_t", /*default*/"0", + "enforce matrix intrucion MN size"> + ]; +} + def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> { let summary = "fuse transpositions"; diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 6b758414896b..602c7dc0fa0f 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -378,18 +378,21 @@ bool supportMMA(triton::DotOp op, int version) { } #ifdef USE_ROCM -static bool supportMFMAGranularity(int m, int n, int k, int64_t nonKDim) { +static bool supportMFMAGranularity(int m, int n, int k) { // these limitations are dtype dependent, in future we may relax them - const int granularityMN = nonKDim; - const int granularityK = nonKDim == 32 ? 8 : 16; - if (m % granularityMN != 0 || n % granularityMN != 0) - return false; - if (k % granularityK != 0) - return false; - return true; + const static std::pair mfmaTypes[2] = {{32, 8}, {16, 16}}; + for (const auto &mfmaType : mfmaTypes) { + auto [granularityMN, granularityK] = mfmaType; + if (m % granularityMN != 0 || n % granularityMN != 0) + continue; + if (k % granularityK != 0) + continue; + return true; + } + return false; } -bool supportMFMA(triton::DotOp op, int64_t nonKDim) { +bool supportMFMA(triton::DotOp op) { auto aTy = op.getA().getType().cast(); auto bTy = op.getB().getType().cast(); @@ -403,7 +406,7 @@ bool supportMFMA(triton::DotOp op, int64_t nonKDim) { auto bShape = bTy.getShape(); assert(aShape[1] == bShape[0]); - if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1], nonKDim)) + if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1])) return false; return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() || diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp index bcda6ccf5d7c..97f8fb4ee6f1 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp @@ -82,7 +82,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { .cast() .getEncoding() .dyn_cast(); - if (!isOuter && mfmaLayout && supportMFMA(op, mfmaLayout.getNonKDim())) { + if (!isOuter && mfmaLayout && supportMFMA(op)) { return convertMFMA(op, adaptor, getTypeConverter(), rewriter); } #endif diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index e3ee4beb84fd..01681c48b2b5 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -436,10 +436,11 @@ struct ReduceOpConversion inputTy.getEncoding().dyn_cast(); if (inMfma && inMfma.getIsTransposed()) { assert(numLaneToReduce == 2 || numLaneToReduce == 4); - // for mfma 32x32 adjecant threads in y dimension in transposed MFMA layout are 32 - // apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...]. - // for mfma 16x16 adjecant threads in y dimension in transposed MFMA layout are 16 - // apart: [[0 0 0 0 16 16 16 16 32 32 32 32 ...] [1 1 1 1 33 33 33 33 ...] ...]. + // for mfma 32x32 adjacent threads in y dimension in transposed MFMA + // layout are 32 apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 + // ...] ...]. for mfma 16x16 adjacent threads in y dimension in + // transposed MFMA layout are 16 apart: [[0 0 0 0 16 16 16 16 32 32 32 + // 32 ...] [1 1 1 1 33 33 33 33 ...] ...]. const int warpSize = 64; shuffleIdx = warpSize / N / 2; } diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp new file mode 100644 index 000000000000..059dde957b2b --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp @@ -0,0 +1,251 @@ +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" +#include + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace { +using tt::DotOp; +using ttg::BlockedEncodingAttr; +using ttg::ConvertLayoutOp; +using ttg::DotOperandEncodingAttr; +using ttg::MfmaEncodingAttr; +using ttg::SliceEncodingAttr; + +SmallVector +warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { + // TODO: needs to be updated with appropriate shapePerWarp etc. + auto filter = [&dotOp](Operation *op) { + return op->getParentRegion() == dotOp->getParentRegion(); + }; + auto slices = mlir::getSlice(dotOp, filter); + for (Operation *op : slices) + if (isa(op) && (op != dotOp)) + return {(unsigned)numWarps, 1}; + + SmallVector tensorShape = {shape[0], shape[1]}; + SmallVector ret = {1, 1}; + SmallVector shapePerWarp = {32, 32}; + bool changed = false; + + do { + changed = false; + if (ret[0] * ret[1] >= numWarps) + break; + if (tensorShape[0] / (shapePerWarp[0] * 2) / ret[0] >= + tensorShape[1] / shapePerWarp[1] / ret[1]) { + if (ret[0] < tensorShape[0] / shapePerWarp[0]) { + ret[0] *= 2; + } else + ret[1] *= 2; + } else { + ret[1] *= 2; + } + } while (true); + + if (ret[1] * shapePerWarp[1] > tensorShape[1]) { + return {ret[1], ret[0]}; + } + + return ret; +} + +class BlockedToMFMA : public mlir::RewritePattern { + int mfmaVersion; + int enforcedNonKDim; + +public: + BlockedToMFMA(mlir::MLIRContext *context, int mfmaVersion, int nonKDim) + : mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context), + mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim) {} + + bool isChainDot(tt::DotOp &dotOp) const { + auto filter = [&dotOp](Operation *op) { + return op->getParentRegion() == dotOp->getParentRegion(); + }; + auto slices = mlir::getSlice(dotOp, filter); + for (Operation *op : slices) { + if (isa(op) && (op != dotOp)) + return true; + } + return false; + } + + /// @brief Choose MFMA instruction parameters + /// @param dot target dot operation + /// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments + std::pair chooseMfmaDimensions(tt::DotOp dot) const { + // number of matrix elements along k dim per one MFMA intruction + int64_t kDim = -1; + auto opType = dot.getA().getType().cast(); + auto elemType = opType.getElementType(); + + auto resType = dot.getD().getType().cast(); + auto resShape = resType.getShape(); + + int64_t nonKDim = -1; + if (enforcedNonKDim != 0) { + nonKDim = enforcedNonKDim; + } else { + nonKDim = (resShape[0] < 32 || resShape[1] < 32) ? 16 : 32; + } + if (nonKDim == 32) { + if (elemType.isF32()) + kDim = 2; + if (elemType.isF16()) + kDim = 8; + if (elemType.isBF16()) { + if (mfmaVersion == 1) + kDim = 4; + if (mfmaVersion == 2) + kDim = 8; + } + if (elemType.isInteger(8)) + kDim = 8; + } else { + if (elemType.isF32()) + kDim = 4; + if (elemType.isF16()) + kDim = 16; + if (elemType.isBF16()) { + if (mfmaVersion == 1) + kDim = 8; + if (mfmaVersion == 2) + kDim = 16; + } + if (elemType.isInteger(8)) + kDim = 16; + } + assert(kDim != -1); + assert(nonKDim != -1); + assert(resShape[0] % nonKDim == 0 && resShape[1] % nonKDim == 0); + assert(opType.getShape()[1] % kDim == 0); + return {nonKDim, kDim}; + } + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto dotOp = cast(op); + + auto oldRetType = dotOp.getResult().getType().cast(); + if (!oldRetType.getEncoding() || + !oldRetType.getEncoding().isa()) + return failure(); + + if (!supportMFMA(dotOp)) + return failure(); + + auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding()); + + // get MFMA encoding for the given number of warps + auto retShape = oldRetType.getShape(); + auto mod = op->getParentOfType(); + int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); + + // operands + Value a = dotOp.getA(); + Value b = dotOp.getB(); + auto oldAType = a.getType().cast(); + auto oldBType = b.getType().cast(); + auto ctx = oldAType.getContext(); + + ttg::MfmaEncodingAttr mfmaEnc; + + auto [nonKDim, kDim] = chooseMfmaDimensions(dotOp); + + auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps); + + bool isTransposed = isChainDot(dotOp); + mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim, + warpsPerTile, isTransposed, CTALayout); + + auto newRetType = + RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc); + + // convert accumulator + auto oldAcc = dotOp.getOperand(2); + auto newAcc = rewriter.create(oldAcc.getLoc(), + newRetType, oldAcc); + auto oldAOrder = oldAType.getEncoding() + .cast() + .getParent() + .cast() + .getOrder(); + auto oldBOrder = oldBType.getEncoding() + .cast() + .getParent() + .cast() + .getOrder(); + + // kWidth is a number of consecutive elements per one instruction per one + // thread + auto kWidth = kDim; + // in mfma 32x32 case argument matrix groups elements in 2 groups + // in mfma 16x16 case argument matrix groups elements in 4 groups + if (nonKDim == 32) { + kWidth /= 2; + } else { + assert(nonKDim == 16); + kWidth /= 4; + } + auto newAType = RankedTensorType::get( + oldAType.getShape(), oldAType.getElementType(), + ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth)); + auto newBType = RankedTensorType::get( + oldBType.getShape(), oldBType.getElementType(), + ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth)); + a = rewriter.create(a.getLoc(), newAType, a); + b = rewriter.create(b.getLoc(), newBType, b); + auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, + newAcc, dotOp.getAllowTF32()); + + rewriter.replaceOpWithNewOp(op, oldRetType, + newDot.getResult()); + return success(); + } +}; + +} // namespace + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonAMDGPUAccelerateMatmulPass + : public TritonAMDGPUAccelerateMatmulBase< + TritonAMDGPUAccelerateMatmulPass> { +public: + TritonAMDGPUAccelerateMatmulPass() = default; + TritonAMDGPUAccelerateMatmulPass(int matrixCoreVersion, + int matrixInstructionSize) { + this->matrixCoreVersion = matrixCoreVersion; + this->matrixInstructionSize = matrixInstructionSize; + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp m = getOperation(); + + mlir::RewritePatternSet patterns(context); + if (matrixCoreVersion == 1 || matrixCoreVersion == 2) + patterns.add<::BlockedToMFMA>(context, matrixCoreVersion, + matrixInstructionSize); + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { + signalPassFailure(); + } + } +}; + +std::unique_ptr +mlir::createTritonAMDGPUAccelerateMatmulPass(int matrixCoreVersion, + int matrixInstructionSize) { + return std::make_unique( + matrixCoreVersion, matrixInstructionSize); +} diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index dcab1d44c3b6..b8060cd6ce0c 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -75,46 +75,6 @@ warpsPerTileV2(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { return ret; } -#ifdef USE_ROCM -SmallVector warpsPerTileMI200(tt::DotOp dotOp, - const ArrayRef shape, - int numWarps) { - // TODO: needs to be updated with appropriate shapePerWarp etc. - auto filter = [&dotOp](Operation *op) { - return op->getParentRegion() == dotOp->getParentRegion(); - }; - auto slices = mlir::getSlice(dotOp, filter); - for (Operation *op : slices) - if (isa(op) && (op != dotOp)) - return {(unsigned)numWarps, 1}; - - SmallVector tensorShape = {shape[0], shape[1]}; - SmallVector ret = {1, 1}; - SmallVector shapePerWarp = {32, 32}; - bool changed = false; - - do { - changed = false; - if (ret[0] * ret[1] >= numWarps) - break; - if (tensorShape[0] / (shapePerWarp[0] *2 ) / ret[0] >= - tensorShape[1] / shapePerWarp[1] / ret[1]) { - if (ret[0] < tensorShape[0] / shapePerWarp[0]) { - ret[0] *= 2; - } else - ret[1] *= 2; - } else { - ret[1] *= 2; - } - } while (true); - - if (ret[1] * shapePerWarp[1] > tensorShape[1]) { - return {ret[1], ret[0]}; - } - - return ret; -} - SmallVector warpsPerTileV3(tt::DotOp dotOp, const ArrayRef shape, int numWarps, const SmallVector &instrShape) { @@ -139,161 +99,6 @@ warpsPerTileV3(tt::DotOp dotOp, const ArrayRef shape, int numWarps, return ret; } -class BlockedToMFMA : public mlir::RewritePattern { - int mfmaVersion; -public: - BlockedToMFMA(mlir::MLIRContext *context, int mfmaVersion) - : mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context), mfmaVersion(mfmaVersion) {} - - bool isChainDot(tt::DotOp &dotOp) const { - auto filter = [&dotOp](Operation *op) { - return op->getParentRegion() == dotOp->getParentRegion(); - }; - auto slices = mlir::getSlice(dotOp, filter); - for (Operation *op : slices) { - if (isa(op) && (op != dotOp)) - return true; - } - return false; - } - - /// @brief Choose MFMA instruction parameters - /// @param dot target dot operation - /// @param mfmaVersion - /// @param nonKDim - /// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments - std::pair chooseMfmaDimensions(tt::DotOp dot, - int mfmaVersion, - int64_t nonKDim) const { - // number of matrix elements along k dim per one MFMA intruction - int64_t kDim = -1; - auto opType = dot.getA().getType().cast(); - auto elemType = opType.getElementType(); - if (nonKDim == 32) { - if (elemType.isF32()) - kDim = 2; - if (elemType.isF16()) - kDim = 8; - if (elemType.isBF16()) { - if (mfmaVersion == 1) - kDim = 4; - if (mfmaVersion == 2) - kDim = 8; - } - if (elemType.isInteger(8)) - kDim = 8; - } else { - if (elemType.isF32()) - kDim = 4; - if (elemType.isF16()) - kDim = 16; - if (elemType.isBF16()) { - if (mfmaVersion == 1) - kDim = 8; - if (mfmaVersion == 2) - kDim = 16; - } - if (elemType.isInteger(8)) - kDim = 16; - } - assert(kDim != -1); - return {nonKDim, kDim}; - } - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto dotOp = cast(op); - - auto oldRetType = dotOp.getResult().getType().cast(); - if (!oldRetType.getEncoding() || - !oldRetType.getEncoding().isa()) - return failure(); - - // TODO replace with nonKDim with some heuristic in chooseMfmaDimensions - // function - int64_t externalNonKDim = 32; - - const char *mfmaType = std::getenv("MFMA_TYPE"); - if (mfmaType) { - externalNonKDim = std::stol(mfmaType); - assert(externalNonKDim == 32 || externalNonKDim == 16); - } - - if (!supportMFMA(dotOp, externalNonKDim)) - return failure(); - - auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding()); - - // get MFMA encoding for the given number of warps - auto retShape = oldRetType.getShape(); - auto mod = op->getParentOfType(); - int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); - - // operands - Value a = dotOp.getA(); - Value b = dotOp.getB(); - auto oldAType = a.getType().cast(); - auto oldBType = b.getType().cast(); - auto ctx = oldAType.getContext(); - - ttg::MfmaEncodingAttr mfmaEnc; - - auto [nonKDim, kDim] = - chooseMfmaDimensions(dotOp, mfmaVersion, externalNonKDim); - - auto warpsPerTile = warpsPerTileMI200(dotOp, retShape, numWarps); - - bool isTransposed = isChainDot(dotOp); - mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim, - warpsPerTile, isTransposed, CTALayout); - - auto newRetType = - RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc); - - // convert accumulator - auto oldAcc = dotOp.getOperand(2); - auto newAcc = rewriter.create( - oldAcc.getLoc(), newRetType, oldAcc); - auto oldAOrder = oldAType.getEncoding() - .cast() - .getParent() - .cast() - .getOrder(); - auto oldBOrder = oldBType.getEncoding() - .cast() - .getParent() - .cast() - .getOrder(); - - // kWidth is a number of consecutive elements per one instruction per one thread - auto kWidth = kDim; - // in mfma 32x32 case argument matrix groups elements in 2 groups - // in mfma 16x16 case argument matrix groups elements in 4 groups - if (nonKDim == 32) { - kWidth /= 2; - } else { - assert(nonKDim == 16); - kWidth /= 4; - } - auto newAType = RankedTensorType::get( - oldAType.getShape(), oldAType.getElementType(), - ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth)); - auto newBType = RankedTensorType::get( - oldBType.getShape(), oldBType.getElementType(), - ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth)); - a = rewriter.create(a.getLoc(), newAType, a); - b = rewriter.create(b.getLoc(), newBType, b); - auto newDot = rewriter.create( - dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32()); - - rewriter.replaceOpWithNewOp( - op, oldRetType, newDot.getResult()); - return success(); - } -}; -#endif - class BlockedToMMA : public mlir::RewritePattern { int computeCapability; mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding @@ -535,14 +340,7 @@ class TritonGPUAccelerateMatmulPass ModuleOp m = getOperation(); mlir::RewritePatternSet patterns(context); -#ifdef USE_ROCM - if (computeCapability == 1 || computeCapability == 2) { - int mfmaVersion = computeCapability; - patterns.add<::BlockedToMFMA>(context, mfmaVersion); - } -#else patterns.add<::BlockedToMMA>(context, computeCapability); -#endif if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) { signalPassFailure(); } diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 8a2342d3aca6..fab606807184 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(TritonGPUTransforms AccelerateMatmul.cpp + AccelerateAMDMatmul.cpp Coalesce.cpp DecomposeConversions.cpp OptimizeDotOperands.cpp diff --git a/python/src/triton.cc b/python/src/triton.cc index 29276493e8cd..27051ed84ebb 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1849,6 +1849,11 @@ void init_triton_ir(py::module &&m) { self.addPass( mlir::createTritonGPUAccelerateMatmulPass(computeCapability)); }) + .def("add_tritonamdgpu_accelerate_matmul_pass", + [](mlir::PassManager &self, int tensorCoreVersion, int instrSize) { + self.addPass(mlir::createTritonAMDGPUAccelerateMatmulPass( + tensorCoreVersion, instrSize)); + }) .def("add_tritongpu_optimize_dot_operands_pass", [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUOptimizeDotOperandsPass()); diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 8add0a85e289..cf7a3bc2953c 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -1455,9 +1455,9 @@ def kernel(X, stride_xm, stride_xn, # --------------- -@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype", +@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim", # FMA Test Dot tests - [(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype) + [(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype, 0) for shape in [(64, 64, 64), (16, 16, 16)] for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] for allow_tf32 in [True, False] @@ -1466,7 +1466,7 @@ def kernel(X, stride_xm, stride_xn, ('float32', 'float32')] if not (allow_tf32 and (in_dtype in ['float16']))] + - [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype) + [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, 0) for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], @@ -1486,7 +1486,7 @@ def kernel(X, stride_xm, stride_xn, ('float32', 'float32')]] if triton.language.semantic.gpu_matrix_core_version() == 0 else # MFMA Test Dot tests - [(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype) + [(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim) for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax'] for allow_tf32 in [True, False] @@ -1494,9 +1494,10 @@ def kernel(X, stride_xm, stride_xn, ('bfloat16', 'float32'), ('float16', 'float32'), ('float32', 'float32')] + for non_k_dim in [0, 16, 32] if not (allow_tf32 and (in_dtype in ['float16']))] + - [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype) + [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, non_k_dim) for shape_nw in [[128, 128, 32, 2], [128, 16, 32, 4], [128, 128, 64, 2], @@ -1524,8 +1525,9 @@ def kernel(X, stride_xm, stride_xn, for col_a in [True, False] for col_b in [True, False] for in_dtype in ['int8', 'bfloat16', 'float16', 'float32'] - for out_dtype in [None]]) -def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, device='cuda'): + for out_dtype in [None] + for non_k_dim in [0, 16, 32]]) +def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, device='cuda'): capability = torch.cuda.get_device_capability() if torch.version.hip is not None: @@ -1539,6 +1541,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o out_dtype = "float32" else: out_dtype = "int32" + if non_k_dim == 32 and (M < 32 or N < 32): + pytest.skip("incompatible non_k_dim == 32 with MN sizes") if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") @@ -1652,7 +1656,8 @@ def kernel(X, stride_xm, stride_xk, DO_SOFTMAX=epilogue == 'softmax', CHAIN_DOT=epilogue == 'chain-dot', ALLOW_TF32=allow_tf32, - num_warps=num_warps) + num_warps=num_warps, + matrix_instr_nonkdim=non_k_dim) # torch result if in_dtype == 'int8': z_ref = np.matmul(x.astype(np.float32), @@ -1687,6 +1692,15 @@ def kernel(X, stride_xm, stride_xk, # added atol, to loose precision for float16xfloat16->float32 case np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) if torch.version.hip is not None: + import triton.language.semantic as sem + if sem.gpu_matrix_core_version() > 0: + ttgir = pgm.asm['ttgir'] + if non_k_dim == 16: + assert "#triton_gpu.mfma<{nonKDim = 16" in ttgir + assert "#triton_gpu.mfma<{nonKDim = 32" not in ttgir + elif non_k_dim == 32: + assert "#triton_gpu.mfma<{nonKDim = 32" in ttgir + assert "#triton_gpu.mfma<{nonKDim = 16" not in ttgir return # make sure ld/st are vectorized ptx = pgm.asm['ptx'] diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 0c93233dfe72..53fad9041ff9 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -85,7 +85,7 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, arch): def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch, - cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue): + cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_inst_type): pm = ir.pass_manager(mod.context) pm.enable_debug() pm.add_tritongpu_coalesce_pass() @@ -100,7 +100,8 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch, # TODO change interface of accelerate_matmul_pass if is_hip(): matrix_core_version = gpu_matrix_core_version() - pm.add_tritongpu_accelerate_matmul_pass(matrix_core_version) + matrix_inst_size = matrix_inst_type + pm.add_tritonamdgpu_accelerate_matmul_pass(matrix_core_version, matrix_inst_size) pm.add_tritongpu_remove_layout_conversions_pass() if optimize_epilogue: pm.add_tritongpu_optimize_epilogue_pass() @@ -310,6 +311,7 @@ def make_hash(fn, arch, env_vars, **kwargs): num_ctas = kwargs.get("num_ctas", 1) num_stages = kwargs.get("num_stages", 3) waves_per_eu = kwargs.get("waves_per_eu", 0) + matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0); enable_warp_specialization = kwargs.get("enable_warp_specialization", False) enable_persistent = kwargs.get("enable_persistent", False) debug = kwargs.get("debug", False) @@ -317,7 +319,7 @@ def make_hash(fn, arch, env_vars, **kwargs): get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8)) configs_key = [get_conf_key(conf) for conf in configs] env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] - key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}" + key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}" return hashlib.md5(key.encode("utf-8")).hexdigest() assert isinstance(fn, str) return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest() @@ -479,6 +481,7 @@ def compile(fn, **kwargs): num_ctas = kwargs.get("num_ctas", 1) num_stages = kwargs.get("num_stages", get_arch_default_num_stages(device_type, capability=capability)) waves_per_eu = kwargs.get("waves_per_eu", 0) + matrix_instr_nonkdim = kwargs.get("matrix_instr_nonkdim", 0) # TODO[shuhaoj]: Default should be to enable warp specialization once possible enable_warp_specialization = kwargs.get("enable_warp_specialization", False) # TODO[shuhaoj]: persistent can be decoupled with warp specialization @@ -504,7 +507,7 @@ def compile(fn, **kwargs): stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch)) stages["ttgir"] = (lambda path: parse_mlir_module(path, context), - lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) + lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size, num_ctas, arch), num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_instr_nonkdim)) stages["llir"] = (lambda path: Path(path).read_text(), lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos, waves_per_eu)) if is_cuda: @@ -579,6 +582,7 @@ def compile(fn, **kwargs): "num_ctas": num_ctas, "num_stages": num_stages, "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": matrix_instr_nonkdim, "enable_warp_specialization": enable_warp_specialization, "enable_persistent": enable_persistent, "constants": _get_jsonable_constants(constants), diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 43382636b4a7..c02ef5ec2dd9 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1279,18 +1279,16 @@ def gpu_matrix_core_version() -> int: return 0 def mfma_supported_granularity(m, n, k) -> bool: - granularity_mn = 32 - granularity_k = 8 - import os - if "MFMA_TYPE" in os.environ and os.environ["MFMA_TYPE"] == "16": - granularity_mn = 16 - granularity_k = 16 - - if m % granularity_mn != 0 or n % granularity_mn != 0: - return False - if k % granularity_k != 0: - return False - return True + # todo make this gran_type matrix element type sensitive + for gran_type in [(32, 8), (16, 16)]: + granularity_mn, granularity_k = gran_type + + if m % granularity_mn != 0 or n % granularity_mn != 0: + continue + if k % granularity_k != 0: + continue + return True + return False def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: matrix_core_version = gpu_matrix_core_version() diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 6dc9de5a7c60..74390491c8e9 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -276,13 +276,13 @@ def _make_constants(self, constexpr_key): constants = dict(zip(self.constexprs, constexpr_key)) return constants - def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs): + def _call_hook(self, key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs): if JITFunction.cache_hook is None: return False name = self.fn.__name__ module = self.fn.__module__ arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])]) - repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})" + repr = f"{name}[num_warps={num_warps}, num_ctas={num_ctas}, num_stages={num_stages}, waves_per_eu={waves_per_eu}, matrix_instr_nonkdim={matrix_instr_nonkdim}, enable_warp_specialization={enable_warp_specialization}]({arg_reprs})" key = str(key) class LegacyCompiler: @@ -364,7 +364,7 @@ def _make_launcher(self): src = f""" import triton -def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None): +def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None): from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()} constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()} @@ -406,7 +406,7 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu if num_stages is None: num_stages = get_arch_default_num_stages(device_type) - key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, self.debug) + key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, self.debug) if not extern_libs is None: key = (key, tuple(extern_libs.items())) @@ -434,8 +434,8 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu for i, arg in constants.items(): if callable(arg): raise TypeError(f"Callable constexpr at index {{i}} is not supported") - if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, enable_warp_specialization, extern_libs, configs): - bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) + if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, waves_per_eu, matrix_instr_nonkdim, enable_warp_specialization, extern_libs, configs): + bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=matrix_instr_nonkdim, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) # Create tensormaps and append to args args = bin.assemble_tensormap_to_arg(args) if not warmup: From 3b1c273cd321a38a60697603919cd9204bda61ac Mon Sep 17 00:00:00 2001 From: Aleksandr Efimov Date: Tue, 17 Oct 2023 21:01:04 +0000 Subject: [PATCH 2/3] add license in AccelerateMatmul --- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index b8060cd6ce0c..a61faa356416 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -1,3 +1,25 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files + * (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, + * publish, distribute, sublicense, and/or sell copies of the Software, + * and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" From c0a06647aacf8e36911ed07258f6c2f66cc5f9d2 Mon Sep 17 00:00:00 2001 From: Aleksandr Efimov Date: Tue, 17 Oct 2023 21:01:20 +0000 Subject: [PATCH 3/3] review fix --- include/triton/Dialect/TritonGPU/Transforms/Passes.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index b30e897207e3..ad6720616b40 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -102,9 +102,9 @@ def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir: Option<"matrixCoreVersion", "matrix-core-version", "int32_t", /*default*/"0", "device matrix core version">, - Option<"matrixInstructionSize", "matrix-instructio-size", + Option<"matrixInstructionSize", "matrix-instruction-size", "int32_t", /*default*/"0", - "enforce matrix intrucion MN size"> + "enforce matrix instruction MN size"> ]; }