Skip to content

Commit 364c594

Browse files
committed
support save cache and load broadcast
1 parent 774d32f commit 364c594

9 files changed

+59
-37
lines changed

internlm/checkpoint/checkpoint_manager.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None,
8787
if universal_ckpt:
8888
from internlm.checkpoint.vescale.api import load as vescale_load
8989
checkpoint_state = {"model": ckpt_mm.model, "optimizer": ckpt_mm.optimizer}
90-
vescale_load(load_ckpt_folder, checkpoint_state, broadcast_checkpoint=False)
90+
vescale_load(load_ckpt_folder, checkpoint_state, broadcast_checkpoint=gpc.config.ckpt.universal_ckpt.broadcast_load)
9191

9292
if not universal_ckpt and load_content.need_load(CheckpointLoadContent.MODEL):
9393
load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model)
@@ -448,7 +448,7 @@ def try_save_checkpoint(self, train_state, force=False):
448448
train_state=train_state,
449449
model_config=self.model_config,
450450
model_config_file=self.model_config_file,
451-
universal_ckpt=gpc.config.ckpt.universal_ckpt,
451+
universal_ckpt=gpc.config.ckpt.universal_ckpt.enable,
452452
)
453453

454454
if (
@@ -591,7 +591,7 @@ def try_resume_training(self, train_state: TrainState, current_time=""):
591591
load_path = self.load_ckpt_info["path"]
592592
load_content = self.load_ckpt_info["content"]
593593
load_type = self.load_ckpt_info["ckpt_type"]
594-
universal_ckpt = gpc.config.ckpt.universal_ckpt
594+
universal_ckpt = gpc.config.ckpt.universal_ckpt.enable
595595
kwargs = {}
596596

597597
if universal_ckpt:
@@ -656,7 +656,7 @@ def save_checkpoint(
656656
vescale_save(
657657
path=folder,
658658
checkpoint_state={"model": model, "optimizer": optimizer},
659-
async_checkpoint=False,
659+
async_checkpoint=gpc.config.ckpt.universal_ckpt.aysnc_save,
660660
)
661661

662662

internlm/checkpoint/vescale/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def sort_rank_ranges(process_list: List[Tuple]) -> List[Tuple]:
5959
return sorted_process_list
6060

6161

62-
_MAX_CACHE_SIZE = 8
62+
_MAX_CACHE_SIZE = 2 # model ckpt + optm ckpt
6363

6464

6565
class PlanLRUCache:

internlm/checkpoint/vescale/filesystem.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from internlm.core.context import global_context as gpc
2828
from internlm.core.context import ParallelMode
2929
from internlm.train.pipeline import map_fqn_global_to_local, map_layer_attr
30+
from internlm.utils.common import get_current_device
3031

3132

3233
from torch.distributed.checkpoint.metadata import (
@@ -880,8 +881,9 @@ def read_from_files(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlan
880881
bytes.seek(0)
881882
planner.load_bytes(req, bytes)
882883
else:
883-
tensor = cast(Tensor, torch.load(file_slice, map_location="cpu"))
884+
tensor = cast(Tensor, torch.load(file_slice, map_location="cpu")) #att
884885
tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
886+
print(f"req: {req.dest_index.fqn}, {req}", flush=True)
885887
target_tensor = planner.resolve_tensor(req).detach()
886888

887889
assert (
@@ -892,18 +894,20 @@ def read_from_files(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlan
892894

893895
def read_data_with_broadcast(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlanner):
894896
for relative_path, reqs in per_file.items():
895-
if dist.get_rank(self.data_parallel_process_group) == 0:
897+
# if dist.get_rank(self.data_parallel_process_group) == 0:
898+
if gpc.get_local_rank(ParallelMode.DATA) == 0:
896899
file_path = self._get_file_path(relative_path)
897900
file = open(file_path, "rb")
898901
dist.barrier(self.data_parallel_process_group)
899902
reqs = sorted(reqs, key=lambda req: self.storage_data[req.storage_index].offset)
900903
for req in reqs:
901-
if dist.get_rank(self.data_parallel_process_group) == 0:
904+
if gpc.get_local_rank(ParallelMode.DATA)== 0:
902905
item_md = self.storage_data[req.storage_index]
903906
file_slice = self._slice_file(file, item_md)
904907

905908
if req.type == LoadItemType.BYTE_IO:
906-
if dist.get_rank(self.data_parallel_process_group) == 0:
909+
assert False
910+
if gpc.get_local_rank(ParallelMode.DATA) == 0:
907911
object_list = [io.BytesIO(file_slice.read(item_md.length))]
908912
else:
909913
object_list = [None]
@@ -912,23 +916,23 @@ def read_data_with_broadcast(self, per_file: Dict[str, List[ReadItem]], planner:
912916
object_list,
913917
src=dist.get_global_rank(self.data_parallel_process_group, 0),
914918
group=self.data_parallel_process_group,
915-
device=f"cuda:{torch.cuda.current_device()}",
919+
device=get_current_device(),
916920
)
917921
bytes = object_list[0]
918922
bytes.seek(0)
919923
planner.load_bytes(req, bytes)
920924
else:
921-
if dist.get_rank(self.data_parallel_process_group) == 0:
925+
if gpc.get_local_rank(ParallelMode.DATA) == 0:
922926
object_list = [cast(Tensor, torch.load(file_slice, map_location="cuda"))]
923927
else:
924928
object_list = [None]
925929
dist.broadcast_object_list(
926930
object_list,
927931
src=dist.get_global_rank(self.data_parallel_process_group, 0),
928932
group=self.data_parallel_process_group,
929-
device=f"cuda:{torch.cuda.current_device()}",
933+
device=get_current_device(),
930934
)
931-
tensor = object_list[0].cpu()
935+
tensor = object_list[0].cpu() #att
932936
tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
933937
target_tensor = planner.resolve_tensor(req).detach()
934938

internlm/checkpoint/vescale/save_state_dict.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def save_state_dict(
5353
[veScale version] Saves a distributed model in SPMD style. Fix sub-group storage.
5454
Args and usage is the same as `torch.distributed.checkpoint.save_state_dict`.
5555
"""
56-
5756
# Step 0: create distributed world based on process group and coordinator rank
5857
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
5958
if process_group:
@@ -132,6 +131,7 @@ def finish_checkpoint(all_results):
132131

133132
# Wait for last write futures to finish.
134133
if last_write_futures:
134+
print(f"last_write_futures: {last_write_futures}", flush=True)
135135
logger.info("Start waiting for last write events.")
136136
last_write_start_time = time.time()
137137
for fut in last_write_futures:
@@ -145,22 +145,23 @@ def finish_checkpoint(all_results):
145145
plan_start_time = time.time()
146146
cached_data = None
147147

148+
# if isinstance(planner, VeScaleSavePlanner):
149+
# central_plan = distW.reduce_scatter("plan", local_step, global_step)
150+
# else:
151+
# raise AssertionError("Unsupported planner for saving checkpoint")
152+
148153
if isinstance(planner, VeScaleSavePlanner):
149-
central_plan = distW.reduce_scatter("plan", local_step, global_step)
154+
cached_data = planner.lookup_plan_meta()
155+
if cached_data:
156+
logger.info("Plan cache hit. Reuse existing plan")
157+
central_plan, _ = cached_data
158+
# _ = local_step() #attn
159+
else:
160+
logger.info("Plan cache miss. The model/optimizer appears for the first time.")
161+
162+
central_plan = distW.reduce_scatter("plan", local_step, global_step)
150163
else:
151164
raise AssertionError("Unsupported planner for saving checkpoint")
152-
# if isinstance(planner, VeScaleSavePlanner): #attn
153-
# cached_data = planner.lookup_plan_meta()
154-
# if cached_data:
155-
# logger.debug("Plan cache hit. Reuse existing plan")
156-
# central_plan, _ = cached_data
157-
# _ = local_step()
158-
# else:
159-
# logger.debug("Plan cache miss. The model/optimizer appears for the first time.")
160-
161-
# central_plan = distW.reduce_scatter("plan", local_step, global_step)
162-
# else:
163-
# raise AssertionError("Unsupported planner for saving checkpoint")
164165

165166

166167

@@ -194,7 +195,7 @@ def finish_checkpoint(all_results):
194195
final_storage_metadata = distW.all_reduce("write", write_data, finish_checkpoint)
195196
assert central_plan is not None
196197
assert final_storage_metadata is not None
197-
# planner.cache_plan_meta(central_plan, final_storage_metadata) #attn
198+
planner.cache_plan_meta(central_plan, final_storage_metadata) #attn
198199
else:
199200
raise AssertionError("Unsupported planner for writing data and metadata")
200201
store_local_cost_time = time.time() - store_local_start_time

internlm/checkpoint/vescale/vescale_checkpointer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ def load(
225225
# print(f"model_state {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.PIPELINE)}: {p})", flush=True)
226226
# Set process group
227227
if broadcast_checkpoint:
228-
assert False
229-
model_load_process_group = VESCALE_DEVICE_MESH.get_data_parallel_dim_groups()
228+
# model_load_process_group = VESCALE_DEVICE_MESH.get_data_parallel_dim_groups()
229+
model_load_process_group = gpc.get_group(ParallelMode.DATA)
230230
else:
231231
model_load_process_group = None
232232
# Load model

internlm/checkpoint/vescale/vescale_planner.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,33 @@ def lookup_object(self, index: MetadataIndex, fqn=None) -> Any:
135135
return find_state_dict_object(self.state_dict, index, fqn)
136136

137137
def lookup_plan_meta(self) -> Optional[Tuple[SavePlan, Metadata]]:
138+
# if not hasattr(self, STATE_DICT_STR):
139+
# return None
140+
# else:
141+
# device_mesh = VESCALE_DEVICE_MESH.get()
142+
# plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh))
143+
# return self._plan_cache.get(plan_key)
144+
138145
if not hasattr(self, STATE_DICT_STR):
139146
return None
140147
else:
141-
device_mesh = VESCALE_DEVICE_MESH.get()
142-
plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh))
148+
plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator))
143149
return self._plan_cache.get(plan_key)
144150

145151
def cache_plan_meta(self, new_plan: SavePlan, new_metadata: Metadata) -> None:
146-
device_mesh = VESCALE_DEVICE_MESH.get()
147-
plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh))
152+
# device_mesh = VESCALE_DEVICE_MESH.get()
153+
# plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh))
154+
# self._plan_cache.put(plan_key, new_plan, new_metadata)
155+
156+
print(f"new_plan {gpc.get_global_rank()}: {new_plan}", flush=True)
157+
print(f"new_metadata {gpc.get_global_rank()}: {new_metadata}", flush=True)
158+
159+
plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator))
160+
print(f"Before GPU Memory Allocated {gpc.get_global_rank()}: {torch.cuda.memory_allocated() /1024/1024} bytes", flush=True)
161+
print(f"Before GPU Memory Cached {gpc.get_global_rank()}: {torch.cuda.memory_reserved() /1024/1024} bytes", flush=True)
148162
self._plan_cache.put(plan_key, new_plan, new_metadata)
163+
print(f"After GPU Memory Allocated {gpc.get_global_rank()}: {torch.cuda.memory_allocated() /1024/1024} bytes", flush=True)
164+
print(f"After GPU Memory Cached {gpc.get_global_rank()}: {torch.cuda.memory_reserved() /1024/1024} bytes", flush=True)
149165

150166
def clear_cache(self) -> None:
151167
self._plan_cache.clear()

internlm/checkpoint/vescale/vescale_planner_helpers.py

+1
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex, fq
298298
# if isinstance(obj, torch.Tensor): #att
299299
# return find_tensor_shard(obj, index)
300300
if isinstance(obj, OptimizerStateSpec):
301+
assert False
301302
return obj.local_tensor
302303
# elif index.offset is not None:
303304
# raise ValueError(

internlm/initialize/launch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def args_sanity_check():
261261
ckpt._add_item("auto_resume", True)
262262

263263
if "universal_ckpt" not in ckpt:
264-
ckpt._add_item("universal_ckpt", False)
264+
ckpt._add_item("universal_ckpt", dict(enable=False, aysnc_save=False, broadcast_load=False))
265265

266266
if gpc.is_rank_for_log():
267267
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201

internlm/solver/optimizer/hybrid_zero_optim.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,7 @@ def state_dict(self):
10001000
optim_states = self.optim.state_dict()
10011001
grad_scaler = self.grad_scaler.state_dict()
10021002
states["grad_scaler"] = grad_scaler
1003-
if not gpc.config.ckpt.universal_ckpt:
1003+
if not gpc.config.ckpt.universal_ckpt.enable:
10041004
states["base_optim_states"] = optim_states
10051005
flat_fp32_weights = {}
10061006
for group_id, param in self._fp32_flat_param_groups_of_current_rank.items():
@@ -1217,7 +1217,7 @@ def state_dict(self):
12171217

12181218

12191219
def load_state_dict(self, states, global_optimizer_state=None):
1220-
if not gpc.config.ckpt.universal_ckpt:
1220+
if not gpc.config.ckpt.universal_ckpt.enable:
12211221
# TODO: Need to take into account the change in the number of DP.
12221222
assert "grad_scaler" in states, "Not found grad_scaler state!"
12231223
grad_scaler = states["grad_scaler"]

0 commit comments

Comments
 (0)