Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Add a loop unroller pass #4645

Merged
merged 7 commits into from
Sep 9, 2024
Merged

Conversation

htyu
Copy link
Collaborator

@htyu htyu commented Sep 4, 2024

Adding a loop unroller pass which applies to only loops with unroll annotation.

An annotated loop will look like:

    scf.for %arg5 = %c0_i32 to %arg3 step %c32_i32 : i32 {
      ...
    } {tt.loop_unroll_factor = 2 : i32}

@htyu
Copy link
Collaborator Author

htyu commented Sep 4, 2024

#4639 is needed to unblock a test failure where the MLIR loop unroller doesn't work on for loops with integer IV.

Copy link
Collaborator

@manman-ren manman-ren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff looks good to me!

Copy link
Collaborator

@manman-ren manman-ren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you currently get performance wins from this without needing other changes?

@htyu
Copy link
Collaborator Author

htyu commented Sep 5, 2024

Do you currently get performance wins from this without needing other changes?

Yes when used with kernel override, for kernels where register pressure isn't an issue (e.g persistent kernels). For E2E run we'll need a frontend change.

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about the use case and if you considered using tl.static_range?

lib/Dialect/Triton/Transforms/LoopUnroll.cpp Outdated Show resolved Hide resolved
third_party/nvidia/backend/compiler.py Outdated Show resolved Hide resolved
lib/Dialect/Triton/Transforms/LoopUnroll.cpp Show resolved Hide resolved
test/Triton/loop-unroll.mlir Outdated Show resolved Hide resolved
@htyu
Copy link
Collaborator Author

htyu commented Sep 6, 2024

I'm curious about the use case and if you considered using tl.static_range?

One typical use case is that for the GEMM persistent kernel, like #4662 proposes, to annotate the flattened loop. An unroll factor of 2 on tile 128x128x128 gives some speedup, where SMEM and register pressure aren't an issue. The speedup can be higher when combined with our subsequent work of branch removal, i.e, removing useless branch checking for the prolog and epilog checks.

@htyu htyu force-pushed the hoy/loopunroll branch 2 times, most recently from 5ffcf23 to b9591ad Compare September 6, 2024 16:01
@ThomasRaoux
Copy link
Collaborator

I'm curious about the use case and if you considered using tl.static_range?

One typical use case is that for the GEMM persistent kernel, like #4662 proposes, to annotate the flattened loop. An unroll factor of 2 on tile 128x128x128 gives some speedup, where SMEM and register pressure aren't an issue. The speedup can be higher when combined with our subsequent work of branch removal, i.e, removing useless branch checking for the prolog and epilog checks.

but GEMM kernels usually don't have a static K loop? It doesn't sound like something that can be used in general cases?

@ThomasRaoux
Copy link
Collaborator

I'm curious about the use case and if you considered using tl.static_range?

One typical use case is that for the GEMM persistent kernel, like #4662 proposes, to annotate the flattened loop. An unroll factor of 2 on tile 128x128x128 gives some speedup, where SMEM and register pressure aren't an issue. The speedup can be higher when combined with our subsequent work of branch removal, i.e, removing useless branch checking for the prolog and epilog checks.

but GEMM kernels usually don't have a static K loop? It doesn't sound like something that can be used in general cases?

actually looking at the code again I see that it supports dynamic case. Is that what you do for GEMM? I don't understand how this helps since the unrolled iteration will have to be guarded by an IF op right?

@htyu
Copy link
Collaborator Author

htyu commented Sep 6, 2024

I'm curious about the use case and if you considered using tl.static_range?

One typical use case is that for the GEMM persistent kernel, like #4662 proposes, to annotate the flattened loop. An unroll factor of 2 on tile 128x128x128 gives some speedup, where SMEM and register pressure aren't an issue. The speedup can be higher when combined with our subsequent work of branch removal, i.e, removing useless branch checking for the prolog and epilog checks.

but GEMM kernels usually don't have a static K loop? It doesn't sound like something that can be used in general cases?

actually looking at the code again I see that it supports dynamic case. Is that what you do for GEMM?

Yes, something like

`for _ in tl.range(0, k_tiles * tiles_per_SM, unroll_factor=F):`

where F comes from the autotuner. So far we've been giving it 2.

I don't understand how this helps since the unrolled iteration will have to be guarded by an IF op right?

So each original iteration comes with two if-checks, one to identify the start of the original conceptual inner K-loop, one is to identify the last iteration of the K-loop. Once unrolled by 2, there'd be four if-checks, two of which are unnecessary.

Yes the unrolled version comes with an epilog reminder loop, which can be unnecessary too. I think one more hint to the loop unroller should get it.

@ThomasRaoux
Copy link
Collaborator

I'm curious about the use case and if you considered using tl.static_range?

One typical use case is that for the GEMM persistent kernel, like #4662 proposes, to annotate the flattened loop. An unroll factor of 2 on tile 128x128x128 gives some speedup, where SMEM and register pressure aren't an issue. The speedup can be higher when combined with our subsequent work of branch removal, i.e, removing useless branch checking for the prolog and epilog checks.

but GEMM kernels usually don't have a static K loop? It doesn't sound like something that can be used in general cases?

actually looking at the code again I see that it supports dynamic case. Is that what you do for GEMM?

Yes, something like

`for _ in tl.range(0, k_tiles * tiles_per_SM, unroll_factor=F):`

where F comes from the autotuner. So far we've been giving it 2.

I don't understand how this helps since the unrolled iteration will have to be guarded by an IF op right?

So each original iteration comes with two if-checks, one to identify the start of the original conceptual inner K-loop, one is to identify the last iteration of the K-loop. Once unrolled by 2, there'd be four if-checks, two of which are unnecessary.

Yes the unrolled version comes with an epilog reminder loop, which can be unnecessary too. I think one more hint to the loop unroller should get it.

interesting, that makes sense

test/Triton/loop-unroll.mlir Outdated Show resolved Hide resolved
lib/Dialect/Triton/Transforms/LoopUnroll.cpp Outdated Show resolved Hide resolved
lib/Dialect/Triton/Transforms/LoopUnroll.cpp Outdated Show resolved Hide resolved
@htyu
Copy link
Collaborator Author

htyu commented Sep 9, 2024

@ThomasRaoux @antiagainst Thanks for reviewing this change! How does the latest version look to you? Please let me know if you have more comments and I'm happy to address them.

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine to me assuming the FE changes are ready as well

@htyu
Copy link
Collaborator Author

htyu commented Sep 9, 2024

Looks fine to me assuming the FE changes are ready as well

#4662 is for the FE side changes BTW.

@htyu htyu merged commit 7df871d into triton-lang:main Sep 9, 2024
7 checks passed
htyu pushed a commit that referenced this pull request Sep 21, 2024
This change exposes the scf For Loop attribute used in PR #4645 the
frontend. It does this by adding a field to tl.range (the same `as
num_stages`), this will allow setting loop unrolling factors like so:

```
@triton.jit
def _kernel(dst, v):
    pid = tl.program_id(axis=0)
    for i in tl.range(0, 10, loop_unroll_factor=2):
        tl.atomic_add(dst + pid, i + pid)
```

Unroll factors of less than 2 do nothing, but 2 or more results in the
loop body being replicated that number of times (similar to a clang
`#pragma unroll`).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants