-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature]: Memory Efficient Flash Attention for gfx1100 (7900xtx) #16
Comments
We are already working on the efficient attention. However since AOTriton uses Triton as compiler, the actual landing of the gfx1100 support for efficient attention may take longer.
Now AOTriton is only compiled with MI200/300 series DC GPUs (commonly known as CDNA2/3 architectures). We are going to add Navi targets once the Triton compiler supports them.
Unfortunately this does not help due to missing proper support of Navi3x in Triton, mainly about missing WMMA compiler support. We are actively working on it and you'll see its landing as soon as it gets supported. |
Thank you for the detailed update! I'll be following closely as I'm super interested in maximizing the utility of the 7900s and seeing more options on the market to compete against Nvidia. |
it's seems gfx1100 triton support is upstreamed , ROCm/triton#250 (comment) accord this issue comment and it's show the triton flash-att performance not good |
I was just going to reference that I saw this issue said compiler support for RDNA3 is available in triton but i can't find much info beyond that yet: triton-lang/triton#3704 but in this case, someone was on an older generation card not a 7900xtx |
RDNA support is a more complicated topic because it requires an upgrade to the Triton compiler, which causes quite a few compatibility problems. |
cc: @jayfurmanek what is the status of Navi support on both repos? |
Any update on flash attention being in released builds of aotriton? would love to give this a whirl soon. |
We are doing experiments on Navi and fixing Triton compiler problems now (See: ROCm/triton#596 for more details) A newer compiler is necessary to support Navi. |
I use upstreaming triton and it's triton flash-att impl not faster as it's expected
rlst is
|
I got the following numbers by running
I also managed to use aotriton's Flash Attention implementation with upstream Triton, and have it working in SD:Next. The performance is slightly slower (10it/s to 8it/s), but it uses lesser VRAM (upscale 512x768 images to 2.5x w/o OOM, which was 1.6x, w/ VAE Tiling disabled). |
I wrote a flash attention 2 implementaton by using rocWMMA library: |
Amazing work!! |
maybe, I packaged it as a extension and tested work in WSL: |
Thanks a ton! This works well. Had to install Ninja. |
Does this bring us anywhere closer to getting Flash Attention forward and backward pass working on Radeon 7000 GPUs (gfx1100) using PyTorch or is this implementation only for Stable Diffusion? |
If you mean using it in transformer lib, I believe it's still quite far away from the offical flash attn v2, we have to implement all the flash_attn.flash_attn_interface and other utils like flash_attn.bert_padding. |
For anyone wondering, there's also a CK-based version for Navi3x (ROCm/flash-attention, howiejay/navi_support branch) described here: ROCm/flash-attention#27 (comment) It's fast, but it's also FA version 2.0.6 and only the forward pass works. |
The howiejay CK is a lot faster in forward than the Repeerc wmma going by his own numbers. By monkey patching the torch SDPA function I can hit something like 3.8 it/s on low-power XTX for 1024² SDXL The Repeerc version at a self-reported 3.5 is closer to what I get from the DaoLab Triton JIT version after adding a custom autotune config. ...All this and the base 3090 is still a good deal faster from what I can tell. |
The backward pass is what interests me about Repeerc's version, since training uses more memory and torch sdpa does not have a memory-efficient fallback. Unfortunately it still seems to have stability issues. |
## What's Changed 1. A whole new tuning system (referred as `cpptune`/`cpp_tune`/`cpptuning`) based on pre-compiling all GPU kernels with CMake option `AOTRITON_BUILD_FOR_TUNING` and kernel selection parameters provided by all AOTriton API 2. GPU kernel compiling can timeout (default limit is 8 minutes), to avoid excessive long Navi31 kernel build 3. Migrating the backward kernel away from block pointers 4. Improved backward kernel performance by using better tuning database generated from cpptune. 5. Add Navi31 to tuning database 6. Enable Navi31 by default 7. Default to AOTRITON_COMPRESS_KERNEL=ON and consequently requires zstd as runtime dependency 8. Use `pkg-config` to search zstd since `find_package(zstd)` is not supported officially. ## Known problems 1. No official Navi32 support. Users may want to duplicate Navi31 tuning database entries to accomplish Navi 32 support in AOTriton. This fixes #16
Do we have any of the FA implementations working for training (such as lora training via Kohya) ? |
@sancspro Wait for pytorch to update to aotriton v0.7b. The PR is pytorch/pytorch#134498 |
Ok, thanks for the info @feffy380 So after this PR is pushed, I just need to update Python and then enable SDPA in cross-attention setting in Kohya to make use of FA? |
Pytorch nightly out - torch: 2.5.0.dev20240912+rocm6.2 SDP - flash attention works out of the box. No tweaking, special configuration etc. Just enable SDP and use env var: 7800XT unleashed! :) |
which branch of flash_attention do you install after the nightly update? |
Hi. I didn't use any external FA. For example, for auto1111 I enabled SDP and then I saw VRAM reduction immediately. I think they've integrated AOTriton into Pytorch which works out of the box for some GPUs. Now, they added RDNA3 to the list of supported GPUs. Note that I also set an env variable which I mentioned above. |
Hello @sancspro could I ask you for some numbers you have with and without FA. I tested it with comfyUI and saw a super small vram reduction in sd1.5[512x512] (about 2gb) and in sdxl [1024x1024] I saw no memory savings at all, only a speed increase. |
Yes this experimental function works, I tested in a pure matrix test.py with setting os env variable in python as your suggestion, worked well on my 7900! Just how can you enable SDP in automatic111? I don't know where to set env var in automatic yet, I tried to simply enable SDP in automatic 111 settings, and it always came back with a failure notice as: My environment, ubuntu2404+Rocm6.2+Torch2.5+Py3.10+Automatic111v1.10.1. |
ok I managed to set env var in the launch.py directly, now SDP works. but the effect here is NOT faster, but saving more vram. it's about 50% slower than doggettx, but use only 1/3 vram of doggettx. Not sure that's the way it should, or anywhere I haven't set right... |
This is a known problem, which actually motivates the following decisions
|
got it, so SDP should be faster but not there yet, hopefully in next iteration of ROCm/pyTorch can be solved ane this eperimental feature comes to formal |
@sancspro Grad_norm becomes crazy when training, model doesn't converge. |
Suggestion Description
Started using torchlearn to train models in pytorch using my gfx1100 card but get a warning that 1toch was not compiled with memory efficient flash attention.
I see there is a recently merged patch pending nightly in pytorch for adding rocm flash attention support but it looks like it's only targeting the MI200 and 300 cards. Any plans to support consumer/workstation cards? Can i compile this in myself today?
torchlearn is getting about 1.6it/s compared to 3it/s some people get with a 4090... while my card costs 1/2 as much as a 4090, i should be within 80-85% the performance but i think the lack of memory efficient flash attention is hindering
Operating System
ubuntu 22.0.4
GPU
gfx1100
ROCm Component
No response
The text was updated successfully, but these errors were encountered: