This repository has been archived by the owner on May 26, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
executable file
·105 lines (84 loc) · 2.72 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python3
# SPDX-FileCopyrightText: 2020 Harish Rajagopal <[email protected]>
#
# SPDX-License-Identifier: MIT
"""Test the DQN on Pong."""
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, Namespace
from pathlib import Path
from typing import Deque
import gym
import tensorflow as tf
from model import DQN
from train import DQNTrainer
from utils import ENV_NAME, STATE_FRAMES, Config, load_config, preprocess
def test(
env: gym.Env, model: tf.keras.Model, config: Config, log_dir: Path
) -> None:
"""Test the DQN on Pong.
Args:
env: The Atari Pong environment
model: The model to be trained
config: The hyper-param config
log_dir: Path where to save the video
"""
env = gym.wrappers.Monitor(
env,
log_dir,
force=True, # overwrite existing videos
video_callable=lambda count: True, # force save this episode
)
state = Deque[tf.Tensor](maxlen=STATE_FRAMES)
state.append(preprocess(env.reset())) # initial state
print("Starting testing...")
while True:
if len(state) < STATE_FRAMES:
initial = None
action = env.action_space.sample()
else:
initial = tf.stack(state, axis=-1)
action = model.choose_action(initial) # choose greedily
state_new, _, done, _ = env.step(action)
state_new = preprocess(state_new)
state.append(state_new)
if done:
break
print("Testing done")
def main(args: Namespace) -> None:
"""Run the main program.
Arguments:
args: The object containing the commandline arguments
"""
config = load_config(args.config)
tf.keras.utils.set_random_seed(config.seed)
env = gym.make(ENV_NAME, frameskip=config.frame_skips)
env.seed(config.seed)
model = DQN(env.action_space.n, config)
model.load_weights(args.load_dir / DQNTrainer.MODEL_NAME)
print("Loaded model")
if not args.log_dir.exists():
args.log_dir.mkdir(parents=True)
test(env, model, config, log_dir=args.log_dir)
if __name__ == "__main__":
parser = ArgumentParser(
description="Test the DQN on Pong",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--load-dir",
type=Path,
default="./checkpoints/",
help="path from where to load the model and data",
)
parser.add_argument(
"-c",
"--config",
type=Path,
help="Path to a TOML config containing hyper-parameter values",
)
parser.add_argument(
"--log-dir",
type=Path,
default="./logs/test/",
help="path where to save the video",
)
main(parser.parse_args())