Skip to content

Commit e5ec237

Browse files
Merge pull request #2 from Schwidola0607/xylian_dev
[fix] format
2 parents 915bef8 + 30b33f9 commit e5ec237

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

deepspeed/checkpoint/hf_to_universal.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,17 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s
166166
shard_files = get_shard_list(args.hf_checkpoint_dir)
167167
total_shards = len(shard_files)
168168
logger.info(f"Found {total_shards} shards to process")
169-
170-
# Process shards in batches equal to number of workers
169+
# Process shards in batches equal to the number of workers
171170
batch_size = args.num_workers
172171
for i in range(0, total_shards, batch_size):
173172
batch_shards = shard_files[i:i + batch_size]
174173
logger.info(f"Processing batch of {len(batch_shards)} shards ({i+1}-{min(i+batch_size, total_shards)} of {total_shards})")
175174
process_shard_batch(batch_shards,
176-
args.hf_checkpoint_dir,
177-
temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir
178-
args.safe_serialization)
175+
args.hf_checkpoint_dir,
176+
temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir
177+
args.safe_serialization)
179178

180-
# Force garbage collection after each batch
179+
# Clear CUDA cache after each batch to free up memory
181180
torch.cuda.empty_cache()
182181

183182
logger.info("All shard batches processed successfully")

deepspeed/runtime/engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2962,7 +2962,6 @@ def load_checkpoint(self,
29622962
custom_load_fn=custom_load_fn)
29632963

29642964
load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
2965-
# import pdb; pdb.set_trace()
29662965
if self.load_universal_checkpoint():
29672966
ucp_ckpt_folder = os.path.join(load_dir, tag)
29682967
# UCP load can ignore '*mp' files or '*model_states.pt' but ucp_ckpt_folder must exist

0 commit comments

Comments
 (0)