diff --git a/examples/vizdoom/README.md b/examples/vizdoom/README.md index ed68a46c9..ab01b7a77 100644 --- a/examples/vizdoom/README.md +++ b/examples/vizdoom/README.md @@ -53,10 +53,6 @@ python3 replay.py maps/D4_battle2.cfg results/c51/d4.lmp See [maps/README.md](maps/README.md) -## Algorithms - -The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example. - ## Reward 1. living reward is bad @@ -64,3 +60,28 @@ The setting is exactly the same as Atari. You can definitely try more algorithms 3. negative reward for health and ammo2 is really helpful for d3/d4 4. only with positive reward for health is really helpful for d1 5. remove MOVE_BACKWARD may converge faster but the final performance may be lower + +## Algorithms + +The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example. + +### C51 (single run) + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| D2_navigation | 747.52 | ![](results/c51/D2_navigation_rew.png) | `python3 vizdoom_c51.py --task "D2_navigation"` | +| D3_battle | 1855.29 | ![](results/c51/D3_battle_rew.png) | `python3 vizdoom_c51.py --task "D3_battle"` | + +### PPO (single run) + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| D2_navigation | 770.75 | ![](results/ppo/D2_navigation_rew.png) | `python3 vizdoom_ppo.py --task "D2_navigation"` | +| D3_battle | 320.59 | ![](results/ppo/D3_battle_rew.png) | `python3 vizdoom_ppo.py --task "D3_battle"` | + +### PPO with ICM (single run) + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| D2_navigation | 844.99 | ![](results/ppo_icm/D2_navigation_rew.png) | `python3 vizdoom_ppo.py --task "D2_navigation" --icm-lr-scale 10` | +| D3_battle | 547.08 | ![](results/ppo_icm/D3_battle_rew.png) | `python3 vizdoom_ppo.py --task "D3_battle" --icm-lr-scale 10` | diff --git a/examples/vizdoom/results/c51/D2_navigation_rew.png b/examples/vizdoom/results/c51/D2_navigation_rew.png new file mode 100644 index 000000000..e5251366c Binary files /dev/null and b/examples/vizdoom/results/c51/D2_navigation_rew.png differ diff --git a/examples/vizdoom/results/c51/D3_battle_rew.png b/examples/vizdoom/results/c51/D3_battle_rew.png new file mode 100644 index 000000000..f54570ad5 Binary files /dev/null and b/examples/vizdoom/results/c51/D3_battle_rew.png differ diff --git a/examples/vizdoom/results/ppo/D2_navigation_rew.png b/examples/vizdoom/results/ppo/D2_navigation_rew.png new file mode 100644 index 000000000..d96ac55b9 Binary files /dev/null and b/examples/vizdoom/results/ppo/D2_navigation_rew.png differ diff --git a/examples/vizdoom/results/ppo/D3_battle_rew.png b/examples/vizdoom/results/ppo/D3_battle_rew.png new file mode 100644 index 000000000..716b11b6b Binary files /dev/null and b/examples/vizdoom/results/ppo/D3_battle_rew.png differ diff --git a/examples/vizdoom/results/ppo_icm/D2_navigation_rew.png b/examples/vizdoom/results/ppo_icm/D2_navigation_rew.png new file mode 100644 index 000000000..7a37d6189 Binary files /dev/null and b/examples/vizdoom/results/ppo_icm/D2_navigation_rew.png differ diff --git a/examples/vizdoom/results/ppo_icm/D3_battle_rew.png b/examples/vizdoom/results/ppo_icm/D3_battle_rew.png new file mode 100644 index 000000000..04fe92819 Binary files /dev/null and b/examples/vizdoom/results/ppo_icm/D3_battle_rew.png differ diff --git a/examples/vizdoom/vizdoom_a2c_icm.py b/examples/vizdoom/vizdoom_ppo.py similarity index 67% rename from examples/vizdoom/vizdoom_a2c_icm.py rename to examples/vizdoom/vizdoom_ppo.py index 99668ce16..ec88cb914 100644 --- a/examples/vizdoom/vizdoom_a2c_icm.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -6,11 +6,12 @@ import torch from env import Env from network import DQN +from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import ShmemVectorEnv -from tianshou.policy import A2CPolicy, ICMPolicy +from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic @@ -21,18 +22,28 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='D2_navigation') parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--buffer-size', type=int, default=2000000) - parser.add_argument('--lr', type=float, default=0.0001) + parser.add_argument('--buffer-size', type=int, default=100000) + parser.add_argument('--lr', type=float, default=0.00002) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=300) parser.add_argument('--step-per-epoch', type=int, default=100000) - parser.add_argument('--episode-per-collect', type=int, default=10) - parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--update-per-step', type=int, default=1) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) + parser.add_argument('--step-per-collect', type=int, default=1000) + parser.add_argument('--repeat-per-collect', type=int, default=4) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--hidden-size', type=int, default=512) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--rew-norm', type=int, default=False) + parser.add_argument('--vf-coef', type=float, default=0.5) + parser.add_argument('--ent-coef', type=float, default=0.01) + parser.add_argument('--gae-lambda', type=float, default=0.95) + parser.add_argument('--lr-decay', type=int, default=True) + parser.add_argument('--max-grad-norm', type=float, default=0.5) + parser.add_argument('--eps-clip', type=float, default=0.2) + parser.add_argument('--dual-clip', type=float, default=None) + parser.add_argument('--value-clip', type=int, default=0) + parser.add_argument('--norm-adv', type=int, default=1) + parser.add_argument('--recompute-adv', type=int, default=0) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( @@ -75,7 +86,7 @@ def get_args(): return parser.parse_args() -def test_a2c(args=get_args()): +def test_ppo(args=get_args()): args.cfg_path = f"maps/{args.task}.cfg" args.wad_path = f"maps/{args.task}.wad" args.res = (args.skip_num, 84, 84) @@ -105,33 +116,65 @@ def test_a2c(args=get_args()): test_envs.seed(args.seed) # define model net = DQN( - *args.state_shape, args.action_shape, device=args.device, features_only=True + *args.state_shape, + args.action_shape, + device=args.device, + features_only=True, + output_dim=args.hidden_size ) - actor = Actor( - net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device - ) - critic = Critic(net, hidden_sizes=args.hidden_sizes, device=args.device) + actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) + critic = Critic(net, device=args.device) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) + + lr_scheduler = None + if args.lr_decay: + # decay learning rate to 0 linearly + max_update_num = np.ceil( + args.step_per_epoch / args.step_per_collect + ) * args.epoch + + lr_scheduler = LambdaLR( + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num + ) + # define policy - dist = torch.distributions.Categorical - policy = A2CPolicy(actor, critic, optim, dist).to(args.device) + def dist(p): + return torch.distributions.Categorical(logits=p) + + policy = PPOPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + max_grad_norm=args.max_grad_norm, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + reward_normalization=args.rew_norm, + action_scaling=False, + lr_scheduler=lr_scheduler, + action_space=env.action_space, + eps_clip=args.eps_clip, + value_clip=args.value_clip, + dual_clip=args.dual_clip, + advantage_normalization=args.norm_adv, + recompute_advantage=args.recompute_adv + ).to(args.device) if args.icm_lr_scale > 0: feature_net = DQN( *args.state_shape, args.action_shape, device=args.device, - features_only=True + features_only=True, + output_dim=args.hidden_size ) action_dim = np.prod(args.action_shape) feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule( - feature_net.net, - feature_dim, - action_dim, - hidden_sizes=args.hidden_sizes, - device=args.device + feature_net.net, feature_dim, action_dim, device=args.device ) - icm_optim = torch.optim.adam(icm_net.parameters(), lr=args.lr) + icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) policy = ICMPolicy( policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale, args.icm_forward_loss_weight @@ -153,7 +196,8 @@ def test_a2c(args=get_args()): train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # log - log_path = os.path.join(args.logdir, args.task, 'a2c') + log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo' + log_path = os.path.join(args.logdir, args.task, log_name) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) @@ -162,10 +206,15 @@ def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): - return False + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + elif 'Pong' in args.task: + return mean_rewards >= 20 + else: + return False + # watch agent's performance def watch(): - # watch agent's performance print("Setup test envs ...") policy.eval() test_envs.seed(args.seed) @@ -210,7 +259,7 @@ def watch(): args.repeat_per_collect, args.test_num, args.batch_size, - episode_per_collect=args.episode_per_collect, + step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn, logger=logger, @@ -222,4 +271,4 @@ def watch(): if __name__ == '__main__': - test_a2c(get_args()) + test_ppo(get_args())