Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] TorchAO Quantizer #10009

Merged
merged 39 commits into from
Dec 16, 2024
Merged

[core] TorchAO Quantizer #10009

merged 39 commits into from
Dec 16, 2024

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Nov 25, 2024

What does this PR do?

Adds support for the TorchAO Quantizer.

Quantization formats

TorchAO supports a wide variety of quantizations. For a normal user, it can get quite overwhelming to understand all the different parameters and what they mean. In order to simplify this a bit, I've used some custom commonly used names that are easier to remember/use while also supporting full configurability of the original arguments.

The naming conventions used are:

  • full function names as in torchao: int8_weight_only, float8_weight_only, etc. You can pass the arguments supported by each method (as described in torchao docs) it through quantization kwargs.
  • wo (weight-only) and dq (weight + activation quantization) suffixes
  • {dtype}_a{x}w{y}: shorthand notations for convenience reasons and because the a{x}w{y} notation is used extensively in the torchao docs
  • float8 quantization also supports per tensor and per row granularity. per tensor is suffixed with _tensor and per row is suffixed with _row. per axis and per group granularity is also supported but they involve additional parameters in their constructors so power-users are free to play with that if they like, but the shorthands provided here are just for tensor/row.

Broadly, int4, int8, uintx, fp8 and fpx quantizations are supported, with dynamic activation quants where applicable, otherwise weight-only. Group sizes can be specified by power-users via the full function names and we don't have special names to handle those.

Benchmarks

The following code is used for benchmarking:

Code
import argparse
import gc
import os
import pathlib
import traceback

# os.environ["TORCH_LOGS"] = "+dynamo,graph_breaks,recompiles"
# os.environ["TORCHDYNAMO_VERBOSE"] = "1"

import git
import pandas as pd
import torch
import torch.utils.benchmark as benchmark
from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel, FluxPipeline, FluxTransformer2DModel, TorchAoConfig
from diffusers.training_utils import set_seed
from diffusers.utils import export_to_video
from tabulate import tabulate
from torchao.quantization.utils import recommended_inductor_config_setter

recommended_inductor_config_setter()

set_seed(42)

PROMPT = "A dramatic landscape on an exoplanet with a breathtaking view of a ringed gas giant in the sky. The planet's surface is rugged and alien, green and violet colored rocky lands and mountains, with strange rock formations. The rings of the reddish-yellow gas giant cast colorful shadows and reflections, creating a surreal and captivating environment."


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)
    
    return elapsed_time, output


def pretty_print_results(results, precision: int = 6):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def precompute_flux_embeds(dtype: torch.dtype, output_dir: pathlib.Path):
    model_id = "black-forest-labs/Flux.1-Dev"
    cache_dir = "/raid/.cache/huggingface"

    pipe = FluxPipeline.from_pretrained(
        model_id,
        transformer=None,
        vae=None,
        torch_dtype=dtype,
        cache_dir=cache_dir,
    )
    pipe.to("cuda")

    prompt_embeds, pooled_prompt_embeds, text_ids = pipe.encode_prompt(
        prompt=PROMPT,
        prompt_2=PROMPT,
        device="cuda",
        num_images_per_prompt=1,
        max_sequence_length=512,
    )

    torch.save(prompt_embeds, output_dir / "prompt_embeds.pt")
    torch.save(pooled_prompt_embeds, output_dir / "pooled_prompt_embeds.pt")


def precompute_cogvideox_embeds(dtype: torch.dtype, output_dir: pathlib.Path):
    model_id = "THUDM/CogVideoX1.5-5b"
    cache_dir = None

    pipe = CogVideoXPipeline.from_pretrained(
        model_id,
        transformer=None,
        vae=None,
        torch_dtype=dtype,
        cache_dir=cache_dir,
    )
    pipe.to("cuda")

    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt=PROMPT,
        negative_prompt=None,
        do_classifier_free_guidance=True,
        num_videos_per_prompt=1,
        max_sequence_length=226,
        device="cuda",
    )

    torch.save(prompt_embeds, output_dir / "prompt_embeds.pt")
    torch.save(negative_prompt_embeds, output_dir / "negative_prompt_embeds.pt")


def load_flux_embeds(dir: pathlib.Path):
    prompt_embeds = torch.load(dir / "prompt_embeds.pt", weights_only=True)
    pooled_prompt_embeds = torch.load(dir / "pooled_prompt_embeds.pt", weights_only=True)

    return {
        "prompt_embeds": prompt_embeds,
        "pooled_prompt_embeds": pooled_prompt_embeds,
    }


def load_cogvideox_embeds(dir: pathlib.Path):
    prompt_embeds = torch.load(dir / "prompt_embeds.pt", weights_only=True)
    negative_prompt_embeds = torch.load(dir / "negative_prompt_embeds.pt", weights_only=True)

    return {
        "prompt_embeds": prompt_embeds,
        "negative_prompt_embeds": negative_prompt_embeds,
    }


def prepare_flux(
    dtype: torch.dtype,
    quantization_config: TorchAoConfig,
    compile: bool = False,
    **kwargs,
):
    model_id = "black-forest-labs/Flux.1-Dev"
    cache_dir = "/raid/.cache/huggingface"

    transformer = FluxTransformer2DModel.from_pretrained(
        model_id,
        subfolder="transformer",
        quantization_config=quantization_config,
        cache_dir=cache_dir,
        torch_dtype=dtype,
    )
    pipe = FluxPipeline.from_pretrained(
        model_id,
        text_encoder=None,
        text_encoder_2=None,
        transformer=transformer,
        torch_dtype=dtype,
        cache_dir=cache_dir,
    )
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
    
    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)
    
    generation_kwargs = {
        "height": 768,
        "width": 768,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_cogvideox(
    dtype: torch.dtype,
    quantization_config: TorchAoConfig,
    compile: bool = False,
    **kwargs,
):
    model_id = "THUDM/CogVideoX1.5-5b"
    cache_dir = None

    transformer = CogVideoXTransformer3DModel.from_pretrained(
        model_id,
        subfolder="transformer",
        quantization_config=quantization_config,
        cache_dir=cache_dir,
        torch_dtype=dtype,
    )
    pipe = CogVideoXPipeline.from_pretrained(
        model_id,
        text_encoder=None,
        transformer=transformer,
        torch_dtype=dtype,
        cache_dir=cache_dir,
    )
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
    
    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)
    
    generation_kwargs = {
        "height": 768,
        "width": 1360,
        "num_frames": 81,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


def decode_cogvideox(pipe: CogVideoXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=16)
    return filename


def _generate_fpx_quantization_types(bits):
    types = []
    for ebits in range(0, bits):
        mbits = bits - ebits - 1
        types.append(f"fp{bits}_e{ebits}m{mbits}")
    return types


MODEL_MAPPING = {
    "flux": {
        "precompute": precompute_flux_embeds,
        "load": load_flux_embeds,
        "prepare": prepare_flux,
        "decode": decode_flux,
    },
    "cogvideox": {
        "precompute": precompute_cogvideox_embeds,
        "load": load_cogvideox_embeds,
        "prepare": prepare_cogvideox,
        "decode": decode_cogvideox,
    },
}

STR_TO_COMPUTE_DTYPE = {
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
}

QUANTIZATION_TYPES_TO_TEST = [
    "none",
    "int4wo", "int4dq", "int8wo", "int8dq",
    "uint1wo", "uint2wo", "uint3wo", "uint4wo", "uint5wo", "uint6wo", "uint7wo", "uint8wo",
]

if TorchAoConfig._is_cuda_capability_atleast_8_9():
    QUANTIZATION_TYPES_TO_TEST.extend([
        "float8wo_e5m2", "float8wo_e4m3",
        "float8dq_e4m3",
        "float8dq_e4m3_tensor", "float8dq_e4m3_row",
        *_generate_fpx_quantization_types(3),
        *_generate_fpx_quantization_types(4),
        *_generate_fpx_quantization_types(5),
        *_generate_fpx_quantization_types(6),
        *_generate_fpx_quantization_types(7),
        *_generate_fpx_quantization_types(8),
    ])


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator().manual_seed(3047)
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


@torch.no_grad()
def main(model_id: str, output_dir: str, dtype: str, compile: bool = False):
    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")
    
    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"
    
    compute_dtype = STR_TO_COMPUTE_DTYPE[dtype]
    model = MODEL_MAPPING[model_id]

    model["precompute"](compute_dtype, output_dir)
    
    repo = git.Repo(path="/home/aryan/work/diffusers")
    branch = repo.active_branch

    for quantization_type in QUANTIZATION_TYPES_TO_TEST:
        try:
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.reset_accumulated_memory_stats()
            gc.collect()
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.synchronize()
            
            # 1. Prepare inputs and quantization config
            quantization_config = TorchAoConfig(quant_type=quantization_type) if quantization_type != "none" else None
            kwargs = model["load"](output_dir)
            pipe, generation_kwargs = model["prepare"](compute_dtype, quantization_config, compile, **kwargs)
            before_inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
            before_inference_max_memory = round(torch.cuda.max_memory_allocated() / 1024**3, 3)
            before_inference_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
            
            # 2. Warmup
            num_warmups = 1
            for _ in range(num_warmups):
                run_inference(pipe, generation_kwargs)
            
            # 3. Benchmark
            time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)

            after_inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
            after_inference_max_memory = round(torch.cuda.max_memory_allocated() / 1024**3, 3)
            after_inference_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

            # 4. Decode latent
            filename = output_dir / f"{model_id}---dtype-{dtype}---qtype-{quantization_type}---compile-{compile}"
            filename = model["decode"](pipe, latents, filename, height=generation_kwargs["height"], width=generation_kwargs["width"])

            # 5. Save artifacts
            info = {
                "model_id": model_id,
                "quantization_type": quantization_type,
                "compute_dtype": dtype,
                "compile": compile,
                "time": time,
                "before_inference_memory": before_inference_memory,
                "before_inference_max_memory": before_inference_max_memory,
                "before_inference_max_memory_reserved": before_inference_max_memory_reserved,
                "after_inference_memory": after_inference_memory,
                "after_inference_max_memory": after_inference_max_memory,
                "after_inference_max_memory_reserved": after_inference_max_memory_reserved,
                "branch": branch,
                "filename": filename,
                "exception": None,
            }
        except Exception as e:
            print(f"An error occurred: {e}")
            traceback.print_exc()
            
            # 5. Save artifacts
            info = {
                "model_id": model_id,
                "quantization_type": quantization_type,
                "compute_dtype": dtype,
                "compile": compile,
                "time": None,
                "before_inference_memory": None,
                "before_inference_max_memory": None,
                "before_inference_max_memory_reserved": None,
                "after_inference_memory": None,
                "after_inference_max_memory": None,
                "after_inference_max_memory_reserved": None,
                "branch": branch,
                "filename": None,
                "exception": str(e),
            }
        
        pretty_print_results(info, precision=3)

        df = pd.DataFrame([info])
        df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux", "cogvideox"],
        help="Model to run benchmark for.",
    )
    parser.add_argument("--output_dir", type=str, help="Path where the benchmark artifacts and outputs are the be saved.")
    parser.add_argument("--dtype", type=str, help="torch.dtype to use for inference")
    parser.add_argument(
        "--compile",
        action="store_true",
        default=False,
        help="Whether to torch.compile the denoiser.",
    )
    args = parser.parse_args()

    main(args.model_id, args.output_dir, args.dtype, args.compile)

You can launch it with something like:

#!/bin/bash

MODEL_IDS=("flux" "cogvideox")
COMPILE_OPTIONS=("" "--compile")

for compile in "${COMPILE_OPTIONS[@]}"; do
  for model_id in "${MODEL_IDS[@]}"; do
    cmd="python3 benchmark.py --model_id $model_id --output_dir torchao_benchmark_results --dtype bf16 $compile"

    echo "Running command: $cmd"
    eval $cmd
    echo -ne "-------------------- Finished executing script --------------------\n\n"
  done
done

Here are the time/memory results from a single H100:

