From e54c54096afec9c3b64d2d8e791edc0fc510cdd7 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 17 Jan 2025 22:36:29 +0100 Subject: [PATCH 01/15] first draft model loading refactor --- scripts/convert_sana_to_diffusers.py | 4 +- scripts/convert_sd3_to_diffusers.py | 6 +- scripts/convert_stable_audio.py | 8 +- scripts/convert_stable_cascade.py | 6 +- scripts/convert_stable_cascade_lite.py | 6 +- src/diffusers/loaders/single_file_model.py | 84 +- src/diffusers/loaders/single_file_utils.py | 26 +- src/diffusers/loaders/transformer_flux.py | 6 +- src/diffusers/loaders/transformer_sd3.py | 8 +- src/diffusers/loaders/unet.py | 6 +- src/diffusers/models/model_loading_utils.py | 204 ++--- src/diffusers/models/modeling_utils.py | 789 ++++++++++-------- .../quantizers/bitsandbytes/bnb_quantizer.py | 45 +- src/diffusers/utils/hub_utils.py | 43 +- tests/models/test_modeling_common.py | 6 +- .../test_models_transformer_sana.py | 25 - 16 files changed, 625 insertions(+), 647 deletions(-) diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index 99a9ff322251..d96b8bce38c1 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -18,7 +18,7 @@ SanaPipeline, SanaTransformer2DModel, ) -from diffusers.models.modeling_utils import load_model_dict_into_meta +from diffusers.models.modeling_utils import load_state_dict_into_meta_model from diffusers.utils.import_utils import is_accelerate_available @@ -189,7 +189,7 @@ def main(args): ) if is_accelerate_available(): - load_model_dict_into_meta(transformer, converted_state_dict) + load_state_dict_into_meta_model(transformer, converted_state_dict) else: transformer.load_state_dict(converted_state_dict, strict=True, assign=True) diff --git a/scripts/convert_sd3_to_diffusers.py b/scripts/convert_sd3_to_diffusers.py index 0a3569efeab0..67bc62a04431 100644 --- a/scripts/convert_sd3_to_diffusers.py +++ b/scripts/convert_sd3_to_diffusers.py @@ -7,7 +7,7 @@ from diffusers import AutoencoderKL, SD3Transformer2DModel from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint -from diffusers.models.modeling_utils import load_model_dict_into_meta +from diffusers.models.modeling_utils import load_state_dict_into_meta_model from diffusers.utils.import_utils import is_accelerate_available @@ -319,7 +319,7 @@ def main(args): dual_attention_layers=attn2_layers, ) if is_accelerate_available(): - load_model_dict_into_meta(transformer, converted_transformer_state_dict) + load_state_dict_into_meta_model(transformer, converted_transformer_state_dict) else: transformer.load_state_dict(converted_transformer_state_dict, strict=True) @@ -339,7 +339,7 @@ def main(args): ) converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config) if is_accelerate_available(): - load_model_dict_into_meta(vae, converted_vae_state_dict) + load_state_dict_into_meta_model(vae, converted_vae_state_dict) else: vae.load_state_dict(converted_vae_state_dict, strict=True) diff --git a/scripts/convert_stable_audio.py b/scripts/convert_stable_audio.py index a0f9d0f87d90..959aa125e9cb 100644 --- a/scripts/convert_stable_audio.py +++ b/scripts/convert_stable_audio.py @@ -18,7 +18,7 @@ StableAudioPipeline, StableAudioProjectionModel, ) -from diffusers.models.modeling_utils import load_model_dict_into_meta +from diffusers.models.modeling_utils import load_state_dict_into_meta_model from diffusers.utils import is_accelerate_available @@ -221,7 +221,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay ], # assume `seconds_start` and `seconds_total` have the same min / max values. ) if is_accelerate_available(): - load_model_dict_into_meta(projection_model, projection_model_state_dict) + load_state_dict_into_meta_model(projection_model, projection_model_state_dict) else: projection_model.load_state_dict(projection_model_state_dict) @@ -242,7 +242,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay cross_attention_input_dim=model_config["cond_token_dim"], ) if is_accelerate_available(): - load_model_dict_into_meta(model, model_state_dict) + load_state_dict_into_meta_model(model, model_state_dict) else: model.load_state_dict(model_state_dict) @@ -260,7 +260,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay ) if is_accelerate_available(): - load_model_dict_into_meta(autoencoder, autoencoder_state_dict) + load_state_dict_into_meta_model(autoencoder, autoencoder_state_dict) else: autoencoder.load_state_dict(autoencoder_state_dict) diff --git a/scripts/convert_stable_cascade.py b/scripts/convert_stable_cascade.py index ce10970b0b6a..59a4b4e2280f 100644 --- a/scripts/convert_stable_cascade.py +++ b/scripts/convert_stable_cascade.py @@ -20,7 +20,7 @@ ) from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers from diffusers.models import StableCascadeUNet -from diffusers.models.modeling_utils import load_model_dict_into_meta +from diffusers.models.modeling_utils import load_state_dict_into_meta_model from diffusers.pipelines.wuerstchen import PaellaVQModel from diffusers.utils import is_accelerate_available @@ -126,7 +126,7 @@ switch_level=[False], ) if is_accelerate_available(): - load_model_dict_into_meta(prior_model, prior_state_dict) + load_state_dict_into_meta_model(prior_model, prior_state_dict) else: prior_model.load_state_dict(prior_state_dict) @@ -181,7 +181,7 @@ ) if is_accelerate_available(): - load_model_dict_into_meta(decoder, decoder_state_dict) + load_state_dict_into_meta_model(decoder, decoder_state_dict) else: decoder.load_state_dict(decoder_state_dict) diff --git a/scripts/convert_stable_cascade_lite.py b/scripts/convert_stable_cascade_lite.py index ddccaa3b2e8a..8f57bec97361 100644 --- a/scripts/convert_stable_cascade_lite.py +++ b/scripts/convert_stable_cascade_lite.py @@ -20,7 +20,7 @@ ) from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers from diffusers.models import StableCascadeUNet -from diffusers.models.modeling_utils import load_model_dict_into_meta +from diffusers.models.modeling_utils import load_state_dict_into_meta_model from diffusers.pipelines.wuerstchen import PaellaVQModel from diffusers.utils import is_accelerate_available @@ -133,7 +133,7 @@ ) if is_accelerate_available(): - load_model_dict_into_meta(prior_model, prior_state_dict) + load_state_dict_into_meta_model(prior_model, prior_state_dict) else: prior_model.load_state_dict(prior_state_dict) @@ -189,7 +189,7 @@ ) if is_accelerate_available(): - load_model_dict_into_meta(decoder, decoder_state_dict) + load_state_dict_into_meta_model(decoder, decoder_state_dict) else: decoder.load_state_dict(decoder_state_dict) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index c7d0fcb3046e..c8b81330d3fc 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -13,15 +13,11 @@ # limitations under the License. import importlib import inspect -import re -from contextlib import nullcontext from typing import Optional -import torch from huggingface_hub.utils import validate_hf_hub_args -from ..quantizers import DiffusersAutoQuantizer -from ..utils import deprecate, is_accelerate_available, logging +from ..utils import deprecate, logging from .single_file_utils import ( SingleFileComponentError, convert_animatediff_checkpoint_to_diffusers, @@ -49,12 +45,6 @@ logger = logging.get_logger(__name__) -if is_accelerate_available(): - from accelerate import init_empty_weights - - from ..models.modeling_utils import load_model_dict_into_meta - - SINGLE_FILE_LOADABLE_CLASSES = { "StableCascadeUNet": { "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers, @@ -234,9 +224,6 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = subfolder = kwargs.pop("subfolder", None) revision = kwargs.pop("revision", None) config_revision = kwargs.pop("config_revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) - quantization_config = kwargs.pop("quantization_config", None) - device = kwargs.pop("device", None) disable_mmap = kwargs.pop("disable_mmap", False) if isinstance(pretrained_model_link_or_path_or_dict, dict): @@ -252,12 +239,6 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = revision=revision, disable_mmap=disable_mmap, ) - if quantization_config is not None: - hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) - hf_quantizer.validate_environment() - - else: - hf_quantizer = None mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] @@ -336,62 +317,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." ) - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - model = cls.from_config(diffusers_model_config) - - # Check if `_keep_in_fp32_modules` is not None - use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( - (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + return cls.from_pretrained( + pretrained_model_name_or_path=None, + state_dict=diffusers_format_checkpoint, + config=diffusers_model_config, + **kwargs, ) - if use_keep_in_fp32_modules: - keep_in_fp32_modules = cls._keep_in_fp32_modules - if not isinstance(keep_in_fp32_modules, list): - keep_in_fp32_modules = [keep_in_fp32_modules] - - else: - keep_in_fp32_modules = [] - - if hf_quantizer is not None: - hf_quantizer.preprocess_model( - model=model, - device_map=None, - state_dict=diffusers_format_checkpoint, - keep_in_fp32_modules=keep_in_fp32_modules, - ) - - if is_accelerate_available(): - param_device = torch.device(device) if device else torch.device("cpu") - named_buffers = model.named_buffers() - unexpected_keys = load_model_dict_into_meta( - model, - diffusers_format_checkpoint, - dtype=torch_dtype, - device=param_device, - hf_quantizer=hf_quantizer, - keep_in_fp32_modules=keep_in_fp32_modules, - named_buffers=named_buffers, - ) - - else: - _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) - - if model._keys_to_ignore_on_load_unexpected is not None: - for pat in model._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) - - if hf_quantizer is not None: - hf_quantizer.postprocess_model(model) - model.hf_quantizer = hf_quantizer - - if torch_dtype is not None and hf_quantizer is None: - model.to(torch_dtype) - - model.eval() - - return model diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 1f52efbcc1f7..d5b1ba8097b6 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -53,7 +53,7 @@ if is_accelerate_available(): from accelerate import init_empty_weights - from ..models.modeling_utils import load_model_dict_into_meta + from ..models.modeling_utils import load_state_dict_into_meta_model logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1588,18 +1588,9 @@ def create_diffusers_clip_model_from_ldm( raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") if is_accelerate_available(): - unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + load_state_dict_into_meta_model(model, diffusers_format_checkpoint, dtype=torch_dtype) else: - _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) - - if model._keys_to_ignore_on_load_unexpected is not None: - for pat in model._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) + model.load_state_dict(diffusers_format_checkpoint, strict=False) if torch_dtype is not None: model.to(torch_dtype) @@ -2056,16 +2047,7 @@ def create_diffusers_t5_model_from_checkpoint( diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint) if is_accelerate_available(): - unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - if model._keys_to_ignore_on_load_unexpected is not None: - for pat in model._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) - + load_state_dict_into_meta_model(model, diffusers_format_checkpoint, dtype=torch_dtype) else: model.load_state_dict(diffusers_format_checkpoint) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index 9fe712bb12e9..bef1b0398748 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -17,7 +17,7 @@ ImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import load_model_dict_into_meta +from ..models.modeling_utils import load_state_dict_into_meta_model from ..utils import ( is_accelerate_available, is_torch_version, @@ -82,7 +82,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us if not low_cpu_mem_usage: image_projection.load_state_dict(updated_state_dict, strict=True) else: - load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) + load_state_dict_into_meta_model(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) return image_projection @@ -153,7 +153,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F else: device = self.device dtype = self.dtype - load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) + load_state_dict_into_meta_model(attn_procs[name], value_dict, device=device, dtype=dtype) key_id += 1 diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py index 435d1da06ca1..5c57d04ef2dc 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/transformer_sd3.py @@ -15,7 +15,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 from ..models.embeddings import IPAdapterTimeImageProjection -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict_into_meta_model class SD3Transformer2DLoadersMixin: @@ -59,7 +59,7 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _ if not low_cpu_mem_usage: attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) else: - load_model_dict_into_meta( + load_state_dict_into_meta_model( attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype ) @@ -86,4 +86,6 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _ if not low_cpu_mem_usage: self.image_proj.load_state_dict(state_dict["image_proj"], strict=True) else: - load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype) + load_state_dict_into_meta_model( + self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype + ) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index c68349c36dba..a8346925af7b 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -30,7 +30,7 @@ IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict +from ..models.modeling_utils import load_state_dict, load_state_dict_into_meta_model from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -753,7 +753,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us if not low_cpu_mem_usage: image_projection.load_state_dict(updated_state_dict, strict=True) else: - load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) + load_state_dict_into_meta_model(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) return image_projection @@ -846,7 +846,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F else: device = next(iter(value_dict.values())).device dtype = next(iter(value_dict.values())).dtype - load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) + load_state_dict_into_meta_model(attn_procs[name], value_dict, device=device, dtype=dtype) key_id += 2 diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 0acf50b82356..0f7110a9a018 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -20,7 +20,8 @@ from array import array from collections import OrderedDict from pathlib import Path -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union +from zipfile import is_zipfile import safetensors import torch @@ -55,7 +56,7 @@ if is_accelerate_available(): from accelerate import infer_auto_device_map - from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device + from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device # Adapted from `transformers` (see modeling_utils.py) @@ -134,15 +135,14 @@ def _fetch_remapped_cls_from_config(config, old_class): def load_state_dict( checkpoint_file: Union[str, os.PathLike], - variant: Optional[str] = None, dduf_entries: Optional[Dict[str, DDUFEntry]] = None, disable_mmap: bool = False, + map_location: Optional[Union[str, torch.device]] = None, ): """ Reads a checkpoint file, returning properly formatted errors if they arise. """ - # TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change - # when refactoring the _merge_sharded_checkpoints() method later. + # TODO: maybe refactor a bit this part where we pass a dict here if isinstance(checkpoint_file, dict): return checkpoint_file try: @@ -152,6 +152,14 @@ def load_state_dict( # tensors are loaded on cpu with dduf_entries[checkpoint_file].as_mmap() as mm: return safetensors.torch.load(mm) + # Check format of the archive + with safetensors.safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata is not None and metadata.get("format") not in ["pt", "flax"]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_pretrained` method." + ) if disable_mmap: return safetensors.torch.load(open(checkpoint_file, "rb").read()) else: @@ -159,12 +167,20 @@ def load_state_dict( elif file_extension == GGUF_FILE_EXTENSION: return load_gguf_checkpoint(checkpoint_file) else: + if map_location is None: + map_location = "cpu" + extra_args = {} weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} - return torch.load( - checkpoint_file, - map_location="cpu", - **weights_only_kwarg, - ) + # mmap can only be used with files serialized with zipfile-based format. + if ( + isinstance(checkpoint_file, str) + and map_location != "meta" + and is_torch_version(">=", "2.1.0") + and is_zipfile(checkpoint_file) + and not disable_mmap + ): + extra_args = {"mmap": True} + return torch.load(checkpoint_file, map_location="cpu", **weights_only_kwarg, **extra_args) except Exception as e: try: with open(checkpoint_file) as f: @@ -185,26 +201,28 @@ def load_state_dict( ) -def load_model_dict_into_meta( +def load_state_dict_into_meta_model( model, state_dict: OrderedDict, - device: Optional[Union[str, torch.device]] = None, dtype: Optional[Union[str, torch.dtype]] = None, model_name_or_path: Optional[str] = None, hf_quantizer=None, keep_in_fp32_modules=None, - named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None, + device_map=None, + unexpected_keys=None, + offload_folder=None, + offload_index=None, + state_dict_index=None, + state_dict_folder=None, ) -> List[str]: - if device is not None and not isinstance(device, (str, torch.device)): - raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.") - if hf_quantizer is None: - device = device or torch.device("cpu") - dtype = dtype or torch.float32 - is_quantized = hf_quantizer is not None + """ + This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its + params on a `meta` device. It replaces the model params with the data from the `state_dict` + """ - accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) + error_msgs = [] + is_quantized = hf_quantizer is not None empty_state_dict = model.state_dict() - unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] for param_name, param in state_dict.items(): if param_name not in empty_state_dict: @@ -214,7 +232,7 @@ def load_model_dict_into_meta( # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params # in int/uint/bool and not cast them. # TODO: revisit cases when param.dtype == torch.float8_e4m3fn - if torch.is_floating_point(param): + if dtype is not None and torch.is_floating_point(param): if ( keep_in_fp32_modules is not None and any( @@ -223,12 +241,41 @@ def load_model_dict_into_meta( and dtype == torch.float16 ): param = param.to(torch.float32) - if accepts_dtype: - set_module_kwargs["dtype"] = torch.float32 + set_module_kwargs["dtype"] = torch.float32 else: param = param.to(dtype) - if accepts_dtype: - set_module_kwargs["dtype"] = dtype + set_module_kwargs["dtype"] = dtype + + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which + # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. + # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 + old_param = model + splits = param_name.split(".") + for split in splits: + old_param = getattr(old_param, split) + + if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): + old_param = None + + if old_param is not None: + if dtype is None: + param = param.to(old_param.dtype) + + if old_param.is_contiguous(): + param = param.contiguous() + + module_name = param_name + + if device_map is None: + param_device = "cpu" + else: + # find next higher level module that is defined in device_map: + # bert.lm_head.weight -> bert.lm_head -> bert -> '' + while len(module_name) > 0 and module_name not in device_map: + module_name = ".".join(module_name.split(".")[:-1]) + if module_name == "" and "" not in device_map: + raise ValueError(f"{param_name} doesn't have any device set.") + param_device = device_map[module_name] # bnb params are flattened. # gguf quants have a different shape based on the type of quantization applied @@ -236,7 +283,9 @@ def load_model_dict_into_meta( if ( is_quantized and hf_quantizer.pre_quantized - and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) + and hf_quantizer.check_if_quantized_param( + model, param, param_name, state_dict, param_device=param_device + ) ): hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param) else: @@ -244,35 +293,42 @@ def load_model_dict_into_meta( raise ValueError( f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." ) - - if is_quantized and ( - hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) + if param_device == "disk": + offload_index = offload_weight(param, param_name, offload_folder, offload_index) + elif param_device == "cpu" and state_dict_index is not None: + state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) + elif is_quantized and ( + hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device) ): - hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) + hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys) else: - if accepts_dtype: - set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) - else: - set_module_tensor_to_device(model, param_name, device, value=param) + set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs) - if named_buffers is None: - return unexpected_keys + return error_msgs, offload_index, state_dict_index - for param_name, param in named_buffers: - if is_quantized and ( - hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) - ): - hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) - else: - if accepts_dtype: - set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) - else: - set_module_tensor_to_device(model, param_name, device, value=param) - return unexpected_keys +def load_model_dict_into_meta( + model, + state_dict: OrderedDict, + dtype: Optional[Union[str, torch.dtype]] = None, + model_name_or_path: Optional[str] = None, + hf_quantizer=None, + keep_in_fp32_modules=None, + device_map=None, + unexpected_keys=None, + is_safetensors=None, + offload_folder=None, + offload_index=None, + state_dict_index=None, + state_dict_folder=None, +) -> List[str]: + error_msgs = [] + return error_msgs, offload_index, state_dict_index -def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: +def _load_state_dict_into_model( + model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False +) -> List[str]: # Convert old format to new format if needed from a PyTorch state_dict # copy state_dict so _load_from_state_dict can modify it state_dict = state_dict.copy() @@ -280,15 +336,17 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. - def load(module: torch.nn.Module, prefix: str = ""): - args = (state_dict, prefix, {}, True, [], [], error_msgs) + def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False): + local_metadata = {} + local_metadata["assign_to_params_buffers"] = assign_to_params_buffers + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: - load(child, prefix + name + ".") + load(child, prefix + name + ".", assign_to_params_buffers) - load(model_to_load) + load(model_to_load, assign_to_params_buffers=assign_to_params_buffers) return error_msgs @@ -343,46 +401,6 @@ def _fetch_index_file( return index_file -# Adapted from -# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 -def _merge_sharded_checkpoints( - sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None -): - weight_map = sharded_metadata.get("weight_map", None) - if weight_map is None: - raise KeyError("'weight_map' key not found in the shard index file.") - - # Collect all unique safetensors files from weight_map - files_to_load = set(weight_map.values()) - is_safetensors = all(f.endswith(".safetensors") for f in files_to_load) - merged_state_dict = {} - - # Load tensors from each unique file - for file_name in files_to_load: - part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name) - if dduf_entries: - if part_file_path not in dduf_entries: - raise FileNotFoundError(f"Part file {file_name} not found.") - else: - if not os.path.exists(part_file_path): - raise FileNotFoundError(f"Part file {file_name} not found.") - - if is_safetensors: - if dduf_entries: - with dduf_entries[part_file_path].as_mmap() as mm: - tensors = safetensors.torch.load(mm) - merged_state_dict.update(tensors) - else: - with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: - for tensor_key in f.keys(): - if tensor_key in weight_map: - merged_state_dict[tensor_key] = f.get_tensor(tensor_key) - else: - merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu")) - - return merged_state_dict - - def _fetch_index_file_legacy( is_local, pretrained_model_name_or_path, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1c2b9a76dd67..97f023dccb61 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -20,10 +20,13 @@ import json import os import re +import shutil +import tempfile from collections import OrderedDict +from contextlib import ExitStack, contextmanager from functools import partial, wraps from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union import safetensors import torch @@ -61,16 +64,49 @@ _fetch_index_file, _fetch_index_file_legacy, _load_state_dict_into_model, - _merge_sharded_checkpoints, - load_model_dict_into_meta, load_state_dict, + load_state_dict_into_meta_model, ) +class ContextManagers: + """ + Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` + in the `fastcore` library. + """ + + def __init__(self, context_managers: List[ContextManager]): + self.context_managers = context_managers + self.stack = ExitStack() + + def __enter__(self): + for context_manager in self.context_managers: + self.stack.enter_context(context_manager) + + def __exit__(self, *args, **kwargs): + self.stack.__exit__(*args, **kwargs) + + logger = logging.get_logger(__name__) _REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}") +TORCH_INIT_FUNCTIONS = { + "uniform_": nn.init.uniform_, + "normal_": nn.init.normal_, + "trunc_normal_": nn.init.trunc_normal_, + "constant_": nn.init.constant_, + "xavier_uniform_": nn.init.xavier_uniform_, + "xavier_normal_": nn.init.xavier_normal_, + "kaiming_uniform_": nn.init.kaiming_uniform_, + "kaiming_normal_": nn.init.kaiming_normal_, + "uniform": nn.init.uniform, + "normal": nn.init.normal, + "xavier_uniform": nn.init.xavier_uniform, + "xavier_normal": nn.init.xavier_normal, + "kaiming_uniform": nn.init.kaiming_uniform, + "kaiming_normal": nn.init.kaiming_normal, +} if is_torch_version(">=", "1.9.0"): _LOW_CPU_MEM_USAGE_DEFAULT = True @@ -80,6 +116,8 @@ if is_accelerate_available(): import accelerate + from accelerate import dispatch_model + from accelerate.utils import load_offloaded_weights, save_offload_index def get_parameter_device(parameter: torch.nn.Module) -> torch.device: @@ -134,6 +172,57 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: return last_tuple[1].dtype +def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): + """ + Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first + checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's + parameters. + + Note: We fully disable this if we are using `deepspeed` + """ + if model_to_load.device.type == "meta": + return False + + if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: + return False + + # Some models explicitly do not support param buffer assignment + if not getattr(model_to_load, "_supports_param_buffer_assignment", True): + logger.debug( + f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" + ) + return False + + # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype + first_key = next(iter(model_to_load.state_dict().keys())) + if start_prefix + first_key in state_dict: + return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype + + # For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`) + return False + + +@contextmanager +def no_init_weights(): + """ + Context manager to globally disable weight initialization to speed up loading large models. To do that, all the + torch.nn.init function are all replaced with skip. + """ + + def _skip_init(*args, **kwargs): + pass + + # # Save the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, _skip_init) + try: + yield + finally: + # # Restore the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, init_func) + + class ModelMixin(torch.nn.Module, PushToHubMixin): r""" Base class for all models. @@ -609,6 +698,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P quantization_config = kwargs.pop("quantization_config", None) dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) + state_dict = kwargs.pop("state_dict", None) + config = kwargs.pop("config", None) allow_pickle = False if use_safetensors is None: @@ -679,33 +770,39 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") - # Load config if we don't provide a configuration - config_path = pretrained_model_name_or_path + if (not config and state_dict) or (config and not state_dict): + raise ValueError("You need to pass both the config and the state dict to initalize the model.") user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } + unused_kwargs = {} - # load config - config, unused_kwargs, commit_hash = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - return_commit_hash=True, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - dduf_entries=dduf_entries, - **kwargs, - ) - # no in-place modification of the original config. - config = copy.deepcopy(config) + if config is None: + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + # TODO: We need to let the user pass a config in from_pretrained + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + dduf_entries=dduf_entries, + **kwargs, + ) + # no in-place modification of the original config. + config = copy.deepcopy(config) # determine initial quantization config. ####################################### @@ -724,13 +821,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P hf_quantizer = None if hf_quantizer is not None: - if device_map is not None: - raise NotImplementedError( - "Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future." - ) - hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) + device_map = hf_quantizer.update_device_map(device_map) # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value @@ -758,89 +851,106 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.") else: keep_in_fp32_modules = [] - ####################################### - # Determine if we're loading from a directory of sharded checkpoints. is_sharded = False - index_file = None - is_local = os.path.isdir(pretrained_model_name_or_path) - index_file_kwargs = { - "is_local": is_local, - "pretrained_model_name_or_path": pretrained_model_name_or_path, - "subfolder": subfolder or "", - "use_safetensors": use_safetensors, - "cache_dir": cache_dir, - "variant": variant, - "force_download": force_download, - "proxies": proxies, - "local_files_only": local_files_only, - "token": token, - "revision": revision, - "user_agent": user_agent, - "commit_hash": commit_hash, - "dduf_entries": dduf_entries, - } - index_file = _fetch_index_file(**index_file_kwargs) - # In case the index file was not found we still have to consider the legacy format. - # this becomes applicable when the variant is not None. - if variant is not None and (index_file is None or not os.path.exists(index_file)): - index_file = _fetch_index_file_legacy(**index_file_kwargs) - if index_file is not None and (dduf_entries or index_file.is_file()): - is_sharded = True - - if is_sharded and from_flax: - raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") - - # load model - model_file = None - if from_flax: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=FLAX_WEIGHTS_NAME, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - ) - model = cls.from_config(config, **unused_kwargs) - - # Convert the weights - from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model - - model = load_flax_checkpoint_in_pytorch_model(model, model_file) - else: - # in the case it is sharded, we have already the index - if is_sharded: - sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( + resolved_archive_file = None + if state_dict is None: + # Determine if we're loading from a directory of sharded checkpoints. + sharded_metadata = None + index_file = None + is_local = os.path.isdir(pretrained_model_name_or_path) + index_file_kwargs = { + "is_local": is_local, + "pretrained_model_name_or_path": pretrained_model_name_or_path, + "subfolder": subfolder or "", + "use_safetensors": use_safetensors, + "cache_dir": cache_dir, + "variant": variant, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "token": token, + "revision": revision, + "user_agent": user_agent, + "commit_hash": commit_hash, + "dduf_entries": dduf_entries, + } + index_file = _fetch_index_file(**index_file_kwargs) + # In case the index file was not found we still have to consider the legacy format. + # this becomes applicable when the variant is not None. + if variant is not None and (index_file is None or not os.path.exists(index_file)): + index_file = _fetch_index_file_legacy(**index_file_kwargs) + if index_file is not None and (dduf_entries or index_file.is_file()): + is_sharded = True + + if is_sharded and from_flax: + raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") + + # load model + if from_flax: + resolved_archive_file = _get_model_file( pretrained_model_name_or_path, - index_file, + weights_name=FLAX_WEIGHTS_NAME, cache_dir=cache_dir, + force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, - user_agent=user_agent, revision=revision, - subfolder=subfolder or "", - dduf_entries=dduf_entries, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, ) - # TODO: https://github.com/huggingface/diffusers/issues/10013 - if hf_quantizer is not None or dduf_entries: - model_file = _merge_sharded_checkpoints( - sharded_ckpt_cached_folder, sharded_metadata, dduf_entries=dduf_entries + model = cls.from_config(config, **unused_kwargs) + + # Convert the weights + from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file) + else: + # in the case it is sharded, we have already the index + if is_sharded: + resolved_archive_file, sharded_metadata = _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_file, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder or "", + dduf_entries=dduf_entries, ) - logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") - is_sharded = False + elif use_safetensors: + try: + resolved_archive_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + dduf_entries=dduf_entries, + ) - elif use_safetensors and not is_sharded: - try: - model_file = _get_model_file( + except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") + if not allow_pickle: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + ) + + if resolved_archive_file is None: + resolved_archive_file = _get_model_file( pretrained_model_name_or_path, - weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -853,183 +963,97 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P dduf_entries=dduf_entries, ) - except IOError as e: - logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") - if not allow_pickle: - raise - logger.warning( - "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." - ) + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] - if model_file is None and not is_sharded: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - dduf_entries=dduf_entries, + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + dtype_orig = None + if torch_dtype is not None: + if not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) + dtype_orig = cls._set_default_torch_dtype(torch_dtype) - if low_cpu_mem_usage: - # Instantiate model with empty weights - with accelerate.init_empty_weights(): - model = cls.from_config(config, **unused_kwargs) + init_contexts = [no_init_weights()] - if hf_quantizer is not None: - hf_quantizer.preprocess_model( - model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules - ) - - # if device_map is None, load the state dict and move the params from meta device to the cpu - if device_map is None and not is_sharded: - # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. - # It would error out during the `validate_environment()` call above in the absence of cuda. - if hf_quantizer is None: - param_device = "cpu" - # TODO (sayakpaul, SunMarc): remove this after model loading refactor - else: - param_device = torch.device(torch.cuda.current_device()) - state_dict = load_state_dict( - model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap - ) - model._convert_deprecated_attention_blocks(state_dict) - - # move the params from meta device to cpu - missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) - if hf_quantizer is not None: - missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="") - if len(missing_keys) > 0: - raise ValueError( - f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" - f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" - " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" - " those weights or else make sure your checkpoint file is correct." - ) + if low_cpu_mem_usage: + init_contexts.append(accelerate.init_empty_weights()) - named_buffers = model.named_buffers() - - unexpected_keys = load_model_dict_into_meta( - model, - state_dict, - device=param_device, - dtype=torch_dtype, - model_name_or_path=pretrained_model_name_or_path, - hf_quantizer=hf_quantizer, - keep_in_fp32_modules=keep_in_fp32_modules, - named_buffers=named_buffers, - ) + with ContextManagers(init_contexts): + model = cls.from_config(config, **unused_kwargs) - if cls._keys_to_ignore_on_load_unexpected is not None: - for pat in cls._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) + if not is_sharded and state_dict is None: + # Time to load the checkpoint + state_dict = load_state_dict( + resolved_archive_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries + ) - else: # else let accelerate handle loading and dispatching. - # Load weights and dispatch according to the device_map - # by default the device_map is None and the weights are loaded on the CPU - device_map = _determine_device_map( - model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer - ) - if device_map is None and is_sharded: - # we load the parameters on the cpu - device_map = {"": "cpu"} - try: - accelerate.load_checkpoint_and_dispatch( - model, - model_file if not is_sharded else index_file, - device_map, - max_memory=max_memory, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - dtype=torch_dtype, - strict=True, - ) - except AttributeError as e: - # When using accelerate loading, we do not have the ability to load the state - # dict and rename the weight names manually. Additionally, accelerate skips - # torch loading conventions and directly writes into `module.{_buffers, _parameters}` - # (which look like they should be private variables?), so we can't use the standard hooks - # to rename parameters on load. We need to mimic the original weight names so the correct - # attributes are available. After we have loaded the weights, we convert the deprecated - # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert - # the weights so we don't have to do this again. - - if "'Attention' object has no attribute" in str(e): - logger.warning( - f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" - " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" - " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," - " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," - " please also re-upload it or open a PR on the original repository." - ) - model._temp_convert_self_to_deprecated_attention_blocks() - accelerate.load_checkpoint_and_dispatch( - model, - model_file if not is_sharded else index_file, - device_map, - max_memory=max_memory, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - dtype=torch_dtype, - strict=True, - ) - model._undo_temp_convert_self_to_deprecated_attention_blocks() - else: - raise e - - loading_info = { - "missing_keys": [], - "unexpected_keys": [], - "mismatched_keys": [], - "error_msgs": [], - } - else: - model = cls.from_config(config, **unused_kwargs) + if is_sharded: + loaded_keys = sharded_metadata["all_checkpoint_keys"] + else: + loaded_keys = list(state_dict.keys()) + # TODO: hacky solution + loaded_keys = list(model._fix_state_dict_keys_on_load({key: "" for key in loaded_keys})) - state_dict = load_state_dict( - model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap - ) - model._convert_deprecated_attention_blocks(state_dict) + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + ) - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( - model, - state_dict, - model_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - ) + # Now that the model is loaded, we can determine the device_map + device_map = _determine_device_map( + model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer + ) + if hf_quantizer is not None: + hf_quantizer.validate_environment(device_map=device_map) + + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + loaded_keys, + ignore_mismatched_sizes=ignore_mismatched_sizes, + low_cpu_mem_usage=low_cpu_mem_usage, + device_map=device_map, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - "error_msgs": error_msgs, - } + # Dispatch model with hooks on all devices if necessary + if device_map is not None: + device_map_kwargs = { + "device_map": device_map, + "offload_dir": offload_folder, + "offload_index": offload_index, + } + dispatch_model(model, **device_map_kwargs) if hf_quantizer is not None: hf_quantizer.postprocess_model(model) model.hf_quantizer = hf_quantizer - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." - ) - # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will - # completely lose the effectivity of `use_keep_in_fp32_modules`. - elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules: - model = model.to(torch_dtype) - if hf_quantizer is not None: # We also make sure to purge `_pre_quantization_dtype` when we serialize # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable. @@ -1039,6 +1063,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Set model in evaluation mode to deactivate DropOut modules by default model.eval() + if output_loading_info: return model, loading_info @@ -1119,54 +1144,133 @@ def _load_pretrained_model( cls, model, state_dict: OrderedDict, - resolved_archive_file, + resolved_archive_file: List[str], pretrained_model_name_or_path: Union[str, os.PathLike], + loaded_keys, ignore_mismatched_sizes: bool = False, + assign_to_params_buffers: bool = False, + hf_quantizer=None, + low_cpu_mem_usage=None, + dtype=None, + keep_in_fp32_modules=None, + device_map=None, + offload_state_dict=None, + offload_folder=None, ): - # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() - loaded_keys = list(state_dict.keys()) - expected_keys = list(model_state_dict.keys()) - - original_loaded_keys = loaded_keys - missing_keys = list(set(expected_keys) - set(loaded_keys)) + if hf_quantizer is not None: + missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="") unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + # # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + mismatched_keys = [] + + assign_to_params_buffers = None + error_msgs = [] + + # Deal with offload + is_safetensors = False + if device_map is not None and "disk" in device_map.values(): + archive_file = ( + resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file + ) + is_safetensors = archive_file.endswith(".safetensors") + if offload_folder is None and not is_safetensors: + raise ValueError( + "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" + " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" + " offers the weights in this format." + ) + if offload_folder is not None: + os.makedirs(offload_folder, exist_ok=True) + if offload_state_dict is None: + offload_state_dict = True + + offload_index = {} if device_map is not None and "disk" in device_map.values() else None + if offload_state_dict: + state_dict_folder = tempfile.mkdtemp() + state_dict_index = {} + else: + state_dict_folder = None + state_dict_index = None - # Make sure we are able to load base models as well as derived models (with heads) - model_to_load = model + # TODO: not sure if this is the most elegant way of dealing with this case + if state_dict is not None: + # load_state_dict will manage the case where we pass a dict instead of a file + # if state dict is not None, it means that we don't need to read the files from resolved_archive_file also + resolved_archive_file = [state_dict] - def _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - ignore_mismatched_sizes, - ): - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: - model_key = checkpoint_key - - if ( - model_key in model_state_dict - and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape - ): - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] - return mismatched_keys + if len(resolved_archive_file) > 1: + resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + model._fix_state_dict_keys_on_load(state_dict) - if state_dict is not None: - # Whole checkpoint - mismatched_keys = _find_mismatched_keys( + def _find_mismatched_keys( state_dict, model_state_dict, - original_loaded_keys, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + # If the checkpoint is sharded, we may not have the key here. + if checkpoint_key not in state_dict: + continue + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + mismatched_keys += _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, ignore_mismatched_sizes, ) - error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if low_cpu_mem_usage: + new_error_msgs, offload_index, state_dict_index = load_state_dict_into_meta_model( + model, + state_dict, + device_map=device_map, + dtype=dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_index=state_dict_index, + state_dict_folder=state_dict_folder, + ) + error_msgs += new_error_msgs + else: + if assign_to_params_buffers is None: + assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) + + error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) + + if offload_index is not None and len(offload_index) > 0: + save_offload_index(offload_index, offload_folder) + offload_index = None + + if offload_state_dict: + load_offloaded_weights(model, state_dict_index, state_dict_folder) + shutil.rmtree(state_dict_folder) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) @@ -1178,17 +1282,11 @@ def _find_mismatched_keys( if len(unexpected_keys) > 0: logger.warning( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" - f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" - f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" - " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" - " BertForPreTraining model).\n- This IS NOT expected if you are initializing" - f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" - " identical (initializing a BertForSequenceClassification model from a" - " BertForSequenceClassification model)." + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" @@ -1216,7 +1314,7 @@ def _find_mismatched_keys( " able to use it for predictions and inference." ) - return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs @classmethod def _get_signature_keys(cls, obj): @@ -1257,6 +1355,33 @@ def _get_no_split_modules(self, device_map: str): modules_to_check += list(module.children()) return list(_no_split_modules) + @classmethod + def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype: + """ + Change the default dtype and return the previous one. This is needed when wanting to instantiate the model + under specific dtype. + + Args: + dtype (`torch.dtype`): + a floating dtype to set to. + + Returns: + `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was + modified. If it wasn't, returns `None`. + + Note `set_default_dtype` currently only works with floating-point types and asserts if for example, + `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception. + """ + if not dtype.is_floating_point: + raise ValueError( + f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype" + ) + + logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.") + dtype_orig = torch.get_default_dtype() + torch.set_default_dtype(dtype) + return dtype_orig + @property def device(self) -> torch.device: """ @@ -1354,7 +1479,12 @@ def get_memory_footprint(self, return_buffers=True): mem = mem + mem_bufs return mem - def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: + def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None: + """ + This function fix the state dict of the model to take into account some changes that were made in the model + architecture: + - depretated attention blocks + """ deprecated_attention_block_paths = [] def recursive_find_attn_block(name, module): @@ -1397,56 +1527,7 @@ def recursive_find_attn_block(name, module): state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") if f"{path}.proj_attn.bias" in state_dict: state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") - - def _temp_convert_self_to_deprecated_attention_blocks(self) -> None: - deprecated_attention_block_modules = [] - - def recursive_find_attn_block(module): - if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: - deprecated_attention_block_modules.append(module) - - for sub_module in module.children(): - recursive_find_attn_block(sub_module) - - recursive_find_attn_block(self) - - for module in deprecated_attention_block_modules: - module.query = module.to_q - module.key = module.to_k - module.value = module.to_v - module.proj_attn = module.to_out[0] - - # We don't _have_ to delete the old attributes, but it's helpful to ensure - # that _all_ the weights are loaded into the new attributes and we're not - # making an incorrect assumption that this model should be converted when - # it really shouldn't be. - del module.to_q - del module.to_k - del module.to_v - del module.to_out - - def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None: - deprecated_attention_block_modules = [] - - def recursive_find_attn_block(module) -> None: - if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: - deprecated_attention_block_modules.append(module) - - for sub_module in module.children(): - recursive_find_attn_block(sub_module) - - recursive_find_attn_block(self) - - for module in deprecated_attention_block_modules: - module.to_q = module.query - module.to_k = module.key - module.to_v = module.value - module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)]) - - del module.query - del module.key - del module.value - del module.proj_attn + return state_dict class LegacyModelMixin(ModelMixin): diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index f7780b66b12b..b0b8fd09eac4 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -235,18 +235,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": torch_dtype = torch.float16 return torch_dtype - # (sayakpaul): I think it could be better to disable custom `device_map`s - # for the first phase of the integration in the interest of simplicity. - # Commenting this for discussions on the PR. - # def update_device_map(self, device_map): - # if device_map is None: - # device_map = {"": torch.cuda.current_device()} - # logger.info( - # "The device_map was not initialized. " - # "Setting device_map to {'':torch.cuda.current_device()}. " - # "If you want to use the model for inference, please set device_map ='auto' " - # ) - # return device_map + def update_device_map(self, device_map): + if device_map is None: + device_map = {"": torch.cuda.current_device()} + logger.info( + "The device_map was not initialized. " + "Setting device_map to {'':torch.cuda.current_device()}. " + "If you want to use the model for inference, please set device_map ='auto' " + ) + return device_map def _process_model_before_weight_loading( self, @@ -289,9 +286,9 @@ def _process_model_before_weight_loading( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config ) model.config.quantization_config = self.quantization_config + model.is_loaded_in_4bit = True def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): - model.is_loaded_in_4bit = True model.is_4bit_serializable = self.is_serializable return model @@ -400,16 +397,16 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": torch_dtype = torch.float16 return torch_dtype - # # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map - # def update_device_map(self, device_map): - # if device_map is None: - # device_map = {"": torch.cuda.current_device()} - # logger.info( - # "The device_map was not initialized. " - # "Setting device_map to {'':torch.cuda.current_device()}. " - # "If you want to use the model for inference, please set device_map ='auto' " - # ) - # return device_map + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map + def update_device_map(self, device_map): + if device_map is None: + device_map = {"": torch.cuda.current_device()} + logger.info( + "The device_map was not initialized. " + "Setting device_map to {'':torch.cuda.current_device()}. " + "If you want to use the model for inference, please set device_map ='auto' " + ) + return device_map def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": if target_dtype != torch.int8: @@ -493,7 +490,6 @@ def create_quantized_param( # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): - model.is_loaded_in_8bit = True model.is_8bit_serializable = self.is_serializable return model @@ -539,6 +535,7 @@ def _process_model_before_weight_loading( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config ) model.config.quantization_config = self.quantization_config + model.is_loaded_in_8bit = True @property # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 839e696c0ce9..61404bcbb360 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -411,22 +411,6 @@ def _get_model_file( ) from e -# Adapted from -# https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976 -# Differences are in parallelization of shard downloads and checking if shards are present. - - -def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames): - shards_path = os.path.join(local_dir, subfolder) - shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames] - for shard_file in shard_filenames: - if not os.path.exists(shard_file): - raise ValueError( - f"{shards_path} does not appear to have a file named {shard_file} which is " - "required according to the checkpoint index." - ) - - def _get_checkpoint_shard_files( pretrained_model_name_or_path, index_filename, @@ -469,13 +453,22 @@ def _get_checkpoint_shard_files( shards_path = os.path.join(pretrained_model_name_or_path, subfolder) # First, let's deal with local folder. - if os.path.isdir(pretrained_model_name_or_path): - _check_if_shards_exist_locally( - pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames - ) - return shards_path, sharded_metadata - elif dduf_entries: - return shards_path, sharded_metadata + if os.path.isdir(pretrained_model_name_or_path) or dduf_entries: + shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames] + for shard_file in shard_filenames: + if dduf_entries: + if shard_file not in dduf_entries: + raise FileNotFoundError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + else: + if not os.path.exists(shard_file): + raise FileNotFoundError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + return shard_filenames, sharded_metadata # At this stage pretrained_model_name_or_path is a model identifier on the Hub allow_patterns = original_shard_filenames @@ -517,7 +510,9 @@ def _get_checkpoint_shard_files( " again after checking your internet connection." ) from e - return cached_folder, sharded_metadata + cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames] + + return cached_filenames, sharded_metadata def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None): diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 4fc14804475a..9e0e79e0dd9c 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -185,12 +185,12 @@ class ModelUtilsTest(unittest.TestCase): def tearDown(self): super().tearDown() - def test_accelerate_loading_error_message(self): - with self.assertRaises(ValueError) as error_context: + def test_missing_key_loading_warning_message(self): + with self.assertLogs("diffusers.models.modeling_utils", level="WARNING") as logs: UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet") # make sure that error message states what keys are missing - assert "conv_out.bias" in str(error_context.exception) + assert "conv_out.bias" in " ".join(logs.output) @parameterized.expand( [ diff --git a/tests/models/transformers/test_models_transformer_sana.py b/tests/models/transformers/test_models_transformer_sana.py index 83db153dadea..0222bef4c7c3 100644 --- a/tests/models/transformers/test_models_transformer_sana.py +++ b/tests/models/transformers/test_models_transformer_sana.py @@ -14,7 +14,6 @@ import unittest -import pytest import torch from diffusers import SanaTransformer2DModel @@ -81,27 +80,3 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"SanaTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - @pytest.mark.xfail( - condition=torch.device(torch_device).type == "cuda", - reason="Test currently fails.", - strict=True, - ) - def test_cpu_offload(self): - return super().test_cpu_offload() - - @pytest.mark.xfail( - condition=torch.device(torch_device).type == "cuda", - reason="Test currently fails.", - strict=True, - ) - def test_disk_offload_with_safetensors(self): - return super().test_disk_offload_with_safetensors() - - @pytest.mark.xfail( - condition=torch.device(torch_device).type == "cuda", - reason="Test currently fails.", - strict=True, - ) - def test_disk_offload_without_safetensors(self): - return super().test_disk_offload_without_safetensors() From 645abc9bd8bd5d4de66faff8a3636de23cb29091 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 17 Jan 2025 23:09:15 +0100 Subject: [PATCH 02/15] revert name change --- scripts/convert_sana_to_diffusers.py | 4 ++-- scripts/convert_sd3_to_diffusers.py | 6 +++--- scripts/convert_stable_audio.py | 8 ++++---- scripts/convert_stable_cascade.py | 6 +++--- scripts/convert_stable_cascade_lite.py | 6 +++--- src/diffusers/loaders/single_file_utils.py | 6 +++--- src/diffusers/loaders/transformer_flux.py | 6 +++--- src/diffusers/loaders/transformer_sd3.py | 8 +++----- src/diffusers/loaders/unet.py | 6 +++--- src/diffusers/models/model_loading_utils.py | 2 +- src/diffusers/models/modeling_utils.py | 4 ++-- 11 files changed, 30 insertions(+), 32 deletions(-) diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index d96b8bce38c1..9820ac7c92a0 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -18,7 +18,7 @@ SanaPipeline, SanaTransformer2DModel, ) -from diffusers.models.modeling_utils import load_state_dict_into_meta_model +from diffusers.models.modeling_utils import load_state_dict_into_meta from diffusers.utils.import_utils import is_accelerate_available @@ -189,7 +189,7 @@ def main(args): ) if is_accelerate_available(): - load_state_dict_into_meta_model(transformer, converted_state_dict) + load_state_dict_into_meta(transformer, converted_state_dict) else: transformer.load_state_dict(converted_state_dict, strict=True, assign=True) diff --git a/scripts/convert_sd3_to_diffusers.py b/scripts/convert_sd3_to_diffusers.py index 67bc62a04431..28ff7db7a550 100644 --- a/scripts/convert_sd3_to_diffusers.py +++ b/scripts/convert_sd3_to_diffusers.py @@ -7,7 +7,7 @@ from diffusers import AutoencoderKL, SD3Transformer2DModel from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint -from diffusers.models.modeling_utils import load_state_dict_into_meta_model +from diffusers.models.modeling_utils import load_state_dict_into_meta from diffusers.utils.import_utils import is_accelerate_available @@ -319,7 +319,7 @@ def main(args): dual_attention_layers=attn2_layers, ) if is_accelerate_available(): - load_state_dict_into_meta_model(transformer, converted_transformer_state_dict) + load_state_dict_into_meta(transformer, converted_transformer_state_dict) else: transformer.load_state_dict(converted_transformer_state_dict, strict=True) @@ -339,7 +339,7 @@ def main(args): ) converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config) if is_accelerate_available(): - load_state_dict_into_meta_model(vae, converted_vae_state_dict) + load_state_dict_into_meta(vae, converted_vae_state_dict) else: vae.load_state_dict(converted_vae_state_dict, strict=True) diff --git a/scripts/convert_stable_audio.py b/scripts/convert_stable_audio.py index 959aa125e9cb..8066cce50b69 100644 --- a/scripts/convert_stable_audio.py +++ b/scripts/convert_stable_audio.py @@ -18,7 +18,7 @@ StableAudioPipeline, StableAudioProjectionModel, ) -from diffusers.models.modeling_utils import load_state_dict_into_meta_model +from diffusers.models.modeling_utils import load_state_dict_into_meta from diffusers.utils import is_accelerate_available @@ -221,7 +221,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay ], # assume `seconds_start` and `seconds_total` have the same min / max values. ) if is_accelerate_available(): - load_state_dict_into_meta_model(projection_model, projection_model_state_dict) + load_state_dict_into_meta(projection_model, projection_model_state_dict) else: projection_model.load_state_dict(projection_model_state_dict) @@ -242,7 +242,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay cross_attention_input_dim=model_config["cond_token_dim"], ) if is_accelerate_available(): - load_state_dict_into_meta_model(model, model_state_dict) + load_state_dict_into_meta(model, model_state_dict) else: model.load_state_dict(model_state_dict) @@ -260,7 +260,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay ) if is_accelerate_available(): - load_state_dict_into_meta_model(autoencoder, autoencoder_state_dict) + load_state_dict_into_meta(autoencoder, autoencoder_state_dict) else: autoencoder.load_state_dict(autoencoder_state_dict) diff --git a/scripts/convert_stable_cascade.py b/scripts/convert_stable_cascade.py index 59a4b4e2280f..3fedae65fab4 100644 --- a/scripts/convert_stable_cascade.py +++ b/scripts/convert_stable_cascade.py @@ -20,7 +20,7 @@ ) from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers from diffusers.models import StableCascadeUNet -from diffusers.models.modeling_utils import load_state_dict_into_meta_model +from diffusers.models.modeling_utils import load_state_dict_into_meta from diffusers.pipelines.wuerstchen import PaellaVQModel from diffusers.utils import is_accelerate_available @@ -126,7 +126,7 @@ switch_level=[False], ) if is_accelerate_available(): - load_state_dict_into_meta_model(prior_model, prior_state_dict) + load_state_dict_into_meta(prior_model, prior_state_dict) else: prior_model.load_state_dict(prior_state_dict) @@ -181,7 +181,7 @@ ) if is_accelerate_available(): - load_state_dict_into_meta_model(decoder, decoder_state_dict) + load_state_dict_into_meta(decoder, decoder_state_dict) else: decoder.load_state_dict(decoder_state_dict) diff --git a/scripts/convert_stable_cascade_lite.py b/scripts/convert_stable_cascade_lite.py index 8f57bec97361..4639d867ace9 100644 --- a/scripts/convert_stable_cascade_lite.py +++ b/scripts/convert_stable_cascade_lite.py @@ -20,7 +20,7 @@ ) from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers from diffusers.models import StableCascadeUNet -from diffusers.models.modeling_utils import load_state_dict_into_meta_model +from diffusers.models.modeling_utils import load_state_dict_into_meta from diffusers.pipelines.wuerstchen import PaellaVQModel from diffusers.utils import is_accelerate_available @@ -133,7 +133,7 @@ ) if is_accelerate_available(): - load_state_dict_into_meta_model(prior_model, prior_state_dict) + load_state_dict_into_meta(prior_model, prior_state_dict) else: prior_model.load_state_dict(prior_state_dict) @@ -189,7 +189,7 @@ ) if is_accelerate_available(): - load_state_dict_into_meta_model(decoder, decoder_state_dict) + load_state_dict_into_meta(decoder, decoder_state_dict) else: decoder.load_state_dict(decoder_state_dict) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index d5b1ba8097b6..8b67546aeca1 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -53,7 +53,7 @@ if is_accelerate_available(): from accelerate import init_empty_weights - from ..models.modeling_utils import load_state_dict_into_meta_model + from ..models.modeling_utils import load_state_dict_into_meta logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1588,7 +1588,7 @@ def create_diffusers_clip_model_from_ldm( raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") if is_accelerate_available(): - load_state_dict_into_meta_model(model, diffusers_format_checkpoint, dtype=torch_dtype) + load_state_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) else: model.load_state_dict(diffusers_format_checkpoint, strict=False) @@ -2047,7 +2047,7 @@ def create_diffusers_t5_model_from_checkpoint( diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint) if is_accelerate_available(): - load_state_dict_into_meta_model(model, diffusers_format_checkpoint, dtype=torch_dtype) + load_state_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) else: model.load_state_dict(diffusers_format_checkpoint) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index bef1b0398748..0f2881a24aad 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -17,7 +17,7 @@ ImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import load_state_dict_into_meta_model +from ..models.modeling_utils import load_state_dict_into_meta from ..utils import ( is_accelerate_available, is_torch_version, @@ -82,7 +82,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us if not low_cpu_mem_usage: image_projection.load_state_dict(updated_state_dict, strict=True) else: - load_state_dict_into_meta_model(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) + load_state_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) return image_projection @@ -153,7 +153,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F else: device = self.device dtype = self.dtype - load_state_dict_into_meta_model(attn_procs[name], value_dict, device=device, dtype=dtype) + load_state_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) key_id += 1 diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py index 5c57d04ef2dc..00a89c349d1f 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/transformer_sd3.py @@ -15,7 +15,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 from ..models.embeddings import IPAdapterTimeImageProjection -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict_into_meta_model +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict_into_meta class SD3Transformer2DLoadersMixin: @@ -59,7 +59,7 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _ if not low_cpu_mem_usage: attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) else: - load_state_dict_into_meta_model( + load_state_dict_into_meta( attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype ) @@ -86,6 +86,4 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _ if not low_cpu_mem_usage: self.image_proj.load_state_dict(state_dict["image_proj"], strict=True) else: - load_state_dict_into_meta_model( - self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype - ) + load_state_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index a8346925af7b..1a7f82cc0d5a 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -30,7 +30,7 @@ IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import load_state_dict, load_state_dict_into_meta_model +from ..models.modeling_utils import load_state_dict, load_state_dict_into_meta from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -753,7 +753,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us if not low_cpu_mem_usage: image_projection.load_state_dict(updated_state_dict, strict=True) else: - load_state_dict_into_meta_model(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) + load_state_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) return image_projection @@ -846,7 +846,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F else: device = next(iter(value_dict.values())).device dtype = next(iter(value_dict.values())).dtype - load_state_dict_into_meta_model(attn_procs[name], value_dict, device=device, dtype=dtype) + load_state_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) key_id += 2 diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 0f7110a9a018..e9a6c30daaa9 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -201,7 +201,7 @@ def load_state_dict( ) -def load_state_dict_into_meta_model( +def load_state_dict_into_meta( model, state_dict: OrderedDict, dtype: Optional[Union[str, torch.dtype]] = None, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 97f023dccb61..fa5a4a8eb3c5 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -65,7 +65,7 @@ _fetch_index_file_legacy, _load_state_dict_into_model, load_state_dict, - load_state_dict_into_meta_model, + load_state_dict_into_meta, ) @@ -1244,7 +1244,7 @@ def _find_mismatched_keys( ) if low_cpu_mem_usage: - new_error_msgs, offload_index, state_dict_index = load_state_dict_into_meta_model( + new_error_msgs, offload_index, state_dict_index = load_state_dict_into_meta( model, state_dict, device_map=device_map, From bd81f50662bacad4ba2489dff3d2cb6ae77463e2 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Sat, 18 Jan 2025 00:51:49 +0100 Subject: [PATCH 03/15] fix bnb --- src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index b0b8fd09eac4..e27c259ede30 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -493,7 +493,7 @@ def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): model.is_8bit_serializable = self.is_serializable return model - # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading with 4bit->8bit def _process_model_before_weight_loading( self, model: "ModelMixin", From 17c1be25979ab6514b87900631706a25f56bf5c2 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Sat, 18 Jan 2025 10:57:52 +0100 Subject: [PATCH 04/15] revert name --- scripts/convert_sana_to_diffusers.py | 4 ++-- scripts/convert_sd3_to_diffusers.py | 6 +++--- scripts/convert_stable_audio.py | 8 ++++---- scripts/convert_stable_cascade.py | 6 +++--- scripts/convert_stable_cascade_lite.py | 6 +++--- src/diffusers/loaders/single_file_utils.py | 6 +++--- src/diffusers/loaders/transformer_flux.py | 6 +++--- src/diffusers/loaders/transformer_sd3.py | 6 +++--- src/diffusers/loaders/unet.py | 6 +++--- src/diffusers/models/model_loading_utils.py | 2 +- src/diffusers/models/modeling_utils.py | 4 ++-- 11 files changed, 30 insertions(+), 30 deletions(-) diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index 9820ac7c92a0..99a9ff322251 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -18,7 +18,7 @@ SanaPipeline, SanaTransformer2DModel, ) -from diffusers.models.modeling_utils import load_state_dict_into_meta +from diffusers.models.modeling_utils import load_model_dict_into_meta from diffusers.utils.import_utils import is_accelerate_available @@ -189,7 +189,7 @@ def main(args): ) if is_accelerate_available(): - load_state_dict_into_meta(transformer, converted_state_dict) + load_model_dict_into_meta(transformer, converted_state_dict) else: transformer.load_state_dict(converted_state_dict, strict=True, assign=True) diff --git a/scripts/convert_sd3_to_diffusers.py b/scripts/convert_sd3_to_diffusers.py index 28ff7db7a550..0a3569efeab0 100644 --- a/scripts/convert_sd3_to_diffusers.py +++ b/scripts/convert_sd3_to_diffusers.py @@ -7,7 +7,7 @@ from diffusers import AutoencoderKL, SD3Transformer2DModel from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint -from diffusers.models.modeling_utils import load_state_dict_into_meta +from diffusers.models.modeling_utils import load_model_dict_into_meta from diffusers.utils.import_utils import is_accelerate_available @@ -319,7 +319,7 @@ def main(args): dual_attention_layers=attn2_layers, ) if is_accelerate_available(): - load_state_dict_into_meta(transformer, converted_transformer_state_dict) + load_model_dict_into_meta(transformer, converted_transformer_state_dict) else: transformer.load_state_dict(converted_transformer_state_dict, strict=True) @@ -339,7 +339,7 @@ def main(args): ) converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config) if is_accelerate_available(): - load_state_dict_into_meta(vae, converted_vae_state_dict) + load_model_dict_into_meta(vae, converted_vae_state_dict) else: vae.load_state_dict(converted_vae_state_dict, strict=True) diff --git a/scripts/convert_stable_audio.py b/scripts/convert_stable_audio.py index 8066cce50b69..a0f9d0f87d90 100644 --- a/scripts/convert_stable_audio.py +++ b/scripts/convert_stable_audio.py @@ -18,7 +18,7 @@ StableAudioPipeline, StableAudioProjectionModel, ) -from diffusers.models.modeling_utils import load_state_dict_into_meta +from diffusers.models.modeling_utils import load_model_dict_into_meta from diffusers.utils import is_accelerate_available @@ -221,7 +221,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay ], # assume `seconds_start` and `seconds_total` have the same min / max values. ) if is_accelerate_available(): - load_state_dict_into_meta(projection_model, projection_model_state_dict) + load_model_dict_into_meta(projection_model, projection_model_state_dict) else: projection_model.load_state_dict(projection_model_state_dict) @@ -242,7 +242,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay cross_attention_input_dim=model_config["cond_token_dim"], ) if is_accelerate_available(): - load_state_dict_into_meta(model, model_state_dict) + load_model_dict_into_meta(model, model_state_dict) else: model.load_state_dict(model_state_dict) @@ -260,7 +260,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay ) if is_accelerate_available(): - load_state_dict_into_meta(autoencoder, autoencoder_state_dict) + load_model_dict_into_meta(autoencoder, autoencoder_state_dict) else: autoencoder.load_state_dict(autoencoder_state_dict) diff --git a/scripts/convert_stable_cascade.py b/scripts/convert_stable_cascade.py index 3fedae65fab4..ce10970b0b6a 100644 --- a/scripts/convert_stable_cascade.py +++ b/scripts/convert_stable_cascade.py @@ -20,7 +20,7 @@ ) from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers from diffusers.models import StableCascadeUNet -from diffusers.models.modeling_utils import load_state_dict_into_meta +from diffusers.models.modeling_utils import load_model_dict_into_meta from diffusers.pipelines.wuerstchen import PaellaVQModel from diffusers.utils import is_accelerate_available @@ -126,7 +126,7 @@ switch_level=[False], ) if is_accelerate_available(): - load_state_dict_into_meta(prior_model, prior_state_dict) + load_model_dict_into_meta(prior_model, prior_state_dict) else: prior_model.load_state_dict(prior_state_dict) @@ -181,7 +181,7 @@ ) if is_accelerate_available(): - load_state_dict_into_meta(decoder, decoder_state_dict) + load_model_dict_into_meta(decoder, decoder_state_dict) else: decoder.load_state_dict(decoder_state_dict) diff --git a/scripts/convert_stable_cascade_lite.py b/scripts/convert_stable_cascade_lite.py index 4639d867ace9..ddccaa3b2e8a 100644 --- a/scripts/convert_stable_cascade_lite.py +++ b/scripts/convert_stable_cascade_lite.py @@ -20,7 +20,7 @@ ) from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers from diffusers.models import StableCascadeUNet -from diffusers.models.modeling_utils import load_state_dict_into_meta +from diffusers.models.modeling_utils import load_model_dict_into_meta from diffusers.pipelines.wuerstchen import PaellaVQModel from diffusers.utils import is_accelerate_available @@ -133,7 +133,7 @@ ) if is_accelerate_available(): - load_state_dict_into_meta(prior_model, prior_state_dict) + load_model_dict_into_meta(prior_model, prior_state_dict) else: prior_model.load_state_dict(prior_state_dict) @@ -189,7 +189,7 @@ ) if is_accelerate_available(): - load_state_dict_into_meta(decoder, decoder_state_dict) + load_model_dict_into_meta(decoder, decoder_state_dict) else: decoder.load_state_dict(decoder_state_dict) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 8b67546aeca1..4ee6544cdaaa 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -53,7 +53,7 @@ if is_accelerate_available(): from accelerate import init_empty_weights - from ..models.modeling_utils import load_state_dict_into_meta + from ..models.modeling_utils import load_model_dict_into_meta logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1588,7 +1588,7 @@ def create_diffusers_clip_model_from_ldm( raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") if is_accelerate_available(): - load_state_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) else: model.load_state_dict(diffusers_format_checkpoint, strict=False) @@ -2047,7 +2047,7 @@ def create_diffusers_t5_model_from_checkpoint( diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint) if is_accelerate_available(): - load_state_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) else: model.load_state_dict(diffusers_format_checkpoint) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index 0f2881a24aad..9fe712bb12e9 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -17,7 +17,7 @@ ImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import load_state_dict_into_meta +from ..models.modeling_utils import load_model_dict_into_meta from ..utils import ( is_accelerate_available, is_torch_version, @@ -82,7 +82,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us if not low_cpu_mem_usage: image_projection.load_state_dict(updated_state_dict, strict=True) else: - load_state_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) + load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) return image_projection @@ -153,7 +153,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F else: device = self.device dtype = self.dtype - load_state_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) + load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) key_id += 1 diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py index 00a89c349d1f..435d1da06ca1 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/transformer_sd3.py @@ -15,7 +15,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 from ..models.embeddings import IPAdapterTimeImageProjection -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict_into_meta +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta class SD3Transformer2DLoadersMixin: @@ -59,7 +59,7 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _ if not low_cpu_mem_usage: attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) else: - load_state_dict_into_meta( + load_model_dict_into_meta( attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype ) @@ -86,4 +86,4 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _ if not low_cpu_mem_usage: self.image_proj.load_state_dict(state_dict["image_proj"], strict=True) else: - load_state_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype) + load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 1a7f82cc0d5a..c47f27fbf171 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -30,7 +30,7 @@ IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import load_state_dict, load_state_dict_into_meta +from ..models.modeling_utils import load_state_dict, load_model_dict_into_meta from ..utils import ( USE_PEFT_BACKEND, _get_model_file, @@ -753,7 +753,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us if not low_cpu_mem_usage: image_projection.load_state_dict(updated_state_dict, strict=True) else: - load_state_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) + load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype) return image_projection @@ -846,7 +846,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F else: device = next(iter(value_dict.values())).device dtype = next(iter(value_dict.values())).dtype - load_state_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) + load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) key_id += 2 diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index e9a6c30daaa9..33c07a2e2f9a 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -201,7 +201,7 @@ def load_state_dict( ) -def load_state_dict_into_meta( +def load_model_dict_into_meta( model, state_dict: OrderedDict, dtype: Optional[Union[str, torch.dtype]] = None, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index fa5a4a8eb3c5..c91e1c042ecd 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -65,7 +65,7 @@ _fetch_index_file_legacy, _load_state_dict_into_model, load_state_dict, - load_state_dict_into_meta, + load_model_dict_into_meta, ) @@ -1244,7 +1244,7 @@ def _find_mismatched_keys( ) if low_cpu_mem_usage: - new_error_msgs, offload_index, state_dict_index = load_state_dict_into_meta( + new_error_msgs, offload_index, state_dict_index = load_model_dict_into_meta( model, state_dict, device_map=device_map, From 72b6259ecb9a43d2e915246a239126aca67b9a87 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Sat, 18 Jan 2025 11:04:13 +0100 Subject: [PATCH 05/15] fix dduf --- src/diffusers/loaders/unet.py | 2 +- src/diffusers/models/model_loading_utils.py | 19 ------------------- src/diffusers/models/modeling_utils.py | 6 ++++-- 3 files changed, 5 insertions(+), 22 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index c47f27fbf171..c68349c36dba 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -30,7 +30,7 @@ IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import load_state_dict, load_model_dict_into_meta +from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict from ..utils import ( USE_PEFT_BACKEND, _get_model_file, diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 33c07a2e2f9a..93b3a7fbc609 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -307,25 +307,6 @@ def load_model_dict_into_meta( return error_msgs, offload_index, state_dict_index -def load_model_dict_into_meta( - model, - state_dict: OrderedDict, - dtype: Optional[Union[str, torch.dtype]] = None, - model_name_or_path: Optional[str] = None, - hf_quantizer=None, - keep_in_fp32_modules=None, - device_map=None, - unexpected_keys=None, - is_safetensors=None, - offload_folder=None, - offload_index=None, - state_dict_index=None, - state_dict_folder=None, -) -> List[str]: - error_msgs = [] - return error_msgs, offload_index, state_dict_index - - def _load_state_dict_into_model( model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False ) -> List[str]: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index c91e1c042ecd..9d93a1946e88 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -64,8 +64,8 @@ _fetch_index_file, _fetch_index_file_legacy, _load_state_dict_into_model, - load_state_dict, load_model_dict_into_meta, + load_state_dict, ) @@ -1033,6 +1033,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P dtype=torch_dtype, hf_quantizer=hf_quantizer, keep_in_fp32_modules=keep_in_fp32_modules, + dduf_entries=dduf_entries, ) loading_info = { "missing_keys": missing_keys, @@ -1156,6 +1157,7 @@ def _load_pretrained_model( device_map=None, offload_state_dict=None, offload_folder=None, + dduf_entries=None, ): model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) @@ -1209,7 +1211,7 @@ def _load_pretrained_model( if len(resolved_archive_file) > 1: resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") for shard_file in resolved_archive_file: - state_dict = load_state_dict(shard_file) + state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) model._fix_state_dict_keys_on_load(state_dict) def _find_mismatched_keys( From b4e4f3b78a4264370fd430fd41eb46c5034a64d6 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Sat, 18 Jan 2025 11:45:24 +0100 Subject: [PATCH 06/15] fix huanyan --- src/diffusers/models/transformers/hunyuan_transformer_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 7f3dab220aaa..77f54b917089 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -278,7 +278,7 @@ def __init__( ) self.text_embedding_padding = nn.Parameter( - torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32) + torch.randn(text_len + text_len_t5, cross_attention_dim) ) self.pos_embed = PatchEmbed( From 5a00dc6ee3351d097d154e73550789dfeda0ead0 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Sat, 18 Jan 2025 11:46:59 +0100 Subject: [PATCH 07/15] style --- src/diffusers/models/transformers/hunyuan_transformer_2d.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 77f54b917089..d5e64cf99aa5 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -277,9 +277,7 @@ def __init__( act_fn="silu_fp32", ) - self.text_embedding_padding = nn.Parameter( - torch.randn(text_len + text_len_t5, cross_attention_dim) - ) + self.text_embedding_padding = nn.Parameter(torch.randn(text_len + text_len_t5, cross_attention_dim)) self.pos_embed = PatchEmbed( height=sample_size, From 2f671af63d6ec9f7acb3d43a99365959007bb928 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Mon, 20 Jan 2025 16:23:46 +0100 Subject: [PATCH 08/15] Update src/diffusers/models/model_loading_utils.py Co-authored-by: Sayak Paul --- src/diffusers/models/model_loading_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 93b3a7fbc609..be7cb9c5ba51 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -163,7 +163,7 @@ def load_state_dict( if disable_mmap: return safetensors.torch.load(open(checkpoint_file, "rb").read()) else: - return safetensors.torch.load_file(checkpoint_file, device="cpu") + return safetensors.torch.load_file(checkpoint_file, device=map_location) elif file_extension == GGUF_FILE_EXTENSION: return load_gguf_checkpoint(checkpoint_file) else: From 7273a94ec8ced73f87bb962c3ca8014f9ffa742b Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 20 Jan 2025 17:01:48 +0100 Subject: [PATCH 09/15] suggestions from reviews --- src/diffusers/models/model_loading_utils.py | 22 ++++++++++++------- src/diffusers/models/modeling_utils.py | 2 +- .../quantizers/bitsandbytes/bnb_quantizer.py | 10 +++++---- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 93b3a7fbc609..88e06dfab1ea 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -133,6 +133,19 @@ def _fetch_remapped_cls_from_config(config, old_class): return old_class +def _check_archive_and_maybe_raise_error(checkpoint_file, format_list): + """ + Check format of the archive + """ + with safetensors.safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata is not None and metadata.get("format") not in format_list: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_pretrained` method." + ) + + def load_state_dict( checkpoint_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None, @@ -152,14 +165,7 @@ def load_state_dict( # tensors are loaded on cpu with dduf_entries[checkpoint_file].as_mmap() as mm: return safetensors.torch.load(mm) - # Check format of the archive - with safetensors.safe_open(checkpoint_file, framework="pt") as f: - metadata = f.metadata() - if metadata is not None and metadata.get("format") not in ["pt", "flax"]: - raise OSError( - f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " - "you save your model with the `save_pretrained` method." - ) + _check_archive_and_maybe_raise_error(checkpoint_file, format_list=["pt", "flax"]) if disable_mmap: return safetensors.torch.load(open(checkpoint_file, "rb").read()) else: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 9d93a1946e88..f2b74c29bd0c 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -198,7 +198,6 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi if start_prefix + first_key in state_dict: return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype - # For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`) return False @@ -1210,6 +1209,7 @@ def _load_pretrained_model( if len(resolved_archive_file) > 1: resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") + for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) model._fix_state_dict_keys_on_load(state_dict) diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index e27c259ede30..0a1bdd96934f 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -237,10 +237,11 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": def update_device_map(self, device_map): if device_map is None: - device_map = {"": torch.cuda.current_device()} + device_map = {"": f"cuda:{torch.cuda.current_device()}"} logger.info( "The device_map was not initialized. " - "Setting device_map to {'':torch.cuda.current_device()}. " + "Setting device_map to {" + ": f`cuda:{torch.cuda.current_device()}`}. " "If you want to use the model for inference, please set device_map ='auto' " ) return device_map @@ -400,10 +401,11 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map def update_device_map(self, device_map): if device_map is None: - device_map = {"": torch.cuda.current_device()} + device_map = {"": f"cuda:{torch.cuda.current_device()}"} logger.info( "The device_map was not initialized. " - "Setting device_map to {'':torch.cuda.current_device()}. " + "Setting device_map to {" + ": f`cuda:{torch.cuda.current_device()}`}. " "If you want to use the model for inference, please set device_map ='auto' " ) return device_map From c5da1924a2e21645db2b5e3ecfeb1f559bdaa66f Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Tue, 21 Jan 2025 10:56:40 +0100 Subject: [PATCH 10/15] Update src/diffusers/models/modeling_utils.py Co-authored-by: YiYi Xu --- src/diffusers/models/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f2b74c29bd0c..1562b73818eb 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -946,7 +946,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." ) - if resolved_archive_file is None: + if resolved_archive_file is None and not is_sharded: resolved_archive_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(WEIGHTS_NAME, variant), From 039eef55b1a2364c0f9a32563e3e6720fbbe6292 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Tue, 21 Jan 2025 10:59:40 +0100 Subject: [PATCH 11/15] remove safetensors check --- src/diffusers/models/modeling_utils.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f2b74c29bd0c..b5e8c1bdfed3 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1176,13 +1176,8 @@ def _load_pretrained_model( error_msgs = [] # Deal with offload - is_safetensors = False if device_map is not None and "disk" in device_map.values(): - archive_file = ( - resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file - ) - is_safetensors = archive_file.endswith(".safetensors") - if offload_folder is None and not is_safetensors: + if offload_folder is None: raise ValueError( "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" From 337b2fc7b065abde767c628ad2b7c184da5575fe Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 23 Jan 2025 15:47:09 +0100 Subject: [PATCH 12/15] fix default value --- src/diffusers/models/modeling_utils.py | 2 +- src/diffusers/pipelines/pipeline_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index a0d77a37ae29..48609f9c2410 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -690,7 +690,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) - offload_state_dict = kwargs.pop("offload_state_dict", False) + offload_state_dict = kwargs.pop("offload_state_dict", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 3cafb77e5d63..0974feacbf9e 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -679,7 +679,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) - offload_state_dict = kwargs.pop("offload_state_dict", False) + offload_state_dict = kwargs.pop("offload_state_dict", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) dduf_file = kwargs.pop("dduf_file", None) From 0df70106a6d255f95686cbe0fe2296b62a321fc6 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 23 Jan 2025 16:27:39 +0100 Subject: [PATCH 13/15] more fix from suggestions --- src/diffusers/models/modeling_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 2646c5b41bce..374083441fa1 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1089,13 +1089,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P state_dict = load_state_dict( resolved_archive_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries ) + # We only fix it for non sharded checkpoints as we don't need it yet for sharded one. + model._fix_state_dict_keys_on_load(state_dict) if is_sharded: loaded_keys = sharded_metadata["all_checkpoint_keys"] else: loaded_keys = list(state_dict.keys()) - # TODO: hacky solution - loaded_keys = list(model._fix_state_dict_keys_on_load({key: "" for key in loaded_keys})) if hf_quantizer is not None: hf_quantizer.preprocess_model( @@ -1305,7 +1305,6 @@ def _load_pretrained_model( for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) - model._fix_state_dict_keys_on_load(state_dict) def _find_mismatched_keys( state_dict, @@ -1578,7 +1577,8 @@ def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None: """ This function fix the state dict of the model to take into account some changes that were made in the model architecture: - - depretated attention blocks + - deprecated attention blocks (happened before we introduced sharded checkpoint, + so this is why we apply this method only when loading non sharded checkpoints for now) """ deprecated_attention_block_paths = [] From d3a7dc8ad03151b050682d517f3e6d77eff3a2e3 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 23 Jan 2025 17:58:13 +0100 Subject: [PATCH 14/15] revert logic for single file --- src/diffusers/loaders/single_file_model.py | 83 +++++++- src/diffusers/models/modeling_utils.py | 227 ++++++++++----------- 2 files changed, 188 insertions(+), 122 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index c8b81330d3fc..ac54db82da41 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -13,11 +13,15 @@ # limitations under the License. import importlib import inspect +import re +from contextlib import nullcontext from typing import Optional +import torch from huggingface_hub.utils import validate_hf_hub_args -from ..utils import deprecate, logging +from ..quantizers import DiffusersAutoQuantizer +from ..utils import deprecate, is_accelerate_available, logging from .single_file_utils import ( SingleFileComponentError, convert_animatediff_checkpoint_to_diffusers, @@ -45,6 +49,12 @@ logger = logging.get_logger(__name__) +if is_accelerate_available(): + from accelerate import init_empty_weights + + from ..models.modeling_utils import load_model_dict_into_meta + + SINGLE_FILE_LOADABLE_CLASSES = { "StableCascadeUNet": { "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers, @@ -224,6 +234,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = subfolder = kwargs.pop("subfolder", None) revision = kwargs.pop("revision", None) config_revision = kwargs.pop("config_revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + quantization_config = kwargs.pop("quantization_config", None) + device = kwargs.pop("device", None) disable_mmap = kwargs.pop("disable_mmap", False) if isinstance(pretrained_model_link_or_path_or_dict, dict): @@ -239,6 +252,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = revision=revision, disable_mmap=disable_mmap, ) + if quantization_config is not None: + hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) + hf_quantizer.validate_environment() + + else: + hf_quantizer = None mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] @@ -317,9 +336,61 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." ) - return cls.from_pretrained( - pretrained_model_name_or_path=None, - state_dict=diffusers_format_checkpoint, - config=diffusers_model_config, - **kwargs, + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model = cls.from_config(diffusers_model_config) + + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") ) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules + if not isinstance(keep_in_fp32_modules, list): + keep_in_fp32_modules = [keep_in_fp32_modules] + + else: + keep_in_fp32_modules = [] + + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, + device_map=None, + state_dict=diffusers_format_checkpoint, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + + if is_accelerate_available(): + param_device = torch.device(device) if device else torch.device("cpu") + unexpected_keys = [param_name for param_name in diffusers_format_checkpoint if param_name not in model.state_dict()] + load_model_dict_into_meta( + model, + diffusers_format_checkpoint, + dtype=torch_dtype, + device_map={"":param_device}, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + ) + else: + _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) + + if model._keys_to_ignore_on_load_unexpected is not None: + for pat in model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model) + model.hf_quantizer = hf_quantizer + + if torch_dtype is not None and hf_quantizer is None: + model.to(torch_dtype) + + model.eval() + + return model diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 374083441fa1..0be7de60b796 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -795,8 +795,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P quantization_config = kwargs.pop("quantization_config", None) dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) - state_dict = kwargs.pop("state_dict", None) - config = kwargs.pop("config", None) allow_pickle = False if use_safetensors is None: @@ -867,9 +865,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") - if (not config and state_dict) or (config and not state_dict): - raise ValueError("You need to pass both the config and the state dict to initalize the model.") - user_agent = { "diffusers": __version__, "file_type": "model", @@ -877,29 +872,28 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P } unused_kwargs = {} - if config is None: - # Load config if we don't provide a configuration - config_path = pretrained_model_name_or_path + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path - # TODO: We need to let the user pass a config in from_pretrained - # load config - config, unused_kwargs, commit_hash = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - return_commit_hash=True, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - dduf_entries=dduf_entries, - **kwargs, - ) - # no in-place modification of the original config. - config = copy.deepcopy(config) + # TODO: We need to let the user pass a config in from_pretrained + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + dduf_entries=dduf_entries, + **kwargs, + ) + # no in-place modification of the original config. + config = copy.deepcopy(config) # determine initial quantization config. ####################################### @@ -951,103 +945,79 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P is_sharded = False resolved_archive_file = None - if state_dict is None: - # Determine if we're loading from a directory of sharded checkpoints. - sharded_metadata = None - index_file = None - is_local = os.path.isdir(pretrained_model_name_or_path) - index_file_kwargs = { - "is_local": is_local, - "pretrained_model_name_or_path": pretrained_model_name_or_path, - "subfolder": subfolder or "", - "use_safetensors": use_safetensors, - "cache_dir": cache_dir, - "variant": variant, - "force_download": force_download, - "proxies": proxies, - "local_files_only": local_files_only, - "token": token, - "revision": revision, - "user_agent": user_agent, - "commit_hash": commit_hash, - "dduf_entries": dduf_entries, - } - index_file = _fetch_index_file(**index_file_kwargs) - # In case the index file was not found we still have to consider the legacy format. - # this becomes applicable when the variant is not None. - if variant is not None and (index_file is None or not os.path.exists(index_file)): - index_file = _fetch_index_file_legacy(**index_file_kwargs) - if index_file is not None and (dduf_entries or index_file.is_file()): - is_sharded = True - - if is_sharded and from_flax: - raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") - - # load model - if from_flax: - resolved_archive_file = _get_model_file( + + # Determine if we're loading from a directory of sharded checkpoints. + sharded_metadata = None + index_file = None + is_local = os.path.isdir(pretrained_model_name_or_path) + index_file_kwargs = { + "is_local": is_local, + "pretrained_model_name_or_path": pretrained_model_name_or_path, + "subfolder": subfolder or "", + "use_safetensors": use_safetensors, + "cache_dir": cache_dir, + "variant": variant, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "token": token, + "revision": revision, + "user_agent": user_agent, + "commit_hash": commit_hash, + "dduf_entries": dduf_entries, + } + index_file = _fetch_index_file(**index_file_kwargs) + # In case the index file was not found we still have to consider the legacy format. + # this becomes applicable when the variant is not None. + if variant is not None and (index_file is None or not os.path.exists(index_file)): + index_file = _fetch_index_file_legacy(**index_file_kwargs) + if index_file is not None and (dduf_entries or index_file.is_file()): + is_sharded = True + + if is_sharded and from_flax: + raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") + + # load model + if from_flax: + resolved_archive_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=FLAX_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + model = cls.from_config(config, **unused_kwargs) + + # Convert the weights + from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file) + else: + # in the case it is sharded, we have already the index + if is_sharded: + resolved_archive_file, sharded_metadata = _get_checkpoint_shard_files( pretrained_model_name_or_path, - weights_name=FLAX_WEIGHTS_NAME, + index_file, cache_dir=cache_dir, - force_download=force_download, proxies=proxies, local_files_only=local_files_only, token=token, - revision=revision, - subfolder=subfolder, user_agent=user_agent, - commit_hash=commit_hash, + revision=revision, + subfolder=subfolder or "", + dduf_entries=dduf_entries, ) - model = cls.from_config(config, **unused_kwargs) - - # Convert the weights - from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model - - model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file) - else: - # in the case it is sharded, we have already the index - if is_sharded: - resolved_archive_file, sharded_metadata = _get_checkpoint_shard_files( - pretrained_model_name_or_path, - index_file, - cache_dir=cache_dir, - proxies=proxies, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder or "", - dduf_entries=dduf_entries, - ) - elif use_safetensors: - try: - resolved_archive_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - dduf_entries=dduf_entries, - ) - - except IOError as e: - logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") - if not allow_pickle: - raise - logger.warning( - "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." - ) - - if resolved_archive_file is None and not is_sharded: + elif use_safetensors: + try: resolved_archive_file = _get_model_file( pretrained_model_name_or_path, - weights_name=_add_variant(WEIGHTS_NAME, variant), + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -1060,6 +1030,30 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P dduf_entries=dduf_entries, ) + except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") + if not allow_pickle: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + ) + + if resolved_archive_file is None and not is_sharded: + resolved_archive_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + dduf_entries=dduf_entries, + ) + if not isinstance(resolved_archive_file, list): resolved_archive_file = [resolved_archive_file] @@ -1084,7 +1078,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if dtype_orig is not None: torch.set_default_dtype(dtype_orig) - if not is_sharded and state_dict is None: + state_dict = None + if not is_sharded: # Time to load the checkpoint state_dict = load_state_dict( resolved_archive_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries From fc4af164f27d14ff779203ec65d4a038178827b1 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 23 Jan 2025 17:58:39 +0100 Subject: [PATCH 15/15] style --- src/diffusers/loaders/single_file_model.py | 6 ++++-- src/diffusers/pipelines/consisid/pipeline_consisid.py | 11 ++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index ac54db82da41..82271ac230a5 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -362,12 +362,14 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = if is_accelerate_available(): param_device = torch.device(device) if device else torch.device("cpu") - unexpected_keys = [param_name for param_name in diffusers_format_checkpoint if param_name not in model.state_dict()] + unexpected_keys = [ + param_name for param_name in diffusers_format_checkpoint if param_name not in model.state_dict() + ] load_model_dict_into_meta( model, diffusers_format_checkpoint, dtype=torch_dtype, - device_map={"":param_device}, + device_map={"": param_device}, hf_quantizer=hf_quantizer, keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, diff --git a/src/diffusers/pipelines/consisid/pipeline_consisid.py b/src/diffusers/pipelines/consisid/pipeline_consisid.py index 0d4891cf17d7..1a99c2a0e9ee 100644 --- a/src/diffusers/pipelines/consisid/pipeline_consisid.py +++ b/src/diffusers/pipelines/consisid/pipeline_consisid.py @@ -48,9 +48,14 @@ >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview") - >>> face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = ( - ... prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16) - ... ) + >>> ( + ... face_helper_1, + ... face_helper_2, + ... face_clip_model, + ... face_main_model, + ... eva_transform_mean, + ... eva_transform_std, + ... ) = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16) >>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16) >>> pipe.to("cuda")