-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add offline trainer and discrete BCQ algorithm (#263)
The result needs to be tuned after `done` issue fixed. Co-authored-by: n+e <[email protected]>
- Loading branch information
1 parent
a633a6a
commit a511cb4
Showing
26 changed files
with
628 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -145,3 +145,5 @@ MUJOCO_LOG.TXT | |
*.zip | ||
*.pstats | ||
*.swp | ||
*.pkl | ||
*.hdf5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import os | ||
import torch | ||
import pickle | ||
import pprint | ||
import argparse | ||
import numpy as np | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from tianshou.env import SubprocVectorEnv | ||
from tianshou.trainer import offline_trainer | ||
from tianshou.utils.net.discrete import Actor | ||
from tianshou.policy import DiscreteBCQPolicy | ||
from tianshou.data import Collector, ReplayBuffer | ||
|
||
from atari_network import DQN | ||
from atari_wrapper import wrap_deepmind | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") | ||
parser.add_argument("--seed", type=int, default=1626) | ||
parser.add_argument("--eps-test", type=float, default=0.001) | ||
parser.add_argument("--lr", type=float, default=6.25e-5) | ||
parser.add_argument("--gamma", type=float, default=0.99) | ||
parser.add_argument("--n-step", type=int, default=3) | ||
parser.add_argument("--target-update-freq", type=int, default=8000) | ||
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) | ||
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) | ||
parser.add_argument("--epoch", type=int, default=100) | ||
parser.add_argument("--step-per-epoch", type=int, default=10000) | ||
parser.add_argument("--batch-size", type=int, default=32) | ||
parser.add_argument('--hidden-sizes', type=int, | ||
nargs='*', default=[512]) | ||
parser.add_argument("--test-num", type=int, default=100) | ||
parser.add_argument('--frames_stack', type=int, default=4) | ||
parser.add_argument("--logdir", type=str, default="log") | ||
parser.add_argument("--render", type=float, default=0.) | ||
parser.add_argument("--resume-path", type=str, default=None) | ||
parser.add_argument("--watch", default=False, action="store_true", | ||
help="watch the play of pre-trained policy only") | ||
parser.add_argument("--log-interval", type=int, default=1000) | ||
parser.add_argument( | ||
"--load-buffer-name", type=str, | ||
default="./expert_DQN_PongNoFrameskip-v4.hdf5", | ||
) | ||
parser.add_argument( | ||
"--device", type=str, | ||
default="cuda" if torch.cuda.is_available() else "cpu", | ||
) | ||
args = parser.parse_known_args()[0] | ||
return args | ||
|
||
|
||
def make_atari_env(args): | ||
return wrap_deepmind(args.task, frame_stack=args.frames_stack) | ||
|
||
|
||
def make_atari_env_watch(args): | ||
return wrap_deepmind(args.task, frame_stack=args.frames_stack, | ||
episode_life=False, clip_rewards=False) | ||
|
||
|
||
def test_discrete_bcq(args=get_args()): | ||
# envs | ||
env = make_atari_env(args) | ||
args.state_shape = env.observation_space.shape or env.observation_space.n | ||
args.action_shape = env.action_space.shape or env.action_space.n | ||
# should be N_FRAMES x H x W | ||
print("Observations shape:", args.state_shape) | ||
print("Actions shape:", args.action_shape) | ||
# make environments | ||
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) | ||
for _ in range(args.test_num)]) | ||
# seed | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
test_envs.seed(args.seed) | ||
# model | ||
feature_net = DQN(*args.state_shape, args.action_shape, | ||
device=args.device, features_only=True).to(args.device) | ||
policy_net = Actor(feature_net, args.action_shape, | ||
hidden_sizes=args.hidden_sizes).to(args.device) | ||
imitation_net = Actor(feature_net, args.action_shape, | ||
hidden_sizes=args.hidden_sizes).to(args.device) | ||
optim = torch.optim.Adam( | ||
set(policy_net.parameters()).union(imitation_net.parameters()), | ||
lr=args.lr, | ||
) | ||
# define policy | ||
policy = DiscreteBCQPolicy( | ||
policy_net, imitation_net, optim, args.gamma, args.n_step, | ||
args.target_update_freq, args.eps_test, | ||
args.unlikely_action_threshold, args.imitation_logits_penalty, | ||
) | ||
# load a previous policy | ||
if args.resume_path: | ||
policy.load_state_dict(torch.load( | ||
args.resume_path, map_location=args.device | ||
)) | ||
print("Loaded agent from: ", args.resume_path) | ||
# buffer | ||
assert os.path.exists(args.load_buffer_name), \ | ||
"Please run atari_dqn.py first to get expert's data buffer." | ||
if args.load_buffer_name.endswith('.pkl'): | ||
buffer = pickle.load(open(args.load_buffer_name, "rb")) | ||
elif args.load_buffer_name.endswith('.hdf5'): | ||
buffer = ReplayBuffer.load_hdf5(args.load_buffer_name) | ||
else: | ||
print(f"Unknown buffer format: {args.load_buffer_name}") | ||
exit(0) | ||
|
||
# collector | ||
test_collector = Collector(policy, test_envs) | ||
|
||
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') | ||
writer = SummaryWriter(log_path) | ||
|
||
def save_fn(policy): | ||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) | ||
|
||
def stop_fn(mean_rewards): | ||
return False | ||
|
||
# watch agent's performance | ||
def watch(): | ||
print("Setup test envs ...") | ||
policy.eval() | ||
policy.set_eps(args.eps_test) | ||
test_envs.seed(args.seed) | ||
print("Testing agent ...") | ||
test_collector.reset() | ||
result = test_collector.collect(n_episode=[1] * args.test_num, | ||
render=args.render) | ||
pprint.pprint(result) | ||
|
||
if args.watch: | ||
watch() | ||
exit(0) | ||
|
||
result = offline_trainer( | ||
policy, buffer, test_collector, | ||
args.epoch, args.step_per_epoch, args.test_num, args.batch_size, | ||
stop_fn=stop_fn, save_fn=save_fn, writer=writer, | ||
log_interval=args.log_interval, | ||
) | ||
|
||
pprint.pprint(result) | ||
watch() | ||
|
||
|
||
if __name__ == "__main__": | ||
test_discrete_bcq(get_args()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.