Flux Table
model_id quantization_type compute_dtype compile time before_inference_memory before_inference_max_memory before_inference_max_memory_reserved after_inference_memory after_inference_max_memory after_inference_max_memory_reserved branch filename exception
flux none bf16 False 6.856 22.364 22.364 22.377 22.365 22.633 23.018 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-none---compile-False.png
flux int4wo bf16 False 68.017 6.219 28.553 28.564 6.219 28.553 28.564 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-int4wo---compile-False.png
flux int4dq bf16 False 23.931 18.845 19.013 30.631 18.845 19.525 30.631 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-int4dq---compile-False.png
flux int8wo bf16 False 9.742 24.014 24.183 24.537 24.014 24.319 24.959 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-int8wo---compile-False.png
flux int8dq bf16 False 229.463 22.58 24.014 27.025 22.58 24.014 27.025 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-int8dq---compile-False.png
flux uint1wo bf16 False 19.064 13.974 22.58 26.158 13.974 22.58 26.158 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint1wo---compile-False.png
flux uint2wo bf16 False 16.781 6.647 13.974 16.592 6.647 13.974 16.592 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint2wo---compile-False.png
flux uint3wo bf16 False 24.523 9.529 9.695 9.711 9.529 10.124 10.631 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint3wo---compile-False.png
flux uint4wo bf16 False 16.599 12.345 12.512 12.527 12.345 12.942 13.395 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint4wo---compile-False.png
flux uint5wo bf16 False 22.05 15.115 15.281 15.299 15.115 15.712 16.219 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint5wo---compile-False.png
flux uint6wo bf16 False 20.503 17.93 18.097 18.18 17.93 18.526 19.1 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint6wo---compile-False.png
flux uint7wo bf16 False 27.758 20.85 21.019 21.035 20.85 21.447 22.061 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint7wo---compile-False.png
flux float8wo_e5m2 bf16 False 10.207 22.527 22.693 23.51 22.527 22.903 24.014 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-float8wo_e5m2---compile-False.png
flux float8wo_e4m3 bf16 False 10.31 22.534 22.702 26.08 22.534 22.909 26.08 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-float8wo_e4m3---compile-False.png
flux float8dq_e4m3 bf16 False 17.203 22.523 22.691 23.236 22.523 23.047 24.078 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-float8dq_e4m3---compile-False.png
flux float8dq_e4m3_tensor bf16 False 17.194 22.534 22.702 23.234 22.534 23.058 24.084 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-float8dq_e4m3_atwt---compile-False.png
flux float8dq_e4m3_row bf16 False 16.906 22.534 22.702 23.234 22.534 23.058 24.084 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-float8dq_e4m3_arwr---compile-False.png
flux fp3_e1m1 bf16 False 75.212 15.585 15.753 16.404 15.585 17.2 18.76 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp3_e1m1---compile-False.png
flux fp3_e2m0 bf16 False 76.124 8.628 15.585 20.33 8.628 15.585 20.33 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp3_e2m0---compile-False.png
flux fp4_e1m2 bf16 False 59.883 10.01 10.171 10.777 10.01 11.203 12.549 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp4_e1m2---compile-False.png
flux fp4_e2m1 bf16 False 50.86 11.401 11.561 14.119 11.401 13.016 14.393 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp4_e2m1---compile-False.png
flux fp4_e3m0 bf16 False 51.619 11.401 11.565 15.963 11.401 13.016 15.963 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp4_e3m0---compile-False.png
flux fp5_e1m3 bf16 False 93.817 12.749 12.913 13.676 12.749 13.942 15.447 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp5_e1m3---compile-False.png
flux fp5_e2m2 bf16 False 82.817 14.094 14.257 17.018 14.094 15.287 17.127 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp5_e2m2---compile-False.png
flux fp5_e3m1 bf16 False 73.498 14.09 14.252 18.697 14.09 15.705 18.697 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp5_e3m1---compile-False.png
flux fp5_e4m0 bf16 False 74.348 14.09 14.25 19.264 14.09 15.705 19.264 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp5_e4m0---compile-False.png
flux fp6_e1m4 bf16 False 86.639 15.589 15.75 17.119 15.589 16.782 18.891 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp6_e1m4---compile-False.png
flux fp6_e2m3 bf16 False 64.291 17.085 17.246 20.461 17.085 18.278 20.779 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp6_e2m3---compile-False.png
flux fp6_e3m2 bf16 False 53.167 17.083 17.246 22.35 17.083 18.277 22.35 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp6_e3m2---compile-False.png
flux fp6_e4m1 bf16 False 43.95 17.084 17.247 22.352 17.084 18.699 22.352 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp6_e4m1---compile-False.png
flux fp6_e5m0 bf16 False 44.676 17.082 17.245 22.918 17.082 18.698 22.918 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp6_e5m0---compile-False.png
flux fp7_e1m5 bf16 False 178.858 18.472 18.635 20.07 18.472 19.666 21.854 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp7_e1m5---compile-False.png
flux fp7_e2m4 bf16 False 135.323 19.853 20.018 23.424 19.853 21.047 23.424 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp7_e2m4---compile-False.png
flux fp7_e3m3 bf16 False 111.125 19.85 20.011 24.592 19.85 21.045 24.592 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp7_e3m3---compile-False.png
flux fp7_e4m2 bf16 False 100.502 19.852 20.014 24.592 19.852 21.046 24.592 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp7_e4m2---compile-False.png
flux fp7_e5m1 bf16 False 90.251 19.851 20.014 24.592 19.851 21.467 24.592 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp7_e5m1---compile-False.png
flux fp7_e6m0 bf16 False 91.344 19.852 20.014 25.158 19.852 21.468 25.158 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp7_e6m0---compile-False.png
flux none bf16 True 4.305 22.364 22.364 22.377 22.333 22.861 23.156 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-none---compile-True.png
flux int4wo bf16 True 64.279 6.219 28.553 29.342 6.188 28.553 29.342 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-int4wo---compile-True.png
flux int4dq bf16 True 5.188 18.845 19.013 19.893 12.793 19.013 19.893 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-int4dq---compile-True.png
flux int8wo bf16 True 5.25 24.014 24.183 25.504 11.359 24.183 26.125 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-int8wo---compile-True.png
flux int8dq bf16 True 3.663 22.58 22.749 24.242 11.359 22.749 24.723 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-int8dq---compile-True.png
flux uint1wo bf16 True 5.196 13.974 14.143 15.102 2.753 14.143 15.242 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint1wo---compile-True.png
flux uint2wo bf16 True 4.943 6.645 6.811 8.67 4.027 6.811 8.67 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint2wo---compile-True.png
flux uint3wo bf16 True 6.134 9.524 9.691 10.674 5.633 9.691 10.674 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint3wo---compile-True.png
flux uint4wo bf16 True 4.864 12.344 12.51 13.49 6.846 12.51 13.49 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint4wo---compile-True.png
flux uint5wo bf16 True 5.836 15.115 15.281 16.041 8.404 15.281 16.041 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint5wo---compile-True.png
flux uint6wo bf16 True 5.295 17.93 18.097 19.141 9.662 18.097 19.141 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint6wo---compile-True.png
flux uint7wo bf16 True 7.183 20.849 21.017 22.002 11.324 21.017 22.002 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-uint7wo---compile-True.png
flux float8wo_e5m2 bf16 True 4.736 22.525 22.69 24.457 11.335 22.69 24.457 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-float8wo_e5m2---compile-True.png
flux float8wo_e4m3 bf16 True 4.784 22.534 22.702 24.311 11.336 22.702 24.311 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-float8wo_e4m3---compile-True.png
flux float8dq_e5m2 bf16 True 3.625 22.528 22.697 24.205 11.33 22.875 24.686 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-float8dq_e5m2---compile-True.png
flux float8dq_e4m3 bf16 True 3.516 22.523 22.691 23.957 11.33 22.871 24.438 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-float8dq_e4m3---compile-True.png
flux float8dq_e5m2_tensor bf16 True 3.592 22.523 22.691 23.951 11.33 22.691 23.951 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-float8dq_e5m2_atwt---compile-True.png
flux float8dq_e5m2_row bf16 True 3.51 22.534 22.702 23.887 11.341 22.881 24.367 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-float8dq_e5m2_arwr---compile-True.png
flux float8dq_e4m3_tensor bf16 True 3.443 22.534 22.702 23.959 11.33 22.702 23.959 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-float8dq_e4m3_atwt---compile-True.png
flux float8dq_e4m3_row bf16 True 3.47 22.534 22.702 23.918 11.341 22.881 24.398 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-float8dq_e4m3_arwr---compile-True.png
flux fp3_e1m1 bf16 True 13.372 15.584 15.752 17.113 4.381 15.752 17.113 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp3_e1m1---compile-True.png
flux fp3_e2m0 bf16 True 13.257 8.628 8.79 10.754 4.379 8.79 10.754 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp3_e2m0---compile-True.png
flux fp4_e1m2 bf16 True 6.556 10.011 10.173 11.986 5.764 10.173 11.986 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp4_e1m2---compile-True.png
flux fp4_e2m1 bf16 True 6.456 11.402 11.562 13.059 5.769 11.562 13.059 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp4_e2m1---compile-True.png
flux fp4_e3m0 bf16 True 6.336 11.402 11.566 12.945 5.768 11.566 12.945 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp4_e3m0---compile-True.png
flux fp5_e1m3 bf16 True 11.859 12.75 12.915 14.664 7.117 12.915 14.664 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp5_e1m3---compile-True.png
flux fp5_e2m2 bf16 True 11.601 14.094 14.257 16.553 7.111 14.257 16.553 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp5_e2m2---compile-True.png
flux fp5_e3m1 bf16 True 11.396 14.092 14.253 16.545 7.114 14.253 16.545 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp5_e3m1---compile-True.png
flux fp5_e4m0 bf16 True 11.242 14.093 14.254 16.537 7.113 14.254 16.537 torchao-quantizer torchao_benchmark_results/flux---dtype-bf16---qtype-fp5_e4m0---compile-True.png
Flux visual results
bf16 baseline
bf16---int4wo bf16---int4dq
bf16---int8wo bf16---int8dq
bf16---uint1wo bf16---uint2wo
bf16---uint3wo bf16---uint4wo
bf16---uint5wo bf16---uint6wo
bf16---uint7wo bf16---float8wo_e5m2
bf16---float8wo_e4m3 bf16---float8dq_e4m3
bf16---float8dq_e4m3_tensor bf16---float8dq_e4m3_row
bf16---fp3_e1m1 bf16---fp3_e2m0
bf16---fp4_e1m2 bf16---fp4_e2m1
bf16---fp4_e3m0 bf16---fp5_e1m3
bf16---fp5_e2m2 bf16---fp5_e3m1
bf16---fp5_e4m0 bf16---fp6_e1m4
bf16---fp6_e2m3 bf16---fp6_e3m2
bf16---fp6_e4m1 bf16---fp6_e5m0
bf16---fp7_e1m5 bf16---fp7_e2m4
bf16---fp7_e3m3 bf16---fp7_e4m2
bf16---fp7_e5m1 bf16---fp7_e6m0
CogVideoX table
model_id quantization_type compute_dtype compile time before_inference_memory before_inference_max_memory before_inference_max_memory_reserved after_inference_memory after_inference_max_memory after_inference_max_memory_reserved branch filename exception
cogvideox none bf16 False 109.183 10.824 10.824 10.836 10.827 13.485 15.531 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-none---compile-False.mp4
cogvideox int4wo bf16 False 666.9 3.678 14.471 32.055 3.677 14.471 32.055 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-int4wo---compile-False.mp4
cogvideox int4dq bf16 False 157.475 9.516 9.935 35.67 9.517 14.842 35.67 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-int4dq---compile-False.mp4
cogvideox int8wo bf16 False 120.042 11.51 11.928 34.115 11.509 14.171 34.115 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-int8wo---compile-False.mp4
cogvideox int8dq bf16 False 338.083 10.88 11.509 33.689 10.881 14.274 33.689 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-int8dq---compile-False.mp4
cogvideox uint1wo bf16 False 109.582 6.953 10.881 34.791 6.953 10.881 34.791 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint1wo---compile-False.mp4
cogvideox uint2wo bf16 False 112.541 3.52 6.953 29.977 3.52 6.953 29.977 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint2wo---compile-False.mp4
cogvideox uint3wo bf16 False 115.799 4.861 5.271 26.996 4.861 7.523 26.996 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint3wo---compile-False.mp4
cogvideox uint4wo bf16 False 113.59 6.18 6.598 27.877 6.18 8.843 27.877 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint4wo---compile-False.mp4
cogvideox uint5wo bf16 False 115.742 7.477 7.888 29.207 7.477 10.139 29.207 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint5wo---compile-False.mp4
cogvideox uint6wo bf16 False 115.488 8.786 9.203 30.504 8.785 11.448 30.504 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint6wo---compile-False.mp4
cogvideox uint7wo bf16 False 117.26 10.174 10.587 31.883 10.174 12.835 31.883 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint7wo---compile-False.mp4
cogvideox float8wo_e5m2 bf16 False 110.556 10.939 11.358 11.611 10.94 13.602 16.332 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-float8wo_e5m2---compile-False.mp4
cogvideox float8wo_e4m3 bf16 False 110.689 10.844 11.259 37.793 10.843 13.505 37.793 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-float8wo_e4m3---compile-False.mp4
cogvideox float8dq_e4m3 bf16 False 116.14 10.828 11.241 11.801 10.828 15.748 19.938 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-float8dq_e4m3---compile-False.mp4
cogvideox float8dq_e4m3_tensor bf16 False 116.196 10.835 11.249 11.807 10.835 15.756 19.941 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-float8dq_e4m3_atwt---compile-False.mp4
cogvideox float8dq_e4m3_row bf16 False 118.433 10.839 11.249 11.805 10.839 15.757 19.941 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-float8dq_e4m3_arwr---compile-False.mp4
cogvideox fp3_e1m1 bf16 False 142.859 7.622 8.037 8.596 7.622 10.342 13.551 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp3_e1m1---compile-False.mp4
cogvideox fp3_e2m0 bf16 False 142.234 4.403 7.622 31.381 4.403 7.622 31.381 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp3_e2m0---compile-False.mp4
cogvideox fp4_e1m2 bf16 False 132.628 5.026 5.441 5.881 5.026 7.689 10.742 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp4_e1m2---compile-False.mp4
cogvideox fp4_e2m1 bf16 False 128.228 5.655 6.067 27.254 5.655 8.372 27.254 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp4_e2m1---compile-False.mp4
cogvideox fp4_e3m0 bf16 False 128.394 5.651 6.066 29.162 5.651 8.371 29.162 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp4_e3m0---compile-False.mp4
cogvideox fp5_e1m3 bf16 False 150.136 6.302 6.713 7.285 6.302 8.962 12.146 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp5_e1m3---compile-False.mp4
cogvideox fp5_e2m2 bf16 False 145.217 6.946 7.361 28.658 6.946 9.609 28.658 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp5_e2m2---compile-False.mp4
cogvideox fp5_e3m1 bf16 False 141.137 6.95 7.361 29.566 6.95 9.668 29.566 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp5_e3m1---compile-False.mp4
cogvideox fp5_e4m0 bf16 False 142.074 6.946 7.361 30.979 6.946 9.666 30.979 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp5_e4m0---compile-False.mp4
cogvideox fp6_e1m4 bf16 False 144.156 7.68 8.091 9.348 7.68 10.342 14.209 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp6_e1m4---compile-False.mp4
cogvideox fp6_e2m3 bf16 False 133.987 8.409 8.823 30.719 8.409 11.071 30.719 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp6_e2m3---compile-False.mp4
cogvideox fp6_e3m2 bf16 False 128.73 8.409 8.821 31.818 8.409 11.071 31.818 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp6_e3m2---compile-False.mp4
cogvideox fp6_e4m1 bf16 False 124.977 8.408 8.82 31.838 8.408 11.127 31.838 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp6_e4m1---compile-False.mp4
cogvideox fp6_e5m0 bf16 False 123.939 8.409 8.822 31.949 8.409 11.129 31.949 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp6_e5m0---compile-False.mp4
cogvideox fp7_e1m5 bf16 False 192.617 9.016 9.429 10.908 9.016 11.678 15.77 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp7_e1m5---compile-False.mp4
cogvideox fp7_e2m4 bf16 False 171.963 9.639 10.052 33.6 9.64 12.303 33.6 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp7_e2m4---compile-False.mp4
cogvideox fp7_e3m3 bf16 False 161.623 9.64 10.069 33.908 9.64 12.303 33.908 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp7_e3m3---compile-False.mp4
cogvideox fp7_e4m2 bf16 False 155.976 9.639 10.068 32.512 9.639 12.301 32.512 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp7_e4m2---compile-False.mp4
cogvideox fp7_e5m1 bf16 False 150.499 9.638 10.067 32.512 9.638 12.358 32.512 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp7_e5m1---compile-False.mp4
cogvideox fp7_e6m0 bf16 False 148.88 9.639 10.067 32.605 9.639 12.359 32.605 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp7_e6m0---compile-False.mp4
cogvideox none bf16 True 84.041 10.824 10.824 10.836 10.796 13.719 13.928 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-none---compile-True.mp4
cogvideox int4wo bf16 True 629.471 3.647 14.439 37.809 3.646 14.439 37.809 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-int4wo---compile-True.mp4
cogvideox int4dq bf16 True 87.653 9.485 9.904 31.266 6.261 10.29 31.266 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-int4dq---compile-True.mp4
cogvideox int8wo bf16 True 85.003 11.478 11.897 34.809 5.638 12.174 34.809 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-int8wo---compile-True.mp4
cogvideox int8dq bf16 True 77.878 10.848 11.269 33.518 5.633 11.269 33.518 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-int8dq---compile-True.mp4
cogvideox uint1wo bf16 True 80.105 6.924 7.34 35.025 1.709 7.34 35.025 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint1wo---compile-True.mp4
cogvideox uint2wo bf16 True 84.675 3.49 3.908 31.797 2.201 5.217 31.797 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint2wo---compile-True.mp4
cogvideox uint3wo bf16 True 85.097 4.83 5.239 29.791 3.041 6.056 29.791 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint3wo---compile-True.mp4
cogvideox uint4wo bf16 True 85.057 6.147 6.565 32.113 3.527 6.565 32.113 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint4wo---compile-True.mp4
cogvideox uint5wo bf16 True 85.073 7.446 7.856 31.158 4.331 7.856 31.158 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint5wo---compile-True.mp4
cogvideox uint6wo bf16 True 85.114 8.75 9.168 33.416 4.84 9.168 33.416 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint6wo---compile-True.mp4
cogvideox uint7wo bf16 True 85.803 10.141 10.552 32.578 5.714 10.552 32.578 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-uint7wo---compile-True.mp4
cogvideox float8wo_e5m2 bf16 True 85.035 10.902 11.319 14.775 5.608 11.319 14.775 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-float8wo_e5m2---compile-True.mp4
cogvideox float8wo_e4m3 bf16 True 85.018 10.807 11.217 33.383 5.612 11.217 33.383 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-float8wo_e4m3---compile-True.mp4
cogvideox float8dq_e5m2 bf16 True 76.829 10.803 11.218 33.088 5.608 12.044 33.088 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-float8dq_e5m2---compile-True.mp4
cogvideox float8dq_e4m3 bf16 True 74.553 10.799 11.214 33.893 5.608 12.071 33.893 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-float8dq_e4m3---compile-True.mp4
cogvideox float8dq_e5m2_tensor bf16 True 76.776 10.798 11.214 33.898 5.608 11.214 33.898 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-float8dq_e5m2_atwt---compile-True.mp4
cogvideox float8dq_e5m2_row bf16 True 76.541 10.808 11.223 15.629 5.618 12.048 17.436 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-float8dq_e5m2_arwr---compile-True.mp4
cogvideox float8dq_e4m3_tensor bf16 True 74.409 10.808 11.223 33.938 5.607 11.223 33.938 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-float8dq_e4m3_atwt---compile-True.mp4
cogvideox float8dq_e4m3_row bf16 True 75.586 10.807 11.222 15.623 5.617 12.048 17.43 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-float8dq_e4m3_arwr---compile-True.mp4
cogvideox fp3_e1m1 bf16 True 87.497 7.592 8.008 12.545 2.393 8.008 12.785 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp3_e1m1---compile-True.mp4
cogvideox fp3_e2m0 bf16 True 87.833 4.371 4.784 30.59 2.394 5.504 30.59 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp3_e2m0---compile-True.mp4
cogvideox fp4_e1m2 bf16 True 85.197 4.998 5.412 9.059 3.021 6.168 9.059 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp4_e1m2---compile-True.mp4
cogvideox fp4_e2m1 bf16 True 85.359 5.625 6.04 30.518 3.022 6.081 30.518 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp4_e2m1---compile-True.mp4
cogvideox fp4_e3m0 bf16 True 85.112 5.623 6.039 30.451 3.02 6.08 30.451 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp4_e3m0---compile-True.mp4
cogvideox fp5_e1m3 bf16 True 87.812 6.273 6.687 10.471 3.671 6.868 10.471 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp5_e1m3---compile-True.mp4
cogvideox fp5_e2m2 bf16 True 87.85 6.919 7.335 31.41 3.669 7.335 31.41 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp5_e2m2---compile-True.mp4
cogvideox fp5_e3m1 bf16 True 87.576 6.921 7.334 31.35 3.67 7.334 31.35 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp5_e3m1---compile-True.mp4
cogvideox fp5_e4m0 bf16 True 87.036 6.922 7.337 31.357 3.671 7.337 31.357 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp5_e4m0---compile-True.mp4
cogvideox fp6_e1m4 bf16 True 86.014 7.652 8.067 12.518 4.4 8.067 12.518 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp6_e1m4---compile-True.mp4
cogvideox fp6_e2m3 bf16 True 85.274 8.384 8.799 32.5 4.403 8.799 32.5 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp6_e2m3---compile-True.mp4
cogvideox fp6_e3m2 bf16 True 85.123 8.379 8.796 30.496 4.397 8.796 30.496 torchao-quantizer cogvideox_benchmark_results/cogvideox---dtype-bf16---qtype-fp6_e3m2---compile-True.mp4
CogVideoX visual results

Unfortunately, the prompt I used does not produce a very good initial video. Should have verified this in the beginning... 🫠

For some reason, GitHub does not render the videos from HF despite trying a few things. So, I'm not embedding it here. The results can be found here: https://huggingface.co/datasets/a-r-r-o-w/randoms/tree/main/cogvideox_benchmark_results

The minimal code for using the quantizer would be:

from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

model_id = "black-forest-labs/Flux.1-Dev"
dtype = torch.bfloat16

quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=dtype,
)
pipe = FluxPipeline.from_pretrained(
    model_id,
    transformer=transformer,
    torch_dtype=dtype,
)
pipe.to("cuda")

prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
image.save("output.png")

TODO

  • save_pretrained
  • from_pretrained
  • tests
    • Memory footprint.
    • Integration tests for Flux, Cog, etc.
    • modules_to_not_convert.
    • Training.
    • torch.compile() <> torchao
  • docs

@DN6 @sayakpaul @yiyixuxu

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@jerryzh168
Copy link

looks good to me overall, I think we also want to think about how we can integrate autoquant API: https://github.com/pytorch/ao/tree/main/torchao/quantization#autoquantization that works on the full model instead of individual linear modules

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review November 28, 2024 05:06
Comment on lines 674 to 706
if device_map is not None:
raise NotImplementedError(
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul I'm not sure how this impacts BnB quantizer. I assume it was disabled for BnB for some reason I'm not aware of. It works with TorchAO as expected though so if you need this to have some kind of guard for torchao-specific, I'll add it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works with TorchAO as expected though

How did you test that?

I assume it was disabled for BnB for some reason I'm not aware of.

That is because we merge the sharded checkpoints when using bnb and using custom device_maps needs this codepath:

accelerate.load_checkpoint_and_dispatch(

This is not hit when loading quantized checkpoints at least for bitsandbytes. This will be tackled in: #10013

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you test that?

There is a test for this in tests/quantization/torchao/test_torchao.py called test_offload that can be used to verify that cpu/disk offloading works with torchao

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but this is about custom user-provided device_maps. What am I missing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check of if device_map is not None was added in the BnB Quantizer PR. I assume it was added because device_map is not supported in BnB. But it works perfectly fine with TorchAO (as the test checks), so I removed the change in order to do the initial testing of the TorchAO quantizer quickly.

I would like to know if there should be an error raised if BnB quantizer is the method used. Something like:

if quantization method is BnB and device_map is not None:
    raise Error

Does that work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that works for me.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this. I love my aspects of this PR. My favorite being how we're supporting many torchao quant configs!

Apart from the comments I left in-line, I have the following additional comments:

  1. Consider testing for model memory footprint as well.
  2. Consider including integration tests for Flux, Cog, etc. At least Flux should be covered.
  3. Consider adding a note on serialization in the docs.
  4. Consider testing for modules_to_not_convert.
  5. Consider adding a test for training. Example:
    def test_training(self):

LMK if anything is unclear.

docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved

## Usage

Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] or even load a pre-quantized model. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.

docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved
src/diffusers/models/model_loading_utils.py Outdated Show resolved Hide resolved
Comment on lines 674 to 706
if device_map is not None:
raise NotImplementedError(
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works with TorchAO as expected though

How did you test that?

I assume it was disabled for BnB for some reason I'm not aware of.

That is because we merge the sharded checkpoints when using bnb and using custom device_maps needs this codepath:

accelerate.load_checkpoint_and_dispatch(

This is not hit when loading quantized checkpoints at least for bitsandbytes. This will be tackled in: #10013

tests/quantization/torchao/test_torchao.py Outdated Show resolved Hide resolved
tests/quantization/torchao/test_torchao.py Outdated Show resolved Hide resolved
tests/quantization/torchao/test_torchao.py Outdated Show resolved Hide resolved
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map_offload,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We assign a hf_device_map attribute to the model too, so we should also check if the quantized_hf_device_map matches the expected one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w just checking if this is remaining to be added?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the test_offload test with a check to verify hf_device_map is same as device_map. I don't see any quantized_hf_device_map when grepping the codebase or searching on github, so not sure what you are referring too. Could you help with this?

tests/quantization/torchao/test_torchao.py Outdated Show resolved Hide resolved
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did another pass and answered some questions.

Very important would be to have a test suite for torchao + torch.compile() (at least for some quant types) as that is a massive USP of torchao.

src/diffusers/models/model_loading_utils.py Outdated Show resolved Hide resolved
src/diffusers/quantizers/torchao/torchao_quantizer.py Outdated Show resolved Hide resolved
Comment on lines +170 to +171
module, tensor_name = get_module_from_name(model, param_name)
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check it to do a further inspection on the shape as well as to handle any parameter creation. But I guess it's fine with torchao because of the reasons you mentioned. @SunMarc WDYT?

Comment on lines 250 to 253
@property
def is_trainable(self):
# TODO(aryan): needs testing
return self.quantization_config.quant_type.startswith("int8")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8 training is orthogonal here if we are talking about peft I think. But I feel we should be able to fine-tune a fp8 quantized model as well (with float8_weight_only, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight) I feel, I haven't tried this though, did you see any errors when you try it?

I think this should be checked in the torchao CI given the popularity of training quantized models? If you give us a heads up about that, we'd be more than happy to configure this here accordingly.

@BenjaminBossan could you comment on the support of torchao <> peft a bit here?

@a-r-r-o-w a-r-r-o-w requested review from DN6 and sayakpaul December 5, 2024 13:37
for param in module.parameters():
if param.__class__.__name__ == "AffineQuantizedTensor":
data, scale, zero_point = param.layout_tensor.get_plain()
quantized_param_memory += data.numel() + data.element_size()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh typo here... should be multiplied

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting, missed it! Will try it out

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really well!

I think it would also make sense to run the existing and important integration tests before merging to make sure there's no obvious bugs.

docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved
docs/source/en/quantization/torchao.md Show resolved Hide resolved
docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved
docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved
docs/source/en/quantization/torchao.md Show resolved Hide resolved
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map_offload,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w just checking if this is remaining to be added?

tests/quantization/torchao/test_torchao.py Show resolved Hide resolved
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))

@staticmethod
def _get_memory_footprint(module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this not work?

def get_memory_footprint(self, return_buffers=True):

If not, we should consider these changes in modeling_utils.py, IMO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, this does not return the correct size of model weights when quantization is applied. We can consider the change in modeling_utils.py in a separate PR to account for the AQT tensors, since this is just present in the tests for the moment.

The TorchAO utility provided by Jerry here is probably better to use that what I have here.

Comment on lines +380 to +420
@staticmethod
def _get_memory_footprint(module):
quantized_param_memory = 0.0
unquantized_param_memory = 0.0

for param in module.parameters():
if param.__class__.__name__ == "AffineQuantizedTensor":
data, scale, zero_point = param.layout_tensor.get_plain()
quantized_param_memory += data.numel() + data.element_size()
quantized_param_memory += scale.numel() + scale.element_size()
quantized_param_memory += zero_point.numel() + zero_point.element_size()
else:
unquantized_param_memory += param.data.numel() * param.data.element_size()

total_memory = quantized_param_memory + unquantized_param_memory
return total_memory, quantized_param_memory, unquantized_param_memory

def test_memory_footprint(self):
r"""
A simple test to check if the model conversion has been done correctly by checking on the
memory footprint of the converted model and the class type of the linear layers of the converted models
"""
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"]
transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"]
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"]
transformer_bf16 = self.get_dummy_components(None)["transformer"]

total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo)
total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint(
transformer_int4wo_gs32
)
total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo)
total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16)

