-
Notifications
You must be signed in to change notification settings - Fork 4.5k
HF2UCP: Converting a pytorch_model.bin
or .safetensors
checkpoint to UCP
#7212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
0e1ea4c
727206e
49588a8
7bef517
2930f2a
9207df9
8090369
7b8962a
2389567
f34c6df
2fa0889
724a480
64f5bc3
4dbd67f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import torch | ||
import os | ||
import shutil | ||
import logging | ||
from concurrent.futures import ProcessPoolExecutor | ||
from deepspeed.accelerator import get_accelerator | ||
from tqdm import tqdm | ||
from typing import List | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
# Hard-coded constants for parameter patterns | ||
VOCAB_PARAMETER_PATTERNS = [ | ||
'word_embeddings', | ||
'embed_tokens', | ||
'embedding', | ||
'wte', # GPT style embeddings | ||
'lm_head' # Language model head, often tied with embeddings | ||
] | ||
|
||
|
||
def get_parameter_type(name: str) -> dict: | ||
"""Determine parameter type and required fields based on name.""" | ||
param_info = { | ||
'cat_dim': 0 # Default concatenation dimension | ||
} | ||
|
||
# Check for vocabulary tensors (embeddings, etc.) | ||
if any(pattern in name.lower() for pattern in VOCAB_PARAMETER_PATTERNS): | ||
param_info['vocab_tensor'] = True | ||
|
||
# TODO: figure out if we need to check for row-parallel parameters | ||
return param_info | ||
|
||
|
||
if __name__ == '__main__': | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint to Universal Checkpoint format') | ||
parser.add_argument('--hf_checkpoint_dir', | ||
type=str, | ||
required=True, | ||
help='Path to the HuggingFace checkpoint directory') | ||
parser.add_argument('--safe_serialization', | ||
action='store_true', | ||
default=False, | ||
help='Use safetensors for serialization') | ||
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers to use for saving checkpoints') | ||
parser.add_argument('--save_dir', type=str, required=True, help='Directory to save checkpoints') | ||
args = parser.parse_args() | ||
|
||
# Create a temporary directory for atomic operations | ||
temp_save_dir = args.save_dir + '.tmp' | ||
|
||
def save_parameter(name: str, param: torch.Tensor, save_dir: str): | ||
"""Save a parameter and its optimizer states in universal format.""" | ||
# Create parameter directory under zero/ | ||
param_dir = os.path.join(save_dir, name) | ||
os.makedirs(param_dir, exist_ok=True) | ||
|
||
# Get parameter type and required fields | ||
param_info = get_parameter_type(name) | ||
|
||
# Save parameter in fp32 with proper dictionary structure | ||
param_path = os.path.join(param_dir, "fp32.pt") | ||
param_dict = { | ||
'param': param.to(torch.float32), # Main tensor goes in 'param' field | ||
**param_info # Include all determined parameter info | ||
} | ||
torch.save(param_dict, param_path) | ||
|
||
# Since HuggingFace checkpoints do not have optimizer states, | ||
# we initialize them with zeros | ||
for state in ("exp_avg", "exp_avg_sq"): | ||
state_path = os.path.join(param_dir, f"{state}.pt") | ||
state_dict = { | ||
'param': torch.zeros_like(param, dtype=torch.float32), | ||
**param_info # Include same parameter info in optimizer states | ||
} | ||
torch.save(state_dict, state_path) | ||
|
||
def process_shard(shard_file, checkpoint_dir, save_dir, safe_serialization): | ||
"""Process a single shard file.""" | ||
try: | ||
shard_path = os.path.join(checkpoint_dir, shard_file) | ||
logger.info(f"Loading shard from: {shard_path}") | ||
|
||
if safe_serialization: | ||
from safetensors.torch import load_file | ||
shard_dict = load_file(shard_path) | ||
else: | ||
shard_dict = torch.load(shard_path, map_location='cpu') | ||
|
||
# Create progress bar for parameters within this shard | ||
pbar = tqdm(total=len(shard_dict), | ||
desc=f"Processing {os.path.basename(shard_file)}", | ||
position=1, | ||
leave=False) | ||
|
||
for key, param in shard_dict.items(): | ||
save_parameter(key, param, save_dir) | ||
del param | ||
pbar.update(1) | ||
pbar.set_postfix({'key': key[:20] + '...' if len(key) > 20 else key}) | ||
|
||
pbar.close() | ||
del shard_dict | ||
get_accelerator().empty_cache() | ||
logger.info(f"Completed processing shard: {shard_file}") | ||
|
||
except Exception as e: | ||
logger.error(f"Error processing shard {shard_file}: {str(e)}") | ||
raise | ||
|
||
def get_shard_list(checkpoint_dir): | ||
"""Get list of shards from index file.""" | ||
if args.safe_serialization: | ||
index_file = os.path.join(checkpoint_dir, "model.safetensors.index.json") | ||
else: | ||
index_file = os.path.join(checkpoint_dir, "pytorch_model.bin.index.json") | ||
|
||
if os.path.exists(index_file): | ||
import json | ||
with open(index_file, 'r') as f: | ||
index = json.load(f) | ||
return list(set(index['weight_map'].values())) | ||
else: | ||
# Handle single file case | ||
if args.safe_serialization and os.path.exists(os.path.join(checkpoint_dir, "model.safetensors")): | ||
return ["model.safetensors"] | ||
elif os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")): | ||
return ["pytorch_model.bin"] | ||
else: | ||
raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}") | ||
|
||
def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: str, safe_serialization: bool): | ||
"""Process a batch of shards in parallel.""" | ||
with ProcessPoolExecutor(max_workers=args.num_workers) as executor: | ||
futures = [] | ||
for shard_file in shard_files: | ||
future = executor.submit(process_shard, shard_file, checkpoint_dir, save_dir, safe_serialization) | ||
futures.append((shard_file, future)) | ||
|
||
# Create progress bar for this batch | ||
batch_pbar = tqdm(total=len(futures), desc=f"Processing shard batch", position=0, leave=True) | ||
|
||
# Wait for all futures to complete | ||
for shard_file, future in futures: | ||
try: | ||
future.result() # This will raise any exceptions that occurred | ||
batch_pbar.update(1) | ||
batch_pbar.set_postfix({'last_completed': os.path.basename(shard_file)}) | ||
except Exception as e: | ||
logger.error(f"Failed processing shard {shard_file}: {str(e)}") | ||
raise | ||
|
||
batch_pbar.close() | ||
|
||
try: | ||
# Create zero subdirectory in temp directory | ||
temp_zero_dir = os.path.join(temp_save_dir, 'zero') | ||
if os.path.exists(temp_zero_dir): | ||
logger.info(f"Removing existing temp directory: {temp_zero_dir}") | ||
shutil.rmtree(temp_zero_dir) | ||
|
||
shard_files = get_shard_list(args.hf_checkpoint_dir) | ||
total_shards = len(shard_files) | ||
logger.info(f"Found {total_shards} shards to process") | ||
# Process shards in batches equal to the number of workers | ||
batch_size = args.num_workers | ||
for i in range(0, total_shards, batch_size): | ||
batch_shards = shard_files[i:i + batch_size] | ||
logger.info( | ||
f"Processing batch of {len(batch_shards)} shards ({i+1}-{min(i+batch_size, total_shards)} of {total_shards})" | ||
) | ||
process_shard_batch( | ||
batch_shards, | ||
args.hf_checkpoint_dir, | ||
temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir | ||
args.safe_serialization) | ||
|
||
# Clear CUDA cache after each batch to free up memory | ||
get_accelerator().empty_cache() | ||
|
||
logger.info("All shard batches processed successfully") | ||
|
||
final_save_dir = os.path.join(args.save_dir, 'zero') | ||
if os.path.exists(final_save_dir): | ||
shutil.rmtree(final_save_dir) | ||
|
||
# Create the parent directory if it doesn't exist | ||
os.makedirs(os.path.dirname(final_save_dir), exist_ok=True) | ||
# Move the zero directory to its final location | ||
os.rename(temp_zero_dir, final_save_dir) | ||
|
||
# Clean up the temporary directory | ||
if os.path.exists(temp_save_dir): | ||
shutil.rmtree(temp_save_dir) | ||
|
||
# Write identifier file | ||
with open(os.path.join(args.save_dir, 'source.txt'), 'w') as f: | ||
f.write("Huggingface checkpoint") | ||
|
||
logger.info(f"Successfully saved checkpoint to {final_save_dir}") | ||
|
||
# Update latest file | ||
checkpoint_root_folder = os.path.dirname(args.save_dir) | ||
step_folder = os.path.basename(args.save_dir) | ||
latest_file = os.path.join(checkpoint_root_folder, 'latest_universal') | ||
with open(latest_file, 'w') as f: | ||
f.write(step_folder) | ||
|
||
logger.info(f"Checkpoint conversion completed successfully. Latest file updated at {latest_file}") | ||
|
||
except Exception as e: | ||
logger.error(f"Failed to process checkpoint: {str(e)}") | ||
if os.path.exists(temp_save_dir): | ||
shutil.rmtree(temp_save_dir) | ||
raise |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,11 +21,14 @@ class ZeROOptimizer(DeepSpeedOptimizer): | |
def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: | ||
checkpoint_dir = os.path.join(checkpoint_dir, "zero") | ||
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") | ||
assert os.path.isfile( | ||
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' | ||
optim_sd = torch.load(optim_state_path, weights_only=False) | ||
|
||
self._load_global_state(optim_sd) | ||
if os.path.isfile(optim_state_path): | ||
ignore_missing_optim_state = False | ||
optim_sd = torch.load(optim_state_path, weights_only=False) | ||
self._load_global_state(optim_sd) | ||
else: | ||
logger.warning(f'{optim_state_path} containing optimizer global state is missing!') | ||
ignore_missing_optim_state = True | ||
optim_sd = {} | ||
|
||
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) | ||
if self.mpu is None: | ||
|
@@ -35,8 +38,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec | |
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ | ||
else self.mpu.get_tensor_model_parallel_world_size() | ||
|
||
for i, (param_group, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel the changes in this loop can be avoided by returning early if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @tjruwase , returning early when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Schwidola0607 thanks for pointing that out. I overlooked some of that logic. My main concern is that readability is hurt by the following snippet in the middle of the long loop body. If ignore_missing_optim_state:
continue What about splitting the loop into two: (1) L42-L61, and (2) L65-L69? Then the second loop can be skipped for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tjruwase that makes sense, I will go ahead and make the fix. |
||
loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])): | ||
for i, param_group in enumerate(self.optimizer.param_groups): | ||
# We have an assumption that all params in the same param_group have the same keys | ||
opt_keys = set() | ||
steps = [] | ||
|
@@ -58,6 +60,9 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec | |
|
||
map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys) | ||
|
||
if ignore_missing_optim_state: | ||
continue | ||
loaded_param_group = optim_sd['param_groups'][i] | ||
for key, value in loaded_param_group.items(): | ||
if key == 'params': | ||
continue | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2864,11 +2864,13 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa | |
""" Load optimizer and model states from the checkpoint directory. """ | ||
checkpoint_dir = os.path.join(checkpoint_dir, "zero") | ||
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") | ||
assert os.path.isfile( | ||
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' | ||
|
||
optim_sd = torch.load(optim_state_path, weights_only=False) | ||
self._load_global_state_stage3(optim_sd) | ||
if os.path.isfile(optim_state_path): | ||
ignore_missing_optim_state = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @xylian86 , all the code path currently calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. That’s fine with me. |
||
optim_sd = torch.load(optim_state_path, weights_only=False) | ||
self._load_global_state_stage3(optim_sd) | ||
else: | ||
logger.warning(f'{optim_state_path} containing optimizer global state is missing!') | ||
ignore_missing_optim_state = True | ||
|
||
key_list = ["fp32", "exp_avg", "exp_avg_sq"] | ||
|
||
|
@@ -2880,14 +2882,13 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa | |
if key == "fp32": | ||
self.fp32_partitioned_groups_flat[0].data.copy_(key_tensor) | ||
self.optimizer.param_groups[0]['params'].append(self.fp32_partitioned_groups_flat[0]) | ||
else: | ||
elif not ignore_missing_optim_state: | ||
optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor | ||
|
||
if self.swap_optimizer: | ||
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint | ||
self.optimizer_swapper.purge_state() | ||
|
||
if self.swap_optimizer: | ||
# Touch all parameters to synchronize all buffers | ||
timer_names = set() | ||
self._partition_all_parameters() | ||
|
@@ -2897,9 +2898,10 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa | |
self._release_sub_group(sub_group_id, timer_names) | ||
self._post_step(timer_names) | ||
|
||
self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT]) | ||
for param_group in self.optimizer.param_groups: | ||
param_group['params'] = [] | ||
if not ignore_missing_optim_state: | ||
self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT]) | ||
for param_group in self.optimizer.param_groups: | ||
param_group['params'] = [] | ||
|
||
for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): | ||
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] | ||
|
@@ -2923,7 +2925,13 @@ def load_hp_checkpoint_state(self, folder, key): | |
local_rank = dist.get_local_rank() | ||
|
||
# Load tensors from files and reshape them to flat vectors | ||
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1) | ||
loaded_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False) | ||
if isinstance(loaded_state, dict): | ||
loaded_checkpoint_state = loaded_state['param'].view(-1) | ||
elif isinstance(loaded_state, torch.Tensor): | ||
loaded_checkpoint_state = loaded_state.view(-1) | ||
else: | ||
raise ValueError(f"Unknown type {type(loaded_state)} for loaded state") | ||
|
||
# Partition the loaded data according to the local rank | ||
world_size = dist.get_world_size(group=self.dp_process_group) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion:
Update the logic as follows:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See below