Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sharded models when TorchAO quantization is enabled #10256

Merged
merged 10 commits into from
Dec 20, 2024
2 changes: 1 addition & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder or "",
)
if hf_quantizer is not None:
if hf_quantizer is not None and is_bnb_quantization_method:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False
Expand Down
10 changes: 9 additions & 1 deletion src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ..models.attention_processor import FusedAttnProcessor2_0
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
from ..quantizers.bitsandbytes.utils import _check_bnb_status
from ..quantizers.torchao.utils import _check_torchao_status
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
CONFIG_NAME,
Expand Down Expand Up @@ -388,6 +389,7 @@ def to(self, *args, **kwargs):

device = device or device_arg
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
pipeline_has_torchao = any(_check_torchao_status(module) for _, module in self.components.items())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to skip the error message for now?

the issue of not being able use pipe.to("cuda") when the pipe has a device mapped module, has nothing to do with torchAO, if we do not throw an error here and then we make sure not to move the module with device_map later here, it would work

if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):

some refactor is needed, though, and I don't think needs to be done in this PR


# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
def module_is_sequentially_offloaded(module):
Expand All @@ -411,7 +413,7 @@ def module_is_offloaded(module):
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
if device and torch.device(device).type == "cuda":
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
if pipeline_is_sequentially_offloaded and not (pipeline_has_bnb or pipeline_has_torchao):
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)
Expand All @@ -420,6 +422,12 @@ def module_is_offloaded(module):
raise ValueError(
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
)
elif pipeline_has_torchao:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif pipeline_has_torchao:
elif pipeline_is_sequentially_offloaded and pipeline_has_torchao:

raise ValueError(
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `torchao`. This is not supported. There are two options on what could be done to fix this error:\n"
"1. Move the individual components of the model to the desired device directly using `.to()` on each.\n"
'2. Pass `device_map="balanced"` when initializing the pipeline to let `accelerate` handle the device placement.'
)

is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
Expand Down
20 changes: 20 additions & 0 deletions src/diffusers/quantizers/torchao/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..quantization_config import QuantizationMethod


def _check_torchao_status(module) -> bool:
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
is_loaded_in_torchao = getattr(module, "quantization_method", None) == QuantizationMethod.TORCHAO
return is_loaded_in_torchao
70 changes: 47 additions & 23 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,14 @@ def test_int4wo_quant_bfloat16_conversion(self):
self.assertEqual(weight.quant_max, 15)
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))

def test_offload(self):
def test_device_map(self):
"""
Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies
that the device map is correctly set (in the `hf_device_map` attribute of the model).
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
correctly set (in the `hf_device_map` attribute of the model).
"""

device_map_offload = {
custom_device_map_dict = {
"time_text_embed": torch_device,
"context_embedder": torch_device,
"x_embedder": torch_device,
Expand All @@ -293,27 +294,50 @@ def test_offload(self):
"norm_out": torch_device,
"proj_out": "cpu",
}
device_maps = ["auto", custom_device_map_dict]

inputs = self.get_dummy_tensor_inputs(torch_device)

with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map_offload,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)

self.assertTrue(quantized_model.hf_device_map == device_map_offload)

output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()

expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])

for device_map in device_maps:
device_map_to_compare = {"": 0} if device_map == "auto" else device_map

# Test non-sharded model
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)

self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)

output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))

# Test sharded model
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)

self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)

output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()

self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))

def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
Expand Down
Loading