self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16)
# int4wo_gs32 has smaller group size, so more groups -> more scales and zero points
self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32)
# int4 with default group size quantized very few linear layers compared to a smaller group size of 32
self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32)
# int8 quantizes more layers compare to int4 with default group size
self.assertTrue(quantized_int8wo < quantized_int4wo)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah both are separate memory footprint gives us the ballpark around how much we need to load. Memory usage will tell us the actual memory needed for execution. Both could be considered to be included here.

tests/quantization/torchao/test_torchao.py Show resolved Hide resolved
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh thanks for the great work!
PR looks very good to me and I think we can merge this very soon. The only concern I have is the API to support all the shorthand, IMO we should not, but I'm open to different opinions! :)

)

@classmethod
def _get_torchao_quant_type_to_method(cls):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh, I think it will be too much to support API with all the shorthand, creating and maintaining all the related docs, and keeping the list up to date.

Is this something currently supported by TorchAO, or in their plan to? If they support it, I think it will be ok/managable for us to maintain a parallel mapping. otherwise, I think it is unnecessary/not meaningful for us to come up with new APIs for the external libraries we integrate.

If we want to create a set of "shorthand standards" that we can use at diffusers across all the different quantization methods/libraries we support (e.g. something we can use for both bnb, torchAO etc), it might be meaningful, but I think it will be better if we do that after we have a few more libraries in :)

