Skip to content

Commit

Permalink
Merge branch 'main' into ds-support-sd3-lora
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Dec 25, 2024
2 parents ed6e374 + cd991d1 commit 5e760cd
Show file tree
Hide file tree
Showing 5 changed files with 350 additions and 125 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ jobs:
test_location: "bnb"
- backend: "gguf"
test_location: "gguf"
- backend: "torchao"
test_location: "torchao"
runs-on:
group: aws-g6e-xlarge-plus
container:
Expand Down
62 changes: 62 additions & 0 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]
The example below only quantizes the weights to int8.

```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

model_id = "black-forest-labs/FLUX.1-dev"
Expand All @@ -44,6 +45,10 @@ pipe = FluxPipeline.from_pretrained(
)
pipe.to("cuda")

# Without quantization: ~31.447 GB
# With quantization: ~20.40 GB
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")

prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
Expand Down Expand Up @@ -88,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use

Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.

## Serializing and Deserializing quantized models

To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.

```python
import torch
from diffusers import FluxTransformer2DModel, TorchAoConfig

quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
```

To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.

```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel

transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
```

Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.

```python
import torch
from accelerate import init_empty_weights
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig

# Serialize the model
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=TorchAoConfig("uint4wo"),
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
# ...

# Load the model
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
with init_empty_weights():
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
transformer.load_state_dict(state_dict, strict=True, assign=True)
```

## Resources

- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,10 +718,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
hf_quantizer = None

if hf_quantizer is not None:
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
if is_bnb_quantization_method and device_map is not None:
if device_map is not None:
raise NotImplementedError(
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
)

hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
Expand Down Expand Up @@ -820,7 +819,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder or "",
)
if hf_quantizer is not None and is_bnb_quantization_method:
# TODO: https://github.com/huggingface/diffusers/issues/10013
if hf_quantizer is not None:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def validate_environment(self, *args, **kwargs):
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type

if quant_type.startswith("int"):
if quant_type.startswith("int") or quant_type.startswith("uint"):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
Expand Down
Loading

0 comments on commit 5e760cd

Please sign in to comment.