forked from NTT123/a0-jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_agent.py
316 lines (273 loc) · 9.82 KB
/
train_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""
AlphaZero training script.
Train agent by self-play only.
"""
import os
import pickle
import random
from functools import partial
import chex
import click
import fire
import jax
import jax.numpy as jnp
import jax.tools.colab_tpu
import numpy as np
import opax
import optax
import pax
from env import Enviroment
from play import agent_vs_agent_multiple_games
from tree_search import improve_policy_with_mcts, recurrent_fn
from utils import batched_policy, env_step, import_class, replicate, reset_env
EPSILON = 1e-9 # a very small positive value
@chex.dataclass(frozen=True)
class TrainingExample:
"""AlphaZero training example.
state: the current state of the game.
action_weights: the target action probabilities from MCTS policy.
value: the target value from self-play result.
"""
state: chex.Array
action_weights: chex.Array
value: chex.Array
@chex.dataclass(frozen=True)
class MoveOutput:
"""The output of a single self-play move.
state: the current state of game.
reward: the reward after execute the action from MCTS policy.
terminated: the current state is a terminated state (bad state).
action_weights: the action probabilities from MCTS policy.
"""
state: chex.Array
reward: chex.Array
terminated: chex.Array
action_weights: chex.Array
@partial(jax.pmap, in_axes=(None, None, 0), static_broadcasted_argnums=(3, 4))
def collect_batched_self_play_data(
agent,
env: Enviroment,
rng_key: chex.Array,
batch_size: int,
num_simulations_per_move: int,
):
"""Collect a batch of self-play data using mcts."""
def single_move(prev, inputs):
"""Execute one self-play move using MCTS.
This function is designed to be compatible with jax.scan.
"""
env, rng_key, step = prev
del inputs
rng_key, rng_key_next = jax.random.split(rng_key, 2)
state = jax.vmap(lambda e: e.canonical_observation())(env)
terminated = env.is_terminated()
policy_output = improve_policy_with_mcts(
agent,
env,
rng_key,
recurrent_fn,
num_simulations_per_move,
)
env, reward = jax.vmap(env_step)(env, policy_output.action)
return (env, rng_key_next, step + 1), MoveOutput(
state=state,
action_weights=policy_output.action_weights,
reward=reward,
terminated=terminated,
)
env = reset_env(env)
env = replicate(env, batch_size)
step = jnp.array(1)
_, self_play_data = pax.scan(
single_move,
(env, rng_key, step),
None,
length=env.max_num_steps(),
time_major=False,
)
return self_play_data
def prepare_training_data(data: MoveOutput, env):
"""Preprocess the data collected from self-play.
1. remove states after the enviroment is terminated.
2. compute the value at each state.
"""
buffer = []
N = len(data.terminated)
for i in range(N):
state = data.state[i]
is_terminated = data.terminated[i]
action_weights = data.action_weights[i]
reward = data.reward[i]
L = len(is_terminated)
value = None
for idx in reversed(range(L)):
if is_terminated[idx]:
continue
value = reward[idx] if value is None else -value
s = np.copy(state[idx])
a = np.copy(action_weights[idx])
for augmented_s, augmented_a in env.symmetries(s, a):
buffer.append(
TrainingExample(
state=augmented_s,
action_weights=augmented_a,
value=np.array(value, dtype=np.float32),
)
)
return buffer
def collect_self_play_data(
agent,
env,
rng_key: chex.Array,
batch_size: int,
data_size: int,
num_simulations_per_move: int,
):
"""Collect self-play data for training."""
N = data_size // batch_size
devices = jax.local_devices()
num_devices = len(devices)
rng_keys = jax.random.split(rng_key, N * num_devices)
rng_keys = jnp.stack(rng_keys).reshape((N, num_devices, -1))
data = []
with click.progressbar(range(N), label=" self play ") as bar:
for i in bar:
batch = collect_batched_self_play_data(
agent,
env,
rng_keys[i],
batch_size // num_devices,
num_simulations_per_move,
)
batch = jax.device_get(batch)
batch = jax.tree_util.tree_map(
lambda x: x.reshape((-1, *x.shape[2:])), batch
)
data.extend(prepare_training_data(batch, env=env))
return data
def loss_fn(net, data: TrainingExample):
"""Sum of value loss and policy loss."""
net, (action_logits, value) = batched_policy(net, data.state)
# value loss (mse)
mse_loss = optax.l2_loss(value, data.value)
mse_loss = jnp.mean(mse_loss)
# policy loss (KL(target_policy', agent_policy))
target_pr = data.action_weights
# to avoid log(0) = nan
target_pr = jnp.where(target_pr == 0, EPSILON, target_pr)
action_logits = jax.nn.log_softmax(action_logits, axis=-1)
kl_loss = jnp.sum(target_pr * (jnp.log(target_pr) - action_logits), axis=-1)
kl_loss = jnp.mean(kl_loss)
# return the total loss
return mse_loss + kl_loss, (net, (mse_loss, kl_loss))
@partial(jax.pmap, axis_name="i")
def train_step(net, optim, data: TrainingExample):
"""A training step."""
(_, (net, losses)), grads = jax.value_and_grad(loss_fn, has_aux=True)(net, data)
grads = jax.lax.pmean(grads, axis_name="i")
net, optim = opax.apply_gradients(net, optim, grads)
return net, optim, losses
def train(
game_class="connect_two_game.Connect2Game",
agent_class="mlp_policy.MlpPolicyValueNet",
selfplay_batch_size: int = 128,
training_batch_size: int = 128,
num_iterations: int = 100,
num_simulations_per_move: int = 32,
num_self_plays_per_iteration: int = 128 * 100,
learning_rate: float = 0.01,
ckpt_filename: str = "./agent.ckpt",
random_seed: int = 42,
weight_decay: float = 1e-4,
lr_decay_steps: int = 100_000,
):
"""Train an agent by self-play."""
env = import_class(game_class)()
agent = import_class(agent_class)(
input_dims=env.observation().shape,
num_actions=env.num_actions(),
)
def lr_schedule(step):
e = jnp.floor(step * 1.0 / lr_decay_steps)
return learning_rate * jnp.exp2(-e)
optim = opax.chain(
opax.add_decayed_weights(weight_decay),
opax.sgd(lr_schedule, momentum=0.9),
).init(agent.parameters())
if os.path.isfile(ckpt_filename):
print("Loading weights at", ckpt_filename)
with open(ckpt_filename, "rb") as f:
dic = pickle.load(f)
agent = agent.load_state_dict(dic["agent"])
optim = optim.load_state_dict(dic["optim"])
start_iter = dic["iter"] + 1
else:
start_iter = 0
rng_key = jax.random.PRNGKey(random_seed)
shuffler = random.Random(random_seed)
devices = jax.local_devices()
num_devices = jax.local_device_count()
def _stack_and_reshape(*xs):
x = np.stack(xs)
x = np.reshape(x, (num_devices, -1) + x.shape[1:])
return x
for iteration in range(start_iter, num_iterations):
print(f"Iteration {iteration}")
rng_key_1, rng_key_2, rng_key_3, rng_key = jax.random.split(rng_key, 4)
agent = agent.eval()
data = collect_self_play_data(
agent,
env,
rng_key_1,
selfplay_batch_size,
num_self_plays_per_iteration,
num_simulations_per_move,
)
data = list(data)
shuffler.shuffle(data)
old_agent = jax.tree_util.tree_map(lambda x: jnp.copy(x), agent)
agent, losses = agent.train(), []
agent, optim = jax.device_put_replicated((agent, optim), devices)
ids = range(0, len(data) - training_batch_size, training_batch_size)
with click.progressbar(ids, label=" train agent ") as progressbar:
for idx in progressbar:
batch = data[idx : (idx + training_batch_size)]
batch = jax.tree_util.tree_map(_stack_and_reshape, *batch)
agent, optim, loss = train_step(agent, optim, batch)
losses.append(loss)
value_loss, policy_loss = zip(*losses)
value_loss = np.mean(sum(jax.device_get(value_loss))) / len(value_loss)
policy_loss = np.mean(sum(jax.device_get(policy_loss))) / len(policy_loss)
agent, optim = jax.tree_util.tree_map(lambda x: x[0], (agent, optim))
win_count1, draw_count1, loss_count1 = agent_vs_agent_multiple_games(
agent.eval(), old_agent, env, rng_key_2
)
loss_count2, draw_count2, win_count2 = agent_vs_agent_multiple_games(
old_agent, agent.eval(), env, rng_key_3
)
print(
" evaluation {} win - {} draw - {} loss".format(
win_count1 + win_count2,
draw_count1 + draw_count2,
loss_count1 + loss_count2,
)
)
print(
f" value loss {value_loss:.3f}"
f" policy loss {policy_loss:.3f}"
f" learning rate {optim[1][-1].learning_rate:.1e}"
)
# save agent's weights to disk
with open(ckpt_filename, "wb") as f:
dic = {
"agent": jax.device_get(agent.state_dict()),
"optim": jax.device_get(optim.state_dict()),
"iter": iteration,
}
pickle.dump(dic, f)
print("Done!")
if __name__ == "__main__":
if "COLAB_TPU_ADDR" in os.environ:
jax.tools.colab_tpu.setup_tpu()
print("Cores:", jax.local_devices())
fire.Fire(train)