From b09ffda34c288960b105e14fa8e8e618f940f047 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 28 Oct 2024 12:12:56 +0800 Subject: [PATCH] feature(pu): add load pretrained ckpt in serial_entry_onpolicy and serial_entry --- ding/entry/serial_entry.py | 9 +++++++++ ding/entry/serial_entry_onpolicy.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/ding/entry/serial_entry.py b/ding/entry/serial_entry.py index 929c83a219..5c879162b6 100644 --- a/ding/entry/serial_entry.py +++ b/ding/entry/serial_entry.py @@ -60,6 +60,15 @@ def serial_pipeline( set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) + # Load pretrained model if specified + if cfg.policy.load_path is not None: + logging.info(f'Loading model from {cfg.policy.load_path} begin...') + if cfg.policy.cuda and torch.cuda.is_available(): + policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cuda')) + else: + policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu')) + logging.info(f'Loading model from {cfg.policy.load_path} end!') + # Create worker components: learner, collector, evaluator, replay buffer, commander. tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) diff --git a/ding/entry/serial_entry_onpolicy.py b/ding/entry/serial_entry_onpolicy.py index 22e9cf74f9..112f326249 100644 --- a/ding/entry/serial_entry_onpolicy.py +++ b/ding/entry/serial_entry_onpolicy.py @@ -58,6 +58,15 @@ def serial_pipeline_onpolicy( set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) + # Load pretrained model if specified + if cfg.policy.load_path is not None: + logging.info(f'Loading model from {cfg.policy.load_path} begin...') + if cfg.policy.cuda and torch.cuda.is_available(): + policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cuda')) + else: + policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu')) + logging.info(f'Loading model from {cfg.policy.load_path} end!') + # Create worker components: learner, collector, evaluator, replay buffer, commander. tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)