Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MFMA] Implement MFMA 4x64 v3 #550

Draft
wants to merge 8 commits into
base: triton-mlir
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ compared to 1*64 when the hasLeadingOffset is false.

if (mfmaEnc) {
int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0;
int nonKDimNum = 1 - kDimNum;
if (needTrans)
kDimNum = 1 - kDimNum;
bool isKDimInner = (order[0] == kDimNum);
Expand All @@ -143,17 +144,51 @@ compared to 1*64 when the hasLeadingOffset is false.
int innerDimLength = shape[order[0]];
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit;

int mDim = mfmaEnc.getMDim();
int nDim = mfmaEnc.getNDim();
int nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim;
if ((mDim == 4 && nDim == 64) || (nDim == 4 && mDim == 64)) {
// Operands of the layout have following shapes
// Large operand:
// - shape 64(non-k)x64(k) for 16 bit dtypes
// - shape 64(non-k)x16(k) for 32 bit dtypes
// Small operand:
// - shape 4(non-k)x64(k) for 16 bit dtypes
// - shape 4(non-k)x16(k) for 32 bit dtypes
int vecSize = bankBitWidth / typeWidthInBit;
if (nonKDim == 64 && vecSize < dotOpEnc.getKWidth()) {
// This heuristic introduces bank conflicts,
// but saves address computation overhead and
// reduces number of ds_read/ds_write instructions

// can not swizzle in vectors of 4 for fp32,
// because mfma 4x64 and 64x4 shift loads by 1*(laneId/4)in k dimension, breaking borders of one load
vecSize = typeWidthInBit == 32 ? 1 : 4;
}
const int bankRowInBits = numBanks * bankBitWidth;
const int innerDimInBits = innerDimLength * typeWidthInBit;
const int perPhase = std::max(1, bankRowInBits / innerDimInBits);
const int phaseSanityLimit = std::min(numBanks, innerDimLength / vecSize);
const int maxPhase = std::min(phaseSanityLimit, nonKDim) / perPhase;
return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
}
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
// vecSize is set to kWidth of the dotop layout
int vecSize = dotOpEnc.getKWidth();
// maxPhase is set to SIMDWidth / perPhase
int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize);
// TODO (zhanglx): figure out better parameters for mfma4
auto mDim = mfmaEnc.getMDim();
auto nDim = mfmaEnc.getNDim();
auto nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim;
if (4 == nonKDim)
maxPhase = 4;
// if maxPhase * perPhase is larger than one block of warps,
// fallback to unswizzled tensor.
// Shared to dot op conversion requires that swizzling patern
// fits into one block of warps.
auto warpsPerCTA = mfmaEnc.getWarpsPerCTA();
if (maxPhase * perPhase > nonKDim * warpsPerCTA[nonKDimNum]) {
assert(isKDimInner);
maxPhase = 1;
}
assert(maxPhase > 0);

return get(context, vecSize, perPhase, maxPhase, order, CTALayout);
Expand Down
6 changes: 4 additions & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ struct MfmaInsnAttr {
unsigned n;
unsigned k;
// k_base refers to the number of elements per thread
unsigned k_base;
unsigned k_base_a;
unsigned k_base_b;
llvm::StringRef insn;
};

Expand Down Expand Up @@ -223,7 +224,8 @@ class MfmaInsn {
unsigned getMDim();
unsigned getNDim();
StringRef getInsnName();
unsigned getKBase();
unsigned getKBaseA();
unsigned getKBaseB();
};
} // namespace mlir

Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getKWidth() == 4 &&
dotOperandLayout.getParent() == mfmaLayout &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16 ||
(mfmaLayout.getMDim() == 4 && mfmaLayout.getNDim() == 64)) &&
mfmaLayout.getIsTransposed() &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,26 @@ llvm::SmallVector<llvm::SmallVector<Value>> computeTensorElemMappingInBlock(
if (iNonKDim == 32)
laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
else {
// In this configuration wave contains 16 copies of same data
if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) {
// shortcut for 64x64 tile size.
// In this case warp do not wrap, so no need to introduce this offset
if (iNonKDim == 64)
laneHOffset = i32_val(0);
} else {
assert(iKDim * iNonKDim / numOfElems == 64 &&
"seems no all threads in wave contain unique elements");
else
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
}
}

