diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index c7d0fcb3046e..82271ac230a5 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -362,17 +362,18 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = if is_accelerate_available(): param_device = torch.device(device) if device else torch.device("cpu") - named_buffers = model.named_buffers() - unexpected_keys = load_model_dict_into_meta( + unexpected_keys = [ + param_name for param_name in diffusers_format_checkpoint if param_name not in model.state_dict() + ] + load_model_dict_into_meta( model, diffusers_format_checkpoint, dtype=torch_dtype, - device=param_device, + device_map={"": param_device}, hf_quantizer=hf_quantizer, keep_in_fp32_modules=keep_in_fp32_modules, - named_buffers=named_buffers, + unexpected_keys=unexpected_keys, ) - else: _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 731b7b87f625..9da23d101afb 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -1588,18 +1588,9 @@ def create_diffusers_clip_model_from_ldm( raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") if is_accelerate_available(): - unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) else: - _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) - - if model._keys_to_ignore_on_load_unexpected is not None: - for pat in model._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) + model.load_state_dict(diffusers_format_checkpoint, strict=False) if torch_dtype is not None: model.to(torch_dtype) @@ -2056,16 +2047,7 @@ def create_diffusers_t5_model_from_checkpoint( diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint) if is_accelerate_available(): - unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) - if model._keys_to_ignore_on_load_unexpected is not None: - for pat in model._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) - + load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) else: model.load_state_dict(diffusers_format_checkpoint) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 7e7445ef1239..793f821e1e1a 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -20,7 +20,8 @@ from array import array from collections import OrderedDict from pathlib import Path -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union +from zipfile import is_zipfile import safetensors import torch @@ -55,7 +56,7 @@ if is_accelerate_available(): from accelerate import infer_auto_device_map - from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device + from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device # Adapted from `transformers` (see modeling_utils.py) @@ -132,17 +133,29 @@ def _fetch_remapped_cls_from_config(config, old_class): return old_class +def _check_archive_and_maybe_raise_error(checkpoint_file, format_list): + """ + Check format of the archive + """ + with safetensors.safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata is not None and metadata.get("format") not in format_list: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_pretrained` method." + ) + + def load_state_dict( checkpoint_file: Union[str, os.PathLike], - variant: Optional[str] = None, dduf_entries: Optional[Dict[str, DDUFEntry]] = None, disable_mmap: bool = False, + map_location: Optional[Union[str, torch.device]] = None, ): """ Reads a checkpoint file, returning properly formatted errors if they arise. """ - # TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change - # when refactoring the _merge_sharded_checkpoints() method later. + # TODO: maybe refactor a bit this part where we pass a dict here if isinstance(checkpoint_file, dict): return checkpoint_file try: @@ -152,19 +165,28 @@ def load_state_dict( # tensors are loaded on cpu with dduf_entries[checkpoint_file].as_mmap() as mm: return safetensors.torch.load(mm) + _check_archive_and_maybe_raise_error(checkpoint_file, format_list=["pt", "flax"]) if disable_mmap: return safetensors.torch.load(open(checkpoint_file, "rb").read()) else: - return safetensors.torch.load_file(checkpoint_file, device="cpu") + return safetensors.torch.load_file(checkpoint_file, device=map_location) elif file_extension == GGUF_FILE_EXTENSION: return load_gguf_checkpoint(checkpoint_file) else: + if map_location is None: + map_location = "cpu" + extra_args = {} weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} - return torch.load( - checkpoint_file, - map_location="cpu", - **weights_only_kwarg, - ) + # mmap can only be used with files serialized with zipfile-based format. + if ( + isinstance(checkpoint_file, str) + and map_location != "meta" + and is_torch_version(">=", "2.1.0") + and is_zipfile(checkpoint_file) + and not disable_mmap + ): + extra_args = {"mmap": True} + return torch.load(checkpoint_file, map_location="cpu", **weights_only_kwarg, **extra_args) except Exception as e: try: with open(checkpoint_file) as f: @@ -188,23 +210,25 @@ def load_state_dict( def load_model_dict_into_meta( model, state_dict: OrderedDict, - device: Optional[Union[str, torch.device]] = None, dtype: Optional[Union[str, torch.dtype]] = None, model_name_or_path: Optional[str] = None, hf_quantizer=None, keep_in_fp32_modules=None, - named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None, + device_map=None, + unexpected_keys=None, + offload_folder=None, + offload_index=None, + state_dict_index=None, + state_dict_folder=None, ) -> List[str]: - if device is not None and not isinstance(device, (str, torch.device)): - raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.") - if hf_quantizer is None: - device = device or torch.device("cpu") - dtype = dtype or torch.float32 - is_quantized = hf_quantizer is not None + """ + This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its + params on a `meta` device. It replaces the model params with the data from the `state_dict` + """ - accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) + error_msgs = [] + is_quantized = hf_quantizer is not None empty_state_dict = model.state_dict() - unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] for param_name, param in state_dict.items(): if param_name not in empty_state_dict: @@ -214,7 +238,7 @@ def load_model_dict_into_meta( # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params # in int/uint/bool and not cast them. # TODO: revisit cases when param.dtype == torch.float8_e4m3fn - if torch.is_floating_point(param): + if dtype is not None and torch.is_floating_point(param): if ( keep_in_fp32_modules is not None and any( @@ -223,12 +247,41 @@ def load_model_dict_into_meta( and dtype == torch.float16 ): param = param.to(torch.float32) - if accepts_dtype: - set_module_kwargs["dtype"] = torch.float32 + set_module_kwargs["dtype"] = torch.float32 else: param = param.to(dtype) - if accepts_dtype: - set_module_kwargs["dtype"] = dtype + set_module_kwargs["dtype"] = dtype + + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which + # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. + # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 + old_param = model + splits = param_name.split(".") + for split in splits: + old_param = getattr(old_param, split) + + if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): + old_param = None + + if old_param is not None: + if dtype is None: + param = param.to(old_param.dtype) + + if old_param.is_contiguous(): + param = param.contiguous() + + module_name = param_name + + if device_map is None: + param_device = "cpu" + else: + # find next higher level module that is defined in device_map: + # bert.lm_head.weight -> bert.lm_head -> bert -> '' + while len(module_name) > 0 and module_name not in device_map: + module_name = ".".join(module_name.split(".")[:-1]) + if module_name == "" and "" not in device_map: + raise ValueError(f"{param_name} doesn't have any device set.") + param_device = device_map[module_name] # bnb params are flattened. # gguf quants have a different shape based on the type of quantization applied @@ -236,7 +289,9 @@ def load_model_dict_into_meta( if ( is_quantized and hf_quantizer.pre_quantized - and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) + and hf_quantizer.check_if_quantized_param( + model, param, param_name, state_dict, param_device=param_device + ) ): hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param) else: @@ -244,35 +299,23 @@ def load_model_dict_into_meta( raise ValueError( f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." ) - - if is_quantized and ( - hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) - ): - hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) - else: - if accepts_dtype: - set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) - else: - set_module_tensor_to_device(model, param_name, device, value=param) - - if named_buffers is None: - return unexpected_keys - - for param_name, param in named_buffers: - if is_quantized and ( - hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) + if param_device == "disk": + offload_index = offload_weight(param, param_name, offload_folder, offload_index) + elif param_device == "cpu" and state_dict_index is not None: + state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) + elif is_quantized and ( + hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device) ): - hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) + hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys) else: - if accepts_dtype: - set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) - else: - set_module_tensor_to_device(model, param_name, device, value=param) + set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs) - return unexpected_keys + return error_msgs, offload_index, state_dict_index -def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: +def _load_state_dict_into_model( + model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False +) -> List[str]: # Convert old format to new format if needed from a PyTorch state_dict # copy state_dict so _load_from_state_dict can modify it state_dict = state_dict.copy() @@ -280,15 +323,17 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. - def load(module: torch.nn.Module, prefix: str = ""): - args = (state_dict, prefix, {}, True, [], [], error_msgs) + def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False): + local_metadata = {} + local_metadata["assign_to_params_buffers"] = assign_to_params_buffers + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: - load(child, prefix + name + ".") + load(child, prefix + name + ".", assign_to_params_buffers) - load(model_to_load) + load(model_to_load, assign_to_params_buffers=assign_to_params_buffers) return error_msgs @@ -343,46 +388,6 @@ def _fetch_index_file( return index_file -# Adapted from -# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 -def _merge_sharded_checkpoints( - sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None -): - weight_map = sharded_metadata.get("weight_map", None) - if weight_map is None: - raise KeyError("'weight_map' key not found in the shard index file.") - - # Collect all unique safetensors files from weight_map - files_to_load = set(weight_map.values()) - is_safetensors = all(f.endswith(".safetensors") for f in files_to_load) - merged_state_dict = {} - - # Load tensors from each unique file - for file_name in files_to_load: - part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name) - if dduf_entries: - if part_file_path not in dduf_entries: - raise FileNotFoundError(f"Part file {file_name} not found.") - else: - if not os.path.exists(part_file_path): - raise FileNotFoundError(f"Part file {file_name} not found.") - - if is_safetensors: - if dduf_entries: - with dduf_entries[part_file_path].as_mmap() as mm: - tensors = safetensors.torch.load(mm) - merged_state_dict.update(tensors) - else: - with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: - for tensor_key in f.keys(): - if tensor_key in weight_map: - merged_state_dict[tensor_key] = f.get_tensor(tensor_key) - else: - merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu")) - - return merged_state_dict - - def _fetch_index_file_legacy( is_local, pretrained_model_name_or_path, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 4d5669e37f5a..0be7de60b796 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -20,10 +20,13 @@ import json import os import re +import shutil +import tempfile from collections import OrderedDict +from contextlib import ExitStack, contextmanager from functools import partial, wraps from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Type, Union import safetensors import torch @@ -63,16 +66,49 @@ _fetch_index_file, _fetch_index_file_legacy, _load_state_dict_into_model, - _merge_sharded_checkpoints, load_model_dict_into_meta, load_state_dict, ) +class ContextManagers: + """ + Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` + in the `fastcore` library. + """ + + def __init__(self, context_managers: List[ContextManager]): + self.context_managers = context_managers + self.stack = ExitStack() + + def __enter__(self): + for context_manager in self.context_managers: + self.stack.enter_context(context_manager) + + def __exit__(self, *args, **kwargs): + self.stack.__exit__(*args, **kwargs) + + logger = logging.get_logger(__name__) _REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}") +TORCH_INIT_FUNCTIONS = { + "uniform_": nn.init.uniform_, + "normal_": nn.init.normal_, + "trunc_normal_": nn.init.trunc_normal_, + "constant_": nn.init.constant_, + "xavier_uniform_": nn.init.xavier_uniform_, + "xavier_normal_": nn.init.xavier_normal_, + "kaiming_uniform_": nn.init.kaiming_uniform_, + "kaiming_normal_": nn.init.kaiming_normal_, + "uniform": nn.init.uniform, + "normal": nn.init.normal, + "xavier_uniform": nn.init.xavier_uniform, + "xavier_normal": nn.init.xavier_normal, + "kaiming_uniform": nn.init.kaiming_uniform, + "kaiming_normal": nn.init.kaiming_normal, +} if is_torch_version(">=", "1.9.0"): _LOW_CPU_MEM_USAGE_DEFAULT = True @@ -82,6 +118,8 @@ if is_accelerate_available(): import accelerate + from accelerate import dispatch_model + from accelerate.utils import load_offloaded_weights, save_offload_index def get_parameter_device(parameter: torch.nn.Module) -> torch.device: @@ -147,6 +185,56 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: return last_tuple[1].dtype +def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): + """ + Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first + checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's + parameters. + + Note: We fully disable this if we are using `deepspeed` + """ + if model_to_load.device.type == "meta": + return False + + if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: + return False + + # Some models explicitly do not support param buffer assignment + if not getattr(model_to_load, "_supports_param_buffer_assignment", True): + logger.debug( + f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" + ) + return False + + # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype + first_key = next(iter(model_to_load.state_dict().keys())) + if start_prefix + first_key in state_dict: + return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype + + return False + + +@contextmanager +def no_init_weights(): + """ + Context manager to globally disable weight initialization to speed up loading large models. To do that, all the + torch.nn.init function are all replaced with skip. + """ + + def _skip_init(*args, **kwargs): + pass + + # # Save the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, _skip_init) + try: + yield + finally: + # # Restore the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, init_func) + + class ModelMixin(torch.nn.Module, PushToHubMixin): r""" Base class for all models. @@ -700,7 +788,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) - offload_state_dict = kwargs.pop("offload_state_dict", False) + offload_state_dict = kwargs.pop("offload_state_dict", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) @@ -777,15 +865,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.") - # Load config if we don't provide a configuration - config_path = pretrained_model_name_or_path - user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } + unused_kwargs = {} + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + # TODO: We need to let the user pass a config in from_pretrained # load config config, unused_kwargs, commit_hash = cls.load_config( config_path, @@ -822,13 +912,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P hf_quantizer = None if hf_quantizer is not None: - if device_map is not None: - raise NotImplementedError( - "Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future." - ) - hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) + device_map = hf_quantizer.update_device_map(device_map) # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value @@ -856,10 +942,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.") else: keep_in_fp32_modules = [] - ####################################### - # Determine if we're loading from a directory of sharded checkpoints. is_sharded = False + resolved_archive_file = None + + # Determine if we're loading from a directory of sharded checkpoints. + sharded_metadata = None index_file = None is_local = os.path.isdir(pretrained_model_name_or_path) index_file_kwargs = { @@ -890,9 +978,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.") # load model - model_file = None if from_flax: - model_file = _get_model_file( + resolved_archive_file = _get_model_file( pretrained_model_name_or_path, weights_name=FLAX_WEIGHTS_NAME, cache_dir=cache_dir, @@ -910,11 +997,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Convert the weights from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model - model = load_flax_checkpoint_in_pytorch_model(model, model_file) + model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file) else: # in the case it is sharded, we have already the index if is_sharded: - sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( + resolved_archive_file, sharded_metadata = _get_checkpoint_shard_files( pretrained_model_name_or_path, index_file, cache_dir=cache_dir, @@ -926,17 +1013,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P subfolder=subfolder or "", dduf_entries=dduf_entries, ) - # TODO: https://github.com/huggingface/diffusers/issues/10013 - if hf_quantizer is not None or dduf_entries: - model_file = _merge_sharded_checkpoints( - sharded_ckpt_cached_folder, sharded_metadata, dduf_entries=dduf_entries - ) - logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") - is_sharded = False - - elif use_safetensors and not is_sharded: + elif use_safetensors: try: - model_file = _get_model_file( + resolved_archive_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, @@ -959,8 +1038,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." ) - if model_file is None and not is_sharded: - model_file = _get_model_file( + if resolved_archive_file is None and not is_sharded: + resolved_archive_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, @@ -975,159 +1054,99 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P dduf_entries=dduf_entries, ) - if low_cpu_mem_usage: - # Instantiate model with empty weights - with accelerate.init_empty_weights(): - model = cls.from_config(config, **unused_kwargs) + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] - if hf_quantizer is not None: - hf_quantizer.preprocess_model( - model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules - ) + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + dtype_orig = None + if torch_dtype is not None: + if not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + dtype_orig = cls._set_default_torch_dtype(torch_dtype) - # if device_map is None, load the state dict and move the params from meta device to the cpu - if device_map is None and not is_sharded: - # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. - # It would error out during the `validate_environment()` call above in the absence of cuda. - if hf_quantizer is None: - param_device = "cpu" - # TODO (sayakpaul, SunMarc): remove this after model loading refactor - else: - param_device = torch.device(torch.cuda.current_device()) - state_dict = load_state_dict( - model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap - ) - model._convert_deprecated_attention_blocks(state_dict) + init_contexts = [no_init_weights()] - # move the params from meta device to cpu - missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) - if hf_quantizer is not None: - missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="") - if len(missing_keys) > 0: - raise ValueError( - f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" - f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" - " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" - " those weights or else make sure your checkpoint file is correct." - ) + if low_cpu_mem_usage: + init_contexts.append(accelerate.init_empty_weights()) - named_buffers = model.named_buffers() - - unexpected_keys = load_model_dict_into_meta( - model, - state_dict, - device=param_device, - dtype=torch_dtype, - model_name_or_path=pretrained_model_name_or_path, - hf_quantizer=hf_quantizer, - keep_in_fp32_modules=keep_in_fp32_modules, - named_buffers=named_buffers, - ) + with ContextManagers(init_contexts): + model = cls.from_config(config, **unused_kwargs) - if cls._keys_to_ignore_on_load_unexpected is not None: - for pat in cls._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" - ) + state_dict = None + if not is_sharded: + # Time to load the checkpoint + state_dict = load_state_dict( + resolved_archive_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries + ) + # We only fix it for non sharded checkpoints as we don't need it yet for sharded one. + model._fix_state_dict_keys_on_load(state_dict) - else: # else let accelerate handle loading and dispatching. - # Load weights and dispatch according to the device_map - # by default the device_map is None and the weights are loaded on the CPU - device_map = _determine_device_map( - model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer - ) - if device_map is None and is_sharded: - # we load the parameters on the cpu - device_map = {"": "cpu"} - try: - accelerate.load_checkpoint_and_dispatch( - model, - model_file if not is_sharded else index_file, - device_map, - max_memory=max_memory, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - dtype=torch_dtype, - strict=True, - ) - except AttributeError as e: - # When using accelerate loading, we do not have the ability to load the state - # dict and rename the weight names manually. Additionally, accelerate skips - # torch loading conventions and directly writes into `module.{_buffers, _parameters}` - # (which look like they should be private variables?), so we can't use the standard hooks - # to rename parameters on load. We need to mimic the original weight names so the correct - # attributes are available. After we have loaded the weights, we convert the deprecated - # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert - # the weights so we don't have to do this again. - - if "'Attention' object has no attribute" in str(e): - logger.warning( - f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" - " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" - " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," - " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," - " please also re-upload it or open a PR on the original repository." - ) - model._temp_convert_self_to_deprecated_attention_blocks() - accelerate.load_checkpoint_and_dispatch( - model, - model_file if not is_sharded else index_file, - device_map, - max_memory=max_memory, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - dtype=torch_dtype, - strict=True, - ) - model._undo_temp_convert_self_to_deprecated_attention_blocks() - else: - raise e - - loading_info = { - "missing_keys": [], - "unexpected_keys": [], - "mismatched_keys": [], - "error_msgs": [], - } - else: - model = cls.from_config(config, **unused_kwargs) + if is_sharded: + loaded_keys = sharded_metadata["all_checkpoint_keys"] + else: + loaded_keys = list(state_dict.keys()) - state_dict = load_state_dict( - model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap - ) - model._convert_deprecated_attention_blocks(state_dict) + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + ) - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( - model, - state_dict, - model_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - ) + # Now that the model is loaded, we can determine the device_map + device_map = _determine_device_map( + model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer + ) + if hf_quantizer is not None: + hf_quantizer.validate_environment(device_map=device_map) + + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + loaded_keys, + ignore_mismatched_sizes=ignore_mismatched_sizes, + low_cpu_mem_usage=low_cpu_mem_usage, + device_map=device_map, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + dduf_entries=dduf_entries, + ) + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - "error_msgs": error_msgs, - } + # Dispatch model with hooks on all devices if necessary + if device_map is not None: + device_map_kwargs = { + "device_map": device_map, + "offload_dir": offload_folder, + "offload_index": offload_index, + } + dispatch_model(model, **device_map_kwargs) if hf_quantizer is not None: hf_quantizer.postprocess_model(model) model.hf_quantizer = hf_quantizer - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." - ) - # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will - # completely lose the effectivity of `use_keep_in_fp32_modules`. - elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules: - model = model.to(torch_dtype) - if hf_quantizer is not None: # We also make sure to purge `_pre_quantization_dtype` when we serialize # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable. @@ -1137,6 +1156,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Set model in evaluation mode to deactivate DropOut modules by default model.eval() + if output_loading_info: return model, loading_info @@ -1217,54 +1237,129 @@ def _load_pretrained_model( cls, model, state_dict: OrderedDict, - resolved_archive_file, + resolved_archive_file: List[str], pretrained_model_name_or_path: Union[str, os.PathLike], + loaded_keys, ignore_mismatched_sizes: bool = False, + assign_to_params_buffers: bool = False, + hf_quantizer=None, + low_cpu_mem_usage=None, + dtype=None, + keep_in_fp32_modules=None, + device_map=None, + offload_state_dict=None, + offload_folder=None, + dduf_entries=None, ): - # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() - loaded_keys = list(state_dict.keys()) - expected_keys = list(model_state_dict.keys()) - - original_loaded_keys = loaded_keys - missing_keys = list(set(expected_keys) - set(loaded_keys)) + if hf_quantizer is not None: + missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="") unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + # # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - # Make sure we are able to load base models as well as derived models (with heads) - model_to_load = model + mismatched_keys = [] - def _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - ignore_mismatched_sizes, - ): - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: - model_key = checkpoint_key - - if ( - model_key in model_state_dict - and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape - ): - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] - return mismatched_keys + assign_to_params_buffers = None + error_msgs = [] + # Deal with offload + if device_map is not None and "disk" in device_map.values(): + if offload_folder is None: + raise ValueError( + "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" + " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" + " offers the weights in this format." + ) + if offload_folder is not None: + os.makedirs(offload_folder, exist_ok=True) + if offload_state_dict is None: + offload_state_dict = True + + offload_index = {} if device_map is not None and "disk" in device_map.values() else None + if offload_state_dict: + state_dict_folder = tempfile.mkdtemp() + state_dict_index = {} + else: + state_dict_folder = None + state_dict_index = None + + # TODO: not sure if this is the most elegant way of dealing with this case if state_dict is not None: - # Whole checkpoint - mismatched_keys = _find_mismatched_keys( + # load_state_dict will manage the case where we pass a dict instead of a file + # if state dict is not None, it means that we don't need to read the files from resolved_archive_file also + resolved_archive_file = [state_dict] + + if len(resolved_archive_file) > 1: + resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") + + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) + + def _find_mismatched_keys( state_dict, model_state_dict, - original_loaded_keys, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + # If the checkpoint is sharded, we may not have the key here. + if checkpoint_key not in state_dict: + continue + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + mismatched_keys += _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, ignore_mismatched_sizes, ) - error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if low_cpu_mem_usage: + new_error_msgs, offload_index, state_dict_index = load_model_dict_into_meta( + model, + state_dict, + device_map=device_map, + dtype=dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_index=state_dict_index, + state_dict_folder=state_dict_folder, + ) + error_msgs += new_error_msgs + else: + if assign_to_params_buffers is None: + assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) + + error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) + + if offload_index is not None and len(offload_index) > 0: + save_offload_index(offload_index, offload_folder) + offload_index = None + + if offload_state_dict: + load_offloaded_weights(model, state_dict_index, state_dict_folder) + shutil.rmtree(state_dict_folder) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) @@ -1276,17 +1371,11 @@ def _find_mismatched_keys( if len(unexpected_keys) > 0: logger.warning( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" - f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" - f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" - " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" - " BertForPreTraining model).\n- This IS NOT expected if you are initializing" - f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" - " identical (initializing a BertForSequenceClassification model from a" - " BertForSequenceClassification model)." + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" @@ -1314,7 +1403,7 @@ def _find_mismatched_keys( " able to use it for predictions and inference." ) - return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs @classmethod def _get_signature_keys(cls, obj): @@ -1355,6 +1444,33 @@ def _get_no_split_modules(self, device_map: str): modules_to_check += list(module.children()) return list(_no_split_modules) + @classmethod + def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype: + """ + Change the default dtype and return the previous one. This is needed when wanting to instantiate the model + under specific dtype. + + Args: + dtype (`torch.dtype`): + a floating dtype to set to. + + Returns: + `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was + modified. If it wasn't, returns `None`. + + Note `set_default_dtype` currently only works with floating-point types and asserts if for example, + `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception. + """ + if not dtype.is_floating_point: + raise ValueError( + f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype" + ) + + logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.") + dtype_orig = torch.get_default_dtype() + torch.set_default_dtype(dtype) + return dtype_orig + @property def device(self) -> torch.device: """ @@ -1452,7 +1568,13 @@ def get_memory_footprint(self, return_buffers=True): mem = mem + mem_bufs return mem - def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: + def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None: + """ + This function fix the state dict of the model to take into account some changes that were made in the model + architecture: + - deprecated attention blocks (happened before we introduced sharded checkpoint, + so this is why we apply this method only when loading non sharded checkpoints for now) + """ deprecated_attention_block_paths = [] def recursive_find_attn_block(name, module): @@ -1495,56 +1617,7 @@ def recursive_find_attn_block(name, module): state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") if f"{path}.proj_attn.bias" in state_dict: state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") - - def _temp_convert_self_to_deprecated_attention_blocks(self) -> None: - deprecated_attention_block_modules = [] - - def recursive_find_attn_block(module): - if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: - deprecated_attention_block_modules.append(module) - - for sub_module in module.children(): - recursive_find_attn_block(sub_module) - - recursive_find_attn_block(self) - - for module in deprecated_attention_block_modules: - module.query = module.to_q - module.key = module.to_k - module.value = module.to_v - module.proj_attn = module.to_out[0] - - # We don't _have_ to delete the old attributes, but it's helpful to ensure - # that _all_ the weights are loaded into the new attributes and we're not - # making an incorrect assumption that this model should be converted when - # it really shouldn't be. - del module.to_q - del module.to_k - del module.to_v - del module.to_out - - def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None: - deprecated_attention_block_modules = [] - - def recursive_find_attn_block(module) -> None: - if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: - deprecated_attention_block_modules.append(module) - - for sub_module in module.children(): - recursive_find_attn_block(sub_module) - - recursive_find_attn_block(self) - - for module in deprecated_attention_block_modules: - module.to_q = module.query - module.to_k = module.key - module.to_v = module.value - module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)]) - - del module.query - del module.key - del module.value - del module.proj_attn + return state_dict class LegacyModelMixin(ModelMixin): diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 13aa7d076d03..285d72cf146e 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -279,9 +279,7 @@ def __init__( act_fn="silu_fp32", ) - self.text_embedding_padding = nn.Parameter( - torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32) - ) + self.text_embedding_padding = nn.Parameter(torch.randn(text_len + text_len_t5, cross_attention_dim)) self.pos_embed = PatchEmbed( height=sample_size, diff --git a/src/diffusers/pipelines/consisid/pipeline_consisid.py b/src/diffusers/pipelines/consisid/pipeline_consisid.py index 0d4891cf17d7..1a99c2a0e9ee 100644 --- a/src/diffusers/pipelines/consisid/pipeline_consisid.py +++ b/src/diffusers/pipelines/consisid/pipeline_consisid.py @@ -48,9 +48,14 @@ >>> from huggingface_hub import snapshot_download >>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview") - >>> face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = ( - ... prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16) - ... ) + >>> ( + ... face_helper_1, + ... face_helper_2, + ... face_clip_model, + ... face_main_model, + ... eva_transform_mean, + ... eva_transform_std, + ... ) = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16) >>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d56a2ce6eb30..e973343ef655 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -679,7 +679,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) - offload_state_dict = kwargs.pop("offload_state_dict", False) + offload_state_dict = kwargs.pop("offload_state_dict", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) dduf_file = kwargs.pop("dduf_file", None) diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 60c2f495fef8..ada75588a42a 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -235,18 +235,16 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": torch_dtype = torch.float16 return torch_dtype - # (sayakpaul): I think it could be better to disable custom `device_map`s - # for the first phase of the integration in the interest of simplicity. - # Commenting this for discussions on the PR. - # def update_device_map(self, device_map): - # if device_map is None: - # device_map = {"": torch.cuda.current_device()} - # logger.info( - # "The device_map was not initialized. " - # "Setting device_map to {'':torch.cuda.current_device()}. " - # "If you want to use the model for inference, please set device_map ='auto' " - # ) - # return device_map + def update_device_map(self, device_map): + if device_map is None: + device_map = {"": f"cuda:{torch.cuda.current_device()}"} + logger.info( + "The device_map was not initialized. " + "Setting device_map to {" + ": f`cuda:{torch.cuda.current_device()}`}. " + "If you want to use the model for inference, please set device_map ='auto' " + ) + return device_map def _process_model_before_weight_loading( self, @@ -289,9 +287,9 @@ def _process_model_before_weight_loading( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config ) model.config.quantization_config = self.quantization_config + model.is_loaded_in_4bit = True def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): - model.is_loaded_in_4bit = True model.is_4bit_serializable = self.is_serializable return model @@ -400,16 +398,17 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": torch_dtype = torch.float16 return torch_dtype - # # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map - # def update_device_map(self, device_map): - # if device_map is None: - # device_map = {"": torch.cuda.current_device()} - # logger.info( - # "The device_map was not initialized. " - # "Setting device_map to {'':torch.cuda.current_device()}. " - # "If you want to use the model for inference, please set device_map ='auto' " - # ) - # return device_map + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map + def update_device_map(self, device_map): + if device_map is None: + device_map = {"": f"cuda:{torch.cuda.current_device()}"} + logger.info( + "The device_map was not initialized. " + "Setting device_map to {" + ": f`cuda:{torch.cuda.current_device()}`}. " + "If you want to use the model for inference, please set device_map ='auto' " + ) + return device_map def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": if target_dtype != torch.int8: @@ -493,11 +492,10 @@ def create_quantized_param( # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): - model.is_loaded_in_8bit = True model.is_8bit_serializable = self.is_serializable return model - # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading with 4bit->8bit def _process_model_before_weight_loading( self, model: "ModelMixin", @@ -539,6 +537,7 @@ def _process_model_before_weight_loading( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config ) model.config.quantization_config = self.quantization_config + model.is_loaded_in_8bit = True @property # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index de587704ee17..f80f96a3425d 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -338,22 +338,6 @@ def _get_model_file( ) from e -# Adapted from -# https://github.com/huggingface/transformers/blob/1360801a69c0b169e3efdbb0cd05d9a0e72bfb70/src/transformers/utils/hub.py#L976 -# Differences are in parallelization of shard downloads and checking if shards are present. - - -def _check_if_shards_exist_locally(local_dir, subfolder, original_shard_filenames): - shards_path = os.path.join(local_dir, subfolder) - shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames] - for shard_file in shard_filenames: - if not os.path.exists(shard_file): - raise ValueError( - f"{shards_path} does not appear to have a file named {shard_file} which is " - "required according to the checkpoint index." - ) - - def _get_checkpoint_shard_files( pretrained_model_name_or_path, index_filename, @@ -396,13 +380,22 @@ def _get_checkpoint_shard_files( shards_path = os.path.join(pretrained_model_name_or_path, subfolder) # First, let's deal with local folder. - if os.path.isdir(pretrained_model_name_or_path): - _check_if_shards_exist_locally( - pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames - ) - return shards_path, sharded_metadata - elif dduf_entries: - return shards_path, sharded_metadata + if os.path.isdir(pretrained_model_name_or_path) or dduf_entries: + shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames] + for shard_file in shard_filenames: + if dduf_entries: + if shard_file not in dduf_entries: + raise FileNotFoundError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + else: + if not os.path.exists(shard_file): + raise FileNotFoundError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + return shard_filenames, sharded_metadata # At this stage pretrained_model_name_or_path is a model identifier on the Hub allow_patterns = original_shard_filenames @@ -444,7 +437,9 @@ def _get_checkpoint_shard_files( " again after checking your internet connection." ) from e - return cached_folder, sharded_metadata + cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames] + + return cached_filenames, sharded_metadata def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None): diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 05050e05bb19..a658aad01106 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -199,12 +199,12 @@ class ModelUtilsTest(unittest.TestCase): def tearDown(self): super().tearDown() - def test_accelerate_loading_error_message(self): - with self.assertRaises(ValueError) as error_context: + def test_missing_key_loading_warning_message(self): + with self.assertLogs("diffusers.models.modeling_utils", level="WARNING") as logs: UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet") # make sure that error message states what keys are missing - assert "conv_out.bias" in str(error_context.exception) + assert "conv_out.bias" in " ".join(logs.output) @parameterized.expand( [