also cc @SunMarc here, because for quantisation we would like to keep the API roughly consistent between transformer and diffusers

Overall, IMO , I think we should just only accept passing method name as it is, for shorthand, it is ok to support a very very small and most commonly used list.

But I'm open to different opinions! so let me know cc @DN6 too

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, maybe we can provide some common shorthand, these have been repeated in many libraries: https://github.com/pytorch/ao/blob/8a805d08898e5c961fb9b4f6ab61ffd5d5bdbca5/torchao/_models/llama/generate.py#L702

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking forward to hear thoughts from others! I find the shorthands for uintx/fpx, and suffixes of wo and dq, rather convenient. I believe they are commonly used too. No hard preferences, okay with whatever we decide

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think the number of terms here will be a bit difficult to maintain. Perhaps we support just the Shorthands mentioned here? Does the community have any preference?
https://github.com/huggingface/diffusers/pull/10009/files#r1873188178

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @DN6! I've removed the documentation shorthands (anything of the form {dtype}_aXwY after our discussion in DM.

I think we should definitely have the fully qualified function names, wo and dq suffixes. Updated the documentation accordingly.

For Jerry's suggestion, the hqq, marlin, sparsify, autoquant, intx (prototype), spinquant can be tackled in a separate PR after trying it out. Let's keep this one to the just the ones that we have here already. Will check the generation quality with these soon

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding, really super!

