Skip to content

Commit

Permalink
Revert "[MFMA] Move operand casts to AccelerateMatMul pass (#477)" (#500
Browse files Browse the repository at this point in the history
)

This reverts commit a9c0bdb.
  • Loading branch information
zhanglx13 authored Feb 6, 2024
1 parent a9c0bdb commit 1181155
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 92 deletions.
98 changes: 10 additions & 88 deletions lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,76 +129,6 @@ warpsPerTileWMMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
return warpsPerTile(dotOp, shape, numWarps, {16, 16});
}

/**
* @brief Convert layout and cast element type of a given tensor
*
* If old element type is different from new element type, this function
* creates two new operations:
* 1. %converted_value = layout_convert %value, newEncoding
* 2. %casted_value = cast(fext, ftrunc, etc.) %value, newElemType
*
* If old element type is same as new element type, this function creates only
* one operation: %converted_value = layout_convert %value, newEncoding
*
* @param rewriter
* @param value original tensor value, which we need to convert and cast
* @param newEncoding new encoding for the tenosr
* @param newElemType new element type for the tensor
* @return converted and optionaly casted tensor value
*/
Value convertAndCastTensor(mlir::PatternRewriter &rewriter, Value value,
::mlir::Attribute newEncoding, Type newElemType) {
assert(newElemType.isIntOrFloat());

auto loc = value.getLoc();
auto oldType = value.getType().cast<RankedTensorType>();
auto oldElemType = oldType.getElementType();

assert(oldElemType.isIntOrFloat());
assert(oldElemType.isIntOrIndex() == newElemType.isIntOrIndex());

auto convertedType =
RankedTensorType::get(oldType.getShape(), oldElemType, newEncoding);

Value convertedTensor =
rewriter.create<ttg::ConvertLayoutOp>(loc, convertedType, value);

if (newElemType == oldElemType)
return convertedTensor;

Type castedType = convertedType.cloneWith(std::nullopt, newElemType);

Value castedTensor;

if (newElemType.isIntOrIndex()) {
unsigned oldWidth = oldElemType.getIntOrFloatBitWidth();
unsigned newWidth = newElemType.getIntOrFloatBitWidth();
if (oldWidth == newWidth)
castedTensor = rewriter.create<mlir::arith::BitcastOp>(loc, convertedType,
convertedTensor);
else if (oldWidth > newWidth)
castedTensor = rewriter.create<mlir::arith::TruncIOp>(loc, castedType,
convertedTensor);
else if (oldElemType.isSignedInteger())
castedTensor = rewriter.create<mlir::arith::ExtSIOp>(loc, castedType,
convertedTensor);
else
castedTensor = rewriter.create<mlir::arith::ExtUIOp>(loc, castedType,
convertedTensor);
} else {
if (oldElemType.isF16() && newElemType.isF32())
castedTensor = rewriter.create<mlir::arith::ExtFOp>(loc, castedType,
convertedTensor);
else if (oldElemType.isF32() && newElemType.isF16())
castedTensor = rewriter.create<mlir::arith::TruncFOp>(loc, castedType,
convertedTensor);
else
castedTensor =
rewriter.create<tt::FpToFpOp>(loc, castedType, convertedTensor);
}
return castedTensor;
}

class BlockedToMFMA : public mlir::RewritePattern {
int mfmaVersion;
int enforcedNonKDim;
Expand Down Expand Up @@ -310,17 +240,13 @@ class BlockedToMFMA : public mlir::RewritePattern {
/*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile,
/*instrShape*/ mDim, nDim, isTransposed);

// convert accumulator
Type mfmaAccType;
if (oldRetType.getElementType().isIntOrIndex())
mfmaAccType = rewriter.getIntegerType(32);
else
mfmaAccType = rewriter.getF32Type();
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc);

// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = convertAndCastTensor(rewriter, oldAcc, mfmaEnc, mfmaAccType);

// convert A/B operands
auto newAcc = rewriter.create<ttg::ConvertLayoutOp>(oldAcc.getLoc(),
newRetType, oldAcc);
auto oldAOrder = oldAType.getEncoding()
.cast<ttg::DotOperandEncodingAttr>()
.getParent()
Expand Down Expand Up @@ -355,16 +281,12 @@ class BlockedToMFMA : public mlir::RewritePattern {
ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth));
a = rewriter.create<ttg::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<ttg::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<tt::DotOp>(dotOp.getLoc(), newAcc.getType(),
a, b, newAcc, dotOp.getAllowTF32(),
dotOp.getMaxNumImpreciseAcc());

Value dotOutput =
convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(),
oldRetType.getElementType());

rewriter.replaceOp(op, dotOutput);
auto newDot = rewriter.create<tt::DotOp>(dotOp.getLoc(), newRetType, a, b,
newAcc, dotOp.getAllowTF32(),
dotOp.getMaxNumImpreciseAcc());

rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType,
newDot.getResult());
return success();
}
};
Expand Down
27 changes: 23 additions & 4 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,26 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target):
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, max_num_imprecise_acc),
ret_ty)
return cast(ret, ret_scalar_ty, builder)
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32,
ret_scalar_ty, builder.target) and ret_scalar_ty.primitive_bitwidth <= 32:
# max_num_imprecise_acc does not yet apply to hip
if is_hip():
max_num_imprecise_acc = 0
if max_num_imprecise_acc is None:
max_num_imprecise_acc = 2**30

if lhs.type.scalar.is_int():
ret_dot_scalar_ty = tl.int32
_0 = builder.create_splat(builder.get_int32(0), [M, N])
else:
ret_dot_scalar_ty = tl.float32
_0 = builder.create_splat(builder.get_fp32(0), [M, N])
ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N])
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, max_num_imprecise_acc),
ret_ty)
return cast(ret, ret_scalar_ty, builder)

_0 = builder.create_splat(_0, [M, N])
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
if acc is None:
acc_handle = builder.create_splat(_0, [M, N])
Expand All @@ -1333,11 +1352,11 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target):
assert acc.type == ret_ty

# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8()
and ret_scalar_ty.is_fp32()):
max_num_imprecise_acc = 0
if max_num_imprecise_acc is None:
if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
max_num_imprecise_acc = builder.options.max_num_imprecise_acc_default
else:
max_num_imprecise_acc = 0
max_num_imprecise_acc = 2**30

return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty)

Expand Down

0 comments on commit 1181155

Please sign in to comment.