Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 12, 2024
1 parent bacf268 commit 3d68951
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
9 changes: 9 additions & 0 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ def main(cfg: "DictConfig"): # noqa: F821
device = "cpu"
device = torch.device(device)

collector_device = cfg.collector.device
if collector_device in ("", None):
if torch.cuda.is_available():
collector_device = "cuda:0"
else:
collector_device = "cpu"
collector_device = torch.device(collector_device)
cfg.collector.device = collector_device

# Create logger
exp_name = generate_exp_name("DDPG", cfg.logger.exp_name)
logger = None
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def make_collector(
init_random_frames=cfg.collector.init_random_frames,
reset_at_each_iter=cfg.collector.reset_at_each_iter,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
policy_device=cfg.collector.device,
env_device=train_env.device,
compile_policy={"mode": compile_mode} if compile else False,
cudagraph_policy=cudagraph,
)
Expand Down

0 comments on commit 3d68951

Please sign in to comment.