-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dot] [MFMA] [FMA] Update Dot implementation to support upstream tests #260
Conversation
@micmelesse Can you try this PR to see if it resolve the issue in test_core? |
Yes, It is fixing the issue, I'll move tests to test_core_amd.py soon |
c16e849
to
851a103
Compare
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype", | ||
[(*shape, 2, False, False, epilogue, allow_tf32, dtype) | ||
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype", | ||
[(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype) | ||
for shape in [(64, 64, 64), (32, 32, 32)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we not testing (16,16,16)? It is in upstream and when I try it I get a segfault.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you are running them on MI100/200 GPU, it tries to use MFMA instructions with minimal M/N size of 32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@micmelesse
Upd. I've added workaround with FMA instructions, so 16x16x16 tests work
851a103
to
a429da0
Compare
if ret_cast_scalar_ty == tl.float16: | ||
_0 = builder.create_splat(builder.get_fp16(0), [M, N]) | ||
else: | ||
_0 = builder.create_splat(builder.get_fp32(0), [M, N]) | ||
ret_ty = tl.block_type(ret_cast_scalar_ty, [M, N]) | ||
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), | ||
ret_ty) | ||
return cast(ret, ret_scalar_ty, builder) | ||
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32: | ||
if lhs.type.scalar.is_int(): | ||
ret_dot_scalar_ty = tl.int32 | ||
_0 = builder.create_splat(builder.get_int32(0), [M, N]) | ||
else: | ||
ret_dot_scalar_ty = tl.float32 | ||
_0 = builder.create_splat(builder.get_fp32(0), [M, N]) | ||
ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N]) | ||
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), | ||
ret_ty) | ||
return cast(ret, ret_scalar_ty, builder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is related to support of FP16 output
lib/Analysis/Utility.cpp
Outdated
static bool supportMFMAGranularity(int dim_size) { | ||
std::vector<int> supported_granularity{32}; | ||
for (int granularity: supported_granularity) | ||
if (dim_size % granularity == 0) | ||
return true; | ||
return false; | ||
} | ||
|
||
bool supportMFMA(triton::DotOp op) { | ||
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType(); | ||
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType(); | ||
auto aTy = op.getA().getType().cast<RankedTensorType>(); | ||
auto bTy = op.getB().getType().cast<RankedTensorType>(); | ||
|
||
auto aShape = aTy.getShape(); | ||
auto bShape = bTy.getShape(); | ||
|
||
assert(aShape[1] == bShape[0]); | ||
if (!supportMFMAGranularity(aShape[0]) || | ||
!supportMFMAGranularity(aShape[1]) || | ||
!supportMFMAGranularity(bShape[1])) | ||
return false; | ||
|
||
auto aElemTy = aTy.getElementType(); | ||
auto bElemTy = bTy.getElementType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part enables fallback to FMA implementation for small matrix sizes in C++ part of compiler
python/triton/language/semantic.py
Outdated
def mfma_supported_granularity(dim_size) -> bool: | ||
supported_granularity = [32] | ||
for granularity in supported_granularity: | ||
if dim_size % granularity == 0: | ||
return True | ||
return False | ||
|
||
def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: | ||
if not gpu_has_mfma(): | ||
return False | ||
if not mfma_supported_granularity(M) or \ | ||
not mfma_supported_granularity(N) or \ | ||
not mfma_supported_granularity(K): | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part enables fallback to FMA implementation for small matrix sizes in python part of compiler
020d111
to
defaa55
Compare
lib/Analysis/Utility.cpp
Outdated
@@ -149,13 +149,34 @@ bool supportMMA(triton::DotOp op, int version) { | |||
} | |||
|
|||
#ifdef USE_ROCM | |||
static bool supportMFMAGranularity(int m, int n, int k) { | |||
// these limitations are dtype dependent, in future we may relax them | |||
int granularityMN = 32; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we define these as constants somewhere so we can change easily if needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, this should be a constant.
About place: I actually can not think of better place for this constants.
P.s. If you are worried about duplication in C++ and python code: I want to refactor this and remove python part eventually, so there will be only one place with these constants.
6ffd66a
to
f0659ee
Compare
This PR adds cast of output tensor to requested data type.
f0659ee
to
a1e8311
Compare
Michael is out; Request addressed.
This PR adds: