Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux Vae broke for float16, force bfloat16 or float32 were compatible #7213

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

Vargol
Copy link
Contributor

@Vargol Vargol commented Oct 27, 2024

Summary

The Flux VAE, like many VAEs, is broken if run using float16 inputs returning black images due to NaNs
This will fix the issue by forcing the VAE to run in bfloat16 or float32 were compatible

Related Issues / Discussions

Fix for issue #7208

QA Instructions

Tested on MacOS, VAE works with float16 in the invoke.yaml and left to default.
I also briefly forced it down the float32 route to check that to.
Needs testing on CUDA / ROCm

Merge Plan

It should be a straight forward merge,

@github-actions github-actions bot added python PRs that change python files backend PRs that change backend files labels Oct 27, 2024
Copy link
Collaborator

@brandonrising brandonrising left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Comment on lines 315 to 320
# VAE is broken in float16, use same logic in model loading to pick bfloat16 or float32
if x.dtype == torch.float16:
try:
x = x.to(torch.bfloat16)
except TypeError:
x = x.to(torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AutoEncoder class should not be responsible for casting input dtypes - it should simply raise if an incompatible dtype is passed, which it already does.

The casting should happen in the caller (flux_vae_encode.py / flux_vae_decode.py). And rather than mirroring the model loading logic, we should simply cast the input to match the dtype of the model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a sensible way to get the vae dtype, as the original code just grabs the default / config dtype rather than grabbing the models and I can't see a sensible variable in AutoEncoder.

At the moment I'm looking at crap code like

vae_dtype = next(iter(vae.state_dict().items()))[1].dtype

or picking a state_dict key and getting the value, I'm assuming just grabbing the first value is more efficient ATM .

invokeai/backend/flux/modules/autoencoder.py Outdated Show resolved Hide resolved
invokeai/backend/model_manager/load/model_loaders/flux.py Outdated Show resolved Hide resolved
invokeai/backend/model_manager/load/model_loaders/flux.py Outdated Show resolved Hide resolved
@github-actions github-actions bot added invocations PRs that change invocations frontend PRs that change frontend files labels Nov 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend PRs that change backend files frontend PRs that change frontend files invocations PRs that change invocations python PRs that change python files
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants