Skip to content

Commit

Permalink
Add vecenv wrappers for obs_norm to support running mujoco experiment…
Browse files Browse the repository at this point in the history
… with envpool (#628)

- add VectorEnvWrapper and VectorEnvNormObs
- obs_rms store in policy save/load
- align mujoco scripts with atari: obs_norm, envpool, wandb and README
  • Loading branch information
Trinkle23897 authored May 5, 2022
1 parent a7c789f commit 2a7c151
Show file tree
Hide file tree
Showing 27 changed files with 920 additions and 589 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/extra_sys.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ jobs:
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest
run: |
pytest test/base test/continuous --cov=tianshou --durations=0 -v
pytest test/base test/continuous --cov=tianshou --durations=0 -v --color=yes
2 changes: 1 addition & 1 deletion .github/workflows/gputest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ jobs:
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --durations=0 -v
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --durations=0 -v --color=yes
2 changes: 1 addition & 1 deletion .github/workflows/profile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ jobs:
python -m pip install ".[dev]" --upgrade
- name: Test with pytest
run: |
pytest test/throughput --durations=0 -v
pytest test/throughput --durations=0 -v --color=yes
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |
pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v
pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v --color=yes
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pytest:
$(call check_install, pytest)
$(call check_install, pytest_cov)
$(call check_install, pytest_xdist)
pytest test --cov ${PROJECT_PATH} --durations 0 -v --cov-report term-missing
pytest test --cov ${PROJECT_PATH} --durations 0 -v --cov-report term-missing --color=yes

mypy:
$(call check_install, mypy)
Expand Down
15 changes: 3 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,7 @@

---

[![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/)
[![Conda](https://img.shields.io/conda/vn/conda-forge/tianshou)](https://github.com/conda-forge/tianshou-feedstock)
[![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/master)
[![Read the Docs](https://img.shields.io/readthedocs/tianshou-docs-zh-cn?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://tianshou.readthedocs.io/zh/master/)
[![Unittest](https://github.com/thu-ml/tianshou/workflows/Unittest/badge.svg?branch=master)](https://github.com/thu-ml/tianshou/actions)
[![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou)
[![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues)
[![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers)
[![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network)
[![GitHub license](https://img.shields.io/github/license/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
[![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/) [![Conda](https://img.shields.io/conda/vn/conda-forge/tianshou)](https://github.com/conda-forge/tianshou-feedstock) [![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/master) [![Read the Docs](https://img.shields.io/readthedocs/tianshou-docs-zh-cn?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://tianshou.readthedocs.io/zh/master/) [![Unittest](https://github.com/thu-ml/tianshou/workflows/Unittest/badge.svg?branch=master)](https://github.com/thu-ml/tianshou/actions) [![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou) [![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues) [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) [![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network) [![GitHub license](https://img.shields.io/github/license/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/blob/master/LICENSE)

**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:

Expand Down Expand Up @@ -48,7 +39,7 @@
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
- [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf)

Here is Tianshou's other features:
Here are Tianshou's other features:

- Elegant framework, using only ~4000 lines of code
- State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
Expand Down Expand Up @@ -132,7 +123,7 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma

<sup>(2): not all algorithms support this feature</sup>

<sup>(3): TQC and QR-DQN in [sb3-contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) instead of main repo</sup>
<sup>(3): TQC and QR-DQN in [sb3-contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) instead of main repo</sup>

<sup>(4): super fast APPO!</sup>

Expand Down
21 changes: 17 additions & 4 deletions examples/mujoco/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,19 @@ Supported algorithms are listed below:
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/6426a39796db052bafb7cabe85c764db20a722b0)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/5057b5c89e6168220272c9c28a15b758a72efc32)

#### Usage
## EnvPool

We highly recommend using envpool to run the following experiments. To install, in a linux machine, type:

```bash
pip install envpool
```

After that, `make_mujoco_env` will automatically switch to envpool's Mujoco env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline in average) than python vectorized env implementation, and it's behavior is consistent to gym's Mujoco env.

For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/) and [Docs](https://envpool.readthedocs.io/en/latest/api/mujoco.html).

## Usage

Run

Expand All @@ -46,23 +58,24 @@ This will start 10 experiments with different seeds.
Now that all the experiments are finished, we can convert all tfevent files into csv files and then try plotting the results.

```bash
# geenrate csv
# generate csv
$ ./tools.py --root-dir ./results/Ant-v3/sac
# generate figures
$ ./plotter.py --root-dir ./results/Ant-v3 --shaded-std --legend-pattern "\\w+"
# generate numerical result (support multiple groups: `--root-dir ./` instead of single dir)
$ ./analysis.py --root-dir ./results --norm
```

#### Example benchmark
## Example benchmark

<img src="./benchmark/Ant-v3/offpolicy.png" width="500" height="450">

Other graphs can be found under `/examples/mujuco/benchmark/`
Other graphs can be found under `examples/mujuco/benchmark/`

For pretrained agents, detailed graphs (single agent, single game) and log details, please refer to [https://cloud.tsinghua.edu.cn/d/f45fcfc5016043bc8fbc/](https://cloud.tsinghua.edu.cn/d/f45fcfc5016043bc8fbc/).

## Offpolicy algorithms

#### Notes

1. In offpolicy algorithms (DDPG, TD3, SAC), the shared hyperparameters are almost the same, and unless otherwise stated, hyperparameters are consistent with those used for benchmark in SpinningUp's implementations (e.g. we use batchsize 256 in DDPG/TD3/SAC while SpinningUp use 100. Minor difference also lies with `start-timesteps`, data loop method `step_per_collect`, method to deal with/bootstrap truncated steps because of timelimit and unfinished/collecting episodes (contribute to performance improvement), etc.).
Expand Down
12 changes: 6 additions & 6 deletions examples/mujoco/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from tools import csv2numpy, find_all_files, group_files


def numerical_anysis(root_dir, xlim, norm=False):
file_pattern = re.compile(r".*/test_rew_\d+seeds.csv$")
def numerical_analysis(root_dir, xlim, norm=False):
file_pattern = re.compile(r".*/test_reward_\d+seeds.csv$")
norm_group_pattern = re.compile(r"(/|^)\w+?\-v(\d|$)")
output_group_pattern = re.compile(r".*?(?=(/|^)\w+?\-v\d)")
csv_files = find_all_files(root_dir, file_pattern)
Expand All @@ -23,13 +23,13 @@ def numerical_anysis(root_dir, xlim, norm=False):
if norm:
result = np.stack(
[
result['env_step'], result['rew'] - result['rew'][0],
result['rew:shaded']
result['env_step'], result['reward'] - result['reward'][0],
result['reward:shaded']
]
)
else:
result = np.stack(
[result['env_step'], result['rew'], result['rew:shaded']]
[result['env_step'], result['reward'], result['reward:shaded']]
)

if result[0, -1] < xlim:
Expand Down Expand Up @@ -96,4 +96,4 @@ def numerical_anysis(root_dir, xlim, norm=False):
help="Normalize all results according to environment."
)
args = parser.parse_args()
numerical_anysis(args.root_dir, args.xlim, norm=args.norm)
numerical_analysis(args.root_dir, args.xlim, norm=args.norm)
4 changes: 2 additions & 2 deletions examples/mujoco/gen_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def merge(rootdir):
result.append(
{
'env_step': int(row['env_step']),
'rew': float(row['rew']),
'rew_std': float(row['rew:shaded']),
'rew': float(row['reward']),
'rew_std': float(row['reward:shaded']),
'Agent': algo,
}
)
Expand Down
130 changes: 72 additions & 58 deletions examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,106 +5,101 @@
import os
import pprint

import gym
import numpy as np
import torch
from mujoco_env import make_mujoco_env
from torch import nn
from torch.distributions import Independent, Normal
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.policy import A2CPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='HalfCheetah-v3')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=4096)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--lr', type=float, default=7e-4)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=30000)
parser.add_argument('--step-per-collect', type=int, default=80)
parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument("--task", type=str, default="Ant-v3")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=4096)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
parser.add_argument("--lr", type=float, default=7e-4)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=30000)
parser.add_argument("--step-per-collect", type=int, default=80)
parser.add_argument("--repeat-per-collect", type=int, default=1)
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument("--batch-size", type=int, default=99999)
parser.add_argument("--training-num", type=int, default=16)
parser.add_argument("--test-num", type=int, default=10)
# a2c special
parser.add_argument('--rew-norm', type=int, default=True)
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.01)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--bound-action-method', type=str, default="clip")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--rew-norm", type=int, default=True)
parser.add_argument("--vf-coef", type=float, default=0.5)
parser.add_argument("--ent-coef", type=float, default=0.01)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--bound-action-method", type=str, default="clip")
parser.add_argument("--lr-decay", type=int, default=True)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--watch',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark")
parser.add_argument(
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only",
)
return parser.parse_args()


def test_a2c(args=get_args()):
env = gym.make(args.task)
env, train_envs, test_envs = make_mujoco_env(
args.task, args.seed, args.training_num, args.test_num, obs_norm=True
)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True
)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
norm_obs=True,
obs_rms=train_envs.obs_rms,
update_obs_rms=False
)

# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net_a = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device
device=args.device,
)
actor = ActorProb(
net_a,
args.action_shape,
max_action=args.max_action,
unbounded=True,
device=args.device
device=args.device,
).to(args.device)
net_c = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device
device=args.device,
)
critic = Critic(net_c, device=args.device).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
Expand All @@ -125,7 +120,7 @@ def test_a2c(args=get_args()):
list(actor.parameters()) + list(critic.parameters()),
lr=args.lr,
eps=1e-5,
alpha=0.99
alpha=0.99,
)

lr_scheduler = None
Expand Down Expand Up @@ -156,12 +151,15 @@ def dist(*logits):
action_scaling=True,
action_bound_method=args.bound_action_method,
lr_scheduler=lr_scheduler,
action_space=env.action_space
action_space=env.action_space,
)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
ckpt = torch.load(args.resume_path, map_location=args.device)
policy.load_state_dict(ckpt["model"])
train_envs.set_obs_rms(ckpt["obs_rms"])
test_envs.set_obs_rms(ckpt["obs_rms"])
print("Loaded agent from: ", args.resume_path)

# collector
Expand All @@ -171,16 +169,32 @@ def dist(*logits):
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)

# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_a2c'
log_path = os.path.join(args.logdir, args.task, 'a2c', log_file)
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "a2c"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)

def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()}
torch.save(state, os.path.join(log_path, "policy.pth"))

if not args.watch:
# trainer
Expand All @@ -196,7 +210,7 @@ def save_best_fn(policy):
step_per_collect=args.step_per_collect,
save_best_fn=save_best_fn,
logger=logger,
test_in_train=False
test_in_train=False,
)
pprint.pprint(result)

Expand All @@ -208,5 +222,5 @@ def save_best_fn(policy):
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')


if __name__ == '__main__':
if __name__ == "__main__":
test_a2c()
Loading

0 comments on commit 2a7c151

Please sign in to comment.