Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Layerwise Upcasting #10347

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

[core] Layerwise Upcasting #10347

wants to merge 7 commits into from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Dec 23, 2024

[...continuation of #9177]

Pytorch has had support for float8_e4m3fn and float8_e5m2 as storage dtypes for a while now. This allows one to store model weights in a lower precision dtype and upcast them on-the-fly when a layer is required for proceeding with computation.

Code
import argparse
import gc
import pathlib
import traceback

import git
import pandas as pd
import torch
from diffusers import AllegroPipeline, CogVideoXPipeline, CogVideoXImageToVideoPipeline, LattePipeline, FluxPipeline
from diffusers.models.hooks import LayerwiseUpcastingGranualarity, apply_layerwise_upcasting
from diffusers.utils import export_to_video, load_image
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate


repo = git.Repo(path="/home/aryan/work/diffusers")
branch = repo.active_branch


def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output


def prepare_flux(dtype: torch.dtype, compile: bool = False, **kwargs) -> None:
    model_id = "black-forest-labs/Flux.1-Dev"
    cache_dir = "/raid/.cache/huggingface"

    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A cat holding a sign that says hello world",
        "height": 768,
        "width": 768,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_cogvideox_1_0(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "THUDM/CogVideoX-5b"
    cache_dir = None

    pipe = CogVideoXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": (
            "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
            "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
            "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
            "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
            "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
            "atmosphere of this unique musical performance."
        ),
        "height": 480,
        "width": 720,
        "num_frames": 49,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_latte(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "maxin-cn/Latte-1"
    cache_dir = None

    pipe = LattePipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "a cat wearing sunglasses and working as a lifeguard at pool.",
        "height": 512,
        "width": 512,
        "video_length": 16,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_allegro(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "rhymes-ai/Allegro"
    cache_dir = None

    pipe = AllegroPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats.",
        "height": 720,
        "width": 1280,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


def decode_cogvideox_1_0(pipe: CogVideoXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_latte(pipe: LattePipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents, video_length=kwargs["video_length"])
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_allegro(pipe: AllegroPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


MODEL_MAPPING = {
    "flux": {
        "prepare": prepare_flux,
        "decode": decode_flux,
    },
    "cogvideox-1.0": {
        "prepare": prepare_cogvideox_1_0,
        "decode": decode_cogvideox_1_0,
    },
    "latte": {
        "prepare": prepare_latte,
        "decode": decode_latte,
    },
    "allegro": {
        "prepare": prepare_allegro,
        "decode": decode_allegro,
    },
}

STR_TO_DTYPE = {
    "float8_e4m3fn": torch.float8_e4m3fn,
    "float8_e5m2": torch.float8_e5m2,
    "bfloat16": torch.bfloat16,
    "float16": torch.float16,
    "float32": torch.float32,
}


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator("cuda").manual_seed(181201)
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


@torch.no_grad()
def main(model_id: str, granularity: str, output_dir: str, storage_dtype: str, compute_dtype: str, compile: bool = False):
    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"

    pytorch_storage_dtype = STR_TO_DTYPE[storage_dtype]
    pytorch_compute_dtype = STR_TO_DTYPE[compute_dtype]
    model = MODEL_MAPPING[model_id]

    try:
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

        # 1. Prepare inputs and generation kwargs
        pipe, generation_kwargs = model["prepare"](dtype=pytorch_compute_dtype, compile=compile)

        initial_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        initial_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 2. Apply layerwise upcasting technique
        if granularity == "diffusers_model":
            apply_layerwise_upcasting(
                pipe.transformer,
                storage_dtype=pytorch_storage_dtype,
                compute_dtype=pytorch_compute_dtype,
                granularity=LayerwiseUpcastingGranualarity.DIFFUSERS_MODEL,
            )
        elif granularity == "diffusers_layer":
            apply_layerwise_upcasting(
                pipe.transformer,
                storage_dtype=pytorch_storage_dtype,
                compute_dtype=pytorch_compute_dtype,
                granularity=LayerwiseUpcastingGranualarity.DIFFUSERS_LAYER,
                skip_modules_pattern=["pos_embed", "patch_embed", "norm"],
            )
        elif granularity == "pytorch_layer":
            apply_layerwise_upcasting(
                pipe.transformer,
                storage_dtype=pytorch_storage_dtype,
                compute_dtype=pytorch_compute_dtype,
                granularity=LayerwiseUpcastingGranualarity.PYTORCH_LAYER,
                skip_modules_pattern=["pos_embed", "patch_embed", "norm"],
            )
        elif granularity == "none":
            pass
        else:
            raise ValueError(f"Invalid {granularity=} provided.")
        
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()
        # We reset the peak memory stats to get the memory usage of the model after layerwise upcasting
        torch.cuda.reset_peak_memory_stats()

        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        model_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 3. Warmup
        num_warmups = 1
        original_num_inference_steps = generation_kwargs["num_inference_steps"]
        generation_kwargs["num_inference_steps"] = 2
        for _ in range(num_warmups):
            run_inference(pipe, generation_kwargs)
        generation_kwargs["num_inference_steps"] = original_num_inference_steps

        # 4. Benchmark
        time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 5. Decode latents
        filename = output_dir / f"{model_id}---storage_dtype-{storage_dtype}---compute_dtype-{compute_dtype}---granularity-{granularity}---compile-{compile}"
        filename = model["decode"](
            pipe,
            latents,
            filename,
            height=generation_kwargs["height"],
            width=generation_kwargs["width"],
            video_length=generation_kwargs.get("video_length", None),
        )

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "granularity": granularity,
            "storage_dtype": storage_dtype,
            "compute_dtype": compute_dtype,
            "compile": compile,
            "time": time,
            "initial_memory": initial_memory,
            "initial_max_memory_reserved": initial_max_memory_reserved,
            "model_memory_upcasted": model_memory,
            "model_max_memory_reserved_upcasted": model_max_memory_reserved,
            "inference_max_memory_reserved": inference_max_memory_reserved,
            "branch": branch,
            "filename": filename,
            "exception": None,
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "granularity": granularity,
            "storage_dtype": storage_dtype,
            "compute_dtype": compute_dtype,
            "compile": compile,
            "time": None,
            "initial_memory": None,
            "initial_max_memory_reserved": None,
            "model_memory_upcasted": None,
            "model_max_memory_reserved_upcasted": None,
            "inference_max_memory_reserved": None,
            "branch": branch,
            "filename": None,
            "exception": str(e),
        }

    pretty_print_results(info, precision=3)

    df = pd.DataFrame([info])
    df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux", "cogvideox-1.0", "latte", "allegro"],
        help="Model to run benchmark for.",
    )
    parser.add_argument(
        "--granularity",
        type=str,
        default="diffusers_model",
        choices=["diffusers_model", "diffusers_layer", "pytorch_layer", "none"],
        help="Cache method to use.",
    )
    parser.add_argument(
        "--output_dir", type=str, help="Path where the benchmark artifacts and outputs are the be saved."
    )
    parser.add_argument("--storage_dtype", type=str, choices=["float8_e4m3fn", "float8_e5m2", "bfloat16", "float16", "float32"], help="Storage torch.dtype to use for transformer")
    parser.add_argument("--compute_dtype", type=str, choices=["bfloat16", "float16", "float32"], help="Compute torch.dtype to use for transformer")
    parser.add_argument(
        "--compile",
        action="store_true",
        default=False,
        help="Whether to torch.compile the denoiser.",
    )
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
    args = parser.parse_args()

    if args.verbose:
        set_verbosity_debug()
    else:
        set_verbosity_info()

    main(args.model_id, args.granularity, args.output_dir, args.storage_dtype, args.compute_dtype, args.compile)
model_id granularity storage_dtype compute_dtype time initial_memory initial_max_memory_reserved model_memory_upcasted model_max_memory_reserved_upcasted inference_max_memory_reserved compile
flux none bfloat16 bfloat16 16.939 31.438 31.447 31.438 31.447 32.02 False
flux diffusers_model float8_e4m3fn bfloat16 23.866 31.438 31.447 20.354 21.178 33.963 False
flux diffusers_layer float8_e4m3fn bfloat16 18.125 31.438 31.447 28.736 28.779 29.291 False
flux pytorch_layer float8_e4m3fn bfloat16 20.339 31.438 31.447 23.378 24.449 24.945 False
flux diffusers_model float8_e5m2 bfloat16 22.097 31.438 31.447 20.353 21.18 33.949 False
flux diffusers_layer float8_e5m2 bfloat16 18.013 31.438 31.447 28.735 28.797 29.309 False
flux pytorch_layer float8_e5m2 bfloat16 20.084 31.44 31.451 23.38 24.451 24.947 False
cogvideox-1.0 none bfloat16 bfloat16 244.255 19.661 19.678 19.661 19.678 24.426 False
cogvideox-1.0 diffusers_model float8_e4m3fn bfloat16 243.65 19.661 19.678 14.473 14.531 25.217 False
cogvideox-1.0 diffusers_layer float8_e4m3fn bfloat16 243.541 19.66 19.678 16.705 16.76 21.469 False
cogvideox-1.0 pytorch_layer float8_e4m3fn bfloat16 243.346 19.661 19.678 15.228 15.281 19.992 False
cogvideox-1.0 diffusers_model float8_e5m2 bfloat16 243.899 19.661 19.678 14.473 14.531 25.217 False
cogvideox-1.0 diffusers_layer float8_e5m2 bfloat16 243.182 19.661 19.678 16.705 16.76 21.469 False
cogvideox-1.0 pytorch_layer float8_e5m2 bfloat16 243.136 19.661 19.678 15.228 15.281 19.992 False
Flux visual results
Baseline
diffusers_model-float8_e4m3 diffusers_model-float8_e5m2
diffusers_layer-float8_e4m3 diffusers_layer-float8_e5m2
pytorch_layer-float8_e4m3 pytorch_layer-float8_e5m2
CogVideoX visual results
Baseline
cogvideox-1.0---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
diffusers_model-float8_e4m3 diffusers_model-float8_e5m2
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
diffusers_layer-float8_e4m3 diffusers_layer-float8_e5m2
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
pytorch_layer-float8_e4m3 pytorch_layer-float8_e5m2
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4

Assumptions made so far:

  • The input to the models with a hook are not casted, and are expected to already be in compute_dtype
  • Weight casting learned parameters of normalization layers can lead to poor quality as we've seen in the past few integrations. By default, layers for normalization and modulation are not downcasted to storage_dtype.
  • Sensible default names to avoid embedding, normalization and modulation layers. This is still configurable so users can choose to typecast them if they want.

Why is there no memory savings in the initial load memory?

We are first moving weights to VRAM and then performing the lower dtype casting. We should maybe look into directly allowing loading of weights of lower dtype


Why different "granularities"?

This was mostly an experiment and we don't need to use everything in the PR. I wanted to understand the affect of typecasting all weights vs some of them vs only the pytorch primitives. As usual, image models seem to be less affected by normalization casting (from DIFFUSERS_MODEL granularity compared to video models. However, the more granular we try to go, the more times weights are casted per inference step and more synchronizations are introduced with the current implementation, leading to slow downs in inference time. Allowing different levels of applying the typecasting hooks is akin to what we have for model cpu offloading vs sequential cpu offloading, and allows for some tradeoffs that users can choose based on their use cases.


Is this compatible with torch.compile?

No, it isn't because we overwrite the forward method of underlying models to invoke a pre-hook and post-hook. Both the pre and post hook change the state of the underlying model (downcast or upcast it) per forward pass, which makes it incompatible as it does not fit with the rules of torch.compile. Using @torch._dynamo.disable(recursive=False) or similar does not seem to work.


Why a different approach from #9177?

While providing the API to use this via ModelMixin is okay, it puts a restriction that requires all implementations to derive from it to use it. As this method can be generally applied to any modeling component, at any level of granularity, implementing it independent of ModelMixin allows for its use in other modeling components like text encoders, which come from transformers, and any downstream research work or library can directly use it for their demos on Spaces without having to reimplement the wheel.

Not opposed to the idea of having enable_layerwise_upcasting in ModelMixin, but let's do it in a way that does not impose any restrictions on how it's possible to use it.

Also, the original PR typecasted all leaf nodes to storage dtype, but this may not be ideal for things like normalization and modulation, so supporting parameters like skip_modules_pattern and skip_modules_classes helps ignore a few layers. We can default to sensible values, while to maintain another parameter per class for layers to not upcast/downcast. This is also one of the places where it helps to follow a common naming convention across all our models.


Fixes #9949

cc @vladmandic @asomoza

TODOs:

  • Explore non_blocking and cuda streams for overlapping weight casting with computation without introducing many stream synchronizations on default stream
  • Try to make torch compile work
  • Figure out how to handle typecasting of inputs. Inputs can be Tensor, LongTensor, BoolTensor, etc. and we should not typecast all of them to compute_dtype, which would be incorrect
  • Test with LoRAs
  • Test with training in https://github.com/a-r-r-o-w/finetrainers
  • Test tensor caching in lower precision for methods like [core] Pyramid Attention Broadcast #9562 and [core] FasterCache #10163
  • Tests
  • Docs

Nice reading material for the interested:

@a-r-r-o-w a-r-r-o-w requested review from DN6, sayakpaul and hlky December 23, 2024 00:14
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Experimental] expose dynamic upcasting of layers as experimental APIs
2 participants