From da5040d5a5010becc7a83ea6466dd2913e17beef Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Thu, 25 Jan 2024 13:20:40 +0000 Subject: [PATCH 1/8] [MFMA][FRONTEND] Add more options for forced mfma layout sizes This PR: - adds an `matrix_instr_nonkdim` options to force MFMA 64x4 and 4x64 layout: 464 corresponds 4(M)x64(N), 644 corresponds 64(M)x4(N) - adds tests for this option - fixes swizzling patter in some cases MFMA size heuristic now looks like this: 1. If kernel specific option is set, pick it 2. If the result tile shape is larger than 32x32, pick mfma32 3. If the tile shape is smaller than 32x32 but larger than 16x16, pick mfma16 4. if the tile shape is smaller than 4x64 or 64x4, pick mfma4x4 5. Otherwise, pick mfma4x64 or mfma64x4, depending on what tile fits into matrices --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 10 ++++++++++ .../Transforms/AccelerateAMDMatmul.cpp | 16 ++++++++++++++-- python/test/unit/language/test_core_amd.py | 18 +++++++++++++++++- 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index c8fdc1c70f5c..54bc607a4ede 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -131,6 +131,7 @@ compared to 1*64 when the hasLeadingOffset is false. if (mfmaEnc) { int kDimNum = dotOpEnc.getOpIdx() == 0 ? 1 : 0; + int nonKDimNum = 1 - kDimNum; if (needTrans) kDimNum = 1 - kDimNum; bool isKDimInner = (order[0] == kDimNum); @@ -154,6 +155,15 @@ compared to 1*64 when the hasLeadingOffset is false. auto nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim; if (4 == nonKDim) maxPhase = 4; + // if maxPhase * perPhase is larger than one block of warps, + // fallback to unswizzled tensor. + // Shared to dot op conversion requires that swizzling patern + // fits into one block of warps. + auto warpsPerCTA = mfmaEnc.getWarpsPerCTA(); + if (maxPhase * perPhase > nonKDim * warpsPerCTA[nonKDimNum]) { + assert(isKDimInner); + maxPhase = 1; + } assert(maxPhase > 0); return get(context, vecSize, perPhase, maxPhase, order, CTALayout); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp index 12fdbf23e4a4..3f39248597bd 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp @@ -175,8 +175,20 @@ class BlockedToMFMA : public mlir::RewritePattern { unsigned mDim = 0; unsigned nDim = 0; if (enforcedNonKDim != 0) { - mDim = enforcedNonKDim; - nDim = enforcedNonKDim; + if (enforcedNonKDim == 32 || enforcedNonKDim == 16 || + enforcedNonKDim == 4) { + mDim = enforcedNonKDim; + nDim = enforcedNonKDim; + } else if (enforcedNonKDim == 464) { + mDim = 4; + nDim = 64; + } else if (enforcedNonKDim == 644) { + mDim = 64; + nDim = 4; + } else { + llvm::report_fatal_error("Invalid MFMA nonKDim option, supported " + "values are: 32, 16, 4, 464, 644"); + } } else { int minSize = std::min(resShape[0], resShape[1]); if (minSize >= 32) { diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 2bf5c63dd613..0a451d539453 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -1665,6 +1665,19 @@ def kernel(X, stride_xm, stride_xn, for non_k_dim in [0, 4, 16, 32] if not (allow_tf32 and (in_dtype in ['float16']))] + + [(*shape, warps, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, 1) + for shape in [(64, 16, 128), (16, 64, 128)] + for warps in [1, 4] + for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax'] + for allow_tf32 in [False] + for in_dtype, out_dtype in [('float16', 'float16'), + ('bfloat16', 'float32'), + ('float8e5m2fnuz', 'float32'), + ('float8e4m3fnuz', 'float32'), + ('float16', 'float32'), + ('float32', 'float32')] + for non_k_dim in [464, 644]] + + [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, non_k_dim, kpack) for shape_nw in [[128, 128, 32, 2], [128, 16, 32, 4], @@ -1728,6 +1741,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o pytest.skip("incompatible non_k_dim == 4 with K size") if non_k_dim == 4 and (M > 16 or N > 16): pytest.skip("skipping large matrices for non_k_dim == 4 to speedup testing") + if (non_k_dim == 464 and N < 64) or (non_k_dim == 644 and M < 64): + pytest.skip(f"skipping non_k_dim={non_k_dim} specific test with incompatible matrix sizes") + if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") @@ -1852,7 +1868,7 @@ def kernel(X, stride_xm, stride_xk, z_tri = to_triton(z, device=device) if epilogue == 'trans': - z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) + z_tri = torch.as_strided(z_tri, (M, N), [1, M]) if out_dtype == 'int8': out_dtype = tl.int8 From fde46d8fd07675f12227fbafd0fbc1f3595b0a44 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Wed, 13 Mar 2024 19:29:00 +0000 Subject: [PATCH 2/8] [MFMA] MFMA 4x64 64x4 version 2 Extend K dimension of mfma4x64 and mfma64x4 dot operand layout from 4 to 64. --- .../Dialect/TritonGPU/Transforms/Utility.h | 6 +- lib/Analysis/Utility.cpp | 3 +- .../SharedToDotOperandMFMA.cpp | 14 +- .../TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp | 234 +++++++++++++----- lib/Dialect/TritonGPU/IR/Dialect.cpp | 36 ++- .../Transforms/AccelerateAMDMatmul.cpp | 41 ++- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 149 +++++------ python/test/unit/language/test_core_amd.py | 7 +- .../generate_accelerate_matmul_tests.py | 182 ++++++++++++++ test/TritonGPU/accelerate-matmul-cdna1.mlir | 113 ++++----- test/TritonGPU/accelerate-matmul-cdna2.mlir | 113 ++++----- test/TritonGPU/accelerate-matmul-cdna3.mlir | 113 ++++----- 12 files changed, 667 insertions(+), 344 deletions(-) create mode 100755 scripts/amd/lit_tests/generate_accelerate_matmul_tests.py diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 9cb4de97d744..5afa922665ef 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -175,7 +175,8 @@ struct MfmaInsnAttr { unsigned n; unsigned k; // k_base refers to the number of elements per thread - unsigned k_base; + unsigned k_base_a; + unsigned k_base_b; llvm::StringRef insn; }; @@ -223,7 +224,8 @@ class MfmaInsn { unsigned getMDim(); unsigned getNDim(); StringRef getInsnName(); - unsigned getKBase(); + unsigned getKBaseA(); + unsigned getKBaseB(); }; } // namespace mlir diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 6dbc10b943a6..51c7ed03b1a9 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -571,7 +571,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { dotOperandLayout.getOpIdx() == 0 && dotOperandLayout.getKWidth() == 4 && dotOperandLayout.getParent() == mfmaLayout && - (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) && + (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16 || + (mfmaLayout.getMDim() == 4 && mfmaLayout.getNDim() == 64)) && mfmaLayout.getIsTransposed() && (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 86a4153603b2..77d5f6ca5160 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -158,14 +158,12 @@ llvm::SmallVector> computeTensorElemMappingInBlock( if (iNonKDim == 32) laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0); else { - // In this configuration wave contains 16 copies of same data - if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) { + // shortcut for 64x64 tile size. + // In this case warp do not wrap, so no need to introduce this offset + if (iNonKDim == 64) laneHOffset = i32_val(0); - } else { - assert(iKDim * iNonKDim / numOfElems == 64 && - "seems no all threads in wave contain unique elements"); + else laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems)); - } } for (int loadId = 0; loadId < loadsPerThread; ++loadId) { @@ -346,7 +344,7 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, // 32 33 34 35 ... 63 // 32 33 34 35 ... 63 Value halfOffset; - if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) + if (iNonKDim == 64) halfOffset = i32_val(0); else halfOffset = @@ -456,6 +454,8 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int numSubBlocks = 1; if ((mfmaInstrK == 4 || mfmaInstrK == 1) && mfmaInstrNonK == 4) numSubBlocks = 16; + assert(numSubBlocks == 1 && + "after reworking layout, there should be no redundency"); int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWaveSize; assert(numOfElems >= 1); diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp index 10bec3614969..94809cbdf360 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -37,7 +37,12 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::MfmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; -using ValueTable = std::map, Value>; +// mapping from touple to vector of values +// vector contains single element for MFMA32, MFMA16 and MFMA4 layouts +// for MFMA 4x64 and 64x4 layouts there are 16 vectors for one of the arguments, +// because each repetition in these layouts requires 16 mfma operations +using ValueTable = std::map, + llvm::SmallVector>; struct DotOpMFMAConversionHelper { MfmaEncodingAttr mfmaLayout; @@ -60,16 +65,114 @@ struct DotOpMFMAConversionHelper { return rewriter.create(loc, i32_ty, tid); } + /** + * @param mfmaInsnName + * @param valA + * @param valB + * @param valC + * @param cbsz Control Broadcast Size modifier + * @param abid A-matrix Broadcast Identifier + * @param blgp B-matrix Lane Group Pattern modifier + */ Value generateMFMAOp(StringRef mfmaInsnName, Value valA, Value valB, - Value valC) const { + Value valC, int cbsz = 0, int abid = 0, + int blgp = 0) const { + assert(cbsz >= 0 && cbsz <= 4); + assert(abid >= 0 && abid <= 15); + assert(blgp >= 0 && blgp <= 7); auto resType = valC.getType(); - Value zeroFlag = i32_val(0); + Value zeroVal = i32_val(0); + Value cbszFlag = cbsz != 0 ? i32_val(cbsz) : zeroVal; + Value abidFlag = abid != 0 ? i32_val(abid) : zeroVal; + Value blgpFlag = blgp != 0 ? i32_val(blgp) : zeroVal; OperationState loweredOp(loc, mfmaInsnName); loweredOp.addTypes(resType); - loweredOp.addOperands({valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + loweredOp.addOperands({valA, valB, valC, cbszFlag, abidFlag, blgpFlag}); return rewriter.create(loweredOp)->getResult(0); } + Value broadcastGroup(Value val, int groupId, int numGroups) const { + constexpr int waveSize = 64; + const int groupSize = waveSize / numGroups; + + Value lane = getThreadId(); + // Multiply by 4, because permute requires offset in bytes + Value laneOffset = mul(urem(lane, i32_val(groupSize)), i32_val(4)); + Value permuteAddr = add(laneOffset, i32_val(groupId * groupSize * 4)); + Type valType = val.getType(); + Value broadcasted; + if (valType.isInteger(32)) + broadcasted = rewriter.create(loc, val.getType(), + permuteAddr, val); + if (valType.isF32()) { + val = bitcast(val, i32_ty); + broadcasted = rewriter.create(loc, val.getType(), + permuteAddr, val); + broadcasted = bitcast(broadcasted, f32_ty); + } + if (valType.isa()) { + auto vecTy = valType.dyn_cast(); + auto vecBitSize = vecTy.getElementType().getIntOrFloatBitWidth() * + vecTy.getNumElements(); + const int int32VecSize = vecBitSize / 32; + + Type int32VecTy = vec_ty(i32_ty, int32VecSize); + Value int32Val = bitcast(val, int32VecTy); + Value int32Broadcasted = undef(int32VecTy); + for (int i = 0; i < int32VecSize; ++i) { + Value int32Chunk = extract_element(i32_ty, int32Val, i32_val(i)); + Value broadcastedChunk = rewriter.create( + loc, i32_ty, permuteAddr, int32Chunk); + int32Broadcasted = insert_element(int32VecTy, int32Broadcasted, + broadcastedChunk, i32_val(i)); + } + broadcasted = bitcast(int32Broadcasted, valType); + } + assert(broadcasted); + return broadcasted; + } + + Value generateMFMATile(StringRef mfmaInsnName, SmallVector valA, + SmallVector valB, Value valC, int mDim, + int nDim, bool transpose) const { + + Value acc; + if (mDim == nDim) { + assert(valA.size() == 1 && valB.size() == 1); + acc = transpose ? generateMFMAOp(mfmaInsnName, valB[0], valA[0], valC) + : generateMFMAOp(mfmaInsnName, valA[0], valB[0], valC); + } + if (mDim == 4 && nDim == 64 || mDim == 64 && nDim == 4) { + // broadcast selected kRep A operand matrix to all A matrices(2^4=16) + constexpr int broadcastCtrl = 4; + constexpr int numRepeats = 16; + acc = valC; + for (int kRep = 0; kRep < numRepeats; kRep++) { + if (mDim == 4 && !transpose) { + assert(valA.size() == 1 && valB.size() == 16); + acc = generateMFMAOp(mfmaInsnName, valA[0], valB[kRep], acc, + broadcastCtrl, kRep); + } + if (mDim == 4 && transpose) { + assert(valA.size() == 1 && valB.size() == 16); + Value broadcastValA = broadcastGroup(valA[0], kRep, numRepeats); + acc = generateMFMAOp(mfmaInsnName, valB[kRep], broadcastValA, acc); + } + if (nDim == 4 && !transpose) { + assert(valA.size() == 16 && valB.size() == 1); + Value broadcastValB = broadcastGroup(valB[0], kRep, numRepeats); + acc = generateMFMAOp(mfmaInsnName, valA[kRep], broadcastValB, acc); + } + if (nDim == 4 && transpose) { + assert(valA.size() == 16 && valB.size() == 1); + acc = generateMFMAOp(mfmaInsnName, valB[0], valA[kRep], acc, + broadcastCtrl, kRep); + } + } + } + return acc; + } + int getNumSubmatrices(Type elementType, int mDim, int nDim) const { if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) return 1; @@ -187,13 +290,14 @@ struct DotOpMFMAConversionHelper { llvm::report_fatal_error("No match found in MFMA database\n"); mfmaInsnName = (*maybeMfmaInsn).getInsnName(); - unsigned k_base = (*maybeMfmaInsn).getKBase(); + unsigned kBaseA = (*maybeMfmaInsn).getKBaseA(); + unsigned kBaseB = (*maybeMfmaInsn).getKBaseB(); auto aEncoding = aTensorTy.getEncoding().cast(); auto bEncoding = bTensorTy.getEncoding().cast(); - auto kWidth = aEncoding.getKWidth(); - assert(kWidth == bEncoding.getKWidth()); + auto kWidthA = aEncoding.getKWidth(); + auto kWidthB = bEncoding.getKWidth(); auto repA = aEncoding.getMFMARep(aTensorTy.getShape()); auto repB = bEncoding.getMFMARep(bTensorTy.getShape()); @@ -209,9 +313,9 @@ struct DotOpMFMAConversionHelper { auto numRepK = repA[1]; auto operandA = getValuesFromDotOperandLayoutStruct( - loadedA, numRepM, numRepK, kWidth, k_base, aTensorTy.getElementType()); + loadedA, numRepM, numRepK, kWidthA, kBaseA, aTensorTy.getElementType()); auto operandB = getValuesFromDotOperandLayoutStruct( - loadedB, numRepN, numRepK, kWidth, k_base, aTensorTy.getElementType()); + loadedB, numRepN, numRepK, kWidthB, kBaseB, aTensorTy.getElementType()); auto dstElemTy = dTensorTy.getElementType(); auto fc = @@ -236,12 +340,10 @@ struct DotOpMFMAConversionHelper { acc = zeroAuxiliarBlocks(subBlocks, acc); for (size_t k = 0; k < numRepK; k++) - for (int kpack = 0; kpack < kWidth / k_base; ++kpack) - acc = mfmaLayout.getIsTransposed() - ? generateMFMAOp(mfmaInsnName, operandB[kpack][{n, k}], - operandA[kpack][{m, k}], acc) - : generateMFMAOp(mfmaInsnName, operandA[kpack][{m, k}], - operandB[kpack][{n, k}], acc); + for (int kpack = 0; kpack < kWidthA / kBaseA; ++kpack) + acc = generateMFMATile(mfmaInsnName, operandA[{kpack, m, k}], + operandB[{kpack, n, k}], acc, mDim, nDim, + mfmaLayout.getIsTransposed()); acc = reduceSubBlocks(subBlocks, acc); for (unsigned v = 0; v < elemsPerVec; ++v) { fc[m * numRepN * elemsPerVec + n * elemsPerVec + v] = @@ -260,30 +362,39 @@ struct DotOpMFMAConversionHelper { } /** - * @brief extract vector from rawElems based on kWidth and k_base + * @brief extract vector from rawElems based on kWidth and kBase * rawElems is a vector of kWidth elements. We need to prepare vector(s) of - * k_base elements for each mfma instruction + * kBase elements for each mfma instruction + * + * @param rawElems vector of "raw" elements for one mfma tile + * @param k id in k-pack + * @param kPack size of k-pack + * @param numIntrinsics number of operands we need to extract + * @param type type mfma intrinsic requires + * + * @return elements converted for one repetition */ - SmallVector extractOperands(Value rawElems, int kWidth, int k_base, - Type type) const { - int kpack = kWidth / k_base; + SmallVector extractOperands(Value rawElems, int k, int kPack, + int numIntrinsics, Type type) const { + assert(numIntrinsics == 1 || numIntrinsics == 16); + auto rawTy = rawElems.getType().cast(); + auto rawElemTy = rawTy.getElementType(); + // number of elements required by one mfma intrinsic + int intrinsicK = rawTy.getNumElements() / numIntrinsics / kPack; + int kBase = rawTy.getNumElements() / kPack; + SmallVector results; - auto vecTy = vec_ty(type, k_base); - for (int k = 0; k < kpack; ++k) { - Value vec = undef(vecTy); - for (int elemId = 0; elemId < k_base; ++elemId) { - auto val = - extract_element(type, rawElems, i32_val(elemId + k * k_base)); - vec = insert_element(vecTy, vec, val, i32_val(elemId)); + // extract needed elements in original dtype + auto typedVecTy = vec_ty(rawElemTy, intrinsicK); + for (int intrinsic = 0; intrinsic < numIntrinsics; ++intrinsic) { + Value typedVec = undef(typedVecTy); + for (int elemId = 0; elemId < intrinsicK; ++elemId) { + int elemOff = elemId + intrinsic * intrinsicK + k * kBase; + auto val = extract_element(rawElemTy, rawElems, i32_val(elemOff)); + typedVec = insert_element(typedVecTy, typedVec, val, i32_val(elemId)); } - if (type.getIntOrFloatBitWidth() == 8) { - if (4 == k_base) - // This is for int8 on pre- MI300 GPUs - results.push_back(bitcast(vec, i32_ty)); - if (8 == k_base) - results.push_back(bitcast(vec, i64_ty)); - } else - results.push_back(vec); + Value castedVec = bitcast(typedVec, type); + results.push_back(castedVec); } return results; } @@ -292,35 +403,38 @@ struct DotOpMFMAConversionHelper { * @brief Converts dot operand structure to value table and converts types * appropriate for mfma instructions */ - SmallVector - getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1, int kWidth, - int k_base, Type type) const { + ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1, + int kWidth, int kBase, + Type type) const { auto elems = typeConverter->unpackLLElements(loc, value, rewriter, type); - ValueTable vals; - ValueTable vals1; - int kpack = kWidth / k_base; - SmallVector dotOpVals(kpack); + int kpack = kWidth / kBase; + // "Wide operand" means that this operand is for mfma 4x64 layout + // This operand is 64x64 for fp16, bf16 and int8 data types and + // 16x64 for fp32 + bool wideOperand = kWidth >= 16; + // How many rocdl intrinsics will process one tile + int numIntrinsics = wideOperand ? 16 : 1; + int intrinsicKWidth = wideOperand ? kBase / numIntrinsics : kBase; + Type intrinsicDType; + if (type.isF32()) + intrinsicDType = f32_ty; + if (type.getIntOrFloatBitWidth() == 8) + intrinsicDType = rewriter.getIntegerType(intrinsicKWidth * 8); + if (type.isBF16()) + intrinsicDType = vec_ty(i16_ty, intrinsicKWidth); + if (type.isF16()) + intrinsicDType = vec_ty(f16_ty, intrinsicKWidth); + assert(intrinsicDType); + + ValueTable dotOpVals; for (int i = 0; i < n0; i++) { for (int j = 0; j < n1; j++) { auto rawElems = elems[n1 * i + j]; - - if (type.isF32()) { - for (int k = 0; k < kpack; ++k) { - dotOpVals[k][{i, j}] = extract_element(type, rawElems, i32_val(k)); - } - } else { - SmallVector vals; - if (type.getIntOrFloatBitWidth() == 8) { - vals = extractOperands(rawElems, kWidth, k_base, i8_ty); - } else if (type.isBF16()) { - vals = extractOperands(rawElems, kWidth, k_base, i16_ty); - } else { - assert(type.isF16() && "Unsupported data type"); - vals = extractOperands(rawElems, kWidth, k_base, f16_ty); - } - for (int k = 0; k < kpack; ++k) { - dotOpVals[k][{i, j}] = vals[k]; - } + for (int k = 0; k < kpack; k++) { + SmallVector vals = extractOperands( + rawElems, k, kpack, numIntrinsics, intrinsicDType); + assert(vals.size() == numIntrinsics); + dotOpVals[{k, i, j}] = vals; } } } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 87e8bb218bc9..48171dc43822 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -304,12 +304,17 @@ SmallVector getSizePerThread(Attribute layout) { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); return {}; } - } else if (parentLayout.isa()) { + } else if (auto mfmaLayout = parentLayout.dyn_cast()) { auto opIdx = dotLayout.getOpIdx(); + auto kWidth = dotLayout.getKWidth(); if (opIdx == 0) { - return {4, 1}; + int repeats = + (mfmaLayout.getMDim() == 64 && mfmaLayout.getNDim() == 4) ? 16 : 1; + return {1, kWidth * repeats}; } else if (opIdx == 1) { - return {1, 4}; + int repeats = + (mfmaLayout.getMDim() == 4 && mfmaLayout.getNDim() == 64) ? 16 : 1; + return {kWidth * repeats, 1}; } else { assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1"); return {}; @@ -458,6 +463,8 @@ SmallVector getShapePerCTATile(Attribute layout, auto parentShapePerCTA = getShapePerCTATile(parentLayout, tensorShape); auto opIdx = dotLayout.getOpIdx(); + assert(parentMfmaLayout.getMDim() == 32); + if (opIdx == 0) { return {parentShapePerCTA[0], 32}; } else if (opIdx == 1) { @@ -1102,16 +1109,13 @@ DotOperandEncodingAttr::getMFMAElemsPerInstr() const { (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); int64_t kWidth = getKWidth(); constexpr int waveSize = 64; // MFMA is used on wave64 architectures only - int kGroups = -1; - if (mDim == nDim) - kGroups = waveSize / mDim; - if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) - kGroups = 1; + auto nonKDim = getOpIdx() == 0 ? mDim : nDim; + int kGroups = waveSize / nonKDim; int64_t kDim = kWidth * kGroups; if (getOpIdx() == 0) - return {mDim, kDim}; + return {nonKDim, kDim}; else - return {kDim, nDim}; + return {kDim, nonKDim}; } SmallVector @@ -1902,6 +1906,18 @@ struct TritonGPUInferLayoutInterface // Verify that the encodings are valid. if (!aEncoding || !bEncoding) return op->emitError("mismatching encoding between A and B operands"); +#ifdef USE_ROCM + auto aParentEncoding = + aEncoding.getParent().dyn_cast_or_null(); + auto bParentEncoding = + bEncoding.getParent().dyn_cast_or_null(); + if (aParentEncoding != bParentEncoding) + return op->emitError( + "mismatching parent encoding between A and B operands"); + if (aParentEncoding != nullptr && + aParentEncoding.getMDim() != aParentEncoding.getNDim()) + return success(); +#endif // USE_ROCM if (aEncoding.getKWidth() != bEncoding.getKWidth()) return op->emitError("mismatching kWidth between A and B operands"); return success(); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp index 3f39248597bd..32303ca748bc 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp @@ -158,9 +158,8 @@ class BlockedToMFMA : public mlir::RewritePattern { /// @brief Choose MFMA instruction parameters /// @param dot target dot operation - /// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments - std::tuple - chooseMfmaDimensions(tt::DotOp dot) const { + /// @return selected mfma instruction + MfmaInsn chooseMfmaDimensions(tt::DotOp dot) const { // number of matrix elements along k dim per one MFMA intruction unsigned kDim = 0; auto opType = dot.getA().getType().cast(); @@ -200,6 +199,8 @@ class BlockedToMFMA : public mlir::RewritePattern { nDim = 16; } if (minSize < 16) { + assert(opType.getShape()[1] >= 64 && + "k should be at least 64 to use this layout"); if (resShape[0] < 16 && resShape[1] >= 64) { mDim = 4; nDim = 64; @@ -207,8 +208,6 @@ class BlockedToMFMA : public mlir::RewritePattern { mDim = 64; nDim = 4; } else { - assert(opType.getShape()[1] >= 64 && - "k should be at least 64 to use this layout"); mDim = 4; nDim = 4; } @@ -227,7 +226,7 @@ class BlockedToMFMA : public mlir::RewritePattern { assert(mDim != 0 && nDim != 0); assert(resShape[0] % mDim == 0 && resShape[1] % nDim == 0); assert(opType.getShape()[1] % kDim == 0); - return {mDim, nDim, kDim}; + return maybeMfmaInsn.value(); } mlir::LogicalResult @@ -259,7 +258,10 @@ class BlockedToMFMA : public mlir::RewritePattern { ttg::MfmaEncodingAttr mfmaEnc; - auto [mDim, nDim, kDim] = chooseMfmaDimensions(dotOp); + auto instr = chooseMfmaDimensions(dotOp); + auto mDim = instr.getMDim(); + auto nDim = instr.getNDim(); + auto kDim = instr.getKDim(); auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim}); @@ -290,33 +292,24 @@ class BlockedToMFMA : public mlir::RewritePattern { // kWidth is initialized as k_base, which is the number of elements hold by // one thread per mfma instruction - auto kWidth = -1; - // in mfma 32x32 case argument matrix groups elements in 2 groups - // in mfma 16x16 case argument matrix groups elements in 4 groups - // in mfma 4x4 case argument matrix groups in 16 groups - if (mDim == 32 && nDim == 32) - kWidth = kDim / 2; - if (mDim == 16 && nDim == 16) - kWidth = kDim / 4; - if (mDim == 4 && nDim == 4) - kWidth = kDim / 16; - if (mDim == 4 && nDim == 64 || mDim == 64 && nDim == 4) - kWidth = kDim; - assert(kWidth != -1); + auto kWidthA = instr.getKBaseA(); + auto kWidthB = instr.getKBaseB(); // We want to extend kWidth by kpack (kpack=1 means no extension) // to increase ds_read vector size // However, in FA, the second dot can only use kWidth = k_bse since it's // limited by the result of the first dot, which is of mfmaLayout. - if (!isSecondDot(dotOp)) - kWidth *= kpack; + if (!isSecondDot(dotOp)) { + kWidthA *= kpack; + kWidthB *= kpack; + } auto newAType = RankedTensorType::get( oldAType.getShape(), oldAType.getElementType(), - ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth)); + ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidthA)); auto newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), - ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth)); + ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidthB)); a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 23f1befd2617..5a5046b1f8f0 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -669,173 +669,183 @@ using MfmaInsnGroupMap = llvm::DenseMap const MfmaInsnGroupMap & { static MfmaInsnGroupMap MfmaInsnMap{ + // MFMA tile description: + // M N K k_base_a k_base_b instr_name // f32 // mfma_f32_32x32x2f32 {{32, 32, MfmaTypeId::Fp32TyId, 1}, - {32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, + {32, 32, 2, 1, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, {{32, 32, MfmaTypeId::Fp32TyId, 2}, - {32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, + {32, 32, 2, 1, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, {{32, 32, MfmaTypeId::Fp32TyId, 3}, - {32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, + {32, 32, 2, 1, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, // mfma_f32_16x16x4f32 {{16, 16, MfmaTypeId::Fp32TyId, 1}, - {16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, + {16, 16, 4, 1, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, {{16, 16, MfmaTypeId::Fp32TyId, 2}, - {16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, + {16, 16, 4, 1, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, {{16, 16, MfmaTypeId::Fp32TyId, 3}, - {16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, + {16, 16, 4, 1, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, // mfma_f32_4x4x1f32 {{4, 4, MfmaTypeId::Fp32TyId, 1}, - {4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 4, 16, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{4, 4, MfmaTypeId::Fp32TyId, 2}, - {4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 4, 16, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{4, 64, MfmaTypeId::Fp32TyId, 1}, - {4, 64, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 64, 16, 1, 16, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{4, 64, MfmaTypeId::Fp32TyId, 2}, - {4, 64, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 64, 16, 1, 16, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{64, 4, MfmaTypeId::Fp32TyId, 1}, - {64, 4, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {64, 4, 16, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{64, 4, MfmaTypeId::Fp32TyId, 2}, - {64, 4, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {64, 4, 16, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, // mfma_f32_4x4x1_16B_f32 {{4, 4, MfmaTypeId::Fp32TyId, 3}, - {4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 4, 16, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{4, 64, MfmaTypeId::Fp32TyId, 3}, - {4, 64, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 64, 16, 1, 16, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{64, 4, MfmaTypeId::Fp32TyId, 3}, - {64, 4, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {64, 4, 16, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, // f16 // mfma_f32_32x32x8f16 {{32, 32, MfmaTypeId::Fp16TyId, 1}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, {{32, 32, MfmaTypeId::Fp16TyId, 2}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, {{32, 32, MfmaTypeId::Fp16TyId, 3}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, // mfma_f32_16x16x16xf16 {{16, 16, MfmaTypeId::Fp16TyId, 1}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, {{16, 16, MfmaTypeId::Fp16TyId, 2}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, {{16, 16, MfmaTypeId::Fp16TyId, 3}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, // mfma_f32_4x4x4f16 {{4, 4, MfmaTypeId::Fp16TyId, 1}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 4, MfmaTypeId::Fp16TyId, 2}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 4, MfmaTypeId::Fp16TyId, 3}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 64, MfmaTypeId::Fp16TyId, 1}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 64, MfmaTypeId::Fp16TyId, 2}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 64, MfmaTypeId::Fp16TyId, 3}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{64, 4, MfmaTypeId::Fp16TyId, 1}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{64, 4, MfmaTypeId::Fp16TyId, 2}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{64, 4, MfmaTypeId::Fp16TyId, 3}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, // bf16 // mfma_f32_32x32x4_bf16 {{32, 32, MfmaTypeId::Bf16TyId, 1}, - {32, 32, 4, 2, ROCDL::mfma_f32_32x32x4bf16::getOperationName()}}, + {32, 32, 4, 2, 2, ROCDL::mfma_f32_32x32x4bf16::getOperationName()}}, // mfma_f32_32x32x8_bf16_1K {{32, 32, MfmaTypeId::Bf16TyId, 2}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}}, {{32, 32, MfmaTypeId::Bf16TyId, 3}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}}, // mfma_f32_16x16x8_bf16 {{16, 16, MfmaTypeId::Bf16TyId, 1}, - {16, 16, 8, 2, ROCDL::mfma_f32_16x16x8bf16::getOperationName()}}, + {16, 16, 8, 2, 2, ROCDL::mfma_f32_16x16x8bf16::getOperationName()}}, // mfma_f32_16x16x16_bf16_1K {{16, 16, MfmaTypeId::Bf16TyId, 2}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}}, {{16, 16, MfmaTypeId::Bf16TyId, 3}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}}, // mfma_f32_4x4x2_bf16 {{4, 4, MfmaTypeId::Bf16TyId, 1}, - {4, 4, 32, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, + {4, 4, 32, 2, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, {{4, 64, MfmaTypeId::Bf16TyId, 1}, - {4, 64, 2, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, + {4, 64, 32, 2, 32, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, {{64, 4, MfmaTypeId::Bf16TyId, 1}, - {64, 4, 2, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, + {64, 4, 32, 32, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, // mfma_f32_4x4x4_bf16_1K {{4, 4, MfmaTypeId::Bf16TyId, 2}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{4, 4, MfmaTypeId::Bf16TyId, 3}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{4, 64, MfmaTypeId::Bf16TyId, 2}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{4, 64, MfmaTypeId::Bf16TyId, 3}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{64, 4, MfmaTypeId::Bf16TyId, 2}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{64, 4, MfmaTypeId::Bf16TyId, 3}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, // int8 // mfma_i32_32x32x8i8 {{32, 32, MfmaTypeId::I8TyId, 1}, - {32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}}, {{32, 32, MfmaTypeId::I8TyId, 2}, - {32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}}, // mfma_i32_32x32x16i8 {{32, 32, MfmaTypeId::I8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_i32_32x32x16_i8::getOperationName()}}, + {32, 32, 16, 8, 8, ROCDL::mfma_i32_32x32x16_i8::getOperationName()}}, // mfma_i32_16x16x16i8 {{16, 16, MfmaTypeId::I8TyId, 1}, - {16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}}, {{16, 16, MfmaTypeId::I8TyId, 2}, - {16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}}, // mfma_i32_16x16x32i8 {{16, 16, MfmaTypeId::I8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_i32_16x16x32_i8::getOperationName()}}, + {16, 16, 32, 8, 8, ROCDL::mfma_i32_16x16x32_i8::getOperationName()}}, // mfma_i32_4x4x4i8 {{4, 4, MfmaTypeId::I8TyId, 1}, - {4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 4, MfmaTypeId::I8TyId, 2}, - {4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 4, MfmaTypeId::I8TyId, 3}, - {4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 64, MfmaTypeId::I8TyId, 1}, - {4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 64, MfmaTypeId::I8TyId, 2}, - {4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 64, MfmaTypeId::I8TyId, 3}, - {4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{64, 4, MfmaTypeId::I8TyId, 1}, - {64, 4, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{64, 4, MfmaTypeId::I8TyId, 2}, - {64, 4, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{64, 4, MfmaTypeId::I8TyId, 3}, - {64, 4, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, // fp8 * pf8 // mfma_f32_32x32x16_FP8_FP8 {{32, 32, MfmaTypeId::Fp8Fp8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName()}}, + {32, 32, 16, 8, 8, + ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName()}}, // mfma_f32_16x16x32_FP8_FP8 {{16, 16, MfmaTypeId::Fp8Fp8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName()}}, + {16, 16, 32, 8, 8, + ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName()}}, // mfma_f32_32x32x16_FP8_BF8 {{32, 32, MfmaTypeId::Fp8Bf8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName()}}, + {32, 32, 16, 8, 8, + ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName()}}, // mfma_f32_16x16x32_FP8_BF8 {{16, 16, MfmaTypeId::Fp8Bf8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName()}}, + {16, 16, 32, 8, 8, + ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName()}}, // mfma_f32_32x32x16_BF8_FP8 {{32, 32, MfmaTypeId::Bf8Fp8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName()}}, + {32, 32, 16, 8, 8, + ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName()}}, // mfma_f32_16x16x32_BF8_FP8 {{16, 16, MfmaTypeId::Bf8Fp8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName()}}, + {16, 16, 32, 8, 8, + ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName()}}, // mfma_f32_32x32x16_BF8_BF8 {{32, 32, MfmaTypeId::Bf8Bf8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName()}}, + {32, 32, 16, 8, 8, + ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName()}}, // mfma_f32_16x16x32_BF8_BF8 {{16, 16, MfmaTypeId::Bf8Bf8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName()}}}; + {16, 16, 32, 8, 8, + ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName()}}}; return MfmaInsnMap; }; @@ -859,6 +869,7 @@ unsigned MfmaInsn::getKDim() { return attr.k; } unsigned MfmaInsn::getMDim() { return attr.m; } unsigned MfmaInsn::getNDim() { return attr.n; } StringRef MfmaInsn::getInsnName() { return attr.insn; } -unsigned MfmaInsn::getKBase() { return attr.k_base;} +unsigned MfmaInsn::getKBaseA() { return attr.k_base_a; } +unsigned MfmaInsn::getKBaseB() { return attr.k_base_b; } } // namespace mlir diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 0a451d539453..f24bf301c6d0 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -1666,7 +1666,7 @@ def kernel(X, stride_xm, stride_xn, if not (allow_tf32 and (in_dtype in ['float16']))] + [(*shape, warps, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, 1) - for shape in [(64, 16, 128), (16, 64, 128)] + for shape in [(64, 64, 128), (16, 64, 128)] for warps in [1, 4] for epilogue in ['none', 'trans', 'add-matrix', 'chain-dot', 'softmax'] for allow_tf32 in [False] @@ -1706,8 +1706,9 @@ def kernel(X, stride_xm, stride_xn, [4, 32, 64, 4], [32, 4, 64, 2], [16, 4, 64, 8], - [64, 4, 16, 1], - [4, 64, 16, 1], + [64, 4, 64, 1], + [4, 64, 64, 1], + [4, 64, 64, 4], ] for allow_tf32 in [False, True] for col_a in [True, False] diff --git a/scripts/amd/lit_tests/generate_accelerate_matmul_tests.py b/scripts/amd/lit_tests/generate_accelerate_matmul_tests.py new file mode 100755 index 000000000000..5bcc266634d1 --- /dev/null +++ b/scripts/amd/lit_tests/generate_accelerate_matmul_tests.py @@ -0,0 +1,182 @@ +import argparse +import sys + +# M N K a_ty b_ty c_ty +configs = [[32, 32, 32, "f16", "f16", "f32"], + [32, 32, 32, "bf16", "bf16", "f32"], + [32, 32, 32, "f32", "f32", "f32"], + [32, 32, 32, "i8", "i8", "i32"], + [32, 32, 32, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [32, 32, 32, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [32, 32, 32, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [32, 32, 32, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], + + [16, 16, 32, "f16", "f16", "f32"], + [16, 16, 32, "bf16", "bf16", "f32"], + [16, 16, 32, "f32", "f32", "f32"], + [16, 16, 32, "i8", "i8", "i32"], + [16, 16, 32, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [16, 16, 32, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [16, 16, 32, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [16, 16, 32, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], + + [4, 4, 64, "f16", "f16", "f32"], + [4, 4, 64, "bf16", "bf16", "f32"], + [4, 4, 64, "f32", "f32", "f32"], + [4, 4, 64, "i8", "i8", "i32"], + [4, 4, 64, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [4, 4, 64, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [4, 4, 64, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [4, 4, 64, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], + + [64, 4, 64, "f16", "f16", "f32"], + [64, 4, 64, "bf16", "bf16", "f32"], + [64, 4, 64, "f32", "f32", "f32"], + [64, 4, 64, "i8", "i8", "i32"], + [64, 4, 64, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [64, 4, 64, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [64, 4, 64, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [64, 4, 64, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"], + + [4, 64, 64, "f16", "f16", "f32"], + [4, 64, 64, "bf16", "bf16", "f32"], + [4, 64, 64, "f32", "f32", "f32"], + [4, 64, 64, "i8", "i8", "i32"], + [4, 64, 64, "f8E4M3FNUZ", "f8E4M3FNUZ", "f32"], + [4, 64, 64, "f8E4M3FNUZ", "f8E5M2FNUZ", "f32"], + [4, 64, 64, "f8E5M2FNUZ", "f8E4M3FNUZ", "f32"], + [4, 64, 64, "f8E5M2FNUZ", "f8E5M2FNUZ", "f32"] + ] + +def generate(cdna_version, output_file): + arch_names = {0:"", 1: "gfx908", 2: "gfx90a", 3: "gfx940"} + arch_name = arch_names[cdna_version] + print(f"// This file is generated: $ python3 {' '.join(sys.argv)}", file=output_file) + print(f"// RUN: (! triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul=arch-generation-name={arch_name} --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null) | FileCheck --check-prefixes=CHECK %s", file=output_file) + + for cfg_id in range(len(configs)): + cfg = configs[cfg_id] + + cfg_name = "_".join([str(item) for item in cfg]) + + M, N, K, a_ty, b_ty, c_ty = cfg + if "i" in c_ty: + cst_val = "0" + else: + cst_val = "0.000000e+00" + + supported = True + if cdna_version < 3 and ("f8" in a_ty or "f8" in b_ty): + supported = False + + if M >= 32 and N >= 32: + m_dim = 32 + n_dim = 32 + elif M >= 16 and N >= 16: + m_dim = 16 + n_dim = 16 + elif M >= 64 and N < 16: + m_dim = 64 + n_dim = 4 + elif M < 16 and N >= 64: + m_dim = 4 + n_dim = 64 + elif M < 16 and N < 16: + m_dim = 4 + n_dim = 4 + if ("f8" in a_ty or "f8" in b_ty) and min(m_dim, n_dim) == 4: + supported = False + + if cdna_version == 1: + if a_ty == "f16": + k_width0 = 4 + k_width1 = 4 + if a_ty == "bf16": + k_width0 = 2 + k_width1 = 2 + if a_ty == "i8": + k_width0 = 4 + k_width1 = 4 + if a_ty == "f32": + k_width0 = 1 + k_width1 = 1 + if cdna_version == 2: + if a_ty == "f16": + k_width0 = 4 + k_width1 = 4 + if a_ty == "bf16": + k_width0 = 4 + k_width1 = 4 + if a_ty == "i8": + k_width0 = 4 + k_width1 = 4 + if a_ty == "f32": + k_width0 = 1 + k_width1 = 1 + if cdna_version == 3: + if "f8" in a_ty: + k_width0 = 8 + k_width1 = 8 + if a_ty == "f16": + k_width0 = 4 + k_width1 = 4 + if a_ty == "bf16": + k_width0 = 4 + k_width1 = 4 + if a_ty == "i8": + if min(m_dim, n_dim) == 4: + k_width0 = 4 + k_width1 = 4 + else: + k_width0 = 8 + k_width1 = 8 + if a_ty == "f32": + k_width0 = 1 + k_width1 = 1 + if m_dim == 64: + k_width0 *= 16 + if n_dim == 64: + k_width1 *= 16 + + if supported: + mfma_check = f"// CHECK: #mfma = #triton_gpu.mfma<{{versionMajor = {cdna_version}, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [{m_dim}, {n_dim}], isTransposed = false}}>" + label_check = f"// CHECK: convert_dot_{cfg_name}" + checks =f"""// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #blocked>) -> tensor<{{{{.*}}}}, #mfma> +// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 0, parent = #blocked}}>>) -> tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 0, parent = #mfma, kWidth = {k_width0}}}>> +// CHECK: triton_gpu.convert_layout {{{{.*}}}} : (tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 1, parent = #blocked}}>>) -> tensor<{{{{.*}}}}, #triton_gpu.dot_op<{{opIdx = 1, parent = #mfma, kWidth = {k_width1}}}>>""" + else: + mfma_check = "" + label_check = f"// CHECK-NOT: convert_dot_{cfg_name}" + checks = "" + + case_text = f''' +!a_ty = {a_ty} +!b_ty = {b_ty} +!c_ty = {c_ty} +#blocked = #triton_gpu.blocked<{{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}}> +#dot_operand_a = #triton_gpu.dot_op<{{opIdx=0, parent=#blocked}}> +#dot_operand_b = #triton_gpu.dot_op<{{opIdx=1, parent=#blocked}}> +module attributes {{"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32}} {{ +{mfma_check} +{label_check} + tt.func @convert_dot_{cfg_name}(%a: tensor<{M}x{K}x!a_ty, #dot_operand_a>, %b: tensor<{K}x{N}x!b_ty, #dot_operand_b>) -> tensor<{M}x{N}x!c_ty, #blocked> {{ + %cst_c = arith.constant dense<{cst_val}> : tensor<{M}x{N}x!c_ty, #blocked> +{checks} + %D = tt.dot %a, %b, %cst_c {{allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false}} : tensor<{M}x{K}x!a_ty, #dot_operand_a> * tensor<{K}x{N}x!b_ty, #dot_operand_b> -> tensor<{M}x{N}x!c_ty, #blocked> + tt.return %D: tensor<{M}x{N}x!c_ty, #blocked> + }} +}} + +''' + if cfg_id == len(configs) - 1: + print(case_text, end="", file=output_file) + else: + print(case_text, end="// -----\n", file=output_file) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("cdna_version", type=int) + parser.add_argument("output_file", type=str) + args = parser.parse_args() + with open(args.output_file, "w") as f: + generate(cdna_version=args.cdna_version, output_file=f) diff --git a/test/TritonGPU/accelerate-matmul-cdna1.mlir b/test/TritonGPU/accelerate-matmul-cdna1.mlir index 51956c590035..07886f2f2e21 100644 --- a/test/TritonGPU/accelerate-matmul-cdna1.mlir +++ b/test/TritonGPU/accelerate-matmul-cdna1.mlir @@ -1,3 +1,4 @@ +// This file is generated: $ python3 ../scripts/amd/lit_tests/generate_accelerate_matmul_tests.py 1 ../test/TritonGPU/accelerate-matmul-cdna1.mlir // RUN: (! triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul=arch-generation-name=gfx908 --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null) | FileCheck --check-prefixes=CHECK %s !a_ty = f16 @@ -488,13 +489,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_f16_f16_f32 - tt.func @convert_dot_64_4_4_f16_f16_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_f16_f16_f32 + tt.func @convert_dot_64_4_64_f16_f16_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -509,13 +510,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_bf16_bf16_f32 - tt.func @convert_dot_64_4_4_bf16_bf16_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_bf16_bf16_f32 + tt.func @convert_dot_64_4_64_bf16_bf16_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 2}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 32}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 2}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -530,13 +531,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_f32_f32_f32 - tt.func @convert_dot_64_4_4_f32_f32_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_f32_f32_f32 + tt.func @convert_dot_64_4_64_f32_f32_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 1}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 16}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 1}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -551,13 +552,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_i8_i8_i32 - tt.func @convert_dot_64_4_4_i8_i8_i32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_i8_i8_i32 + tt.func @convert_dot_64_4_64_i8_i8_i32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -572,11 +573,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E4M3FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E4M3FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -591,11 +592,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E4M3FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E4M3FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -610,11 +611,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E5M2FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E5M2FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -629,11 +630,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E5M2FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E5M2FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -648,13 +649,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_f16_f16_f32 - tt.func @convert_dot_4_64_4_f16_f16_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_f16_f16_f32 + tt.func @convert_dot_4_64_64_f16_f16_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -669,13 +670,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_bf16_bf16_f32 - tt.func @convert_dot_4_64_4_bf16_bf16_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_bf16_bf16_f32 + tt.func @convert_dot_4_64_64_bf16_bf16_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 2}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 2}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 32}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -690,13 +691,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_f32_f32_f32 - tt.func @convert_dot_4_64_4_f32_f32_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_f32_f32_f32 + tt.func @convert_dot_4_64_64_f32_f32_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 1}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 1}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 16}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -711,13 +712,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_i8_i8_i32 - tt.func @convert_dot_4_64_4_i8_i8_i32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_i8_i8_i32 + tt.func @convert_dot_4_64_64_i8_i8_i32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -732,11 +733,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E4M3FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E4M3FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -751,11 +752,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E4M3FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E4M3FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -770,11 +771,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E5M2FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E5M2FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -789,11 +790,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E5M2FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E5M2FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } diff --git a/test/TritonGPU/accelerate-matmul-cdna2.mlir b/test/TritonGPU/accelerate-matmul-cdna2.mlir index 0d853186bac3..4226f92ff0fd 100644 --- a/test/TritonGPU/accelerate-matmul-cdna2.mlir +++ b/test/TritonGPU/accelerate-matmul-cdna2.mlir @@ -1,3 +1,4 @@ +// This file is generated: $ python3 ../scripts/amd/lit_tests/generate_accelerate_matmul_tests.py 2 ../test/TritonGPU/accelerate-matmul-cdna2.mlir // RUN: (! triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul=arch-generation-name=gfx90a --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null) | FileCheck --check-prefixes=CHECK %s !a_ty = f16 @@ -488,13 +489,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_f16_f16_f32 - tt.func @convert_dot_64_4_4_f16_f16_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_f16_f16_f32 + tt.func @convert_dot_64_4_64_f16_f16_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -509,13 +510,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_bf16_bf16_f32 - tt.func @convert_dot_64_4_4_bf16_bf16_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_bf16_bf16_f32 + tt.func @convert_dot_64_4_64_bf16_bf16_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -530,13 +531,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_f32_f32_f32 - tt.func @convert_dot_64_4_4_f32_f32_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_f32_f32_f32 + tt.func @convert_dot_64_4_64_f32_f32_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 1}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 16}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 1}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -551,13 +552,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_i8_i8_i32 - tt.func @convert_dot_64_4_4_i8_i8_i32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_i8_i8_i32 + tt.func @convert_dot_64_4_64_i8_i8_i32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -572,11 +573,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E4M3FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E4M3FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -591,11 +592,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E4M3FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E4M3FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -610,11 +611,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E5M2FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E5M2FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -629,11 +630,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E5M2FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E5M2FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -648,13 +649,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_f16_f16_f32 - tt.func @convert_dot_4_64_4_f16_f16_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_f16_f16_f32 + tt.func @convert_dot_4_64_64_f16_f16_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -669,13 +670,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_bf16_bf16_f32 - tt.func @convert_dot_4_64_4_bf16_bf16_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_bf16_bf16_f32 + tt.func @convert_dot_4_64_64_bf16_bf16_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -690,13 +691,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_f32_f32_f32 - tt.func @convert_dot_4_64_4_f32_f32_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_f32_f32_f32 + tt.func @convert_dot_4_64_64_f32_f32_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 1}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 1}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 16}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -711,13 +712,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_i8_i8_i32 - tt.func @convert_dot_4_64_4_i8_i8_i32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_i8_i8_i32 + tt.func @convert_dot_4_64_64_i8_i8_i32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -732,11 +733,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E4M3FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E4M3FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -751,11 +752,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E4M3FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E4M3FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -770,11 +771,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E5M2FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E5M2FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -789,11 +790,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E5M2FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E5M2FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } diff --git a/test/TritonGPU/accelerate-matmul-cdna3.mlir b/test/TritonGPU/accelerate-matmul-cdna3.mlir index 24d8ee993615..02c096550345 100644 --- a/test/TritonGPU/accelerate-matmul-cdna3.mlir +++ b/test/TritonGPU/accelerate-matmul-cdna3.mlir @@ -1,3 +1,4 @@ +// This file is generated: $ python3 ../scripts/amd/lit_tests/generate_accelerate_matmul_tests.py 3 ../test/TritonGPU/accelerate-matmul-cdna3.mlir // RUN: (! triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul=arch-generation-name=gfx940 --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null) | FileCheck --check-prefixes=CHECK %s !a_ty = f16 @@ -504,13 +505,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_f16_f16_f32 - tt.func @convert_dot_64_4_4_f16_f16_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_f16_f16_f32 + tt.func @convert_dot_64_4_64_f16_f16_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -525,13 +526,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_bf16_bf16_f32 - tt.func @convert_dot_64_4_4_bf16_bf16_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_bf16_bf16_f32 + tt.func @convert_dot_64_4_64_bf16_bf16_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -546,13 +547,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_f32_f32_f32 - tt.func @convert_dot_64_4_4_f32_f32_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_f32_f32_f32 + tt.func @convert_dot_64_4_64_f32_f32_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 1}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 16}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 1}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -567,13 +568,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [64, 4], isTransposed = false}> -// CHECK: convert_dot_64_4_4_i8_i8_i32 - tt.func @convert_dot_64_4_4_i8_i8_i32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK: convert_dot_64_4_64_i8_i8_i32 + tt.func @convert_dot_64_4_64_i8_i8_i32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0> : tensor<64x4x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 64}>> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -588,11 +589,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E4M3FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E4M3FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -607,11 +608,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E4M3FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E4M3FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -626,11 +627,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E5M2FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E5M2FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -645,11 +646,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_64_4_4_f8E5M2FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_64_4_4_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x4x!a_ty, #dot_operand_a>, %b: tensor<4x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_64_4_64_f8E5M2FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_64_4_64_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<64x64x!a_ty, #dot_operand_a>, %b: tensor<64x4x!b_ty, #dot_operand_b>) -> tensor<64x4x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<64x4x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x4x!a_ty, #dot_operand_a> * tensor<4x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64x!a_ty, #dot_operand_a> * tensor<64x4x!b_ty, #dot_operand_b> -> tensor<64x4x!c_ty, #blocked> tt.return %D: tensor<64x4x!c_ty, #blocked> } } @@ -664,13 +665,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_f16_f16_f32 - tt.func @convert_dot_4_64_4_f16_f16_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_f16_f16_f32 + tt.func @convert_dot_4_64_64_f16_f16_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -685,13 +686,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_bf16_bf16_f32 - tt.func @convert_dot_4_64_4_bf16_bf16_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_bf16_bf16_f32 + tt.func @convert_dot_4_64_64_bf16_bf16_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -706,13 +707,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_f32_f32_f32 - tt.func @convert_dot_4_64_4_f32_f32_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_f32_f32_f32 + tt.func @convert_dot_4_64_64_f32_f32_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 1}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 1}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 16}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -727,13 +728,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { // CHECK: #mfma = #triton_gpu.mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [4, 64], isTransposed = false}> -// CHECK: convert_dot_4_64_4_i8_i8_i32 - tt.func @convert_dot_4_64_4_i8_i8_i32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK: convert_dot_4_64_64_i8_i8_i32 + tt.func @convert_dot_4_64_64_i8_i8_i32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0> : tensor<4x64x!c_ty, #blocked> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #blocked>) -> tensor<{{.*}}, #mfma> // CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> -// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> +// CHECK: triton_gpu.convert_layout {{.*}} : (tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>) -> tensor<{{.*}}, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 64}>> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -748,11 +749,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E4M3FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E4M3FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E4M3FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -767,11 +768,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E4M3FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E4M3FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E4M3FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -786,11 +787,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E5M2FNUZ_f8E4M3FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E5M2FNUZ_f8E4M3FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E5M2FNUZ_f8E4M3FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } @@ -805,11 +806,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { -// CHECK-NOT: convert_dot_4_64_4_f8E5M2FNUZ_f8E5M2FNUZ_f32 - tt.func @convert_dot_4_64_4_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x4x!a_ty, #dot_operand_a>, %b: tensor<4x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { +// CHECK-NOT: convert_dot_4_64_64_f8E5M2FNUZ_f8E5M2FNUZ_f32 + tt.func @convert_dot_4_64_64_f8E5M2FNUZ_f8E5M2FNUZ_f32(%a: tensor<4x64x!a_ty, #dot_operand_a>, %b: tensor<64x64x!b_ty, #dot_operand_b>) -> tensor<4x64x!c_ty, #blocked> { %cst_c = arith.constant dense<0.000000e+00> : tensor<4x64x!c_ty, #blocked> - %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x4x!a_ty, #dot_operand_a> * tensor<4x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> + %D = tt.dot %a, %b, %cst_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<4x64x!a_ty, #dot_operand_a> * tensor<64x64x!b_ty, #dot_operand_b> -> tensor<4x64x!c_ty, #blocked> tt.return %D: tensor<4x64x!c_ty, #blocked> } } From beb2f664841041d6d5b00b5bf7af66f7b37ab770 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Thu, 21 Mar 2024 19:32:10 +0000 Subject: [PATCH 3/8] change swizzling pattern --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 54bc607a4ede..82f1eb8b9283 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -144,15 +144,28 @@ compared to 1*64 when the hasLeadingOffset is false. int innerDimLength = shape[order[0]]; int elemsPerOneBanksRow = (numBanks * bankBitWidth) / typeWidthInBit; + int mDim = mfmaEnc.getMDim(); + int nDim = mfmaEnc.getNDim(); + int nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim; + if ((mDim == 4 && nDim == 64) || (nDim == 4 && mDim == 64)) { + // Operands of the layout have following shapes + // Large operand: + // - shape 64(non-k)x64(k) for 16 bit dtypes + // - shape 64(non-k)x16(k) for 32 bit dtypes + // Small operand: + // - shape 4(non-k)x64(k) for 16 bit dtypes + // - shape 4(non-k)x16(k) for 32 bit dtypes + const int vecSize = bankBitWidth / typeWidthInBit; + const int perPhase = std::max(1, numBanks / innerDimLength); + const int maxPhase = std::min(numBanks, nonKDim) / perPhase; + return get(context, vecSize, perPhase, maxPhase, order, CTALayout); + } int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); // vecSize is set to kWidth of the dotop layout int vecSize = dotOpEnc.getKWidth(); // maxPhase is set to SIMDWidth / perPhase int maxPhase = std::min(SIMDWidth / perPhase, innerDimLength / vecSize); // TODO (zhanglx): figure out better parameters for mfma4 - auto mDim = mfmaEnc.getMDim(); - auto nDim = mfmaEnc.getNDim(); - auto nonKDim = dotOpEnc.getOpIdx() == 0 ? mDim : nDim; if (4 == nonKDim) maxPhase = 4; // if maxPhase * perPhase is larger than one block of warps, From 1fdf1ac0980a371c4413f1982a6cc639f7355109 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Fri, 22 Mar 2024 16:28:06 +0000 Subject: [PATCH 4/8] add small attention script --- python/06-attention-decode.py | 853 ++++++++++++++++++++++++++++++++++ 1 file changed, 853 insertions(+) create mode 100644 python/06-attention-decode.py diff --git a/python/06-attention-decode.py b/python/06-attention-decode.py new file mode 100644 index 000000000000..04d985405a58 --- /dev/null +++ b/python/06-attention-decode.py @@ -0,0 +1,853 @@ +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple +import pytest +import torch +import sys + +import triton +import triton.language as tl + +def _strides(x: torch.Tensor, *stride_names: str): + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + + +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Seq_len, + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qk, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kk, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vk, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + Z, + N_CTX_Q, + N_CTX_K, + BLOCK_N_PER_SPLIT, + H: tl.constexpr, + G: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_SEQ_LEN: tl.constexpr, + PACKED_PER_VAL: tl.constexpr = 1, + N_GROUPS: tl.constexpr = 1, +): + """This kernel can accept non-quantized or int4-quantized keys/values. + PACKED_PER_VAL determines the quantization type: + - PACKED_PER_VAL == 1 means no quantization + - PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32) + For the quantized case K/V should be int32 tensors. + Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8. + Quantization coefficients are stored at the beginning of the row along the last dimension of K/V + So K[B, H, M, :] has a form + [ quant_coef0, quant_coef1, ...| + group0_quant_value0, group0_quant_value1,... | + group1_quant_value0, group1_quant_value1,...] + where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset. + + """ + tl.static_assert( + (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32)) + or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)), + f"Only 4-bit quantization is supported, K/V should have dtype int32 in " + f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}", + ) + tl.static_assert( + (((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8), + "Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.", + ) + + QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 + PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS + D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS + + start_m = tl.program_id(0) + off_zhg = tl.program_id(1) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + splitk_idx = tl.program_id(2) + + lo = splitk_idx * BLOCK_N_PER_SPLIT + if USE_SEQ_LEN: + kv_len = tl.load(Seq_len + off_z) + else: + kv_len = N_CTX_K + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) + + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + + k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg + # Additional shift by 1 along the last dimension in the quantized case, since + # the first element along that dim contains packed quantization coefficients. + K_block_ptr = tl.make_block_ptr( + base=k_base + stride_kk * QUANTIZED * N_GROUPS, + shape=(PACKED_D_PER_GROUP, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(PACKED_D_PER_GROUP, BLOCK_N), + order=(0, 1), + ) + v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg + V_block_ptr = tl.make_block_ptr( + base=v_base + stride_vk * QUANTIZED * N_GROUPS, + shape=(hi, PACKED_D_PER_GROUP), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, PACKED_D_PER_GROUP), + order=(1, 0), + ) + + if QUANTIZED: + # Pointers to quantization coefficients. Even those they are 1D, + # we have to use block pointers, since usual pointers + # don't support boundary checks + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(1, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi, 1), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + else: + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + acc = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821 + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + # q: "VAR_ARGS_ARRAY" # noqa: F821 + # for i in range(elem_num): # noqa: F821 + q = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0,) + ) + q = (q * qk_scale).to(tl.float16) + + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + # k: "VAR_ARGS_ARRAY" # noqa: F821 + # v: "VAR_ARGS_ARRAY" # noqa: F821 + # for i in range(len(acc)): # noqa: F821 + k, v = load_dequantize_k_v_group( # noqa: F821 + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + PACKED_PER_VAL, + PACKED_D_PER_GROUP, + Q.dtype.element_ty, + 0, + ) + # k.append(k_tmp) + # v.append(v_tmp) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # for i in range(elem_num): # noqa: F821 + qk += tl.dot(q, k) # noqa: F821 + #qk += tl.dot(q, k) # noqa: F821 + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + # for i in range(elem_num): # noqa: F821 + acc *= alpha[:, None] # noqa: F821 + #acc += tl.dot(p, v) # noqa: F821 + acc += tl.dot(p, v) # noqa: F821 + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + if PACKED_PER_VAL > 1: + K_scale_shift_block_ptr = tl.advance( + K_scale_shift_block_ptr, (0, BLOCK_N) + ) + V_scale_shift_block_ptr = tl.advance( + V_scale_shift_block_ptr, (BLOCK_N, 0) + ) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + # for i in range(elem_num): # noqa: F821 + tl.store( + tl.advance(O_block_ptr, (0, 0)), + acc, # noqa: F821 + boundary_check=(0,), + ) + # Write metadata for split-K reduction + Metadata_ptr = ( + Metadata + + off_zhg * stride_mzhg + + splitk_idx * stride_ms + + start_m * BLOCK_M + + tl.arange(0, BLOCK_M) + ) + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + +@triton.jit +def load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + PACKED_D_PER_GROUP: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, +): + #Load K/V for a given block. In case of int4-quantized K/V, + # dequantize them after loading. If quantization is group-wise, + # use group_id to advance the pointers to the current group. + + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()) + + if PACKED_PER_VAL > 1: + # K/V are quantized, load quantization coefficients and dequantize + + K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0)) + V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id)) + + k_scale_shift = tl.load( + K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else () + ) + v_scale_shift = tl.load( + V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else () + ) + + k_scale, k_shift = cast_uint32_to_half2(k_scale_shift) + v_scale, v_shift = cast_uint32_to_half2(v_scale_shift) + v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype) + k_t = dequantize( + tl.trans(k), + tl.trans(k_scale), + tl.trans(k_shift), + PACKED_PER_VAL, + ).to(dtype) + k = tl.trans(k_t) + return k, v + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + #and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = ( + x_[:, None, :] >> offsets[None, :, None] + ) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view( + quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) + ) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + +@triton.jit +def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + M_ceil:tl.constexpr, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + split_k:tl.constexpr, + splitK_pow2:tl.constexpr, + use_mask:tl.constexpr, +): + off_zhg = tl.program_id(0) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + off_m = tl.program_id(1) + off_k = tl.program_id(2) + + # read chunk + spk_idx = tl.arange(0, splitK_pow2) + kidx = tl.arange(0, BLOCK_SIZE) + + Metadata_ptr = ( + Metadata + + stride_mzhg * off_zhg + + spk_idx * stride_ms + + off_m * stride_mm + ) + + o_ptr = ( + Out_splitK + + off_zhg * stride_osk_zhg + + stride_osk_m * off_m + + off_k * BLOCK_SIZE + + stride_osk_s * spk_idx[:, None] + + kidx[None, :] * stride_osk_k + ) + + # read max values of each splitK + if use_mask: + spk_mask = spk_idx < split_k + l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) + l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) + acc = tl.load(o_ptr, mask=spk_mask[:,None], other=0.0) + else: + l_m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(o_ptr) + + g_m = tl.max(l_m, axis=0) + alpha = tl.math.exp2(l_m - g_m) + + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + + alpha = tl.math.exp2(l_m - g_m) + acc = acc * alpha[:, None] + acc_out = tl.sum(acc, axis=0) / g_sum + Out_ptr = ( + Out + + stride_oz * off_z + + stride_oh * off_h + + stride_og * off_g + + stride_om * off_m + + off_k * BLOCK_SIZE + + tl.arange(0, BLOCK_SIZE) + ) + tl.store(Out_ptr, acc_out) + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) + + +def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + # Scale and shift are such that quantization linearly maps + # int4 values range [0..15] to input values range min(k)..max(k) + # individually for every row + k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) + max_vals = torch.max(k, dim=-1, keepdim=True).values + min_vals = torch.min(k, dim=-1, keepdim=True).values + scale_k: torch.Tensor = (max_vals - min_vals) / 15 + + shift_k = torch.min(k, dim=-1, keepdim=True).values + scale_k = scale_k.to(torch.float16) + shift_k = shift_k.to(torch.float16) + + in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 + in_bytes = in_bytes.to(torch.uint8) + in_int4 = in_bytes & 0xF + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + scale_shift = torch.concat( + [scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1 + ) + k_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ).view(torch.int16) + return k_quant + + +def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + k_i16 = quant_k.view(torch.int16) + k_ui8 = k_i16.view(torch.uint8) + + ss_size = num_groups * 4 + scale_shift_ui8 = k_ui8[...,0:ss_size] + scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4) + scale = scale_shift_ui8[...,0:2].view(torch.float16) + shift = scale_shift_ui8[...,2:4].view(torch.float16) + + kv_ui8 = k_ui8[...,ss_size:] + k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1) + k1_i4 = k_ui8 & 0xF + k2_i4 = (k_ui8 & 0xF0) >> 4 + k_shape = k1_i4.shape + k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + + out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device) + out[...,::2] = k1_f16 + out[...,1::2] = k2_f16 + out = out.reshape(*k_shape[:-2], -1) + + return out + + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + + +class _attention(torch.autograd.Function): + + OPERATOR = _fwd_kernel_splitK + SUPPORTED_DEVICES = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES = { + torch.half, + torch.bfloat16, + } # Those are dtypes of Q. In the quantized case K/V has dtype int32 + SUPPORTED_MAX_K = 128 + # SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + # type(None), + # BlockDiagonalCausalWithOffsetPaddedKeysMask, + # } + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "triton_splitKF" + + @staticmethod + def forward(cls, q, k, v, scale_float): + + cls.SPLIT_K: Optional[int] = None + cls.BLOCK_M = 16 + cls.BLOCK_N = 64 + + cls.NUM_GROUPS = 1 # Default quantization is row-wise + + # attn_bias = inp.attn_bias + seq_len = None + + # Transpose in the case of MQA/GQA + mqa_swap_seqlen_head = False + if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0: + mqa_swap_seqlen_head = True + assert q.shape[1] == 1 + q = q.transpose(1, 3) + k = k[:, :, :, :1] + v = v[:, :, :, :1] + + if k.dtype == torch.int32: + # Quantized K/V + PACKED_PER_VAL = 8 + Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8 + else: + Lk = k.shape[-1] + PACKED_PER_VAL = 1 + + B, Mk, G, H, Kkv = k.shape + B, M, G, H, Kq = q.shape + assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}" + # print(f"B = {B}, M = {M}, G = {G}, H = {H}, Kkv = {Kkv}, Kq = {Kq}") + + BLOCK_M = cls.BLOCK_M + BLOCK_N = cls.BLOCK_N + if cls.SPLIT_K is not None: + split_k = cls.SPLIT_K + else: + # Use heuristics + split_k = get_split_k(B, G, H, Mk) + + M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M + o_splitk = torch.empty( + [B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device + ) + metadata = torch.empty( + [B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device + ) + lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32) + grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k) + + split_size = (Mk + split_k - 1) // split_k + use_seq_len = seq_len is not None + + #print(f"B = {B}, G = {G}, H = {H}, split_k = {split_k}, M_ceil = {M_ceil}, Kq = {Kq}, num_of_wgs = {G * G * H * split_k}") + #print(grid) + #print(_strides(k, "kz", "kn", "kg", "kh", "kk")) + #print("BLOCK_N", BLOCK_N) + + pgm = _fwd_kernel_splitK[grid]( + Q=q, + K=k, + V=v, + sm_scale=scale_float, + Out_splitK=o_splitk, + Metadata=metadata, + Seq_len=seq_len, + **_strides(q, "qz", "qm", "qg", "qh", "qk"), + **_strides(k, "kz", "kn", "kg", "kh", "kk"), + **_strides(v, "vz", "vn", "vg", "vh", "vk"), + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + Z=B, + H=H, + G=G, + N_CTX_Q=M, + N_CTX_K=Mk, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len, + USE_SEQ_LEN=use_seq_len, + num_warps=4, + num_stages=1, + PACKED_PER_VAL=PACKED_PER_VAL, + N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1, + matrix_instr_nonkdim=464, + waves_per_eu=1 + ) + #print(f"kernel run B = {B}, G = {G}, H = {H}, split_k = {split_k}, M_ceil = {M_ceil}, Kq = {Kq}\n", pgm.asm["amdgcn"]) + + if mqa_swap_seqlen_head: + out = torch.empty( + (B, H, G, M, Kq), device=q.device, dtype=q.dtype + ).transpose(1, 3) + else: + out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + use_mask = splitK_pow2 > split_k + if B * G * H * M >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert out.shape[-1] % k_block_num == 0 + k_block_size = out.shape[-1] // k_block_num + grid = (B * G * H, M, k_block_num) + #print("reduce split", split_k, k_block_size, k_block_num) + _splitK_reduce[grid]( + o_splitk, + metadata, + out, + lse, + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), + M_ceil=M_ceil, + BLOCK_SIZE=k_block_size, + G=G, + H=H, + # TODO: Tune num_warps + split_k=split_k, + splitK_pow2=splitK_pow2, + use_mask=use_mask, + num_warps=4 + ) + + lse = lse.reshape([B, G, H, M]) + if mqa_swap_seqlen_head: + # H/M dimensions have been swapped + out = out.transpose(1, 3) + lse = lse.transpose(2, 3) + if q.ndim == 4: + # BMGHK -> BMHK + assert G == 1 + out = out[:, :, 0] + lse = lse[:, 0] + if Mk == 0: + out.zero_() + if mqa_swap_seqlen_head: + out = out.reshape(B, -1, M * G, Kq).transpose(1, 2).contiguous() + else: + out = out.reshape(B, H * G, -1, Kq).contiguous() + + return out + +attention = _attention.apply + +def get_input_shapes(): +# cases = [ +# (max(1, 2 ** (16 - i)), 1, 2**i, 16, 1, 128) +# for i in range(13, 14) +# ] +# return cases + cases = [ + (max(1, 2 ** (16 - i)), 1, 2**i, 16, 1, 128) + for i in range(8, 14) + ] + [ + (max(1, 2 ** (16 - i)), 1, 2**i, 16, 2, 128) + for i in range(8, 14) + ] + cases += [(4, 1, 8192, 16, 1, 128)] + + return cases + + +@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', + get_input_shapes()) +def test_op_fwd(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): + torch.manual_seed(20) + q = ( + torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + v = ( + torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, device="cuda") + .normal_(mean=0., std=0.5) + .requires_grad_() + ).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + scale = 1 / K**0.5 + tri_out = attention(q, k, v, scale) + + q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-3, rtol=0.01) + + +@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', + get_input_shapes()) +def test_op_fwd_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): + torch.manual_seed(2) + q = ( + torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, device="cuda") + .normal_(mean=1.0, std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, device="cuda") + .normal_(mean=1.0, std=0.5) + .requires_grad_() + ).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + v = ( + torch.empty((B, Mkv, Hkv, 1, K), dtype=dtype, device="cuda") + .normal_(mean=1.0, std=0.5) + .requires_grad_() + ).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) + + num_groups = 1 + quant_k = ( + quantize_kv_int4(k, num_groups=num_groups) + .contiguous() + .view(torch.int32) + ) + quant_v = ( + quantize_kv_int4(v, num_groups=num_groups) + .contiguous() + .view(torch.int32) + ) + scale = 1 / K**0.5 + tri_out = attention(q, quant_k, quant_v, scale) + + q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + ref_out = attn @ v + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2.1e-2, rtol=0) + + # since quantization introduces rounding error, use the + # dequantized kv as inputs to the ref implementation to reduce + # the tolerance to 1e-3 + dqk = dequantize_kv_fp16(quant_k, num_groups=num_groups) + dqv = dequantize_kv_fp16(quant_v, num_groups=num_groups) + dqk = dqk.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dqv = dqv.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + dq_attn = (q @ dqk.transpose(-1, -2) * scale).softmax(-1) + dq_ref_out = dq_attn @ dqv + torch.testing.assert_close(dq_ref_out, tri_out, atol=1e-3, rtol=0) + + +def test_quantization(): + a = torch.randn((2, 4, 32), dtype=torch.float16, device='cuda') + qa = quantize_kv_int4(a, num_groups=4) + dqa = dequantize_kv_fp16(qa, num_groups=4) + torch.testing.assert_close(a, dqa, atol=1.5e-1, rtol=1e-1) + + +try: + from flash_attn.flash_attn_interface import \ + flash_attn_qkvpacked_func as flash_attn_func + FLASH_VER = 2 +except BaseException: + try: + from flash_attn.flash_attn_interface import flash_attn_func + FLASH_VER = 1 + except BaseException: + FLASH_VER = None +HAS_FLASH = FLASH_VER is not None + +configs = [] +for mode in ['fwd']: + # for D_HEAD in [128]: + for causal in [False]: + configs.append(triton.testing.Benchmark( + x_names=['B', 'Mq','Mkv', 'Hq', 'Hkv', 'K'], + x_vals=get_input_shapes(), + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-d{128}-{mode}-causal={causal}', + args={ + # 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'causal': causal}) + ) + + +@triton.testing.perf_report(configs) +def bench_flash_attention(B, Mq, Mkv, Hq, Hkv, K, causal, mode, provider, dtype=torch.float16, device="cuda"): + assert mode in ['fwd', 'bwd'] + warmup = 100 + rep = 400 + ms = 0 + if provider == "triton": + q = torch.randn( + [B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=False + ) + k = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=False + ).expand(-1, -1, -1, Hq // Hkv, -1) + v = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=False + ).expand(-1, -1, -1, Hq // Hkv, -1) + + sm_scale = 1.3 + fn = lambda: attention(q, k, v, sm_scale) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + + flops_per_matmul = 2 * B * Hq * (Mq * K * Mkv + Mq * Mkv * K) + total_flops = 2 * flops_per_matmul + totalBytes = ((B * Mkv * Hkv * K * 2) + (B * Mq * Hq * K) + (B * Mq * Hq * K)) * 2 + + # return totalBytes / ms * 1e-9 + return ms * 1000 + + +def main(): + bench_flash_attention.run(save_path='.', print_data=True) + +if __name__ == '__main__': + sys.exit(main()) From c7cc839f7bfa9aeb2966c975d35f32088c91b041 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Mon, 25 Mar 2024 21:15:58 +0000 Subject: [PATCH 5/8] adjust tuning script for mfma464 tuning command: `python3 tune_gemm.py --compare -m 16 -n 64 -k 64 -dtype_a "fp16" -dtype_b "fp16" -dtype_c "fp32" --num_threads=16 --keep` --- scripts/amd/gemm/one_config.py | 16 ++++++++++++++-- scripts/amd/gemm/tune_gemm.py | 14 +++++++------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/scripts/amd/gemm/one_config.py b/scripts/amd/gemm/one_config.py index 1aa228f69a3e..b746b8268956 100644 --- a/scripts/amd/gemm/one_config.py +++ b/scripts/amd/gemm/one_config.py @@ -24,8 +24,10 @@ def parse_args(): parser.add_argument("--num_warps", type=int, default=0) parser.add_argument("--num_stages", type=int, default=0) parser.add_argument("--waves_per_eu", type=int, default=0) + parser.add_argument("--kpack", type=int, default=0) + parser.add_argument("--mfma", type=int, default=0) - parser.add_argument("--config_str", type=str, default="", help="can take from gemm_tune.py script output, looks like M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0") + parser.add_argument("--config_str", type=str, default="", help="can take from gemm_tune.py script output, looks like M16_N8_K128_BM64_BN64_BK64_GM1_SK2_nW2_nS0_EU0_kP1_mfma16") args = parser.parse_args() return args @@ -44,6 +46,8 @@ def parse_config(cfg_str): "nW": "num_warps", "nS": "num_stages", "EU": "waves_per_eu", + "kP": "kpack", + "mfma": "matrix_instr_nonkdim", } config = {} for val in values: @@ -70,8 +74,16 @@ def main(): "num_warps": args.num_warps, "num_stages": args.num_stages, "waves_per_eu": args.waves_per_eu, + "kpack": args.kpack, + "matrix_instr_nonkdim": args.mfma, } - tune_gemm.test_correctness(config["M"], config["N"], config["K"], config, verbose=True) + col_a = False + col_b = False + dtype_a = "fp16" + dtype_b = "fp16" + dtype_c = "fp32" + init_type = "randn" + tune_gemm.test_correctness(config["M"], config["N"], config["K"], col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, verbose=True) if __name__ == "__main__": diff --git a/scripts/amd/gemm/tune_gemm.py b/scripts/amd/gemm/tune_gemm.py index 51ecdd285d45..8295723d526a 100644 --- a/scripts/amd/gemm/tune_gemm.py +++ b/scripts/amd/gemm/tune_gemm.py @@ -30,7 +30,7 @@ def get_full_tuning_space(): # other values in the future num_stage_range = [0] waves_per_eu_range = [0] - matrix_instr_nonkdim_range = [16, 32] + matrix_instr_nonkdim_range = [464, 16] kpack_range = [1, 2] for block_m in block_mn_range: @@ -68,9 +68,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): num_warps = config.get("num_warps") matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") kpack = config.get("kpack") - if matrix_instr_nonkdim > mfma: - continue - if mfma == 4 and BLOCK_SIZE_K < 64: + if matrix_instr_nonkdim == 464 and BLOCK_SIZE_K < 64: continue # some layouts could not work properly in case # number elemens per thread is less 1 @@ -78,11 +76,13 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): continue SPLIT_K = config.get("SPLIT_K") GROUP_M = config.get("GROUP_SIZE_M") - if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim: + matrix_instr_m = 4 if matrix_instr_nonkdim > 32 else matrix_instr_nonkdim + matrix_instr_n = 64 if matrix_instr_nonkdim > 32 else matrix_instr_nonkdim + if BLOCK_SIZE_M < matrix_instr_m or BLOCK_SIZE_N < matrix_instr_n: continue - if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim: + if M <= matrix_instr_m and BLOCK_SIZE_M != matrix_instr_m: continue - if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim: + if N <= matrix_instr_n and BLOCK_SIZE_N != matrix_instr_n: continue # Skip BLOCK_SIZE that is too large compare to M/N # unless BLOCK_SIZE is already small enough From cd393f025a66b306c80329336e6561446c25b31d Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Wed, 27 Mar 2024 14:07:31 +0000 Subject: [PATCH 6/8] simplify fast path computations in shared to mfma conversion --- .../SharedToDotOperandMFMA.cpp | 50 ++++++++++--------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 77d5f6ca5160..d803160fa181 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -326,33 +326,35 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, Value waveOffset = mul(waveId, i32_val(iNonKDim)); Value colOffset = urem(laneId, _nonKDim); + // halfOffset is an offset related to wrapping of wave in the tile. + // for example, mfma 32 case (mapping of tensor elements to lane ids in + // wave): + // + // 0 1 2 3 ... 31 + // 0 1 2 3 ... 31 + // 0 1 2 3 ... 31 + // 0 1 2 3 ... 31 + // 32 33 34 35 ... 63 <- at this point wave is wrapping + // 32 33 34 35 ... 63 + // 32 33 34 35 ... 63 + // 32 33 34 35 ... 63 + Value halfWaveOffset; + if (iNonKDim == 64) + halfWaveOffset = i32_val(0); + else + halfWaveOffset = + mul(udiv(laneId, _nonKDim), i32_val(numOfElems * lineSize)); + + // sum of offsets dependent from lane id and warp Id across non-k dim + Value baseThreadOffset = add(add(halfWaveOffset, colOffset), waveOffset); + for (int block = 0; block < numN; ++block) { - Value blockOffset = i32_val(block * iNonKDim * warpsPerBlock); + int blockOffset = block * iNonKDim * warpsPerBlock; for (int tile = 0; tile < numK; ++tile) { - Value tileOffset = i32_val(tile * iKDim * lineSize); + int tileOffset = tile * iKDim * lineSize; for (int elem = 0; elem < numOfElems; ++elem) { - // halfOffset is an offset related to wrapping of wave in the tile. - // for example, mfma 32 case (mapping of tensor elements to lane ids in - // wave): - // - // 0 1 2 3 ... 31 - // 0 1 2 3 ... 31 - // 0 1 2 3 ... 31 - // 0 1 2 3 ... 31 - // 32 33 34 35 ... 63 <- at this point wave is wrapping - // 32 33 34 35 ... 63 - // 32 33 34 35 ... 63 - // 32 33 34 35 ... 63 - Value halfOffset; - if (iNonKDim == 64) - halfOffset = i32_val(0); - else - halfOffset = - mul(udiv(laneId, _nonKDim), i32_val(numOfElems * lineSize)); - Value rowOffset = add(i32_val(elem * lineSize), halfOffset); - Value elemOffset = add(rowOffset, colOffset); - Value offset = - add(add(add(waveOffset, blockOffset), tileOffset), elemOffset); + int rowOffset = elem * lineSize; + Value offset = add(baseThreadOffset, i32_val(blockOffset + tileOffset + rowOffset)); offsets[numK * numOfElems * block + numOfElems * tile + elem] = offset; } } From aa959e7546bb2e08229e3e49d180a32030120f98 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Wed, 27 Mar 2024 20:49:17 +0000 Subject: [PATCH 7/8] widen vec size in swizzling patern for large operand --- include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 82f1eb8b9283..f5a3faed871c 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -155,7 +155,13 @@ compared to 1*64 when the hasLeadingOffset is false. // Small operand: // - shape 4(non-k)x64(k) for 16 bit dtypes // - shape 4(non-k)x16(k) for 32 bit dtypes - const int vecSize = bankBitWidth / typeWidthInBit; + int vecSize = bankBitWidth / typeWidthInBit; + if (nonKDim == 64 && vecSize < dotOpEnc.getKWidth()) { + // This heuristic introduces bank conflicts, + // but saves address computation overhead and + // reduces number of ds_read/ds_write instructions + vecSize = 4; + } const int perPhase = std::max(1, numBanks / innerDimLength); const int maxPhase = std::min(numBanks, nonKDim) / perPhase; return get(context, vecSize, perPhase, maxPhase, order, CTALayout); From f005ec02947cf92665a23c15de02c921c8d7d34e Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Thu, 28 Mar 2024 21:07:39 +0000 Subject: [PATCH 8/8] [MFMA] Implement MFMA 4x64 v3 This implements version 3 mfma 4x64 layout with swizzled operands --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 12 +++- .../SharedToDotOperandMFMA.cpp | 24 ++++++- .../TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp | 72 ++++++++++++++----- python/06-attention-decode.py | 15 ++-- 4 files changed, 95 insertions(+), 28 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index f5a3faed871c..50bdbaff8cea 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -160,10 +160,16 @@ compared to 1*64 when the hasLeadingOffset is false. // This heuristic introduces bank conflicts, // but saves address computation overhead and // reduces number of ds_read/ds_write instructions - vecSize = 4; + + // can not swizzle in vectors of 4 for fp32, + // because mfma 4x64 and 64x4 shift loads by 1*(laneId/4)in k dimension, breaking borders of one load + vecSize = typeWidthInBit == 32 ? 1 : 4; } - const int perPhase = std::max(1, numBanks / innerDimLength); - const int maxPhase = std::min(numBanks, nonKDim) / perPhase; + const int bankRowInBits = numBanks * bankBitWidth; + const int innerDimInBits = innerDimLength * typeWidthInBit; + const int perPhase = std::max(1, bankRowInBits / innerDimInBits); + const int phaseSanityLimit = std::min(numBanks, innerDimLength / vecSize); + const int maxPhase = std::min(phaseSanityLimit, nonKDim) / perPhase; return get(context, vecSize, perPhase, maxPhase, order, CTALayout); } int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index d803160fa181..50418173f09c 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -168,7 +168,16 @@ llvm::SmallVector> computeTensorElemMappingInBlock( for (int loadId = 0; loadId < loadsPerThread; ++loadId) { Value elemVOffset = _0; - Value elemHOffset = i32_val(loadId * loadVecSize); + Value elemHOffset; + if (iNonKDim == 64) { + Value groupId = udiv(laneId, i32_val(4)); + Value groupShift = mul(groupId, i32_val(iKDim / 16)); + Value loadShift = add(groupShift, i32_val(loadId * loadVecSize)); + Value wrappedLoadShift = urem(loadShift, i32_val(iKDim)); + elemHOffset = wrappedLoadShift; + } else { + elemHOffset = i32_val(loadId * loadVecSize); + } Value sliceVOffset = add(add(add(tileVOffset, laneVOffset), elemVOffset), waveVOffset); @@ -353,8 +362,17 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, for (int tile = 0; tile < numK; ++tile) { int tileOffset = tile * iKDim * lineSize; for (int elem = 0; elem < numOfElems; ++elem) { - int rowOffset = elem * lineSize; - Value offset = add(baseThreadOffset, i32_val(blockOffset + tileOffset + rowOffset)); + Value rowOffset; + if (iNonKDim == 64) { + Value groupId = udiv(laneId, i32_val(4)); + Value groupShift = mul(groupId, i32_val(iKDim / 16)); + Value elemShift = add(groupShift, i32_val(elem)); + Value wrappedElemShift = urem(elemShift, i32_val(iKDim)); + rowOffset = mul(wrappedElemShift, i32_val(lineSize)); + } else { + rowOffset = i32_val(elem * lineSize); + } + Value offset = add(add(baseThreadOffset, rowOffset), i32_val(blockOffset + tileOffset)); offsets[numK * numOfElems * block + numOfElems * tile + elem] = offset; } } diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp index 94809cbdf360..65f09b0e1bac 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -132,6 +132,50 @@ struct DotOpMFMAConversionHelper { return broadcasted; } + Value rollShiftOperand(Value val) const { + constexpr int waveSize = 64; + constexpr int numGroups = 16; + const int groupSize = waveSize / numGroups; + + Value lane = getThreadId(); + Value targetLane = urem(add(lane, i32_val(groupSize)), i32_val(waveSize)); + + // Multiply by 4, because permute requires offset in bytes + Value permuteAddr = mul(targetLane, i32_val(4)); + + Type valType = val.getType(); + Value rolled; + if (valType.isInteger(32)) + rolled = rewriter.create(loc, val.getType(), + permuteAddr, val); + if (valType.isF32()) { + val = bitcast(val, i32_ty); + rolled = rewriter.create(loc, val.getType(), + permuteAddr, val); + rolled = bitcast(rolled, f32_ty); + } + if (valType.isa()) { + auto vecTy = valType.dyn_cast(); + auto vecBitSize = vecTy.getElementType().getIntOrFloatBitWidth() * + vecTy.getNumElements(); + const int int32VecSize = vecBitSize / 32; + + Type int32VecTy = vec_ty(i32_ty, int32VecSize); + Value int32Val = bitcast(val, int32VecTy); + Value int32Rolled = undef(int32VecTy); + for (int i = 0; i < int32VecSize; ++i) { + Value int32Chunk = extract_element(i32_ty, int32Val, i32_val(i)); + Value rolledChunk = rewriter.create( + loc, i32_ty, permuteAddr, int32Chunk); + int32Rolled = insert_element(int32VecTy, int32Rolled, + rolledChunk, i32_val(i)); + } + rolled = bitcast(int32Rolled, valType); + } + assert(rolled); + return rolled; + } + Value generateMFMATile(StringRef mfmaInsnName, SmallVector valA, SmallVector valB, Value valC, int mDim, int nDim, bool transpose) const { @@ -148,25 +192,21 @@ struct DotOpMFMAConversionHelper { constexpr int numRepeats = 16; acc = valC; for (int kRep = 0; kRep < numRepeats; kRep++) { - if (mDim == 4 && !transpose) { - assert(valA.size() == 1 && valB.size() == 16); - acc = generateMFMAOp(mfmaInsnName, valA[0], valB[kRep], acc, - broadcastCtrl, kRep); - } - if (mDim == 4 && transpose) { + if (mDim == 4) { assert(valA.size() == 1 && valB.size() == 16); - Value broadcastValA = broadcastGroup(valA[0], kRep, numRepeats); - acc = generateMFMAOp(mfmaInsnName, valB[kRep], broadcastValA, acc); - } - if (nDim == 4 && !transpose) { - assert(valA.size() == 16 && valB.size() == 1); - Value broadcastValB = broadcastGroup(valB[0], kRep, numRepeats); - acc = generateMFMAOp(mfmaInsnName, valA[kRep], broadcastValB, acc); + if (!transpose) + acc = generateMFMAOp(mfmaInsnName, valA[0], valB[kRep], acc); + else + acc = generateMFMAOp(mfmaInsnName, valB[kRep], valA[0], acc); + valA[0] = rollShiftOperand(valA[0]); } - if (nDim == 4 && transpose) { + if (nDim == 4) { assert(valA.size() == 16 && valB.size() == 1); - acc = generateMFMAOp(mfmaInsnName, valB[0], valA[kRep], acc, - broadcastCtrl, kRep); + if (!transpose) + acc = generateMFMAOp(mfmaInsnName, valA[kRep], valB[0], acc); + else + acc = generateMFMAOp(mfmaInsnName, valB[0], valA[kRep], acc); + valB[0] = rollShiftOperand(valB[0]); } } } diff --git a/python/06-attention-decode.py b/python/06-attention-decode.py index 04d985405a58..21aa031cd1ae 100644 --- a/python/06-attention-decode.py +++ b/python/06-attention-decode.py @@ -522,7 +522,7 @@ class _attention(torch.autograd.Function): NAME = "triton_splitKF" @staticmethod - def forward(cls, q, k, v, scale_float): + def forward(cls, q, k, v, scale_float, matrix_instr_nonkdim=0): cls.SPLIT_K: Optional[int] = None cls.BLOCK_M = 16 @@ -609,7 +609,7 @@ def forward(cls, q, k, v, scale_float): num_stages=1, PACKED_PER_VAL=PACKED_PER_VAL, N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1, - matrix_instr_nonkdim=464, + matrix_instr_nonkdim=matrix_instr_nonkdim, waves_per_eu=1 ) #print(f"kernel run B = {B}, G = {G}, H = {H}, split_k = {split_k}, M_ceil = {M_ceil}, Kq = {Kq}\n", pgm.asm["amdgcn"]) @@ -691,9 +691,12 @@ def get_input_shapes(): return cases -@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K', - get_input_shapes()) -def test_op_fwd(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): +@pytest.mark.parametrize('B, Mq, Mkv, Hq, Hkv, K, matrix_instr_nonkdim', + [(*shape, matrix_instr_nonkdim) + for shape in get_input_shapes() + for matrix_instr_nonkdim in [16, 464] + ]) +def test_op_fwd(B, Mq, Mkv, Hq, Hkv, K, matrix_instr_nonkdim, dtype=torch.float16): torch.manual_seed(20) q = ( torch.empty((B, Mq, Hkv, (Hq + Hkv - 1) // Hkv, K), dtype=dtype, device="cuda") @@ -711,7 +714,7 @@ def test_op_fwd(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): .requires_grad_() ).expand(-1, -1, -1, (Hq + Hkv - 1) // Hkv, -1) scale = 1 / K**0.5 - tri_out = attention(q, k, v, scale) + tri_out = attention(q, k, v, scale, matrix_instr_nonkdim) q = q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) k = k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)