From 3592f45446e6cc98423df2f1c28d8ca0ef2be821 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sun, 28 Nov 2021 07:10:28 -0800 Subject: [PATCH] Fix critic network for Discrete CRR (#485) - Fixes an inconsistency in the implementation of Discrete CRR. Now it uses `Critic` class for its critic, following conventions in other actor-critic policies; - Updates several offline policies to use `ActorCritic` class for its optimizer to eliminate randomness caused by parameter sharing between actor and critic; - Add `writer.flush()` in TensorboardLogger to ensure real-time result; - Enable `test_collector=None` in 3 trainers to turn off testing during training; - Updates the Atari offline results in README.md; - Moves Atari offline RL examples to `examples/offline`; tests to `test/offline` per review comments. --- examples/__init__.py | 0 examples/atari/README.md | 65 +------ examples/atari/__init__.py | 0 examples/offline/README.md | 62 ++++++- examples/offline/__init__.py | 0 examples/{atari => offline}/atari_bcq.py | 16 +- examples/{atari => offline}/atari_cql.py | 8 +- examples/{atari => offline}/atari_crr.py | 26 +-- test/discrete/test_dqn.py | 18 +- test/discrete/test_qrdqn.py | 20 +-- test/offline/gather_cartpole_data.py | 160 ++++++++++++++++++ .../test_discrete_bcq.py} | 37 ++-- .../test_discrete_cql.py} | 12 +- .../test_discrete_crr.py} | 40 +++-- tianshou/__init__.py | 2 +- tianshou/policy/imitation/discrete_crr.py | 4 +- tianshou/trainer/offline.py | 65 ++++--- tianshou/trainer/offpolicy.py | 74 ++++---- tianshou/trainer/onpolicy.py | 74 ++++---- tianshou/trainer/utils.py | 31 ++-- tianshou/utils/logger/tensorboard.py | 1 + 21 files changed, 458 insertions(+), 257 deletions(-) create mode 100644 examples/__init__.py create mode 100644 examples/atari/__init__.py create mode 100644 examples/offline/__init__.py rename examples/{atari => offline}/atari_bcq.py (94%) rename examples/{atari => offline}/atari_cql.py (97%) rename examples/{atari => offline}/atari_crr.py (90%) create mode 100644 test/offline/gather_cartpole_data.py rename test/{discrete/test_il_bcq.py => offline/test_discrete_bcq.py} (84%) rename test/{discrete/test_qrdqn_il_cql.py => offline/test_discrete_cql.py} (92%) rename test/{discrete/test_il_crr.py => offline/test_discrete_crr.py} (79%) diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/atari/README.md b/examples/atari/README.md index ffccecba3..51b1af931 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -1,4 +1,4 @@ -# Atari General +# Atari The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network. @@ -95,66 +95,3 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | MsPacmanNoFrameskip-v4 | 3101 | ![](results/rainbow/MsPacman_rew.png) | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` | | SeaquestNoFrameskip-v4 | 2126 | ![](results/rainbow/Seaquest_rew.png) | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 1794.5 | ![](results/rainbow/SpaceInvaders_rew.png) | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` | - -# BCQ - -To running BCQ algorithm on Atari, you need to do the following things: - -- Train an expert, by using the command listed in the above DQN section; -- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); -- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`. - -We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): - -| Task | Online DQN | Behavioral | BCQ | -| ---------------------- | ---------- | ---------- | --------------------------------- | -| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) | -| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) | - -# CQL - -To running CQL algorithm on Atari, you need to do the following things: - -- Train an expert, by using the command listed in the above QRDQN section; -- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); -- Train CQL: `python3 atari_cql.py --task {your_task} --load-buffer-name expert.hdf5`. - -We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): - -| Task | Online QRDQN | Behavioral | CQL | parameters | -| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | 6.8 | 19.5 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | -| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 248.3 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | - -We reduce the size of the offline data to 10% and 1% of the above and get: - -Buffer size 100000: - -| Task | Online QRDQN | Behavioral | CQL | parameters | -| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` | -| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` | - -Buffer size 10000: - -| Task | Online QRDQN | Behavioral | CQL | parameters | -| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` | -| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` | - -# CRR - -To running CRR algorithm on Atari, you need to do the following things: - -- Train an expert, by using the command listed in the above QRDQN section; -- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); -- Train CQL: `python3 atari_crr.py --task {your_task} --load-buffer-name expert.hdf5`. - -We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): - -| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters | -| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 16.1 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | -| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 26.4 (epoch 12) | 125.0 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | - -Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps. diff --git a/examples/atari/__init__.py b/examples/atari/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/offline/README.md b/examples/offline/README.md index 8995ee6e2..c0a07fab0 100644 --- a/examples/offline/README.md +++ b/examples/offline/README.md @@ -2,9 +2,11 @@ In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore. -Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets. +## Continous control -## Train +Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets. + +### Train Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset. @@ -26,3 +28,59 @@ After 1M steps: | --------------------- | --------------- | | halfcheetah-expert-v1 | 10624.0 ± 181.4 | +## Discrete control + +For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent. In the future, we can switch to better benchmarks such as the Atari portion of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged). + +### Gather Data + +To running CQL algorithm on Atari, you need to do the following things: + +- Train an expert, by using the command listed in the QRDQN section of Atari examples: `python3 atari_qrdqn.py --task {your_task}` +- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); +- Train offline model: `python3 atari_{bcq,cql,crr}.py --task {your_task} --load-buffer-name expert.hdf5`. + +### BCQ + +We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): + +| Task | Online QRDQN | Behavioral | BCQ | parameters | +| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.1 (epoch 5) | `python3 atari_bcq.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 64.6 (epoch 12, could be higher) | `python3 atari_bcq.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` | + +### CQL + +We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): + +| Task | Online QRDQN | Behavioral | CQL | parameters | +| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.4 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 129.4 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | + +We reduce the size of the offline data to 10% and 1% of the above and get: + +Buffer size 100000: + +| Task | Online QRDQN | Behavioral | CQL | parameters | +| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` | + +Buffer size 10000: + +| Task | Online QRDQN | Behavioral | CQL | parameters | +| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` | +| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` | + +### CRR + +We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): + +| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters | +| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | + +Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps. diff --git a/examples/offline/__init__.py b/examples/offline/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/atari/atari_bcq.py b/examples/offline/atari_bcq.py similarity index 94% rename from examples/atari/atari_bcq.py rename to examples/offline/atari_bcq.py index ec89243b4..83865d18d 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -6,15 +6,16 @@ import numpy as np import torch -from atari_network import DQN -from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter +from examples.atari.atari_network import DQN +from examples.atari.atari_wrapper import wrap_deepmind from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import ShmemVectorEnv from tianshou.policy import DiscreteBCQPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor @@ -93,18 +94,17 @@ def test_discrete_bcq(args=get_args()): args.action_shape, device=args.device, hidden_sizes=args.hidden_sizes, - softmax_output=False + softmax_output=False, ).to(args.device) imitation_net = Actor( feature_net, args.action_shape, device=args.device, hidden_sizes=args.hidden_sizes, - softmax_output=False + softmax_output=False, ).to(args.device) - optim = torch.optim.Adam( - list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr - ) + actor_critic = ActorCritic(policy_net, imitation_net) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # define policy policy = DiscreteBCQPolicy( policy_net, imitation_net, optim, args.gamma, args.n_step, @@ -171,7 +171,7 @@ def watch(): args.batch_size, stop_fn=stop_fn, save_fn=save_fn, - logger=logger + logger=logger, ) pprint.pprint(result) diff --git a/examples/atari/atari_cql.py b/examples/offline/atari_cql.py similarity index 97% rename from examples/atari/atari_cql.py rename to examples/offline/atari_cql.py index 685e006db..22ef7b253 100644 --- a/examples/atari/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -6,10 +6,10 @@ import numpy as np import torch -from atari_network import QRDQN -from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter +from examples.atari.atari_network import QRDQN +from examples.atari.atari_wrapper import wrap_deepmind from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import ShmemVectorEnv from tianshou.policy import DiscreteCQLPolicy @@ -94,7 +94,7 @@ def test_discrete_cql(args=get_args()): args.num_quantiles, args.n_step, args.target_update_freq, - min_q_weight=args.min_q_weight + min_q_weight=args.min_q_weight, ).to(args.device) # load a previous policy if args.resume_path: @@ -156,7 +156,7 @@ def watch(): args.batch_size, stop_fn=stop_fn, save_fn=save_fn, - logger=logger + logger=logger, ) pprint.pprint(result) diff --git a/examples/atari/atari_crr.py b/examples/offline/atari_crr.py similarity index 90% rename from examples/atari/atari_crr.py rename to examples/offline/atari_crr.py index 8905c7e58..0214bf2f5 100644 --- a/examples/atari/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -6,16 +6,17 @@ import numpy as np import torch -from atari_network import DQN -from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter +from examples.atari.atari_network import DQN +from examples.atari.atari_wrapper import wrap_deepmind from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import ShmemVectorEnv from tianshou.policy import DiscreteCRRPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.discrete import Actor +from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.discrete import Actor, Critic def get_args(): @@ -91,15 +92,18 @@ def test_discrete_crr(args=get_args()): actor = Actor( feature_net, args.action_shape, + hidden_sizes=args.hidden_sizes, device=args.device, + softmax_output=False, + ).to(args.device) + critic = Critic( + feature_net, hidden_sizes=args.hidden_sizes, - softmax_output=False + last_size=np.prod(args.action_shape), + device=args.device, ).to(args.device) - critic = DQN(*args.state_shape, args.action_shape, - device=args.device).to(args.device) - optim = torch.optim.Adam( - list(actor.parameters()) + list(critic.parameters()), lr=args.lr - ) + actor_critic = ActorCritic(actor, critic) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # define policy policy = DiscreteCRRPolicy( actor, @@ -110,7 +114,7 @@ def test_discrete_crr(args=get_args()): ratio_upper_bound=args.ratio_upper_bound, beta=args.beta, min_q_weight=args.min_q_weight, - target_update_freq=args.target_update_freq + target_update_freq=args.target_update_freq, ).to(args.device) # load a previous policy if args.resume_path: @@ -171,7 +175,7 @@ def watch(): args.batch_size, stop_fn=stop_fn, save_fn=save_fn, - logger=logger + logger=logger, ) pprint.pprint(result) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 6912a1933..c02866493 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,6 +1,5 @@ import argparse import os -import pickle import pprint import gym @@ -42,9 +41,6 @@ def get_args(): parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) - parser.add_argument( - '--save-buffer-name', type=str, default="./expert_DQN_CartPole-v0.pkl" - ) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' ) @@ -85,7 +81,7 @@ def test_dqn(args=get_args()): optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq + target_update_freq=args.target_update_freq, ) # buffer if args.prioritized_replay: @@ -93,7 +89,7 @@ def test_dqn(args=get_args()): args.buffer_size, buffer_num=len(train_envs), alpha=args.alpha, - beta=args.beta + beta=args.beta, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) @@ -142,7 +138,7 @@ def test_fn(epoch, env_step): test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, - logger=logger + logger=logger, ) assert stop_fn(result['best_reward']) @@ -157,14 +153,6 @@ def test_fn(epoch, env_step): rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") - # save buffer in pickle format, for imitation learning unittest - buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) - policy.set_eps(0.2) - collector = Collector(policy, test_envs, buf, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) - pickle.dump(buf, open(args.save_buffer_name, "wb")) - print(result["rews"].mean()) - def test_pdqn(args=get_args()): args.prioritized_replay = True diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index cf8d22212..956cb03fd 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -1,6 +1,5 @@ import argparse import os -import pickle import pprint import gym @@ -43,9 +42,6 @@ def get_args(): parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) - parser.add_argument( - '--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl" - ) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' ) @@ -80,7 +76,7 @@ def test_qrdqn(args=get_args()): hidden_sizes=args.hidden_sizes, device=args.device, softmax=False, - num_atoms=args.num_quantiles + num_atoms=args.num_quantiles, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = QRDQNPolicy( @@ -89,7 +85,7 @@ def test_qrdqn(args=get_args()): args.gamma, args.num_quantiles, args.n_step, - target_update_freq=args.target_update_freq + target_update_freq=args.target_update_freq, ).to(args.device) # buffer if args.prioritized_replay: @@ -97,7 +93,7 @@ def test_qrdqn(args=get_args()): args.buffer_size, buffer_num=len(train_envs), alpha=args.alpha, - beta=args.beta + beta=args.beta, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) @@ -146,7 +142,7 @@ def test_fn(epoch, env_step): stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step + update_per_step=args.update_per_step, ) assert stop_fn(result['best_reward']) @@ -161,14 +157,6 @@ def test_fn(epoch, env_step): rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") - # save buffer in pickle format, for imitation learning unittest - buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) - policy.set_eps(0.9) # 10% of expert data as demonstrated in the original paper - collector = Collector(policy, test_envs, buf, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) - pickle.dump(buf, open(args.save_buffer_name, "wb")) - print(result["rews"].mean()) - def test_pqrdqn(args=get_args()): args.prioritized_replay = True diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py new file mode 100644 index 000000000..78d85170d --- /dev/null +++ b/test/offline/gather_cartpole_data.py @@ -0,0 +1,160 @@ +import argparse +import os +import pickle + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.policy import QRDQNPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=1) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--num-quantiles', type=int, default=200) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--prioritized-replay', action="store_true", default=False) + parser.add_argument('--alpha', type=float, default=0.6) + parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument( + '--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl" + ) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + args = parser.parse_known_args()[0] + return args + + +def gather_data(): + args = get_args() + env = gym.make(args.task) + if args.task == 'CartPole-v0': + env.spec.reward_threshold = 190 # lower the goal + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=False, + num_atoms=args.num_quantiles, + ) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = QRDQNPolicy( + net, + optim, + args.gamma, + args.num_quantiles, + args.n_step, + target_update_freq=args.target_update_freq, + ).to(args.device) + # buffer + if args.prioritized_replay: + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta, + ) + else: + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size * args.training_num) + # log + log_path = os.path.join(args.logdir, args.task, 'qrdqn') + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def train_fn(epoch, env_step): + # eps annnealing, just a demo + if env_step <= 10000: + policy.set_eps(args.eps_train) + elif env_step <= 50000: + eps = args.eps_train - (env_step - 10000) / \ + 40000 * (0.9 * args.eps_train) + policy.set_eps(eps) + else: + policy.set_eps(0.1 * args.eps_train) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + ) + assert stop_fn(result['best_reward']) + + # save buffer in pickle format, for imitation learning unittest + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) + policy.set_eps(0.2) + collector = Collector(policy, test_envs, buf, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + pickle.dump(buf, open(args.save_buffer_name, "wb")) + print(result["rews"].mean()) + return buf diff --git a/test/discrete/test_il_bcq.py b/test/offline/test_discrete_bcq.py similarity index 84% rename from test/discrete/test_il_bcq.py rename to test/offline/test_discrete_bcq.py index 47540dadd..460ddb304 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -13,7 +13,13 @@ from tianshou.policy import DiscreteBCQPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.discrete import Actor + +if __name__ == "__main__": + from gather_cartpole_data import gather_data +else: # pytest + from test.offline.gather_cartpole_data import gather_data def get_args(): @@ -37,7 +43,7 @@ def get_args(): parser.add_argument( "--load-buffer-name", type=str, - default="./expert_DQN_CartPole-v0.pkl", + default="./expert_QRDQN_CartPole-v0.pkl", ) parser.add_argument( "--device", @@ -65,21 +71,15 @@ def test_discrete_bcq(args=get_args()): torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - policy_net = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device + net = Net(args.state_shape, args.hidden_sizes[0], device=args.device) + policy_net = Actor( + net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device ).to(args.device) - imitation_net = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device + imitation_net = Actor( + net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device ).to(args.device) - optim = torch.optim.Adam( - list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr - ) + actor_critic = ActorCritic(policy_net, imitation_net) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) policy = DiscreteBCQPolicy( policy_net, @@ -93,9 +93,10 @@ def test_discrete_bcq(args=get_args()): args.imitation_logits_penalty, ) # buffer - assert os.path.exists(args.load_buffer_name), \ - "Please run test_dqn.py first to get expert's data buffer." - buffer = pickle.load(open(args.load_buffer_name, "rb")) + if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + else: + buffer = gather_data() # collector test_collector = Collector(policy, test_envs, exploration_noise=True) diff --git a/test/discrete/test_qrdqn_il_cql.py b/test/offline/test_discrete_cql.py similarity index 92% rename from test/discrete/test_qrdqn_il_cql.py rename to test/offline/test_discrete_cql.py index 01b868f13..c97f45628 100644 --- a/test/discrete/test_qrdqn_il_cql.py +++ b/test/offline/test_discrete_cql.py @@ -15,6 +15,11 @@ from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net +if __name__ == "__main__": + from gather_cartpole_data import gather_data +else: # pytest + from test.offline.gather_cartpole_data import gather_data + def get_args(): parser = argparse.ArgumentParser() @@ -83,9 +88,10 @@ def test_discrete_cql(args=get_args()): min_q_weight=args.min_q_weight ).to(args.device) # buffer - assert os.path.exists(args.load_buffer_name), \ - "Please run test_qrdqn.py first to get expert's data buffer." - buffer = pickle.load(open(args.load_buffer_name, "rb")) + if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + else: + buffer = gather_data() # collector test_collector = Collector(policy, test_envs, exploration_noise=True) diff --git a/test/discrete/test_il_crr.py b/test/offline/test_discrete_crr.py similarity index 79% rename from test/discrete/test_il_crr.py rename to test/offline/test_discrete_crr.py index 929469e8b..0b4e7c63d 100644 --- a/test/discrete/test_il_crr.py +++ b/test/offline/test_discrete_crr.py @@ -13,7 +13,13 @@ from tianshou.policy import DiscreteCRRPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.discrete import Actor, Critic + +if __name__ == "__main__": + from gather_cartpole_data import gather_data +else: # pytest + from test.offline.gather_cartpole_data import gather_data def get_args(): @@ -34,7 +40,7 @@ def get_args(): parser.add_argument( "--load-buffer-name", type=str, - default="./expert_DQN_CartPole-v0.pkl", + default="./expert_QRDQN_CartPole-v0.pkl", ) parser.add_argument( "--device", @@ -60,23 +66,22 @@ def test_discrete_crr(args=get_args()): torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - actor = Net( - args.state_shape, + net = Net(args.state_shape, args.hidden_sizes[0], device=args.device) + actor = Actor( + net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, - softmax=False + softmax_output=False ) - critic = Net( - args.state_shape, - args.action_shape, + critic = Critic( + net, hidden_sizes=args.hidden_sizes, - device=args.device, - softmax=False - ) - optim = torch.optim.Adam( - list(actor.parameters()) + list(critic.parameters()), lr=args.lr + last_size=np.prod(args.action_shape), + device=args.device ) + actor_critic = ActorCritic(actor, critic) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) policy = DiscreteCRRPolicy( actor, @@ -86,14 +91,15 @@ def test_discrete_crr(args=get_args()): target_update_freq=args.target_update_freq, ).to(args.device) # buffer - assert os.path.exists(args.load_buffer_name), \ - "Please run test_dqn.py first to get expert's data buffer." - buffer = pickle.load(open(args.load_buffer_name, "rb")) + if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + else: + buffer = gather_data() # collector test_collector = Collector(policy, test_envs, exploration_noise=True) - log_path = os.path.join(args.logdir, args.task, 'discrete_cql') + log_path = os.path.join(args.logdir, args.task, 'discrete_crr') writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index bc61c15fc..c9416847d 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,6 +1,6 @@ from tianshou import data, env, exploration, policy, trainer, utils -__version__ = "0.4.4" +__version__ = "0.4.5" __all__ = [ "env", diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 6a149509e..dd4efe78b 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -83,14 +83,14 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # type: ignor if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() - q_t, _ = self.critic(batch.obs) + q_t = self.critic(batch.obs) act = to_torch(batch.act, dtype=torch.long, device=q_t.device) qa_t = q_t.gather(1, act.unsqueeze(1)) # Critic loss with torch.no_grad(): target_a_t, _ = self.actor_old(batch.obs_next) target_m = Categorical(logits=target_a_t) - q_t_target, _ = self.critic_old(batch.obs_next) + q_t_target = self.critic_old(batch.obs_next) rew = to_torch_as(batch.rew, q_t_target) expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True) expected_target_q[batch.done > 0] = 0.0 diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index f5b454fa2..d2f85bc2a 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -14,7 +14,7 @@ def offline_trainer( policy: BasePolicy, buffer: ReplayBuffer, - test_collector: Collector, + test_collector: Optional[Collector], max_epoch: int, update_per_epoch: int, episode_per_test: int, @@ -33,7 +33,8 @@ def offline_trainer( The "step" in offline trainer means a gradient step. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param Collector test_collector: the collector used for testing. + :param Collector test_collector: the collector used for testing. If it's None, then + no testing will be performed. :param int max_epoch: the maximum number of epochs for training. The training process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. :param int update_per_epoch: the number of policy network updates, so-called @@ -73,14 +74,16 @@ def offline_trainer( start_epoch, _, gradient_step = logger.restore_data() stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() - test_collector.reset_stat() - test_result = test_episode( - policy, test_collector, test_fn, start_epoch, episode_per_test, logger, - gradient_step, reward_metric - ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] + if test_collector is not None: + test_c: Collector = test_collector + test_collector.reset_stat() + test_result = test_episode( + policy, test_c, test_fn, start_epoch, episode_per_test, logger, + gradient_step, reward_metric + ) + best_epoch = start_epoch + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] if save_fn: save_fn(policy) @@ -97,22 +100,32 @@ def offline_trainer( data[k] = f"{losses[k]:.3f}" logger.log_update_data(losses, gradient_step) t.set_postfix(**data) - # test - test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, logger, - gradient_step, reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) - if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + # test + if test_collector is not None: + test_result = test_episode( + policy, test_c, test_fn, epoch, episode_per_test, logger, + gradient_step, reward_metric ) - if stop_fn and stop_fn(best_reward): - break - return gather_info(start_time, None, test_collector, best_reward, best_reward_std) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std + if save_fn: + save_fn(policy) + if verbose: + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) + if stop_fn and stop_fn(best_reward): + break + + if test_collector is None and save_fn: + save_fn(policy) + + if test_collector is None: + return gather_info(start_time, None, None, 0.0, 0.0) + else: + return gather_info( + start_time, None, test_collector, best_reward, best_reward_std + ) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index eb14dfb1f..9b8727b24 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -14,7 +14,7 @@ def offpolicy_trainer( policy: BasePolicy, train_collector: Collector, - test_collector: Collector, + test_collector: Optional[Collector], max_epoch: int, step_per_epoch: int, step_per_collect: int, @@ -38,7 +38,8 @@ def offpolicy_trainer( :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param Collector train_collector: the collector used for training. - :param Collector test_collector: the collector used for testing. + :param Collector test_collector: the collector used for testing. If it's None, then + no testing will be performed. :param int max_epoch: the maximum number of epochs for training. The training process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. :param int step_per_epoch: the number of transitions collected per epoch. @@ -90,14 +91,19 @@ def offpolicy_trainer( stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() - test_collector.reset_stat() - test_in_train = test_in_train and train_collector.policy == policy - test_result = test_episode( - policy, test_collector, test_fn, start_epoch, episode_per_test, logger, - env_step, reward_metric + test_in_train = test_in_train and ( + train_collector.policy == policy and test_collector is not None ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] + + if test_collector is not None: + test_c: Collector = test_collector # for mypy + test_collector.reset_stat() + test_result = test_episode( + policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, + reward_metric + ) + best_epoch = start_epoch + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] if save_fn: save_fn(policy) @@ -129,8 +135,8 @@ def offpolicy_trainer( if result["n/ep"] > 0: if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, - logger, env_step + policy, test_c, test_fn, epoch, episode_per_test, logger, + env_step ) if stop_fn(test_result["rew"]): if save_fn: @@ -156,24 +162,32 @@ def offpolicy_trainer( t.set_postfix(**data) if t.n <= t.total: t.update() - # test - test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step, - reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + # test + if test_collector is not None: + test_result = test_episode( + policy, test_c, test_fn, epoch, episode_per_test, logger, env_step, + reward_metric ) - if stop_fn and stop_fn(best_reward): - break - return gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std - ) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std + if save_fn: + save_fn(policy) + if verbose: + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) + if stop_fn and stop_fn(best_reward): + break + + if test_collector is None and save_fn: + save_fn(policy) + + if test_collector is None: + return gather_info(start_time, train_collector, None, 0.0, 0.0) + else: + return gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 2c539a2e7..251c55637 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -14,7 +14,7 @@ def onpolicy_trainer( policy: BasePolicy, train_collector: Collector, - test_collector: Collector, + test_collector: Optional[Collector], max_epoch: int, step_per_epoch: int, repeat_per_collect: int, @@ -39,7 +39,8 @@ def onpolicy_trainer( :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param Collector train_collector: the collector used for training. - :param Collector test_collector: the collector used for testing. + :param Collector test_collector: the collector used for testing. If it's None, then + no testing will be performed. :param int max_epoch: the maximum number of epochs for training. The training process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. :param int step_per_epoch: the number of transitions collected per epoch. @@ -96,14 +97,19 @@ def onpolicy_trainer( stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() - test_collector.reset_stat() - test_in_train = test_in_train and train_collector.policy == policy - test_result = test_episode( - policy, test_collector, test_fn, start_epoch, episode_per_test, logger, - env_step, reward_metric + test_in_train = test_in_train and ( + train_collector.policy == policy and test_collector is not None ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] + + if test_collector is not None: + test_c: Collector = test_collector # for mypy + test_collector.reset_stat() + test_result = test_episode( + policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, + reward_metric + ) + best_epoch = start_epoch + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] if save_fn: save_fn(policy) @@ -137,8 +143,8 @@ def onpolicy_trainer( if result["n/ep"] > 0: if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, - logger, env_step + policy, test_c, test_fn, epoch, episode_per_test, logger, + env_step ) if stop_fn(test_result["rew"]): if save_fn: @@ -172,24 +178,32 @@ def onpolicy_trainer( t.set_postfix(**data) if t.n <= t.total: t.update() - # test - test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step, - reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + # test + if test_collector is not None: + test_result = test_episode( + policy, test_c, test_fn, epoch, episode_per_test, logger, env_step, + reward_metric ) - if stop_fn and stop_fn(best_reward): - break - return gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std - ) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std + if save_fn: + save_fn(policy) + if verbose: + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) + if stop_fn and stop_fn(best_reward): + break + + if test_collector is None and save_fn: + save_fn(policy) + + if test_collector is None: + return gather_info(start_time, train_collector, None, 0.0, 0.0) + else: + return gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 6ad2f0f2a..9bb841248 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -36,7 +36,7 @@ def test_episode( def gather_info( start_time: float, train_c: Optional[Collector], - test_c: Collector, + test_c: Optional[Collector], best_reward: float, best_reward_std: float, ) -> Dict[str, Union[float, str]]: @@ -58,21 +58,32 @@ def gather_info( * ``duration`` the total elapsed time. """ duration = time.time() - start_time - model_time = duration - test_c.collect_time - test_speed = test_c.collect_step / test_c.collect_time + model_time = duration result: Dict[str, Union[float, str]] = { - "test_step": test_c.collect_step, - "test_episode": test_c.collect_episode, - "test_time": f"{test_c.collect_time:.2f}s", - "test_speed": f"{test_speed:.2f} step/s", - "best_reward": best_reward, - "best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}", "duration": f"{duration:.2f}s", "train_time/model": f"{model_time:.2f}s", } + if test_c is not None: + model_time = duration - test_c.collect_time + test_speed = test_c.collect_step / test_c.collect_time + result.update( + { + "test_step": test_c.collect_step, + "test_episode": test_c.collect_episode, + "test_time": f"{test_c.collect_time:.2f}s", + "test_speed": f"{test_speed:.2f} step/s", + "best_reward": best_reward, + "best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}", + "duration": f"{duration:.2f}s", + "train_time/model": f"{model_time:.2f}s", + } + ) if train_c is not None: model_time -= train_c.collect_time - train_speed = train_c.collect_step / (duration - test_c.collect_time) + if test_c is not None: + train_speed = train_c.collect_step / (duration - test_c.collect_time) + else: + train_speed = train_c.collect_step / duration result.update( { "train_step": train_c.collect_step, diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index 86e873cda..469d32765 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -35,6 +35,7 @@ def __init__( def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: for k, v in data.items(): self.writer.add_scalar(k, v, global_step=step) + self.writer.flush() # issue #482 def save_data( self,