torch.compile support, bug fixes & more
Pre-built binary wheels require PyTorch 2.4.0
Added
- fMHA: PagedBlockDiagonalGappyKeysMask
- fMHA: heterogeneous queries in triton_splitk
- fMHA: support for paged attention in flash
- fMHA: Added backwards pass for merge_attentions
- fMHA: Added torch.compile support for 3 biases (LowerTriangularMask, LowerTriangularMaskWithTensorBias and BlockDiagonalMask) - some might require PyTorch 2.4
- fMHA: Added torch.compile support in memory_efficient_attention when passing the flash operator explicitely (eg memory_efficient_attention(..., op=(flash.FwOp, flash.BwOp)))
- fMHA: memory_efficient_attention now expects its attn_bias argument to be on the same device as the other input tensor. Previously, it would convert the bias to the right device.
- fMHA: AttentionBias subclasses are now constructed by default on the cuda device if available - they used to be created on the CPU device
- 2:4 sparsity: Added xformers.ops.sp24.sparsify24_ste for Straight Through Estimator (STE) with options to rescale the gradient differently for masked out/kept values
Improved
- fMHA: Fixed out-of-bounds reading for Split-K triton implementation
- Profiler: fix bug with modules that take a single tuple as argument
- Profiler: Added manual trigger for a profiling step, by creating a trigger file in the profiling directory
Removed
- Removed support for PyTorch version older than 2.2.0