diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index a44f0a2c629..30153a062dd 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -48,6 +48,15 @@ def main(cfg: "DictConfig"): # noqa: F821 device = "cpu" device = torch.device(device) + collector_device = cfg.collector.device + if collector_device in ("", None): + if torch.cuda.is_available(): + collector_device = "cuda:0" + else: + collector_device = "cpu" + collector_device = torch.device(collector_device) + cfg.collector.device = collector_device + # Create logger exp_name = generate_exp_name("DDPG", cfg.logger.exp_name) logger = None diff --git a/sota-implementations/ddpg/utils.py b/sota-implementations/ddpg/utils.py index 201ce2b39a3..19b1e3bbb7d 100644 --- a/sota-implementations/ddpg/utils.py +++ b/sota-implementations/ddpg/utils.py @@ -127,7 +127,8 @@ def make_collector( init_random_frames=cfg.collector.init_random_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + policy_device=cfg.collector.device, + env_device=train_env.device, compile_policy={"mode": compile_mode} if compile else False, cudagraph_policy=cudagraph, )