Skip to content

HF2UCP: Converting a pytorch_model.bin or .safetensors checkpoint to UCP #7212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 225 additions & 0 deletions deepspeed/checkpoint/hf_to_universal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import os
import shutil
import logging
from concurrent.futures import ProcessPoolExecutor
from deepspeed.accelerator import get_accelerator
from tqdm import tqdm
from typing import List

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Hard-coded constants for parameter patterns
VOCAB_PARAMETER_PATTERNS = [
'word_embeddings',
'embed_tokens',
'embedding',
'wte', # GPT style embeddings
'lm_head' # Language model head, often tied with embeddings
]


def get_parameter_type(name: str) -> dict:
"""Determine parameter type and required fields based on name."""
param_info = {
'cat_dim': 0 # Default concatenation dimension
}

# Check for vocabulary tensors (embeddings, etc.)
if any(pattern in name.lower() for pattern in VOCAB_PARAMETER_PATTERNS):
param_info['vocab_tensor'] = True

# TODO: figure out if we need to check for row-parallel parameters
return param_info


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint to Universal Checkpoint format')
parser.add_argument('--hf_checkpoint_dir',
type=str,
required=True,
help='Path to the HuggingFace checkpoint directory')
parser.add_argument('--safe_serialization',
action='store_true',
default=False,
help='Use safetensors for serialization')
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers to use for saving checkpoints')
parser.add_argument('--save_dir', type=str, required=True, help='Directory to save checkpoints')
args = parser.parse_args()

# Create a temporary directory for atomic operations
temp_save_dir = args.save_dir + '.tmp'

def save_parameter(name: str, param: torch.Tensor, save_dir: str):
"""Save a parameter and its optimizer states in universal format."""
# Create parameter directory under zero/
param_dir = os.path.join(save_dir, name)
os.makedirs(param_dir, exist_ok=True)

# Get parameter type and required fields
param_info = get_parameter_type(name)

# Save parameter in fp32 with proper dictionary structure
param_path = os.path.join(param_dir, "fp32.pt")
param_dict = {
'param': param.to(torch.float32), # Main tensor goes in 'param' field
**param_info # Include all determined parameter info
}
torch.save(param_dict, param_path)

# Since HuggingFace checkpoints do not have optimizer states,
# we initialize them with zeros
for state in ("exp_avg", "exp_avg_sq"):
state_path = os.path.join(param_dir, f"{state}.pt")
state_dict = {
'param': torch.zeros_like(param, dtype=torch.float32),
**param_info # Include same parameter info in optimizer states
}
torch.save(state_dict, state_path)

def process_shard(shard_file, checkpoint_dir, save_dir, safe_serialization):
"""Process a single shard file."""
try:
shard_path = os.path.join(checkpoint_dir, shard_file)
logger.info(f"Loading shard from: {shard_path}")

if safe_serialization:
from safetensors.torch import load_file
shard_dict = load_file(shard_path)
else:
shard_dict = torch.load(shard_path, map_location='cpu')

# Create progress bar for parameters within this shard
pbar = tqdm(total=len(shard_dict),
desc=f"Processing {os.path.basename(shard_file)}",
position=1,
leave=False)

for key, param in shard_dict.items():
save_parameter(key, param, save_dir)
del param
pbar.update(1)
pbar.set_postfix({'key': key[:20] + '...' if len(key) > 20 else key})

pbar.close()
del shard_dict
get_accelerator().empty_cache()
logger.info(f"Completed processing shard: {shard_file}")

except Exception as e:
logger.error(f"Error processing shard {shard_file}: {str(e)}")
raise

def get_shard_list(checkpoint_dir):
"""Get list of shards from index file."""
if args.safe_serialization:
index_file = os.path.join(checkpoint_dir, "model.safetensors.index.json")
else:
index_file = os.path.join(checkpoint_dir, "pytorch_model.bin.index.json")

