Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sharded models when TorchAO quantization is enabled #10256

Merged
merged 10 commits into from
Dec 20, 2024
6 changes: 6 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
subfolder=subfolder or "",
)
if hf_quantizer is not None:
is_torchao_quantization_method = quantization_config.quant_method == QuantizationMethod.TORCHAO
Copy link
Collaborator

@yiyixuxu yiyixuxu Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we consolidate with this bnb check (remove the bnb check and extend this check for any quantization method)

is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"

this should not specific to any quantisation method, no? I run this test, for non-sharded checkpoint, both works for shared checkpoint, both throw same error

from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig, BitsAndBytesConfig
import torch

sharded_model_id = "black-forest-labs/Flux.1-Dev"
single_model_path = "/raid/yiyi/flux_model_single"
dtype = torch.bfloat16

# create a non-sharded checkpoint
# transformer = FluxTransformer2DModel.from_pretrained(
#     model_id,
#     subfolder="transformer",
#     torch_dtype=dtype,
# )
# transformer.save_pretrained(single_model_path, max_shard_size="100GB")

torch_ao_quantization_config = TorchAoConfig("int8wo")
bnb_quantization_config = BitsAndBytesConfig(load_in_8bit=True)

print(f" testing non-sharded checkpoint")
transformer = FluxTransformer2DModel.from_pretrained(
    single_model_path,
    quantization_config=torch_ao_quantization_config,
    device_map="auto",
    torch_dtype=dtype,
)

print(f"torchao hf_device_map: {transformer.hf_device_map}")

transformer = FluxTransformer2DModel.from_pretrained(
    single_model_path, 
    quantization_config=bnb_quantization_config,
    device_map="auto",
    torch_dtype=dtype,
)
print(f"bnb hf_device_map: {transformer.hf_device_map}")


print(f" testing sharded checkpoint")
## sharded checkpoint
try:
    transformer = FluxTransformer2DModel.from_pretrained(
        sharded_model_id, 
        subfolder="transformer",
        quantization_config=torch_ao_quantization_config,
        device_map="auto",
        torch_dtype=dtype,
    )
    print(f"torchao: {transformer.hf_device_map}")
except Exception as e:
    print(f"error: {e}")

try:
    transformer = FluxTransformer2DModel.from_pretrained(
        sharded_model_id,
        subfolder="transformer",
        quantization_config=bnb_quantization_config,
        device_map="auto",
        torch_dtype=dtype,
)
    print(f"bnb hf_device_map: {transformer.hf_device_map}")
except Exception as e:
    print(f"error: {e}")

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think non-sharded works for both, no? non-sharded checkpoint only seems to work torchao at the moment. These are my results:

method/shard sharded non-sharded
torchao fails works
bnb fails fails

I tried with your code as well and get the following error when using BnB with unsharded on this branch:

NotImplementedError: Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future.

Whatever the automatic infer of device_map thing is, we are still unable to pass device_map manually when state dict is sharded/unsharded, so I would put it in same bucket as failing.

Consolidating the checks together sounds good. Will update

if device_map is not None and is_torchao_quantization_method:
raise NotImplementedError(
"Loading sharded checkpoints, while passing `device_map`, is not supported with `torchao` quantization. This will be supported in the near future."
)

model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False
Expand Down
Loading