From bcde44f119b37fe438040c78913fc6455db5df26 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Tue, 12 Mar 2024 20:10:13 +0100 Subject: [PATCH] [AMD] Refactor SharedToDotOperandMFMA (#533) - remove unused functions and variables - replace "group" with block to unify language - add check for swizzling pattern compatibility with normal path - unify normal and fast path index computation --- .../SharedToDotOperandMFMA.cpp | 198 +++++++++--------- 1 file changed, 100 insertions(+), 98 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 865e43778fcf..86a4153603b2 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -126,7 +126,6 @@ swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row, * @param elemsPerInstr operand tile shape consumed by one MFMA instruction * @param waveId id component of 2d wave grid along nono-K axis * @param laneId lane id in warp [0..63] - * @param warpsPerGroup number of warps in one block * @param numOfElems number of elements accessed by thread per repetition * @param reps number of instructions repretition to fully cover dot operand * @param smemStrides strides in LDS tensor @@ -136,13 +135,11 @@ swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row, * @return vector (i-th element corresponds to i-th load instruction) of * 2-element vectors(tensor row and col). */ -llvm::SmallVector> -computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc, - const ArrayRef &elemsPerInstr, Value waveId, - Value laneId, int warpsPerGroup, int numOfElems, - ArrayRef reps, ArrayRef smemOffsets, - int loadVecSize, unsigned iNonKDim, unsigned iKDim) { - auto numM = reps[0]; +llvm::SmallVector> computeTensorElemMappingInBlock( + ConversionPatternRewriter &rewriter, Location loc, + const ArrayRef &elemsPerInstr, Value waveId, Value laneId, + int numOfElems, ArrayRef reps, ArrayRef smemOffsets, + int loadVecSize, unsigned iNonKDim, unsigned iKDim) { auto numK = reps[1]; const int loadsPerThread = numOfElems / loadVecSize; llvm::SmallVector> mapping(numK * loadsPerThread); @@ -204,7 +201,7 @@ Value computeOffset(ConversionPatternRewriter &rewriter, Location loc, llvm::SmallVector computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value waveId, - Value laneId, int warpsPerGroup, int numOfElems, + Value laneId, int warpsPerBlock, int numOfElems, ArrayRef reps, SharedMemoryObject smemObj, SharedEncodingAttr srcLayout, unsigned nonKDim, unsigned kDim) { @@ -219,14 +216,25 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc, vectorSize = numOfElems; } - auto mapping = computeTensorElemMapping( - rewriter, loc, elemsPerInstr, waveId, laneId, warpsPerGroup, numOfElems, - reps, offsets, vectorSize, nonKDim, kDim); - llvm::SmallVector aOffsets(mapping.size()); - for (int i = 0; i < mapping.size(); ++i) { - Value row = mapping[i][0]; - Value col = mapping[i][1]; - aOffsets[i] = computeOffset(rewriter, loc, row, col, smemObj, srcLayout); + auto mapping = computeTensorElemMappingInBlock( + rewriter, loc, elemsPerInstr, waveId, laneId, numOfElems, reps, offsets, + vectorSize, nonKDim, kDim); + + const auto numBlocks = reps[0]; + const auto blockSize = mapping.size(); + auto order = srcLayout.getOrder(); + llvm::SmallVector aOffsets(blockSize * numBlocks); + + for (int block = 0; block < numBlocks; ++block) { + int blockNonKOffset = block * nonKDim * warpsPerBlock; + Value offAdjust = mul(i32_val(blockNonKOffset), strides[0]); + for (int i = 0; i < blockSize; ++i) { + Value row = mapping[i][0]; + Value col = mapping[i][1]; + aOffsets[block * blockSize + i] = + add(offAdjust, + computeOffset(rewriter, loc, row, col, smemObj, srcLayout)); + } } return aOffsets; } @@ -234,15 +242,17 @@ computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc, llvm::SmallVector computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value waveId, - Value laneId, int warpsPerGroup, int numOfElems, + Value laneId, int warpsPerBlock, int numOfElems, ArrayRef reps, SharedMemoryObject smemObj, SharedEncodingAttr srcLayout, unsigned nonKDim, unsigned kDim) { // transpose reps and offsets, because operand B has layout equal to // transposed operand A layout + // this unifies axis order, so non-K dim is 0, k dim is 1 SmallVector tElemsPerInstr{elemsPerInstr[1], elemsPerInstr[0]}; SmallVector tReps{reps[1], reps[0]}; - SmallVector toffsets{smemObj.offsets[1], smemObj.offsets[0]}; + SmallVector tOffsets{smemObj.offsets[1], smemObj.offsets[0]}; + SmallVector tStrides{smemObj.strides[1], smemObj.strides[0]}; int vectorSize = 1; if (srcLayout.getOrder()[0] == 0) { @@ -252,16 +262,26 @@ computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc, vectorSize = numOfElems; } - auto mapping = computeTensorElemMapping( - rewriter, loc, tElemsPerInstr, waveId, laneId, warpsPerGroup, numOfElems, - tReps, toffsets, vectorSize, nonKDim, kDim); - llvm::SmallVector bOffsets(mapping.size()); - for (int i = 0; i < mapping.size(); ++i) { - // swap row and col, because operand B layout is a transposed operand A - // layout - Value row = mapping[i][1]; - Value col = mapping[i][0]; - bOffsets[i] = computeOffset(rewriter, loc, row, col, smemObj, srcLayout); + auto mapping = computeTensorElemMappingInBlock( + rewriter, loc, tElemsPerInstr, waveId, laneId, numOfElems, tReps, + tOffsets, vectorSize, nonKDim, kDim); + + const auto numBlocks = tReps[0]; + const auto blockSize = mapping.size(); + llvm::SmallVector bOffsets(blockSize * numBlocks); + + for (int block = 0; block < numBlocks; ++block) { + int blockNonKOffset = block * nonKDim * warpsPerBlock; + Value offAdjust = mul(i32_val(blockNonKOffset), tStrides[0]); + for (int i = 0; i < mapping.size(); ++i) { + // swap row and col, because operand B layout is a transposed operand A + // layout + Value row = mapping[i][1]; + Value col = mapping[i][0]; + bOffsets[block * blockSize + i] = + add(offAdjust, + computeOffset(rewriter, loc, row, col, smemObj, srcLayout)); + } } return bOffsets; } @@ -277,52 +297,6 @@ Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc, return base; } -/** - * @brief try find if value is an integer constant - * - * Trace def-use chain and return integer in case we can proof it is constant. - * Current implementation can trace chains of insertValue->extractValue - * operations. - * - * @param val Value for that we want to get constant - * @return std::optional on found integer value or empty std::optional - */ -std::optional findConstValue(Value val) { - while (val && !val.getDefiningOp()) { - LLVM::ExtractValueOp extractValOp = - val.getDefiningOp(); - if (!extractValOp) - return std::optional(); - auto extractPosArr = extractValOp.getPosition(); - if (extractPosArr.size() > 1) - return std::optional(); - int extractPos = extractPosArr[0]; - - int insertPos = -1; - LLVM::InsertValueOp insertValOp; - Value container = extractValOp.getOperand(); - do { - insertValOp = container.getDefiningOp(); - if (!insertValOp) - return std::optional(); - auto insertPosArr = insertValOp.getPosition(); - if (insertPosArr.size() > 1) - return std::optional(); - insertPos = insertPosArr[0]; - container = insertValOp.getContainer(); - } while (insertPos != extractPos); - val = insertValOp.getValue(); - } - if (!val) - return std::optional(); - auto cOp = val.getDefiningOp(); - assert(cOp); - auto valAttr = cOp.getValueAttr(); - auto intAttr = dyn_cast(valAttr); - assert(intAttr); - return intAttr.getInt(); -} - bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) { return srcEncoding.getMaxPhase() > 1; } @@ -334,14 +308,14 @@ bool hasSwizzleEnabled(const SharedEncodingAttr &srcEncoding) { // instruction // @param waveId wave id for the "non K" axis // @param laneId lane id in warp [0..63] -// @param warpsPerGroup number of warps per horizontal axis +// @param warpsPerBlock number of warps per horizontal axis // @param numOfElems number of elements accessed by threads per repetition // @param reps number of instructions repretition to fully cover dot operand // @param cSwizzleOffset llvm::SmallVector fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value waveId, - Value laneId, int warpsPerGroup, int numOfElems, + Value laneId, int warpsPerBlock, int numOfElems, ArrayRef reps, Value cSwizzleOffset) { auto numK = reps[0]; auto numN = reps[1]; @@ -349,13 +323,13 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, auto iKDim = elemsPerInstr[0]; auto iNonKDim = elemsPerInstr[1]; - int lineSize = warpsPerGroup * iNonKDim * numN; + int lineSize = warpsPerBlock * iNonKDim * numN; Value _nonKDim = i32_val(iNonKDim); Value waveOffset = mul(waveId, i32_val(iNonKDim)); Value colOffset = urem(laneId, _nonKDim); for (int block = 0; block < numN; ++block) { - Value blockOffset = i32_val(block * iNonKDim * warpsPerGroup); + Value blockOffset = i32_val(block * iNonKDim * warpsPerBlock); for (int tile = 0; tile < numK; ++tile) { Value tileOffset = i32_val(tile * iKDim * lineSize); for (int elem = 0; elem < numOfElems; ++elem) { @@ -401,6 +375,43 @@ bool isKMajor(::llvm::ArrayRef order, int opIdx) { return false; } +/** + * @brief test if swizzle pattern is compatible with "normal" path of offset + * computation + * + * This function checks that swizzle pattern fits into one warp block + * and block size is a multiple of swizzle size along non-K dimension + * + * @param sharedLayout + * @param opIdx operand id 0 or 1 + * @param shape tensor shape + * @param mfmaInstrNonK size of one instruction along non-k Dim (tile size) + * @param warpsPerBlockNonK number of warps along non-k Dim + * @return bool + */ +bool isSwizzlePatternNormalPathCompatible(const SharedEncodingAttr sharedLayout, + int opIdx, ArrayRef shape, + unsigned mfmaInstrNonK, + unsigned warpsPerBlockNonK) { + auto order = sharedLayout.getOrder(); + auto rank = shape.size(); + int64_t swizzleFastDimSize = + sharedLayout.getMaxPhase() * sharedLayout.getVec(); + swizzleFastDimSize = std::min(swizzleFastDimSize, shape[order[0]]); + int64_t swizzleSlowDimSize = + sharedLayout.getMaxPhase() * sharedLayout.getPerPhase(); + swizzleSlowDimSize = std::min(swizzleSlowDimSize, shape[order[1]]); + const auto swizzlePatternSizeK = + isKMajor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize; + const auto swizzlePatternSizeNonK = + !isKMajor(order, opIdx) ? swizzleFastDimSize : swizzleSlowDimSize; + + const auto blockSizeK = opIdx == 0 ? shape[rank - 1] : shape[rank - 2]; + const auto blockSizeNonK = mfmaInstrNonK * warpsPerBlockNonK; + return blockSizeK % swizzlePatternSizeK == 0 && + blockSizeNonK % swizzlePatternSizeNonK == 0; +} + Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Location loc, Value tensor, DotOperandEncodingAttr encoding, const SharedMemoryObject &smemObj, @@ -449,7 +460,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, assert(numOfElems >= 1); unsigned int maxNumWarps = shape[nonKDimIdx] / mfmaInstrNonK; - int warpsPerGroupNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps); + int warpsPerBlockNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps); elemTy = typeConverter->convertType(elemTy); SmallVector loadedValues; @@ -467,7 +478,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, SmallVector elemsPerInstr{mfmaInstrK, mfmaInstrNonK}; SmallVector reps{numReps[1], numReps[0]}; offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr, - spatialWaveId, lane, warpsPerGroupNonK, + spatialWaveId, lane, warpsPerBlockNonK, numOfElems, reps, cSwizzleOffset); } else { llvm_unreachable( @@ -479,7 +490,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, "col major operand B should be handled in the normal path"); } else { offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr, - spatialWaveId, lane, warpsPerGroupNonK, + spatialWaveId, lane, warpsPerBlockNonK, numOfElems, numReps, cSwizzleOffset); } } @@ -491,16 +502,15 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, // performant case // 2. k-major + swizzling is disabled <-- for testing purpose only // 3. non k-major + swizzling is enabled <-- for testing purpose only - // - // In this path, it requires a 2-step method to compute the offsets. + assert(isSwizzlePatternNormalPathCompatible( + sharedLayout, opIdx, shape, mfmaInstrNonK, warpsPerBlockNonK)); if (opIdx == 0) { offsets = computeOffsetsAType( - rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerGroupNonK, + rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerBlockNonK, numOfElems, numReps, smemObj, sharedLayout, mDim, mfmaInstrK); } else { - assert(opIdx == 1); offsets = computeOffsetsBType( - rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerGroupNonK, + rewriter, loc, elemsPerInstr, spatialWaveId, lane, warpsPerBlockNonK, numOfElems, numReps, smemObj, sharedLayout, nDim, mfmaInstrK); } smemBase = computeBasePtr(rewriter, loc, smemObj); @@ -509,27 +519,19 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, Type resElemTy = typeConverter->convertType(elemTy); Type smemPtrTy = getShemPtrTy(elemTy); - int loadsPerThread = offsets.size() / numRepK / (isFastPath ? numRepNonK : 1); + int loadsPerThread = offsets.size() / numRepK / numRepNonK; int elemsPerLoad = numOfElems / loadsPerThread; assert(numOfElems % loadsPerThread == 0); for (int nonK = 0; nonK < numRepNonK; ++nonK) { - int blockNonKOffset = nonK * mfmaInstrNonK * warpsPerGroupNonK; - Value offAdjust = i32_val(blockNonKOffset * shape[order[0]]); for (int k = 0; k < numRepK; ++k) { auto vecTy = vec_ty(resElemTy, numOfElems); Value valVec = undef(vecTy); for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) { auto loadVecTy = vec_ty(elemTy, elemsPerLoad); Value loadOffset; - if (isFastPath) - loadOffset = offsets[nonK * loadsPerThread * numRepK + - k * loadsPerThread + loadId]; - else - // In the normal path, we only computed the offsets of elements - // in the first wave-block. Therefore, we update the offsets - // of elements in later wave-blocks by adding a constant stride - loadOffset = add(offAdjust, offsets[k * loadsPerThread + loadId]); + loadOffset = offsets[nonK * loadsPerThread * numRepK + + k * loadsPerThread + loadId]; Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset), getShemPtrTy(loadVecTy)); Value loadedValue = load(loadAddress);