Skip to content

Commit

Permalink
Fix critic network for Discrete CRR (#485)
Browse files Browse the repository at this point in the history
- Fixes an inconsistency in the implementation of Discrete CRR. Now it uses `Critic` class for its critic, following conventions in other actor-critic policies;
- Updates several offline policies to use `ActorCritic` class for its optimizer to eliminate randomness caused by parameter sharing between actor and critic;
- Add `writer.flush()` in TensorboardLogger to ensure real-time result;
- Enable `test_collector=None` in 3 trainers to turn off testing during training;
- Updates the Atari offline results in README.md;
- Moves Atari offline RL examples to `examples/offline`; tests to `test/offline` per review comments.
  • Loading branch information
nuance1979 authored Nov 28, 2021
1 parent 5c5a3db commit 3592f45
Show file tree
Hide file tree
Showing 21 changed files with 458 additions and 257 deletions.
Empty file added examples/__init__.py
Empty file.
65 changes: 1 addition & 64 deletions examples/atari/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Atari General
# Atari

The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network.

Expand Down Expand Up @@ -95,66 +95,3 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| MsPacmanNoFrameskip-v4 | 3101 | ![](results/rainbow/MsPacman_rew.png) | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 2126 | ![](results/rainbow/Seaquest_rew.png) | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 1794.5 | ![](results/rainbow/SpaceInvaders_rew.png) | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` |

# BCQ

To running BCQ algorithm on Atari, you need to do the following things:

- Train an expert, by using the command listed in the above DQN section;
- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`.

We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):

| Task | Online DQN | Behavioral | BCQ |
| ---------------------- | ---------- | ---------- | --------------------------------- |
| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) |
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) |

# CQL

To running CQL algorithm on Atari, you need to do the following things:

- Train an expert, by using the command listed in the above QRDQN section;
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train CQL: `python3 atari_cql.py --task {your_task} --load-buffer-name expert.hdf5`.

We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):

| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 19.5 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 248.3 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |

We reduce the size of the offline data to 10% and 1% of the above and get:

Buffer size 100000:

| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |

Buffer size 10000:

| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |

# CRR

To running CRR algorithm on Atari, you need to do the following things:

- Train an expert, by using the command listed in the above QRDQN section;
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train CQL: `python3 atari_crr.py --task {your_task} --load-buffer-name expert.hdf5`.

We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):

| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters |
| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 16.1 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 26.4 (epoch 12) | 125.0 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |

Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.
Empty file added examples/atari/__init__.py
Empty file.
62 changes: 60 additions & 2 deletions examples/offline/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore.

Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
## Continous control

## Train
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.

### Train

Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset.

Expand All @@ -26,3 +28,59 @@ After 1M steps:
| --------------------- | --------------- |
| halfcheetah-expert-v1 | 10624.0 ± 181.4 |

## Discrete control

For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent. In the future, we can switch to better benchmarks such as the Atari portion of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged).

### Gather Data

To running CQL algorithm on Atari, you need to do the following things:

- Train an expert, by using the command listed in the QRDQN section of Atari examples: `python3 atari_qrdqn.py --task {your_task}`
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train offline model: `python3 atari_{bcq,cql,crr}.py --task {your_task} --load-buffer-name expert.hdf5`.

### BCQ

We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):

| Task | Online QRDQN | Behavioral | BCQ | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.1 (epoch 5) | `python3 atari_bcq.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 64.6 (epoch 12, could be higher) | `python3 atari_bcq.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` |

### CQL

We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):

| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.4 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 129.4 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |

We reduce the size of the offline data to 10% and 1% of the above and get:

Buffer size 100000:

| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |

Buffer size 10000:

| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |

### CRR

We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):

| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters |
| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |

Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.
Empty file added examples/offline/__init__.py
Empty file.
16 changes: 8 additions & 8 deletions examples/atari/atari_bcq.py → examples/offline/atari_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@

import numpy as np
import torch
from atari_network import DQN
from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter

from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import wrap_deepmind
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.discrete import Actor


Expand Down Expand Up @@ -93,18 +94,17 @@ def test_discrete_bcq(args=get_args()):
args.action_shape,
device=args.device,
hidden_sizes=args.hidden_sizes,
softmax_output=False
softmax_output=False,
).to(args.device)
imitation_net = Actor(
feature_net,
args.action_shape,
device=args.device,
hidden_sizes=args.hidden_sizes,
softmax_output=False
softmax_output=False,
).to(args.device)
optim = torch.optim.Adam(
list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr
)
actor_critic = ActorCritic(policy_net, imitation_net)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
# define policy
policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step,
Expand Down Expand Up @@ -171,7 +171,7 @@ def watch():
args.batch_size,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger
logger=logger,
)

pprint.pprint(result)
Expand Down
8 changes: 4 additions & 4 deletions examples/atari/atari_cql.py → examples/offline/atari_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

import numpy as np
import torch
from atari_network import QRDQN
from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter

from examples.atari.atari_network import QRDQN
from examples.atari.atari_wrapper import wrap_deepmind
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteCQLPolicy
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_discrete_cql(args=get_args()):
args.num_quantiles,
args.n_step,
args.target_update_freq,
min_q_weight=args.min_q_weight
min_q_weight=args.min_q_weight,
).to(args.device)
# load a previous policy
if args.resume_path:
Expand Down Expand Up @@ -156,7 +156,7 @@ def watch():
args.batch_size,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger
logger=logger,
)

pprint.pprint(result)
Expand Down
26 changes: 15 additions & 11 deletions examples/atari/atari_crr.py → examples/offline/atari_crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@

import numpy as np
import torch
from atari_network import DQN
from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter

from examples.atari.atari_network import DQN
from examples.atari.atari_wrapper import wrap_deepmind
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DiscreteCRRPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.discrete import Actor
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.discrete import Actor, Critic


def get_args():
Expand Down Expand Up @@ -91,15 +92,18 @@ def test_discrete_crr(args=get_args()):
actor = Actor(
feature_net,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
softmax_output=False,
).to(args.device)
critic = Critic(
feature_net,
hidden_sizes=args.hidden_sizes,
softmax_output=False
last_size=np.prod(args.action_shape),
device=args.device,
).to(args.device)
critic = DQN(*args.state_shape, args.action_shape,
device=args.device).to(args.device)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr
)
actor_critic = ActorCritic(actor, critic)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
# define policy
policy = DiscreteCRRPolicy(
actor,
Expand All @@ -110,7 +114,7 @@ def test_discrete_crr(args=get_args()):
ratio_upper_bound=args.ratio_upper_bound,
beta=args.beta,
min_q_weight=args.min_q_weight,
target_update_freq=args.target_update_freq
target_update_freq=args.target_update_freq,
).to(args.device)
# load a previous policy
if args.resume_path:
Expand Down Expand Up @@ -171,7 +175,7 @@ def watch():
args.batch_size,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger
logger=logger,
)

pprint.pprint(result)
Expand Down
18 changes: 3 additions & 15 deletions test/discrete/test_dqn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import os
import pickle
import pprint

import gym
Expand Down Expand Up @@ -42,9 +41,6 @@ def get_args():
parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--save-buffer-name', type=str, default="./expert_DQN_CartPole-v0.pkl"
)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
Expand Down Expand Up @@ -85,15 +81,15 @@ def test_dqn(args=get_args()):
optim,
args.gamma,
args.n_step,
target_update_freq=args.target_update_freq
target_update_freq=args.target_update_freq,
)
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
args.buffer_size,
buffer_num=len(train_envs),
alpha=args.alpha,
beta=args.beta
beta=args.beta,
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
Expand Down Expand Up @@ -142,7 +138,7 @@ def test_fn(epoch, env_step):
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger
logger=logger,
)
assert stop_fn(result['best_reward'])

Expand All @@ -157,14 +153,6 @@ def test_fn(epoch, env_step):
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")

# save buffer in pickle format, for imitation learning unittest
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
pickle.dump(buf, open(args.save_buffer_name, "wb"))
print(result["rews"].mean())


def test_pdqn(args=get_args()):
args.prioritized_replay = True
Expand Down
Loading

0 comments on commit 3592f45

Please sign in to comment.