Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] TorchAO Quantizer #10009

Merged
merged 39 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
64cbf11
torchao quantizer
a-r-r-o-w Nov 24, 2024
b78a36c
make style
a-r-r-o-w Nov 24, 2024
355509e
update
a-r-r-o-w Nov 24, 2024
cbb0da4
update
a-r-r-o-w Nov 24, 2024
ee084a5
cuda capability check
a-r-r-o-w Nov 24, 2024
748a002
update
a-r-r-o-w Nov 24, 2024
bc006f2
fix
a-r-r-o-w Nov 24, 2024
956f3bf
fix
a-r-r-o-w Nov 25, 2024
2c6beef
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Nov 25, 2024
cfdb94f
update
a-r-r-o-w Nov 25, 2024
8e214e2
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Nov 25, 2024
1d9f832
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Nov 28, 2024
01b2b42
update tests
a-r-r-o-w Nov 28, 2024
b17cf35
device map changes
a-r-r-o-w Nov 28, 2024
250ccf4
update; apply suggestions from review
a-r-r-o-w Dec 1, 2024
50946a9
fix
a-r-r-o-w Dec 1, 2024
edae34b
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 1, 2024
8f09bdf
remove slow marker
a-r-r-o-w Dec 1, 2024
7c79b8e
remove pytest deprecation warnings
a-r-r-o-w Dec 1, 2024
820ac88
Merge branch 'main' into torchao-quantizer
sayakpaul Dec 5, 2024
f9f1535
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 5, 2024
747bd7d
apply review suggestions
a-r-r-o-w Dec 5, 2024
25d3cf8
add torch compile test
a-r-r-o-w Dec 5, 2024
10deb16
add more tests; add expected slices
a-r-r-o-w Dec 5, 2024
f3771a8
fix
a-r-r-o-w Dec 5, 2024
55d6155
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 5, 2024
de97a51
improve test check
a-r-r-o-w Dec 5, 2024
101d10c
update docs
a-r-r-o-w Dec 5, 2024
edd98db
bnb device map check
a-r-r-o-w Dec 5, 2024
2677e0c
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 5, 2024
cc70887
update docs
a-r-r-o-w Dec 5, 2024
5f75db2
Apply suggestions from code review
a-r-r-o-w Dec 6, 2024
9704daa
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 6, 2024
b227189
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 9, 2024
7d9d1dc
address review comments
a-r-r-o-w Dec 9, 2024
e9fccb6
update
a-r-r-o-w Dec 9, 2024
bc874fc
add nightly marker for torch.compile test
a-r-r-o-w Dec 9, 2024
29ec905
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 12, 2024
7ca64fd
Merge branch 'main' into torchao-quantizer
a-r-r-o-w Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@
title: Getting Started
- local: quantization/bitsandbytes
title: bitsandbytes
- local: quantization/torchao
title: torchao
title: Quantization Methods
- sections:
- local: optimization/fp16
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/api/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui

[[autodoc]] BitsAndBytesConfig

## TorchAoConfig

[[autodoc]] TorchAoConfig

## DiffusersQuantizer

[[autodoc]] quantizers.base.DiffusersQuantizer
2 changes: 1 addition & 1 deletion docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ If you are new to the quantization field, we recommend you to check out these be

## When to use what?

This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes` and `torchao`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
31 changes: 31 additions & 0 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# torchao

[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

Before you begin, make sure you have Pytorch version 2.5, or above, and TorchAO installed:
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

```bash
pip install -U torch torchao
```

## Usage
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
Now you can quantize a model by passing a [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] or even load a pre-quantized model. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.


## Usage

sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
## Resources

- [TorchAO Quantization API]()
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao)
4 changes: 2 additions & 2 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig"],
"quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
Expand Down Expand Up @@ -551,7 +551,7 @@

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig
from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig

try:
if not is_onnx_available():
Expand Down
9 changes: 3 additions & 6 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torch
from huggingface_hub.utils import EntryNotFoundError

from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
Expand Down Expand Up @@ -176,11 +175,9 @@ def load_model_dict_into_meta(
hf_quantizer=None,
keep_in_fp32_modules=None,
) -> List[str]:
if hf_quantizer is None:
device = device or torch.device("cpu")
device = device or torch.device("cpu")
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
dtype = dtype or torch.float32
is_quantized = hf_quantizer is not None
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES

accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
empty_state_dict = model.state_dict()
Expand Down Expand Up @@ -213,12 +210,12 @@ def load_model_dict_into_meta(
# bnb params are flattened.
if empty_state_dict[param_name].shape != param.shape:
if (
is_quant_method_bnb
is_quantized
and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
elif not is_quant_method_bnb:
else:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
Expand Down
5 changes: 1 addition & 4 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,13 +829,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if device_map is None and not is_sharded:
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
# It would error out during the `validate_environment()` call above in the absence of cuda.
is_quant_method_bnb = (
getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
)
if hf_quantizer is None:
param_device = "cpu"
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
elif is_quant_method_bnb:
else:
param_device = torch.cuda.current_device()
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@
from typing import Dict, Optional, Union

from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig
from .torchao import TorchAoHfQuantizer


AUTO_QUANTIZER_MAPPING = {
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
"torchao": TorchAoHfQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
"bitsandbytes_4bit": BitsAndBytesConfig,
"bitsandbytes_8bit": BitsAndBytesConfig,
"torchao": TorchAoConfig,
}


Expand Down
Loading
Loading