Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
[...continuation of #9177]
Pytorch has had support for
float8_e4m3fn
andfloat8_e5m2
as storage dtypes for a while now. This allows one to store model weights in a lower precision dtype and upcast them on-the-fly when a layer is required for proceeding with computation.Code
Flux visual results
CogVideoX visual results
cogvideox-1.0---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
Assumptions made so far:
compute_dtype
storage_dtype
.Why is there no memory savings in the initial load memory?
We are first moving weights to VRAM and then performing the lower dtype casting. We should maybe look into directly allowing loading of weights of lower dtype
Why different "granularities"?
This was mostly an experiment and we don't need to use everything in the PR. I wanted to understand the affect of typecasting all weights vs some of them vs only the pytorch primitives. As usual, image models seem to be less affected by normalization casting (from
DIFFUSERS_MODEL
granularity compared to video models. However, the more granular we try to go, the more times weights are casted per inference step and more synchronizations are introduced with the current implementation, leading to slow downs in inference time. Allowing different levels of applying the typecasting hooks is akin to what we have formodel cpu offloading
vssequential cpu offloading
, and allows for some tradeoffs that users can choose based on their use cases.Is this compatible with
torch.compile
?No, it isn't because we overwrite the forward method of underlying models to invoke a pre-hook and post-hook. Both the pre and post hook change the state of the underlying model (downcast or upcast it) per forward pass, which makes it incompatible as it does not fit with the rules of
torch.compile
. Using@torch._dynamo.disable(recursive=False)
or similar does not seem to work.Why a different approach from #9177?
While providing the API to use this via
ModelMixin
is okay, it puts a restriction that requires all implementations to derive from it to use it. As this method can be generally applied to any modeling component, at any level of granularity, implementing it independent ofModelMixin
allows for its use in other modeling components like text encoders, which come from transformers, and any downstream research work or library can directly use it for their demos on Spaces without having to reimplement the wheel.Not opposed to the idea of having
enable_layerwise_upcasting
inModelMixin
, but let's do it in a way that does not impose any restrictions on how it's possible to use it.Also, the original PR typecasted all leaf nodes to storage dtype, but this may not be ideal for things like normalization and modulation, so supporting parameters like
skip_modules_pattern
andskip_modules_classes
helps ignore a few layers. We can default to sensible values, while to maintain another parameter per class for layers to not upcast/downcast. This is also one of the places where it helps to follow a common naming convention across all our models.Fixes #9949
cc @vladmandic @asomoza
TODOs:
non_blocking
and cuda streams for overlapping weight casting with computation without introducing many stream synchronizations on default streamTensor
,LongTensor
,BoolTensor
, etc. and we should not typecast all of them tocompute_dtype
, which would be incorrectNice reading material for the interested: