Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dot] [MFMA] [FMA] Update Dot implementation to support upstream tests #260

Merged
merged 7 commits into from
Aug 3, 2023

Conversation

binarman
Copy link

@binarman binarman commented Jul 13, 2023

This PR adds:

  • support of FP16 outputs for MFMA dot operations
  • fallback to FMA implementation in case MFMA can not handle input sizes

@zhanglx13
Copy link

@micmelesse Can you try this PR to see if it resolve the issue in test_core?

@binarman
Copy link
Author

binarman commented Jul 13, 2023

Yes, It is fixing the issue, I'll move tests to test_core_amd.py soon

@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype",
[(*shape, 2, False, False, epilogue, allow_tf32, dtype)
@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype",
[(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype)
for shape in [(64, 64, 64), (32, 32, 32)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we not testing (16,16,16)? It is in upstream and when I try it I get a segfault.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are running them on MI100/200 GPU, it tries to use MFMA instructions with minimal M/N size of 32.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@micmelesse
Upd. I've added workaround with FMA instructions, so 16x16x16 tests work

@binarman binarman requested a review from micmelesse July 26, 2023 14:08
@binarman binarman changed the title [Dot] [MFMA] Support FP16 output of MFMA dot [Dot] [MFMA] [FMA] Update Dot implementation to support upstream tests Jul 26, 2023
Comment on lines +1292 to +1310
if ret_cast_scalar_ty == tl.float16:
_0 = builder.create_splat(builder.get_fp16(0), [M, N])
else:
_0 = builder.create_splat(builder.get_fp32(0), [M, N])
ret_ty = tl.block_type(ret_cast_scalar_ty, [M, N])
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)
return cast(ret, ret_scalar_ty, builder)
if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32:
if lhs.type.scalar.is_int():
ret_dot_scalar_ty = tl.int32
_0 = builder.create_splat(builder.get_int32(0), [M, N])
else:
ret_dot_scalar_ty = tl.float32
_0 = builder.create_splat(builder.get_fp32(0), [M, N])
ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N])
ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)
return cast(ret, ret_scalar_ty, builder)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is related to support of FP16 output

Comment on lines 152 to 168
static bool supportMFMAGranularity(int dim_size) {
std::vector<int> supported_granularity{32};
for (int granularity: supported_granularity)
if (dim_size % granularity == 0)
return true;
return false;
}

bool supportMFMA(triton::DotOp op) {
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
auto aTy = op.getA().getType().cast<RankedTensorType>();
auto bTy = op.getB().getType().cast<RankedTensorType>();

auto aShape = aTy.getShape();
auto bShape = bTy.getShape();

assert(aShape[1] == bShape[0]);
if (!supportMFMAGranularity(aShape[0]) ||
!supportMFMAGranularity(aShape[1]) ||
!supportMFMAGranularity(bShape[1]))
return false;

auto aElemTy = aTy.getElementType();
auto bElemTy = bTy.getElementType();
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part enables fallback to FMA implementation for small matrix sizes in C++ part of compiler

Comment on lines 1241 to 1254
def mfma_supported_granularity(dim_size) -> bool:
supported_granularity = [32]
for granularity in supported_granularity:
if dim_size % granularity == 0:
return True
return False

def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool:
if not gpu_has_mfma():
return False
if not mfma_supported_granularity(M) or \
not mfma_supported_granularity(N) or \
not mfma_supported_granularity(K):
return False
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part enables fallback to FMA implementation for small matrix sizes in python part of compiler

@binarman binarman force-pushed the mfma_dot_out_fp16 branch 2 times, most recently from 020d111 to defaa55 Compare August 2, 2023 13:44
@binarman binarman marked this pull request as draft August 2, 2023 14:29
@binarman binarman marked this pull request as ready for review August 2, 2023 16:46
@@ -149,13 +149,34 @@ bool supportMMA(triton::DotOp op, int version) {
}

#ifdef USE_ROCM
static bool supportMFMAGranularity(int m, int n, int k) {
// these limitations are dtype dependent, in future we may relax them
int granularityMN = 32;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we define these as constants somewhere so we can change easily if needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this should be a constant.

About place: I actually can not think of better place for this constants.

P.s. If you are worried about duplication in C++ and python code: I want to refactor this and remove python part eventually, so there will be only one place with these constants.

@jayfurmanek jayfurmanek dismissed micmelesse’s stale review August 3, 2023 18:47

Michael is out; Request addressed.

@jayfurmanek jayfurmanek merged commit 86f8b64 into ROCm:triton-mlir Aug 3, 2023
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants