diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp index a83972ae1948..35c45fd2c069 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp @@ -129,76 +129,6 @@ warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef shape, int numWarps) { return warpsPerTile(dotOp, shape, numWarps, {16, 16}); } -/** - * @brief Convert layout and cast element type of a given tensor - * - * If old element type is different from new element type, this function - * creates two new operations: - * 1. %converted_value = layout_convert %value, newEncoding - * 2. %casted_value = cast(fext, ftrunc, etc.) %value, newElemType - * - * If old element type is same as new element type, this function creates only - * one operation: %converted_value = layout_convert %value, newEncoding - * - * @param rewriter - * @param value original tensor value, which we need to convert and cast - * @param newEncoding new encoding for the tenosr - * @param newElemType new element type for the tensor - * @return converted and optionaly casted tensor value - */ -Value convertAndCastTensor(mlir::PatternRewriter &rewriter, Value value, - ::mlir::Attribute newEncoding, Type newElemType) { - assert(newElemType.isIntOrFloat()); - - auto loc = value.getLoc(); - auto oldType = value.getType().cast(); - auto oldElemType = oldType.getElementType(); - - assert(oldElemType.isIntOrFloat()); - assert(oldElemType.isIntOrIndex() == newElemType.isIntOrIndex()); - - auto convertedType = - RankedTensorType::get(oldType.getShape(), oldElemType, newEncoding); - - Value convertedTensor = - rewriter.create(loc, convertedType, value); - - if (newElemType == oldElemType) - return convertedTensor; - - Type castedType = convertedType.cloneWith(std::nullopt, newElemType); - - Value castedTensor; - - if (newElemType.isIntOrIndex()) { - unsigned oldWidth = oldElemType.getIntOrFloatBitWidth(); - unsigned newWidth = newElemType.getIntOrFloatBitWidth(); - if (oldWidth == newWidth) - castedTensor = rewriter.create(loc, convertedType, - convertedTensor); - else if (oldWidth > newWidth) - castedTensor = rewriter.create(loc, castedType, - convertedTensor); - else if (oldElemType.isSignedInteger()) - castedTensor = rewriter.create(loc, castedType, - convertedTensor); - else - castedTensor = rewriter.create(loc, castedType, - convertedTensor); - } else { - if (oldElemType.isF16() && newElemType.isF32()) - castedTensor = rewriter.create(loc, castedType, - convertedTensor); - else if (oldElemType.isF32() && newElemType.isF16()) - castedTensor = rewriter.create(loc, castedType, - convertedTensor); - else - castedTensor = - rewriter.create(loc, castedType, convertedTensor); - } - return castedTensor; -} - class BlockedToMFMA : public mlir::RewritePattern { int mfmaVersion; int enforcedNonKDim; @@ -310,17 +240,13 @@ class BlockedToMFMA : public mlir::RewritePattern { /*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile, /*instrShape*/ mDim, nDim, isTransposed); - // convert accumulator - Type mfmaAccType; - if (oldRetType.getElementType().isIntOrIndex()) - mfmaAccType = rewriter.getIntegerType(32); - else - mfmaAccType = rewriter.getF32Type(); + auto newRetType = + RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc); + // convert accumulator auto oldAcc = dotOp.getOperand(2); - auto newAcc = convertAndCastTensor(rewriter, oldAcc, mfmaEnc, mfmaAccType); - - // convert A/B operands + auto newAcc = rewriter.create(oldAcc.getLoc(), + newRetType, oldAcc); auto oldAOrder = oldAType.getEncoding() .cast() .getParent() @@ -355,16 +281,12 @@ class BlockedToMFMA : public mlir::RewritePattern { ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth)); a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); - auto newDot = rewriter.create(dotOp.getLoc(), newAcc.getType(), - a, b, newAcc, dotOp.getAllowTF32(), - dotOp.getMaxNumImpreciseAcc()); - - Value dotOutput = - convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(), - oldRetType.getElementType()); - - rewriter.replaceOp(op, dotOutput); + auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, + newAcc, dotOp.getAllowTF32(), + dotOp.getMaxNumImpreciseAcc()); + rewriter.replaceOpWithNewOp(op, oldRetType, + newDot.getResult()); return success(); } }; diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index ccea1b6859b3..059d727f4c1b 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1324,7 +1324,26 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, max_num_imprecise_acc), ret_ty) return cast(ret, ret_scalar_ty, builder) + if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, + ret_scalar_ty, builder.target) and ret_scalar_ty.primitive_bitwidth <= 32: + # max_num_imprecise_acc does not yet apply to hip + if is_hip(): + max_num_imprecise_acc = 0 + if max_num_imprecise_acc is None: + max_num_imprecise_acc = 2**30 + if lhs.type.scalar.is_int(): + ret_dot_scalar_ty = tl.int32 + _0 = builder.create_splat(builder.get_int32(0), [M, N]) + else: + ret_dot_scalar_ty = tl.float32 + _0 = builder.create_splat(builder.get_fp32(0), [M, N]) + ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N]) + ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, max_num_imprecise_acc), + ret_ty) + return cast(ret, ret_scalar_ty, builder) + + _0 = builder.create_splat(_0, [M, N]) ret_ty = tl.block_type(ret_scalar_ty, [M, N]) if acc is None: acc_handle = builder.create_splat(_0, [M, N]) @@ -1333,11 +1352,11 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): assert acc.type == ret_ty # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 + if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() + and ret_scalar_ty.is_fp32()): + max_num_imprecise_acc = 0 if max_num_imprecise_acc is None: - if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): - max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default - else: - max_num_imprecise_acc = 0 + max_num_imprecise_acc = 2**30 return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty)