From 352aa6ff495c7cdca02890ec92feb01fd20c2e63 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 24 Aug 2023 13:26:29 +0000 Subject: [PATCH] refactor PPO, bump version --- cleanba/cleanba_impala.py | 8 +-- cleanba/cleanba_ppo.py | 135 ++++++++++++++++++++++---------------- pyproject.toml | 2 +- 3 files changed, 84 insertions(+), 61 deletions(-) diff --git a/cleanba/cleanba_impala.py b/cleanba/cleanba_impala.py index 1585185..7e4da2b 100644 --- a/cleanba/cleanba_impala.py +++ b/cleanba/cleanba_impala.py @@ -719,10 +719,10 @@ def update_minibatch(agent_state, minibatch): writer.add_scalar( "charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step ) - writer.add_scalar("losses/value_loss", v_loss.item(), global_step) - writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) - writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) - writer.add_scalar("losses/loss", loss.item(), global_step) + writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step) + writer.add_scalar("losses/loss", loss[-1].item(), global_step) if learner_policy_version >= args.num_updates: break diff --git a/cleanba/cleanba_ppo.py b/cleanba/cleanba_ppo.py index 0d588fd..b9b1281 100644 --- a/cleanba/cleanba_ppo.py +++ b/cleanba/cleanba_ppo.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field from types import SimpleNamespace from typing import List, NamedTuple, Optional, Sequence, Tuple +from functools import partial import envpool import flax @@ -61,7 +62,7 @@ class Args: "total timesteps of the experiments" learning_rate: float = 2.5e-4 "the learning rate of the optimizer" - local_num_envs: int = 64 + local_num_envs: int = 60 "the number of parallel game environments" num_actor_threads: int = 2 "the number of actor threads to use" @@ -214,7 +215,8 @@ class Transition(NamedTuple): obs: list dones: list actions: list - logitss: list + logprobs: list + values: list env_ids: list rewards: list truncations: list @@ -242,7 +244,7 @@ def rollout( start_time = time.time() @jax.jit - def get_action( + def get_action_and_value( params: flax.core.FrozenDict, next_obs: np.ndarray, key: jax.random.PRNGKey, @@ -255,20 +257,22 @@ def get_action( key, subkey = jax.random.split(key) u = jax.random.uniform(subkey, shape=logits.shape) action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1) - return next_obs, action, logits, key + logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] + value = Critic().apply(params.critic_params, hidden) + return next_obs, action, logprob, value.squeeze(), key # put data in the last index episode_returns = np.zeros((args.local_num_envs,), dtype=np.float32) returned_episode_returns = np.zeros((args.local_num_envs,), dtype=np.float32) episode_lengths = np.zeros((args.local_num_envs,), dtype=np.float32) returned_episode_lengths = np.zeros((args.local_num_envs,), dtype=np.float32) - envs.async_reset() params_queue_get_time = deque(maxlen=10) rollout_time = deque(maxlen=10) rollout_queue_put_time = deque(maxlen=10) actor_policy_version = 0 - storage = [] + next_obs = envs.reset() + next_done = jnp.zeros(args.local_num_envs, dtype=jax.numpy.bool_) @jax.jit def prepare_data(storage: List[Transition]) -> Transition: @@ -281,9 +285,6 @@ def prepare_data(storage: List[Transition]) -> Transition: storage_time = 0 d2h_time = 0 env_send_time = 0 - num_steps_with_bootstrap = ( - args.num_steps + 1 + int(len(storage) == 0) - ) # num_steps + 1 to get the states for value bootstrapping. # NOTE: `update != 2` is actually IMPORTANT — it allows us to start running policy collection # concurrently with the learning process. It also ensures the actor's policy version is only 1 step # behind the learner's policy version @@ -292,8 +293,8 @@ def prepare_data(storage: List[Transition]) -> Transition: if update != 2: params = params_queue.get() # NOTE: block here is important because otherwise this thread will call - # the jitted `get_action` function that hangs until the params are ready. - # This blocks the `get_action` function in other actor threads. + # the jitted `get_action_and_value` function that hangs until the params are ready. + # This blocks the `get_action_and_value` function in other actor threads. # See https://excalidraw.com/#json=hSooeQL707gE5SWY8wOSS,GeaN1eb2r24PPi75a3n14Q for a visual explanation. params.network_params["params"]["Dense_0"][ "kernel" @@ -304,22 +305,22 @@ def prepare_data(storage: List[Transition]) -> Transition: actor_policy_version += 1 params_queue_get_time.append(time.time() - params_queue_get_time_start) rollout_time_start = time.time() - for _ in range(1, num_steps_with_bootstrap): - env_recv_time_start = time.time() - next_obs, next_reward, next_done, info = envs.recv() - env_recv_time += time.time() - env_recv_time_start + storage = [] + for _ in range(0, args.num_steps): + cached_next_obs = next_obs + cached_next_done = next_done global_step += len(next_done) * args.num_actor_threads * len_actor_device_ids * args.world_size - env_id = info["env_id"] - inference_time_start = time.time() - next_obs, action, logits, key = get_action(params, next_obs, key) + cached_next_obs, action, logprob, value, key = get_action_and_value(params, cached_next_obs, key) inference_time += time.time() - inference_time_start d2h_time_start = time.time() cpu_action = np.array(action) d2h_time += time.time() - d2h_time_start + env_send_time_start = time.time() - envs.send(cpu_action, env_id) + next_obs, next_reward, next_done, info = envs.step(cpu_action) + env_id = info["env_id"] env_send_time += time.time() - env_send_time_start storage_time_start = time.time() @@ -328,10 +329,11 @@ def prepare_data(storage: List[Transition]) -> Transition: truncated = info["elapsed_step"] >= envs.spec.config.max_episode_steps storage.append( Transition( - obs=next_obs, - dones=next_done, + obs=cached_next_obs, + dones=cached_next_done, actions=action, - logitss=logits, + logprobs=logprob, + values=value, env_ids=env_id, rewards=next_reward, truncations=truncated, @@ -357,11 +359,16 @@ def prepare_data(storage: List[Transition]) -> Transition: sharded_storage = Transition( *list(map(lambda x: jax.device_put_sharded(x, devices=learner_devices), partitioned_storage)) ) + # next_obs, next_done are still in the host + sharded_next_obs = jax.device_put_sharded(np.split(next_obs, len(learner_devices)), devices=learner_devices) + sharded_next_done = jax.device_put_sharded(np.split(next_done, len(learner_devices)), devices=learner_devices) payload = ( global_step, actor_policy_version, update, sharded_storage, + sharded_next_obs, + sharded_next_done, np.mean(params_queue_get_time), device_thread_id, ) @@ -369,9 +376,6 @@ def prepare_data(storage: List[Transition]) -> Transition: rollout_queue.put(payload) rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start) - # move bootstrapping step to the beginning of the next update - storage = storage[-1:] - if update % args.log_frequency == 0: if device_thread_id == 0: print( @@ -526,16 +530,38 @@ def get_logprob_entropy_value( value = Critic().apply(params.critic_params, hidden).squeeze(-1) return logprob, entropy, value + def compute_gae_once(carry, inp, gamma, gae_lambda): + advantages = carry + nextdone, nextvalues, curvalues, reward = inp + nextnonterminal = 1.0 - nextdone + + delta = reward + gamma * nextvalues * nextnonterminal - curvalues + advantages = delta + gamma * gae_lambda * nextnonterminal * advantages + return advantages, advantages + + compute_gae_once = partial(compute_gae_once, gamma=args.gamma, gae_lambda=args.gae_lambda) + + @jax.jit + def compute_gae( + agent_state: TrainState, + next_obs: np.ndarray, + next_done: np.ndarray, + storage: Transition, + ): + next_value = critic.apply( + agent_state.params.critic_params, network.apply(agent_state.params.network_params, next_obs) + ).squeeze() + + advantages = jnp.zeros((args.local_num_envs,)) + dones = jnp.concatenate([storage.dones, next_done[None, :]], axis=0) + values = jnp.concatenate([storage.values, next_value[None, :]], axis=0) + _, advantages = jax.lax.scan( + compute_gae_once, advantages, (dones[1:], values[1:], values[:-1], storage.rewards), reverse=True + ) + return advantages, advantages + storage.values + def ppo_loss(params, obs, actions, behavior_logprobs, firststeps, advantages, target_values): - # TODO: figure out when to use `mask` - # mask = 1.0 - firststeps newlogprob, entropy, newvalue = jax.vmap(get_logprob_entropy_value, in_axes=(None, 0, 0))(params, obs, actions) - behavior_logprobs = behavior_logprobs[:-1] - newlogprob = newlogprob[:-1] - entropy = entropy[:-1] - actions = actions[:-1] - # mask = mask[:-1] - logratio = newlogprob - behavior_logprobs ratio = jnp.exp(logratio) approx_kl = ((ratio - 1) - logratio).mean() @@ -546,7 +572,7 @@ def ppo_loss(params, obs, actions, behavior_logprobs, firststeps, advantages, ta pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean() # Value loss - v_loss = 0.5 * ((newvalue[:-1] - target_values) ** 2).mean() + v_loss = 0.5 * ((newvalue - target_values) ** 2).mean() entropy_loss = entropy.mean() loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) @@ -555,26 +581,15 @@ def ppo_loss(params, obs, actions, behavior_logprobs, firststeps, advantages, ta def single_device_update( agent_state: TrainState, sharded_storages: List, + sharded_next_obs: List, + sharded_next_done: List, key: jax.random.PRNGKey, ): storage = jax.tree_map(lambda *x: jnp.hstack(x), *sharded_storages) + next_obs = jnp.concatenate(sharded_next_obs) + next_done = jnp.concatenate(sharded_next_done) ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) - behavior_logprobs = jax.vmap(lambda logits, action: jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action])( - storage.logitss, storage.actions - ) - values = jax.vmap(get_value, in_axes=(None, 0))(agent_state.params, storage.obs) - discounts = (1.0 - storage.dones) * args.gamma - - def gae_advantages(rewards: jnp.array, discounts: jnp.array, values: jnp.array) -> Tuple[jnp.ndarray, jnp.array]: - advantages = rlax.truncated_generalized_advantage_estimation(rewards, discounts, args.gae_lambda, values) - advantages = jax.lax.stop_gradient(advantages) - target_values = values[:-1] + advantages - target_values = jax.lax.stop_gradient(target_values) - return advantages, target_values - - advantages, target_values = jax.vmap(gae_advantages, in_axes=1, out_axes=1)( - storage.rewards[:-1], discounts[:-1], values - ) + advantages, target_values = compute_gae(agent_state, next_obs, next_done, storage) # NOTE: notable implementation difference: we normalize advantage at the batch level if args.norm_adv: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) @@ -604,7 +619,7 @@ def update_minibatch(agent_state, minibatch): ( jnp.array(jnp.split(storage.obs, args.num_minibatches * args.gradient_accumulation_steps, axis=1)), jnp.array(jnp.split(storage.actions, args.num_minibatches * args.gradient_accumulation_steps, axis=1)), - jnp.array(jnp.split(behavior_logprobs, args.num_minibatches * args.gradient_accumulation_steps, axis=1)), + jnp.array(jnp.split(storage.logprobs, args.num_minibatches * args.gradient_accumulation_steps, axis=1)), jnp.array(jnp.split(storage.firststeps, args.num_minibatches * args.gradient_accumulation_steps, axis=1)), jnp.array(jnp.split(advantages, args.num_minibatches * args.gradient_accumulation_steps, axis=1)), jnp.array(jnp.split(target_values, args.num_minibatches * args.gradient_accumulation_steps, axis=1)), @@ -661,6 +676,8 @@ def update_minibatch(agent_state, minibatch): learner_policy_version += 1 rollout_queue_get_time_start = time.time() sharded_storages = [] + sharded_next_obss = [] + sharded_next_dones = [] for d_idx, d_id in enumerate(args.actor_device_ids): for thread_id in range(args.num_actor_threads): ( @@ -668,15 +685,21 @@ def update_minibatch(agent_state, minibatch): actor_policy_version, update, sharded_storage, + sharded_next_obs, + sharded_next_done, avg_params_queue_get_time, device_thread_id, ) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get() sharded_storages.append(sharded_storage) + sharded_next_obss.append(sharded_next_obs) + sharded_next_dones.append(sharded_next_done) rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) training_time_start = time.time() (agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update( agent_state, sharded_storages, + sharded_next_obss, + sharded_next_dones, learner_keys, ) unreplicated_params = flax.jax_utils.unreplicate(agent_state.params) @@ -703,11 +726,11 @@ def update_minibatch(agent_state, minibatch): writer.add_scalar( "charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step ) - writer.add_scalar("losses/value_loss", v_loss.item(), global_step) - writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) - writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) - writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) - writer.add_scalar("losses/loss", loss.item(), global_step) + writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step) + writer.add_scalar("losses/loss", loss[-1].item(), global_step) if learner_policy_version >= args.num_updates: break diff --git a/pyproject.toml b/pyproject.toml index ee7b359..5099a14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cleanba" -version = "1.0.0b2" +version = "1.0.0b3" description = "" authors = ["Costa Huang "] readme = "README.md"