-
Notifications
You must be signed in to change notification settings - Fork 8.9k
Remove AMD workarounds as they do more harm than good on recent ROCm releases. #8289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
pytorch rocm 6.4 is still nightly so I can't merge some of these changes as is or it's going to break people using the stable pytorch which is on 6.3 |
Ok, fair, it can wait. I was using pytorch wheels from AMD https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4.1/ to validate those changes. |
It causes crashes even without pytorch attention for big sizes, and for resonable sizes it is significantly faster. This reverts commit 1cd6cd6.
It's is significantly faster, 8 it/s vs 12 it/s on 9070 xt, ROCm 6.4.1.
It works fine on ROCm 6.4.1. Also it is faster and avoid OOMs in VAE Decode. Fixes: comfyanonymous#7400 Fixes: ROCm/ROCm#4742
I've rebased. I think we should merge this. We could add some version checks, but it's not clear to me on what versions it is broken. (also upstream ROCm 6.4 will probably be on next pytorch release, which can take months) The main issue currently is that those |
IMO, it is much safer if we have flag to enable the new is_amd checking logic until we actually benchmark and see that there are no performance regression on older amd cards. |
It's already possible to force VAE type. There is no need to have hardcoded is_amd cases. Also it's was reported to cause OOMs on both new and old gpus. |
@kasper93 sorry to bother you. I tried your fork with this commit and the massive memory usage during VAE Decode persists. Also, adding the following flag: |
Well, yes. It still uses lots of memory, but it's less than before and previously fallback to tailed vae would take ages because of conversions, this seems to work more snappy too.
Maybe, but it's a lot faster. The main goal of this PR is to remove hardcoded conditions and allow to evaluate other components in the pipeline. |
Thanks for the fast reply and the explanations. I forgot to mention that I haven't noticed any problems with using BF16 so far, if it can helps. |
Tested on gfx1201, ROCm 6.4.1. It fixes VAE Decode issues and generally performance is better with pytorch flash attention. bfloats16 are working fine.