Skip to content

Commit

Permalink
feature(pu): add load pretrained ckpt in serial_entry_onpolicy and se…
Browse files Browse the repository at this point in the history
…rial_entry
  • Loading branch information
puyuan1996 committed Oct 28, 2024
1 parent d88ebe2 commit b09ffda
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
9 changes: 9 additions & 0 deletions ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ def serial_pipeline(
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])

# Load pretrained model if specified
if cfg.policy.load_path is not None:
logging.info(f'Loading model from {cfg.policy.load_path} begin...')
if cfg.policy.cuda and torch.cuda.is_available():
policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cuda'))
else:
policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
logging.info(f'Loading model from {cfg.policy.load_path} end!')

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
Expand Down
9 changes: 9 additions & 0 deletions ding/entry/serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def serial_pipeline_onpolicy(
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])

# Load pretrained model if specified
if cfg.policy.load_path is not None:
logging.info(f'Loading model from {cfg.policy.load_path} begin...')
if cfg.policy.cuda and torch.cuda.is_available():
policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cuda'))
else:
policy.learn_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
logging.info(f'Loading model from {cfg.policy.load_path} end!')

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
Expand Down

0 comments on commit b09ffda

Please sign in to comment.