Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes MARL workflows for recording videos during training/inferencing #1596

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, env: DirectMARLEnv) -> None:
self.cfg = self.env.cfg
self.sim = self.env.sim
self.scene = self.env.scene
self.render_mode = self.env.render_mode

self.single_observation_space = gym.spaces.Dict()
if self._state_as_observation:
Expand Down Expand Up @@ -126,7 +127,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn:
return obs, rewards, terminated, time_outs, extras

def render(self, recompute: bool = False) -> np.ndarray | None:
self.env.render(recompute)
return self.env.render(recompute)

def close(self) -> None:
self.env.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@
kwargs={
"env_cfg_entry_point": f"{__name__}.cart_double_pendulum_env:CartDoublePendulumEnvCfg",
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml",
"rsl_rl_cfg_entry_point": (
f"{agents.__name__}.rsl_rl_ppo_cfg:CartDoublePendulumPPORunnerCfg"
), # FIXME: Raises errors
"skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml",
"skrl_ippo_cfg_entry_point": f"{agents.__name__}:skrl_ippo_cfg.yaml",
"skrl_mappo_cfg_entry_point": f"{agents.__name__}:skrl_mappo_cfg.yaml",
"sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml", # FIXME: Raises errors
},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from omni.isaac.lab.utils import configclass

from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import (
RslRlOnPolicyRunnerCfg,
RslRlPpoActorCriticCfg,
RslRlPpoAlgorithmCfg,
)


@configclass
class CartDoublePendulumPPORunnerCfg(RslRlOnPolicyRunnerCfg):
num_steps_per_env = 16
max_iterations = 150
save_interval = 50
experiment_name = "cart_double_pendulum_direct"
empirical_normalization = False
policy = RslRlPpoActorCriticCfg(
init_noise_std=1.0,
actor_hidden_dims=[32, 32],
critic_hidden_dims=[32, 32],
activation="elu",
)
algorithm = RslRlPpoAlgorithmCfg(
value_loss_coef=1.0,
use_clipped_value_loss=True,
clip_param=0.2,
entropy_coef=0.005,
num_learning_epochs=5,
num_mini_batches=4,
learning_rate=1.0e-3,
schedule="adaptive",
gamma=0.99,
lam=0.95,
desired_kl=0.01,
max_grad_norm=1.0,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Reference: https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/ppo.yml#L32
seed: 42

n_timesteps: !!float 1e6
policy: 'MlpPolicy'
n_steps: 16
batch_size: 4096
gae_lambda: 0.95
gamma: 0.99
n_epochs: 20
ent_coef: 0.01
learning_rate: !!float 3e-4
clip_range: !!float 0.2
policy_kwargs: "dict(
activation_fn=nn.ELU,
net_arch=[32, 32],
squash_output=False,
)"
vf_coef: 1.0
max_grad_norm: 1.0
device: "cuda:0"
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,5 @@ agent:
# https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html
trainer:
class: SequentialTrainer
timesteps: 1600
timesteps: 4800
environment_info: log
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,5 @@ agent:
# https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html
trainer:
class: SequentialTrainer
timesteps: 1600
timesteps: 4800
environment_info: log
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,5 @@ agent:
# https://skrl.readthedocs.io/en/latest/api/trainers/sequential.html
trainer:
class: SequentialTrainer
timesteps: 1600
timesteps: 4800
environment_info: log
9 changes: 5 additions & 4 deletions source/standalone/workflows/rl_games/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ def main():

# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)

# wrap for video recording
if args_cli.video:
video_kwargs = {
Expand All @@ -106,10 +111,6 @@ def main():
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)

# wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions)

Expand Down
9 changes: 5 additions & 4 deletions source/standalone/workflows/rl_games/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen

# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)

# wrap for video recording
if args_cli.video:
video_kwargs = {
Expand All @@ -141,10 +146,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)

# wrap around environment for rl-games
env = RlGamesVecEnvWrapper(env, rl_device, clip_obs, clip_actions)

Expand Down
9 changes: 5 additions & 4 deletions source/standalone/workflows/rsl_rl/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def main():

# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)

# wrap for video recording
if args_cli.video:
video_kwargs = {
Expand All @@ -86,10 +91,6 @@ def main():
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)

# wrap around environment for rsl-rl
env = RslRlVecEnvWrapper(env)

Expand Down
8 changes: 4 additions & 4 deletions source/standalone/workflows/rsl_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)

# save resume path before creating a new log_dir
if agent_cfg.resume:
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
Expand All @@ -116,10 +120,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)

# wrap around environment for rsl-rl
env = RslRlVecEnvWrapper(env)

Expand Down
6 changes: 6 additions & 0 deletions source/standalone/workflows/sb3/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize

from omni.isaac.lab.envs import DirectMARLEnv, multi_agent_to_single_agent
from omni.isaac.lab.utils.dict import print_dict

import omni.isaac.lab_tasks # noqa: F401
Expand Down Expand Up @@ -82,6 +83,11 @@ def main():

# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)

# wrap for video recording
if args_cli.video:
video_kwargs = {
Expand Down
9 changes: 5 additions & 4 deletions source/standalone/workflows/sb3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen

# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)

# wrap for video recording
if args_cli.video:
video_kwargs = {
Expand All @@ -116,10 +121,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)

# wrap around environment for stable baselines
env = Sb3VecEnvWrapper(env)

Expand Down
9 changes: 5 additions & 4 deletions source/standalone/workflows/skrl/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ def main():

# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv) and algorithm in ["ppo"]:
env = multi_agent_to_single_agent(env)

# wrap for video recording
if args_cli.video:
video_kwargs = {
Expand All @@ -128,10 +133,6 @@ def main():
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv) and algorithm in ["ppo"]:
env = multi_agent_to_single_agent(env)

# wrap around environment for skrl
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="auto")`

Expand Down
9 changes: 5 additions & 4 deletions source/standalone/workflows/skrl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen

# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv) and algorithm in ["ppo"]:
env = multi_agent_to_single_agent(env)

# wrap for video recording
if args_cli.video:
video_kwargs = {
Expand All @@ -163,10 +168,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
print_dict(video_kwargs, nesting=4)
env = gym.wrappers.RecordVideo(env, **video_kwargs)

# convert to single-agent instance if required by the RL algorithm
if isinstance(env.unwrapped, DirectMARLEnv) and algorithm in ["ppo"]:
env = multi_agent_to_single_agent(env)

# wrap around environment for skrl
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="auto")`

Expand Down