Skip to content

Commit

Permalink
[AMD] Refactor SharedToDotOperandMFMA (#533)
Browse files Browse the repository at this point in the history
- remove unused functions and variables
- replace "group" with block to unify language
- add check for swizzling pattern compatibility with normal path
- unify normal and fast path index computation
  • Loading branch information
binarman authored Mar 12, 2024
1 parent 2d501e4 commit bcde44f
Showing 1 changed file with 100 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
* @param elemsPerInstr operand tile shape consumed by one MFMA instruction
* @param waveId id component of 2d wave grid along nono-K axis
* @param laneId lane id in warp [0..63]
* @param warpsPerGroup number of warps in one block
* @param numOfElems number of elements accessed by thread per repetition
* @param reps number of instructions repretition to fully cover dot operand
* @param smemStrides strides in LDS tensor
Expand All @@ -136,13 +135,11 @@ swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
* @return vector (i-th element corresponds to i-th load instruction) of
* 2-element vectors(tensor row and col).
*/
llvm::SmallVector<llvm::SmallVector<Value>>
computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, ArrayRef<Value> smemOffsets,
int loadVecSize, unsigned iNonKDim, unsigned iKDim) {
auto numM = reps[0];
llvm::SmallVector<llvm::SmallVector<Value>> computeTensorElemMappingInBlock(
ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId, Value laneId,
int numOfElems, ArrayRef<int64_t> reps, ArrayRef<Value> smemOffsets,
int loadVecSize, unsigned iNonKDim, unsigned iKDim) {
auto numK = reps[1];
const int loadsPerThread = numOfElems / loadVecSize;
llvm::SmallVector<llvm::SmallVector<Value>> mapping(numK * loadsPerThread);
Expand Down Expand Up @@ -204,7 +201,7 @@ Value computeOffset(ConversionPatternRewriter &rewriter, Location loc,
llvm::SmallVector<Value>
computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
Value laneId, int warpsPerBlock, int numOfElems,
ArrayRef<int64_t> reps, SharedMemoryObject smemObj,
SharedEncodingAttr srcLayout, unsigned nonKDim,
unsigned kDim) {
Expand All @@ -219,30 +216,43 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc,
vectorSize = numOfElems;
}

auto mapping = computeTensorElemMapping(
rewriter, loc, elemsPerInstr, waveId, laneId, warpsPerGroup, numOfElems,
reps, offsets, vectorSize, nonKDim, kDim);
llvm::SmallVector<Value> aOffsets(mapping.size());
for (int i = 0; i < mapping.size(); ++i) {
Value row = mapping[i][0];
Value col = mapping[i][1];
aOffsets[i] = computeOffset(rewriter, loc, row, col, smemObj, srcLayout);
auto mapping = computeTensorElemMappingInBlock(
rewriter, loc, elemsPerInstr, waveId, laneId, numOfElems, reps, offsets,
vectorSize, nonKDim, kDim);

const auto numBlocks = reps[0];
const auto blockSize = mapping.size();
auto order = srcLayout.getOrder();
llvm::SmallVector<Value> aOffsets(blockSize * numBlocks);

for (int block = 0; block < numBlocks; ++block) {
int blockNonKOffset = block * nonKDim * warpsPerBlock;
Value offAdjust = mul(i32_val(blockNonKOffset), strides[0]);
for (int i = 0; i < blockSize; ++i) {
Value row = mapping[i][0];
Value col = mapping[i][1];
aOffsets[block * blockSize + i] =
add(offAdjust,
computeOffset(rewriter, loc, row, col, smemObj, srcLayout));
}
}
return aOffsets;
}

llvm::SmallVector<Value>
computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
Value laneId, int warpsPerBlock, int numOfElems,
ArrayRef<int64_t> reps, SharedMemoryObject smemObj,
SharedEncodingAttr srcLayout, unsigned nonKDim,
unsigned kDim) {
// transpose reps and offsets, because operand B has layout equal to
// transposed operand A layout
// this unifies axis order, so non-K dim is 0, k dim is 1
SmallVector<int64_t> tElemsPerInstr{elemsPerInstr[1], elemsPerInstr[0]};
SmallVector<int64_t> tReps{reps[1], reps[0]};
SmallVector<Value> toffsets{smemObj.offsets[1], smemObj.offsets[0]};
SmallVector<Value> tOffsets{smemObj.offsets[1], smemObj.offsets[0]};
SmallVector<Value> tStrides{smemObj.strides[1], smemObj.strides[0]};

int vectorSize = 1;
if (srcLayout.getOrder()[0] == 0) {
Expand All @@ -252,16 +262,26 @@ computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc,
vectorSize = numOfElems;
}

auto mapping = computeTensorElemMapping(
rewriter, loc, tElemsPerInstr, waveId, laneId, warpsPerGroup, numOfElems,
tReps, toffsets, vectorSize, nonKDim, kDim);
llvm::SmallVector<Value> bOffsets(mapping.size());
for (int i = 0; i < mapping.size(); ++i) {
// swap row and col, because operand B layout is a transposed operand A
// layout
Value row = mapping[i][1];
Value col = mapping[i][0];
bOffsets[i] = computeOffset(rewriter, loc, row, col, smemObj, srcLayout);
auto mapping = computeTensorElemMappingInBlock(
rewriter, loc, tElemsPerInstr, waveId, laneId, numOfElems, tReps,
tOffsets, vectorSize, nonKDim, kDim);

const auto numBlocks = tReps[0];
const auto blockSize = mapping.size();
llvm::SmallVector<Value> bOffsets(blockSize * numBlocks);

for (int block = 0; block < numBlocks; ++block) {
int blockNonKOffset = block * nonKDim * warpsPerBlock;
Value offAdjust = mul(i32_val(blockNonKOffset), tStrides[0]);
for (int i = 0; i < mapping.size(); ++i) {
// swap row and col, because operand B layout is a transposed operand A
// layout
Value row = mapping[i][1];
Value col = mapping[i][0];
bOffsets[block * blockSize + i] =
add(offAdjust,
computeOffset(rewriter, loc, row, col, smemObj, srcLayout));
}
}
return bOffsets;
}
Expand All @@ -277,52 +297,6 @@ Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc,
return base;
}

/**
* @brief try find if value is an integer constant
*
* Trace def-use chain and return integer in case we can proof it is constant.
* Current implementation can trace chains of insertValue->extractValue
* operations.
*
* @param val Value for that we want to get constant
* @return std::optional on found integer value or empty std::optional
*/
std::optional<int> findConstValue(Value val) {
while (val && !val.getDefiningOp<LLVM::ConstantOp>()) {
LLVM::ExtractValueOp extractValOp =
val.getDefiningOp<LLVM::ExtractValueOp>();
if (!extractValOp)
return std::optional<int>();
auto extractPosArr = extractValOp.getPosition();
if (extractPosArr.size() > 1)
return std::optional<int>();
int extractPos = extractPosArr[0];

int insertPos = -1;
LLVM::InsertValueOp insertValOp;
Value container = extractValOp.getOperand();
do {
insertValOp = container.getDefiningOp<LLVM::InsertValueOp>();
if (!insertValOp)
return std::optional<int>();
auto insertPosArr = insertValOp.getPosition();
if (insertPosArr.size() > 1)
return std::optional<int>();
insertPos = insertPosArr[0];
container = insertValOp.getContainer();
} while (insertPos != extractPos);
val = insertValOp.getValue();
}
if (!val)
return std::optional<int>();
auto cOp = val.getDefiningOp<LLVM::ConstantOp>();
assert(cOp);
auto valAttr = cOp.getValueAttr();
auto intAttr = dyn_cast<mlir::IntegerAttr>(valAttr);
assert(intAttr);
return intAttr.getInt();
}

bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) {
return srcEncoding.getMaxPhase() > 1;
}
Expand All @@ -334,28 +308,28 @@ bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) {
// instruction
// @param waveId wave id for the "non K" axis
// @param laneId lane id in warp [0..63]
// @param warpsPerGroup number of warps per horizontal axis
// @param warpsPerBlock number of warps per horizontal axis
// @param numOfElems number of elements accessed by threads per repetition
// @param reps number of instructions repretition to fully cover dot operand
// @param cSwizzleOffset
llvm::SmallVector<Value>
fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
Value laneId, int warpsPerBlock, int numOfElems,
ArrayRef<int64_t> reps, Value cSwizzleOffset) {
auto numK = reps[0];
auto numN = reps[1];
SmallVector<Value> offsets(numK * numN * numOfElems);

auto iKDim = elemsPerInstr[0];
auto iNonKDim = elemsPerInstr[1];
int lineSize = warpsPerGroup * iNonKDim * numN;
int lineSize = warpsPerBlock * iNonKDim * numN;
Value _nonKDim = i32_val(iNonKDim);
Value waveOffset = mul(waveId, i32_val(iNonKDim));
Value colOffset = urem(laneId, _nonKDim);

for (int block = 0; block < numN; ++block) {
Value blockOffset = i32_val(block * iNonKDim * warpsPerGroup);
Value blockOffset = i32_val(block * iNonKDim * warpsPerBlock);
for (int tile = 0; tile < numK; ++tile) {
Value tileOffset = i32_val(tile * iKDim * lineSize);
for (int elem = 0; elem < numOfElems; ++elem) {
Expand Down Expand Up @@ -401,6 +375,43 @@ bool isKMajor(::llvm::ArrayRef<unsigned> order, int opIdx) {
return false;
}

/**
* @brief test if swizzle pattern is compatible with "normal" path of offset
* computation
*
* This function checks that swizzle pattern fits into one warp block
* and block size is a multiple of swizzle size along non-K dimension
*
* @param sharedLayout
* @param opIdx operand id 0 or 1
* @param shape tensor shape
* @param mfmaInstrNonK size of one instruction along non-k Dim (tile size)
* @param warpsPerBlockNonK number of warps along non-k Dim
* @return bool
*/
bool isSwizzlePatternNormalPathCompatible(const SharedEncodingAttr sharedLayout,
int opIdx, ArrayRef<int64_t> shape,
unsigned mfmaInstrNonK,
unsigned warpsPerBlockNonK) {
auto order = sharedLayout.getOrder();
auto rank = shape.size();
int64_t swizzleFastDimSize =
sharedLayout.getMaxPhase() * sharedLayout.getVec();
swizzleFastDimSize = std::min(swizzleFastDimSize, shape[order[0]]);
int64_t swizzleSlowDimSize =
sharedLayout.getMaxPhase() * sharedLayout.getPerPhase();
swizzleSlowDimSize = std::min(swizzleSlowDimSize, shape[order[1]]);
const auto swizzlePatternSizeK =
isKMajor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;
const auto swizzlePatternSizeNonK =
!isKMajor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize;

const auto blockSizeK = opIdx == 0 ? shape[rank - 1] : shape[rank - 2];
const auto blockSizeNonK = mfmaInstrNonK * warpsPerBlockNonK;
return blockSizeK % swizzlePatternSizeK == 0 &&
blockSizeNonK % swizzlePatternSizeNonK == 0;
}

Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
Location loc, Value tensor, DotOperandEncodingAttr encoding,
const SharedMemoryObject &smemObj,
Expand Down Expand Up @@ -449,7 +460,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
assert(numOfElems >= 1);

unsigned int maxNumWarps = shape[nonKDimIdx] / mfmaInstrNonK;
int warpsPerGroupNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps);
int warpsPerBlockNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps);
elemTy = typeConverter->convertType(elemTy);

SmallVector<Value> loadedValues;
Expand All @@ -467,7 +478,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
SmallVector<int64_t> elemsPerInstr{mfmaInstrK, mfmaInstrNonK};
SmallVector<int64_t> reps{numReps[1], numReps[0]};
offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr,
spatialWaveId, lane, warpsPerGroupNonK,
spatialWaveId, lane, warpsPerBlockNonK,
numOfElems, reps, cSwizzleOffset);
} else {
llvm_unreachable(
Expand All @@ -479,7 +490,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
"col major operand B should be handled in the normal path");
} else {
offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr,
spatialWaveId, lane, warpsPerGroupNonK,
spatialWaveId, lane, warpsPerBlockNonK,
numOfElems, numReps, cSwizzleOffset);
}
}
Expand All @@ -491,16 +502,15 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
// performant case
// 2. k-major + swizzling is disabled <-- for testing purpose only
// 3. non k-major + swizzling is enabled <-- for testing purpose only
//
// In this path, it requires a 2-step method to compute the offsets.
assert(isSwizzlePatternNormalPathCompatible(
sharedLayout, opIdx, shape, mfmaInstrNonK, warpsPerBlockNonK));
if (opIdx == 0) {
offsets = computeOffsetsAType(
rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerGroupNonK,
rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerBlockNonK,
numOfElems, numReps, smemObj, sharedLayout, mDim, mfmaInstrK);
} else {
assert(opIdx == 1);
offsets = computeOffsetsBType(
rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerGroupNonK,
rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerBlockNonK,
numOfElems, numReps, smemObj, sharedLayout, nDim, mfmaInstrK);
}
smemBase = computeBasePtr(rewriter, loc, smemObj);
Expand All @@ -509,27 +519,19 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
Type resElemTy = typeConverter->convertType(elemTy);
Type smemPtrTy = getShemPtrTy(elemTy);

int loadsPerThread = offsets.size() / numRepK / (isFastPath ? numRepNonK : 1);
int loadsPerThread = offsets.size() / numRepK / numRepNonK;
int elemsPerLoad = numOfElems / loadsPerThread;
assert(numOfElems % loadsPerThread == 0);

for (int nonK = 0; nonK < numRepNonK; ++nonK) {
int blockNonKOffset = nonK * mfmaInstrNonK * warpsPerGroupNonK;
Value offAdjust = i32_val(blockNonKOffset * shape[order[0]]);
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(elemTy, elemsPerLoad);
Value loadOffset;
if (isFastPath)
loadOffset = offsets[nonK * loadsPerThread * numRepK +
k * loadsPerThread + loadId];
else
// In the normal path, we only computed the offsets of elements
// in the first wave-block. Therefore, we update the offsets
// of elements in later wave-blocks by adding a constant stride
loadOffset = add(offAdjust, offsets[k * loadsPerThread + loadId]);
loadOffset = offsets[nonK * loadsPerThread * numRepK +
k * loadsPerThread + loadId];
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
getShemPtrTy(loadVecTy));
Value loadedValue = load(loadAddress);
Expand Down

0 comments on commit bcde44f

Please sign in to comment.