-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix the bug that for block_k=16 mma, the compilation crash on Ampere. #4768
base: main
Are you sure you want to change the base?
Conversation
The origin issue is reported here: triton-lang#3435 The issue happens during compilation, when arith.sitofp (from i8 to fp16) operates on the tensor operand which has dot_op layout with the first dimension of the tensor being 16 and opidx = 1. For example: %104 = arith.sitofp %103 : tensor<16x64xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> Investigation shows that the bug happens in TritonGPUToLLVM pass. in the corner case (block_k = 16 and opidx = 1) extra elements will be unpacked in include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h:line 186-194. The code unpack extra elements due to an implicit assumption in lib/Dialect/TritonGPU/IR/Dialect.h, at line 2000, at least 4 rep will be loaded. Therefore, in our patch, extra loaded elements are dropped in the corner case.
Hi, thanks for your information. However, a PR like this is not acceptable. Most importantly, I don't see any test case included |
We can take a first pass at this from the Google side, since we asked Nvidia for help with this, and once it's in an acceptable state we will assign to a core maintainer for final approval. @Moerafaat can take a first look (unfortunately, I don't think I can assign him officially as the reviewer), since he's most familiar with that part of the codebase. And then @chsigg can serve as the final quality gate, as agreed between OAI & Google. |
Thank you for your feedback! I’ll work on adding the test cases. Additionally, could you please let me know if there’s anything else that might be missing? |
The mixed precision stuff is quite fragile as far as I know. I'm not sure as I haven't been taken a look for a while. |
test/Conversion/tritongpu_to_llvm_ampere.mlir.TYhis This is a the test case for converting tensor from s8 to fp16 when the first dimension of the tensor==16 and opidx==1. Previously, the compilation could crash during convert-triton-gpu-llvm on Nvidia Ampere GPU. The new code patch resolves the issue. This new test case is to verify the crash does not exist.
This is a the test case for converting tensor from s8 to fp16 when the first dimension of the tensor==16 and opidx==1. Previously, the compilation could crash during convert-triton-gpu-llvm on Nvidia Ampere GPU. The new code patch resolves the issue. This new test case is to verify the crash does not exist
This is a the test case for converting tensor from s8 to fp16 when the first dimension of the tensor==16 and opidx==1. Previously, the compilation could crash during convert-triton-gpu-llvm on Nvidia Ampere GPU. The new code patch resolves the issue. This new test case is to verify the crash does not exist
Thanks for the effort here, I will take a deeper look. Just a few notes for now:
|
I was also trying to see other scenarios that we ran into that could be classified as the same issue. For example, looking at this sample HLO:
With your change this will fail with the same "size mismatch" error. I suspect that in this case it would be because the adjustment will not apply to the left-hand side because of the restriction on the OpIdx. Is that restriction necessary or is something else the reason for failing in this example? |
Another example where the current solution would fail is for this example:
This one goes through tt.sparse_dot (only exists if you use XLA's Triton). It's fine if you can't reproduce this example, but what I could also observe with this example for some tilings (for example shape[0] = 64, shape[1] = 16 for OpIdx = 0) is that this will also run into the same failure. So I'm questioning if the condition strictly applies for shape[0]? |
Also make sure to run pre-commit on the change. |
[128, 128, 128, 64, 16, 16, torch.float16, torch.int8] | ||
]) | ||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 8, reason="Requires compute capability == 8") | ||
def test_gemm_mixed_dtype(M, N, K, block_m, block_n, block_k, dtype_a, dtype_b): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you leave a comment explaining the purpose of this test? There are mixed-type tests that exist. The emphasis here should be focused on mixed-types with small tile sizes. I would suggest also renaming the file to indicate emphasis on small tiles.
Also can we add more variety in the parametrization? Perhaps use fp8 types as well instead of just int8.
|
||
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> | ||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 3072 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} { | ||
tt.func public @amperes8tofp16conversion(%1 : tensor<16x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>) attributes {noinline = false} { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Triton's conversion on the naming here is snake_case. Please adjust.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have adjusted the naming following snake_case and added a test case when 'opIdx == 1 and shape[1] == 16'.
ret.push_back(values[i + 13]); | ||
ret.push_back(values[i + 14]); | ||
ret.push_back(values[i + 15]); | ||
if(in_shape[0] == 16 && inEncoding.getOpIdx() == 1){ // In the corner case where in_shape[0] == 16 and getOpIdx() == 1, extra elements will be loaded. It is necessary to discard these additional elements. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I mentioned in the comments I left on the PR, I'm questioning the condition here as some cases will not match this and fail for the same reason.
I tried to change this condition to
in_shape[0] == 16 || out_shape == 16
instead, and it passes the 2 cases I indicated previously that fail. That being said, I'm not sure if this is the correct way to handle it or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the information. I have looked into this. The condition to (in_shape[0] == 16 && inEncoding.getOpIdx() == 1) && (in_shape[1] == 16 && inEncoding.getOpIdx() == 0)
resolve the issues and might be a better way, including s8xfp16 case. I will look at the other cases you mentions.
@@ -56,23 +58,37 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType, | |||
} | |||
if (inBitWidth == 8 && ouBitWidth == 16) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this fix only applicable to conversions from 8-bit to 16-bit or is this extendable to other combinations as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only works for the conversions from 8-bit to 16-bit. Other datatype (e.g., 32-bit) has different data layout for triton_gpu.dot_op.
Thanks for the detailed comments. I update the commit log by adding the permalink. I will test other cases and take a deep look into it. Could you provide an example code for running this gemm_fusion_autotuner ? |
Can I ask what is the data tiling configurations restriction in here besides triton's restriction? As far as I know, triton has the restriction that block_m >=16, block_n >=16, block_k>=16, and block_m, block_n, block_k should be power 2. Besides the triton restriction and (block_k != 16) which leads to crash. What is other restriction in the autotuner? Thanks! |
You can run tests here that are purposed for this. You don't have to worry about coverage for now, I can test internally and report to you other issues if they arise. |
The restriction here is Triton specific only at the mentioned line. The idea here is that block_k = 16 is specific to cases where we have 8-bit width types. The way it is written here is just the generalization for any bit-width. |
ret.push_back(values[i + 13]); | ||
ret.push_back(values[i + 14]); | ||
ret.push_back(values[i + 15]); | ||
if ((in_shape[0] == 16 && inEncoding.getOpIdx() == 1) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would be more readable:
bool loadsExtraElements = in_shape[1-inEncoding.getOpIdx()] == 16;
for (unsigned i = 0; i < values.size(); i += 16) {
ret.push_back(values[i + 0]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 8]);
ret.push_back(values[i + 9]);
ret.push_back(values[i + 10]);
ret.push_back(values[i + 11]);
if (loadsExtraElements)
continue; // Discard elements that aren't needed.
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
ret.push_back(values[i + 12]);
ret.push_back(values[i + 13]);
ret.push_back(values[i + 14]);
ret.push_back(values[i + 15]);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestions. I have restructured the code in the newest commit. All the test and pre-commit have been done again.
After addressing the last comment from @chsigg I think the PR would be ready for OAI to have a look. |
I have restructured the code in the newest commit. All the test and pre-commit have been done again. |
@Jokeren, this should be ready for review now. Would you mind to take a look? Thanks! |
cc @ThomasRaoux |
This seems like a very ad hoc workaround for the crash. I assume you are not blocked by this on the XLA side as you have this patch already? |
IMHO, I think it's fine to merge if it's an urgent issue that prevents you guys from moving forward, and we can use your code as a reference while we're refactoring the dot operand layout. |
@ThomasRaoux it is indeed working around the issue, but not in an invasive manner I would say as the change is small and local. We work around it in XLA by entirely omitting tiling configurations that include such cases. We really would like to include these sizes in our tiling configs and this fix will allow us to do so. We have many cases where we maintain some fixes internally (some are hacks), but ideally where we can, it would be great to have the changes upstream to minimize the differences. This allows us to also monitor issues and collaborate easily if this diff is smaller. |
@ThomasRaoux just a friendly ping on this thread when you have time to take a look :) |
Trying to do rebase and test it on Ampere. However, currently pytest of upstream has errors on ampere platform. Report the issue here #4990 |
ret.push_back(values[i + 6]); | ||
ret.push_back(values[i + 7]); | ||
if (loadsExtraElements) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also repeat extra elements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I did not get the words "repeat extra elements". Any clarification is appreciated. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Push i=0 to 7
again into ret
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
if (loadsExtraElements){
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
continue; // Discard elements that aren't needed.
}
I just tried to push the i=0 -7
again. However, it gives error when testing python/test/unit/ampere/test_gemm_mixed_dtype_small_tile.py. Here, packing extra elements may cause unmatch for number of registers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Jokeren I'm not sure I understand what you mean. The intent of the change is to actually not pack the extra elements as they result in a size mismatch during lowering to LLVM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's wait a bit until this PR has been landed. #5044
Maybe a lot of issues you guys saw previously will be gone
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We merged this change internally (via a patch) for now as we are facing issues now with Triton at main (including the change in #5044.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll look into it today
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768). PiperOrigin-RevId: 695275077
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768). PiperOrigin-RevId: 695275077
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768). PiperOrigin-RevId: 695275077
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768). PiperOrigin-RevId: 695275077
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768). PiperOrigin-RevId: 695275077
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768). PiperOrigin-RevId: 695275077
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768). PiperOrigin-RevId: 695275077
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768). PiperOrigin-RevId: 695275077
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768). PiperOrigin-RevId: 695311919
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768). PiperOrigin-RevId: 695311919
The origin issue is reported here #3435: The issue happens during compilation, when arith.sitofp (from i8 to fp16) operates on the tensor operand which has dot_op layout with the first dimension of the tensor being 16 and opidx = 1. For example: %104 = arith.sitofp %103 : tensor<16x64xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> Investigation shows that the bug happens in TritonGPUToLLVM pass. in the corner case (block_k = 16 and opidx = 1) extra elements will be unpacked in include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h:line (here). The code unpack extra elements due to an implicit assumption in lib/Dialect/TritonGPU/IR/Dialect.h, at here, at least 4 rep will be loaded. Therefore, in our patch, extra loaded elements are dropped in the corner case.
The core Triton is a small number of people, and we receive many PRs (thank
you!). To help us review your code more quickly, if you are a new
contributor (less than 3 PRs merged) we ask that you complete the following
tasks and include the filled-out checklist in your PR description.
Complete the following tasks before sending your PR, and replace
[ ]
with[x]
to indicate you have done them.I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsFILL THIS IN
.Select one of the following.
lit
tests.lit
tests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)