Skip to content

Commit

Permalink
Fixes SB-3 and RL-Games RL wrappers (#242)
Browse files Browse the repository at this point in the history
# Description

This MR goes over the current implementations of Stable-Baselines3 and
RL-Games wrapper. It fixes the wrapper implementations as well as the
checkpoint loader to work for the logging format of these wrappers.

The changes have been tested against the `Isaac-Cartpole-v0` environment
from MR #241.

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
  • Loading branch information
Mayankm96 authored Nov 10, 2023
1 parent 13327b8 commit 5d44141
Show file tree
Hide file tree
Showing 15 changed files with 161 additions and 61 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ extra_standard_library = [
"bpy",
"matplotlib",
"gymnasium",
"gym",
"scipy",
"hid",
"yaml",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.5.1"
version = "0.5.2"

# Description
title = "ORBIT Environments"
Expand Down
11 changes: 11 additions & 0 deletions source/extensions/omni.isaac.orbit_tasks/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
Changelog
---------

0.5.2 (2023-11-08)
~~~~~~~~~~~~~~~~~~

Fixed
^^^^^

* Fixed the RL wrappers for Stable-Baselines3 and RL-Games. It now works with their most recent versions.
* Fixed the :meth:`get_checkpoint_path` to allow any in-between sub-folders between the run directory and the
checkpoint directory.


0.5.1 (2023-11-04)
~~~~~~~~~~~~~~~~~~

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,21 +140,26 @@ def parse_env_cfg(task_name: str, use_gpu: bool | None = None, num_envs: int | N


def get_checkpoint_path(
log_path: str, run_dir: str = "*", checkpoint: str = "*", sort_alphabetical: bool = True
log_path: str, run_dir: str = ".*", checkpoint: str = ".*", other_dirs: list[str] = None, sort_alpha: bool = True
) -> str:
"""Get path to the model checkpoint in input directory.
The checkpoint file is resolved as: <log_path>/<run_dir>/<checkpoint>.
If run_dir and checkpoint are regex expressions then the most recent (highest alphabetical order) run and checkpoint are selected.
The checkpoint file is resolved as: <log_path>/<run_dir>/<*other_dirs>/<checkpoint>, where the
:attr:`other_dirs` are intermediate folder names to concatenate. These cannot be regex expressions.
If :attr:`run_dir` and :attr:`checkpoint` are regex expressions then the most recent (highest alphabetical order)
run and checkpoint are selected. To disable this behavior, set the flag :attr:`sort_alpha` to False.
Args:
log_path: The log directory path to find models in.
run_dir: Regex expression for the name of the directory containing the run. Defaults to the most
run_dir: The regex expression for the name of the directory containing the run. Defaults to the most
recent directory created inside :obj:`log_dir`.
checkpoint: The model checkpoint file or directory name. Defaults to the most recent
other_dirs: The intermediate directories between the run directory and the checkpoint file. Defaults to
None, which implies that checkpoint file is directly under the run directory.
checkpoint: The regex expression for the model checkpoint file. Defaults to the most recent
torch-model saved in the :obj:`run_dir` directory.
sort_alphabetical: Whether to sort the runs and checkpoints by alphabetical order. Defaults to True.
If False, the checkpoints are sorted by the last modified time.
sort_alpha: Whether to sort the runs by alphabetical order. Defaults to True.
If False, the folders in :attr:`run_dir` are sorted by the last modified time.
Raises:
ValueError: When no runs are found in the input directory.
Expand All @@ -173,12 +178,15 @@ def get_checkpoint_path(
os.path.join(log_path, run) for run in os.scandir(log_path) if run.is_dir() and re.match(run_dir, run.name)
]
# sort matched runs by alphabetical order (latest run should be last)
if sort_alphabetical:
if sort_alpha:
runs.sort()
else:
runs = sorted(runs, key=os.path.getmtime)
# create last run file path
run_path = runs[-1]
if other_dirs is not None:
run_path = os.path.join(runs[-1], *other_dirs)
else:
run_path = runs[-1]
except IndexError:
raise ValueError(f"No runs present in the directory: '{log_path}' match: '{run_dir}'.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@

from __future__ import annotations

import gymnasium as gym
import gym.spaces # needed for rl-games incompatibility: https://github.com/Denys88/rl_games/issues/261
import gymnasium
import torch

from rl_games.common import env_configurations
Expand Down Expand Up @@ -61,13 +62,12 @@ class RlGamesVecEnvWrapper(IVecEnv):
observations. This dictionary contains "obs" and "states" which typically correspond
to the actor and critic observations respectively.
To use asymmetric actor-critic, the environment instance must have the attributes
To use asymmetric actor-critic, the environment observations from :class:`RLTaskEnv`
must have the key or group name "critic". The observation group is used to set the
:attr:`num_states` (int) and :attr:`state_space` (:obj:`gym.spaces.Box`). These are
used by the learning agent to allocate buffers in the trajectory memory. Additionally,
the method :meth:`_get_observations()` should have the key "critic" which corresponds
to the privileged observations. Since this is optional for some environments, the wrapper
checks if these attributes exist. If they don't then the wrapper defaults to zero as number
of privileged observations.
used by the learning agent in RL-Games to allocate buffers in the trajectory memory.
Since this is optional for some environments, the wrapper checks if these attributes exist.
If they don't then the wrapper defaults to zero as number of privileged observations.
.. caution::
Expand Down Expand Up @@ -104,19 +104,11 @@ def __init__(self, env: RLTaskEnv, rl_device: str, clip_obs: float, clip_actions
self._clip_obs = clip_obs
self._clip_actions = clip_actions
self._sim_device = env.unwrapped.device

# information about spaces for the wrapper
# note: rl-games only wants single observation and action spaces
self.rlg_observation_space = self.unwrapped.single_observation_space["policy"]
self.rlg_action_space = self.unwrapped.single_action_space
# information for privileged observations
self.rlg_state_space = self.unwrapped.single_observation_space.get("critic")
if self.rlg_state_space is not None:
if not isinstance(self.rlg_state_space, gym.spaces.Box):
raise ValueError(f"Privileged observations must be of type Box. Type: {type(self.rlg_state_space)}")
self.rlg_num_states = self.rlg_state_space.shape[0]
else:
if self.state_space is None:
self.rlg_num_states = 0
else:
self.rlg_num_states = self.state_space.shape[0]

def __str__(self):
"""Returns the wrapper name and the :attr:`env` representation string."""
Expand All @@ -142,14 +134,35 @@ def render_mode(self) -> str | None:
return self.env.render_mode

@property
def observation_space(self) -> gym.Space:
def observation_space(self) -> gym.spaces.Box:
"""Returns the :attr:`Env` :attr:`observation_space`."""
return self.env.observation_space
# note: rl-games only wants single observation space
policy_obs_space = self.unwrapped.single_observation_space["policy"]
if not isinstance(policy_obs_space, gymnasium.spaces.Box):
raise NotImplementedError(
f"The RL-Games wrapper does not currently support observation space: '{type(policy_obs_space)}'."
f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
" and if you are nice, please send a merge-request."
)
# note: maybe should check if we are a sub-set of the actual space. don't do it right now since
# in RLTaskEnv we are setting action space as (-inf, inf).
return gym.spaces.Box(-self._clip_obs, self._clip_obs, policy_obs_space.shape)

@property
def action_space(self) -> gym.Space:
"""Returns the :attr:`Env` :attr:`action_space`."""
return self.env.action_space
# note: rl-games only wants single action space
action_space = self.unwrapped.single_action_space
if not isinstance(action_space, gymnasium.spaces.Box):
raise NotImplementedError(
f"The RL-Games wrapper does not currently support action space: '{type(action_space)}'."
f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
" and if you are nice, please send a merge-request."
)
# return casted space in gym.spaces.Box (OpenAI Gym)
# note: maybe should check if we are a sub-set of the actual space. don't do it right now since
# in RLTaskEnv we are setting action space as (-inf, inf).
return gym.spaces.Box(-self._clip_actions, self._clip_actions, action_space.shape)

@classmethod
def class_name(cls) -> str:
Expand All @@ -168,16 +181,45 @@ def unwrapped(self) -> RLTaskEnv:
Properties
"""

@property
def num_envs(self) -> int:
"""Returns the number of sub-environment instances."""
return self.unwrapped.num_envs

@property
def device(self) -> str:
"""Returns the base environment simulation device."""
return self.unwrapped.device

@property
def state_space(self) -> gym.spaces.Box | None:
"""Returns the :attr:`Env` :attr:`observation_space`."""
# note: rl-games only wants single observation space
critic_obs_space = self.unwrapped.single_observation_space.get("critic")
# check if we even have a critic obs
if critic_obs_space is None:
return None
elif not isinstance(critic_obs_space, gymnasium.spaces.Box):
raise NotImplementedError(
f"The RL-Games wrapper does not currently support state space: '{type(critic_obs_space)}'."
f" If you need to support this, please modify the wrapper: {self.__class__.__name__},"
" and if you are nice, please send a merge-request."
)
# return casted space in gym.spaces.Box (OpenAI Gym)
# note: maybe should check if we are a sub-set of the actual space. don't do it right now since
# in RLTaskEnv we are setting action space as (-inf, inf).
return gym.spaces.Box(-self._clip_obs, self._clip_obs, critic_obs_space.shape)

def get_number_of_agents(self) -> int:
"""Returns number of actors in the environment."""
return getattr(self, "num_agents", 1)

def get_env_info(self) -> dict:
"""Returns the Gym spaces for the environment."""
return {
"observation_space": self.rlg_observation_space,
"action_space": self.rlg_action_space,
"state_space": self.rlg_state_space,
"observation_space": self.observation_space,
"action_space": self.action_space,
"state_space": self.state_space,
}

"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

from __future__ import annotations

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn # noqa: F401
from typing import Any

from stable_baselines3.common.utils import constant_fn
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn

from omni.isaac.orbit.envs import RLTaskEnv
Expand All @@ -44,16 +47,28 @@ def process_sb3_cfg(cfg: dict) -> dict:
Reference:
https://github.com/DLR-RM/rl-baselines3-zoo/blob/0e5eb145faefa33e7d79c7f8c179788574b20da5/utils/exp_manager.py#L358
"""
_direct_eval = ["policy_kwargs", "replay_buffer_class", "replay_buffer_kwargs"]

def update_dict(d):
for key, value in d.items():
def update_dict(hyperparams: dict[str, Any]) -> dict[str, Any]:
for key, value in hyperparams.items():
if isinstance(value, dict):
update_dict(value)
else:
if key in _direct_eval:
d[key] = eval(value)
return d
if key in ["policy_kwargs", "replay_buffer_class", "replay_buffer_kwargs"]:
hyperparams[key] = eval(value)
elif key in ["learning_rate", "clip_range", "clip_range_vf", "delta_std"]:
if isinstance(value, str):
_, initial_value = value.split("_")
initial_value = float(initial_value)
hyperparams[key] = lambda progress_remaining: progress_remaining * initial_value
elif isinstance(value, (float, int)):
# Negative value: ignore (ex: for clipping)
if value < 0:
continue
hyperparams[key] = constant_fn(float(value))
else:
raise ValueError(f"Invalid value for {key}: {hyperparams[key]}")

return hyperparams

# parse agent configuration and convert to classes
return update_dict(cfg)
Expand Down Expand Up @@ -127,9 +142,14 @@ def __init__(self, env: RLTaskEnv):
self.num_envs = self.unwrapped.num_envs
self.sim_device = self.unwrapped.device
self.render_mode = self.unwrapped.render_mode
# initialize vec-env
# obtain gym spaces
# note: stable-baselines3 does not like when we have unbounded action space so
# we set it to some high value here. Maybe this is not general but something to think about.
observation_space = self.unwrapped.single_observation_space["policy"]
action_space = self.unwrapped.single_action_space
if isinstance(action_space, gym.spaces.Box) and action_space.is_bounded() != "both":
action_space = gym.spaces.Box(low=-100, high=100, shape=action_space.shape)
# initialize vec-env
VecEnv.__init__(self, self.num_envs, observation_space, action_space)
# add buffer for logging episodic information
self._ep_rew_buf = torch.zeros(self.num_envs, device=self.sim_device)
Expand Down
4 changes: 2 additions & 2 deletions source/extensions/omni.isaac.orbit_tasks/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
# Extra dependencies for RL agents
EXTRAS_REQUIRE = {
"sb3": ["stable-baselines3>=2.0"],
"skrl": ["skrl>=0.10.0"],
"rl_games": ["rl-games==1.6.1"],
"skrl": ["skrl==0.10.0"],
"rl_games": ["rl-games==1.6.1", "gym"], # rl-games still needs gym :(
"rsl_rl": ["rsl_rl@git+https://github.com/leggedrobotics/rsl_rl.git"],
"robomimic": ["robomimic@git+https://github.com/ARISE-Initiative/robomimic.git"],
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_random_actions(self):
with torch.inference_mode():
for _ in range(100):
# sample actions from -1 to 1
actions = 2 * torch.rand(env.action_space.shape, device=env.device) - 1
actions = 2 * torch.rand(env.num_envs, *env.action_space.shape, device=env.device) - 1
# apply actions
transition = env.step(actions)
# check signals
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_random_actions(self):
with torch.inference_mode():
for _ in range(1000):
# sample actions from -1 to 1
actions = 2 * np.random.rand(env.num_envs, env.action_space.shape) - 1
actions = 2 * np.random.rand(env.num_envs, *env.action_space.shape) - 1
# apply actions
transition = env.step(actions)
# check signals
Expand Down
10 changes: 4 additions & 6 deletions source/standalone/workflows/rl_games/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,15 @@ def main():
# find checkpoint
if args_cli.checkpoint is None:
# specify directory for logging runs
if "full_experiment_name" not in agent_cfg["params"]["config"]:
run_dir = os.path.join("*", "nn")
else:
run_dir = os.path.join(agent_cfg["params"]["config"]["full_experiment_name"], "nn")
run_dir = agent_cfg["params"]["config"].get("full_experiment_name", ".*")
# specify name of checkpoint
if args_cli.use_last_checkpoint:
checkpoint_file = None
checkpoint_file = ".*"
else:
# this loads the best checkpoint
checkpoint_file = f"{agent_cfg['params']['config']['name']}.pth"
# get path to previous checkpoint
resume_path = get_checkpoint_path(log_root_path, run_dir, checkpoint_file)
resume_path = get_checkpoint_path(log_root_path, run_dir, checkpoint_file, other_dirs=["nn"])
else:
resume_path = os.path.abspath(args_cli.checkpoint)
# load previously trained model
Expand Down
2 changes: 1 addition & 1 deletion source/standalone/workflows/rsl_rl/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from omni.isaac.orbit.app import AppLauncher

# local imports
import source.standalone.workflows.rsl_rl.cli_args as cli_args # isort: skip
import cli_args # isort: skip

# add argparse arguments
parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.")
Expand Down
2 changes: 1 addition & 1 deletion source/standalone/workflows/rsl_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from omni.isaac.orbit.app import AppLauncher

# local imports
import source.standalone.workflows.rsl_rl.cli_args as cli_args # isort: skip
import cli_args # isort: skip


# add argparse arguments
Expand Down
Loading

0 comments on commit 5d44141

Please sign in to comment.