Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the bug that for block_k=16 mma, the compilation crash on Ampere. #15

Closed
wants to merge 11 commits into from

Conversation

bingyizh233
Copy link

The origin issue is reported here: triton-lang#3435 The issue happens during compilation, when arith.sitofp (from i8 to fp16) operates on the tensor operand which has dot_op layout with the first dimension of the tensor being 16 and opidx = 1.
For example: %104 = arith.sitofp %103 : tensor<16x64xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>

Investigation shows that the bug happens in TritonGPUToLLVM pass. in the corner case (block_k = 16 and opidx = 1) extra elements will be unpacked in include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h:line 186-194. The code unpacks extra elements due to an implicit assumption in lib/Dialect/TritonGPU/IR/Dialect.h, at line 2000, at least 4 rep (e.g., i32) will be loaded.

Therefore, in our patch, extra loaded elements are dropped in the corner case (block_k = 16 and opidx = 1).

The core Triton is a small number of people, and we receive many PRs (thank
you!). To help us review your code more quickly, if you are a new
contributor (less than 3 PRs merged) we ask that you complete the following
tasks and include the filled-out checklist in your PR description.

Complete the following tasks before sending your PR, and replace [ ] with
[x] to indicate you have done them.

  • [ x] I am not making a trivial change, such as fixing a typo in a comment.

  • [ x] I have written a PR description following these
    rules.

  • [ x] I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • [x ] This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • [ x] I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

vwbaker and others added 11 commits August 26, 2024 10:11
…ng#4410)

Included the use of the non-deprecated version of createMCObjectStreamer (needed after llvm/llvm-project@f1422a8).
…ich exists in gcc-defaults. (triton-lang#4548)

The llvm build check is trying to get
http://ftp.de.debian.org/debian/pool/main/g/gcc-defaults/gcc-aarch64-linux-gnu_13.2.0-7_amd64.deb,
which does not exist and therefore fails. Updating the version to an
existing one (14.1.0-2).

[x] I am not making a trivial change, such as fixing a typo in a
comment.
[x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).
[x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.
[x] This PR does not need a test because it is not a functional change,
should fix git checks builds.
[x] I have not added any `lit` tests.
The origin issue is reported here: triton-lang#3435
The issue happens during compilation, when arith.sitofp (from i8 to fp16) operates on the tensor operand which has dot_op layout with the first dimension of the tensor being 16 and opidx = 1.
For example: %104 = arith.sitofp %103 : tensor<16x64xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>

Investigation shows that the bug happens in TritonGPUToLLVM pass.
in the corner case (block_k = 16 and opidx = 1) extra elements will be unpacked in include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h:line 186-194.
The code unpack extra elements due to an implicit assumption in lib/Dialect/TritonGPU/IR/Dialect.h, at line 2000, at least 4 rep will be loaded.

Therefore, in our patch, extra loaded elements are dropped in the corner case.
@gflegar
Copy link
Member

gflegar commented Sep 20, 2024

For tracking: this has been opened as a triton-lang#4768 upstream

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants