Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TD3 and SAC support for multiple envs #481

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions cleanrl/td3_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class Args:
"""total timesteps of the experiments"""
learning_rate: float = 3e-4
"""the learning rate of the optimizer"""
num_envs: int = 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would default this to 1

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Probably also worth discussing how to handle total_timesteps with multiple environments.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pseudo-rnd-thoughts It seems like in sb3 they do it like this:
Lets say total_timesteps is 100_000.
Then they actually run 100_000 * num_envs steps, because for each timestep num_envs step are executed.

"""the number of parallel game environments"""
buffer_size: int = int(1e6)
"""the replay memory buffer size"""
gamma: float = 0.99
Expand Down Expand Up @@ -85,7 +87,10 @@ def thunk():
class QNetwork(nn.Module):
def __init__(self, env):
super().__init__()
self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256)
self.fc1 = nn.Linear(
np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape),
256,
)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 1)

Expand All @@ -105,10 +110,18 @@ def __init__(self, env):
self.fc_mu = nn.Linear(256, np.prod(env.single_action_space.shape))
# action rescaling
self.register_buffer(
"action_scale", torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32)
"action_scale",
torch.tensor(
(env.single_action_space.high - env.single_action_space.low) / 2.0,
dtype=torch.float32,
),
)
self.register_buffer(
"action_bias", torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32)
"action_bias",
torch.tensor(
(env.single_action_space.high + env.single_action_space.low) / 2.0,
dtype=torch.float32,
),
)

def forward(self, x):
Expand Down Expand Up @@ -157,7 +170,9 @@ def forward(self, x):
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

actor = Actor(envs).to(device)
Expand All @@ -178,6 +193,7 @@ def forward(self, x):
envs.single_observation_space,
envs.single_action_space,
device,
n_envs=args.num_envs,
handle_timeout_termination=False,
)
start_time = time.time()
Expand All @@ -200,10 +216,11 @@ def forward(self, x):
# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break
if info is not None:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
Expand Down Expand Up @@ -264,7 +281,11 @@ def forward(self, x):
writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
writer.add_scalar(
"charts/SPS",
int(global_step / (time.time() - start_time)),
global_step,
)

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
Expand All @@ -290,7 +311,14 @@ def forward(self, x):

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "TD3", f"runs/{run_name}", f"videos/{run_name}-eval")
push_to_hub(
args,
episodic_returns,
repo_id,
"TD3",
f"runs/{run_name}",
f"videos/{run_name}-eval",
)

envs.close()
writer.close()
Loading