Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve FA fwd kernel with causal=True (#356)
* Attempt to absorb upstream's changes to improve causal=True * Add autotuner * Optimize for AMD MI250 - add pre_load_v as a tuning parameter - do not define N_CTX as constexpr - perform the second dot before sum - remove qk_scale out of the inner loop - add more configs in the autotuner Note that bwd kernel is disabled for now. This is because we enabled autotuning and grid becomes a function. So ctx.grid[0] no longer works. * Enable bwd kernel
- Loading branch information