-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux Vae broke for float16, force bfloat16 or float32 were compatible #7213
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
81e3001
to
496b02a
Compare
# VAE is broken in float16, use same logic in model loading to pick bfloat16 or float32 | ||
if x.dtype == torch.float16: | ||
try: | ||
x = x.to(torch.bfloat16) | ||
except TypeError: | ||
x = x.to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The AutoEncoder
class should not be responsible for casting input dtypes - it should simply raise if an incompatible dtype is passed, which it already does.
The casting should happen in the caller (flux_vae_encode.py
/ flux_vae_decode.py
). And rather than mirroring the model loading logic, we should simply cast the input to match the dtype of the model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a sensible way to get the vae dtype, as the original code just grabs the default / config dtype rather than grabbing the models and I can't see a sensible variable in AutoEncoder.
At the moment I'm looking at crap code like
vae_dtype = next(iter(vae.state_dict().items()))[1].dtype
or picking a state_dict key and getting the value, I'm assuming just grabbing the first value is more efficient ATM .
Added a bool to allow the node user to add noise in to initial latents (default) or to leave them alone.
Summary
The Flux VAE, like many VAEs, is broken if run using float16 inputs returning black images due to NaNs
This will fix the issue by forcing the VAE to run in bfloat16 or float32 were compatible
Related Issues / Discussions
Fix for issue #7208
QA Instructions
Tested on MacOS, VAE works with float16 in the invoke.yaml and left to default.
I also briefly forced it down the float32 route to check that to.
Needs testing on CUDA / ROCm
Merge Plan
It should be a straight forward merge,