docs/source/en/quantization/overview.md Outdated Show resolved Hide resolved
docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved
docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved
docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved
docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved
docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved
docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved
docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved
docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved
"""This is a config class for torchao quantization/sparsity techniques.

Args:
quant_type (`str`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is quite heavy for a docstring

docs/source/en/quantization/torchao.md Outdated Show resolved Hide resolved
)

@classmethod
def _get_torchao_quant_type_to_method(cls):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think the number of terms here will be a bit difficult to maintain. Perhaps we support just the Shorthands mentioned here? Does the community have any preference?
https://github.com/huggingface/diffusers/pull/10009/files#r1873188178

@yiyixuxu
Copy link
Collaborator

@SunMarc can you do a final review if you haven't?

@a-r-r-o-w a-r-r-o-w requested a review from SunMarc December 12, 2024 10:53
@yiyixuxu yiyixuxu merged commit 9f00c61 into main Dec 16, 2024
15 checks passed
@yiyixuxu yiyixuxu deleted the torchao-quantizer branch December 16, 2024 23:35
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* torchao quantizer


---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Steven Liu <[email protected]>

[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.

Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed it seems Pytorch 2.5+ is required because in

there is an import of torch.uint1 (and others) which are not available in earlier torch versions. However, diffusers seem to require torch>=1.4 (ref), so this seem inconsistent. Am I missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchAO will not be imported or usable unless the pytorch version of 2.5 or above is available. Some Diffusers models can run with the 1.4 version as well, which is why that's the minimum required version.

Copy link

@fjeremic fjeremic Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm running into the same issue with the torch.unit1 import. It seems the TorchAO import is not guarded according to the backtrace. The following backtrace stems from this import line:

from diffusers import StableDiffusionXLPipeline

Here is the trace, and the pip list:

    Traceback (most recent call last):
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/utils/import_utils.py", line 920, in _get_module
        return importlib.import_module("." + module_name, self.__name__)
      File "/usr/local/lib/python3.10/importlib/__init__.py", line 126, in import_module
        return _bootstrap._gcd_import(name[level:], package, level)
      File "<frozen importlib._bootstrap>", line 1050, in _gcd_import
      File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
      File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
      File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
      File "<frozen importlib._bootstrap_external>", line 883, in exec_module
      File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/loaders/single_file.py", line 24, in <module>
        from .single_file_utils import (
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/loaders/single_file_utils.py", line 28, in <module>
        from ..models.modeling_utils import load_state_dict
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 35, in <module>
        from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/quantizers/__init__.py", line 15, in <module>
        from .auto import DiffusersAutoQuantizer
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/quantizers/auto.py", line 31, in <module>
        from .torchao import TorchAoHfQuantizer
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/quantizers/torchao/__init__.py", line 15, in <module>
        from .torchao_quantizer import TorchAoHfQuantizer
      File "/github/home/.local/lib/python3.10/site-packages/diffusers/quantizers/torchao/torchao_quantizer.py", line 45, in <module>
        torch.uint1,
      File "/github/home/.local/lib/python3.10/site-packages/torch/__init__.py", line 1938, in __getattr__
        raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
    AttributeError: module 'torch' has no attribute 'uint1'

And the pip list:

pip list -v
Package                  Version     Editable project location    Location                                         Installer
------------------------ ----------- ---------------------------- ------------------------------------------------ ---------
certifi                  2024.12.14                               /github/home/.local/lib/python3.10/site-packages pip
charset-normalizer       3.4.0                                    /github/home/.local/lib/python3.10/site-packages pip
colorama                 0.4.6                                    /github/home/.local/lib/python3.10/site-packages pip
coloredlogs              15.0.1                                   /github/home/.local/lib/python3.10/site-packages pip
colorlog                 6.9.0                                    /github/home/.local/lib/python3.10/site-packages pip
coverage                 7.6.9                                    /github/home/.local/lib/python3.10/site-packages pip
diffusers                0.32.0                                   /github/home/.local/lib/python3.10/site-packages pip
exceptiongroup           1.2.2                                    /github/home/.local/lib/python3.10/site-packages pip
execnet                  2.1.1                                    /github/home/.local/lib/python3.10/site-packages pip
filelock                 3.16.1                                   /github/home/.local/lib/python3.10/site-packages pip
flatbuffers              24.12.23                                 /github/home/.local/lib/python3.10/site-packages pip
fsspec                   2024.12.0                                /github/home/.local/lib/python3.10/site-packages pip
huggingface-hub          0.27.0                                   /github/home/.local/lib/python3.10/site-packages pip
humanfriendly            10.0                                     /github/home/.local/lib/python3.10/site-packages pip
idna                     3.10                                     /github/home/.local/lib/python3.10/site-packages pip
importlib_metadata       [8](/runs/952286/job/1969259#step:10:9).5.0                                    /github/home/.local/lib/python3.10/site-packages pip
iniconfig                2.0.0                                    /github/home/.local/lib/python3.10/site-packages pip
Jinja2                   3.1.5                                    /github/home/.local/lib/python3.10/site-packages pip
markdown-it-py           3.0.0                                    /github/home/.local/lib/python3.10/site-packages pip
MarkupSafe               3.0.2                                    /github/home/.local/lib/python3.10/site-packages pip
mdurl                    0.1.2                                    /github/home/.local/lib/python3.10/site-packages pip
mpmath                   1.3.0                                    /github/home/.local/lib/python3.10/site-packages pip
networkx                 3.4.2                                    /github/home/.local/lib/python3.10/site-packages pip
numpy                    1.26.4                                   /github/home/.local/lib/python3.10/site-packages pip
nvidia-cublas-cu12       12.1.3.1                                 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cuda-cupti-cu12   12.1.105                                 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cuda-nvrtc-cu12   12.1.105                                 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cuda-runtime-cu12 12.1.105                                 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cudnn-cu12        8.[9](/runs/952286/job/1969259#step:10:10).2.26                                 /github/home/.local/lib/python3.10/site-packages pip
nvidia-cufft-cu12        11.0.2.54                                /github/home/.local/lib/python3.[10](/runs/952286/job/1969259#step:10:11)/site-packages pip
nvidia-curand-cu12       10.3.2.106                               /github/home/.local/lib/python3.10/site-packages pip
nvidia-cusolver-cu12     [11](/runs/952286/job/1969259#step:10:12).4.5.107                               /github/home/.local/lib/python3.10/site-packages pip
nvidia-cusparse-cu[12](/runs/952286/job/1969259#step:10:13)     12.1.0.106                               /github/home/.local/lib/python3.10/site-packages pip
nvidia-nccl-cu12         2.19.3                                   /github/home/.local/lib/python3.10/site-packages pip
nvidia-nvjitlink-cu12    12.6.85                                  /github/home/.local/lib/python3.10/site-packages pip
nvidia-nvtx-cu12         12.1.105                                 /github/home/.local/lib/python3.10/site-packages pip
onnx                     1.17.0                                   /github/home/.local/lib/python3.10/site-packages pip
onnx2torch               1.5.15                                   /github/home/.local/lib/python3.10/site-packages pip
onnxruntime              1.20.1                                   /github/home/.local/lib/python3.10/site-packages pip
onnxsim                  0.4.36                                   /github/home/.local/lib/python3.10/site-packages pip
packaging                24.2                                     /github/home/.local/lib/python3.10/site-packages pip
pandas                   2.2.3                                    /github/home/.local/lib/python3.10/site-packages pip
pillow                   11.0.0                                   /github/home/.local/lib/python3.10/site-packages pip
pip                      22.0.4                                   /usr/local/lib/python3.10/site-packages          pip
pluggy                   1.5.0                                    /github/home/.local/lib/python3.10/site-packages pip
protobuf                 5.29.2                                   /github/home/.local/lib/python3.10/site-packages pip
Pygments                 2.18.0                                   /github/home/.local/lib/python3.10/site-packages pip
pytest                   8.3.4                                    /github/home/.local/lib/python3.10/site-packages pip
pytest-xdist             3.6.1                                    /github/home/.local/lib/python3.10/site-packages pip
python-dateutil          2.9.0.post0                              /github/home/.local/lib/python3.10/site-packages pip
python-dotenv            1.0.1                                    /github/home/.local/lib/python3.10/site-packages pip
python-json-logger       3.2.1                                    /github/home/.local/lib/python3.10/site-packages pip
pytz                     2024.2                                   /github/home/.local/lib/python3.10/site-packages pip
PyYAML                   6.0.2                                    /github/home/.local/lib/python3.10/site-packages pip
regex                    2024.11.6                                /github/home/.local/lib/python3.10/site-packages pip
requests                 2.32.3                                   /github/home/.local/lib/python3.10/site-packages pip
rich                     13.9.4                                   /github/home/.local/lib/python3.10/site-packages pip
safetensors              0.4.5                                    /github/home/.local/lib/python3.10/site-packages pip
scipy                    1.[14](/runs/952286/job/1969259#step:10:15).1                                   /github/home/.local/lib/python3.10/site-packages pip
sentencepiece            0.2.0                                    /github/home/.local/lib/python3.10/site-packages pip
setuptools               58.1.0                                   /usr/local/lib/python3.10/site-packages          pip
six                      1.17.0                                   /github/home/.local/lib/python3.10/site-packages pip
sympy                    1.13.3                                   /github/home/.local/lib/python3.10/site-packages pip
tabulate                 0.9.0                                    /github/home/.local/lib/python3.10/site-packages pip
tokenizers               0.[15](/runs/952286/job/1969259#step:10:16).2                                   /github/home/.local/lib/python3.10/site-packages pip
tomli                    2.2.1                                    /github/home/.local/lib/python3.10/site-packages pip
torch                    2.2.2                                    /github/home/.local/lib/python3.10/site-packages pip
torchvision              0.[17](/runs/952286/job/1969259#step:10:18).2                                   /github/home/.local/lib/python3.10/site-packages pip
tqdm                     4.67.1                                   /github/home/.local/lib/python3.10/site-packages pip
transformers             4.38.2                                   /github/home/.local/lib/python3.10/site-packages pip
triton                   2.2.0                                    /github/home/.local/lib/python3.10/site-packages pip
typing_extensions        4.12.2                                   /github/home/.local/lib/python3.10/site-packages pip
tzdata                   [20](/runs/952286/job/1969259#step:10:21)24.2                                   /github/home/.local/lib/python3.10/site-packages pip
urllib3                  2.3.0                                    /github/home/.local/lib/python3.10/site-packages pip
wheel                    0.37.1                                   /usr/local/lib/python3.10/site-packages          pip
zipp                     3.[21](/runs/952286/job/1969259#step:10:22).0                                   /github/home/.local/lib/python3.10/site-packages pip

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for reporting @fjeremic! We were able to replicate for torch <= 2.2. It seems to not cause the import errors for >= 2.3. We will be doing a patch release soon to fix this behaviour. Sorry for the inconvenience!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for providing a quick fix!
For completeness, I was running into the import error with torch 2.2.2 when importing AutoencoderKL

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BeckerFelix @fjeremic The patch release is out! Hope it fixes any problems you were facing in torch < 2.3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
quantization roadmap Add to current release roadmap
Projects
None yet
Development

Successfully merging this pull request may close these issues.