27
27
from internlm .core .context import global_context as gpc
28
28
from internlm .core .context import ParallelMode
29
29
from internlm .train .pipeline import map_fqn_global_to_local , map_layer_attr
30
+ from internlm .utils .common import get_current_device
30
31
31
32
32
33
from torch .distributed .checkpoint .metadata import (
@@ -880,8 +881,9 @@ def read_from_files(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlan
880
881
bytes .seek (0 )
881
882
planner .load_bytes (req , bytes )
882
883
else :
883
- tensor = cast (Tensor , torch .load (file_slice , map_location = "cpu" ))
884
+ tensor = cast (Tensor , torch .load (file_slice , map_location = "cpu" )) #att
884
885
tensor = narrow_tensor_by_index (tensor , req .storage_offsets , req .lengths )
886
+ print (f"req: { req .dest_index .fqn } , { req } " , flush = True )
885
887
target_tensor = planner .resolve_tensor (req ).detach ()
886
888
887
889
assert (
@@ -892,18 +894,20 @@ def read_from_files(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlan
892
894
893
895
def read_data_with_broadcast (self , per_file : Dict [str , List [ReadItem ]], planner : LoadPlanner ):
894
896
for relative_path , reqs in per_file .items ():
895
- if dist .get_rank (self .data_parallel_process_group ) == 0 :
897
+ # if dist.get_rank(self.data_parallel_process_group) == 0:
898
+ if gpc .get_local_rank (ParallelMode .DATA ) == 0 :
896
899
file_path = self ._get_file_path (relative_path )
897
900
file = open (file_path , "rb" )
898
901
dist .barrier (self .data_parallel_process_group )
899
902
reqs = sorted (reqs , key = lambda req : self .storage_data [req .storage_index ].offset )
900
903
for req in reqs :
901
- if dist . get_rank ( self . data_parallel_process_group ) == 0 :
904
+ if gpc . get_local_rank ( ParallelMode . DATA ) == 0 :
902
905
item_md = self .storage_data [req .storage_index ]
903
906
file_slice = self ._slice_file (file , item_md )
904
907
905
908
if req .type == LoadItemType .BYTE_IO :
906
- if dist .get_rank (self .data_parallel_process_group ) == 0 :
909
+ assert False
910
+ if gpc .get_local_rank (ParallelMode .DATA ) == 0 :
907
911
object_list = [io .BytesIO (file_slice .read (item_md .length ))]
908
912
else :
909
913
object_list = [None ]
@@ -912,23 +916,23 @@ def read_data_with_broadcast(self, per_file: Dict[str, List[ReadItem]], planner:
912
916
object_list ,
913
917
src = dist .get_global_rank (self .data_parallel_process_group , 0 ),
914
918
group = self .data_parallel_process_group ,
915
- device = f"cuda: { torch . cuda . current_device () } " ,
919
+ device = get_current_device () ,
916
920
)
917
921
bytes = object_list [0 ]
918
922
bytes .seek (0 )
919
923
planner .load_bytes (req , bytes )
920
924
else :
921
- if dist . get_rank ( self . data_parallel_process_group ) == 0 :
925
+ if gpc . get_local_rank ( ParallelMode . DATA ) == 0 :
922
926
object_list = [cast (Tensor , torch .load (file_slice , map_location = "cuda" ))]
923
927
else :
924
928
object_list = [None ]
925
929
dist .broadcast_object_list (
926
930
object_list ,
927
931
src = dist .get_global_rank (self .data_parallel_process_group , 0 ),
928
932
group = self .data_parallel_process_group ,
929
- device = f"cuda: { torch . cuda . current_device () } " ,
933
+ device = get_current_device () ,
930
934
)
931
- tensor = object_list [0 ].cpu ()
935
+ tensor = object_list [0 ].cpu () #att
932
936
tensor = narrow_tensor_by_index (tensor , req .storage_offsets , req .lengths )
933
937
target_tensor = planner .resolve_tensor (req ).detach ()
934
938
0 commit comments