-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain_agent.py
197 lines (171 loc) · 6.51 KB
/
train_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import jax
import sys
import jax.numpy as jnp
from jax.config import config as jax_config
from util import *
from rl.agents import create_agent_train_state, make_train_step, get_agent
from environments.rollout import GymRolloutWrapper
from environments.offline_rollout import OfflineRolloutGenerator
from diffusion.rollout_generator import MixedRolloutGenerator
def make_train(args):
"""Makes agent train function."""
def _init_env(rng):
# --- Initialize real environment for evaluation ---
env = GymRolloutWrapper(args.env_name, num_env_workers=args.num_env_workers)
if args.synthetic_experience:
# --- Initialize mixed (synthetic + dataset) sampler ---
rng, _rng = jax.random.split(rng)
rollout_gen = MixedRolloutGenerator(
_rng,
args,
env.obs_shape,
env.action_dim,
env.action_lims,
args.num_rollout_steps,
)
else:
# --- Initialize dataset sampler ---
rollout_gen = OfflineRolloutGenerator(
args,
env.obs_shape,
env.action_dim,
env.action_lims,
args.num_rollout_steps,
)
return rollout_gen, env
def _init_agent(rng, rollout_gen):
# --- Get agent networks ---
network, aux_networks = get_agent(
args,
rollout_gen.action_dim,
rollout_gen.action_lims,
obs_stats=rollout_gen.obs_stats if args.normalize_obs else None,
)
if isinstance(network, dict):
eval_apply_fn = network["eval"].apply
network = network["train"]
else:
eval_apply_fn = network.apply
# --- Create agent train states ---
rng, _rng = jax.random.split(rng)
train_state = create_agent_train_state(
_rng, network, args, rollout_gen.obs_shape
)
if aux_networks is None:
return train_state, None, network, aux_networks, eval_apply_fn
# --- Create auxiliary train states ---
aux_train_states = []
for net in aux_networks:
rng, _rng = jax.random.split(rng)
ts = create_agent_train_state(
_rng, net, args, rollout_gen.obs_shape, rollout_gen.action_dim
)
aux_train_states.append(ts)
aux_train_states = tuple(aux_train_states)
return train_state, aux_train_states, network, aux_networks, eval_apply_fn
def train(rng):
# --- Initialize environment ---
rng, _rng = jax.random.split(rng)
rollout_gen, env = _init_env(_rng)
# --- Initialize agent (policy + value) and auxiliary networks ---
rng, _rng = jax.random.split(rng)
(
train_state,
aux_train_states,
network,
aux_networks,
eval_apply_fn,
) = _init_agent(_rng, rollout_gen)
rollout_gen.set_apply_fn(jax.jit(train_state.apply_fn))
env.set_apply_fn(jax.jit(eval_apply_fn))
_agent_train_step_fn = jax.jit(make_train_step(args, network, aux_networks))
losses, metrics = [], []
for step_idx in range(args.num_train_steps):
# --- Sample batch and update agent ---
rng, _rng = jax.random.split(rng)
traj_batch = rollout_gen.batch_rollout(_rng)
rng, _rng = jax.random.split(rng)
train_state, aux_train_states, loss, metric = _agent_train_step_fn(
train_state, aux_train_states, traj_batch, _rng
)
losses.append(loss)
# --- Evaluate agent ---
if step_idx % args.eval_rate == 0:
rng, _rng = jax.random.split(rng)
eval_traj_batch = eval_agents(
rng,
env,
train_state,
args.num_env_workers,
)
info = eval_traj_batch.info
metric = {
"num_updates": train_state.step,
"returned_episode_returns": jnp.nanmean(
info["returned_episode_returns"]
),
"returned_episode_scores": jnp.nanmean(
info["returned_episode_scores"]
),
}
metrics.append(metric)
# --- Regenerate synthetic dataset (if not finished) ---
if (
args.synthetic_experience
and step_idx % args.synth_dataset_lifetime == 0
and step_idx != 0
):
rng, _rng = jax.random.split(rng)
rollout_gen.update_synthetic_dataset(_rng, train_state.params)
return metrics, losses
return train
def train_agents(args):
rng = jax.random.PRNGKey(args.seed)
# --- Train agent and log metrics ---
train_fn = make_train(args)
metric, loss = train_fn(rng)
if args.log:
# --- Compute mean return and score per step ---
returns = [met["returned_episode_returns"] for met in metric]
scores = [met["returned_episode_scores"] for met in metric]
num_updates = [met["num_updates"] for met in metric]
# --- Subsample steps for logging ---
if len(returns) > MAX_LOG_STEPS:
steps = jnp.linspace(0, len(returns), MAX_LOG_STEPS, dtype=jnp.int32)
else:
steps = jnp.arange(len(returns))
# --- Log step metrics ---
for step in steps:
# Log nearest step with return value
log(
{
"episode_return": returns[step],
"episode_score": scores[step],
"step": step,
"num_updates": num_updates[step],
**loss[step * args.eval_rate],
}
)
def main(cmd_args=sys.argv[1:]):
# --- Parse arguments and initialize logging ---
args = parse_agent_args(cmd_args)
if args.log:
wandb.init(
config=args,
project=args.wandb_project,
entity=args.wandb_team,
group=args.wandb_group,
job_type="train_agent",
)
debug = args.debug
debug_nans = args.debug_nans
if debug_nans:
jax_config.update("jax_debug_nans", True)
# --- Launch training ---
if debug:
with jax.disable_jit():
return train_agents(args)
else:
return train_agents(args)
if __name__ == "__main__":
main()