-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BACKEND] Add a loop unroller pass #4645
Conversation
#4639 is needed to unblock a test failure where the MLIR loop unroller doesn't work on for loops with integer IV. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This diff looks good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you currently get performance wins from this without needing other changes?
Yes when used with kernel override, for kernels where register pressure isn't an issue (e.g persistent kernels). For E2E run we'll need a frontend change. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious about the use case and if you considered using tl.static_range?
One typical use case is that for the GEMM persistent kernel, like #4662 proposes, to annotate the flattened loop. An unroll factor of 2 on tile 128x128x128 gives some speedup, where SMEM and register pressure aren't an issue. The speedup can be higher when combined with our subsequent work of branch removal, i.e, removing useless branch checking for the prolog and epilog checks. |
5ffcf23
to
b9591ad
Compare
but GEMM kernels usually don't have a static K loop? It doesn't sound like something that can be used in general cases? |
actually looking at the code again I see that it supports dynamic case. Is that what you do for GEMM? I don't understand how this helps since the unrolled iteration will have to be guarded by an IF op right? |
Yes, something like
where
So each original iteration comes with two if-checks, one to identify the start of the original conceptual inner K-loop, one is to identify the last iteration of the K-loop. Once unrolled by 2, there'd be four if-checks, two of which are unnecessary. Yes the unrolled version comes with an epilog reminder loop, which can be unnecessary too. I think one more hint to the loop unroller should get it. |
interesting, that makes sense |
@ThomasRaoux @antiagainst Thanks for reviewing this change! How does the latest version look to you? Please let me know if you have more comments and I'm happy to address them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine to me assuming the FE changes are ready as well
#4662 is for the FE side changes BTW. |
This change exposes the scf For Loop attribute used in PR #4645 the frontend. It does this by adding a field to tl.range (the same `as num_stages`), this will allow setting loop unrolling factors like so: ``` @triton.jit def _kernel(dst, v): pid = tl.program_id(axis=0) for i in tl.range(0, 10, loop_unroll_factor=2): tl.atomic_add(dst + pid, i + pid) ``` Unroll factors of less than 2 do nothing, but 2 or more results in the loop body being replicated that number of times (similar to a clang `#pragma unroll`).
Adding a loop unroller pass which applies to only loops with unroll annotation.
An annotated loop will look like: