-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT] Model loading refactor #10604
base: main
Are you sure you want to change the base?
Changes from all commits
e54c540
645abc9
bd81f50
17c1be2
72b6259
b4e4f3b
5a00dc6
3bcd6cc
2f671af
7273a94
00f0bd1
c5da192
039eef5
21f94a1
337b2fc
aedf6af
0df7010
d3a7dc8
fc4af16
18d61bb
26228eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -20,7 +20,8 @@ | |||
from array import array | ||||
from collections import OrderedDict | ||||
from pathlib import Path | ||||
from typing import Dict, Iterator, List, Optional, Tuple, Union | ||||
from typing import Dict, List, Optional, Union | ||||
from zipfile import is_zipfile | ||||
|
||||
import safetensors | ||||
import torch | ||||
|
@@ -55,7 +56,7 @@ | |||
|
||||
if is_accelerate_available(): | ||||
from accelerate import infer_auto_device_map | ||||
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device | ||||
from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device | ||||
|
||||
|
||||
# Adapted from `transformers` (see modeling_utils.py) | ||||
|
@@ -132,17 +133,29 @@ def _fetch_remapped_cls_from_config(config, old_class): | |||
return old_class | ||||
|
||||
|
||||
def _check_archive_and_maybe_raise_error(checkpoint_file, format_list): | ||||
""" | ||||
Check format of the archive | ||||
""" | ||||
with safetensors.safe_open(checkpoint_file, framework="pt") as f: | ||||
metadata = f.metadata() | ||||
if metadata is not None and metadata.get("format") not in format_list: | ||||
raise OSError( | ||||
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " | ||||
"you save your model with the `save_pretrained` method." | ||||
) | ||||
|
||||
|
||||
def load_state_dict( | ||||
checkpoint_file: Union[str, os.PathLike], | ||||
variant: Optional[str] = None, | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
But let's make sure the method is invoked properly with proper arguments. |
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None, | ||||
disable_mmap: bool = False, | ||||
map_location: Optional[Union[str, torch.device]] = None, | ||||
): | ||||
""" | ||||
Reads a checkpoint file, returning properly formatted errors if they arise. | ||||
""" | ||||
# TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change | ||||
# when refactoring the _merge_sharded_checkpoints() method later. | ||||
# TODO: maybe refactor a bit this part where we pass a dict here | ||||
if isinstance(checkpoint_file, dict): | ||||
return checkpoint_file | ||||
SunMarc marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
try: | ||||
|
@@ -152,19 +165,28 @@ def load_state_dict( | |||
# tensors are loaded on cpu | ||||
with dduf_entries[checkpoint_file].as_mmap() as mm: | ||||
return safetensors.torch.load(mm) | ||||
_check_archive_and_maybe_raise_error(checkpoint_file, format_list=["pt", "flax"]) | ||||
if disable_mmap: | ||||
return safetensors.torch.load(open(checkpoint_file, "rb").read()) | ||||
else: | ||||
return safetensors.torch.load_file(checkpoint_file, device="cpu") | ||||
return safetensors.torch.load_file(checkpoint_file, device=map_location) | ||||
elif file_extension == GGUF_FILE_EXTENSION: | ||||
return load_gguf_checkpoint(checkpoint_file) | ||||
else: | ||||
if map_location is None: | ||||
map_location = "cpu" | ||||
extra_args = {} | ||||
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} | ||||
return torch.load( | ||||
checkpoint_file, | ||||
map_location="cpu", | ||||
**weights_only_kwarg, | ||||
) | ||||
# mmap can only be used with files serialized with zipfile-based format. | ||||
if ( | ||||
isinstance(checkpoint_file, str) | ||||
and map_location != "meta" | ||||
and is_torch_version(">=", "2.1.0") | ||||
and is_zipfile(checkpoint_file) | ||||
and not disable_mmap | ||||
): | ||||
extra_args = {"mmap": True} | ||||
return torch.load(checkpoint_file, map_location="cpu", **weights_only_kwarg, **extra_args) | ||||
except Exception as e: | ||||
try: | ||||
with open(checkpoint_file) as f: | ||||
|
@@ -188,23 +210,25 @@ def load_state_dict( | |||
def load_model_dict_into_meta( | ||||
model, | ||||
state_dict: OrderedDict, | ||||
device: Optional[Union[str, torch.device]] = None, | ||||
dtype: Optional[Union[str, torch.dtype]] = None, | ||||
model_name_or_path: Optional[str] = None, | ||||
hf_quantizer=None, | ||||
keep_in_fp32_modules=None, | ||||
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None, | ||||
device_map=None, | ||||
SunMarc marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
unexpected_keys=None, | ||||
offload_folder=None, | ||||
offload_index=None, | ||||
state_dict_index=None, | ||||
state_dict_folder=None, | ||||
) -> List[str]: | ||||
if device is not None and not isinstance(device, (str, torch.device)): | ||||
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.") | ||||
if hf_quantizer is None: | ||||
device = device or torch.device("cpu") | ||||
dtype = dtype or torch.float32 | ||||
is_quantized = hf_quantizer is not None | ||||
""" | ||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its | ||||
params on a `meta` device. It replaces the model params with the data from the `state_dict` | ||||
""" | ||||
|
||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) | ||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
error_msgs = [] | ||||
is_quantized = hf_quantizer is not None | ||||
empty_state_dict = model.state_dict() | ||||
unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] | ||||
SunMarc marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
for param_name, param in state_dict.items(): | ||||
if param_name not in empty_state_dict: | ||||
|
@@ -214,7 +238,7 @@ def load_model_dict_into_meta( | |||
# We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params | ||||
# in int/uint/bool and not cast them. | ||||
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn | ||||
if torch.is_floating_point(param): | ||||
if dtype is not None and torch.is_floating_point(param): | ||||
if ( | ||||
keep_in_fp32_modules is not None | ||||
and any( | ||||
|
@@ -223,72 +247,93 @@ def load_model_dict_into_meta( | |||
and dtype == torch.float16 | ||||
): | ||||
param = param.to(torch.float32) | ||||
if accepts_dtype: | ||||
set_module_kwargs["dtype"] = torch.float32 | ||||
set_module_kwargs["dtype"] = torch.float32 | ||||
else: | ||||
param = param.to(dtype) | ||||
if accepts_dtype: | ||||
set_module_kwargs["dtype"] = dtype | ||||
set_module_kwargs["dtype"] = dtype | ||||
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which | ||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. | ||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 | ||||
old_param = model | ||||
splits = param_name.split(".") | ||||
for split in splits: | ||||
old_param = getattr(old_param, split) | ||||
|
||||
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): | ||||
old_param = None | ||||
|
||||
if old_param is not None: | ||||
if dtype is None: | ||||
param = param.to(old_param.dtype) | ||||
|
||||
if old_param.is_contiguous(): | ||||
param = param.contiguous() | ||||
|
||||
module_name = param_name | ||||
|
||||
if device_map is None: | ||||
param_device = "cpu" | ||||
else: | ||||
# find next higher level module that is defined in device_map: | ||||
# bert.lm_head.weight -> bert.lm_head -> bert -> '' | ||||
while len(module_name) > 0 and module_name not in device_map: | ||||
module_name = ".".join(module_name.split(".")[:-1]) | ||||
if module_name == "" and "" not in device_map: | ||||
raise ValueError(f"{param_name} doesn't have any device set.") | ||||
param_device = device_map[module_name] | ||||
|
||||
# bnb params are flattened. | ||||
# gguf quants have a different shape based on the type of quantization applied | ||||
if empty_state_dict[param_name].shape != param.shape: | ||||
if ( | ||||
is_quantized | ||||
and hf_quantizer.pre_quantized | ||||
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) | ||||
and hf_quantizer.check_if_quantized_param( | ||||
model, param, param_name, state_dict, param_device=param_device | ||||
) | ||||
SunMarc marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
): | ||||
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param) | ||||
else: | ||||
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" | ||||
raise ValueError( | ||||
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." | ||||
) | ||||
|
||||
if is_quantized and ( | ||||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) | ||||
): | ||||
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) | ||||
else: | ||||
if accepts_dtype: | ||||
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) | ||||
else: | ||||
set_module_tensor_to_device(model, param_name, device, value=param) | ||||
|
||||
if named_buffers is None: | ||||
return unexpected_keys | ||||
Comment on lines
-258
to
-259
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this going away? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By adding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I am missing something but I couldn't spot There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nevermind, found it: diffusers/src/diffusers/models/modeling_utils.py Line 1144 in 18d61bb
It's a tad bit easier for reviewers if we could just provide these links going forward. |
||||
|
||||
for param_name, param in named_buffers: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to keep this or equivalent elsewhere, context: #10523 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The changes I did should also cover this use case. The test you added should pass with my PR. The is mainly due to adding the dispatch_model function at the end. |
||||
if is_quantized and ( | ||||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) | ||||
if param_device == "disk": | ||||
offload_index = offload_weight(param, param_name, offload_folder, offload_index) | ||||
elif param_device == "cpu" and state_dict_index is not None: | ||||
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) | ||||
elif is_quantized and ( | ||||
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device) | ||||
): | ||||
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) | ||||
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys) | ||||
else: | ||||
if accepts_dtype: | ||||
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) | ||||
else: | ||||
set_module_tensor_to_device(model, param_name, device, value=param) | ||||
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs) | ||||
|
||||
return unexpected_keys | ||||
return error_msgs, offload_index, state_dict_index | ||||
SunMarc marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: | ||||
def _load_state_dict_into_model( | ||||
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False | ||||
) -> List[str]: | ||||
# Convert old format to new format if needed from a PyTorch state_dict | ||||
# copy state_dict so _load_from_state_dict can modify it | ||||
state_dict = state_dict.copy() | ||||
error_msgs = [] | ||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | ||||
# so we need to apply the function recursively. | ||||
def load(module: torch.nn.Module, prefix: str = ""): | ||||
args = (state_dict, prefix, {}, True, [], [], error_msgs) | ||||
def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False): | ||||
local_metadata = {} | ||||
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers | ||||
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) | ||||
module._load_from_state_dict(*args) | ||||
|
||||
for name, child in module._modules.items(): | ||||
if child is not None: | ||||
load(child, prefix + name + ".") | ||||
load(child, prefix + name + ".", assign_to_params_buffers) | ||||
|
||||
load(model_to_load) | ||||
load(model_to_load, assign_to_params_buffers=assign_to_params_buffers) | ||||
|
||||
return error_msgs | ||||
|
||||
|
@@ -343,46 +388,6 @@ def _fetch_index_file( | |||
return index_file | ||||
|
||||
|
||||
# Adapted from | ||||
SunMarc marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 | ||||
def _merge_sharded_checkpoints( | ||||
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None | ||||
): | ||||
weight_map = sharded_metadata.get("weight_map", None) | ||||
if weight_map is None: | ||||
raise KeyError("'weight_map' key not found in the shard index file.") | ||||
|
||||
# Collect all unique safetensors files from weight_map | ||||
files_to_load = set(weight_map.values()) | ||||
is_safetensors = all(f.endswith(".safetensors") for f in files_to_load) | ||||
merged_state_dict = {} | ||||
|
||||
# Load tensors from each unique file | ||||
for file_name in files_to_load: | ||||
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name) | ||||
if dduf_entries: | ||||
if part_file_path not in dduf_entries: | ||||
raise FileNotFoundError(f"Part file {file_name} not found.") | ||||
else: | ||||
if not os.path.exists(part_file_path): | ||||
raise FileNotFoundError(f"Part file {file_name} not found.") | ||||
|
||||
if is_safetensors: | ||||
if dduf_entries: | ||||
with dduf_entries[part_file_path].as_mmap() as mm: | ||||
tensors = safetensors.torch.load(mm) | ||||
merged_state_dict.update(tensors) | ||||
else: | ||||
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: | ||||
for tensor_key in f.keys(): | ||||
if tensor_key in weight_map: | ||||
merged_state_dict[tensor_key] = f.get_tensor(tensor_key) | ||||
else: | ||||
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu")) | ||||
|
||||
return merged_state_dict | ||||
|
||||
|
||||
def _fetch_index_file_legacy( | ||||
is_local, | ||||
pretrained_model_name_or_path, | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the single-file related changes to uniformize the use of
load_model_dict_into_meta()
(with the new signature)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that's right !