if os.path.exists(index_file):
import json
with open(index_file, 'r') as f:
index = json.load(f)
return list(set(index['weight_map'].values()))
else:
# Handle single file case
if args.safe_serialization and os.path.exists(os.path.join(checkpoint_dir, "model.safetensors")):
return ["model.safetensors"]
elif os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")):
return ["pytorch_model.bin"]
else:
raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}")

def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: str, safe_serialization: bool):
"""Process a batch of shards in parallel."""
with ProcessPoolExecutor(max_workers=args.num_workers) as executor:
futures = []
for shard_file in shard_files:
future = executor.submit(process_shard, shard_file, checkpoint_dir, save_dir, safe_serialization)
futures.append((shard_file, future))

# Create progress bar for this batch
batch_pbar = tqdm(total=len(futures), desc=f"Processing shard batch", position=0, leave=True)

# Wait for all futures to complete
for shard_file, future in futures:
try:
future.result() # This will raise any exceptions that occurred
batch_pbar.update(1)
batch_pbar.set_postfix({'last_completed': os.path.basename(shard_file)})
except Exception as e:
logger.error(f"Failed processing shard {shard_file}: {str(e)}")
raise

batch_pbar.close()

try:
# Create zero subdirectory in temp directory
temp_zero_dir = os.path.join(temp_save_dir, 'zero')
if os.path.exists(temp_zero_dir):
logger.info(f"Removing existing temp directory: {temp_zero_dir}")
shutil.rmtree(temp_zero_dir)

shard_files = get_shard_list(args.hf_checkpoint_dir)
total_shards = len(shard_files)
logger.info(f"Found {total_shards} shards to process")
# Process shards in batches equal to the number of workers
batch_size = args.num_workers
for i in range(0, total_shards, batch_size):
batch_shards = shard_files[i:i + batch_size]
logger.info(
f"Processing batch of {len(batch_shards)} shards ({i+1}-{min(i+batch_size, total_shards)} of {total_shards})"
)
process_shard_batch(
batch_shards,
args.hf_checkpoint_dir,
temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir
args.safe_serialization)

# Clear CUDA cache after each batch to free up memory
get_accelerator().empty_cache()

logger.info("All shard batches processed successfully")

final_save_dir = os.path.join(args.save_dir, 'zero')
if os.path.exists(final_save_dir):
shutil.rmtree(final_save_dir)

# Create the parent directory if it doesn't exist
os.makedirs(os.path.dirname(final_save_dir), exist_ok=True)
# Move the zero directory to its final location
os.rename(temp_zero_dir, final_save_dir)

# Clean up the temporary directory
if os.path.exists(temp_save_dir):
shutil.rmtree(temp_save_dir)

# Write identifier file
with open(os.path.join(args.save_dir, 'source.txt'), 'w') as f:
f.write("Huggingface checkpoint")

logger.info(f"Successfully saved checkpoint to {final_save_dir}")

# Update latest file
checkpoint_root_folder = os.path.dirname(args.save_dir)
step_folder = os.path.basename(args.save_dir)
latest_file = os.path.join(checkpoint_root_folder, 'latest_universal')
with open(latest_file, 'w') as f:
f.write(step_folder)

logger.info(f"Checkpoint conversion completed successfully. Latest file updated at {latest_file}")

except Exception as e:
logger.error(f"Failed to process checkpoint: {str(e)}")
if os.path.exists(temp_save_dir):
shutil.rmtree(temp_save_dir)
raise
19 changes: 12 additions & 7 deletions deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ class ZeROOptimizer(DeepSpeedOptimizer):
def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None:
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
optim_sd = torch.load(optim_state_path, weights_only=False)

self._load_global_state(optim_sd)
if os.path.isfile(optim_state_path):
ignore_missing_optim_state = False
optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state(optim_sd)
else:
logger.warning(f'{optim_state_path} containing optimizer global state is missing!')
ignore_missing_optim_state = True
optim_sd = {}

tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
if self.mpu is None:
Expand All @@ -34,8 +37,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
else self.mpu.get_tensor_model_parallel_world_size()

for i, (param_group,
loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])):
for i, param_group in enumerate(self.optimizer.param_groups):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()
steps = []
Expand All @@ -57,6 +59,9 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec

