Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Model loading refactor #10604

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,17 +362,18 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =

if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
unexpected_keys = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the single-file related changes to uniformize the use of load_model_dict_into_meta() (with the new signature)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's right !

param_name for param_name in diffusers_format_checkpoint if param_name not in model.state_dict()
]
load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device=param_device,
device_map={"": param_device},
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
unexpected_keys=unexpected_keys,
)

else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

Expand Down
24 changes: 3 additions & 21 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,18 +1588,9 @@ def create_diffusers_clip_model_from_ldm(
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")

if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
model.load_state_dict(diffusers_format_checkpoint, strict=False)

if torch_dtype is not None:
model.to(torch_dtype)
Expand Down Expand Up @@ -2056,16 +2047,7 @@ def create_diffusers_t5_model_from_checkpoint(
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)

if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)

load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
model.load_state_dict(diffusers_format_checkpoint)

Expand Down
195 changes: 100 additions & 95 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile

import safetensors
import torch
Expand Down Expand Up @@ -55,7 +56,7 @@

if is_accelerate_available():
from accelerate import infer_auto_device_map
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device


# Adapted from `transformers` (see modeling_utils.py)
Expand Down Expand Up @@ -132,17 +133,29 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class


def _check_archive_and_maybe_raise_error(checkpoint_file, format_list):
"""
Check format of the archive
"""
with safetensors.safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata is not None and metadata.get("format") not in format_list:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)


def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
variant: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variant isn't used anyway in this method, so good for me.

But let's make sure the method is invoked properly with proper arguments.

dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
disable_mmap: bool = False,
map_location: Optional[Union[str, torch.device]] = None,
):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
# TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
# when refactoring the _merge_sharded_checkpoints() method later.
# TODO: maybe refactor a bit this part where we pass a dict here
if isinstance(checkpoint_file, dict):
return checkpoint_file
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
try:
Expand All @@ -152,19 +165,28 @@ def load_state_dict(
# tensors are loaded on cpu
with dduf_entries[checkpoint_file].as_mmap() as mm:
return safetensors.torch.load(mm)
_check_archive_and_maybe_raise_error(checkpoint_file, format_list=["pt", "flax"])
if disable_mmap:
return safetensors.torch.load(open(checkpoint_file, "rb").read())
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
return safetensors.torch.load_file(checkpoint_file, device=map_location)
elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file)
else:
if map_location is None:
map_location = "cpu"
extra_args = {}
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
checkpoint_file,
map_location="cpu",
**weights_only_kwarg,
)
# mmap can only be used with files serialized with zipfile-based format.
if (
isinstance(checkpoint_file, str)
and map_location != "meta"
and is_torch_version(">=", "2.1.0")
and is_zipfile(checkpoint_file)
and not disable_mmap
):
extra_args = {"mmap": True}
return torch.load(checkpoint_file, map_location="cpu", **weights_only_kwarg, **extra_args)
except Exception as e:
try:
with open(checkpoint_file) as f:
Expand All @@ -188,23 +210,25 @@ def load_state_dict(
def load_model_dict_into_meta(
model,
state_dict: OrderedDict,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
model_name_or_path: Optional[str] = None,
hf_quantizer=None,
keep_in_fp32_modules=None,
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None,
device_map=None,
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
unexpected_keys=None,
offload_folder=None,
offload_index=None,
state_dict_index=None,
state_dict_folder=None,
) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)):
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
if hf_quantizer is None:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
is_quantized = hf_quantizer is not None
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`
"""

accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
error_msgs = []
is_quantized = hf_quantizer is not None
empty_state_dict = model.state_dict()
unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
SunMarc marked this conversation as resolved.
Show resolved Hide resolved

for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
Expand All @@ -214,7 +238,7 @@ def load_model_dict_into_meta(
# We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn
if torch.is_floating_point(param):
if dtype is not None and torch.is_floating_point(param):
if (
keep_in_fp32_modules is not None
and any(
Expand All @@ -223,72 +247,93 @@ def load_model_dict_into_meta(
and dtype == torch.float16
):
param = param.to(torch.float32)
if accepts_dtype:
set_module_kwargs["dtype"] = torch.float32
set_module_kwargs["dtype"] = torch.float32
else:
param = param.to(dtype)
if accepts_dtype:
set_module_kwargs["dtype"] = dtype
set_module_kwargs["dtype"] = dtype

# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)

if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None

if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype)

if old_param.is_contiguous():
param = param.contiguous()

module_name = param_name

if device_map is None:
param_device = "cpu"
else:
# find next higher level module that is defined in device_map:
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
if module_name == "" and "" not in device_map:
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]

# bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied
if empty_state_dict[param_name].shape != param.shape:
if (
is_quantized
and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
and hf_quantizer.check_if_quantized_param(
model, param, param_name, state_dict, param_device=param_device
)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
else:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)

if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)

if named_buffers is None:
return unexpected_keys
Comment on lines -258 to -259
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this going away?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By adding dispatch_model at the end, we won't have the issue with buffers anymore. The test that @hlky added is passing this is PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, found it:

dispatch_model(model, **device_map_kwargs)

It's a tad bit easier for reviewers if we could just provide these links going forward.


for param_name, param in named_buffers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep this or equivalent elsewhere, context: #10523

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes I did should also cover this use case. The test you added should pass with my PR. The is mainly due to adding the dispatch_model function at the end.

if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
if param_device == "disk":
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)

return unexpected_keys
return error_msgs, offload_index, state_dict_index
SunMarc marked this conversation as resolved.
Show resolved Hide resolved


def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
state_dict = state_dict.copy()
error_msgs = []

# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: torch.nn.Module, prefix: str = ""):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False):
local_metadata = {}
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
module._load_from_state_dict(*args)

for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(child, prefix + name + ".", assign_to_params_buffers)

load(model_to_load)
load(model_to_load, assign_to_params_buffers=assign_to_params_buffers)

return error_msgs

Expand Down Expand Up @@ -343,46 +388,6 @@ def _fetch_index_file(
return index_file


# Adapted from
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
def _merge_sharded_checkpoints(
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None
):
weight_map = sharded_metadata.get("weight_map", None)
if weight_map is None:
raise KeyError("'weight_map' key not found in the shard index file.")

# Collect all unique safetensors files from weight_map
files_to_load = set(weight_map.values())
is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
merged_state_dict = {}

# Load tensors from each unique file
for file_name in files_to_load:
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
if dduf_entries:
if part_file_path not in dduf_entries:
raise FileNotFoundError(f"Part file {file_name} not found.")
else:
if not os.path.exists(part_file_path):
raise FileNotFoundError(f"Part file {file_name} not found.")

if is_safetensors:
if dduf_entries:
with dduf_entries[part_file_path].as_mmap() as mm:
tensors = safetensors.torch.load(mm)
merged_state_dict.update(tensors)
else:
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
for tensor_key in f.keys():
if tensor_key in weight_map:
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
else:
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))

return merged_state_dict


def _fetch_index_file_legacy(
is_local,
pretrained_model_name_or_path,
Expand Down
Loading