From 36d1d3a5cd11b8c0ce37afb551cf36877ee53ca0 Mon Sep 17 00:00:00 2001 From: Noah Farr Date: Tue, 27 Aug 2024 22:28:43 +0200 Subject: [PATCH 1/5] Add support for multiple envs --- cleanrl/td3_continuous_action.py | 125 ++++++++++++++++++++++++------- 1 file changed, 97 insertions(+), 28 deletions(-) diff --git a/cleanrl/td3_continuous_action.py b/cleanrl/td3_continuous_action.py index 418c4b3b3..70a3b1e4a 100644 --- a/cleanrl/td3_continuous_action.py +++ b/cleanrl/td3_continuous_action.py @@ -47,6 +47,8 @@ class Args: """total timesteps of the experiments""" learning_rate: float = 3e-4 """the learning rate of the optimizer""" + num_envs: int = 2 + """the number of parallel game environments""" buffer_size: int = int(1e6) """the replay memory buffer size""" gamma: float = 0.99 @@ -85,7 +87,11 @@ def thunk(): class QNetwork(nn.Module): def __init__(self, env): super().__init__() - self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256) + self.fc1 = nn.Linear( + np.array(env.single_observation_space.shape).prod() + + np.prod(env.single_action_space.shape), + 256, + ) self.fc2 = nn.Linear(256, 256) self.fc3 = nn.Linear(256, 1) @@ -105,10 +111,18 @@ def __init__(self, env): self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape)) # action rescaling self.register_buffer( - "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32) + "action_scale", + torch.tensor( + (env.single_action_space.high - env.single_action_space.low) / 2.0, + dtype=torch.float32, + ), ) self.register_buffer( - "action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32) + "action_bias", + torch.tensor( + (env.single_action_space.high + env.single_action_space.low) / 2.0, + dtype=torch.float32, + ), ) def forward(self, x): @@ -145,7 +159,8 @@ def forward(self, x): writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + "|param|value|\n|-|-|\n%s" + % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) # TRY NOT TO MODIFY: seeding @@ -157,8 +172,15 @@ def forward(self, x): device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup - envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) - assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + envs = gym.vector.SyncVectorEnv( + [ + make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) + for i in range(args.num_envs) + ] + ) + assert isinstance( + envs.single_action_space, gym.spaces.Box + ), "only continuous action space is supported" actor = Actor(envs).to(device) qf1 = QNetwork(envs).to(device) @@ -169,7 +191,9 @@ def forward(self, x): target_actor.load_state_dict(actor.state_dict()) qf1_target.load_state_dict(qf1.state_dict()) qf2_target.load_state_dict(qf2.state_dict()) - q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate) + q_optimizer = optim.Adam( + list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate + ) actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.learning_rate) envs.single_observation_space.dtype = np.float32 @@ -178,6 +202,7 @@ def forward(self, x): envs.single_observation_space, envs.single_action_space, device, + n_envs=args.num_envs, handle_timeout_termination=False, ) start_time = time.time() @@ -187,12 +212,18 @@ def forward(self, x): for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: - actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) + actions = np.array( + [envs.single_action_space.sample() for _ in range(envs.num_envs)] + ) else: with torch.no_grad(): actions = actor(torch.Tensor(obs).to(device)) actions += torch.normal(0, actor.action_scale * args.exploration_noise) - actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high) + actions = ( + actions.cpu() + .numpy() + .clip(envs.single_action_space.low, envs.single_action_space.high) + ) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions) @@ -200,10 +231,17 @@ def forward(self, x): # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: for info in infos["final_info"]: - print(f"global_step={global_step}, episodic_return={info['episode']['r']}") - writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) - writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) - break + if info is not None: + print( + f"global_step={global_step}, episodic_return={info['episode']['r']}" + ) + writer.add_scalar( + "charts/episodic_return", info["episode"]["r"], global_step + ) + writer.add_scalar( + "charts/episodic_length", info["episode"]["l"], global_step + ) + break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() @@ -219,17 +257,21 @@ def forward(self, x): if global_step > args.learning_starts: data = rb.sample(args.batch_size) with torch.no_grad(): - clipped_noise = (torch.randn_like(data.actions, device=device) * args.policy_noise).clamp( - -args.noise_clip, args.noise_clip - ) * target_actor.action_scale + clipped_noise = ( + torch.randn_like(data.actions, device=device) * args.policy_noise + ).clamp(-args.noise_clip, args.noise_clip) * target_actor.action_scale - next_state_actions = (target_actor(data.next_observations) + clipped_noise).clamp( + next_state_actions = ( + target_actor(data.next_observations) + clipped_noise + ).clamp( envs.single_action_space.low[0], envs.single_action_space.high[0] ) qf1_next_target = qf1_target(data.next_observations, next_state_actions) qf2_next_target = qf2_target(data.next_observations, next_state_actions) min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1) + next_q_value = data.rewards.flatten() + ( + 1 - data.dones.flatten() + ) * args.gamma * (min_qf_next_target).view(-1) qf1_a_values = qf1(data.observations, data.actions).view(-1) qf2_a_values = qf2(data.observations, data.actions).view(-1) @@ -249,22 +291,42 @@ def forward(self, x): actor_optimizer.step() # update the target network - for param, target_param in zip(actor.parameters(), target_actor.parameters()): - target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) - for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): - target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) - for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): - target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + for param, target_param in zip( + actor.parameters(), target_actor.parameters() + ): + target_param.data.copy_( + args.tau * param.data + (1 - args.tau) * target_param.data + ) + for param, target_param in zip( + qf1.parameters(), qf1_target.parameters() + ): + target_param.data.copy_( + args.tau * param.data + (1 - args.tau) * target_param.data + ) + for param, target_param in zip( + qf2.parameters(), qf2_target.parameters() + ): + target_param.data.copy_( + args.tau * param.data + (1 - args.tau) * target_param.data + ) if global_step % 100 == 0: - writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) - writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step) + writer.add_scalar( + "losses/qf1_values", qf1_a_values.mean().item(), global_step + ) + writer.add_scalar( + "losses/qf2_values", qf2_a_values.mean().item(), global_step + ) writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step) writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step) writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step) print("SPS:", int(global_step / (time.time() - start_time))) - writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + writer.add_scalar( + "charts/SPS", + int(global_step / (time.time() - start_time)), + global_step, + ) if args.save_model: model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" @@ -290,7 +352,14 @@ def forward(self, x): repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name - push_to_hub(args, episodic_returns, repo_id, "TD3", f"runs/{run_name}", f"videos/{run_name}-eval") + push_to_hub( + args, + episodic_returns, + repo_id, + "TD3", + f"runs/{run_name}", + f"videos/{run_name}-eval", + ) envs.close() writer.close() From d348e9f78eddad51fe0978a48851c7e6ca87c26c Mon Sep 17 00:00:00 2001 From: Noah Farr Date: Tue, 27 Aug 2024 22:31:26 +0200 Subject: [PATCH 2/5] Run pre-commit --- cleanrl/td3_continuous_action.py | 87 +++++++++----------------------- 1 file changed, 23 insertions(+), 64 deletions(-) diff --git a/cleanrl/td3_continuous_action.py b/cleanrl/td3_continuous_action.py index 70a3b1e4a..3877c0e7c 100644 --- a/cleanrl/td3_continuous_action.py +++ b/cleanrl/td3_continuous_action.py @@ -88,8 +88,7 @@ class QNetwork(nn.Module): def __init__(self, env): super().__init__() self.fc1 = nn.Linear( - np.array(env.single_observation_space.shape).prod() - + np.prod(env.single_action_space.shape), + np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256, ) self.fc2 = nn.Linear(256, 256) @@ -159,8 +158,7 @@ def forward(self, x): writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", - "|param|value|\n|-|-|\n%s" - % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) # TRY NOT TO MODIFY: seeding @@ -173,14 +171,9 @@ def forward(self, x): # env setup envs = gym.vector.SyncVectorEnv( - [ - make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) - for i in range(args.num_envs) - ] + [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] ) - assert isinstance( - envs.single_action_space, gym.spaces.Box - ), "only continuous action space is supported" + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" actor = Actor(envs).to(device) qf1 = QNetwork(envs).to(device) @@ -191,9 +184,7 @@ def forward(self, x): target_actor.load_state_dict(actor.state_dict()) qf1_target.load_state_dict(qf1.state_dict()) qf2_target.load_state_dict(qf2.state_dict()) - q_optimizer = optim.Adam( - list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate - ) + q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate) actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.learning_rate) envs.single_observation_space.dtype = np.float32 @@ -212,18 +203,12 @@ def forward(self, x): for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: - actions = np.array( - [envs.single_action_space.sample() for _ in range(envs.num_envs)] - ) + actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: with torch.no_grad(): actions = actor(torch.Tensor(obs).to(device)) actions += torch.normal(0, actor.action_scale * args.exploration_noise) - actions = ( - actions.cpu() - .numpy() - .clip(envs.single_action_space.low, envs.single_action_space.high) - ) + actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions) @@ -232,15 +217,9 @@ def forward(self, x): if "final_info" in infos: for info in infos["final_info"]: if info is not None: - print( - f"global_step={global_step}, episodic_return={info['episode']['r']}" - ) - writer.add_scalar( - "charts/episodic_return", info["episode"]["r"], global_step - ) - writer.add_scalar( - "charts/episodic_length", info["episode"]["l"], global_step - ) + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` @@ -257,21 +236,17 @@ def forward(self, x): if global_step > args.learning_starts: data = rb.sample(args.batch_size) with torch.no_grad(): - clipped_noise = ( - torch.randn_like(data.actions, device=device) * args.policy_noise - ).clamp(-args.noise_clip, args.noise_clip) * target_actor.action_scale + clipped_noise = (torch.randn_like(data.actions, device=device) * args.policy_noise).clamp( + -args.noise_clip, args.noise_clip + ) * target_actor.action_scale - next_state_actions = ( - target_actor(data.next_observations) + clipped_noise - ).clamp( + next_state_actions = (target_actor(data.next_observations) + clipped_noise).clamp( envs.single_action_space.low[0], envs.single_action_space.high[0] ) qf1_next_target = qf1_target(data.next_observations, next_state_actions) qf2_next_target = qf2_target(data.next_observations, next_state_actions) min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - next_q_value = data.rewards.flatten() + ( - 1 - data.dones.flatten() - ) * args.gamma * (min_qf_next_target).view(-1) + next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1) qf1_a_values = qf1(data.observations, data.actions).view(-1) qf2_a_values = qf2(data.observations, data.actions).view(-1) @@ -291,32 +266,16 @@ def forward(self, x): actor_optimizer.step() # update the target network - for param, target_param in zip( - actor.parameters(), target_actor.parameters() - ): - target_param.data.copy_( - args.tau * param.data + (1 - args.tau) * target_param.data - ) - for param, target_param in zip( - qf1.parameters(), qf1_target.parameters() - ): - target_param.data.copy_( - args.tau * param.data + (1 - args.tau) * target_param.data - ) - for param, target_param in zip( - qf2.parameters(), qf2_target.parameters() - ): - target_param.data.copy_( - args.tau * param.data + (1 - args.tau) * target_param.data - ) + for param, target_param in zip(actor.parameters(), target_actor.parameters()): + target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): + target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): + target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) if global_step % 100 == 0: - writer.add_scalar( - "losses/qf1_values", qf1_a_values.mean().item(), global_step - ) - writer.add_scalar( - "losses/qf2_values", qf2_a_values.mean().item(), global_step - ) + writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) + writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step) writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step) writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step) From 180e3027d205015a5817c76bd9f3ecb28ea323ba Mon Sep 17 00:00:00 2001 From: Noah Farr Date: Wed, 28 Aug 2024 20:02:01 +0200 Subject: [PATCH 3/5] Default num_envs to 1 --- cleanrl/td3_continuous_action.py | 89 +++++++++++++++++++++++--------- 1 file changed, 65 insertions(+), 24 deletions(-) diff --git a/cleanrl/td3_continuous_action.py b/cleanrl/td3_continuous_action.py index 3877c0e7c..80c2e469b 100644 --- a/cleanrl/td3_continuous_action.py +++ b/cleanrl/td3_continuous_action.py @@ -47,7 +47,7 @@ class Args: """total timesteps of the experiments""" learning_rate: float = 3e-4 """the learning rate of the optimizer""" - num_envs: int = 2 + num_envs: int = 1 """the number of parallel game environments""" buffer_size: int = int(1e6) """the replay memory buffer size""" @@ -88,7 +88,8 @@ class QNetwork(nn.Module): def __init__(self, env): super().__init__() self.fc1 = nn.Linear( - np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), + np.array(env.single_observation_space.shape).prod() + + np.prod(env.single_action_space.shape), 256, ) self.fc2 = nn.Linear(256, 256) @@ -158,7 +159,8 @@ def forward(self, x): writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + "|param|value|\n|-|-|\n%s" + % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) # TRY NOT TO MODIFY: seeding @@ -171,9 +173,14 @@ def forward(self, x): # env setup envs = gym.vector.SyncVectorEnv( - [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] + [ + make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) + for i in range(args.num_envs) + ] ) - assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + assert isinstance( + envs.single_action_space, gym.spaces.Box + ), "only continuous action space is supported" actor = Actor(envs).to(device) qf1 = QNetwork(envs).to(device) @@ -184,7 +191,9 @@ def forward(self, x): target_actor.load_state_dict(actor.state_dict()) qf1_target.load_state_dict(qf1.state_dict()) qf2_target.load_state_dict(qf2.state_dict()) - q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate) + q_optimizer = optim.Adam( + list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate + ) actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.learning_rate) envs.single_observation_space.dtype = np.float32 @@ -203,12 +212,18 @@ def forward(self, x): for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: - actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) + actions = np.array( + [envs.single_action_space.sample() for _ in range(envs.num_envs)] + ) else: with torch.no_grad(): actions = actor(torch.Tensor(obs).to(device)) actions += torch.normal(0, actor.action_scale * args.exploration_noise) - actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high) + actions = ( + actions.cpu() + .numpy() + .clip(envs.single_action_space.low, envs.single_action_space.high) + ) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions) @@ -217,9 +232,15 @@ def forward(self, x): if "final_info" in infos: for info in infos["final_info"]: if info is not None: - print(f"global_step={global_step}, episodic_return={info['episode']['r']}") - writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) - writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + print( + f"global_step={global_step}, episodic_return={info['episode']['r']}" + ) + writer.add_scalar( + "charts/episodic_return", info["episode"]["r"], global_step + ) + writer.add_scalar( + "charts/episodic_length", info["episode"]["l"], global_step + ) break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` @@ -236,17 +257,21 @@ def forward(self, x): if global_step > args.learning_starts: data = rb.sample(args.batch_size) with torch.no_grad(): - clipped_noise = (torch.randn_like(data.actions, device=device) * args.policy_noise).clamp( - -args.noise_clip, args.noise_clip - ) * target_actor.action_scale + clipped_noise = ( + torch.randn_like(data.actions, device=device) * args.policy_noise + ).clamp(-args.noise_clip, args.noise_clip) * target_actor.action_scale - next_state_actions = (target_actor(data.next_observations) + clipped_noise).clamp( + next_state_actions = ( + target_actor(data.next_observations) + clipped_noise + ).clamp( envs.single_action_space.low[0], envs.single_action_space.high[0] ) qf1_next_target = qf1_target(data.next_observations, next_state_actions) qf2_next_target = qf2_target(data.next_observations, next_state_actions) min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1) + next_q_value = data.rewards.flatten() + ( + 1 - data.dones.flatten() + ) * args.gamma * (min_qf_next_target).view(-1) qf1_a_values = qf1(data.observations, data.actions).view(-1) qf2_a_values = qf2(data.observations, data.actions).view(-1) @@ -266,16 +291,32 @@ def forward(self, x): actor_optimizer.step() # update the target network - for param, target_param in zip(actor.parameters(), target_actor.parameters()): - target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) - for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): - target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) - for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): - target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + for param, target_param in zip( + actor.parameters(), target_actor.parameters() + ): + target_param.data.copy_( + args.tau * param.data + (1 - args.tau) * target_param.data + ) + for param, target_param in zip( + qf1.parameters(), qf1_target.parameters() + ): + target_param.data.copy_( + args.tau * param.data + (1 - args.tau) * target_param.data + ) + for param, target_param in zip( + qf2.parameters(), qf2_target.parameters() + ): + target_param.data.copy_( + args.tau * param.data + (1 - args.tau) * target_param.data + ) if global_step % 100 == 0: - writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) - writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step) + writer.add_scalar( + "losses/qf1_values", qf1_a_values.mean().item(), global_step + ) + writer.add_scalar( + "losses/qf2_values", qf2_a_values.mean().item(), global_step + ) writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step) writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step) From 3f0cd76d47091aa73abc1636fdedf420d8c10e26 Mon Sep 17 00:00:00 2001 From: Noah Farr Date: Wed, 28 Aug 2024 20:04:16 +0200 Subject: [PATCH 4/5] Run pre-commit --- cleanrl/td3_continuous_action.py | 87 +++++++++----------------------- 1 file changed, 23 insertions(+), 64 deletions(-) diff --git a/cleanrl/td3_continuous_action.py b/cleanrl/td3_continuous_action.py index 80c2e469b..832853911 100644 --- a/cleanrl/td3_continuous_action.py +++ b/cleanrl/td3_continuous_action.py @@ -88,8 +88,7 @@ class QNetwork(nn.Module): def __init__(self, env): super().__init__() self.fc1 = nn.Linear( - np.array(env.single_observation_space.shape).prod() - + np.prod(env.single_action_space.shape), + np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256, ) self.fc2 = nn.Linear(256, 256) @@ -159,8 +158,7 @@ def forward(self, x): writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", - "|param|value|\n|-|-|\n%s" - % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) # TRY NOT TO MODIFY: seeding @@ -173,14 +171,9 @@ def forward(self, x): # env setup envs = gym.vector.SyncVectorEnv( - [ - make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) - for i in range(args.num_envs) - ] + [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] ) - assert isinstance( - envs.single_action_space, gym.spaces.Box - ), "only continuous action space is supported" + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" actor = Actor(envs).to(device) qf1 = QNetwork(envs).to(device) @@ -191,9 +184,7 @@ def forward(self, x): target_actor.load_state_dict(actor.state_dict()) qf1_target.load_state_dict(qf1.state_dict()) qf2_target.load_state_dict(qf2.state_dict()) - q_optimizer = optim.Adam( - list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate - ) + q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate) actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.learning_rate) envs.single_observation_space.dtype = np.float32 @@ -212,18 +203,12 @@ def forward(self, x): for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: - actions = np.array( - [envs.single_action_space.sample() for _ in range(envs.num_envs)] - ) + actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: with torch.no_grad(): actions = actor(torch.Tensor(obs).to(device)) actions += torch.normal(0, actor.action_scale * args.exploration_noise) - actions = ( - actions.cpu() - .numpy() - .clip(envs.single_action_space.low, envs.single_action_space.high) - ) + actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions) @@ -232,15 +217,9 @@ def forward(self, x): if "final_info" in infos: for info in infos["final_info"]: if info is not None: - print( - f"global_step={global_step}, episodic_return={info['episode']['r']}" - ) - writer.add_scalar( - "charts/episodic_return", info["episode"]["r"], global_step - ) - writer.add_scalar( - "charts/episodic_length", info["episode"]["l"], global_step - ) + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` @@ -257,21 +236,17 @@ def forward(self, x): if global_step > args.learning_starts: data = rb.sample(args.batch_size) with torch.no_grad(): - clipped_noise = ( - torch.randn_like(data.actions, device=device) * args.policy_noise - ).clamp(-args.noise_clip, args.noise_clip) * target_actor.action_scale + clipped_noise = (torch.randn_like(data.actions, device=device) * args.policy_noise).clamp( + -args.noise_clip, args.noise_clip + ) * target_actor.action_scale - next_state_actions = ( - target_actor(data.next_observations) + clipped_noise - ).clamp( + next_state_actions = (target_actor(data.next_observations) + clipped_noise).clamp( envs.single_action_space.low[0], envs.single_action_space.high[0] ) qf1_next_target = qf1_target(data.next_observations, next_state_actions) qf2_next_target = qf2_target(data.next_observations, next_state_actions) min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - next_q_value = data.rewards.flatten() + ( - 1 - data.dones.flatten() - ) * args.gamma * (min_qf_next_target).view(-1) + next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1) qf1_a_values = qf1(data.observations, data.actions).view(-1) qf2_a_values = qf2(data.observations, data.actions).view(-1) @@ -291,32 +266,16 @@ def forward(self, x): actor_optimizer.step() # update the target network - for param, target_param in zip( - actor.parameters(), target_actor.parameters() - ): - target_param.data.copy_( - args.tau * param.data + (1 - args.tau) * target_param.data - ) - for param, target_param in zip( - qf1.parameters(), qf1_target.parameters() - ): - target_param.data.copy_( - args.tau * param.data + (1 - args.tau) * target_param.data - ) - for param, target_param in zip( - qf2.parameters(), qf2_target.parameters() - ): - target_param.data.copy_( - args.tau * param.data + (1 - args.tau) * target_param.data - ) + for param, target_param in zip(actor.parameters(), target_actor.parameters()): + target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): + target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): + target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) if global_step % 100 == 0: - writer.add_scalar( - "losses/qf1_values", qf1_a_values.mean().item(), global_step - ) - writer.add_scalar( - "losses/qf2_values", qf2_a_values.mean().item(), global_step - ) + writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) + writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step) writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step) writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step) From bc29059567b9c5b97307afe2f0089f5e52f3212e Mon Sep 17 00:00:00 2001 From: Noah Farr Date: Wed, 28 Aug 2024 20:09:59 +0200 Subject: [PATCH 5/5] Add multi env support for sac_continuous_action --- cleanrl/sac_continuous_action.py | 39 ++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/cleanrl/sac_continuous_action.py b/cleanrl/sac_continuous_action.py index d28cb8e2f..9c7fa5ff1 100644 --- a/cleanrl/sac_continuous_action.py +++ b/cleanrl/sac_continuous_action.py @@ -39,6 +39,8 @@ class Args: """the environment id of the task""" total_timesteps: int = 1000000 """total timesteps of the experiments""" + num_envs: int = 1 + """the number of parallel game environments""" buffer_size: int = int(1e6) """the replay memory buffer size""" gamma: float = 0.99 @@ -81,7 +83,10 @@ def thunk(): class SoftQNetwork(nn.Module): def __init__(self, env): super().__init__() - self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256) + self.fc1 = nn.Linear( + np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), + 256, + ) self.fc2 = nn.Linear(256, 256) self.fc3 = nn.Linear(256, 1) @@ -106,10 +111,18 @@ def __init__(self, env): self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape)) # action rescaling self.register_buffer( - "action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32) + "action_scale", + torch.tensor( + (env.single_action_space.high - env.single_action_space.low) / 2.0, + dtype=torch.float32, + ), ) self.register_buffer( - "action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32) + "action_bias", + torch.tensor( + (env.single_action_space.high + env.single_action_space.low) / 2.0, + dtype=torch.float32, + ), ) def forward(self, x): @@ -176,7 +189,9 @@ def get_action(self, x): device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") # env setup - envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) + envs = gym.vector.SyncVectorEnv( + [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] + ) assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" max_action = float(envs.single_action_space.high[0]) @@ -206,6 +221,7 @@ def get_action(self, x): envs.single_observation_space, envs.single_action_space, device, + n_envs=args.num_envs, handle_timeout_termination=False, ) start_time = time.time() @@ -226,10 +242,11 @@ def get_action(self, x): # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: for info in infos["final_info"]: - print(f"global_step={global_step}, episodic_return={info['episode']['r']}") - writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) - writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) - break + if info is not None: + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() @@ -302,7 +319,11 @@ def get_action(self, x): writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step) writer.add_scalar("losses/alpha", alpha, global_step) print("SPS:", int(global_step / (time.time() - start_time))) - writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + writer.add_scalar( + "charts/SPS", + int(global_step / (time.time() - start_time)), + global_step, + ) if args.autotune: writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)