1
1
import multiprocessing
2
-
3
- backup_ForkingPickler = multiprocessing .reduction .ForkingPickler
4
- backup_dump = multiprocessing .reduction .dump
5
2
import os
6
3
from functools import partial
7
4
25
22
reset_singletons ,
26
23
)
27
24
25
+ backup_ForkingPickler = multiprocessing .reduction .ForkingPickler
26
+ backup_dump = multiprocessing .reduction .dump
27
+
28
28
# (TOTAL_STEP, CKPT_EVERY, SNPASHOT_EVERY)
29
29
step_info_list = [(8 , 4 , 2 ), (3 , 4 , 2 ), (1 , 6 , 3 )]
30
30
ckpt_config_list = [
@@ -201,8 +201,8 @@ def return_latest_save_path(save_ckpt_folder, total_step, snapshot_freq, ckpt_fr
201
201
@pytest .mark .parametrize ("step_info" , step_info_list )
202
202
@pytest .mark .parametrize ("ckpt_config" , ckpt_config_list )
203
203
def test_ckpt_mm (step_info , ckpt_config , init_dist_and_model ): # noqa # pylint: disable=unused-import
204
- from internlm .core .context import global_context as gpc
205
204
from internlm .checkpoint .checkpoint_manager import CheckpointLoadMask
205
+ from internlm .core .context import global_context as gpc
206
206
207
207
ckpt_config = Config (ckpt_config )
208
208
total_step , checkpoint_every , oss_snapshot_freq = step_info
@@ -222,6 +222,8 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint:
222
222
)
223
223
224
224
model , opim = init_dist_and_model
225
+ gpc .config ._add_item ("ckpt" , dict ())
226
+ gpc .config .ckpt ._add_item ("universal_ckpt" , dict (enable = False , aysnc_save = True , broadcast_load = False ))
225
227
train_state = TrainState (gpc .config , None )
226
228
if isinstance (opim , HybridZeroOptimizer ):
227
229
print ("Is HybridZeroOptimizer!" , flush = True )
@@ -297,9 +299,9 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint:
297
299
298
300
299
301
def query_quit_file (rank , world_size = 2 ):
302
+ from internlm .checkpoint .checkpoint_manager import CheckpointSaveType
300
303
from internlm .core .context import global_context as gpc
301
304
from internlm .initialize import initialize_distributed_env
302
- from internlm .checkpoint .checkpoint_manager import CheckpointSaveType
303
305
304
306
ckpt_config = Config (
305
307
dict (
@@ -348,8 +350,6 @@ def query_quit_file(rank, world_size=2):
348
350
349
351
350
352
def test_quit_siganl_handler (): # noqa # pylint: disable=unused-import
351
- import multiprocessing
352
-
353
353
# we do hack here to workaround the bug of 3rd party library dill, which only occurs in this unittest:
354
354
# https://github.com/uqfoundation/dill/issues/380
355
355
multiprocessing .reduction .ForkingPickler = backup_ForkingPickler
0 commit comments