Skip to content

Commit

Permalink
[Triton] Mfma16 support (#251)
Browse files Browse the repository at this point in the history
* [MFAM] Support mfma with NM size 16

This PR code emitting of MFMA instructions with size 16.

* add control over mfma type with MFMA_TYPE=16 env var
  • Loading branch information
binarman authored Oct 9, 2023
1 parent e801638 commit 7e34c24
Show file tree
Hide file tree
Showing 11 changed files with 251 additions and 106 deletions.
2 changes: 1 addition & 1 deletion include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ bool maybeSharedAllocationOp(Operation *op);
bool maybeAliasOp(Operation *op);

#ifdef USE_ROCM
bool supportMFMA(triton::DotOp op);
bool supportMFMA(triton::DotOp op, int64_t nonKDim);
#endif

bool supportMMA(triton::DotOp op, int version);
Expand Down
12 changes: 6 additions & 6 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,18 +378,18 @@ bool supportMMA(triton::DotOp op, int version) {
}

#ifdef USE_ROCM
static bool supportMFMAGranularity(int m, int n, int k) {
static bool supportMFMAGranularity(int m, int n, int k, int64_t nonKDim) {
// these limitations are dtype dependent, in future we may relax them
const int granularityMN = 32;
const int granularityK = 8;
const int granularityMN = nonKDim;
const int granularityK = nonKDim == 32 ? 8 : 16;
if (m % granularityMN != 0 || n % granularityMN != 0)
return false;
if (k % granularityK != 0)
return false;
return true;
}

bool supportMFMA(triton::DotOp op) {
bool supportMFMA(triton::DotOp op, int64_t nonKDim) {
auto aTy = op.getA().getType().cast<RankedTensorType>();
auto bTy = op.getB().getType().cast<RankedTensorType>();

Expand All @@ -403,7 +403,7 @@ bool supportMFMA(triton::DotOp op) {
auto bShape = bTy.getShape();

assert(aShape[1] == bShape[0]);
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1]))
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1], nonKDim))
return false;

return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() ||
Expand Down Expand Up @@ -455,7 +455,7 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getKWidth() == 4 &&
dotOperandLayout.getParent() == mfmaLayout &&
mfmaLayout.getIsTransposed() &&
mfmaLayout.getNonKDim() == 32 && mfmaLayout.getIsTransposed() &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}
#endif
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ struct ConvertLayoutOpConversion
SmallVector<SmallVector<unsigned>> offsets;
assert(rank == 2);
SmallVector<Value> multiDimOffset(rank);
emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0], multiDimCTAInRepId[1]);
emitMfmaOffsetForCTA(mfmaLayout, offsets, multiDimCTAInRepId[0],
multiDimCTAInRepId[1]);
multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0]));
multiDimOffset[1] = add(multiDimBase[1], i32_val(offsets[elemId][1]));
return multiDimOffset;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
* @param reps number of instructions repretition to fully cover dot operand
* @param smemStrides strides in LDS tensor
* @param loadVecSize number of elements loaded by one operation
* @param iNonKDim non-K dimension of dot operand
* @return vector (i-th element corresponds to i-th load instruction) of
* 2-element vectors(tensor row and col).
*/
Expand All @@ -115,7 +116,7 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, ArrayRef<Value> smemOffsets,
int loadVecSize) {
int loadVecSize, unsigned iNonKDim) {
auto numM = reps[0];
auto numK = reps[1];
const int loadsPerThread = numOfElems / loadVecSize;
Expand All @@ -124,6 +125,7 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,

Value _0 = i32_val(0);
Value _32 = i32_val(32);
Value nonKDim = i32_val(iNonKDim);

for (int block = 0; block < numM; ++block) {
Value blockVOffset = i32_val(block * elemsPerInstr[0] * warpsPerGroup);
Expand All @@ -134,8 +136,13 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,
Value tileVOffset = _0;
Value tileHOffset = i32_val(tile * elemsPerInstr[1]);

Value laneVOffset = urem(laneId, _32);
Value laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
Value laneVOffset = urem(laneId, nonKDim);
Value laneHOffset;
if (iNonKDim == 32)
laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
else
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));