for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
Value elemVOffset = _0;
Value elemHOffset = i32_val(loadId * loadVecSize);
Value elemHOffset;
if (iNonKDim == 64) {
Value groupId = udiv(laneId, i32_val(4));
Value groupShift = mul(groupId, i32_val(iKDim / 16));
Value loadShift = add(groupShift, i32_val(loadId * loadVecSize));
Value wrappedLoadShift = urem(loadShift, i32_val(iKDim));
elemHOffset = wrappedLoadShift;
} else {
elemHOffset = i32_val(loadId * loadVecSize);
}

Value sliceVOffset =
add(add(add(tileVOffset, laneVOffset), elemVOffset), waveVOffset);
Expand Down Expand Up @@ -328,33 +335,44 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc,
Value waveOffset = mul(waveId, i32_val(iNonKDim));
Value colOffset = urem(laneId, _nonKDim);

// halfOffset is an offset related to wrapping of wave in the tile.
// for example, mfma 32 case (mapping of tensor elements to lane ids in
// wave):
//
// 0 1 2 3 ... 31
// 0 1 2 3 ... 31
// 0 1 2 3 ... 31
// 0 1 2 3 ... 31
// 32 33 34 35 ... 63 <- at this point wave is wrapping
// 32 33 34 35 ... 63
// 32 33 34 35 ... 63
// 32 33 34 35 ... 63
Value halfWaveOffset;
if (iNonKDim == 64)
halfWaveOffset = i32_val(0);
else
halfWaveOffset =
mul(udiv(laneId, _nonKDim), i32_val(numOfElems * lineSize));

// sum of offsets dependent from lane id and warp Id across non-k dim
Value baseThreadOffset = add(add(halfWaveOffset, colOffset), waveOffset);

for (int block = 0; block < numN; ++block) {
Value blockOffset = i32_val(block * iNonKDim * warpsPerBlock);
int blockOffset = block * iNonKDim * warpsPerBlock;
for (int tile = 0; tile < numK; ++tile) {
Value tileOffset = i32_val(tile * iKDim * lineSize);
int tileOffset = tile * iKDim * lineSize;
for (int elem = 0; elem < numOfElems; ++elem) {
// halfOffset is an offset related to wrapping of wave in the tile.
// for example, mfma 32 case (mapping of tensor elements to lane ids in
// wave):
//
// 0 1 2 3 ... 31
// 0 1 2 3 ... 31
// 0 1 2 3 ... 31
// 0 1 2 3 ... 31
// 32 33 34 35 ... 63 <- at this point wave is wrapping
// 32 33 34 35 ... 63
// 32 33 34 35 ... 63
// 32 33 34 35 ... 63
Value halfOffset;
if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4)
halfOffset = i32_val(0);
else
halfOffset =
mul(udiv(laneId, _nonKDim), i32_val(numOfElems * lineSize));
Value rowOffset = add(i32_val(elem * lineSize), halfOffset);
Value elemOffset = add(rowOffset, colOffset);
Value offset =
add(add(add(waveOffset, blockOffset), tileOffset), elemOffset);
Value rowOffset;
if (iNonKDim == 64) {
Value groupId = udiv(laneId, i32_val(4));
Value groupShift = mul(groupId, i32_val(iKDim / 16));
Value elemShift = add(groupShift, i32_val(elem));
Value wrappedElemShift = urem(elemShift, i32_val(iKDim));
rowOffset = mul(wrappedElemShift, i32_val(lineSize));
} else {
rowOffset = i32_val(elem * lineSize);
}
Value offset = add(add(baseThreadOffset, rowOffset), i32_val(blockOffset + tileOffset));
offsets[numK * numOfElems * block + numOfElems * tile + elem] = offset;
}
}
Expand Down Expand Up @@ -456,6 +474,8 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
int numSubBlocks = 1;
if ((mfmaInstrK == 4 || mfmaInstrK == 1) && mfmaInstrNonK == 4)
numSubBlocks = 16;
assert(numSubBlocks == 1 &&
"after reworking layout, there should be no redundency");
int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWaveSize;
assert(numOfElems >= 1);

Expand Down
Loading