Skip to content

Commit

Permalink
Improve FA fwd kernel with causal=True (#356)
Browse files Browse the repository at this point in the history
* Attempt to absorb upstream's changes to improve causal=True

* Add autotuner

* Optimize for AMD MI250

- add pre_load_v as a tuning parameter
- do not define N_CTX as constexpr
- perform the second dot before sum
- remove qk_scale out of the inner loop
- add more configs in the autotuner

Note that bwd kernel is disabled for now. This is because we enabled
autotuning and grid becomes a function. So ctx.grid[0] no longer works.

* Enable bwd kernel
  • Loading branch information
zhanglx13 authored Oct 12, 2023
1 parent 6f073a4 commit 821e75a
Showing 1 changed file with 206 additions and 102 deletions.
Loading

0 comments on commit 821e75a

Please sign in to comment.