for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
Value elemVOffset = _0;
Value elemHOffset = i32_val(loadId * loadVecSize);
Expand Down Expand Up @@ -176,7 +183,7 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, SharedMemoryObject smemObj,
SharedEncodingAttr srcLayout) {
SharedEncodingAttr srcLayout, unsigned nonKDim) {
SmallVector<Value> strides{smemObj.strides[0], smemObj.strides[1]};
SmallVector<Value> offsets{smemObj.offsets[0], smemObj.offsets[1]};

Expand All @@ -190,7 +197,7 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc,

auto mapping = computeTensorElemMapping(rewriter, loc, elemsPerInstr, waveId,
laneId, warpsPerGroup, numOfElems,
reps, offsets, vectorSize);
reps, offsets, vectorSize, nonKDim);
llvm::SmallVector<Value> aOffsets(mapping.size());
for (int i = 0; i < mapping.size(); ++i) {
Value row = mapping[i][0];
Expand All @@ -205,7 +212,7 @@ computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, SharedMemoryObject smemObj,
SharedEncodingAttr srcLayout) {
SharedEncodingAttr srcLayout, unsigned nonKDim) {
// transpose reps and offsets, because operand B has layout equal to
// transposed operand A layout
SmallVector<int64_t> tElemsPerInstr{elemsPerInstr[1], elemsPerInstr[0]};
Expand All @@ -222,7 +229,7 @@ computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc,

auto mapping = computeTensorElemMapping(rewriter, loc, tElemsPerInstr, waveId,
laneId, warpsPerGroup, numOfElems,
tReps, toffsets, vectorSize);
tReps, toffsets, vectorSize, nonKDim);
llvm::SmallVector<Value> bOffsets(mapping.size());
for (int i = 0; i < mapping.size(); ++i) {
// swap row and col, because operand B layout is a transposed operand A
Expand Down Expand Up @@ -411,7 +418,8 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
TritonGPUToLLVMTypeConverter *typeConverter, Value tensor,
const SharedMemoryObject &smemObj) {
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
assert(mfmaLayout.getNonKDim() == 32);
auto nonKDim = mfmaLayout.getNonKDim();
assert(nonKDim == 32 || nonKDim == 16);
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();

auto aTensorTy = tensor.getType().cast<RankedTensorType>();
Expand All @@ -430,14 +438,15 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
auto numRepK = numReps[1];

unsigned iWaveSize = triton::gpu::getWarpSize(mfmaLayout);
assert(iWaveSize == 64);
Value waveSize = i32_val(iWaveSize);
Value wave = udiv(thread, waveSize);
Value lane = urem(thread, waveSize);

Value waveM =
getWaveM(rewriter, loc, wave, warpsPerCTA, mfmaInstrM, shape[0]);
int numOfElems =
std::max<int>(mfmaInstrM * mfmaInstrK / iWaveSize /*wave size*/, 1);
int numOfElems = mfmaInstrM * mfmaInstrK / iWaveSize;
assert(numOfElems >= 1);
unsigned int maxNumWarps = shape[0] / mfmaInstrM;
int warpsPerGroupM = std::min(warpsPerCTA[0], maxNumWarps);
aElemTy = typeConverter->convertType(aElemTy);
Expand Down Expand Up @@ -498,7 +507,7 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
} else { // normal path
SmallVector<Value> offsets = computeOffsetsAType(
rewriter, loc, aElemsPerInstr, waveM, lane, warpsPerGroupM, numOfElems,
numReps, smemObj, sharedLayout);
numReps, smemObj, sharedLayout, nonKDim);

Value smemBase = computeBasePtr(rewriter, loc, smemObj);
Type resElemTy = typeConverter->convertType(aElemTy);
Expand Down Expand Up @@ -551,7 +560,8 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
TritonGPUToLLVMTypeConverter *typeConverter, Value tensor,
const SharedMemoryObject &smemObj) {
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
assert(mfmaLayout.getNonKDim() == 32);
auto nonKDim = mfmaLayout.getNonKDim();
assert(nonKDim == 32 || nonKDim == 16);
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();

auto bTensorTy = tensor.getType().cast<RankedTensorType>();
Expand All @@ -569,20 +579,15 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
auto numRepN = numReps[1];

unsigned iWaveSize = triton::gpu::getWarpSize(mfmaLayout);
assert(iWaveSize == 64);
Value waveSize = i32_val(iWaveSize);
Value wave = udiv(thread, waveSize);
Value lane = urem(thread, waveSize);

Value waveN =
getWaveN(rewriter, loc, wave, warpsPerCTA, mfmaInstrN, shape[1]);
int numOfElems =
std::max<int>(mfmaInstrK * mfmaInstrN / iWaveSize /*wave size*/, 1);

int macroTileM = std::max<int>(shape[0] / (warpsPerCTA[0] * 32), 1);
int wptM = std::min<int>(warpsPerCTA[0], macroTileM);
int macroTileN = std::max<int>(shape[1] / (warpsPerCTA[1] * 32), 1);
int wptN = std::min<int>(warpsPerCTA[1], macroTileN);
int wpt = std::max<int>(wptM, wptN);
int numOfElems = mfmaInstrK * mfmaInstrN / iWaveSize;
assert(numOfElems >= 1);

unsigned int maxNumWarps = shape[1] / mfmaInstrN;
int warpsPerGroupN = std::min(warpsPerCTA[1], maxNumWarps);
Expand Down Expand Up @@ -648,7 +653,7 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
} else { // normal path
llvm::SmallVector<Value> offsets = computeOffsetsBType(
rewriter, loc, bElemsPerInstr, waveN, lane, warpsPerGroupN, numOfElems,
numReps, smemObj, sharedLayout);
numReps, smemObj, sharedLayout, nonKDim);

Value smemBase = computeBasePtr(rewriter, loc, smemObj);
Type resElemTy = typeConverter->convertType(bElemTy);
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast<MfmaEncodingAttr>();
if (!isOuter && mfmaLayout && supportMFMA(op)) {
if (!isOuter && mfmaLayout && supportMFMA(op, mfmaLayout.getNonKDim())) {
return convertMFMA(op, adaptor, getTypeConverter(), rewriter);
}
#endif
Expand Down
Loading

3 comments on commit 7e34c24

@lzxdn
Copy link

@lzxdn lzxdn commented on 7e34c24 Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested matrix multiplication with mfma_16x16x4 and found duplicate calculations.

@zhanglx13
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lzxdn For now we only enabled mfma_16x16x16. Using mfma_16x16x4 requires more work and performance is not as good as mfma_16x16x16.

@binarman
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lzxdn
Hi!
Could you elaborate what do you mean by mfma_16x16x4:

  • Did you use fp32xfp32 -> fp32 multiplication?
  • What matrix sizes did you use?
  • Can you point what computations are duplicated?

Please sign in to comment.