map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys)

if ignore_missing_optim_state:
continue
loaded_param_group = optim_sd['param_groups'][i]
for key, value in loaded_param_group.items():
if key == 'params':
continue
Expand Down
15 changes: 12 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2923,7 +2923,7 @@ def _get_all_ckpt_names(self, checkpoints_path, tag):

ckpt_files = glob.glob(ckpt_file_pattern)
ckpt_files.sort()
return ckpt_files
return ckpt_files, ckpt_file_pattern

def load_checkpoint(self,
load_dir,
Expand All @@ -2947,7 +2947,7 @@ def load_checkpoint(self,

Returns:
A tuple of ``load_path`` and ``client_state``.
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed or loading a HF based UCP
*``client_state``: State dictionary used for loading required training states in the client code.

Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
Expand Down Expand Up @@ -2986,6 +2986,11 @@ def load_checkpoint(self,
custom_load_fn=custom_load_fn)

load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
if self.load_universal_checkpoint():
ucp_ckpt_folder = os.path.join(load_dir, tag)
# UCP load can ignore '*mp' files or '*model_states.pt' but ucp_ckpt_folder must exist
load_zero_checkpoint = os.path.isdir(ucp_ckpt_folder)

if load_zero_checkpoint:
if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint():
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
Expand Down Expand Up @@ -3026,7 +3031,11 @@ def _load_checkpoint(self,

from deepspeed.runtime.state_dict_factory import SDLoaderFactory

ckpt_list = self._get_all_ckpt_names(load_dir, tag)
ckpt_list, ckpt_file_pattern = self._get_all_ckpt_names(load_dir, tag)
if self.load_universal_checkpoint() and len(ckpt_list) == 0:
logger.warning(f"Unable to find {ckpt_file_pattern} files in UCP folder {load_dir}")
return None, {}

sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine)

is_pipe_parallel = isinstance(self.module, PipelineModule)
Expand Down
30 changes: 19 additions & 11 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2753,11 +2753,13 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
""" Load optimizer and model states from the checkpoint directory. """
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'

optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state_stage3(optim_sd)
if os.path.isfile(optim_state_path):
ignore_missing_optim_state = False
optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state_stage3(optim_sd)
else:
logger.warning(f'{optim_state_path} containing optimizer global state is missing!')
ignore_missing_optim_state = True

key_list = ["fp32", "exp_avg", "exp_avg_sq"]

Expand All @@ -2769,14 +2771,13 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
if key == "fp32":
self.fp32_partitioned_groups_flat[0].data.copy_(key_tensor)
self.optimizer.param_groups[0]['params'].append(self.fp32_partitioned_groups_flat[0])
else:
elif not ignore_missing_optim_state:
optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor

if self.swap_optimizer:
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
self.optimizer_swapper.purge_state()

if self.swap_optimizer:
# Touch all parameters to synchronize all buffers
timer_names = set()
self._partition_all_parameters()
Expand All @@ -2786,9 +2787,10 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
self._release_sub_group(sub_group_id, timer_names)
self._post_step(timer_names)

self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT])
for param_group in self.optimizer.param_groups:
param_group['params'] = []
if not ignore_missing_optim_state:
self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT])
for param_group in self.optimizer.param_groups:
param_group['params'] = []

for sub_group_id in range(len(self.fp32_partitioned_groups_flat)):
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
Expand All @@ -2812,7 +2814,13 @@ def load_hp_checkpoint_state(self, folder, key):
local_rank = dist.get_local_rank()

# Load tensors from files and reshape them to flat vectors
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1)
loaded_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False)
if isinstance(loaded_state, dict):
loaded_checkpoint_state = loaded_state['param'].view(-1)
elif isinstance(loaded_state, torch.Tensor):
loaded_checkpoint_state = loaded_state.view(-1)
else:
raise ValueError(f"Unknown type {type(loaded_state)} for loaded state")

# Partition the loaded data according to the local rank
world_size = dist.get_world_size(group=self.dp_process_group)
Expand Down
Loading