From 7fe570c27f68f15e303a69de4cf9bc6341218052 Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Sun, 16 Jun 2024 22:48:37 +0100 Subject: [PATCH 01/38] Convert ffm memoroid to flax, it seems to work --- stoix/networks/ffm.py | 143 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 stoix/networks/ffm.py diff --git a/stoix/networks/ffm.py b/stoix/networks/ffm.py new file mode 100644 index 00000000..477d6768 --- /dev/null +++ b/stoix/networks/ffm.py @@ -0,0 +1,143 @@ +from typing import Tuple +import flax.linen as nn +import jax +import jax.numpy as jnp + +class Gate(nn.Module): + output_size: int + + @nn.compact + def __call__(self, x): + x = nn.Dense(self.output_size)(x) + x = nn.sigmoid(x) + return x + +def init_deterministic( + memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 +) -> Tuple[jax.Array, jax.Array]: + a_low = 1e-6 + a_high = 0.5 + a = jnp.linspace(a_low, a_high, memory_size) + b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) + return a, b + +class FFM(nn.Module): + """Feedforward Memory Network.""" + input_size: int + trace_size: int + context_size: int + output_size: int + + def setup(self): + self.pre = nn.Dense(self.trace_size) + self.gate_in = Gate(self.trace_size) + self.gate_out = Gate(self.output_size) + self.skip = nn.Dense(self.output_size) + a, b = init_deterministic(self.trace_size, self.context_size) + self.ffa_params = ( + self.param('ffa_a', lambda rng: a), + self.param('ffa_b', lambda rng: b) + ) + self.mix = nn.Dense(self.output_size) + self.ln = nn.LayerNorm(use_scale=False, use_bias=False) + + def log_gamma(self, t: jax.Array) -> jax.Array: + a, b = self.ffa_params + a = -jnp.abs(a).reshape((1, self.trace_size, 1)) + b = b.reshape(1, 1, self.context_size) + ab = jax.lax.complex(a, b) + return ab * t.reshape(t.shape[0], 1, 1) + + def gamma(self, t: jax.Array) -> jax.Array: + return jnp.exp(self.log_gamma(t)) + + def unwrapped_associative_update( + self, + carry: Tuple[jax.Array, jax.Array, jax.Array], + incoming: Tuple[jax.Array, jax.Array, jax.Array], + ) -> Tuple[jax.Array, jax.Array, jax.Array]: + state, i, = carry + x, j = incoming + state = state * self.gamma(j) + x + return state, j + i + + def wrapped_associative_update(self, carry, incoming): + prev_start, state, i = carry + start, x, j = incoming + # Reset all elements in the carry if we are starting a new episode + state = state * jnp.logical_not(start) + j = j * jnp.logical_not(start) + incoming = x, j + carry = (state, i) + out = self.unwrapped_associative_update(carry, incoming) + start_out = jnp.logical_or(start, prev_start) + return (start_out, *out) + + def scan( + self, + x: jax.Array, + state: jax.Array, + start: jax.Array, + ) -> jax.Array: + """Given an input and recurrent state, this will update the recurrent state. This is equivalent + to the inner-function g in the paper.""" + # x: [T, memory_size] + # memory: [1, memory_size, context_size] + T = x.shape[0] + timestep = jnp.ones(T + 1, dtype=jnp.int32).reshape(-1, 1, 1) + # Add context dim + start = start.reshape(T, 1, 1) + + # Now insert previous recurrent state + x = jnp.concatenate([state, x], axis=0) + start = jnp.concatenate([jnp.zeros_like(start[:1]), start], axis=0) + + # This is not executed during inference -- method will just return x if size is 1 + _, new_state, _ = jax.lax.associative_scan( + self.wrapped_associative_update, + (start, x, timestep), + axis=0, + ) + return new_state[1:] + + def map_to_h(self, x): + gate_in = self.gate_in(x) + pre = self.pre(x) + gated_x = pre * gate_in + scan_input = jnp.repeat(jnp.expand_dims(gated_x, 2), self.context_size, axis=2) + return scan_input + + def map_from_h(self, recurrent_state, x): + z_in = jnp.concatenate([jnp.real(recurrent_state), jnp.imag(recurrent_state)], axis=-1).reshape( + recurrent_state.shape[0], -1 + ) + z = self.mix(z_in) + gate_out = self.gate_out(x) + skip = self.skip(x) + out = self.ln(z * gate_out) + skip * (1 - gate_out) + return out + + def __call__(self, x, recurrent_state, start): + z = self.map_to_h(x) + recurrent_state = self.scan(z, recurrent_state, start) + out = self.map_from_h(recurrent_state, x) + final_state = recurrent_state[-1:] + return out, final_state + + def initial_state(self, shape=tuple()): + return jnp.zeros((*shape, 1, self.trace_size, self.context_size), dtype=jnp.complex64) + + + +if __name__ == "__main__": + m = FFM( + input_size=2, + output_size=4, + trace_size=5, + context_size=6, + ) + s = m.initial_state() + x = jnp.ones((10, 2)) + start = jnp.zeros(10, dtype=bool) + params = m.init(jax.random.PRNGKey(0), x, s, start) + out = m.apply(params, x, s, start) From 229e7f1322da800253ab47403ba124062b9fcdf2 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Fri, 21 Jun 2024 17:03:49 +0000 Subject: [PATCH 02/38] temporary work --- stoix/networks/base.py | 74 +++ stoix/networks/ffm.py | 44 +- stoix/systems/ppo/rec_ppo_temp_ffm.py | 763 ++++++++++++++++++++++++++ 3 files changed, 871 insertions(+), 10 deletions(-) create mode 100644 stoix/systems/ppo/rec_ppo_temp_ffm.py diff --git a/stoix/networks/base.py b/stoix/networks/base.py index 8a8afec3..1bc190d8 100644 --- a/stoix/networks/base.py +++ b/stoix/networks/base.py @@ -9,6 +9,7 @@ from flax import linen as nn from stoix.base_types import Observation, RNNObservation +from stoix.networks.ffm import FFM from stoix.networks.inputs import ObservationInput from stoix.networks.utils import parse_rnn_cell @@ -182,3 +183,76 @@ def __call__( critic_output = self.critic_head(critic_output) return critic_hidden_state, critic_output + + +class RecurrentActorFFM(nn.Module): + """Recurrent Actor Network.""" + + action_head: nn.Module + post_torso: nn.Module + hidden_state_dim: int + cell_type: str + pre_torso: nn.Module + input_layer: nn.Module = ObservationInput() + + @nn.compact + def __call__( + self, + policy_hidden_state: chex.Array, + observation_done: RNNObservation, + ) -> Tuple[chex.Array, distrax.DistributionLike]: + + observation, done = observation_done + + observation = self.input_layer(observation) + policy_embedding = self.pre_torso(observation) + policy_rnn_input = (policy_embedding, done) + BatchFFM = nn.vmap( + FFM, + in_axes=1, out_axes=1, + variable_axes={'params': None}, + split_rngs={'params': False}) + policy_hidden_state, policy_embedding = BatchFFM(self.hidden_state_dim, self.hidden_state_dim, self.hidden_state_dim)( + policy_hidden_state, policy_rnn_input + ) + actor_logits = self.post_torso(policy_embedding) + pi = self.action_head(actor_logits) + + return policy_hidden_state, pi + + +class RecurrentCriticFFM(nn.Module): + """Recurrent Critic Network.""" + + critic_head: nn.Module + post_torso: nn.Module + hidden_state_dim: int + cell_type: str + pre_torso: nn.Module + input_layer: nn.Module = ObservationInput() + + @nn.compact + def __call__( + self, + critic_hidden_state: Tuple[chex.Array, chex.Array], + observation_done: RNNObservation, + ) -> Tuple[chex.Array, chex.Array]: + + observation, done = observation_done + + observation = self.input_layer(observation) + + critic_embedding = self.pre_torso(observation) + critic_rnn_input = (critic_embedding, done) + BatchFFM = nn.vmap( + FFM, + in_axes=1, out_axes=1, + variable_axes={'params': None}, + split_rngs={'params': False}) + critic_hidden_state, critic_embedding = BatchFFM(self.hidden_state_dim, self.hidden_state_dim, self.hidden_state_dim)( + critic_hidden_state, critic_rnn_input + ) + critic_output = self.post_torso(critic_embedding) + critic_output = self.critic_head(critic_output) + + return critic_hidden_state, critic_output \ No newline at end of file diff --git a/stoix/networks/ffm.py b/stoix/networks/ffm.py index 477d6768..2404b777 100644 --- a/stoix/networks/ffm.py +++ b/stoix/networks/ffm.py @@ -23,7 +23,6 @@ def init_deterministic( class FFM(nn.Module): """Feedforward Memory Network.""" - input_size: int trace_size: int context_size: int output_size: int @@ -117,27 +116,52 @@ def map_from_h(self, recurrent_state, x): out = self.ln(z * gate_out) + skip * (1 - gate_out) return out - def __call__(self, x, recurrent_state, start): + def __call__(self, recurrent_state, inputs): + x, resets = inputs z = self.map_to_h(x) - recurrent_state = self.scan(z, recurrent_state, start) + recurrent_state = self.scan(z, recurrent_state, resets) out = self.map_from_h(recurrent_state, x) final_state = recurrent_state[-1:] - return out, final_state + return final_state, out - def initial_state(self, shape=tuple()): - return jnp.zeros((*shape, 1, self.trace_size, self.context_size), dtype=jnp.complex64) + def initialize_carry(self, batch_size: int = None): + if batch_size is None: + return jnp.zeros((1, self.trace_size, self.context_size), dtype=jnp.complex64) + + return jnp.zeros((1, batch_size, self.trace_size, self.context_size), dtype=jnp.complex64) + if __name__ == "__main__": m = FFM( - input_size=2, output_size=4, trace_size=5, context_size=6, ) - s = m.initial_state() + s = m.initialize_carry() x = jnp.ones((10, 2)) start = jnp.zeros(10, dtype=bool) - params = m.init(jax.random.PRNGKey(0), x, s, start) - out = m.apply(params, x, s, start) + params = m.init(jax.random.PRNGKey(0), s, (x, start)) + out_state, out = m.apply(params, s, (x, start)) + + + BatchFFM = nn.vmap( + FFM, + in_axes=1, out_axes=1, + variable_axes={'params': None}, + split_rngs={'params': False}) + + m = BatchFFM( + output_size=4, + trace_size=5, + context_size=6, + ) + + s = m.initialize_carry(8) + x = jnp.ones((10, 8, 2)) + start = jnp.zeros((10, 8), dtype=bool) + params = m.init(jax.random.PRNGKey(0), s, (x, start)) + out_state, out = m.apply(params, s, (x, start)) + + diff --git a/stoix/systems/ppo/rec_ppo_temp_ffm.py b/stoix/systems/ppo/rec_ppo_temp_ffm.py new file mode 100644 index 00000000..29513eb6 --- /dev/null +++ b/stoix/systems/ppo/rec_ppo_temp_ffm.py @@ -0,0 +1,763 @@ +import copy +import time +from typing import Any, Dict, Tuple + +import chex +import flax +import hydra +import jax +import jax.numpy as jnp +import optax +from colorama import Fore, Style +from flax.core.frozen_dict import FrozenDict +from jumanji.env import Environment +from omegaconf import DictConfig, OmegaConf +from rich.pretty import pprint + +from stoix.base_types import ( + ActorCriticOptStates, + ActorCriticParams, + ExperimentOutput, + LearnerFn, + RecActorApply, + RecCriticApply, + RNNLearnerState, +) +from stoix.evaluator import evaluator_setup, get_rec_distribution_act_fn +from stoix.networks.base import RecurrentActor, RecurrentActorFFM, RecurrentCritic, RecurrentCriticFFM, ScannedRNN +from stoix.networks.ffm import FFM +from stoix.systems.ppo.ppo_types import ActorCriticHiddenStates, RNNPPOTransition +from stoix.utils import make_env as environments +from stoix.utils.checkpointing import Checkpointer +from stoix.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims +from stoix.utils.logger import LogEvent, StoixLogger +from stoix.utils.loss import clipped_value_loss, ppo_clip_loss +from stoix.utils.multistep import batch_truncated_generalized_advantage_estimation +from stoix.utils.total_timestep_checker import check_total_timesteps +from stoix.utils.training import make_learning_rate +from stoix.wrappers.episode_metrics import get_final_step_metrics + + +def get_learner_fn( + env: Environment, + apply_fns: Tuple[RecActorApply, RecCriticApply], + update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], + config: DictConfig, +) -> LearnerFn[RNNLearnerState]: + """Get the learner function.""" + + actor_apply_fn, critic_apply_fn = apply_fns + actor_update_fn, critic_update_fn = update_fns + + def _update_step(learner_state: RNNLearnerState, _: Any) -> Tuple[RNNLearnerState, Tuple]: + """A single update of the network. + + This function steps the environment and records the trajectory batch for + training. It then calculates advantages and targets based on the recorded + trajectory and updates the actor and critic networks based on the calculated + losses. + + Args: + learner_state (NamedTuple): + - params (ActorCriticParams): The current model parameters. + - opt_states (OptStates): The current optimizer states. + - key (PRNGKey): The random number generator state. + - env_state (State): The environment state. + - last_timestep (TimeStep): The last timestep in the current trajectory. + - dones (bool): Whether the last timestep was a terminal state. + - hstates (ActorCriticHiddenStates): The current hidden states of the RNN. + _ (Any): The current metrics info. + """ + + def _env_step( + learner_state: RNNLearnerState, _: Any + ) -> Tuple[RNNLearnerState, RNNPPOTransition]: + """Step the environment.""" + ( + params, + opt_states, + key, + env_state, + last_timestep, + last_done, + last_truncated, + hstates, + ) = learner_state + + key, policy_key = jax.random.split(key) + + # Add a batch dimension to the observation. + batched_observation = jax.tree_util.tree_map( + lambda x: x[jnp.newaxis, :], last_timestep.observation + ) + ac_in = ( + batched_observation, + last_done[jnp.newaxis, :], + ) + + jax.debug.print("ac_in {x} {y}", x=ac_in[0].agent_view.shape, y=ac_in[1].shape) + + # Run the network. + policy_hidden_state, actor_policy = actor_apply_fn( + params.actor_params, hstates.policy_hidden_state, ac_in + ) + critic_hidden_state, value = critic_apply_fn( + params.critic_params, hstates.critic_hidden_state, ac_in + ) + + # Sample action from the policy and squeeze out the batch dimension. + action = actor_policy.sample(seed=policy_key) + log_prob = actor_policy.log_prob(action) + value, action, log_prob = ( + value.squeeze(0), + action.squeeze(0), + log_prob.squeeze(0), + ) + + # Step the environment. + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) + + # log episode return and length + done = (timestep.discount == 0.0).reshape(-1) + truncated = (timestep.last() & (timestep.discount != 0.0)).reshape(-1) + info = timestep.extras["episode_metrics"] + + hstates = ActorCriticHiddenStates(policy_hidden_state, critic_hidden_state) + transition = RNNPPOTransition( + last_done, + last_truncated, + action, + value, + timestep.reward, + log_prob, + last_timestep.observation, + hstates, + info, + ) + learner_state = RNNLearnerState( + params, + opt_states, + key, + env_state, + timestep, + done, + truncated, + hstates, + ) + return learner_state, transition + + # INITIALISE RNN STATE + initial_hstates = learner_state.hstates + + # STEP ENVIRONMENT FOR ROLLOUT LENGTH + learner_state, traj_batch = jax.lax.scan( + _env_step, learner_state, None, config.system.rollout_length + ) + + # CALCULATE ADVANTAGE + ( + params, + opt_states, + key, + env_state, + last_timestep, + last_done, + last_truncated, + hstates, + ) = learner_state + + # Add a batch dimension to the observation. + batched_last_observation = jax.tree_util.tree_map( + lambda x: x[jnp.newaxis, :], last_timestep.observation + ) + ac_in = ( + batched_last_observation, + last_done[jnp.newaxis, :], + ) + + # Run the network. + _, last_val = critic_apply_fn(params.critic_params, hstates.critic_hidden_state, ac_in) + # Squeeze out the batch dimension and mask out the value of terminal states. + last_val = last_val.squeeze(0) + last_val = jnp.where(last_done, jnp.zeros_like(last_val), last_val) + + r_t = traj_batch.reward + v_t = jnp.concatenate([traj_batch.value, last_val[None, ...]], axis=0) + d_t = 1.0 - traj_batch.done.astype(jnp.float32) + d_t = (d_t * config.system.gamma).astype(jnp.float32) + advantages, targets = batch_truncated_generalized_advantage_estimation( + r_t, + d_t, + config.system.gae_lambda, + v_t, + time_major=True, + standardize_advantages=config.system.standardize_advantages, + truncation_flags=traj_batch.truncated, + ) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: + """Update the network for a single minibatch.""" + + params, opt_states = train_state + ( + traj_batch, + advantages, + targets, + ) = batch_info + + def _actor_loss_fn( + actor_params: FrozenDict, + traj_batch: RNNPPOTransition, + gae: chex.Array, + ) -> Tuple: + """Calculate the actor loss.""" + # RERUN NETWORK + + obs_and_done = (traj_batch.obs, traj_batch.done) + policy_hidden_state = jax.tree_util.tree_map( + lambda x: x[0], traj_batch.hstates.policy_hidden_state + ) + _, actor_policy = actor_apply_fn( + actor_params, policy_hidden_state, obs_and_done + ) + log_prob = actor_policy.log_prob(traj_batch.action) + + loss_actor = ppo_clip_loss( + log_prob, traj_batch.log_prob, gae, config.system.clip_eps + ) + entropy = actor_policy.entropy().mean() + + total_loss = loss_actor - config.system.ent_coef * entropy + loss_info = { + "actor_loss": loss_actor, + "entropy": entropy, + } + return total_loss, loss_info + + def _critic_loss_fn( + critic_params: FrozenDict, + traj_batch: RNNPPOTransition, + targets: chex.Array, + ) -> Tuple: + """Calculate the critic loss.""" + # RERUN NETWORK + obs_and_done = (traj_batch.obs, traj_batch.done) + critic_hidden_state = jax.tree_util.tree_map( + lambda x: x[0], traj_batch.hstates.critic_hidden_state + ) + _, value = critic_apply_fn(critic_params, critic_hidden_state, obs_and_done) + + # CALCULATE VALUE LOSS + value_loss = clipped_value_loss( + value, traj_batch.value, targets, config.system.clip_eps + ) + + total_loss = config.system.vf_coef * value_loss + loss_info = { + "value_loss": value_loss, + } + return total_loss, loss_info + + # CALCULATE ACTOR LOSS + actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True) + actor_grads, actor_loss_info = actor_grad_fn( + params.actor_params, traj_batch, advantages + ) + + # CALCULATE CRITIC LOSS + critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True) + critic_grads, critic_loss_info = critic_grad_fn( + params.critic_params, traj_batch, targets + ) + + # Compute the parallel mean (pmean) over the batch. + # This calculation is inspired by the Anakin architecture demo notebook. + # available at https://tinyurl.com/26tdzs5x + # This pmean could be a regular mean as the batch axis is on the same device. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="batch" + ) + # pmean over devices. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="device" + ) + + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="batch" + ) + # pmean over devices. + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="device" + ) + + # UPDATE ACTOR PARAMS AND OPTIMISER STATE + actor_updates, actor_new_opt_state = actor_update_fn( + actor_grads, opt_states.actor_opt_state + ) + actor_new_params = optax.apply_updates(params.actor_params, actor_updates) + + # UPDATE CRITIC PARAMS AND OPTIMISER STATE + critic_updates, critic_new_opt_state = critic_update_fn( + critic_grads, opt_states.critic_opt_state + ) + critic_new_params = optax.apply_updates(params.critic_params, critic_updates) + + new_params = ActorCriticParams(actor_new_params, critic_new_params) + new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) + + # PACK LOSS INFO + loss_info = { + **actor_loss_info, + **critic_loss_info, + } + + return (new_params, new_opt_state), loss_info + + ( + params, + opt_states, + init_hstates, + traj_batch, + advantages, + targets, + key, + ) = update_state + key, shuffle_key = jax.random.split(key) + + # SHUFFLE MINIBATCHES + batch = (traj_batch, advantages, targets) + num_recurrent_chunks = ( + config.system.rollout_length // config.system.recurrent_chunk_size + ) + batch = jax.tree_util.tree_map( + lambda x: x.reshape( + config.system.recurrent_chunk_size, + config.arch.num_envs * num_recurrent_chunks, + *x.shape[2:], + ), + batch, + ) + permutation = jax.random.permutation( + shuffle_key, config.arch.num_envs * num_recurrent_chunks + ) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=1), batch + ) + reshaped_batch = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, (x.shape[0], config.system.num_minibatches, -1, *x.shape[2:]) + ), + shuffled_batch, + ) + minibatches = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 0), reshaped_batch) + + # UPDATE MINIBATCHES + (params, opt_states), loss_info = jax.lax.scan( + _update_minibatch, (params, opt_states), minibatches + ) + + update_state = ( + params, + opt_states, + init_hstates, + traj_batch, + advantages, + targets, + key, + ) + return update_state, loss_info + + init_hstates = jax.tree_util.tree_map(lambda x: x[None, :], initial_hstates) + update_state = ( + params, + opt_states, + init_hstates, + traj_batch, + advantages, + targets, + key, + ) + + # UPDATE EPOCHS + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.epochs + ) + + params, opt_states, _, traj_batch, advantages, targets, key = update_state + learner_state = RNNLearnerState( + params, + opt_states, + key, + env_state, + last_timestep, + last_done, + last_truncated, + hstates, + ) + metric = traj_batch.info + return learner_state, (metric, loss_info) + + def learner_fn(learner_state: RNNLearnerState) -> ExperimentOutput[RNNLearnerState]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + + Args: + learner_state (NamedTuple): + - params (ActorCriticParams): The initial model parameters. + - opt_states (OptStates): The initial optimizer states. + - key (chex.PRNGKey): The random number generator state. + - env_state (LogEnvState): The environment state. + - timesteps (TimeStep): The initial timestep in the initial trajectory. + - dones (bool): Whether the initial timestep was a terminal state. + - hstateS (ActorCriticHiddenStates): The initial hidden states of the RNN. + """ + + batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") + + learner_state, (episode_info, loss_info) = jax.lax.scan( + batched_update_step, learner_state, None, config.arch.num_updates_per_eval + ) + return ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ) + + return learner_fn + + +def learner_setup( + env: Environment, keys: chex.Array, config: DictConfig +) -> Tuple[LearnerFn[RNNLearnerState], RecurrentActor, ScannedRNN, RNNLearnerState]: + """Initialise learner_fn, network, optimiser, environment and states.""" + # Get available TPU cores. + n_devices = len(jax.devices()) + + # Get number/dimension of actions. + num_actions = int(env.action_spec().num_values) + config.system.action_dim = num_actions + + # PRNG keys. + key, actor_net_key, critic_net_key = keys + + # Define network and optimisers. + actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) + actor_action_head = hydra.utils.instantiate( + config.network.actor_network.action_head, action_dim=num_actions + ) + critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) + critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) + critic_head = hydra.utils.instantiate(config.network.critic_network.critic_head) + + actor_network = RecurrentActorFFM( + pre_torso=actor_pre_torso, + hidden_state_dim=config.network.critic_network.rnn_layer.hidden_state_dim, + cell_type=config.network.critic_network.rnn_layer.cell_type, + post_torso=actor_post_torso, + action_head=actor_action_head, + ) + critic_network = RecurrentCriticFFM( + pre_torso=critic_pre_torso, + hidden_state_dim=config.network.critic_network.rnn_layer.hidden_state_dim, + cell_type=config.network.critic_network.rnn_layer.cell_type, + post_torso=critic_post_torso, + critic_head=critic_head, + ) + actor_rnn = FFM( + config.network.actor_network.rnn_layer.hidden_state_dim, + config.network.actor_network.rnn_layer.hidden_state_dim, + config.network.actor_network.rnn_layer.hidden_state_dim, + ) + critic_rnn = FFM( + config.network.critic_network.rnn_layer.hidden_state_dim, + config.network.critic_network.rnn_layer.hidden_state_dim, + config.network.critic_network.rnn_layer.hidden_state_dim + ) + + actor_lr = make_learning_rate( + config.system.actor_lr, config, config.system.epochs, config.system.num_minibatches + ) + critic_lr = make_learning_rate( + config.system.critic_lr, config, config.system.epochs, config.system.num_minibatches + ) + + actor_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(actor_lr, eps=1e-5), + ) + critic_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(critic_lr, eps=1e-5), + ) + + # Initialise observation + init_obs = env.observation_spec().generate_value() + init_obs = jax.tree_util.tree_map( + lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0), + init_obs, + ) + init_obs = jax.tree_util.tree_map(lambda x: x[None, ...], init_obs) + init_done = jnp.zeros((1, config.arch.num_envs), dtype=bool) + init_x = (init_obs, init_done) + + # Initialise hidden states. + init_policy_hstate = actor_rnn.initialize_carry(config.arch.num_envs) + init_critic_hstate = critic_rnn.initialize_carry(config.arch.num_envs) + + # initialise params and optimiser state. + actor_params = actor_network.init(actor_net_key, init_policy_hstate, init_x) + actor_opt_state = actor_optim.init(actor_params) + critic_params = critic_network.init(critic_net_key, init_critic_hstate, init_x) + critic_opt_state = critic_optim.init(critic_params) + + actor_network_apply_fn = jax.vmap(actor_network.apply, in_axes=(None, 1, 1)) + critic_network_apply_fn = jax.vmap(critic_network.apply, in_axes=(None, 1, 1)) + + # Get network apply functions and optimiser updates. + apply_fns = (actor_network_apply_fn, critic_network_apply_fn) + update_fns = (actor_optim.update, critic_optim.update) + + # Get batched iterated update and replicate it to pmap it over cores. + learn = get_learner_fn(env, apply_fns, update_fns, config) + learn = jax.pmap(learn, axis_name="device") + + # Pack params and initial states. + params = ActorCriticParams(actor_params, critic_params) + hstates = ActorCriticHiddenStates(init_policy_hstate, init_critic_hstate) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.system.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, restored_hstates = loaded_checkpoint.restore_params(restore_hstates=True) + # Update the params and hstates + params = restored_params + hstates = restored_hstates if restored_hstates else hstates + + # Initialise environment states and timesteps: across devices and batches. + key, *env_keys = jax.random.split( + key, n_devices * config.arch.update_batch_size * config.arch.num_envs + 1 + ) + env_states, timesteps = jax.vmap(env.reset, in_axes=(0))( + jnp.stack(env_keys), + ) + reshape_states = lambda x: x.reshape( + (n_devices, config.arch.update_batch_size, config.arch.num_envs) + x.shape[1:] + ) + # (devices, update batch size, num_envs, ...) + env_states = jax.tree_util.tree_map(reshape_states, env_states) + timesteps = jax.tree_util.tree_map(reshape_states, timesteps) + + # Define params to be replicated across devices and batches. + dones = jnp.zeros( + (config.arch.num_envs,), + dtype=bool, + ) + truncated = jnp.zeros( + (config.arch.num_envs,), + dtype=bool, + ) + key, step_key = jax.random.split(key) + step_keys = jax.random.split(step_key, n_devices * config.arch.update_batch_size) + reshape_keys = lambda x: x.reshape((n_devices, config.arch.update_batch_size) + x.shape[1:]) + step_keys = reshape_keys(jnp.stack(step_keys)) + opt_states = ActorCriticOptStates(actor_opt_state, critic_opt_state) + replicate_learner = (params, opt_states, hstates, dones, truncated) + + # Duplicate learner for update_batch_size. + broadcast = lambda x: jnp.broadcast_to(x, (config.arch.update_batch_size,) + x.shape) + replicate_learner = jax.tree_util.tree_map(broadcast, replicate_learner) + + # Duplicate learner across devices. + replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) + + # Initialise learner state. + params, opt_states, hstates, dones, truncated = replicate_learner + init_learner_state = RNNLearnerState( + params=params, + opt_states=opt_states, + key=step_keys, + env_state=env_states, + timestep=timesteps, + done=dones, + truncated=truncated, + hstates=hstates, + ) + return learn, actor_network_apply_fn, actor_rnn, init_learner_state + + +def run_experiment(_config: DictConfig) -> float: + """Runs experiment.""" + config = copy.deepcopy(_config) + + # Calculate total timesteps. + n_devices = len(jax.devices()) + config.num_devices = n_devices + config = check_total_timesteps(config) + assert ( + config.arch.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + + # Set recurrent chunk size. + if config.system.recurrent_chunk_size is None: + config.system.recurrent_chunk_size = config.system.rollout_length + else: + assert ( + config.system.rollout_length % config.system.recurrent_chunk_size == 0 + ), "Rollout length must be divisible by recurrent chunk size." + + # Create the environments for train and eval. + env, eval_env = environments.make(config) + + # PRNG keys. + key, key_e, actor_net_key, critic_net_key = jax.random.split( + jax.random.PRNGKey(config.arch.seed), num=4 + ) + + # Setup learner. + learn, actor_network_apply_fn, actor_rnn, learner_state = learner_setup( + env, (key, actor_net_key, critic_net_key), config + ) + + # Setup evaluator. + evaluator, absolute_metric_evaluator, (trained_params, eval_keys) = evaluator_setup( + eval_env=eval_env, + key_e=key_e, + eval_act_fn=get_rec_distribution_act_fn(config, actor_network_apply_fn), + params=learner_state.params.actor_params, + config=config, + use_recurrent_net=True, + scanned_rnn=actor_rnn, + ) + + # Calculate number of updates per evaluation. + config.arch.num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation + steps_per_rollout = ( + n_devices + * config.arch.num_updates_per_eval + * config.system.rollout_length + * config.arch.update_batch_size + * config.arch.num_envs + ) + + # Logger setup + logger = StoixLogger(config) + cfg: Dict = OmegaConf.to_container(config, resolve=True) + cfg["arch"]["devices"] = jax.devices() + pprint(cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.system.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Run experiment for a total number of evaluations. + max_episode_return = jnp.float32(-1e7) + best_params = None + for eval_step in range(config.arch.num_evaluation): + # Train. + start_time = time.time() + learner_output = learn(learner_state) + jax.block_until_ready(learner_output) + + # Log the results of the training. + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + episode_metrics, ep_completed = get_final_step_metrics(learner_output.episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + + # Separately log timesteps, actoring metrics and training metrics. + logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) + if ep_completed: # only log episode metrics if an episode was completed in the rollout. + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) + + # Prepare for evaluation. + start_time = time.time() + trained_params = unreplicate_batch_dim(learner_output.learner_state.params.actor_params) + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + # Evaluate. + evaluator_output = evaluator(trained_params, eval_keys) + jax.block_until_ready(evaluator_output) + + # Log the results of the evaluation. + elapsed_time = time.time() - start_time + episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) + + steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) + evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) + + if save_checkpoint: + # Save checkpoint of learner state + checkpointer.save( + timestep=int(steps_per_rollout * (eval_step + 1)), + unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(trained_params) + max_episode_return = episode_return + + # Update runner state to continue training. + learner_state = learner_output.learner_state + + # Measure absolute metric. + if config.arch.absolute_metric: + start_time = time.time() + + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + evaluator_output = absolute_metric_evaluator(best_params, eval_keys) + jax.block_until_ready(evaluator_output) + + elapsed_time = time.time() - start_time + + t = int(steps_per_rollout * (eval_step + 1)) + steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) + evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE) + + # Stop the logger. + logger.stop() + # Record the performance for the final evaluation run. If the absolute metric is not + # calculated, this will be the final evaluation run. + eval_performance = float(jnp.mean(evaluator_output.episode_metrics[config.env.eval_metric])) + return eval_performance + + +@hydra.main(config_path="../../configs", config_name="default_rec_ppo.yaml", version_base="1.2") +def hydra_entry_point(cfg: DictConfig) -> float: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + + # Run experiment. + eval_performance = run_experiment(cfg) + + print(f"{Fore.CYAN}{Style.BRIGHT}Recurrent PPO experiment completed{Style.RESET_ALL}") + return eval_performance + + +if __name__ == "__main__": + hydra_entry_point() From c2f3e9ad4e59a3e83c961acca7f1039e0bf3e434 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Fri, 21 Jun 2024 19:05:21 +0000 Subject: [PATCH 03/38] temporary work - hacky solution to get running --- stoix/networks/base.py | 22 ++++---- stoix/networks/ffm.py | 53 ++++++++++---------- stoix/systems/ppo/rec_ppo_temp_ffm.py | 72 +++++++++++++++------------ 3 files changed, 76 insertions(+), 71 deletions(-) diff --git a/stoix/networks/base.py b/stoix/networks/base.py index 1bc190d8..c56fbf53 100644 --- a/stoix/networks/base.py +++ b/stoix/networks/base.py @@ -208,13 +208,11 @@ def __call__( policy_embedding = self.pre_torso(observation) policy_rnn_input = (policy_embedding, done) BatchFFM = nn.vmap( - FFM, - in_axes=1, out_axes=1, - variable_axes={'params': None}, - split_rngs={'params': False}) - policy_hidden_state, policy_embedding = BatchFFM(self.hidden_state_dim, self.hidden_state_dim, self.hidden_state_dim)( - policy_hidden_state, policy_rnn_input + FFM, in_axes=1, out_axes=1, variable_axes={"params": None}, split_rngs={"params": False} ) + policy_hidden_state, policy_embedding = BatchFFM( + self.hidden_state_dim, self.hidden_state_dim, self.hidden_state_dim + )(policy_hidden_state, policy_rnn_input) actor_logits = self.post_torso(policy_embedding) pi = self.action_head(actor_logits) @@ -245,14 +243,12 @@ def __call__( critic_embedding = self.pre_torso(observation) critic_rnn_input = (critic_embedding, done) BatchFFM = nn.vmap( - FFM, - in_axes=1, out_axes=1, - variable_axes={'params': None}, - split_rngs={'params': False}) - critic_hidden_state, critic_embedding = BatchFFM(self.hidden_state_dim, self.hidden_state_dim, self.hidden_state_dim)( - critic_hidden_state, critic_rnn_input + FFM, in_axes=1, out_axes=1, variable_axes={"params": None}, split_rngs={"params": False} ) + critic_hidden_state, critic_embedding = BatchFFM( + self.hidden_state_dim, self.hidden_state_dim, self.hidden_state_dim + )(critic_hidden_state, critic_rnn_input) critic_output = self.post_torso(critic_embedding) critic_output = self.critic_head(critic_output) - return critic_hidden_state, critic_output \ No newline at end of file + return critic_hidden_state, critic_output diff --git a/stoix/networks/ffm.py b/stoix/networks/ffm.py index 2404b777..0e480aff 100644 --- a/stoix/networks/ffm.py +++ b/stoix/networks/ffm.py @@ -1,28 +1,33 @@ from typing import Tuple + import flax.linen as nn import jax import jax.numpy as jnp + class Gate(nn.Module): - output_size: int + output_size: int - @nn.compact + @nn.compact def __call__(self, x): x = nn.Dense(self.output_size)(x) x = nn.sigmoid(x) return x + def init_deterministic( - memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 + memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 ) -> Tuple[jax.Array, jax.Array]: a_low = 1e-6 - a_high = 0.5 + a_high = 0.5 a = jnp.linspace(a_low, a_high, memory_size) b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) return a, b + class FFM(nn.Module): """Feedforward Memory Network.""" + trace_size: int context_size: int output_size: int @@ -33,10 +38,7 @@ def setup(self): self.gate_out = Gate(self.output_size) self.skip = nn.Dense(self.output_size) a, b = init_deterministic(self.trace_size, self.context_size) - self.ffa_params = ( - self.param('ffa_a', lambda rng: a), - self.param('ffa_b', lambda rng: b) - ) + self.ffa_params = (self.param("ffa_a", lambda rng: a), self.param("ffa_b", lambda rng: b)) self.mix = nn.Dense(self.output_size) self.ln = nn.LayerNorm(use_scale=False, use_bias=False) @@ -55,16 +57,19 @@ def unwrapped_associative_update( carry: Tuple[jax.Array, jax.Array, jax.Array], incoming: Tuple[jax.Array, jax.Array, jax.Array], ) -> Tuple[jax.Array, jax.Array, jax.Array]: - state, i, = carry + ( + state, + i, + ) = carry x, j = incoming state = state * self.gamma(j) + x return state, j + i - def wrapped_associative_update(self, carry, incoming): + def wrapped_associative_update(self, carry, incoming): prev_start, state, i = carry start, x, j = incoming # Reset all elements in the carry if we are starting a new episode - state = state * jnp.logical_not(start) + state = state * jnp.logical_not(start) j = j * jnp.logical_not(start) incoming = x, j carry = (state, i) @@ -107,9 +112,9 @@ def map_to_h(self, x): return scan_input def map_from_h(self, recurrent_state, x): - z_in = jnp.concatenate([jnp.real(recurrent_state), jnp.imag(recurrent_state)], axis=-1).reshape( - recurrent_state.shape[0], -1 - ) + z_in = jnp.concatenate( + [jnp.real(recurrent_state), jnp.imag(recurrent_state)], axis=-1 + ).reshape(recurrent_state.shape[0], -1) z = self.mix(z_in) gate_out = self.gate_out(x) skip = self.skip(x) @@ -127,12 +132,10 @@ def __call__(self, recurrent_state, inputs): def initialize_carry(self, batch_size: int = None): if batch_size is None: return jnp.zeros((1, self.trace_size, self.context_size), dtype=jnp.complex64) - + return jnp.zeros((1, batch_size, self.trace_size, self.context_size), dtype=jnp.complex64) - - if __name__ == "__main__": m = FFM( output_size=4, @@ -144,24 +147,22 @@ def initialize_carry(self, batch_size: int = None): start = jnp.zeros(10, dtype=bool) params = m.init(jax.random.PRNGKey(0), s, (x, start)) out_state, out = m.apply(params, s, (x, start)) - - + BatchFFM = nn.vmap( - FFM, - in_axes=1, out_axes=1, - variable_axes={'params': None}, - split_rngs={'params': False}) - + FFM, in_axes=1, out_axes=1, variable_axes={"params": None}, split_rngs={"params": False} + ) + m = BatchFFM( output_size=4, trace_size=5, context_size=6, ) - + s = m.initialize_carry(8) x = jnp.ones((10, 8, 2)) start = jnp.zeros((10, 8), dtype=bool) params = m.init(jax.random.PRNGKey(0), s, (x, start)) out_state, out = m.apply(params, s, (x, start)) - + print(out.shape) + print(out_state.shape) diff --git a/stoix/systems/ppo/rec_ppo_temp_ffm.py b/stoix/systems/ppo/rec_ppo_temp_ffm.py index 29513eb6..d1116e11 100644 --- a/stoix/systems/ppo/rec_ppo_temp_ffm.py +++ b/stoix/systems/ppo/rec_ppo_temp_ffm.py @@ -24,7 +24,13 @@ RNNLearnerState, ) from stoix.evaluator import evaluator_setup, get_rec_distribution_act_fn -from stoix.networks.base import RecurrentActor, RecurrentActorFFM, RecurrentCritic, RecurrentCriticFFM, ScannedRNN +from stoix.networks.base import ( + RecurrentActor, + RecurrentActorFFM, + RecurrentCritic, + RecurrentCriticFFM, + ScannedRNN, +) from stoix.networks.ffm import FFM from stoix.systems.ppo.ppo_types import ActorCriticHiddenStates, RNNPPOTransition from stoix.utils import make_env as environments @@ -94,8 +100,8 @@ def _env_step( batched_observation, last_done[jnp.newaxis, :], ) - - jax.debug.print("ac_in {x} {y}", x=ac_in[0].agent_view.shape, y=ac_in[1].shape) + + # jax.debug.print("ac_in {x} {y}", x=ac_in[0].agent_view.shape, y=ac_in[1].shape) # Run the network. policy_hidden_state, actor_policy = actor_apply_fn( @@ -131,7 +137,7 @@ def _env_step( timestep.reward, log_prob, last_timestep.observation, - hstates, + jax.tree_map(lambda x: x.squeeze(0), hstates), info, ) learner_state = RNNLearnerState( @@ -215,10 +221,10 @@ def _actor_loss_fn( ) -> Tuple: """Calculate the actor loss.""" # RERUN NETWORK - + # jax.debug.print("{x}",x= traj_batch.hstates.policy_hidden_state.shape) obs_and_done = (traj_batch.obs, traj_batch.done) policy_hidden_state = jax.tree_util.tree_map( - lambda x: x[0], traj_batch.hstates.policy_hidden_state + lambda x: x[0][jnp.newaxis, ...], traj_batch.hstates.policy_hidden_state ) _, actor_policy = actor_apply_fn( actor_params, policy_hidden_state, obs_and_done @@ -235,6 +241,7 @@ def _actor_loss_fn( "actor_loss": loss_actor, "entropy": entropy, } + return total_loss, loss_info def _critic_loss_fn( @@ -246,7 +253,7 @@ def _critic_loss_fn( # RERUN NETWORK obs_and_done = (traj_batch.obs, traj_batch.done) critic_hidden_state = jax.tree_util.tree_map( - lambda x: x[0], traj_batch.hstates.critic_hidden_state + lambda x: x[0][jnp.newaxis, ...], traj_batch.hstates.critic_hidden_state ) _, value = critic_apply_fn(critic_params, critic_hidden_state, obs_and_done) @@ -259,6 +266,7 @@ def _critic_loss_fn( loss_info = { "value_loss": value_loss, } + return total_loss, loss_info # CALCULATE ACTOR LOSS @@ -478,7 +486,7 @@ def learner_setup( critic_rnn = FFM( config.network.critic_network.rnn_layer.hidden_state_dim, config.network.critic_network.rnn_layer.hidden_state_dim, - config.network.critic_network.rnn_layer.hidden_state_dim + config.network.critic_network.rnn_layer.hidden_state_dim, ) actor_lr = make_learning_rate( @@ -517,8 +525,8 @@ def learner_setup( critic_params = critic_network.init(critic_net_key, init_critic_hstate, init_x) critic_opt_state = critic_optim.init(critic_params) - actor_network_apply_fn = jax.vmap(actor_network.apply, in_axes=(None, 1, 1)) - critic_network_apply_fn = jax.vmap(critic_network.apply, in_axes=(None, 1, 1)) + actor_network_apply_fn = actor_network.apply + critic_network_apply_fn = critic_network.apply # Get network apply functions and optimiser updates. apply_fns = (actor_network_apply_fn, critic_network_apply_fn) @@ -694,28 +702,28 @@ def run_experiment(_config: DictConfig) -> float: eval_keys = eval_keys.reshape(n_devices, -1) # Evaluate. - evaluator_output = evaluator(trained_params, eval_keys) - jax.block_until_ready(evaluator_output) - - # Log the results of the evaluation. - elapsed_time = time.time() - start_time - episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) - - steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) - evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) - - if save_checkpoint: - # Save checkpoint of learner state - checkpointer.save( - timestep=int(steps_per_rollout * (eval_step + 1)), - unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), - episode_return=episode_return, - ) - - if config.arch.absolute_metric and max_episode_return <= episode_return: - best_params = copy.deepcopy(trained_params) - max_episode_return = episode_return + # evaluator_output = evaluator(trained_params, eval_keys) + # jax.block_until_ready(evaluator_output) + + # # Log the results of the evaluation. + # elapsed_time = time.time() - start_time + # episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) + + # steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) + # evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + # logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) + + # if save_checkpoint: + # # Save checkpoint of learner state + # checkpointer.save( + # timestep=int(steps_per_rollout * (eval_step + 1)), + # unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), + # episode_return=episode_return, + # ) + + # if config.arch.absolute_metric and max_episode_return <= episode_return: + # best_params = copy.deepcopy(trained_params) + # max_episode_return = episode_return # Update runner state to continue training. learner_state = learner_output.learner_state From 5344323846e2fea52dc119eb645b4111dc4a8524 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Sat, 22 Jun 2024 12:08:26 +0000 Subject: [PATCH 04/38] chore: cleanup and edit stacked_rnn arch --- stoix/configs/arch/anakin.yaml | 2 +- stoix/configs/network/muzero.yaml | 3 +- stoix/configs/network/rnn.yaml | 4 +- stoix/networks/base.py | 126 +---- stoix/networks/layers.py | 12 +- stoix/networks/model_based.py | 25 +- stoix/systems/ppo/rec_ppo.py | 20 +- stoix/systems/ppo/rec_ppo_temp_ffm.py | 771 -------------------------- 8 files changed, 30 insertions(+), 933 deletions(-) delete mode 100644 stoix/systems/ppo/rec_ppo_temp_ffm.py diff --git a/stoix/configs/arch/anakin.yaml b/stoix/configs/arch/anakin.yaml index f6092512..a5025dee 100644 --- a/stoix/configs/arch/anakin.yaml +++ b/stoix/configs/arch/anakin.yaml @@ -3,7 +3,7 @@ # --- Training --- seed: 42 # RNG seed. update_batch_size: 1 # Number of vectorised gradient updates per device. -total_num_envs: 1024 # Total Number of vectorised environments across all devices and batched_updates. Needs to be divisible by n_devices*update_batch_size. +total_num_envs: 512 # Total Number of vectorised environments across all devices and batched_updates. Needs to be divisible by n_devices*update_batch_size. total_timesteps: 1e7 # Set the total environment steps. # If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. num_updates: ~ # Number of updates diff --git a/stoix/configs/network/muzero.yaml b/stoix/configs/network/muzero.yaml index 57f8beb9..6c376e23 100644 --- a/stoix/configs/network/muzero.yaml +++ b/stoix/configs/network/muzero.yaml @@ -30,8 +30,7 @@ wm_network: activation: silu # This can be seen as the dyanmics network. - rnn_size: 256 - num_stacked_rnn_layers: 2 + rnn_sizes: [256, 256] rnn_cell_type: "gru" recurrent_activation: "sigmoid" diff --git a/stoix/configs/network/rnn.yaml b/stoix/configs/network/rnn.yaml index 0a86e231..14801c2d 100644 --- a/stoix/configs/network/rnn.yaml +++ b/stoix/configs/network/rnn.yaml @@ -7,7 +7,7 @@ actor_network: use_layer_norm: False activation: silu rnn_layer: - _target_: stoix.networks.base.ScannedRNN + _target_: stoix.networks.recurrent.ScannedRNN cell_type: gru hidden_state_dim: 128 post_torso: @@ -25,7 +25,7 @@ critic_network: use_layer_norm: False activation: silu rnn_layer: - _target_: stoix.networks.base.ScannedRNN + _target_: stoix.networks.recurrent.ScannedRNN cell_type: gru hidden_state_dim: 128 post_torso: diff --git a/stoix/networks/base.py b/stoix/networks/base.py index c56fbf53..f30f72d5 100644 --- a/stoix/networks/base.py +++ b/stoix/networks/base.py @@ -1,17 +1,12 @@ -import functools from typing import Sequence, Tuple, Union import chex import distrax -import jax import jax.numpy as jnp -import numpy as np from flax import linen as nn from stoix.base_types import Observation, RNNObservation -from stoix.networks.ffm import FFM from stoix.networks.inputs import ObservationInput -from stoix.networks.utils import parse_rnn_cell class FeedForwardActor(nn.Module): @@ -84,51 +79,12 @@ def __call__( return concatenated -class ScannedRNN(nn.Module): - hidden_state_dim: int - cell_type: str - - @functools.partial( - nn.scan, - variable_broadcast="params", - in_axes=0, - out_axes=0, - split_rngs={"params": False}, - ) - @nn.compact - def __call__(self, rnn_state: chex.Array, x: chex.Array) -> Tuple[chex.Array, chex.Array]: - """Applies the module.""" - ins, resets = x - hidden_state_reset_fn = lambda reset_state, current_state: jnp.where( - resets[:, np.newaxis], - reset_state, - current_state, - ) - rnn_state = jax.tree_util.tree_map( - hidden_state_reset_fn, - self.initialize_carry(ins.shape[0]), - rnn_state, - ) - new_rnn_state, y = parse_rnn_cell(self.cell_type)(features=self.hidden_state_dim)( - rnn_state, ins - ) - return new_rnn_state, y - - @nn.nowrap - def initialize_carry(self, batch_size: int) -> chex.Array: - """Initializes the carry state.""" - # Use a dummy key since the default state init fn is just zeros. - cell = parse_rnn_cell(self.cell_type)(features=self.hidden_state_dim) - return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, self.hidden_state_dim)) - - class RecurrentActor(nn.Module): """Recurrent Actor Network.""" action_head: nn.Module post_torso: nn.Module - hidden_state_dim: int - cell_type: str + rnn: nn.Module pre_torso: nn.Module input_layer: nn.Module = ObservationInput() @@ -144,9 +100,7 @@ def __call__( observation = self.input_layer(observation) policy_embedding = self.pre_torso(observation) policy_rnn_input = (policy_embedding, done) - policy_hidden_state, policy_embedding = ScannedRNN(self.hidden_state_dim, self.cell_type)( - policy_hidden_state, policy_rnn_input - ) + policy_hidden_state, policy_embedding = self.rnn(policy_hidden_state, policy_rnn_input) actor_logits = self.post_torso(policy_embedding) pi = self.action_head(actor_logits) @@ -158,74 +112,7 @@ class RecurrentCritic(nn.Module): critic_head: nn.Module post_torso: nn.Module - hidden_state_dim: int - cell_type: str - pre_torso: nn.Module - input_layer: nn.Module = ObservationInput() - - @nn.compact - def __call__( - self, - critic_hidden_state: Tuple[chex.Array, chex.Array], - observation_done: RNNObservation, - ) -> Tuple[chex.Array, chex.Array]: - - observation, done = observation_done - - observation = self.input_layer(observation) - - critic_embedding = self.pre_torso(observation) - critic_rnn_input = (critic_embedding, done) - critic_hidden_state, critic_embedding = ScannedRNN(self.hidden_state_dim, self.cell_type)( - critic_hidden_state, critic_rnn_input - ) - critic_output = self.post_torso(critic_embedding) - critic_output = self.critic_head(critic_output) - - return critic_hidden_state, critic_output - - -class RecurrentActorFFM(nn.Module): - """Recurrent Actor Network.""" - - action_head: nn.Module - post_torso: nn.Module - hidden_state_dim: int - cell_type: str - pre_torso: nn.Module - input_layer: nn.Module = ObservationInput() - - @nn.compact - def __call__( - self, - policy_hidden_state: chex.Array, - observation_done: RNNObservation, - ) -> Tuple[chex.Array, distrax.DistributionLike]: - - observation, done = observation_done - - observation = self.input_layer(observation) - policy_embedding = self.pre_torso(observation) - policy_rnn_input = (policy_embedding, done) - BatchFFM = nn.vmap( - FFM, in_axes=1, out_axes=1, variable_axes={"params": None}, split_rngs={"params": False} - ) - policy_hidden_state, policy_embedding = BatchFFM( - self.hidden_state_dim, self.hidden_state_dim, self.hidden_state_dim - )(policy_hidden_state, policy_rnn_input) - actor_logits = self.post_torso(policy_embedding) - pi = self.action_head(actor_logits) - - return policy_hidden_state, pi - - -class RecurrentCriticFFM(nn.Module): - """Recurrent Critic Network.""" - - critic_head: nn.Module - post_torso: nn.Module - hidden_state_dim: int - cell_type: str + rnn: nn.Module pre_torso: nn.Module input_layer: nn.Module = ObservationInput() @@ -242,12 +129,7 @@ def __call__( critic_embedding = self.pre_torso(observation) critic_rnn_input = (critic_embedding, done) - BatchFFM = nn.vmap( - FFM, in_axes=1, out_axes=1, variable_axes={"params": None}, split_rngs={"params": False} - ) - critic_hidden_state, critic_embedding = BatchFFM( - self.hidden_state_dim, self.hidden_state_dim, self.hidden_state_dim - )(critic_hidden_state, critic_rnn_input) + critic_hidden_state, critic_embedding = self.rnn(critic_hidden_state, critic_rnn_input) critic_output = self.post_torso(critic_embedding) critic_output = self.critic_head(critic_output) diff --git a/stoix/networks/layers.py b/stoix/networks/layers.py index a328b358..2289ff1e 100644 --- a/stoix/networks/layers.py +++ b/stoix/networks/layers.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple import chex import jax @@ -24,19 +24,17 @@ class StackedRNN(nn.Module): activation_fn (str): The activation function to use in each RNN cell (default is "tanh"). """ - rnn_size: int + rnn_sizes: Sequence[int] rnn_cls: nn.Module - num_layers: int activation_fn: str = "sigmoid" def setup(self) -> None: """Set up the RNN cells for the stacked RNN.""" self.cells = [ - self.rnn_cls( - features=self.rnn_size, activation_fn=parse_activation_fn(self.activation_fn) - ) - for _ in range(self.num_layers) + self.rnn_cls(features=size, activation_fn=parse_activation_fn(self.activation_fn)) + for size in self.rnn_sizes ] + self.num_layers = len(self.cells) def __call__( self, all_rnn_states: List[chex.ArrayTree], x: chex.Array diff --git a/stoix/networks/model_based.py b/stoix/networks/model_based.py index 418b7dc9..0a4da83f 100644 --- a/stoix/networks/model_based.py +++ b/stoix/networks/model_based.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import List +from typing import List, Sequence import chex import jax @@ -16,9 +16,8 @@ class RewardBasedWorldModel(nn.Module): obs_encoder: nn.Module reward_torso: nn.Module reward_head: nn.Module - rnn_size: int + rnn_sizes: Sequence[int] action_dim: int - num_stacked_rnn_layers: int normalize_hidden_state: bool = True rnn_cell_type: str = "lstm" recurrent_activation: str = "tanh" @@ -34,16 +33,14 @@ def setup(self) -> None: rnn_cell_cls = parse_rnn_cell(self.rnn_cell_type) - self._core = StackedRNN( - self.rnn_size, rnn_cell_cls, self.num_stacked_rnn_layers, self.recurrent_activation - ) + self._core = StackedRNN(self.rnn_sizes, rnn_cell_cls, self.recurrent_activation) @cached_property def hidden_state_size(self) -> int: if self.rnn_cell_type in ("gru", "simple"): - hidden_state_size = sum([self.rnn_size] * self.num_stacked_rnn_layers) + hidden_state_size = sum(self.rnn_sizes) elif self.rnn_cell_type in ("lstm", "optimised_lstm"): - hidden_state_size = sum([self.rnn_size * self.num_stacked_rnn_layers]) * 2 + hidden_state_size = sum(self.rnn_sizes) * 2 return hidden_state_size def _rnn_to_flat(self, state: List[chex.ArrayTree]) -> chex.Array: @@ -60,16 +57,16 @@ def _flat_to_rnn(self, state: chex.Array) -> List[chex.ArrayTree]: """Maps flat vector to RNN state.""" tensors = [] cur_idx = 0 - for _ in range(self.num_stacked_rnn_layers): + for size in self.rnn_sizes: if self.rnn_cell_type in ("gru", "simple"): - states = state[Ellipsis, cur_idx : cur_idx + self.rnn_size] - cur_idx += self.rnn_size + states = state[Ellipsis, cur_idx : cur_idx + size] + cur_idx += size elif self.rnn_cell_type in ("lstm", "optimised_lstm"): states = ( - state[Ellipsis, cur_idx : cur_idx + self.rnn_size], - state[Ellipsis, cur_idx + self.rnn_size : cur_idx + 2 * self.rnn_size], + state[Ellipsis, cur_idx : cur_idx + size], + state[Ellipsis, cur_idx + size : cur_idx + 2 * size], ) - cur_idx += 2 * self.rnn_size + cur_idx += 2 * size tensors.append(states) assert cur_idx == state.shape[-1] return tensors diff --git a/stoix/systems/ppo/rec_ppo.py b/stoix/systems/ppo/rec_ppo.py index 8cbd814c..3e00b41d 100644 --- a/stoix/systems/ppo/rec_ppo.py +++ b/stoix/systems/ppo/rec_ppo.py @@ -24,7 +24,7 @@ RNNLearnerState, ) from stoix.evaluator import evaluator_setup, get_rec_distribution_act_fn -from stoix.networks.base import RecurrentActor, RecurrentCritic, ScannedRNN +from stoix.networks.base import RecurrentActor, RecurrentCritic from stoix.systems.ppo.ppo_types import ActorCriticHiddenStates, RNNPPOTransition from stoix.utils import make_env as environments from stoix.utils.checkpointing import Checkpointer @@ -431,7 +431,7 @@ def learner_fn(learner_state: RNNLearnerState) -> ExperimentOutput[RNNLearnerSta def learner_setup( env: Environment, keys: chex.Array, config: DictConfig -) -> Tuple[LearnerFn[RNNLearnerState], RecurrentActor, ScannedRNN, RNNLearnerState]: +) -> Tuple[LearnerFn[RNNLearnerState], RecurrentActor, Any, RNNLearnerState]: """Initialise learner_fn, network, optimiser, environment and states.""" # Get available TPU cores. n_devices = len(jax.devices()) @@ -445,36 +445,28 @@ def learner_setup( # Define network and optimisers. actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + actor_rnn = hydra.utils.instantiate(config.network.actor_network.rnn_layer) actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) actor_action_head = hydra.utils.instantiate( config.network.actor_network.action_head, action_dim=num_actions ) critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) + critic_rnn = hydra.utils.instantiate(config.network.critic_network.rnn_layer) critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) critic_head = hydra.utils.instantiate(config.network.critic_network.critic_head) actor_network = RecurrentActor( pre_torso=actor_pre_torso, - hidden_state_dim=config.network.critic_network.rnn_layer.hidden_state_dim, - cell_type=config.network.critic_network.rnn_layer.cell_type, + rnn=actor_rnn, post_torso=actor_post_torso, action_head=actor_action_head, ) critic_network = RecurrentCritic( pre_torso=critic_pre_torso, - hidden_state_dim=config.network.critic_network.rnn_layer.hidden_state_dim, - cell_type=config.network.critic_network.rnn_layer.cell_type, + rnn=critic_rnn, post_torso=critic_post_torso, critic_head=critic_head, ) - actor_rnn = ScannedRNN( - hidden_state_dim=config.network.actor_network.rnn_layer.hidden_state_dim, - cell_type=config.network.actor_network.rnn_layer.cell_type, - ) - critic_rnn = ScannedRNN( - hidden_state_dim=config.network.critic_network.rnn_layer.hidden_state_dim, - cell_type=config.network.critic_network.rnn_layer.cell_type, - ) actor_lr = make_learning_rate( config.system.actor_lr, config, config.system.epochs, config.system.num_minibatches diff --git a/stoix/systems/ppo/rec_ppo_temp_ffm.py b/stoix/systems/ppo/rec_ppo_temp_ffm.py deleted file mode 100644 index d1116e11..00000000 --- a/stoix/systems/ppo/rec_ppo_temp_ffm.py +++ /dev/null @@ -1,771 +0,0 @@ -import copy -import time -from typing import Any, Dict, Tuple - -import chex -import flax -import hydra -import jax -import jax.numpy as jnp -import optax -from colorama import Fore, Style -from flax.core.frozen_dict import FrozenDict -from jumanji.env import Environment -from omegaconf import DictConfig, OmegaConf -from rich.pretty import pprint - -from stoix.base_types import ( - ActorCriticOptStates, - ActorCriticParams, - ExperimentOutput, - LearnerFn, - RecActorApply, - RecCriticApply, - RNNLearnerState, -) -from stoix.evaluator import evaluator_setup, get_rec_distribution_act_fn -from stoix.networks.base import ( - RecurrentActor, - RecurrentActorFFM, - RecurrentCritic, - RecurrentCriticFFM, - ScannedRNN, -) -from stoix.networks.ffm import FFM -from stoix.systems.ppo.ppo_types import ActorCriticHiddenStates, RNNPPOTransition -from stoix.utils import make_env as environments -from stoix.utils.checkpointing import Checkpointer -from stoix.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims -from stoix.utils.logger import LogEvent, StoixLogger -from stoix.utils.loss import clipped_value_loss, ppo_clip_loss -from stoix.utils.multistep import batch_truncated_generalized_advantage_estimation -from stoix.utils.total_timestep_checker import check_total_timesteps -from stoix.utils.training import make_learning_rate -from stoix.wrappers.episode_metrics import get_final_step_metrics - - -def get_learner_fn( - env: Environment, - apply_fns: Tuple[RecActorApply, RecCriticApply], - update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], - config: DictConfig, -) -> LearnerFn[RNNLearnerState]: - """Get the learner function.""" - - actor_apply_fn, critic_apply_fn = apply_fns - actor_update_fn, critic_update_fn = update_fns - - def _update_step(learner_state: RNNLearnerState, _: Any) -> Tuple[RNNLearnerState, Tuple]: - """A single update of the network. - - This function steps the environment and records the trajectory batch for - training. It then calculates advantages and targets based on the recorded - trajectory and updates the actor and critic networks based on the calculated - losses. - - Args: - learner_state (NamedTuple): - - params (ActorCriticParams): The current model parameters. - - opt_states (OptStates): The current optimizer states. - - key (PRNGKey): The random number generator state. - - env_state (State): The environment state. - - last_timestep (TimeStep): The last timestep in the current trajectory. - - dones (bool): Whether the last timestep was a terminal state. - - hstates (ActorCriticHiddenStates): The current hidden states of the RNN. - _ (Any): The current metrics info. - """ - - def _env_step( - learner_state: RNNLearnerState, _: Any - ) -> Tuple[RNNLearnerState, RNNPPOTransition]: - """Step the environment.""" - ( - params, - opt_states, - key, - env_state, - last_timestep, - last_done, - last_truncated, - hstates, - ) = learner_state - - key, policy_key = jax.random.split(key) - - # Add a batch dimension to the observation. - batched_observation = jax.tree_util.tree_map( - lambda x: x[jnp.newaxis, :], last_timestep.observation - ) - ac_in = ( - batched_observation, - last_done[jnp.newaxis, :], - ) - - # jax.debug.print("ac_in {x} {y}", x=ac_in[0].agent_view.shape, y=ac_in[1].shape) - - # Run the network. - policy_hidden_state, actor_policy = actor_apply_fn( - params.actor_params, hstates.policy_hidden_state, ac_in - ) - critic_hidden_state, value = critic_apply_fn( - params.critic_params, hstates.critic_hidden_state, ac_in - ) - - # Sample action from the policy and squeeze out the batch dimension. - action = actor_policy.sample(seed=policy_key) - log_prob = actor_policy.log_prob(action) - value, action, log_prob = ( - value.squeeze(0), - action.squeeze(0), - log_prob.squeeze(0), - ) - - # Step the environment. - env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - - # log episode return and length - done = (timestep.discount == 0.0).reshape(-1) - truncated = (timestep.last() & (timestep.discount != 0.0)).reshape(-1) - info = timestep.extras["episode_metrics"] - - hstates = ActorCriticHiddenStates(policy_hidden_state, critic_hidden_state) - transition = RNNPPOTransition( - last_done, - last_truncated, - action, - value, - timestep.reward, - log_prob, - last_timestep.observation, - jax.tree_map(lambda x: x.squeeze(0), hstates), - info, - ) - learner_state = RNNLearnerState( - params, - opt_states, - key, - env_state, - timestep, - done, - truncated, - hstates, - ) - return learner_state, transition - - # INITIALISE RNN STATE - initial_hstates = learner_state.hstates - - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( - _env_step, learner_state, None, config.system.rollout_length - ) - - # CALCULATE ADVANTAGE - ( - params, - opt_states, - key, - env_state, - last_timestep, - last_done, - last_truncated, - hstates, - ) = learner_state - - # Add a batch dimension to the observation. - batched_last_observation = jax.tree_util.tree_map( - lambda x: x[jnp.newaxis, :], last_timestep.observation - ) - ac_in = ( - batched_last_observation, - last_done[jnp.newaxis, :], - ) - - # Run the network. - _, last_val = critic_apply_fn(params.critic_params, hstates.critic_hidden_state, ac_in) - # Squeeze out the batch dimension and mask out the value of terminal states. - last_val = last_val.squeeze(0) - last_val = jnp.where(last_done, jnp.zeros_like(last_val), last_val) - - r_t = traj_batch.reward - v_t = jnp.concatenate([traj_batch.value, last_val[None, ...]], axis=0) - d_t = 1.0 - traj_batch.done.astype(jnp.float32) - d_t = (d_t * config.system.gamma).astype(jnp.float32) - advantages, targets = batch_truncated_generalized_advantage_estimation( - r_t, - d_t, - config.system.gae_lambda, - v_t, - time_major=True, - standardize_advantages=config.system.standardize_advantages, - truncation_flags=traj_batch.truncated, - ) - - def _update_epoch(update_state: Tuple, _: Any) -> Tuple: - """Update the network for a single epoch.""" - - def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: - """Update the network for a single minibatch.""" - - params, opt_states = train_state - ( - traj_batch, - advantages, - targets, - ) = batch_info - - def _actor_loss_fn( - actor_params: FrozenDict, - traj_batch: RNNPPOTransition, - gae: chex.Array, - ) -> Tuple: - """Calculate the actor loss.""" - # RERUN NETWORK - # jax.debug.print("{x}",x= traj_batch.hstates.policy_hidden_state.shape) - obs_and_done = (traj_batch.obs, traj_batch.done) - policy_hidden_state = jax.tree_util.tree_map( - lambda x: x[0][jnp.newaxis, ...], traj_batch.hstates.policy_hidden_state - ) - _, actor_policy = actor_apply_fn( - actor_params, policy_hidden_state, obs_and_done - ) - log_prob = actor_policy.log_prob(traj_batch.action) - - loss_actor = ppo_clip_loss( - log_prob, traj_batch.log_prob, gae, config.system.clip_eps - ) - entropy = actor_policy.entropy().mean() - - total_loss = loss_actor - config.system.ent_coef * entropy - loss_info = { - "actor_loss": loss_actor, - "entropy": entropy, - } - - return total_loss, loss_info - - def _critic_loss_fn( - critic_params: FrozenDict, - traj_batch: RNNPPOTransition, - targets: chex.Array, - ) -> Tuple: - """Calculate the critic loss.""" - # RERUN NETWORK - obs_and_done = (traj_batch.obs, traj_batch.done) - critic_hidden_state = jax.tree_util.tree_map( - lambda x: x[0][jnp.newaxis, ...], traj_batch.hstates.critic_hidden_state - ) - _, value = critic_apply_fn(critic_params, critic_hidden_state, obs_and_done) - - # CALCULATE VALUE LOSS - value_loss = clipped_value_loss( - value, traj_batch.value, targets, config.system.clip_eps - ) - - total_loss = config.system.vf_coef * value_loss - loss_info = { - "value_loss": value_loss, - } - - return total_loss, loss_info - - # CALCULATE ACTOR LOSS - actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True) - actor_grads, actor_loss_info = actor_grad_fn( - params.actor_params, traj_batch, advantages - ) - - # CALCULATE CRITIC LOSS - critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True) - critic_grads, critic_loss_info = critic_grad_fn( - params.critic_params, traj_batch, targets - ) - - # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x - # This pmean could be a regular mean as the batch axis is on the same device. - actor_grads, actor_loss_info = jax.lax.pmean( - (actor_grads, actor_loss_info), axis_name="batch" - ) - # pmean over devices. - actor_grads, actor_loss_info = jax.lax.pmean( - (actor_grads, actor_loss_info), axis_name="device" - ) - - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="batch" - ) - # pmean over devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="device" - ) - - # UPDATE ACTOR PARAMS AND OPTIMISER STATE - actor_updates, actor_new_opt_state = actor_update_fn( - actor_grads, opt_states.actor_opt_state - ) - actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - - # UPDATE CRITIC PARAMS AND OPTIMISER STATE - critic_updates, critic_new_opt_state = critic_update_fn( - critic_grads, opt_states.critic_opt_state - ) - critic_new_params = optax.apply_updates(params.critic_params, critic_updates) - - new_params = ActorCriticParams(actor_new_params, critic_new_params) - new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) - - # PACK LOSS INFO - loss_info = { - **actor_loss_info, - **critic_loss_info, - } - - return (new_params, new_opt_state), loss_info - - ( - params, - opt_states, - init_hstates, - traj_batch, - advantages, - targets, - key, - ) = update_state - key, shuffle_key = jax.random.split(key) - - # SHUFFLE MINIBATCHES - batch = (traj_batch, advantages, targets) - num_recurrent_chunks = ( - config.system.rollout_length // config.system.recurrent_chunk_size - ) - batch = jax.tree_util.tree_map( - lambda x: x.reshape( - config.system.recurrent_chunk_size, - config.arch.num_envs * num_recurrent_chunks, - *x.shape[2:], - ), - batch, - ) - permutation = jax.random.permutation( - shuffle_key, config.arch.num_envs * num_recurrent_chunks - ) - shuffled_batch = jax.tree_util.tree_map( - lambda x: jnp.take(x, permutation, axis=1), batch - ) - reshaped_batch = jax.tree_util.tree_map( - lambda x: jnp.reshape( - x, (x.shape[0], config.system.num_minibatches, -1, *x.shape[2:]) - ), - shuffled_batch, - ) - minibatches = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 0), reshaped_batch) - - # UPDATE MINIBATCHES - (params, opt_states), loss_info = jax.lax.scan( - _update_minibatch, (params, opt_states), minibatches - ) - - update_state = ( - params, - opt_states, - init_hstates, - traj_batch, - advantages, - targets, - key, - ) - return update_state, loss_info - - init_hstates = jax.tree_util.tree_map(lambda x: x[None, :], initial_hstates) - update_state = ( - params, - opt_states, - init_hstates, - traj_batch, - advantages, - targets, - key, - ) - - # UPDATE EPOCHS - update_state, loss_info = jax.lax.scan( - _update_epoch, update_state, None, config.system.epochs - ) - - params, opt_states, _, traj_batch, advantages, targets, key = update_state - learner_state = RNNLearnerState( - params, - opt_states, - key, - env_state, - last_timestep, - last_done, - last_truncated, - hstates, - ) - metric = traj_batch.info - return learner_state, (metric, loss_info) - - def learner_fn(learner_state: RNNLearnerState) -> ExperimentOutput[RNNLearnerState]: - """Learner function. - - This function represents the learner, it updates the network parameters - by iteratively applying the `_update_step` function for a fixed number of - updates. The `_update_step` function is vectorized over a batch of inputs. - - Args: - learner_state (NamedTuple): - - params (ActorCriticParams): The initial model parameters. - - opt_states (OptStates): The initial optimizer states. - - key (chex.PRNGKey): The random number generator state. - - env_state (LogEnvState): The environment state. - - timesteps (TimeStep): The initial timestep in the initial trajectory. - - dones (bool): Whether the initial timestep was a terminal state. - - hstateS (ActorCriticHiddenStates): The initial hidden states of the RNN. - """ - - batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") - - learner_state, (episode_info, loss_info) = jax.lax.scan( - batched_update_step, learner_state, None, config.arch.num_updates_per_eval - ) - return ExperimentOutput( - learner_state=learner_state, - episode_metrics=episode_info, - train_metrics=loss_info, - ) - - return learner_fn - - -def learner_setup( - env: Environment, keys: chex.Array, config: DictConfig -) -> Tuple[LearnerFn[RNNLearnerState], RecurrentActor, ScannedRNN, RNNLearnerState]: - """Initialise learner_fn, network, optimiser, environment and states.""" - # Get available TPU cores. - n_devices = len(jax.devices()) - - # Get number/dimension of actions. - num_actions = int(env.action_spec().num_values) - config.system.action_dim = num_actions - - # PRNG keys. - key, actor_net_key, critic_net_key = keys - - # Define network and optimisers. - actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) - actor_action_head = hydra.utils.instantiate( - config.network.actor_network.action_head, action_dim=num_actions - ) - critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) - critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) - critic_head = hydra.utils.instantiate(config.network.critic_network.critic_head) - - actor_network = RecurrentActorFFM( - pre_torso=actor_pre_torso, - hidden_state_dim=config.network.critic_network.rnn_layer.hidden_state_dim, - cell_type=config.network.critic_network.rnn_layer.cell_type, - post_torso=actor_post_torso, - action_head=actor_action_head, - ) - critic_network = RecurrentCriticFFM( - pre_torso=critic_pre_torso, - hidden_state_dim=config.network.critic_network.rnn_layer.hidden_state_dim, - cell_type=config.network.critic_network.rnn_layer.cell_type, - post_torso=critic_post_torso, - critic_head=critic_head, - ) - actor_rnn = FFM( - config.network.actor_network.rnn_layer.hidden_state_dim, - config.network.actor_network.rnn_layer.hidden_state_dim, - config.network.actor_network.rnn_layer.hidden_state_dim, - ) - critic_rnn = FFM( - config.network.critic_network.rnn_layer.hidden_state_dim, - config.network.critic_network.rnn_layer.hidden_state_dim, - config.network.critic_network.rnn_layer.hidden_state_dim, - ) - - actor_lr = make_learning_rate( - config.system.actor_lr, config, config.system.epochs, config.system.num_minibatches - ) - critic_lr = make_learning_rate( - config.system.critic_lr, config, config.system.epochs, config.system.num_minibatches - ) - - actor_optim = optax.chain( - optax.clip_by_global_norm(config.system.max_grad_norm), - optax.adam(actor_lr, eps=1e-5), - ) - critic_optim = optax.chain( - optax.clip_by_global_norm(config.system.max_grad_norm), - optax.adam(critic_lr, eps=1e-5), - ) - - # Initialise observation - init_obs = env.observation_spec().generate_value() - init_obs = jax.tree_util.tree_map( - lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0), - init_obs, - ) - init_obs = jax.tree_util.tree_map(lambda x: x[None, ...], init_obs) - init_done = jnp.zeros((1, config.arch.num_envs), dtype=bool) - init_x = (init_obs, init_done) - - # Initialise hidden states. - init_policy_hstate = actor_rnn.initialize_carry(config.arch.num_envs) - init_critic_hstate = critic_rnn.initialize_carry(config.arch.num_envs) - - # initialise params and optimiser state. - actor_params = actor_network.init(actor_net_key, init_policy_hstate, init_x) - actor_opt_state = actor_optim.init(actor_params) - critic_params = critic_network.init(critic_net_key, init_critic_hstate, init_x) - critic_opt_state = critic_optim.init(critic_params) - - actor_network_apply_fn = actor_network.apply - critic_network_apply_fn = critic_network.apply - - # Get network apply functions and optimiser updates. - apply_fns = (actor_network_apply_fn, critic_network_apply_fn) - update_fns = (actor_optim.update, critic_optim.update) - - # Get batched iterated update and replicate it to pmap it over cores. - learn = get_learner_fn(env, apply_fns, update_fns, config) - learn = jax.pmap(learn, axis_name="device") - - # Pack params and initial states. - params = ActorCriticParams(actor_params, critic_params) - hstates = ActorCriticHiddenStates(init_policy_hstate, init_critic_hstate) - - # Load model from checkpoint if specified. - if config.logger.checkpointing.load_model: - loaded_checkpoint = Checkpointer( - model_name=config.system.system_name, - **config.logger.checkpointing.load_args, # Other checkpoint args - ) - # Restore the learner state from the checkpoint - restored_params, restored_hstates = loaded_checkpoint.restore_params(restore_hstates=True) - # Update the params and hstates - params = restored_params - hstates = restored_hstates if restored_hstates else hstates - - # Initialise environment states and timesteps: across devices and batches. - key, *env_keys = jax.random.split( - key, n_devices * config.arch.update_batch_size * config.arch.num_envs + 1 - ) - env_states, timesteps = jax.vmap(env.reset, in_axes=(0))( - jnp.stack(env_keys), - ) - reshape_states = lambda x: x.reshape( - (n_devices, config.arch.update_batch_size, config.arch.num_envs) + x.shape[1:] - ) - # (devices, update batch size, num_envs, ...) - env_states = jax.tree_util.tree_map(reshape_states, env_states) - timesteps = jax.tree_util.tree_map(reshape_states, timesteps) - - # Define params to be replicated across devices and batches. - dones = jnp.zeros( - (config.arch.num_envs,), - dtype=bool, - ) - truncated = jnp.zeros( - (config.arch.num_envs,), - dtype=bool, - ) - key, step_key = jax.random.split(key) - step_keys = jax.random.split(step_key, n_devices * config.arch.update_batch_size) - reshape_keys = lambda x: x.reshape((n_devices, config.arch.update_batch_size) + x.shape[1:]) - step_keys = reshape_keys(jnp.stack(step_keys)) - opt_states = ActorCriticOptStates(actor_opt_state, critic_opt_state) - replicate_learner = (params, opt_states, hstates, dones, truncated) - - # Duplicate learner for update_batch_size. - broadcast = lambda x: jnp.broadcast_to(x, (config.arch.update_batch_size,) + x.shape) - replicate_learner = jax.tree_util.tree_map(broadcast, replicate_learner) - - # Duplicate learner across devices. - replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) - - # Initialise learner state. - params, opt_states, hstates, dones, truncated = replicate_learner - init_learner_state = RNNLearnerState( - params=params, - opt_states=opt_states, - key=step_keys, - env_state=env_states, - timestep=timesteps, - done=dones, - truncated=truncated, - hstates=hstates, - ) - return learn, actor_network_apply_fn, actor_rnn, init_learner_state - - -def run_experiment(_config: DictConfig) -> float: - """Runs experiment.""" - config = copy.deepcopy(_config) - - # Calculate total timesteps. - n_devices = len(jax.devices()) - config.num_devices = n_devices - config = check_total_timesteps(config) - assert ( - config.arch.num_updates > config.arch.num_evaluation - ), "Number of updates per evaluation must be less than total number of updates." - - # Set recurrent chunk size. - if config.system.recurrent_chunk_size is None: - config.system.recurrent_chunk_size = config.system.rollout_length - else: - assert ( - config.system.rollout_length % config.system.recurrent_chunk_size == 0 - ), "Rollout length must be divisible by recurrent chunk size." - - # Create the environments for train and eval. - env, eval_env = environments.make(config) - - # PRNG keys. - key, key_e, actor_net_key, critic_net_key = jax.random.split( - jax.random.PRNGKey(config.arch.seed), num=4 - ) - - # Setup learner. - learn, actor_network_apply_fn, actor_rnn, learner_state = learner_setup( - env, (key, actor_net_key, critic_net_key), config - ) - - # Setup evaluator. - evaluator, absolute_metric_evaluator, (trained_params, eval_keys) = evaluator_setup( - eval_env=eval_env, - key_e=key_e, - eval_act_fn=get_rec_distribution_act_fn(config, actor_network_apply_fn), - params=learner_state.params.actor_params, - config=config, - use_recurrent_net=True, - scanned_rnn=actor_rnn, - ) - - # Calculate number of updates per evaluation. - config.arch.num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation - steps_per_rollout = ( - n_devices - * config.arch.num_updates_per_eval - * config.system.rollout_length - * config.arch.update_batch_size - * config.arch.num_envs - ) - - # Logger setup - logger = StoixLogger(config) - cfg: Dict = OmegaConf.to_container(config, resolve=True) - cfg["arch"]["devices"] = jax.devices() - pprint(cfg) - - # Set up checkpointer - save_checkpoint = config.logger.checkpointing.save_model - if save_checkpoint: - checkpointer = Checkpointer( - metadata=config, # Save all config as metadata in the checkpoint - model_name=config.system.system_name, - **config.logger.checkpointing.save_args, # Checkpoint args - ) - - # Run experiment for a total number of evaluations. - max_episode_return = jnp.float32(-1e7) - best_params = None - for eval_step in range(config.arch.num_evaluation): - # Train. - start_time = time.time() - learner_output = learn(learner_state) - jax.block_until_ready(learner_output) - - # Log the results of the training. - elapsed_time = time.time() - start_time - t = int(steps_per_rollout * (eval_step + 1)) - episode_metrics, ep_completed = get_final_step_metrics(learner_output.episode_metrics) - episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time - - # Separately log timesteps, actoring metrics and training metrics. - logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) - if ep_completed: # only log episode metrics if an episode was completed in the rollout. - logger.log(episode_metrics, t, eval_step, LogEvent.ACT) - logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) - - # Prepare for evaluation. - start_time = time.time() - trained_params = unreplicate_batch_dim(learner_output.learner_state.params.actor_params) - key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) - eval_keys = jnp.stack(eval_keys) - eval_keys = eval_keys.reshape(n_devices, -1) - - # Evaluate. - # evaluator_output = evaluator(trained_params, eval_keys) - # jax.block_until_ready(evaluator_output) - - # # Log the results of the evaluation. - # elapsed_time = time.time() - start_time - # episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) - - # steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) - # evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - # logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) - - # if save_checkpoint: - # # Save checkpoint of learner state - # checkpointer.save( - # timestep=int(steps_per_rollout * (eval_step + 1)), - # unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), - # episode_return=episode_return, - # ) - - # if config.arch.absolute_metric and max_episode_return <= episode_return: - # best_params = copy.deepcopy(trained_params) - # max_episode_return = episode_return - - # Update runner state to continue training. - learner_state = learner_output.learner_state - - # Measure absolute metric. - if config.arch.absolute_metric: - start_time = time.time() - - key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) - eval_keys = jnp.stack(eval_keys) - eval_keys = eval_keys.reshape(n_devices, -1) - - evaluator_output = absolute_metric_evaluator(best_params, eval_keys) - jax.block_until_ready(evaluator_output) - - elapsed_time = time.time() - start_time - - t = int(steps_per_rollout * (eval_step + 1)) - steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) - evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE) - - # Stop the logger. - logger.stop() - # Record the performance for the final evaluation run. If the absolute metric is not - # calculated, this will be the final evaluation run. - eval_performance = float(jnp.mean(evaluator_output.episode_metrics[config.env.eval_metric])) - return eval_performance - - -@hydra.main(config_path="../../configs", config_name="default_rec_ppo.yaml", version_base="1.2") -def hydra_entry_point(cfg: DictConfig) -> float: - """Experiment entry point.""" - # Allow dynamic attributes. - OmegaConf.set_struct(cfg, False) - - # Run experiment. - eval_performance = run_experiment(cfg) - - print(f"{Fore.CYAN}{Style.BRIGHT}Recurrent PPO experiment completed{Style.RESET_ALL}") - return eval_performance - - -if __name__ == "__main__": - hydra_entry_point() From 73b5a83e7a6ee046d34dafcfca9fef2745572f96 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Sat, 22 Jun 2024 12:10:11 +0000 Subject: [PATCH 05/38] chore: move around files --- stoix/networks/recurrent.py | 48 +++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 stoix/networks/recurrent.py diff --git a/stoix/networks/recurrent.py b/stoix/networks/recurrent.py new file mode 100644 index 00000000..8c657b76 --- /dev/null +++ b/stoix/networks/recurrent.py @@ -0,0 +1,48 @@ +import functools +from typing import Tuple + +import chex +import jax +import jax.numpy as jnp +import numpy as np +from flax import linen as nn + +from stoix.networks.utils import parse_rnn_cell + + +class ScannedRNN(nn.Module): + hidden_state_dim: int + cell_type: str + + @functools.partial( + nn.scan, + variable_broadcast="params", + in_axes=0, + out_axes=0, + split_rngs={"params": False}, + ) + @nn.compact + def __call__(self, rnn_state: chex.Array, x: chex.Array) -> Tuple[chex.Array, chex.Array]: + """Applies the module.""" + ins, resets = x + hidden_state_reset_fn = lambda reset_state, current_state: jnp.where( + resets[:, np.newaxis], + reset_state, + current_state, + ) + rnn_state = jax.tree_util.tree_map( + hidden_state_reset_fn, + self.initialize_carry(ins.shape[0]), + rnn_state, + ) + new_rnn_state, y = parse_rnn_cell(self.cell_type)(features=self.hidden_state_dim)( + rnn_state, ins + ) + return new_rnn_state, y + + @nn.nowrap + def initialize_carry(self, batch_size: int) -> chex.Array: + """Initializes the carry state.""" + # Use a dummy key since the default state init fn is just zeros. + cell = parse_rnn_cell(self.cell_type)(features=self.hidden_state_dim) + return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, self.hidden_state_dim)) From 77e4a21bc3ce77478391c8e7f804e188f0eab37d Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Sat, 22 Jun 2024 16:15:58 +0100 Subject: [PATCH 06/38] large ffm refactor, runs but possibly introduced bugs --- stoix/networks/ffm.py | 183 ++++++++++++++++++++++++++---------------- 1 file changed, 112 insertions(+), 71 deletions(-) diff --git a/stoix/networks/ffm.py b/stoix/networks/ffm.py index 0e480aff..7c87890e 100644 --- a/stoix/networks/ffm.py +++ b/stoix/networks/ffm.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Tuple import flax.linen as nn @@ -6,6 +7,7 @@ class Gate(nn.Module): + """Sigmoidal gating""" output_size: int @nn.compact @@ -18,6 +20,7 @@ def __call__(self, x): def init_deterministic( memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 ) -> Tuple[jax.Array, jax.Array]: + """Deterministic initialization of the FFM parameters.""" a_low = 1e-6 a_high = 0.5 a = jnp.linspace(a_low, a_high, memory_size) @@ -25,25 +28,18 @@ def init_deterministic( return a, b -class FFM(nn.Module): - """Feedforward Memory Network.""" - +class FFMCell(nn.Module): + """The binary associative update function for the FFM.""" trace_size: int context_size: int output_size: int def setup(self): - self.pre = nn.Dense(self.trace_size) - self.gate_in = Gate(self.trace_size) - self.gate_out = Gate(self.output_size) - self.skip = nn.Dense(self.output_size) a, b = init_deterministic(self.trace_size, self.context_size) - self.ffa_params = (self.param("ffa_a", lambda rng: a), self.param("ffa_b", lambda rng: b)) - self.mix = nn.Dense(self.output_size) - self.ln = nn.LayerNorm(use_scale=False, use_bias=False) + self.params = (self.param("ffa_a", lambda rng: a), self.param("ffa_b", lambda rng: b)) def log_gamma(self, t: jax.Array) -> jax.Array: - a, b = self.ffa_params + a, b = self.params a = -jnp.abs(a).reshape((1, self.trace_size, 1)) b = b.reshape(1, 1, self.context_size) ab = jax.lax.complex(a, b) @@ -52,11 +48,13 @@ def log_gamma(self, t: jax.Array) -> jax.Array: def gamma(self, t: jax.Array) -> jax.Array: return jnp.exp(self.log_gamma(t)) - def unwrapped_associative_update( - self, - carry: Tuple[jax.Array, jax.Array, jax.Array], - incoming: Tuple[jax.Array, jax.Array, jax.Array], - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + def initialize_carry(self, batch_size: int = None): + if batch_size is None: + return jnp.zeros((1, self.trace_size, self.context_size), dtype=jnp.complex64), jnp.ones((1,), dtype=jnp.int32) + + return jnp.zeros((1, batch_size, self.trace_size, self.context_size), dtype=jnp.complex64), jnp.ones((1, batch_size), dtype=jnp.int32) + + def __call__(self, carry, incoming): ( state, i, @@ -65,44 +63,81 @@ def unwrapped_associative_update( state = state * self.gamma(j) + x return state, j + i - def wrapped_associative_update(self, carry, incoming): - prev_start, state, i = carry - start, x, j = incoming - # Reset all elements in the carry if we are starting a new episode - state = state * jnp.logical_not(start) - j = j * jnp.logical_not(start) - incoming = x, j - carry = (state, i) - out = self.unwrapped_associative_update(carry, incoming) - start_out = jnp.logical_or(start, prev_start) - return (start_out, *out) + +class MemoroidResetWrapper(nn.Module): + """A wrapper around memoroid cells like FFM, LRU, etc that resets + the recurrent state upon a reset signal.""" + cell: nn.Module + + def __call__(self, carry, incoming): + states, prev_start = carry + xs, start = incoming + + def reset_state(start, current_state, initial_state): + # Expand to reset all dims of state: [B, 1, 1, ...] + expanded_start = start.reshape(-1, *([1] * (current_state.ndim - 1))) + out = current_state * jnp.logical_not(expanded_start) + initial_state + return out + + initial_states = self.cell.initialize_carry() + states = jax.tree.map(partial(reset_state, start), states, initial_states) + out = self.cell(states, xs) + start_carry = jnp.logical_or(start, prev_start) + + return out, start_carry + + def initialize_carry(self, batch_size: int = None): + if batch_size is None: + # TODO: Should this be one or zero? + return self.cell.initialize_carry(batch_size), jnp.zeros((1,), dtype=bool) + + return self.cell.initialize_carry(batch_size), jnp.zeros((batch_size,), dtype=bool) + + + +class FFM(nn.Module): + """Fast and Forgetful Memory""" + + trace_size: int + context_size: int + output_size: int + cell: nn.Module + + def setup(self): + self.pre = nn.Dense(self.trace_size) + self.gate_in = Gate(self.trace_size) + self.ffa = FFMCell(self.trace_size, self.context_size, self.output_size) + self.gate_out = Gate(self.output_size) + self.skip = nn.Dense(self.output_size) + self.mix = nn.Dense(self.output_size) + self.ln = nn.LayerNorm(use_scale=False, use_bias=False) def scan( self, - x: jax.Array, state: jax.Array, - start: jax.Array, + inputs: jax.Array, ) -> jax.Array: - """Given an input and recurrent state, this will update the recurrent state. This is equivalent - to the inner-function g in the paper.""" - # x: [T, memory_size] - # memory: [1, memory_size, context_size] - T = x.shape[0] - timestep = jnp.ones(T + 1, dtype=jnp.int32).reshape(-1, 1, 1) - # Add context dim - start = start.reshape(T, 1, 1) - - # Now insert previous recurrent state - x = jnp.concatenate([state, x], axis=0) - start = jnp.concatenate([jnp.zeros_like(start[:1]), start], axis=0) - - # This is not executed during inference -- method will just return x if size is 1 - _, new_state, _ = jax.lax.associative_scan( - self.wrapped_associative_update, - (start, x, timestep), + """Execute the associative scan to update the recurrent state. + + Note that we do a trick here by concatenating the previou state to the inputs. + This is allowed since the scan is associative. This ensures that the previous + recurrent state feeds information into the scan. Without this method, we need + separate methods for rollouts and training.""" + + # Concatenate the prevous state to the inputs and scan over the result + # This ensures the previous recurrent state contributes to the current batch + # state: [start, (x, j)] + # inputs: [start, (x, j)] + scan_inputs = jax.tree.map(lambda x, s: jnp.concatenate([s, x], axis=0), inputs, state) + new_state = jax.lax.associative_scan( + self.cell, + scan_inputs, axis=0, ) - return new_state[1:] + # The zeroth index corresponds to the previous recurrent state + # We just use it to ensure continuity + # We do not actually want to use these values, so slice them away + return jax.tree.map(lambda x: x[1:], new_state) def map_to_h(self, x): gate_in = self.gate_in(x) @@ -124,16 +159,16 @@ def map_from_h(self, recurrent_state, x): def __call__(self, recurrent_state, inputs): x, resets = inputs z = self.map_to_h(x) - recurrent_state = self.scan(z, recurrent_state, resets) - out = self.map_from_h(recurrent_state, x) + # Relative timestep + ts = jnp.ones(x.shape[0], dtype=jnp.int32) + recurrent_state = self.scan(recurrent_state, ((z, ts), resets)) + # recurrent_state is ((state, timestep), reset) + out = self.map_from_h(recurrent_state[0][0], x) final_state = recurrent_state[-1:] return final_state, out def initialize_carry(self, batch_size: int = None): - if batch_size is None: - return jnp.zeros((1, self.trace_size, self.context_size), dtype=jnp.complex64) - - return jnp.zeros((1, batch_size, self.trace_size, self.context_size), dtype=jnp.complex64) + return self.cell.initialize_carry(batch_size) if __name__ == "__main__": @@ -141,6 +176,11 @@ def initialize_carry(self, batch_size: int = None): output_size=4, trace_size=5, context_size=6, + cell=MemoroidResetWrapper( + cell=FFMCell( + output_size=4,trace_size=5,context_size=6 + ) + ) ) s = m.initialize_carry() x = jnp.ones((10, 2)) @@ -148,21 +188,22 @@ def initialize_carry(self, batch_size: int = None): params = m.init(jax.random.PRNGKey(0), s, (x, start)) out_state, out = m.apply(params, s, (x, start)) - BatchFFM = nn.vmap( - FFM, in_axes=1, out_axes=1, variable_axes={"params": None}, split_rngs={"params": False} - ) - - m = BatchFFM( - output_size=4, - trace_size=5, - context_size=6, - ) - - s = m.initialize_carry(8) - x = jnp.ones((10, 8, 2)) - start = jnp.zeros((10, 8), dtype=bool) - params = m.init(jax.random.PRNGKey(0), s, (x, start)) - out_state, out = m.apply(params, s, (x, start)) - - print(out.shape) - print(out_state.shape) + # BatchFFM = nn.vmap( + # FFM, in_axes=1, out_axes=1, variable_axes={"params": None}, split_rngs={"params": False} + # ) + + # m = BatchFFM( + # trace_size=4, + # context_size=5, + # output_size=6, + # cell=MemoroidResetWrapper(cell=FFMCell(4,5,6)) + # ) + + # s = m.initialize_carry(8) + # x = jnp.ones((10, 8, 2)) + # start = jnp.zeros((10, 8), dtype=bool) + # params = m.init(jax.random.PRNGKey(0), s, (x, start)) + # out_state, out = m.apply(params, s, (x, start)) + + # print(out.shape) + # print(out_state.shape) From 46f2d261e2fb03289701d3b47d0110806e050838 Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Sat, 22 Jun 2024 16:37:25 +0100 Subject: [PATCH 07/38] comments, further cleanup, and only return final state --- stoix/networks/ffm.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/stoix/networks/ffm.py b/stoix/networks/ffm.py index 7c87890e..bfc86f1e 100644 --- a/stoix/networks/ffm.py +++ b/stoix/networks/ffm.py @@ -94,7 +94,6 @@ def initialize_carry(self, batch_size: int = None): return self.cell.initialize_carry(batch_size), jnp.zeros((batch_size,), dtype=bool) - class FFM(nn.Module): """Fast and Forgetful Memory""" @@ -139,17 +138,23 @@ def scan( # We do not actually want to use these values, so slice them away return jax.tree.map(lambda x: x[1:], new_state) - def map_to_h(self, x): + def map_to_h(self, inputs): + """Map from the input space to the recurrent state space""" + x, resets = inputs gate_in = self.gate_in(x) pre = self.pre(x) gated_x = pre * gate_in - scan_input = jnp.repeat(jnp.expand_dims(gated_x, 2), self.context_size, axis=2) - return scan_input + # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous + ts = jnp.ones(x.shape[0], dtype=jnp.int32) + z = jnp.repeat(jnp.expand_dims(gated_x, 2), self.context_size, axis=2) + return (z, ts), resets def map_from_h(self, recurrent_state, x): + """Map from the recurrent space to the Markov space""" + (state, ts), reset = recurrent_state z_in = jnp.concatenate( - [jnp.real(recurrent_state), jnp.imag(recurrent_state)], axis=-1 - ).reshape(recurrent_state.shape[0], -1) + [jnp.real(state), jnp.imag(state)], axis=-1 + ).reshape(state.shape[0], -1) z = self.mix(z_in) gate_out = self.gate_out(x) skip = self.skip(x) @@ -157,15 +162,16 @@ def map_from_h(self, recurrent_state, x): return out def __call__(self, recurrent_state, inputs): - x, resets = inputs - z = self.map_to_h(x) - # Relative timestep - ts = jnp.ones(x.shape[0], dtype=jnp.int32) - recurrent_state = self.scan(recurrent_state, ((z, ts), resets)) + # Recurrent state should be ((state, timestep), reset) + # Inputs should be (x, reset) + h = self.map_to_h(inputs) + recurrent_state = self.scan(recurrent_state, h) # recurrent_state is ((state, timestep), reset) - out = self.map_from_h(recurrent_state[0][0], x) - final_state = recurrent_state[-1:] - return final_state, out + out = self.map_from_h(recurrent_state, x) + + # TODO: Remove this when we want to return all recurrent states instead of just the last one + final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) + return final_recurrent_state, out def initialize_carry(self, batch_size: int = None): return self.cell.initialize_carry(batch_size) From 369b1cfa8003c2d839f0701478c942ecaee7f22a Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Sat, 22 Jun 2024 16:49:02 +0100 Subject: [PATCH 08/38] factor out recurrent associative scan --- stoix/networks/ffm.py | 57 ++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/stoix/networks/ffm.py b/stoix/networks/ffm.py index bfc86f1e..bb22a34b 100644 --- a/stoix/networks/ffm.py +++ b/stoix/networks/ffm.py @@ -6,6 +6,34 @@ import jax.numpy as jnp +def recurrent_associative_scan( + cell: nn.Module, + state: jax.Array, + inputs: jax.Array, + axis: int = 0, +) -> jax.Array: + """Execute the associative scan to update the recurrent state. + + Note that we do a trick here by concatenating the previous state to the inputs. + This is allowed since the scan is associative. This ensures that the previous + recurrent state feeds information into the scan. Without this method, we need + separate methods for rollouts and training.""" + + # Concatenate the prevous state to the inputs and scan over the result + # This ensures the previous recurrent state contributes to the current batch + # state: [start, (x, j)] + # inputs: [start, (x, j)] + scan_inputs = jax.tree.map(lambda x, s: jnp.concatenate([s, x], axis=0), inputs, state) + new_state = jax.lax.associative_scan( + cell, + scan_inputs, + axis=axis, + ) + # The zeroth index corresponds to the previous recurrent state + # We just use it to ensure continuity + # We do not actually want to use these values, so slice them away + return jax.tree.map(lambda x: x[1:], new_state) + class Gate(nn.Module): """Sigmoidal gating""" output_size: int @@ -111,33 +139,6 @@ def setup(self): self.mix = nn.Dense(self.output_size) self.ln = nn.LayerNorm(use_scale=False, use_bias=False) - def scan( - self, - state: jax.Array, - inputs: jax.Array, - ) -> jax.Array: - """Execute the associative scan to update the recurrent state. - - Note that we do a trick here by concatenating the previou state to the inputs. - This is allowed since the scan is associative. This ensures that the previous - recurrent state feeds information into the scan. Without this method, we need - separate methods for rollouts and training.""" - - # Concatenate the prevous state to the inputs and scan over the result - # This ensures the previous recurrent state contributes to the current batch - # state: [start, (x, j)] - # inputs: [start, (x, j)] - scan_inputs = jax.tree.map(lambda x, s: jnp.concatenate([s, x], axis=0), inputs, state) - new_state = jax.lax.associative_scan( - self.cell, - scan_inputs, - axis=0, - ) - # The zeroth index corresponds to the previous recurrent state - # We just use it to ensure continuity - # We do not actually want to use these values, so slice them away - return jax.tree.map(lambda x: x[1:], new_state) - def map_to_h(self, inputs): """Map from the input space to the recurrent state space""" x, resets = inputs @@ -165,7 +166,7 @@ def __call__(self, recurrent_state, inputs): # Recurrent state should be ((state, timestep), reset) # Inputs should be (x, reset) h = self.map_to_h(inputs) - recurrent_state = self.scan(recurrent_state, h) + recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h) # recurrent_state is ((state, timestep), reset) out = self.map_from_h(recurrent_state, x) From 72e4b7a82388683f0f5f563e37e10f4c3ce0844c Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Sat, 22 Jun 2024 17:46:34 +0100 Subject: [PATCH 09/38] Add simplified FFM and stacked simplified FFM --- stoix/networks/ffm.py | 126 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 122 insertions(+), 4 deletions(-) diff --git a/stoix/networks/ffm.py b/stoix/networks/ffm.py index bb22a34b..5cde89bc 100644 --- a/stoix/networks/ffm.py +++ b/stoix/networks/ffm.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Tuple +from typing import Any, List, Tuple import flax.linen as nn import jax @@ -55,15 +55,31 @@ def init_deterministic( b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) return a, b +def init_random( + memory_size: int, context_size: int, min_period: int = 1, max_period: int = 10_000, *, key +) -> Tuple[jax.Array, jax.Array]: + _, k1, k2 = jax.random.split(key, 3) + a_low = 1e-6 + a_high = 0.1 + a = jax.random.uniform(k1, (memory_size,), minval=a_low, maxval=a_high) + b = 2 * jnp.pi / jnp.exp(jax.random.uniform(k2, (context_size,), minval=jnp.log(min_period), maxval=jnp.log(max_period))) + return a, b + class FFMCell(nn.Module): """The binary associative update function for the FFM.""" trace_size: int context_size: int output_size: int + deterministic_init: bool = True def setup(self): - a, b = init_deterministic(self.trace_size, self.context_size) + if self.deterministic_init: + a, b = init_deterministic(self.trace_size, self.context_size) + else: + # TODO: Will this result in the same keys for multiple FFMCells? + key = self.make_rng("ffa_params") + a, b = init_random(self.trace_size, self.context_size, key=key) self.params = (self.param("ffa_a", lambda rng: a), self.param("ffa_b", lambda rng: b)) def log_gamma(self, t: jax.Array) -> jax.Array: @@ -150,9 +166,10 @@ def map_to_h(self, inputs): z = jnp.repeat(jnp.expand_dims(gated_x, 2), self.context_size, axis=2) return (z, ts), resets - def map_from_h(self, recurrent_state, x): + def map_from_h(self, recurrent_state, inputs): """Map from the recurrent space to the Markov space""" (state, ts), reset = recurrent_state + (x, start) = inputs z_in = jnp.concatenate( [jnp.real(state), jnp.imag(state)], axis=-1 ).reshape(state.shape[0], -1) @@ -168,7 +185,7 @@ def __call__(self, recurrent_state, inputs): h = self.map_to_h(inputs) recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h) # recurrent_state is ((state, timestep), reset) - out = self.map_from_h(recurrent_state, x) + out = self.map_from_h(recurrent_state, inputs) # TODO: Remove this when we want to return all recurrent states instead of just the last one final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) @@ -178,6 +195,89 @@ def initialize_carry(self, batch_size: int = None): return self.cell.initialize_carry(batch_size) +class SFFM(nn.Module): + """Simplified Fast and Forgetful Memory""" + + trace_size: int + context_size: int + hidden_size: int + cell: nn.Module + + def setup(self): + self.W_trace = nn.Dense(self.trace_size) + self.W_context = Gate(self.context_size) + self.ffa = FFMCell(self.trace_size, self.context_size, self.hidden_size, deterministic_init=False) + self.post = nn.Sequential([ + # Default init but with smaller weights + nn.Dense(self.hidden_size, kernel_init=nn.initializers.variance_scaling(0.01, "fan_in", "truncated_normal")), + nn.LayerNorm(), + nn.leaky_relu, + nn.Dense(self.hidden_size), + nn.LayerNorm(), + nn.leaky_relu, + ]) + + def map_to_h(self, inputs): + x, resets = inputs + pre = jnp.abs(jnp.einsum("bi, bj -> bij", self.W_trace(x), self.W_context(x))) + pre = pre / jnp.sum(pre, axis=(-2,-1), keepdims=True) + # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous + ts = jnp.ones(x.shape[0], dtype=jnp.int32) + return (pre, ts), resets + + def map_from_h(self, recurrent_state, inputs): + x, resets = inputs + (state, ts), reset = recurrent_state + s = state.reshape(state.shape[0], self.context_size * self.trace_size) + eps = s.real + (s.real==0 + jnp.sign(s.real)) * 0.01 + s = s + eps + scaled = jnp.concatenate([ + jnp.log(1 + jnp.abs(s)) * jnp.sin(jnp.angle(s)), + jnp.log(1 + jnp.abs(s)) * jnp.cos(jnp.angle(s)), + ], axis=-1) + z = self.post(scaled) + return z + + def __call__(self, recurrent_state, inputs): + # Recurrent state should be ((state, timestep), reset) + # Inputs should be (x, reset) + h = self.map_to_h(inputs) + recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h) + # recurrent_state is ((state, timestep), reset) + out = self.map_from_h(recurrent_state, inputs) + + # TODO: Remove this when we want to return all recurrent states instead of just the last one + final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) + return final_recurrent_state, out + + def initialize_carry(self, batch_size: int = None): + return self.cell.initialize_carry(batch_size) + +class StackedSFFM(nn.Module): + """A multilayer version of SFFM""" + cells: List[nn.Module] + + def setup(self): + self.project = nn.Dense(cells[0].hidden_size) + + + def __call__( + self, recurrent_state: jax.Array, inputs: Any + ) -> Tuple[jax.Array, jax.Array]: + x, start = inputs + x = self.project(x) + inputs = x, start + for i, cell in enumerate(self.cells): + s, y = cell(recurrent_state[i], inputs) + x = x + y + recurrent_state[i] = s + return y, recurrent_state + + def initialize_carry(self, batch_size: int = None): + return [ + c.initialize_carry(batch_size) for c in self.cells + ] + if __name__ == "__main__": m = FFM( output_size=4, @@ -214,3 +314,21 @@ def initialize_carry(self, batch_size: int = None): # print(out.shape) # print(out_state.shape) + + # TODO: Initialize cells with different random streams so the weights are not identical + cells = [ + SFFM( + trace_size=4, + context_size=5, + hidden_size=6, + cell=MemoroidResetWrapper(cell=FFMCell(4,5,6)) + ) + for i in range(3) + ] + s2fm = StackedSFFM(cells=cells) + + s = s2fm.initialize_carry() + x = jnp.ones((10, 2)) + start = jnp.zeros(10, dtype=bool) + params = s2fm.init(jax.random.PRNGKey(0), s, (x, start)) + out_state, out = s2fm.apply(params, s, (x, start)) \ No newline at end of file From 2faac71fe2ce0526dfd1def9acb5bf9b494474c9 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Sat, 22 Jun 2024 17:06:22 +0000 Subject: [PATCH 10/38] chore: separate files - add baseclass etc --- stoix/networks/ffm_edan.py | 233 +++++++++++++++++++++++++++++++++++++ 1 file changed, 233 insertions(+) create mode 100644 stoix/networks/ffm_edan.py diff --git a/stoix/networks/ffm_edan.py b/stoix/networks/ffm_edan.py new file mode 100644 index 00000000..05a985fc --- /dev/null +++ b/stoix/networks/ffm_edan.py @@ -0,0 +1,233 @@ +from functools import partial +from typing import Tuple + +import chex +import flax.linen as nn +import jax +import jax.numpy as jnp + +Carry = chex.ArrayTree + + +class LRUCellBase(nn.Module): + """LRU cell base class.""" + + def map_to_h(self, inputs): + """Map from the input space to the recurrent state space""" + raise NotImplementedError + + def map_from_h(self, recurrent_state, x): + """Map from the recurrent space to the Markov space""" + raise NotImplementedError + + @nn.nowrap + def initialize_carry(self, rng: chex.PRNGKey, input_shape: Tuple[int, ...]) -> Carry: + """Initialize the LRU cell carry. + + Args: + rng: random number generator passed to the init_fn. + input_shape: a tuple providing the shape of the input to the cell. + + Returns: + An initialized carry for the given RNN cell. + """ + raise NotImplementedError + + @property + def num_feature_axes(self) -> int: + """Returns the number of feature axes of the LRU cell.""" + raise NotImplementedError + + +def recurrent_associative_scan( + cell: nn.Module, + state: jax.Array, + inputs: jax.Array, + axis: int = 0, +) -> jax.Array: + """Execute the associative scan to update the recurrent state. + + Note that we do a trick here by concatenating the previous state to the inputs. + This is allowed since the scan is associative. This ensures that the previous + recurrent state feeds information into the scan. Without this method, we need + separate methods for rollouts and training.""" + + # Concatenate the previous state to the inputs and scan over the result + # This ensures the previous recurrent state contributes to the current batch + # state: [start, (x, j)] + # inputs: [start, (x, j)] + scan_inputs = jax.tree.map(lambda x, s: jnp.concatenate([s, x], axis=0), inputs, state) + new_state = jax.lax.associative_scan( + cell, + scan_inputs, + axis=axis, + ) + # The zeroth index corresponds to the previous recurrent state + # We just use it to ensure continuity + # We do not actually want to use these values, so slice them away + return jax.tree.map(lambda x: x[1:], new_state) + + +class Gate(nn.Module): + """Sigmoidal gating""" + + output_size: int + + @nn.compact + def __call__(self, x): + x = nn.Dense(self.output_size)(x) + x = nn.sigmoid(x) + return x + + +def init_deterministic( + memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 +) -> Tuple[jax.Array, jax.Array]: + """Deterministic initialization of the FFM parameters.""" + a_low = 1e-6 + a_high = 0.5 + a = jnp.linspace(a_low, a_high, memory_size) + b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) + return a, b + + +class FFMCell(LRUCellBase): + """The binary associative update function for the FFM.""" + + trace_size: int + context_size: int + output_size: int + + def setup(self): + a, b = init_deterministic(self.trace_size, self.context_size) + self.params = (self.param("ffa_a", lambda rng: a), self.param("ffa_b", lambda rng: b)) + + # Mapping from input space to recurrent state space + self.pre = nn.Dense(self.trace_size) + self.gate_in = Gate(self.trace_size) + self.gate_out = Gate(self.output_size) + self.skip = nn.Dense(self.output_size) + self.mix = nn.Dense(self.output_size) + self.ln = nn.LayerNorm(use_scale=False, use_bias=False) + + def map_to_h(self, inputs): + """Map from the input space to the recurrent state space""" + x, resets = inputs + gate_in = self.gate_in(x) + pre = self.pre(x) + gated_x = pre * gate_in + # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous + ts = jnp.ones(x.shape[0], dtype=jnp.int32) + z = jnp.repeat(jnp.expand_dims(gated_x, 2), self.context_size, axis=2) + return (z, ts), resets + + def map_from_h(self, recurrent_state, x): + """Map from the recurrent space to the Markov space""" + (state, ts), reset = recurrent_state + z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape( + state.shape[0], -1 + ) + z = self.mix(z_in) + gate_out = self.gate_out(x) + skip = self.skip(x) + out = self.ln(z * gate_out) + skip * (1 - gate_out) + return out + + def log_gamma(self, t: jax.Array) -> jax.Array: + a, b = self.params + a = -jnp.abs(a).reshape((1, self.trace_size, 1)) + b = b.reshape(1, 1, self.context_size) + ab = jax.lax.complex(a, b) + return ab * t.reshape(t.shape[0], 1, 1) + + def gamma(self, t: jax.Array) -> jax.Array: + return jnp.exp(self.log_gamma(t)) + + def initialize_carry(self, batch_size: int = None): + if batch_size is None: + return jnp.zeros( + (1, self.trace_size, self.context_size), dtype=jnp.complex64 + ), jnp.ones((1,), dtype=jnp.int32) + + return jnp.zeros( + (1, batch_size, self.trace_size, self.context_size), dtype=jnp.complex64 + ), jnp.ones((1, batch_size), dtype=jnp.int32) + + def __call__(self, carry, incoming): + ( + state, + i, + ) = carry + x, j = incoming + state = state * self.gamma(j) + x + return state, j + i + + +class MemoroidResetWrapper(LRUCellBase): + """A wrapper around memoroid cells like FFM, LRU, etc that resets + the recurrent state upon a reset signal.""" + + cell: nn.Module + + def __call__(self, carry, incoming): + states, prev_start = carry + xs, start = incoming + + def reset_state(start, current_state, initial_state): + # Expand to reset all dims of state: [B, 1, 1, ...] + expanded_start = start.reshape(-1, *([1] * (current_state.ndim - 1))) + out = current_state * jnp.logical_not(expanded_start) + initial_state + return out + + initial_states = self.cell.initialize_carry() + states = jax.tree.map(partial(reset_state, start), states, initial_states) + out = self.cell(states, xs) + start_carry = jnp.logical_or(start, prev_start) + + return out, start_carry + + def map_to_h(self, inputs): + return self.cell.map_to_h(inputs) + + def map_from_h(self, recurrent_state, x): + return self.cell.map_from_h(recurrent_state, x) + + def initialize_carry(self, batch_size: int = None): + if batch_size is None: + # TODO: Should this be one or zero? + return self.cell.initialize_carry(batch_size), jnp.zeros((1,), dtype=bool) + + return self.cell.initialize_carry(batch_size), jnp.zeros((1, batch_size), dtype=bool) + + +class ScannedLRU(nn.Module): + cell: nn.Module + + @nn.compact + def __call__(self, recurrent_state, inputs): + # Recurrent state should be ((state, timestep), reset) + # Inputs should be (x, reset) + h = self.cell.map_to_h(inputs) + recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h) + # recurrent_state is ((state, timestep), reset) + out = self.cell.map_from_h(recurrent_state, x) + + # TODO: Remove this when we want to return all recurrent states instead of just the last one + final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) + return final_recurrent_state, out + + def initialize_carry(self, batch_size: int = None): + return self.cell.initialize_carry(batch_size) + + +if __name__ == "__main__": + m = ScannedLRU( + cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)) + ) + s = m.initialize_carry() + x = jnp.ones((10, 2)) + start = jnp.zeros(10, dtype=bool) + params = m.init(jax.random.PRNGKey(0), s, (x, start)) + out_state, out = m.apply(params, s, (x, start)) + + print(out) From dc25d6cf34f7711377635a94050c984a0e931b36 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Sat, 22 Jun 2024 17:12:58 +0000 Subject: [PATCH 11/38] chore: edit comments --- stoix/networks/ffm.py | 109 ++++++++++++++++++++++--------------- stoix/networks/ffm_edan.py | 11 +++- 2 files changed, 75 insertions(+), 45 deletions(-) diff --git a/stoix/networks/ffm.py b/stoix/networks/ffm.py index 5cde89bc..817d2275 100644 --- a/stoix/networks/ffm.py +++ b/stoix/networks/ffm.py @@ -13,7 +13,7 @@ def recurrent_associative_scan( axis: int = 0, ) -> jax.Array: """Execute the associative scan to update the recurrent state. - + Note that we do a trick here by concatenating the previous state to the inputs. This is allowed since the scan is associative. This ensures that the previous recurrent state feeds information into the scan. Without this method, we need @@ -24,18 +24,20 @@ def recurrent_associative_scan( # state: [start, (x, j)] # inputs: [start, (x, j)] scan_inputs = jax.tree.map(lambda x, s: jnp.concatenate([s, x], axis=0), inputs, state) - new_state = jax.lax.associative_scan( + new_state = jax.lax.associative_scan( cell, scan_inputs, axis=axis, ) # The zeroth index corresponds to the previous recurrent state - # We just use it to ensure continuity + # We just use it to ensure continuity # We do not actually want to use these values, so slice them away return jax.tree.map(lambda x: x[1:], new_state) + class Gate(nn.Module): """Sigmoidal gating""" + output_size: int @nn.compact @@ -55,6 +57,7 @@ def init_deterministic( b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) return a, b + def init_random( memory_size: int, context_size: int, min_period: int = 1, max_period: int = 10_000, *, key ) -> Tuple[jax.Array, jax.Array]: @@ -62,19 +65,28 @@ def init_random( a_low = 1e-6 a_high = 0.1 a = jax.random.uniform(k1, (memory_size,), minval=a_low, maxval=a_high) - b = 2 * jnp.pi / jnp.exp(jax.random.uniform(k2, (context_size,), minval=jnp.log(min_period), maxval=jnp.log(max_period))) + b = ( + 2 + * jnp.pi + / jnp.exp( + jax.random.uniform( + k2, (context_size,), minval=jnp.log(min_period), maxval=jnp.log(max_period) + ) + ) + ) return a, b class FFMCell(nn.Module): """The binary associative update function for the FFM.""" + trace_size: int context_size: int output_size: int deterministic_init: bool = True def setup(self): - if self.deterministic_init: + if self.deterministic_init: a, b = init_deterministic(self.trace_size, self.context_size) else: # TODO: Will this result in the same keys for multiple FFMCells? @@ -94,9 +106,13 @@ def gamma(self, t: jax.Array) -> jax.Array: def initialize_carry(self, batch_size: int = None): if batch_size is None: - return jnp.zeros((1, self.trace_size, self.context_size), dtype=jnp.complex64), jnp.ones((1,), dtype=jnp.int32) + return jnp.zeros( + (1, self.trace_size, self.context_size), dtype=jnp.complex64 + ), jnp.ones((1,), dtype=jnp.int32) - return jnp.zeros((1, batch_size, self.trace_size, self.context_size), dtype=jnp.complex64), jnp.ones((1, batch_size), dtype=jnp.int32) + return jnp.zeros( + (1, batch_size, self.trace_size, self.context_size), dtype=jnp.complex64 + ), jnp.ones((1, batch_size), dtype=jnp.int32) def __call__(self, carry, incoming): ( @@ -111,6 +127,7 @@ def __call__(self, carry, incoming): class MemoroidResetWrapper(nn.Module): """A wrapper around memoroid cells like FFM, LRU, etc that resets the recurrent state upon a reset signal.""" + cell: nn.Module def __call__(self, carry, incoming): @@ -132,8 +149,8 @@ def reset_state(start, current_state, initial_state): def initialize_carry(self, batch_size: int = None): if batch_size is None: - # TODO: Should this be one or zero? - return self.cell.initialize_carry(batch_size), jnp.zeros((1,), dtype=bool) + # TODO: Should this be one or zero? + return self.cell.initialize_carry(batch_size), jnp.zeros((1,), dtype=bool) return self.cell.initialize_carry(batch_size), jnp.zeros((batch_size,), dtype=bool) @@ -170,9 +187,9 @@ def map_from_h(self, recurrent_state, inputs): """Map from the recurrent space to the Markov space""" (state, ts), reset = recurrent_state (x, start) = inputs - z_in = jnp.concatenate( - [jnp.real(state), jnp.imag(state)], axis=-1 - ).reshape(state.shape[0], -1) + z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape( + state.shape[0], -1 + ) z = self.mix(z_in) gate_out = self.gate_out(x) skip = self.skip(x) @@ -206,21 +223,30 @@ class SFFM(nn.Module): def setup(self): self.W_trace = nn.Dense(self.trace_size) self.W_context = Gate(self.context_size) - self.ffa = FFMCell(self.trace_size, self.context_size, self.hidden_size, deterministic_init=False) - self.post = nn.Sequential([ - # Default init but with smaller weights - nn.Dense(self.hidden_size, kernel_init=nn.initializers.variance_scaling(0.01, "fan_in", "truncated_normal")), - nn.LayerNorm(), - nn.leaky_relu, - nn.Dense(self.hidden_size), - nn.LayerNorm(), - nn.leaky_relu, - ]) + self.ffa = FFMCell( + self.trace_size, self.context_size, self.hidden_size, deterministic_init=False + ) + self.post = nn.Sequential( + [ + # Default init but with smaller weights + nn.Dense( + self.hidden_size, + kernel_init=nn.initializers.variance_scaling( + 0.01, "fan_in", "truncated_normal" + ), + ), + nn.LayerNorm(), + nn.leaky_relu, + nn.Dense(self.hidden_size), + nn.LayerNorm(), + nn.leaky_relu, + ] + ) def map_to_h(self, inputs): x, resets = inputs pre = jnp.abs(jnp.einsum("bi, bj -> bij", self.W_trace(x), self.W_context(x))) - pre = pre / jnp.sum(pre, axis=(-2,-1), keepdims=True) + pre = pre / jnp.sum(pre, axis=(-2, -1), keepdims=True) # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous ts = jnp.ones(x.shape[0], dtype=jnp.int32) return (pre, ts), resets @@ -229,12 +255,15 @@ def map_from_h(self, recurrent_state, inputs): x, resets = inputs (state, ts), reset = recurrent_state s = state.reshape(state.shape[0], self.context_size * self.trace_size) - eps = s.real + (s.real==0 + jnp.sign(s.real)) * 0.01 + eps = s.real + (s.real == 0 + jnp.sign(s.real)) * 0.01 s = s + eps - scaled = jnp.concatenate([ - jnp.log(1 + jnp.abs(s)) * jnp.sin(jnp.angle(s)), - jnp.log(1 + jnp.abs(s)) * jnp.cos(jnp.angle(s)), - ], axis=-1) + scaled = jnp.concatenate( + [ + jnp.log(1 + jnp.abs(s)) * jnp.sin(jnp.angle(s)), + jnp.log(1 + jnp.abs(s)) * jnp.cos(jnp.angle(s)), + ], + axis=-1, + ) z = self.post(scaled) return z @@ -253,17 +282,16 @@ def __call__(self, recurrent_state, inputs): def initialize_carry(self, batch_size: int = None): return self.cell.initialize_carry(batch_size) + class StackedSFFM(nn.Module): """A multilayer version of SFFM""" + cells: List[nn.Module] def setup(self): self.project = nn.Dense(cells[0].hidden_size) - - def __call__( - self, recurrent_state: jax.Array, inputs: Any - ) -> Tuple[jax.Array, jax.Array]: + def __call__(self, recurrent_state: jax.Array, inputs: Any) -> Tuple[jax.Array, jax.Array]: x, start = inputs x = self.project(x) inputs = x, start @@ -271,23 +299,18 @@ def __call__( s, y = cell(recurrent_state[i], inputs) x = x + y recurrent_state[i] = s - return y, recurrent_state + return y, recurrent_state def initialize_carry(self, batch_size: int = None): - return [ - c.initialize_carry(batch_size) for c in self.cells - ] + return [c.initialize_carry(batch_size) for c in self.cells] + if __name__ == "__main__": m = FFM( output_size=4, trace_size=5, context_size=6, - cell=MemoroidResetWrapper( - cell=FFMCell( - output_size=4,trace_size=5,context_size=6 - ) - ) + cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)), ) s = m.initialize_carry() x = jnp.ones((10, 2)) @@ -321,7 +344,7 @@ def initialize_carry(self, batch_size: int = None): trace_size=4, context_size=5, hidden_size=6, - cell=MemoroidResetWrapper(cell=FFMCell(4,5,6)) + cell=MemoroidResetWrapper(cell=FFMCell(4, 5, 6)), ) for i in range(3) ] @@ -331,4 +354,4 @@ def initialize_carry(self, batch_size: int = None): x = jnp.ones((10, 2)) start = jnp.zeros(10, dtype=bool) params = s2fm.init(jax.random.PRNGKey(0), s, (x, start)) - out_state, out = s2fm.apply(params, s, (x, start)) \ No newline at end of file + out_state, out = s2fm.apply(params, s, (x, start)) diff --git a/stoix/networks/ffm_edan.py b/stoix/networks/ffm_edan.py index 05a985fc..b9886028 100644 --- a/stoix/networks/ffm_edan.py +++ b/stoix/networks/ffm_edan.py @@ -99,10 +99,15 @@ class FFMCell(LRUCellBase): output_size: int def setup(self): + + # Create the parameters that are explicitly used in the cells core computation a, b = init_deterministic(self.trace_size, self.context_size) self.params = (self.param("ffa_a", lambda rng: a), self.param("ffa_b", lambda rng: b)) - # Mapping from input space to recurrent state space + # Create the networks and parameters that are used when + # mapping from input space to recurrent state space + # This is used in the map_to_h method and is used in the + # associative scan outer loop self.pre = nn.Dense(self.trace_size) self.gate_in = Gate(self.trace_size) self.gate_out = Gate(self.output_size) @@ -111,7 +116,9 @@ def setup(self): self.ln = nn.LayerNorm(use_scale=False, use_bias=False) def map_to_h(self, inputs): - """Map from the input space to the recurrent state space""" + """Map from the input space to the recurrent state space - unlike the call function + this explicitly expects a shape including the sequence dimension. This is used in the + outer network that uses the associative scan.""" x, resets = inputs gate_in = self.gate_in(x) pre = self.pre(x) From da03f005f1fb3d5afd694257f00cbb1046efca13 Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Sat, 22 Jun 2024 19:26:48 +0100 Subject: [PATCH 12/38] modify ffm edan to work for batched/vmapped memoroids --- stoix/networks/ffm_edan.py | 87 ++++++++++++++++++++++++-------------- 1 file changed, 56 insertions(+), 31 deletions(-) diff --git a/stoix/networks/ffm_edan.py b/stoix/networks/ffm_edan.py index b9886028..7605331e 100644 --- a/stoix/networks/ffm_edan.py +++ b/stoix/networks/ffm_edan.py @@ -9,8 +9,11 @@ Carry = chex.ArrayTree -class LRUCellBase(nn.Module): - """LRU cell base class.""" +def debug_shape(x): + return jax.tree.map(lambda x: x.shape, x) + +class MemoroidCellBase(nn.Module): + """Memoroid cell base class.""" def map_to_h(self, inputs): """Map from the input space to the recurrent state space""" @@ -21,12 +24,13 @@ def map_from_h(self, recurrent_state, x): raise NotImplementedError @nn.nowrap - def initialize_carry(self, rng: chex.PRNGKey, input_shape: Tuple[int, ...]) -> Carry: - """Initialize the LRU cell carry. + def initialize_carry(self, rng: chex.PRNGKey, batch_shape: Tuple[int, ...]) -> Carry: + """Initialize the Memoroid cell carry. Args: rng: random number generator passed to the init_fn. - input_shape: a tuple providing the shape of the input to the cell. + batch_shape: a tuple providing the shape of the input to the cell, + excluding any time or feature dimension(s). Returns: An initialized carry for the given RNN cell. @@ -35,7 +39,7 @@ def initialize_carry(self, rng: chex.PRNGKey, input_shape: Tuple[int, ...]) -> C @property def num_feature_axes(self) -> int: - """Returns the number of feature axes of the LRU cell.""" + """Returns the number of feature axes of the cell.""" raise NotImplementedError @@ -56,7 +60,7 @@ def recurrent_associative_scan( # This ensures the previous recurrent state contributes to the current batch # state: [start, (x, j)] # inputs: [start, (x, j)] - scan_inputs = jax.tree.map(lambda x, s: jnp.concatenate([s, x], axis=0), inputs, state) + scan_inputs = jax.tree.map(lambda s, x: jnp.concatenate([s, x], axis=axis), state, inputs) new_state = jax.lax.associative_scan( cell, scan_inputs, @@ -65,7 +69,11 @@ def recurrent_associative_scan( # The zeroth index corresponds to the previous recurrent state # We just use it to ensure continuity # We do not actually want to use these values, so slice them away - return jax.tree.map(lambda x: x[1:], new_state) + return jax.tree.map( + lambda x: jax.lax.slice_in_dim( + x, start_index=1, limit_index=None, axis=axis + ), new_state + ) class Gate(nn.Module): @@ -91,7 +99,7 @@ def init_deterministic( return a, b -class FFMCell(LRUCellBase): +class FFMCell(MemoroidCellBase): """The binary associative update function for the FFM.""" trace_size: int @@ -150,15 +158,13 @@ def log_gamma(self, t: jax.Array) -> jax.Array: def gamma(self, t: jax.Array) -> jax.Array: return jnp.exp(self.log_gamma(t)) - def initialize_carry(self, batch_size: int = None): - if batch_size is None: - return jnp.zeros( - (1, self.trace_size, self.context_size), dtype=jnp.complex64 - ), jnp.ones((1,), dtype=jnp.int32) - + @nn.nowrap + def initialize_carry(self, rng: chex.PRNGKey, batch_shape: Tuple[int, ...]) -> Carry: + # inputs should be of shape [*batch, time, feature] + # recurrent states should be of shape [*batch, 1, feature] return jnp.zeros( - (1, batch_size, self.trace_size, self.context_size), dtype=jnp.complex64 - ), jnp.ones((1, batch_size), dtype=jnp.int32) + (*batch_shape, 1, self.trace_size, self.context_size), dtype=jnp.complex64 + ), jnp.ones((*batch_shape, 1), dtype=jnp.int32) def __call__(self, carry, incoming): ( @@ -170,23 +176,25 @@ def __call__(self, carry, incoming): return state, j + i -class MemoroidResetWrapper(LRUCellBase): +class MemoroidResetWrapper(MemoroidCellBase): """A wrapper around memoroid cells like FFM, LRU, etc that resets the recurrent state upon a reset signal.""" cell: nn.Module - def __call__(self, carry, incoming): + def __call__(self, carry, incoming, rng=None): states, prev_start = carry xs, start = incoming def reset_state(start, current_state, initial_state): # Expand to reset all dims of state: [B, 1, 1, ...] + assert initial_state.ndim == current_state.ndim expanded_start = start.reshape(-1, *([1] * (current_state.ndim - 1))) out = current_state * jnp.logical_not(expanded_start) + initial_state return out - initial_states = self.cell.initialize_carry() + # Add an extra dim, as start will be [Batch] while intialize carry expects [Batch, Feature] + initial_states = self.cell.initialize_carry(rng, ()) states = jax.tree.map(partial(reset_state, start), states, initial_states) out = self.cell(states, xs) start_carry = jnp.logical_or(start, prev_start) @@ -199,15 +207,14 @@ def map_to_h(self, inputs): def map_from_h(self, recurrent_state, x): return self.cell.map_from_h(recurrent_state, x) - def initialize_carry(self, batch_size: int = None): - if batch_size is None: - # TODO: Should this be one or zero? - return self.cell.initialize_carry(batch_size), jnp.zeros((1,), dtype=bool) - - return self.cell.initialize_carry(batch_size), jnp.zeros((1, batch_size), dtype=bool) + @nn.nowrap + def initialize_carry(self, rng: chex.PRNGKey, batch_shape: Tuple[int, ...]) -> Carry: + # inputs should be of shape [*batch, time, feature] + # recurrent states should be of shape [*batch, 1, feature] + return self.cell.initialize_carry(rng, batch_shape), jnp.zeros((*batch_shape, 1), dtype=bool) -class ScannedLRU(nn.Module): +class ScannedMemoroid(nn.Module): cell: nn.Module @nn.compact @@ -223,18 +230,36 @@ def __call__(self, recurrent_state, inputs): final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) return final_recurrent_state, out - def initialize_carry(self, batch_size: int = None): - return self.cell.initialize_carry(batch_size) + @nn.nowrap + def initialize_carry(self, rng: chex.PRNGKey, batch_shape: Tuple[int, ...]) -> Carry: + return self.cell.initialize_carry(rng, batch_shape) if __name__ == "__main__": - m = ScannedLRU( + m = ScannedMemoroid( cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)) ) - s = m.initialize_carry() x = jnp.ones((10, 2)) + s = m.initialize_carry(None, ()) start = jnp.zeros(10, dtype=bool) params = m.init(jax.random.PRNGKey(0), s, (x, start)) out_state, out = m.apply(params, s, (x, start)) print(out) + + BatchFFM = nn.vmap( + ScannedMemoroid, in_axes=0, out_axes=0, variable_axes={"params": None}, split_rngs={"params": False} + ) + + m = BatchFFM( + cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)) + ) + + x = jnp.ones((8, 10, 2)) + s = m.initialize_carry(None, (8,)) + start = jnp.zeros((8, 10), dtype=bool) + params = m.init(jax.random.PRNGKey(0), s, (x, start)) + out_state, out = m.apply(params, s, (x, start)) + + print(out.shape) + print(debug_shape(out_state)) From 9eacd4f8968735f4172785e82c4681ac31f437fc Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Sun, 23 Jun 2024 12:43:42 +0000 Subject: [PATCH 13/38] feat: make edits to ffm to make batch dim after sequence dim and add back the training script that can use ffm --- stoix/networks/ffm_edan.py | 75 ++- stoix/systems/ppo/rec_ppo_temp_ffm.py | 773 ++++++++++++++++++++++++++ 2 files changed, 820 insertions(+), 28 deletions(-) create mode 100644 stoix/systems/ppo/rec_ppo_temp_ffm.py diff --git a/stoix/networks/ffm_edan.py b/stoix/networks/ffm_edan.py index 7605331e..f3a9b14f 100644 --- a/stoix/networks/ffm_edan.py +++ b/stoix/networks/ffm_edan.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Tuple +from typing import Optional, Tuple import chex import flax.linen as nn @@ -12,6 +12,7 @@ def debug_shape(x): return jax.tree.map(lambda x: x.shape, x) + class MemoroidCellBase(nn.Module): """Memoroid cell base class.""" @@ -24,13 +25,14 @@ def map_from_h(self, recurrent_state, x): raise NotImplementedError @nn.nowrap - def initialize_carry(self, rng: chex.PRNGKey, batch_shape: Tuple[int, ...]) -> Carry: + def initialize_carry( + self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None + ) -> Carry: """Initialize the Memoroid cell carry. Args: - rng: random number generator passed to the init_fn. - batch_shape: a tuple providing the shape of the input to the cell, - excluding any time or feature dimension(s). + batch_size: the batch size of the carry. + rng: random number generator passed to the init_fn. Returns: An initialized carry for the given RNN cell. @@ -70,9 +72,7 @@ def recurrent_associative_scan( # We just use it to ensure continuity # We do not actually want to use these values, so slice them away return jax.tree.map( - lambda x: jax.lax.slice_in_dim( - x, start_index=1, limit_index=None, axis=axis - ), new_state + lambda x: jax.lax.slice_in_dim(x, start_index=1, limit_index=None, axis=axis), new_state ) @@ -159,12 +159,17 @@ def gamma(self, t: jax.Array) -> jax.Array: return jnp.exp(self.log_gamma(t)) @nn.nowrap - def initialize_carry(self, rng: chex.PRNGKey, batch_shape: Tuple[int, ...]) -> Carry: + def initialize_carry( + self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None + ) -> Carry: # inputs should be of shape [*batch, time, feature] # recurrent states should be of shape [*batch, 1, feature] - return jnp.zeros( - (*batch_shape, 1, self.trace_size, self.context_size), dtype=jnp.complex64 - ), jnp.ones((*batch_shape, 1), dtype=jnp.int32) + carry_shape = (1, self.trace_size, self.context_size) + t_shape = (1,) + if batch_size is not None: + carry_shape = (carry_shape[0], batch_size, *carry_shape[1:]) + t_shape = (*t_shape, batch_size) + return jnp.zeros(carry_shape, dtype=jnp.complex64), jnp.ones(t_shape, dtype=jnp.int32) def __call__(self, carry, incoming): ( @@ -194,7 +199,7 @@ def reset_state(start, current_state, initial_state): return out # Add an extra dim, as start will be [Batch] while intialize carry expects [Batch, Feature] - initial_states = self.cell.initialize_carry(rng, ()) + initial_states = self.cell.initialize_carry(rng) states = jax.tree.map(partial(reset_state, start), states, initial_states) out = self.cell(states, xs) start_carry = jnp.logical_or(start, prev_start) @@ -208,10 +213,15 @@ def map_from_h(self, recurrent_state, x): return self.cell.map_from_h(recurrent_state, x) @nn.nowrap - def initialize_carry(self, rng: chex.PRNGKey, batch_shape: Tuple[int, ...]) -> Carry: + def initialize_carry( + self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None + ) -> Carry: # inputs should be of shape [*batch, time, feature] # recurrent states should be of shape [*batch, 1, feature] - return self.cell.initialize_carry(rng, batch_shape), jnp.zeros((*batch_shape, 1), dtype=bool) + start_shape = (1,) + if batch_size is not None: + start_shape = (*start_shape, batch_size) + return self.cell.initialize_carry(batch_size, rng), jnp.zeros(start_shape, dtype=bool) class ScannedMemoroid(nn.Module): @@ -221,6 +231,7 @@ class ScannedMemoroid(nn.Module): def __call__(self, recurrent_state, inputs): # Recurrent state should be ((state, timestep), reset) # Inputs should be (x, reset) + x, _ = inputs h = self.cell.map_to_h(inputs) recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h) # recurrent_state is ((state, timestep), reset) @@ -231,35 +242,43 @@ def __call__(self, recurrent_state, inputs): return final_recurrent_state, out @nn.nowrap - def initialize_carry(self, rng: chex.PRNGKey, batch_shape: Tuple[int, ...]) -> Carry: - return self.cell.initialize_carry(rng, batch_shape) + def initialize_carry( + self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None + ) -> Carry: + return self.cell.initialize_carry(batch_size, rng) if __name__ == "__main__": m = ScannedMemoroid( cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)) ) - x = jnp.ones((10, 2)) - s = m.initialize_carry(None, ()) + y = jnp.ones((10, 2)) + s = m.initialize_carry(None) start = jnp.zeros(10, dtype=bool) - params = m.init(jax.random.PRNGKey(0), s, (x, start)) - out_state, out = m.apply(params, s, (x, start)) + params = m.init(jax.random.PRNGKey(0), s, (y, start)) + out_state, out = m.apply(params, s, (y, start)) print(out) BatchFFM = nn.vmap( - ScannedMemoroid, in_axes=0, out_axes=0, variable_axes={"params": None}, split_rngs={"params": False} + ScannedMemoroid, + in_axes=1, + out_axes=1, + variable_axes={"params": None}, + split_rngs={"params": False}, ) m = BatchFFM( cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)) ) - x = jnp.ones((8, 10, 2)) - s = m.initialize_carry(None, (8,)) - start = jnp.zeros((8, 10), dtype=bool) - params = m.init(jax.random.PRNGKey(0), s, (x, start)) - out_state, out = m.apply(params, s, (x, start)) + y = jnp.ones((10, 8, 2)) + s = m.initialize_carry(8) + start = jnp.zeros((10, 8), dtype=bool) + params = m.init(jax.random.PRNGKey(0), s, (y, start)) + out_state, out = m.apply(params, s, (y, start)) + + out = jnp.swapaxes(out, 0, 1) - print(out.shape) + print(out) print(debug_shape(out_state)) diff --git a/stoix/systems/ppo/rec_ppo_temp_ffm.py b/stoix/systems/ppo/rec_ppo_temp_ffm.py new file mode 100644 index 00000000..a662f659 --- /dev/null +++ b/stoix/systems/ppo/rec_ppo_temp_ffm.py @@ -0,0 +1,773 @@ +import copy +import time +from typing import Any, Dict, Tuple + +import chex +import flax +import hydra +import jax +import jax.numpy as jnp +import optax +from colorama import Fore, Style +from flax.core.frozen_dict import FrozenDict +from jumanji.env import Environment +from omegaconf import DictConfig, OmegaConf +from rich.pretty import pprint + +from stoix.base_types import ( + ActorCriticOptStates, + ActorCriticParams, + ExperimentOutput, + LearnerFn, + RecActorApply, + RecCriticApply, + RNNLearnerState, +) +from stoix.evaluator import evaluator_setup, get_rec_distribution_act_fn +from stoix.networks.base import RecurrentActor, RecurrentCritic +from stoix.networks.ffm_edan import FFMCell, MemoroidResetWrapper, ScannedMemoroid +from stoix.systems.ppo.ppo_types import ActorCriticHiddenStates, RNNPPOTransition +from stoix.utils import make_env as environments +from stoix.utils.checkpointing import Checkpointer +from stoix.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims +from stoix.utils.logger import LogEvent, StoixLogger +from stoix.utils.loss import clipped_value_loss, ppo_clip_loss +from stoix.utils.multistep import batch_truncated_generalized_advantage_estimation +from stoix.utils.total_timestep_checker import check_total_timesteps +from stoix.utils.training import make_learning_rate +from stoix.wrappers.episode_metrics import get_final_step_metrics + + +def get_learner_fn( + env: Environment, + apply_fns: Tuple[RecActorApply, RecCriticApply], + update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], + config: DictConfig, +) -> LearnerFn[RNNLearnerState]: + """Get the learner function.""" + + actor_apply_fn, critic_apply_fn = apply_fns + actor_update_fn, critic_update_fn = update_fns + + def _update_step(learner_state: RNNLearnerState, _: Any) -> Tuple[RNNLearnerState, Tuple]: + """A single update of the network. + + This function steps the environment and records the trajectory batch for + training. It then calculates advantages and targets based on the recorded + trajectory and updates the actor and critic networks based on the calculated + losses. + + Args: + learner_state (NamedTuple): + - params (ActorCriticParams): The current model parameters. + - opt_states (OptStates): The current optimizer states. + - key (PRNGKey): The random number generator state. + - env_state (State): The environment state. + - last_timestep (TimeStep): The last timestep in the current trajectory. + - dones (bool): Whether the last timestep was a terminal state. + - hstates (ActorCriticHiddenStates): The current hidden states of the RNN. + _ (Any): The current metrics info. + """ + + def _env_step( + learner_state: RNNLearnerState, _: Any + ) -> Tuple[RNNLearnerState, RNNPPOTransition]: + """Step the environment.""" + ( + params, + opt_states, + key, + env_state, + last_timestep, + last_done, + last_truncated, + hstates, + ) = learner_state + + key, policy_key = jax.random.split(key) + + # Add a batch dimension to the observation. + batched_observation = jax.tree_util.tree_map( + lambda x: x[jnp.newaxis, :], last_timestep.observation + ) + ac_in = ( + batched_observation, + last_done[jnp.newaxis, :], + ) + + # Run the network. + policy_hidden_state, actor_policy = actor_apply_fn( + params.actor_params, hstates.policy_hidden_state, ac_in + ) + critic_hidden_state, value = critic_apply_fn( + params.critic_params, hstates.critic_hidden_state, ac_in + ) + + # Sample action from the policy and squeeze out the batch dimension. + action = actor_policy.sample(seed=policy_key) + log_prob = actor_policy.log_prob(action) + value, action, log_prob = ( + value.squeeze(0), + action.squeeze(0), + log_prob.squeeze(0), + ) + + # Step the environment. + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) + + # log episode return and length + done = (timestep.discount == 0.0).reshape(-1) + truncated = (timestep.last() & (timestep.discount != 0.0)).reshape(-1) + info = timestep.extras["episode_metrics"] + + hstates = ActorCriticHiddenStates(policy_hidden_state, critic_hidden_state) + transition = RNNPPOTransition( + last_done, + last_truncated, + action, + value, + timestep.reward, + log_prob, + last_timestep.observation, + jax.tree.map(lambda x: x.squeeze(0), hstates), + info, + ) + learner_state = RNNLearnerState( + params, + opt_states, + key, + env_state, + timestep, + done, + truncated, + hstates, + ) + return learner_state, transition + + # INITIALISE RNN STATE + initial_hstates = learner_state.hstates + + # STEP ENVIRONMENT FOR ROLLOUT LENGTH + learner_state, traj_batch = jax.lax.scan( + _env_step, learner_state, None, config.system.rollout_length + ) + + # CALCULATE ADVANTAGE + ( + params, + opt_states, + key, + env_state, + last_timestep, + last_done, + last_truncated, + hstates, + ) = learner_state + + # Add a batch dimension to the observation. + batched_last_observation = jax.tree_util.tree_map( + lambda x: x[jnp.newaxis, :], last_timestep.observation + ) + ac_in = ( + batched_last_observation, + last_done[jnp.newaxis, :], + ) + + # Run the network. + _, last_val = critic_apply_fn(params.critic_params, hstates.critic_hidden_state, ac_in) + # Squeeze out the batch dimension and mask out the value of terminal states. + last_val = last_val.squeeze(0) + last_val = jnp.where(last_done, jnp.zeros_like(last_val), last_val) + + r_t = traj_batch.reward + v_t = jnp.concatenate([traj_batch.value, last_val[None, ...]], axis=0) + d_t = 1.0 - traj_batch.done.astype(jnp.float32) + d_t = (d_t * config.system.gamma).astype(jnp.float32) + advantages, targets = batch_truncated_generalized_advantage_estimation( + r_t, + d_t, + config.system.gae_lambda, + v_t, + time_major=True, + standardize_advantages=config.system.standardize_advantages, + truncation_flags=traj_batch.truncated, + ) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: + """Update the network for a single minibatch.""" + + params, opt_states = train_state + ( + traj_batch, + advantages, + targets, + ) = batch_info + + def _actor_loss_fn( + actor_params: FrozenDict, + traj_batch: RNNPPOTransition, + gae: chex.Array, + ) -> Tuple: + """Calculate the actor loss.""" + # RERUN NETWORK + + obs_and_done = (traj_batch.obs, traj_batch.done) + policy_hidden_state = jax.tree_util.tree_map( + lambda x: x[0][jnp.newaxis, ...], traj_batch.hstates.policy_hidden_state + ) + _, actor_policy = actor_apply_fn( + actor_params, policy_hidden_state, obs_and_done + ) + log_prob = actor_policy.log_prob(traj_batch.action) + + loss_actor = ppo_clip_loss( + log_prob, traj_batch.log_prob, gae, config.system.clip_eps + ) + entropy = actor_policy.entropy().mean() + + total_loss = loss_actor - config.system.ent_coef * entropy + loss_info = { + "actor_loss": loss_actor, + "entropy": entropy, + } + return total_loss, loss_info + + def _critic_loss_fn( + critic_params: FrozenDict, + traj_batch: RNNPPOTransition, + targets: chex.Array, + ) -> Tuple: + """Calculate the critic loss.""" + # RERUN NETWORK + obs_and_done = (traj_batch.obs, traj_batch.done) + critic_hidden_state = jax.tree_util.tree_map( + lambda x: x[0][jnp.newaxis, ...], traj_batch.hstates.critic_hidden_state + ) + _, value = critic_apply_fn(critic_params, critic_hidden_state, obs_and_done) + + # CALCULATE VALUE LOSS + value_loss = clipped_value_loss( + value, traj_batch.value, targets, config.system.clip_eps + ) + + total_loss = config.system.vf_coef * value_loss + loss_info = { + "value_loss": value_loss, + } + return total_loss, loss_info + + # CALCULATE ACTOR LOSS + actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True) + actor_grads, actor_loss_info = actor_grad_fn( + params.actor_params, traj_batch, advantages + ) + + # CALCULATE CRITIC LOSS + critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True) + critic_grads, critic_loss_info = critic_grad_fn( + params.critic_params, traj_batch, targets + ) + + # Compute the parallel mean (pmean) over the batch. + # This calculation is inspired by the Anakin architecture demo notebook. + # available at https://tinyurl.com/26tdzs5x + # This pmean could be a regular mean as the batch axis is on the same device. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="batch" + ) + # pmean over devices. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="device" + ) + + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="batch" + ) + # pmean over devices. + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="device" + ) + + # UPDATE ACTOR PARAMS AND OPTIMISER STATE + actor_updates, actor_new_opt_state = actor_update_fn( + actor_grads, opt_states.actor_opt_state + ) + actor_new_params = optax.apply_updates(params.actor_params, actor_updates) + + # UPDATE CRITIC PARAMS AND OPTIMISER STATE + critic_updates, critic_new_opt_state = critic_update_fn( + critic_grads, opt_states.critic_opt_state + ) + critic_new_params = optax.apply_updates(params.critic_params, critic_updates) + + new_params = ActorCriticParams(actor_new_params, critic_new_params) + new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) + + # PACK LOSS INFO + loss_info = { + **actor_loss_info, + **critic_loss_info, + } + + return (new_params, new_opt_state), loss_info + + ( + params, + opt_states, + init_hstates, + traj_batch, + advantages, + targets, + key, + ) = update_state + key, shuffle_key = jax.random.split(key) + + # SHUFFLE MINIBATCHES + batch = (traj_batch, advantages, targets) + num_recurrent_chunks = ( + config.system.rollout_length // config.system.recurrent_chunk_size + ) + batch = jax.tree_util.tree_map( + lambda x: x.reshape( + config.system.recurrent_chunk_size, + config.arch.num_envs * num_recurrent_chunks, + *x.shape[2:], + ), + batch, + ) + permutation = jax.random.permutation( + shuffle_key, config.arch.num_envs * num_recurrent_chunks + ) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=1), batch + ) + reshaped_batch = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, (x.shape[0], config.system.num_minibatches, -1, *x.shape[2:]) + ), + shuffled_batch, + ) + minibatches = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 0), reshaped_batch) + + # UPDATE MINIBATCHES + (params, opt_states), loss_info = jax.lax.scan( + _update_minibatch, (params, opt_states), minibatches + ) + + update_state = ( + params, + opt_states, + init_hstates, + traj_batch, + advantages, + targets, + key, + ) + return update_state, loss_info + + init_hstates = jax.tree_util.tree_map(lambda x: x[None, :], initial_hstates) + update_state = ( + params, + opt_states, + init_hstates, + traj_batch, + advantages, + targets, + key, + ) + + # UPDATE EPOCHS + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.epochs + ) + + params, opt_states, _, traj_batch, advantages, targets, key = update_state + learner_state = RNNLearnerState( + params, + opt_states, + key, + env_state, + last_timestep, + last_done, + last_truncated, + hstates, + ) + metric = traj_batch.info + return learner_state, (metric, loss_info) + + def learner_fn(learner_state: RNNLearnerState) -> ExperimentOutput[RNNLearnerState]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + + Args: + learner_state (NamedTuple): + - params (ActorCriticParams): The initial model parameters. + - opt_states (OptStates): The initial optimizer states. + - key (chex.PRNGKey): The random number generator state. + - env_state (LogEnvState): The environment state. + - timesteps (TimeStep): The initial timestep in the initial trajectory. + - dones (bool): Whether the initial timestep was a terminal state. + - hstateS (ActorCriticHiddenStates): The initial hidden states of the RNN. + """ + + batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") + + learner_state, (episode_info, loss_info) = jax.lax.scan( + batched_update_step, learner_state, None, config.arch.num_updates_per_eval + ) + return ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ) + + return learner_fn + + +def learner_setup( + env: Environment, keys: chex.Array, config: DictConfig +) -> Tuple[LearnerFn[RNNLearnerState], RecurrentActor, Any, RNNLearnerState]: + """Initialise learner_fn, network, optimiser, environment and states.""" + # Get available TPU cores. + n_devices = len(jax.devices()) + + # Get number/dimension of actions. + num_actions = int(env.action_spec().num_values) + config.system.action_dim = num_actions + + # PRNG keys. + key, actor_net_key, critic_net_key = keys + + # FFM CHANGES HERE + output_size = 64 + trace_size = 64 + context_size = 64 + BatchFFM = flax.linen.vmap( + ScannedMemoroid, + in_axes=1, + out_axes=1, + variable_axes={"params": None}, + split_rngs={"params": False}, + ) + + # Define network and optimisers. + actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + actor_rnn = BatchFFM( + cell=MemoroidResetWrapper( + cell=FFMCell(output_size=output_size, trace_size=trace_size, context_size=context_size) + ) + ) + actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) + actor_action_head = hydra.utils.instantiate( + config.network.actor_network.action_head, action_dim=num_actions + ) + critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) + critic_rnn = BatchFFM( + cell=MemoroidResetWrapper( + cell=FFMCell(output_size=output_size, trace_size=trace_size, context_size=context_size) + ) + ) + critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) + critic_head = hydra.utils.instantiate(config.network.critic_network.critic_head) + + actor_network = RecurrentActor( + pre_torso=actor_pre_torso, + rnn=actor_rnn, + post_torso=actor_post_torso, + action_head=actor_action_head, + ) + critic_network = RecurrentCritic( + pre_torso=critic_pre_torso, + rnn=critic_rnn, + post_torso=critic_post_torso, + critic_head=critic_head, + ) + + actor_lr = make_learning_rate( + config.system.actor_lr, config, config.system.epochs, config.system.num_minibatches + ) + critic_lr = make_learning_rate( + config.system.critic_lr, config, config.system.epochs, config.system.num_minibatches + ) + + actor_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(actor_lr, eps=1e-5), + ) + critic_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(critic_lr, eps=1e-5), + ) + + # Initialise observation + init_obs = env.observation_spec().generate_value() + init_obs = jax.tree_util.tree_map( + lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0), + init_obs, + ) + init_obs = jax.tree_util.tree_map(lambda x: x[None, ...], init_obs) + init_done = jnp.zeros((1, config.arch.num_envs), dtype=bool) + init_x = (init_obs, init_done) + + # Initialise hidden states. + init_policy_hstate = actor_rnn.initialize_carry(config.arch.num_envs) + init_critic_hstate = critic_rnn.initialize_carry(config.arch.num_envs) + + # initialise params and optimiser state. + actor_params = actor_network.init(actor_net_key, init_policy_hstate, init_x) + actor_opt_state = actor_optim.init(actor_params) + critic_params = critic_network.init(critic_net_key, init_critic_hstate, init_x) + critic_opt_state = critic_optim.init(critic_params) + + actor_network_apply_fn = actor_network.apply + critic_network_apply_fn = critic_network.apply + + # Get network apply functions and optimiser updates. + apply_fns = (actor_network_apply_fn, critic_network_apply_fn) + update_fns = (actor_optim.update, critic_optim.update) + + # Get batched iterated update and replicate it to pmap it over cores. + learn = get_learner_fn(env, apply_fns, update_fns, config) + learn = jax.pmap(learn, axis_name="device") + + # Pack params and initial states. + params = ActorCriticParams(actor_params, critic_params) + hstates = ActorCriticHiddenStates(init_policy_hstate, init_critic_hstate) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.system.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, restored_hstates = loaded_checkpoint.restore_params(restore_hstates=True) + # Update the params and hstates + params = restored_params + hstates = restored_hstates if restored_hstates else hstates + + # Initialise environment states and timesteps: across devices and batches. + key, *env_keys = jax.random.split( + key, n_devices * config.arch.update_batch_size * config.arch.num_envs + 1 + ) + env_states, timesteps = jax.vmap(env.reset, in_axes=(0))( + jnp.stack(env_keys), + ) + reshape_states = lambda x: x.reshape( + (n_devices, config.arch.update_batch_size, config.arch.num_envs) + x.shape[1:] + ) + # (devices, update batch size, num_envs, ...) + env_states = jax.tree_util.tree_map(reshape_states, env_states) + timesteps = jax.tree_util.tree_map(reshape_states, timesteps) + + # Define params to be replicated across devices and batches. + dones = jnp.zeros( + (config.arch.num_envs,), + dtype=bool, + ) + truncated = jnp.zeros( + (config.arch.num_envs,), + dtype=bool, + ) + key, step_key = jax.random.split(key) + step_keys = jax.random.split(step_key, n_devices * config.arch.update_batch_size) + reshape_keys = lambda x: x.reshape((n_devices, config.arch.update_batch_size) + x.shape[1:]) + step_keys = reshape_keys(jnp.stack(step_keys)) + opt_states = ActorCriticOptStates(actor_opt_state, critic_opt_state) + replicate_learner = (params, opt_states, hstates, dones, truncated) + + # Duplicate learner for update_batch_size. + broadcast = lambda x: jnp.broadcast_to(x, (config.arch.update_batch_size,) + x.shape) + replicate_learner = jax.tree_util.tree_map(broadcast, replicate_learner) + + # Duplicate learner across devices. + replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) + + # Initialise learner state. + params, opt_states, hstates, dones, truncated = replicate_learner + init_learner_state = RNNLearnerState( + params=params, + opt_states=opt_states, + key=step_keys, + env_state=env_states, + timestep=timesteps, + done=dones, + truncated=truncated, + hstates=hstates, + ) + return learn, actor_network, actor_rnn, init_learner_state + + +def run_experiment(_config: DictConfig) -> float: + """Runs experiment.""" + config = copy.deepcopy(_config) + + # Calculate total timesteps. + n_devices = len(jax.devices()) + config.num_devices = n_devices + config = check_total_timesteps(config) + assert ( + config.arch.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + + # Set recurrent chunk size. + if config.system.recurrent_chunk_size is None: + config.system.recurrent_chunk_size = config.system.rollout_length + else: + assert ( + config.system.rollout_length % config.system.recurrent_chunk_size == 0 + ), "Rollout length must be divisible by recurrent chunk size." + + # Create the environments for train and eval. + env, eval_env = environments.make(config) + + # PRNG keys. + key, key_e, actor_net_key, critic_net_key = jax.random.split( + jax.random.PRNGKey(config.arch.seed), num=4 + ) + + # Setup learner. + learn, actor_network, actor_rnn, learner_state = learner_setup( + env, (key, actor_net_key, critic_net_key), config + ) + + # Setup evaluator. + evaluator, absolute_metric_evaluator, (trained_params, eval_keys) = evaluator_setup( + eval_env=eval_env, + key_e=key_e, + eval_act_fn=get_rec_distribution_act_fn(config, actor_network.apply), + params=learner_state.params.actor_params, + config=config, + use_recurrent_net=True, + scanned_rnn=actor_rnn, + ) + + # Calculate number of updates per evaluation. + config.arch.num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation + steps_per_rollout = ( + n_devices + * config.arch.num_updates_per_eval + * config.system.rollout_length + * config.arch.update_batch_size + * config.arch.num_envs + ) + + # Logger setup + logger = StoixLogger(config) + cfg: Dict = OmegaConf.to_container(config, resolve=True) + cfg["arch"]["devices"] = jax.devices() + pprint(cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.system.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Run experiment for a total number of evaluations. + max_episode_return = jnp.float32(-1e7) + best_params = None + for eval_step in range(config.arch.num_evaluation): + # Train. + start_time = time.time() + learner_output = learn(learner_state) + jax.block_until_ready(learner_output) + + # Log the results of the training. + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + episode_metrics, ep_completed = get_final_step_metrics(learner_output.episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + + # Separately log timesteps, actoring metrics and training metrics. + logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) + if ep_completed: # only log episode metrics if an episode was completed in the rollout. + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) + + # EVALUATION DOESNT CURRENTLY WORK YET... + + # Prepare for evaluation. + # start_time = time.time() + # trained_params = unreplicate_batch_dim(learner_output.learner_state.params.actor_params) + # key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + # eval_keys = jnp.stack(eval_keys) + # eval_keys = eval_keys.reshape(n_devices, -1) + + # # Evaluate. + # evaluator_output = evaluator(trained_params, eval_keys) + # jax.block_until_ready(evaluator_output) + + # # Log the results of the evaluation. + # elapsed_time = time.time() - start_time + # episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) + + # steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) + # evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + # logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) + + # if save_checkpoint: + # # Save checkpoint of learner state + # checkpointer.save( + # timestep=int(steps_per_rollout * (eval_step + 1)), + # unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), + # episode_return=episode_return, + # ) + + # if config.arch.absolute_metric and max_episode_return <= episode_return: + # best_params = copy.deepcopy(trained_params) + # max_episode_return = episode_return + + # Update runner state to continue training. + learner_state = learner_output.learner_state + + # Measure absolute metric. + # if config.arch.absolute_metric: + # start_time = time.time() + + # key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + # eval_keys = jnp.stack(eval_keys) + # eval_keys = eval_keys.reshape(n_devices, -1) + + # evaluator_output = absolute_metric_evaluator(best_params, eval_keys) + # jax.block_until_ready(evaluator_output) + + # elapsed_time = time.time() - start_time + + # t = int(steps_per_rollout * (eval_step + 1)) + # steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) + # evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + # logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE) + + # Stop the logger. + logger.stop() + # Record the performance for the final evaluation run. If the absolute metric is not + # calculated, this will be the final evaluation run. + # eval_performance = float(jnp.mean(evaluator_output.episode_metrics[config.env.eval_metric])) + # return eval_performance + + +@hydra.main(config_path="../../configs", config_name="default_rec_ppo.yaml", version_base="1.2") +def hydra_entry_point(cfg: DictConfig) -> float: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + + # Run experiment. + eval_performance = run_experiment(cfg) + + print(f"{Fore.CYAN}{Style.BRIGHT}Recurrent PPO experiment completed{Style.RESET_ALL}") + return eval_performance + + +if __name__ == "__main__": + hydra_entry_point() From 15473a421522f68826b80cadf0e156d57e05d2c8 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Mon, 24 Jun 2024 14:08:25 +0000 Subject: [PATCH 14/38] chore: change scanned memorid to expect non sequence dimension carry and return a non sequence dimension carry --- stoix/networks/ffm_edan.py | 15 +++- stoix/systems/ppo/rec_ppo_temp_ffm.py | 98 +++++++++++++-------------- 2 files changed, 60 insertions(+), 53 deletions(-) diff --git a/stoix/networks/ffm_edan.py b/stoix/networks/ffm_edan.py index f3a9b14f..00232da4 100644 --- a/stoix/networks/ffm_edan.py +++ b/stoix/networks/ffm_edan.py @@ -231,6 +231,10 @@ class ScannedMemoroid(nn.Module): def __call__(self, recurrent_state, inputs): # Recurrent state should be ((state, timestep), reset) # Inputs should be (x, reset) + + # Unsqueeze the recurrent state to add the sequence dimension of size 1 + recurrent_state = jax.tree.map(lambda x: jnp.expand_dims(x, 0), recurrent_state) + x, _ = inputs h = self.cell.map_to_h(inputs) recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h) @@ -239,13 +243,18 @@ def __call__(self, recurrent_state, inputs): # TODO: Remove this when we want to return all recurrent states instead of just the last one final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) + + # Squeeze the sequence dimension of 1 out + final_recurrent_state = jax.tree.map(lambda x: jnp.squeeze(x, 0), final_recurrent_state) + return final_recurrent_state, out @nn.nowrap def initialize_carry( self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None ) -> Carry: - return self.cell.initialize_carry(batch_size, rng) + # We squeeze the sequence dim of 1 out. + return jax.tree.map(lambda x: x.squeeze(0), self.cell.initialize_carry(batch_size, rng)) if __name__ == "__main__": @@ -262,8 +271,8 @@ def initialize_carry( BatchFFM = nn.vmap( ScannedMemoroid, - in_axes=1, - out_axes=1, + in_axes=(((0, 0), 0), 1), + out_axes=(((0, 0), 0), 1), variable_axes={"params": None}, split_rngs={"params": False}, ) diff --git a/stoix/systems/ppo/rec_ppo_temp_ffm.py b/stoix/systems/ppo/rec_ppo_temp_ffm.py index a662f659..9f6c692f 100644 --- a/stoix/systems/ppo/rec_ppo_temp_ffm.py +++ b/stoix/systems/ppo/rec_ppo_temp_ffm.py @@ -129,7 +129,7 @@ def _env_step( timestep.reward, log_prob, last_timestep.observation, - jax.tree.map(lambda x: x.squeeze(0), hstates), + hstates, info, ) learner_state = RNNLearnerState( @@ -216,7 +216,7 @@ def _actor_loss_fn( obs_and_done = (traj_batch.obs, traj_batch.done) policy_hidden_state = jax.tree_util.tree_map( - lambda x: x[0][jnp.newaxis, ...], traj_batch.hstates.policy_hidden_state + lambda x: x[0], traj_batch.hstates.policy_hidden_state ) _, actor_policy = actor_apply_fn( actor_params, policy_hidden_state, obs_and_done @@ -244,7 +244,7 @@ def _critic_loss_fn( # RERUN NETWORK obs_and_done = (traj_batch.obs, traj_batch.done) critic_hidden_state = jax.tree_util.tree_map( - lambda x: x[0][jnp.newaxis, ...], traj_batch.hstates.critic_hidden_state + lambda x: x[0], traj_batch.hstates.critic_hidden_state ) _, value = critic_apply_fn(critic_params, critic_hidden_state, obs_and_done) @@ -450,8 +450,8 @@ def learner_setup( context_size = 64 BatchFFM = flax.linen.vmap( ScannedMemoroid, - in_axes=1, - out_axes=1, + in_axes=(((0, 0), 0), 1), + out_axes=(((0, 0), 0), 1), variable_axes={"params": None}, split_rngs={"params": False}, ) @@ -694,66 +694,64 @@ def run_experiment(_config: DictConfig) -> float: logger.log(episode_metrics, t, eval_step, LogEvent.ACT) logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) - # EVALUATION DOESNT CURRENTLY WORK YET... - # Prepare for evaluation. - # start_time = time.time() - # trained_params = unreplicate_batch_dim(learner_output.learner_state.params.actor_params) - # key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) - # eval_keys = jnp.stack(eval_keys) - # eval_keys = eval_keys.reshape(n_devices, -1) - - # # Evaluate. - # evaluator_output = evaluator(trained_params, eval_keys) - # jax.block_until_ready(evaluator_output) - - # # Log the results of the evaluation. - # elapsed_time = time.time() - start_time - # episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) - - # steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) - # evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - # logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) - - # if save_checkpoint: - # # Save checkpoint of learner state - # checkpointer.save( - # timestep=int(steps_per_rollout * (eval_step + 1)), - # unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), - # episode_return=episode_return, - # ) - - # if config.arch.absolute_metric and max_episode_return <= episode_return: - # best_params = copy.deepcopy(trained_params) - # max_episode_return = episode_return + start_time = time.time() + trained_params = unreplicate_batch_dim(learner_output.learner_state.params.actor_params) + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + # Evaluate. + evaluator_output = evaluator(trained_params, eval_keys) + jax.block_until_ready(evaluator_output) + + # Log the results of the evaluation. + elapsed_time = time.time() - start_time + episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) + + steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) + evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) + + if save_checkpoint: + # Save checkpoint of learner state + checkpointer.save( + timestep=int(steps_per_rollout * (eval_step + 1)), + unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(trained_params) + max_episode_return = episode_return # Update runner state to continue training. learner_state = learner_output.learner_state # Measure absolute metric. - # if config.arch.absolute_metric: - # start_time = time.time() + if config.arch.absolute_metric: + start_time = time.time() - # key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) - # eval_keys = jnp.stack(eval_keys) - # eval_keys = eval_keys.reshape(n_devices, -1) + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) - # evaluator_output = absolute_metric_evaluator(best_params, eval_keys) - # jax.block_until_ready(evaluator_output) + evaluator_output = absolute_metric_evaluator(best_params, eval_keys) + jax.block_until_ready(evaluator_output) - # elapsed_time = time.time() - start_time + elapsed_time = time.time() - start_time - # t = int(steps_per_rollout * (eval_step + 1)) - # steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) - # evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - # logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE) + t = int(steps_per_rollout * (eval_step + 1)) + steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) + evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE) # Stop the logger. logger.stop() # Record the performance for the final evaluation run. If the absolute metric is not # calculated, this will be the final evaluation run. - # eval_performance = float(jnp.mean(evaluator_output.episode_metrics[config.env.eval_metric])) - # return eval_performance + eval_performance = float(jnp.mean(evaluator_output.episode_metrics[config.env.eval_metric])) + return eval_performance @hydra.main(config_path="../../configs", config_name="default_rec_ppo.yaml", version_base="1.2") From 74e35b3d16f0214f98da2a191dc63c1d03fa885d Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Mon, 24 Jun 2024 15:19:14 +0000 Subject: [PATCH 15/38] feat: add explicit batch dimension and network config - rec_ppo now works with memoroid conf --- stoix/configs/network/memoroid.yaml | 47 ++ stoix/networks/{ffm_edan.py => memoroid.py} | 62 +- stoix/systems/ppo/rec_ppo_temp_ffm.py | 771 -------------------- 3 files changed, 77 insertions(+), 803 deletions(-) create mode 100644 stoix/configs/network/memoroid.yaml rename stoix/networks/{ffm_edan.py => memoroid.py} (85%) delete mode 100644 stoix/systems/ppo/rec_ppo_temp_ffm.py diff --git a/stoix/configs/network/memoroid.yaml b/stoix/configs/network/memoroid.yaml new file mode 100644 index 00000000..8afa9227 --- /dev/null +++ b/stoix/configs/network/memoroid.yaml @@ -0,0 +1,47 @@ +# ---Recurrent Structure Networks for PPO --- + +actor_network: + pre_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [128] + use_layer_norm: False + activation: silu + rnn_layer: + _target_: stoix.networks.memoroid.ScannedMemoroid + cell: + _target_: stoix.networks.memoroid.MemoroidResetWrapper + cell: + _target_: stoix.networks.memoroid.FFMCell + trace_size: 128 + context_size: 128 + output_size: 128 + post_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [128] + use_layer_norm: False + activation: silu + action_head: + _target_: stoix.networks.heads.CategoricalHead + +critic_network: + pre_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [128] + use_layer_norm: False + activation: silu + rnn_layer: + _target_: stoix.networks.memoroid.ScannedMemoroid + cell: + _target_: stoix.networks.memoroid.MemoroidResetWrapper + cell: + _target_: stoix.networks.memoroid.FFMCell + trace_size: 128 + context_size: 128 + output_size: 128 + post_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [128] + use_layer_norm: False + activation: silu + critic_head: + _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/networks/ffm_edan.py b/stoix/networks/memoroid.py similarity index 85% rename from stoix/networks/ffm_edan.py rename to stoix/networks/memoroid.py index 00232da4..393a370b 100644 --- a/stoix/networks/ffm_edan.py +++ b/stoix/networks/memoroid.py @@ -6,7 +6,15 @@ import jax import jax.numpy as jnp +# Typing aliases Carry = chex.ArrayTree +Timestep = chex.Array +MemoroidRecurrentState = Tuple[Tuple[chex.Array, chex.Array], chex.Array] +Inputs = Tuple[chex.Array, chex.Array] +Outputs = Tuple[chex.Array, chex.Array] +RecurrentState = Tuple[chex.Array, chex.Array] +CarryState = Tuple[RecurrentState, chex.Array] +ScanInputs = Tuple[RecurrentState, Inputs] def debug_shape(x): @@ -47,8 +55,8 @@ def num_feature_axes(self) -> int: def recurrent_associative_scan( cell: nn.Module, - state: jax.Array, - inputs: jax.Array, + state: chex.Array, + inputs: chex.Array, axis: int = 0, ) -> jax.Array: """Execute the associative scan to update the recurrent state. @@ -132,15 +140,15 @@ def map_to_h(self, inputs): pre = self.pre(x) gated_x = pre * gate_in # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous - ts = jnp.ones(x.shape[0], dtype=jnp.int32) - z = jnp.repeat(jnp.expand_dims(gated_x, 2), self.context_size, axis=2) + ts = jnp.ones(x.shape[0:2], dtype=jnp.int32) + z = jnp.repeat(jnp.expand_dims(gated_x, 3), self.context_size, axis=3) return (z, ts), resets def map_from_h(self, recurrent_state, x): """Map from the recurrent space to the Markov space""" (state, ts), reset = recurrent_state z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape( - state.shape[0], -1 + state.shape[0], state.shape[1], -1 ) z = self.mix(z_in) gate_out = self.gate_out(x) @@ -150,10 +158,10 @@ def map_from_h(self, recurrent_state, x): def log_gamma(self, t: jax.Array) -> jax.Array: a, b = self.params - a = -jnp.abs(a).reshape((1, self.trace_size, 1)) - b = b.reshape(1, 1, self.context_size) + a = -jnp.abs(a).reshape((1, 1, self.trace_size, 1)) + b = b.reshape(1, 1, 1, self.context_size) ab = jax.lax.complex(a, b) - return ab * t.reshape(t.shape[0], 1, 1) + return ab * t.reshape(t.shape[0], t.shape[1], 1, 1) def gamma(self, t: jax.Array) -> jax.Array: return jnp.exp(self.log_gamma(t)) @@ -192,14 +200,14 @@ def __call__(self, carry, incoming, rng=None): xs, start = incoming def reset_state(start, current_state, initial_state): - # Expand to reset all dims of state: [B, 1, 1, ...] + # Expand to reset all dims of state: [1, B, 1, ...] assert initial_state.ndim == current_state.ndim - expanded_start = start.reshape(-1, *([1] * (current_state.ndim - 1))) + expanded_start = start.reshape(-1, start.shape[1], *([1] * (current_state.ndim - 2))) out = current_state * jnp.logical_not(expanded_start) + initial_state return out # Add an extra dim, as start will be [Batch] while intialize carry expects [Batch, Feature] - initial_states = self.cell.initialize_carry(rng) + initial_states = self.cell.initialize_carry(rng=rng, batch_size=start.shape[1]) states = jax.tree.map(partial(reset_state, start), states, initial_states) out = self.cell(states, xs) start_carry = jnp.logical_or(start, prev_start) @@ -229,6 +237,9 @@ class ScannedMemoroid(nn.Module): @nn.compact def __call__(self, recurrent_state, inputs): + """Apply the ScannedMemoroid. + This takes in a sequence of batched states and inputs. + The recurrent state that is used requires no sequence dimension but does require a batch dimension.""" # Recurrent state should be ((state, timestep), reset) # Inputs should be (x, reset) @@ -253,37 +264,24 @@ def __call__(self, recurrent_state, inputs): def initialize_carry( self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None ) -> Carry: + """Initialize the carry for the ScannedMemoroid. This returns the carry in the shape [Batch, ...] i.e. it contains no sequence dimension""" # We squeeze the sequence dim of 1 out. return jax.tree.map(lambda x: x.squeeze(0), self.cell.initialize_carry(batch_size, rng)) if __name__ == "__main__": - m = ScannedMemoroid( - cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)) - ) - y = jnp.ones((10, 2)) - s = m.initialize_carry(None) - start = jnp.zeros(10, dtype=bool) - params = m.init(jax.random.PRNGKey(0), s, (y, start)) - out_state, out = m.apply(params, s, (y, start)) - - print(out) - - BatchFFM = nn.vmap( - ScannedMemoroid, - in_axes=(((0, 0), 0), 1), - out_axes=(((0, 0), 0), 1), - variable_axes={"params": None}, - split_rngs={"params": False}, - ) + BatchFFM = ScannedMemoroid m = BatchFFM( cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)) ) - y = jnp.ones((10, 8, 2)) - s = m.initialize_carry(8) - start = jnp.zeros((10, 8), dtype=bool) + batch_size = 8 + time_steps = 10 + + y = jnp.ones((time_steps, batch_size, 2)) + s = m.initialize_carry(batch_size) + start = jnp.zeros((time_steps, batch_size), dtype=bool) params = m.init(jax.random.PRNGKey(0), s, (y, start)) out_state, out = m.apply(params, s, (y, start)) diff --git a/stoix/systems/ppo/rec_ppo_temp_ffm.py b/stoix/systems/ppo/rec_ppo_temp_ffm.py deleted file mode 100644 index 9f6c692f..00000000 --- a/stoix/systems/ppo/rec_ppo_temp_ffm.py +++ /dev/null @@ -1,771 +0,0 @@ -import copy -import time -from typing import Any, Dict, Tuple - -import chex -import flax -import hydra -import jax -import jax.numpy as jnp -import optax -from colorama import Fore, Style -from flax.core.frozen_dict import FrozenDict -from jumanji.env import Environment -from omegaconf import DictConfig, OmegaConf -from rich.pretty import pprint - -from stoix.base_types import ( - ActorCriticOptStates, - ActorCriticParams, - ExperimentOutput, - LearnerFn, - RecActorApply, - RecCriticApply, - RNNLearnerState, -) -from stoix.evaluator import evaluator_setup, get_rec_distribution_act_fn -from stoix.networks.base import RecurrentActor, RecurrentCritic -from stoix.networks.ffm_edan import FFMCell, MemoroidResetWrapper, ScannedMemoroid -from stoix.systems.ppo.ppo_types import ActorCriticHiddenStates, RNNPPOTransition -from stoix.utils import make_env as environments -from stoix.utils.checkpointing import Checkpointer -from stoix.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims -from stoix.utils.logger import LogEvent, StoixLogger -from stoix.utils.loss import clipped_value_loss, ppo_clip_loss -from stoix.utils.multistep import batch_truncated_generalized_advantage_estimation -from stoix.utils.total_timestep_checker import check_total_timesteps -from stoix.utils.training import make_learning_rate -from stoix.wrappers.episode_metrics import get_final_step_metrics - - -def get_learner_fn( - env: Environment, - apply_fns: Tuple[RecActorApply, RecCriticApply], - update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], - config: DictConfig, -) -> LearnerFn[RNNLearnerState]: - """Get the learner function.""" - - actor_apply_fn, critic_apply_fn = apply_fns - actor_update_fn, critic_update_fn = update_fns - - def _update_step(learner_state: RNNLearnerState, _: Any) -> Tuple[RNNLearnerState, Tuple]: - """A single update of the network. - - This function steps the environment and records the trajectory batch for - training. It then calculates advantages and targets based on the recorded - trajectory and updates the actor and critic networks based on the calculated - losses. - - Args: - learner_state (NamedTuple): - - params (ActorCriticParams): The current model parameters. - - opt_states (OptStates): The current optimizer states. - - key (PRNGKey): The random number generator state. - - env_state (State): The environment state. - - last_timestep (TimeStep): The last timestep in the current trajectory. - - dones (bool): Whether the last timestep was a terminal state. - - hstates (ActorCriticHiddenStates): The current hidden states of the RNN. - _ (Any): The current metrics info. - """ - - def _env_step( - learner_state: RNNLearnerState, _: Any - ) -> Tuple[RNNLearnerState, RNNPPOTransition]: - """Step the environment.""" - ( - params, - opt_states, - key, - env_state, - last_timestep, - last_done, - last_truncated, - hstates, - ) = learner_state - - key, policy_key = jax.random.split(key) - - # Add a batch dimension to the observation. - batched_observation = jax.tree_util.tree_map( - lambda x: x[jnp.newaxis, :], last_timestep.observation - ) - ac_in = ( - batched_observation, - last_done[jnp.newaxis, :], - ) - - # Run the network. - policy_hidden_state, actor_policy = actor_apply_fn( - params.actor_params, hstates.policy_hidden_state, ac_in - ) - critic_hidden_state, value = critic_apply_fn( - params.critic_params, hstates.critic_hidden_state, ac_in - ) - - # Sample action from the policy and squeeze out the batch dimension. - action = actor_policy.sample(seed=policy_key) - log_prob = actor_policy.log_prob(action) - value, action, log_prob = ( - value.squeeze(0), - action.squeeze(0), - log_prob.squeeze(0), - ) - - # Step the environment. - env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - - # log episode return and length - done = (timestep.discount == 0.0).reshape(-1) - truncated = (timestep.last() & (timestep.discount != 0.0)).reshape(-1) - info = timestep.extras["episode_metrics"] - - hstates = ActorCriticHiddenStates(policy_hidden_state, critic_hidden_state) - transition = RNNPPOTransition( - last_done, - last_truncated, - action, - value, - timestep.reward, - log_prob, - last_timestep.observation, - hstates, - info, - ) - learner_state = RNNLearnerState( - params, - opt_states, - key, - env_state, - timestep, - done, - truncated, - hstates, - ) - return learner_state, transition - - # INITIALISE RNN STATE - initial_hstates = learner_state.hstates - - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( - _env_step, learner_state, None, config.system.rollout_length - ) - - # CALCULATE ADVANTAGE - ( - params, - opt_states, - key, - env_state, - last_timestep, - last_done, - last_truncated, - hstates, - ) = learner_state - - # Add a batch dimension to the observation. - batched_last_observation = jax.tree_util.tree_map( - lambda x: x[jnp.newaxis, :], last_timestep.observation - ) - ac_in = ( - batched_last_observation, - last_done[jnp.newaxis, :], - ) - - # Run the network. - _, last_val = critic_apply_fn(params.critic_params, hstates.critic_hidden_state, ac_in) - # Squeeze out the batch dimension and mask out the value of terminal states. - last_val = last_val.squeeze(0) - last_val = jnp.where(last_done, jnp.zeros_like(last_val), last_val) - - r_t = traj_batch.reward - v_t = jnp.concatenate([traj_batch.value, last_val[None, ...]], axis=0) - d_t = 1.0 - traj_batch.done.astype(jnp.float32) - d_t = (d_t * config.system.gamma).astype(jnp.float32) - advantages, targets = batch_truncated_generalized_advantage_estimation( - r_t, - d_t, - config.system.gae_lambda, - v_t, - time_major=True, - standardize_advantages=config.system.standardize_advantages, - truncation_flags=traj_batch.truncated, - ) - - def _update_epoch(update_state: Tuple, _: Any) -> Tuple: - """Update the network for a single epoch.""" - - def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: - """Update the network for a single minibatch.""" - - params, opt_states = train_state - ( - traj_batch, - advantages, - targets, - ) = batch_info - - def _actor_loss_fn( - actor_params: FrozenDict, - traj_batch: RNNPPOTransition, - gae: chex.Array, - ) -> Tuple: - """Calculate the actor loss.""" - # RERUN NETWORK - - obs_and_done = (traj_batch.obs, traj_batch.done) - policy_hidden_state = jax.tree_util.tree_map( - lambda x: x[0], traj_batch.hstates.policy_hidden_state - ) - _, actor_policy = actor_apply_fn( - actor_params, policy_hidden_state, obs_and_done - ) - log_prob = actor_policy.log_prob(traj_batch.action) - - loss_actor = ppo_clip_loss( - log_prob, traj_batch.log_prob, gae, config.system.clip_eps - ) - entropy = actor_policy.entropy().mean() - - total_loss = loss_actor - config.system.ent_coef * entropy - loss_info = { - "actor_loss": loss_actor, - "entropy": entropy, - } - return total_loss, loss_info - - def _critic_loss_fn( - critic_params: FrozenDict, - traj_batch: RNNPPOTransition, - targets: chex.Array, - ) -> Tuple: - """Calculate the critic loss.""" - # RERUN NETWORK - obs_and_done = (traj_batch.obs, traj_batch.done) - critic_hidden_state = jax.tree_util.tree_map( - lambda x: x[0], traj_batch.hstates.critic_hidden_state - ) - _, value = critic_apply_fn(critic_params, critic_hidden_state, obs_and_done) - - # CALCULATE VALUE LOSS - value_loss = clipped_value_loss( - value, traj_batch.value, targets, config.system.clip_eps - ) - - total_loss = config.system.vf_coef * value_loss - loss_info = { - "value_loss": value_loss, - } - return total_loss, loss_info - - # CALCULATE ACTOR LOSS - actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True) - actor_grads, actor_loss_info = actor_grad_fn( - params.actor_params, traj_batch, advantages - ) - - # CALCULATE CRITIC LOSS - critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True) - critic_grads, critic_loss_info = critic_grad_fn( - params.critic_params, traj_batch, targets - ) - - # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x - # This pmean could be a regular mean as the batch axis is on the same device. - actor_grads, actor_loss_info = jax.lax.pmean( - (actor_grads, actor_loss_info), axis_name="batch" - ) - # pmean over devices. - actor_grads, actor_loss_info = jax.lax.pmean( - (actor_grads, actor_loss_info), axis_name="device" - ) - - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="batch" - ) - # pmean over devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="device" - ) - - # UPDATE ACTOR PARAMS AND OPTIMISER STATE - actor_updates, actor_new_opt_state = actor_update_fn( - actor_grads, opt_states.actor_opt_state - ) - actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - - # UPDATE CRITIC PARAMS AND OPTIMISER STATE - critic_updates, critic_new_opt_state = critic_update_fn( - critic_grads, opt_states.critic_opt_state - ) - critic_new_params = optax.apply_updates(params.critic_params, critic_updates) - - new_params = ActorCriticParams(actor_new_params, critic_new_params) - new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) - - # PACK LOSS INFO - loss_info = { - **actor_loss_info, - **critic_loss_info, - } - - return (new_params, new_opt_state), loss_info - - ( - params, - opt_states, - init_hstates, - traj_batch, - advantages, - targets, - key, - ) = update_state - key, shuffle_key = jax.random.split(key) - - # SHUFFLE MINIBATCHES - batch = (traj_batch, advantages, targets) - num_recurrent_chunks = ( - config.system.rollout_length // config.system.recurrent_chunk_size - ) - batch = jax.tree_util.tree_map( - lambda x: x.reshape( - config.system.recurrent_chunk_size, - config.arch.num_envs * num_recurrent_chunks, - *x.shape[2:], - ), - batch, - ) - permutation = jax.random.permutation( - shuffle_key, config.arch.num_envs * num_recurrent_chunks - ) - shuffled_batch = jax.tree_util.tree_map( - lambda x: jnp.take(x, permutation, axis=1), batch - ) - reshaped_batch = jax.tree_util.tree_map( - lambda x: jnp.reshape( - x, (x.shape[0], config.system.num_minibatches, -1, *x.shape[2:]) - ), - shuffled_batch, - ) - minibatches = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 0), reshaped_batch) - - # UPDATE MINIBATCHES - (params, opt_states), loss_info = jax.lax.scan( - _update_minibatch, (params, opt_states), minibatches - ) - - update_state = ( - params, - opt_states, - init_hstates, - traj_batch, - advantages, - targets, - key, - ) - return update_state, loss_info - - init_hstates = jax.tree_util.tree_map(lambda x: x[None, :], initial_hstates) - update_state = ( - params, - opt_states, - init_hstates, - traj_batch, - advantages, - targets, - key, - ) - - # UPDATE EPOCHS - update_state, loss_info = jax.lax.scan( - _update_epoch, update_state, None, config.system.epochs - ) - - params, opt_states, _, traj_batch, advantages, targets, key = update_state - learner_state = RNNLearnerState( - params, - opt_states, - key, - env_state, - last_timestep, - last_done, - last_truncated, - hstates, - ) - metric = traj_batch.info - return learner_state, (metric, loss_info) - - def learner_fn(learner_state: RNNLearnerState) -> ExperimentOutput[RNNLearnerState]: - """Learner function. - - This function represents the learner, it updates the network parameters - by iteratively applying the `_update_step` function for a fixed number of - updates. The `_update_step` function is vectorized over a batch of inputs. - - Args: - learner_state (NamedTuple): - - params (ActorCriticParams): The initial model parameters. - - opt_states (OptStates): The initial optimizer states. - - key (chex.PRNGKey): The random number generator state. - - env_state (LogEnvState): The environment state. - - timesteps (TimeStep): The initial timestep in the initial trajectory. - - dones (bool): Whether the initial timestep was a terminal state. - - hstateS (ActorCriticHiddenStates): The initial hidden states of the RNN. - """ - - batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") - - learner_state, (episode_info, loss_info) = jax.lax.scan( - batched_update_step, learner_state, None, config.arch.num_updates_per_eval - ) - return ExperimentOutput( - learner_state=learner_state, - episode_metrics=episode_info, - train_metrics=loss_info, - ) - - return learner_fn - - -def learner_setup( - env: Environment, keys: chex.Array, config: DictConfig -) -> Tuple[LearnerFn[RNNLearnerState], RecurrentActor, Any, RNNLearnerState]: - """Initialise learner_fn, network, optimiser, environment and states.""" - # Get available TPU cores. - n_devices = len(jax.devices()) - - # Get number/dimension of actions. - num_actions = int(env.action_spec().num_values) - config.system.action_dim = num_actions - - # PRNG keys. - key, actor_net_key, critic_net_key = keys - - # FFM CHANGES HERE - output_size = 64 - trace_size = 64 - context_size = 64 - BatchFFM = flax.linen.vmap( - ScannedMemoroid, - in_axes=(((0, 0), 0), 1), - out_axes=(((0, 0), 0), 1), - variable_axes={"params": None}, - split_rngs={"params": False}, - ) - - # Define network and optimisers. - actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) - actor_rnn = BatchFFM( - cell=MemoroidResetWrapper( - cell=FFMCell(output_size=output_size, trace_size=trace_size, context_size=context_size) - ) - ) - actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso) - actor_action_head = hydra.utils.instantiate( - config.network.actor_network.action_head, action_dim=num_actions - ) - critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) - critic_rnn = BatchFFM( - cell=MemoroidResetWrapper( - cell=FFMCell(output_size=output_size, trace_size=trace_size, context_size=context_size) - ) - ) - critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso) - critic_head = hydra.utils.instantiate(config.network.critic_network.critic_head) - - actor_network = RecurrentActor( - pre_torso=actor_pre_torso, - rnn=actor_rnn, - post_torso=actor_post_torso, - action_head=actor_action_head, - ) - critic_network = RecurrentCritic( - pre_torso=critic_pre_torso, - rnn=critic_rnn, - post_torso=critic_post_torso, - critic_head=critic_head, - ) - - actor_lr = make_learning_rate( - config.system.actor_lr, config, config.system.epochs, config.system.num_minibatches - ) - critic_lr = make_learning_rate( - config.system.critic_lr, config, config.system.epochs, config.system.num_minibatches - ) - - actor_optim = optax.chain( - optax.clip_by_global_norm(config.system.max_grad_norm), - optax.adam(actor_lr, eps=1e-5), - ) - critic_optim = optax.chain( - optax.clip_by_global_norm(config.system.max_grad_norm), - optax.adam(critic_lr, eps=1e-5), - ) - - # Initialise observation - init_obs = env.observation_spec().generate_value() - init_obs = jax.tree_util.tree_map( - lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0), - init_obs, - ) - init_obs = jax.tree_util.tree_map(lambda x: x[None, ...], init_obs) - init_done = jnp.zeros((1, config.arch.num_envs), dtype=bool) - init_x = (init_obs, init_done) - - # Initialise hidden states. - init_policy_hstate = actor_rnn.initialize_carry(config.arch.num_envs) - init_critic_hstate = critic_rnn.initialize_carry(config.arch.num_envs) - - # initialise params and optimiser state. - actor_params = actor_network.init(actor_net_key, init_policy_hstate, init_x) - actor_opt_state = actor_optim.init(actor_params) - critic_params = critic_network.init(critic_net_key, init_critic_hstate, init_x) - critic_opt_state = critic_optim.init(critic_params) - - actor_network_apply_fn = actor_network.apply - critic_network_apply_fn = critic_network.apply - - # Get network apply functions and optimiser updates. - apply_fns = (actor_network_apply_fn, critic_network_apply_fn) - update_fns = (actor_optim.update, critic_optim.update) - - # Get batched iterated update and replicate it to pmap it over cores. - learn = get_learner_fn(env, apply_fns, update_fns, config) - learn = jax.pmap(learn, axis_name="device") - - # Pack params and initial states. - params = ActorCriticParams(actor_params, critic_params) - hstates = ActorCriticHiddenStates(init_policy_hstate, init_critic_hstate) - - # Load model from checkpoint if specified. - if config.logger.checkpointing.load_model: - loaded_checkpoint = Checkpointer( - model_name=config.system.system_name, - **config.logger.checkpointing.load_args, # Other checkpoint args - ) - # Restore the learner state from the checkpoint - restored_params, restored_hstates = loaded_checkpoint.restore_params(restore_hstates=True) - # Update the params and hstates - params = restored_params - hstates = restored_hstates if restored_hstates else hstates - - # Initialise environment states and timesteps: across devices and batches. - key, *env_keys = jax.random.split( - key, n_devices * config.arch.update_batch_size * config.arch.num_envs + 1 - ) - env_states, timesteps = jax.vmap(env.reset, in_axes=(0))( - jnp.stack(env_keys), - ) - reshape_states = lambda x: x.reshape( - (n_devices, config.arch.update_batch_size, config.arch.num_envs) + x.shape[1:] - ) - # (devices, update batch size, num_envs, ...) - env_states = jax.tree_util.tree_map(reshape_states, env_states) - timesteps = jax.tree_util.tree_map(reshape_states, timesteps) - - # Define params to be replicated across devices and batches. - dones = jnp.zeros( - (config.arch.num_envs,), - dtype=bool, - ) - truncated = jnp.zeros( - (config.arch.num_envs,), - dtype=bool, - ) - key, step_key = jax.random.split(key) - step_keys = jax.random.split(step_key, n_devices * config.arch.update_batch_size) - reshape_keys = lambda x: x.reshape((n_devices, config.arch.update_batch_size) + x.shape[1:]) - step_keys = reshape_keys(jnp.stack(step_keys)) - opt_states = ActorCriticOptStates(actor_opt_state, critic_opt_state) - replicate_learner = (params, opt_states, hstates, dones, truncated) - - # Duplicate learner for update_batch_size. - broadcast = lambda x: jnp.broadcast_to(x, (config.arch.update_batch_size,) + x.shape) - replicate_learner = jax.tree_util.tree_map(broadcast, replicate_learner) - - # Duplicate learner across devices. - replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) - - # Initialise learner state. - params, opt_states, hstates, dones, truncated = replicate_learner - init_learner_state = RNNLearnerState( - params=params, - opt_states=opt_states, - key=step_keys, - env_state=env_states, - timestep=timesteps, - done=dones, - truncated=truncated, - hstates=hstates, - ) - return learn, actor_network, actor_rnn, init_learner_state - - -def run_experiment(_config: DictConfig) -> float: - """Runs experiment.""" - config = copy.deepcopy(_config) - - # Calculate total timesteps. - n_devices = len(jax.devices()) - config.num_devices = n_devices - config = check_total_timesteps(config) - assert ( - config.arch.num_updates > config.arch.num_evaluation - ), "Number of updates per evaluation must be less than total number of updates." - - # Set recurrent chunk size. - if config.system.recurrent_chunk_size is None: - config.system.recurrent_chunk_size = config.system.rollout_length - else: - assert ( - config.system.rollout_length % config.system.recurrent_chunk_size == 0 - ), "Rollout length must be divisible by recurrent chunk size." - - # Create the environments for train and eval. - env, eval_env = environments.make(config) - - # PRNG keys. - key, key_e, actor_net_key, critic_net_key = jax.random.split( - jax.random.PRNGKey(config.arch.seed), num=4 - ) - - # Setup learner. - learn, actor_network, actor_rnn, learner_state = learner_setup( - env, (key, actor_net_key, critic_net_key), config - ) - - # Setup evaluator. - evaluator, absolute_metric_evaluator, (trained_params, eval_keys) = evaluator_setup( - eval_env=eval_env, - key_e=key_e, - eval_act_fn=get_rec_distribution_act_fn(config, actor_network.apply), - params=learner_state.params.actor_params, - config=config, - use_recurrent_net=True, - scanned_rnn=actor_rnn, - ) - - # Calculate number of updates per evaluation. - config.arch.num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation - steps_per_rollout = ( - n_devices - * config.arch.num_updates_per_eval - * config.system.rollout_length - * config.arch.update_batch_size - * config.arch.num_envs - ) - - # Logger setup - logger = StoixLogger(config) - cfg: Dict = OmegaConf.to_container(config, resolve=True) - cfg["arch"]["devices"] = jax.devices() - pprint(cfg) - - # Set up checkpointer - save_checkpoint = config.logger.checkpointing.save_model - if save_checkpoint: - checkpointer = Checkpointer( - metadata=config, # Save all config as metadata in the checkpoint - model_name=config.system.system_name, - **config.logger.checkpointing.save_args, # Checkpoint args - ) - - # Run experiment for a total number of evaluations. - max_episode_return = jnp.float32(-1e7) - best_params = None - for eval_step in range(config.arch.num_evaluation): - # Train. - start_time = time.time() - learner_output = learn(learner_state) - jax.block_until_ready(learner_output) - - # Log the results of the training. - elapsed_time = time.time() - start_time - t = int(steps_per_rollout * (eval_step + 1)) - episode_metrics, ep_completed = get_final_step_metrics(learner_output.episode_metrics) - episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time - - # Separately log timesteps, actoring metrics and training metrics. - logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) - if ep_completed: # only log episode metrics if an episode was completed in the rollout. - logger.log(episode_metrics, t, eval_step, LogEvent.ACT) - logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) - - # Prepare for evaluation. - start_time = time.time() - trained_params = unreplicate_batch_dim(learner_output.learner_state.params.actor_params) - key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) - eval_keys = jnp.stack(eval_keys) - eval_keys = eval_keys.reshape(n_devices, -1) - - # Evaluate. - evaluator_output = evaluator(trained_params, eval_keys) - jax.block_until_ready(evaluator_output) - - # Log the results of the evaluation. - elapsed_time = time.time() - start_time - episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) - - steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) - evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) - - if save_checkpoint: - # Save checkpoint of learner state - checkpointer.save( - timestep=int(steps_per_rollout * (eval_step + 1)), - unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), - episode_return=episode_return, - ) - - if config.arch.absolute_metric and max_episode_return <= episode_return: - best_params = copy.deepcopy(trained_params) - max_episode_return = episode_return - - # Update runner state to continue training. - learner_state = learner_output.learner_state - - # Measure absolute metric. - if config.arch.absolute_metric: - start_time = time.time() - - key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) - eval_keys = jnp.stack(eval_keys) - eval_keys = eval_keys.reshape(n_devices, -1) - - evaluator_output = absolute_metric_evaluator(best_params, eval_keys) - jax.block_until_ready(evaluator_output) - - elapsed_time = time.time() - start_time - - t = int(steps_per_rollout * (eval_step + 1)) - steps_per_eval = int(jnp.sum(evaluator_output.episode_metrics["episode_length"])) - evaluator_output.episode_metrics["steps_per_second"] = steps_per_eval / elapsed_time - logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE) - - # Stop the logger. - logger.stop() - # Record the performance for the final evaluation run. If the absolute metric is not - # calculated, this will be the final evaluation run. - eval_performance = float(jnp.mean(evaluator_output.episode_metrics[config.env.eval_metric])) - return eval_performance - - -@hydra.main(config_path="../../configs", config_name="default_rec_ppo.yaml", version_base="1.2") -def hydra_entry_point(cfg: DictConfig) -> float: - """Experiment entry point.""" - # Allow dynamic attributes. - OmegaConf.set_struct(cfg, False) - - # Run experiment. - eval_performance = run_experiment(cfg) - - print(f"{Fore.CYAN}{Style.BRIGHT}Recurrent PPO experiment completed{Style.RESET_ALL}") - return eval_performance - - -if __name__ == "__main__": - hydra_entry_point() From 2446b9c4c0c9ff020a818882dfbf8d9ccfc51f18 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Mon, 24 Jun 2024 17:26:55 +0000 Subject: [PATCH 16/38] chore: remove reliance on start variable being inside recurrent state --- stoix/configs/arch/anakin.yaml | 2 +- stoix/configs/network/memoroid.yaml | 8 +-- stoix/networks/memoroid.py | 89 +++++++++++++++-------------- 3 files changed, 52 insertions(+), 47 deletions(-) diff --git a/stoix/configs/arch/anakin.yaml b/stoix/configs/arch/anakin.yaml index a5025dee..f6092512 100644 --- a/stoix/configs/arch/anakin.yaml +++ b/stoix/configs/arch/anakin.yaml @@ -3,7 +3,7 @@ # --- Training --- seed: 42 # RNG seed. update_batch_size: 1 # Number of vectorised gradient updates per device. -total_num_envs: 512 # Total Number of vectorised environments across all devices and batched_updates. Needs to be divisible by n_devices*update_batch_size. +total_num_envs: 1024 # Total Number of vectorised environments across all devices and batched_updates. Needs to be divisible by n_devices*update_batch_size. total_timesteps: 1e7 # Set the total environment steps. # If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. num_updates: ~ # Number of updates diff --git a/stoix/configs/network/memoroid.yaml b/stoix/configs/network/memoroid.yaml index 8afa9227..0a45a3d0 100644 --- a/stoix/configs/network/memoroid.yaml +++ b/stoix/configs/network/memoroid.yaml @@ -12,8 +12,8 @@ actor_network: _target_: stoix.networks.memoroid.MemoroidResetWrapper cell: _target_: stoix.networks.memoroid.FFMCell - trace_size: 128 - context_size: 128 + trace_size: 64 + context_size: 4 output_size: 128 post_torso: _target_: stoix.networks.torso.MLPTorso @@ -35,8 +35,8 @@ critic_network: _target_: stoix.networks.memoroid.MemoroidResetWrapper cell: _target_: stoix.networks.memoroid.FFMCell - trace_size: 128 - context_size: 128 + trace_size: 64 + context_size: 4 output_size: 128 post_torso: _target_: stoix.networks.torso.MLPTorso diff --git a/stoix/networks/memoroid.py b/stoix/networks/memoroid.py index 393a370b..466042c0 100644 --- a/stoix/networks/memoroid.py +++ b/stoix/networks/memoroid.py @@ -8,13 +8,15 @@ # Typing aliases Carry = chex.ArrayTree + +HiddenState = chex.Array Timestep = chex.Array -MemoroidRecurrentState = Tuple[Tuple[chex.Array, chex.Array], chex.Array] -Inputs = Tuple[chex.Array, chex.Array] -Outputs = Tuple[chex.Array, chex.Array] -RecurrentState = Tuple[chex.Array, chex.Array] -CarryState = Tuple[RecurrentState, chex.Array] -ScanInputs = Tuple[RecurrentState, Inputs] +Reset = chex.Array + +RecurrentState = Tuple[HiddenState, Timestep] + +InputEmbedding = chex.Array +Inputs = Tuple[InputEmbedding, Reset] def debug_shape(x): @@ -24,18 +26,18 @@ def debug_shape(x): class MemoroidCellBase(nn.Module): """Memoroid cell base class.""" - def map_to_h(self, inputs): + def map_to_h(self, inputs: Inputs) -> RecurrentState: """Map from the input space to the recurrent state space""" raise NotImplementedError - def map_from_h(self, recurrent_state, x): + def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> HiddenState: """Map from the recurrent space to the Markov space""" raise NotImplementedError @nn.nowrap def initialize_carry( self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None - ) -> Carry: + ) -> RecurrentState: """Initialize the Memoroid cell carry. Args: @@ -55,10 +57,10 @@ def num_feature_axes(self) -> int: def recurrent_associative_scan( cell: nn.Module, - state: chex.Array, - inputs: chex.Array, + state: RecurrentState, + inputs: RecurrentState, axis: int = 0, -) -> jax.Array: +) -> RecurrentState: """Execute the associative scan to update the recurrent state. Note that we do a trick here by concatenating the previous state to the inputs. @@ -68,14 +70,21 @@ def recurrent_associative_scan( # Concatenate the previous state to the inputs and scan over the result # This ensures the previous recurrent state contributes to the current batch - # state: [start, (x, j)] - # inputs: [start, (x, j)] + + # We need to add a dummy start signal to the inputs + dummy_start = jnp.zeros(inputs[-1].shape[1:], dtype=bool)[jnp.newaxis, ...] + # Add it to the state i.e. (state, timestep) -> ((state, time), reset) + state = (state, dummy_start) scan_inputs = jax.tree.map(lambda s, x: jnp.concatenate([s, x], axis=axis), state, inputs) new_state = jax.lax.associative_scan( cell, scan_inputs, axis=axis, ) + + # Get rid of the reset signal i.e. ((state, time), reset) -> (state, time) + new_state, _ = new_state + # The zeroth index corresponds to the previous recurrent state # We just use it to ensure continuity # We do not actually want to use these values, so slice them away @@ -98,7 +107,7 @@ def __call__(self, x): def init_deterministic( memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 -) -> Tuple[jax.Array, jax.Array]: +) -> Tuple[chex.Array, chex.Array]: """Deterministic initialization of the FFM parameters.""" a_low = 1e-6 a_high = 0.5 @@ -131,22 +140,21 @@ def setup(self): self.mix = nn.Dense(self.output_size) self.ln = nn.LayerNorm(use_scale=False, use_bias=False) - def map_to_h(self, inputs): + def map_to_h(self, x: InputEmbedding) -> RecurrentState: """Map from the input space to the recurrent state space - unlike the call function this explicitly expects a shape including the sequence dimension. This is used in the outer network that uses the associative scan.""" - x, resets = inputs gate_in = self.gate_in(x) pre = self.pre(x) gated_x = pre * gate_in # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous ts = jnp.ones(x.shape[0:2], dtype=jnp.int32) z = jnp.repeat(jnp.expand_dims(gated_x, 3), self.context_size, axis=3) - return (z, ts), resets + return (z, ts) - def map_from_h(self, recurrent_state, x): + def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> HiddenState: """Map from the recurrent space to the Markov space""" - (state, ts), reset = recurrent_state + state, _ = recurrent_state z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape( state.shape[0], state.shape[1], -1 ) @@ -156,20 +164,20 @@ def map_from_h(self, recurrent_state, x): out = self.ln(z * gate_out) + skip * (1 - gate_out) return out - def log_gamma(self, t: jax.Array) -> jax.Array: + def log_gamma(self, t: chex.Array) -> chex.Array: a, b = self.params a = -jnp.abs(a).reshape((1, 1, self.trace_size, 1)) b = b.reshape(1, 1, 1, self.context_size) ab = jax.lax.complex(a, b) return ab * t.reshape(t.shape[0], t.shape[1], 1, 1) - def gamma(self, t: jax.Array) -> jax.Array: + def gamma(self, t: chex.Array) -> chex.Array: return jnp.exp(self.log_gamma(t)) @nn.nowrap def initialize_carry( self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None - ) -> Carry: + ) -> RecurrentState: # inputs should be of shape [*batch, time, feature] # recurrent states should be of shape [*batch, 1, feature] carry_shape = (1, self.trace_size, self.context_size) @@ -179,7 +187,7 @@ def initialize_carry( t_shape = (*t_shape, batch_size) return jnp.zeros(carry_shape, dtype=jnp.complex64), jnp.ones(t_shape, dtype=jnp.int32) - def __call__(self, carry, incoming): + def __call__(self, carry: RecurrentState, incoming): ( state, i, @@ -199,7 +207,7 @@ def __call__(self, carry, incoming, rng=None): states, prev_start = carry xs, start = incoming - def reset_state(start, current_state, initial_state): + def reset_state(start: Reset, current_state, initial_state): # Expand to reset all dims of state: [1, B, 1, ...] assert initial_state.ndim == current_state.ndim expanded_start = start.reshape(-1, start.shape[1], *([1] * (current_state.ndim - 2))) @@ -214,42 +222,39 @@ def reset_state(start, current_state, initial_state): return out, start_carry - def map_to_h(self, inputs): - return self.cell.map_to_h(inputs) + def map_to_h(self, x: InputEmbedding) -> RecurrentState: + return self.cell.map_to_h(x) - def map_from_h(self, recurrent_state, x): + def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> HiddenState: return self.cell.map_from_h(recurrent_state, x) @nn.nowrap def initialize_carry( self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None - ) -> Carry: - # inputs should be of shape [*batch, time, feature] - # recurrent states should be of shape [*batch, 1, feature] - start_shape = (1,) - if batch_size is not None: - start_shape = (*start_shape, batch_size) - return self.cell.initialize_carry(batch_size, rng), jnp.zeros(start_shape, dtype=bool) + ) -> RecurrentState: + return self.cell.initialize_carry(batch_size, rng) class ScannedMemoroid(nn.Module): cell: nn.Module @nn.compact - def __call__(self, recurrent_state, inputs): + def __call__( + self, recurrent_state: RecurrentState, inputs: Inputs + ) -> Tuple[RecurrentState, HiddenState]: """Apply the ScannedMemoroid. This takes in a sequence of batched states and inputs. The recurrent state that is used requires no sequence dimension but does require a batch dimension.""" - # Recurrent state should be ((state, timestep), reset) + # Recurrent state should be (state, timestep) # Inputs should be (x, reset) # Unsqueeze the recurrent state to add the sequence dimension of size 1 recurrent_state = jax.tree.map(lambda x: jnp.expand_dims(x, 0), recurrent_state) - x, _ = inputs - h = self.cell.map_to_h(inputs) - recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h) - # recurrent_state is ((state, timestep), reset) + x, resets = inputs + h = self.cell.map_to_h(x) + recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, (h, resets)) + # recurrent_state is (state, timestep) out = self.cell.map_from_h(recurrent_state, x) # TODO: Remove this when we want to return all recurrent states instead of just the last one @@ -263,7 +268,7 @@ def __call__(self, recurrent_state, inputs): @nn.nowrap def initialize_carry( self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None - ) -> Carry: + ) -> RecurrentState: """Initialize the carry for the ScannedMemoroid. This returns the carry in the shape [Batch, ...] i.e. it contains no sequence dimension""" # We squeeze the sequence dim of 1 out. return jax.tree.map(lambda x: x.squeeze(0), self.cell.initialize_carry(batch_size, rng)) From aa763251f7befcff895410fc27a6c9eb72718a3a Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Mon, 24 Jun 2024 22:24:24 +0100 Subject: [PATCH 17/38] fix reset zero error --- stoix/networks/memoroid.py | 45 ++++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/stoix/networks/memoroid.py b/stoix/networks/memoroid.py index 466042c0..76632582 100644 --- a/stoix/networks/memoroid.py +++ b/stoix/networks/memoroid.py @@ -185,7 +185,7 @@ def initialize_carry( if batch_size is not None: carry_shape = (carry_shape[0], batch_size, *carry_shape[1:]) t_shape = (*t_shape, batch_size) - return jnp.zeros(carry_shape, dtype=jnp.complex64), jnp.ones(t_shape, dtype=jnp.int32) + return jnp.zeros(carry_shape, dtype=jnp.complex64), jnp.zeros(t_shape, dtype=jnp.int32) def __call__(self, carry: RecurrentState, incoming): ( @@ -204,7 +204,7 @@ class MemoroidResetWrapper(MemoroidCellBase): cell: nn.Module def __call__(self, carry, incoming, rng=None): - states, prev_start = carry + states, prev_carry_reset_flag = carry xs, start = incoming def reset_state(start: Reset, current_state, initial_state): @@ -218,9 +218,9 @@ def reset_state(start: Reset, current_state, initial_state): initial_states = self.cell.initialize_carry(rng=rng, batch_size=start.shape[1]) states = jax.tree.map(partial(reset_state, start), states, initial_states) out = self.cell(states, xs) - start_carry = jnp.logical_or(start, prev_start) + carry_reset_flag = jnp.logical_or(start, prev_carry_reset_flag) - return out, start_carry + return out, carry_reset_flag def map_to_h(self, x: InputEmbedding) -> RecurrentState: return self.cell.map_to_h(x) @@ -274,6 +274,41 @@ def initialize_carry( return jax.tree.map(lambda x: x.squeeze(0), self.cell.initialize_carry(batch_size, rng)) +def test_reset_wrapper(): + """Validate that the reset wrapper works as expected""" + BatchFFM = ScannedMemoroid + + m = BatchFFM( + cell=MemoroidResetWrapper(cell=FFMCell(output_size=2, trace_size=2, context_size=3)) + ) + + batch_size = 4 + time_steps = 100 + # Have a batched version with one episode per batch + # and collapse it into a single episode with a single batch (but same start/resets) + # results should be identical + batched_starts = jnp.ones([batch_size], dtype=bool) + batched_starts = jnp.concatenate([ + batched_starts.reshape(1, -1), + jnp.zeros([time_steps - 1, batch_size], dtype=bool) + ], axis=0) + contig_starts = jnp.swapaxes(batched_starts, 1, 0).reshape(-1, 1) + + x_batched = jnp.arange(time_steps * batch_size * 2).reshape((time_steps, batch_size, 2)) + x_contig = jnp.swapaxes(x_batched, 1, 0).reshape(-1, 1, 2) + batched_s = m.initialize_carry(batch_size) + contig_s = m.initialize_carry(1) + params = m.init(jax.random.PRNGKey(0), batched_s, (x_batched, batched_starts)) + + + (batched_out_state, _), _ = m.apply(params, batched_s, (x_batched, batched_starts)) + (contig_out_state, _), _ = m.apply(params, contig_s, (x_contig, contig_starts)) + + # This should be nearly zero (1e-10 or something) + error = jnp.linalg.norm(contig_out_state - batched_out_state[-1]) + print(error) + + if __name__ == "__main__": BatchFFM = ScannedMemoroid @@ -294,3 +329,5 @@ def initialize_carry( print(out) print(debug_shape(out_state)) + + test_reset_wrapper() From c5816dead37e09a429d7f424f20d0928cd21ccb7 Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Mon, 24 Jun 2024 22:41:26 +0100 Subject: [PATCH 18/38] better tests --- stoix/networks/memoroid.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/stoix/networks/memoroid.py b/stoix/networks/memoroid.py index 76632582..b6a4353d 100644 --- a/stoix/networks/memoroid.py +++ b/stoix/networks/memoroid.py @@ -301,12 +301,14 @@ def test_reset_wrapper(): params = m.init(jax.random.PRNGKey(0), batched_s, (x_batched, batched_starts)) - (batched_out_state, _), _ = m.apply(params, batched_s, (x_batched, batched_starts)) - (contig_out_state, _), _ = m.apply(params, contig_s, (x_contig, contig_starts)) + (batched_out_state, _), batched_out = m.apply(params, batched_s, (x_batched, batched_starts)) + (contig_out_state, _), contig_out = m.apply(params, contig_s, (x_contig, contig_starts)) # This should be nearly zero (1e-10 or something) - error = jnp.linalg.norm(contig_out_state - batched_out_state[-1]) - print(error) + state_error = jnp.linalg.norm(contig_out_state - batched_out_state[-1], axis=-1).sum() + print("state error", state_error) + state_error = jnp.linalg.norm(batched_out - jnp.swapaxes(contig_out.reshape(batch_size, time_steps, -1), 1, 0), axis=-1).sum() + print("state error", state_error) if __name__ == "__main__": From 8fc264047abbbebf2ba70be2878de58e5c171ba3 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Mon, 24 Jun 2024 22:23:09 +0000 Subject: [PATCH 19/38] feat: add more popjym configs --- stoix/configs/env/popjym/repeat_first_easy.yaml | 12 ++++++++++++ stoix/configs/env/popjym/repeat_first_hard.yaml | 12 ++++++++++++ stoix/configs/env/popjym/repeat_first_medium.yaml | 12 ++++++++++++ .../configs/env/popjym/stateless_cartpole_easy.yaml | 4 ++++ stoix/configs/network/memoroid.yaml | 12 ++++++------ stoix/configs/network/rnn.yaml | 12 ++++++------ 6 files changed, 52 insertions(+), 12 deletions(-) create mode 100644 stoix/configs/env/popjym/repeat_first_easy.yaml create mode 100644 stoix/configs/env/popjym/repeat_first_hard.yaml create mode 100644 stoix/configs/env/popjym/repeat_first_medium.yaml diff --git a/stoix/configs/env/popjym/repeat_first_easy.yaml b/stoix/configs/env/popjym/repeat_first_easy.yaml new file mode 100644 index 00000000..db90c91b --- /dev/null +++ b/stoix/configs/env/popjym/repeat_first_easy.yaml @@ -0,0 +1,12 @@ +# ---Environment Configs--- +env_name: popjym + +scenario: + name: RepeatFirstEasy + task_name: repeat_first_easy + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return \ No newline at end of file diff --git a/stoix/configs/env/popjym/repeat_first_hard.yaml b/stoix/configs/env/popjym/repeat_first_hard.yaml new file mode 100644 index 00000000..3736f763 --- /dev/null +++ b/stoix/configs/env/popjym/repeat_first_hard.yaml @@ -0,0 +1,12 @@ +# ---Environment Configs--- +env_name: popjym + +scenario: + name: RepeatFirstHard + task_name: repeat_first_hard + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return \ No newline at end of file diff --git a/stoix/configs/env/popjym/repeat_first_medium.yaml b/stoix/configs/env/popjym/repeat_first_medium.yaml new file mode 100644 index 00000000..eafd73e3 --- /dev/null +++ b/stoix/configs/env/popjym/repeat_first_medium.yaml @@ -0,0 +1,12 @@ +# ---Environment Configs--- +env_name: popjym + +scenario: + name: RepeatFirstMedium + task_name: repeat_first_medium + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return \ No newline at end of file diff --git a/stoix/configs/env/popjym/stateless_cartpole_easy.yaml b/stoix/configs/env/popjym/stateless_cartpole_easy.yaml index 06aea7de..3d26bc2e 100644 --- a/stoix/configs/env/popjym/stateless_cartpole_easy.yaml +++ b/stoix/configs/env/popjym/stateless_cartpole_easy.yaml @@ -6,3 +6,7 @@ scenario: task_name: stateless_cartpole_easy kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return \ No newline at end of file diff --git a/stoix/configs/network/memoroid.yaml b/stoix/configs/network/memoroid.yaml index 0a45a3d0..64f060e3 100644 --- a/stoix/configs/network/memoroid.yaml +++ b/stoix/configs/network/memoroid.yaml @@ -3,7 +3,7 @@ actor_network: pre_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [128] + layer_sizes: [256] use_layer_norm: False activation: silu rnn_layer: @@ -14,10 +14,10 @@ actor_network: _target_: stoix.networks.memoroid.FFMCell trace_size: 64 context_size: 4 - output_size: 128 + output_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [128] + layer_sizes: [256] use_layer_norm: False activation: silu action_head: @@ -26,7 +26,7 @@ actor_network: critic_network: pre_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [128] + layer_sizes: [256] use_layer_norm: False activation: silu rnn_layer: @@ -37,10 +37,10 @@ critic_network: _target_: stoix.networks.memoroid.FFMCell trace_size: 64 context_size: 4 - output_size: 128 + output_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [128] + layer_sizes: [256] use_layer_norm: False activation: silu critic_head: diff --git a/stoix/configs/network/rnn.yaml b/stoix/configs/network/rnn.yaml index 14801c2d..459cfa64 100644 --- a/stoix/configs/network/rnn.yaml +++ b/stoix/configs/network/rnn.yaml @@ -3,16 +3,16 @@ actor_network: pre_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [128] + layer_sizes: [256] use_layer_norm: False activation: silu rnn_layer: _target_: stoix.networks.recurrent.ScannedRNN cell_type: gru - hidden_state_dim: 128 + hidden_state_dim: 256 post_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [128] + layer_sizes: [256] use_layer_norm: False activation: silu action_head: @@ -21,16 +21,16 @@ actor_network: critic_network: pre_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [128] + layer_sizes: [256] use_layer_norm: False activation: silu rnn_layer: _target_: stoix.networks.recurrent.ScannedRNN cell_type: gru - hidden_state_dim: 128 + hidden_state_dim: 256 post_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [128] + layer_sizes: [256] use_layer_norm: False activation: silu critic_head: From d8c845bde6e732b80716a08835a28041e35b9988 Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Tue, 25 Jun 2024 02:11:02 +0100 Subject: [PATCH 20/38] fix dummy state --- stoix/networks/memoroid.py | 71 ++++++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 26 deletions(-) diff --git a/stoix/networks/memoroid.py b/stoix/networks/memoroid.py index b6a4353d..688840a9 100644 --- a/stoix/networks/memoroid.py +++ b/stoix/networks/memoroid.py @@ -71,10 +71,6 @@ def recurrent_associative_scan( # Concatenate the previous state to the inputs and scan over the result # This ensures the previous recurrent state contributes to the current batch - # We need to add a dummy start signal to the inputs - dummy_start = jnp.zeros(inputs[-1].shape[1:], dtype=bool)[jnp.newaxis, ...] - # Add it to the state i.e. (state, timestep) -> ((state, time), reset) - state = (state, dummy_start) scan_inputs = jax.tree.map(lambda s, x: jnp.concatenate([s, x], axis=axis), state, inputs) new_state = jax.lax.associative_scan( cell, @@ -82,9 +78,6 @@ def recurrent_associative_scan( axis=axis, ) - # Get rid of the reset signal i.e. ((state, time), reset) -> (state, time) - new_state, _ = new_state - # The zeroth index corresponds to the previous recurrent state # We just use it to ensure continuity # We do not actually want to use these values, so slice them away @@ -154,7 +147,7 @@ def map_to_h(self, x: InputEmbedding) -> RecurrentState: def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> HiddenState: """Map from the recurrent space to the Markov space""" - state, _ = recurrent_state + (state, _), _ = recurrent_state z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape( state.shape[0], state.shape[1], -1 ) @@ -185,6 +178,7 @@ def initialize_carry( if batch_size is not None: carry_shape = (carry_shape[0], batch_size, *carry_shape[1:]) t_shape = (*t_shape, batch_size) + return jnp.zeros(carry_shape, dtype=jnp.complex64), jnp.zeros(t_shape, dtype=jnp.int32) def __call__(self, carry: RecurrentState, incoming): @@ -232,7 +226,7 @@ def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> Hidd def initialize_carry( self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None ) -> RecurrentState: - return self.cell.initialize_carry(batch_size, rng) + return self.cell.initialize_carry(batch_size, rng), jnp.zeros((1, batch_size), dtype=bool) class ScannedMemoroid(nn.Module): @@ -301,35 +295,60 @@ def test_reset_wrapper(): params = m.init(jax.random.PRNGKey(0), batched_s, (x_batched, batched_starts)) - (batched_out_state, _), batched_out = m.apply(params, batched_s, (x_batched, batched_starts)) - (contig_out_state, _), contig_out = m.apply(params, contig_s, (x_contig, contig_starts)) + (batched_out_state, batched_ts), batched_out = m.apply(params, batched_s, (x_batched, batched_starts)) + (contig_out_state, contig_ts), contig_out = m.apply(params, contig_s, (x_contig, contig_starts)) # This should be nearly zero (1e-10 or something) state_error = jnp.linalg.norm(contig_out_state - batched_out_state[-1], axis=-1).sum() print("state error", state_error) - state_error = jnp.linalg.norm(batched_out - jnp.swapaxes(contig_out.reshape(batch_size, time_steps, -1), 1, 0), axis=-1).sum() - print("state error", state_error) - + out_error = jnp.linalg.norm(batched_out - jnp.swapaxes(contig_out.reshape(batch_size, time_steps, -1), 1, 0), axis=-1).sum() + print("out error", out_error) -if __name__ == "__main__": +def test_reset_wrapper_ts(): BatchFFM = ScannedMemoroid m = BatchFFM( - cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)) + cell=MemoroidResetWrapper(cell=FFMCell(output_size=2, trace_size=2, context_size=3)) ) - batch_size = 8 + batch_size = 2 time_steps = 10 + # Have a batched version with one episode per batch + # and collapse it into a single episode with a single batch (but same start/resets) + # results should be identical + batched_starts = jnp.array([ + [False, False, True, False, False, True, True, False, False, False], + [False, False, True, False, False, True, True, False, False, False], + ]).T + + x_batched = jnp.arange(time_steps * batch_size * 2).reshape((time_steps, batch_size, 2)).astype(jnp.float32) + batched_s = m.initialize_carry(batch_size) + params = m.init(jax.random.PRNGKey(0), batched_s, (x_batched, batched_starts)) + + + ((batched_out_state, batched_ts), batched_reset), batched_out = m.apply(params, batched_s, (x_batched, batched_starts)) + + +if __name__ == "__main__": + # BatchFFM = ScannedMemoroid + + # m = BatchFFM( + # cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)) + # ) + + # batch_size = 8 + # time_steps = 10 - y = jnp.ones((time_steps, batch_size, 2)) - s = m.initialize_carry(batch_size) - start = jnp.zeros((time_steps, batch_size), dtype=bool) - params = m.init(jax.random.PRNGKey(0), s, (y, start)) - out_state, out = m.apply(params, s, (y, start)) + # y = jnp.ones((time_steps, batch_size, 2)) + # s = m.initialize_carry(batch_size) + # start = jnp.zeros((time_steps, batch_size), dtype=bool) + # params = m.init(jax.random.PRNGKey(0), s, (y, start)) + # out_state, out = m.apply(params, s, (y, start)) - out = jnp.swapaxes(out, 0, 1) + # out = jnp.swapaxes(out, 0, 1) - print(out) - print(debug_shape(out_state)) + # print(out) + # print(debug_shape(out_state)) - test_reset_wrapper() + #test_reset_wrapper() + test_reset_wrapper_ts() From 7feaa3fab6a58c84cbdde584672d73d2f3ea8be9 Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Tue, 25 Jun 2024 02:36:48 +0100 Subject: [PATCH 21/38] better reset tests --- stoix/networks/memoroid.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/stoix/networks/memoroid.py b/stoix/networks/memoroid.py index 688840a9..fbf1fbf0 100644 --- a/stoix/networks/memoroid.py +++ b/stoix/networks/memoroid.py @@ -282,10 +282,12 @@ def test_reset_wrapper(): # and collapse it into a single episode with a single batch (but same start/resets) # results should be identical batched_starts = jnp.ones([batch_size], dtype=bool) - batched_starts = jnp.concatenate([ - batched_starts.reshape(1, -1), - jnp.zeros([time_steps - 1, batch_size], dtype=bool) - ], axis=0) + # batched_starts = jnp.concatenate([ + # jnp.zeros([time_steps // 2, batch_size], dtype=bool), + # batched_starts.reshape(1, -1), + # jnp.zeros([time_steps // 2 - 1, batch_size], dtype=bool) + # ], axis=0) + batched_starts = jax.random.uniform(jax.random.PRNGKey(0), (time_steps, batch_size)) < 0.1 contig_starts = jnp.swapaxes(batched_starts, 1, 0).reshape(-1, 1) x_batched = jnp.arange(time_steps * batch_size * 2).reshape((time_steps, batch_size, 2)) @@ -295,14 +297,15 @@ def test_reset_wrapper(): params = m.init(jax.random.PRNGKey(0), batched_s, (x_batched, batched_starts)) - (batched_out_state, batched_ts), batched_out = m.apply(params, batched_s, (x_batched, batched_starts)) - (contig_out_state, contig_ts), contig_out = m.apply(params, contig_s, (x_contig, contig_starts)) + ((batched_out_state, batched_ts), batched_reset), batched_out = m.apply(params, batched_s, (x_batched, batched_starts)) + ((contig_out_state, contig_ts), contig_reset), contig_out = m.apply(params, contig_s, (x_contig, contig_starts)) # This should be nearly zero (1e-10 or something) state_error = jnp.linalg.norm(contig_out_state - batched_out_state[-1], axis=-1).sum() print("state error", state_error) out_error = jnp.linalg.norm(batched_out - jnp.swapaxes(contig_out.reshape(batch_size, time_steps, -1), 1, 0), axis=-1).sum() print("out error", out_error) + print(batched_ts, contig_ts) def test_reset_wrapper_ts(): BatchFFM = ScannedMemoroid @@ -327,6 +330,7 @@ def test_reset_wrapper_ts(): ((batched_out_state, batched_ts), batched_reset), batched_out = m.apply(params, batched_s, (x_batched, batched_starts)) + print(batched_ts == 4) if __name__ == "__main__": @@ -350,5 +354,5 @@ def test_reset_wrapper_ts(): # print(out) # print(debug_shape(out_state)) - #test_reset_wrapper() + test_reset_wrapper() test_reset_wrapper_ts() From 564dc813dda9697461325041895bd25ce6e0b266 Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Tue, 25 Jun 2024 11:21:01 +0100 Subject: [PATCH 22/38] add simple training test --- stoix/configs/network/memoroid.yaml | 32 ++++++++++++++--------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/stoix/configs/network/memoroid.yaml b/stoix/configs/network/memoroid.yaml index 64f060e3..0871725f 100644 --- a/stoix/configs/network/memoroid.yaml +++ b/stoix/configs/network/memoroid.yaml @@ -3,45 +3,45 @@ actor_network: pre_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: False - activation: silu + layer_sizes: [128] + use_layer_norm: True + activation: leaky_relu rnn_layer: _target_: stoix.networks.memoroid.ScannedMemoroid cell: _target_: stoix.networks.memoroid.MemoroidResetWrapper cell: _target_: stoix.networks.memoroid.FFMCell - trace_size: 64 + trace_size: 32 context_size: 4 - output_size: 256 + output_size: 128 post_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: False - activation: silu + layer_sizes: [128, 128] + use_layer_norm: True + activation: leaky_relu action_head: _target_: stoix.networks.heads.CategoricalHead critic_network: pre_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: False - activation: silu + layer_sizes: [128] + use_layer_norm: True + activation: leaky_relu rnn_layer: _target_: stoix.networks.memoroid.ScannedMemoroid cell: _target_: stoix.networks.memoroid.MemoroidResetWrapper cell: _target_: stoix.networks.memoroid.FFMCell - trace_size: 64 + trace_size: 32 context_size: 4 - output_size: 256 + output_size: 128 post_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: False - activation: silu + layer_sizes: [128, 128] + use_layer_norm: True + activation: leaky_relu critic_head: _target_: stoix.networks.heads.ScalarCriticHead From 4752cb81ed65d8bd3be642e1fcd756562402a86b Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Tue, 25 Jun 2024 11:21:23 +0100 Subject: [PATCH 23/38] add simple training test --- stoix/networks/memoroid.py | 54 +++++++++++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/stoix/networks/memoroid.py b/stoix/networks/memoroid.py index fbf1fbf0..fbb9db95 100644 --- a/stoix/networks/memoroid.py +++ b/stoix/networks/memoroid.py @@ -3,7 +3,9 @@ import chex import flax.linen as nn +import flax import jax +import optax import jax.numpy as jnp # Typing aliases @@ -99,7 +101,7 @@ def __call__(self, x): def init_deterministic( - memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 + memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1024 ) -> Tuple[chex.Array, chex.Array]: """Deterministic initialization of the FFM parameters.""" a_low = 1e-6 @@ -247,6 +249,12 @@ def __call__( x, resets = inputs h = self.cell.map_to_h(x) + # TODO: In the original implementation, the recurrent timestep is also one + # recurrent_state = ( + # (recurrent_state[0][0], + # jnp.ones_like(recurrent_state[0][1])), + # recurrent_state[1] + # ) recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, (h, resets)) # recurrent_state is (state, timestep) out = self.cell.map_from_h(recurrent_state, x) @@ -333,6 +341,45 @@ def test_reset_wrapper_ts(): print(batched_ts == 4) + +def train_memorize(): + BatchFFM = ScannedMemoroid + + m = BatchFFM( + cell=MemoroidResetWrapper(cell=FFMCell(output_size=128, trace_size=32, context_size=4)) + ) + + batch_size = 1 + rem_ts = 10 + time_steps = rem_ts * 5 + obs_space = 2 + rng = jax.random.PRNGKey(0) + x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space).reshape(-1, 1, 1) + y = jnp.repeat(x[::rem_ts], x.shape[0] // x[::rem_ts].shape[0]).reshape(-1, 1) + start = jnp.zeros([time_steps, batch_size], dtype=bool).at[::rem_ts].set(True) + #start = jnp.zeros([time_steps, batch_size], dtype=bool) + #start = jnp.ones([time_steps, batch_size], dtype=bool) + + s = m.initialize_carry(batch_size) + params = m.init(jax.random.PRNGKey(0), s, (x, start)) + + def error(params, x, start, key): + s = m.initialize_carry(batch_size) + x = jax.random.randint(key, (time_steps, batch_size), 0, obs_space).reshape(-1, 1, 1) + out_state, y_hat = m.apply(params, s, (x, start)) + return jnp.mean((y - y_hat) ** 2) + + optimizer = optax.adam(learning_rate=0.002) + state = optimizer.init(params) + loss_fn = jax.jit(jax.value_and_grad(error)) + for step in range(10_000): + rng = jax.random.split(rng)[0] + loss, grads = loss_fn(params, x, start, rng) + updates, state = optimizer.update(grads, state) + params = optax.apply_updates(params, updates) + print(f"Step {step+1}, Loss: {loss}") + + if __name__ == "__main__": # BatchFFM = ScannedMemoroid @@ -354,5 +401,6 @@ def test_reset_wrapper_ts(): # print(out) # print(debug_shape(out_state)) - test_reset_wrapper() - test_reset_wrapper_ts() + #test_reset_wrapper() + #test_reset_wrapper_ts() + train_memorize() From 78e5fad342a562c73708e27e2d26d36befb517eb Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Tue, 25 Jun 2024 11:21:57 +0100 Subject: [PATCH 24/38] add simple training test --- stoix/networks/memoroid.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/stoix/networks/memoroid.py b/stoix/networks/memoroid.py index fbb9db95..323e79f6 100644 --- a/stoix/networks/memoroid.py +++ b/stoix/networks/memoroid.py @@ -3,7 +3,6 @@ import chex import flax.linen as nn -import flax import jax import optax import jax.numpy as jnp @@ -101,7 +100,7 @@ def __call__(self, x): def init_deterministic( - memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1024 + memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1000 ) -> Tuple[chex.Array, chex.Array]: """Deterministic initialization of the FFM parameters.""" a_low = 1e-6 From 0b210ceeb0282e05fe7f32bed01f55aacaaa5e13 Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Tue, 25 Jun 2024 11:26:44 +0100 Subject: [PATCH 25/38] oops wrong y, pls pull --- stoix/networks/memoroid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stoix/networks/memoroid.py b/stoix/networks/memoroid.py index 323e79f6..6a148117 100644 --- a/stoix/networks/memoroid.py +++ b/stoix/networks/memoroid.py @@ -350,7 +350,7 @@ def train_memorize(): batch_size = 1 rem_ts = 10 - time_steps = rem_ts * 5 + time_steps = rem_ts * 10 obs_space = 2 rng = jax.random.PRNGKey(0) x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space).reshape(-1, 1, 1) @@ -365,6 +365,7 @@ def train_memorize(): def error(params, x, start, key): s = m.initialize_carry(batch_size) x = jax.random.randint(key, (time_steps, batch_size), 0, obs_space).reshape(-1, 1, 1) + y = jnp.repeat(x[::rem_ts], x.shape[0] // x[::rem_ts].shape[0]).reshape(-1, 1) out_state, y_hat = m.apply(params, s, (x, start)) return jnp.mean((y - y_hat) ** 2) From acf9a298a1170eca551f9534a27ab78aa0fd9cf5 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Tue, 25 Jun 2024 17:51:29 +0000 Subject: [PATCH 26/38] feat: add demos --- .../configs/env/popjym/repeat_first_easy.yaml | 2 +- .../configs/env/popjym/repeat_first_hard.yaml | 2 +- .../env/popjym/repeat_first_medium.yaml | 2 +- .../env/popjym/stateless_cartpole_easy.yaml | 2 +- stoix/configs/network/rnn.yaml | 8 +- stoix/configs/system/rec_ppo.yaml | 16 +- stoix/networks/memoroid.py | 54 ++-- stoix/networks/utils.py | 1 + stoix/networks/working_demo.py | 242 ++++++++++++++++++ stoix/networks/working_demov2.py | 223 ++++++++++++++++ 10 files changed, 515 insertions(+), 37 deletions(-) create mode 100644 stoix/networks/working_demo.py create mode 100644 stoix/networks/working_demov2.py diff --git a/stoix/configs/env/popjym/repeat_first_easy.yaml b/stoix/configs/env/popjym/repeat_first_easy.yaml index db90c91b..5c3118a4 100644 --- a/stoix/configs/env/popjym/repeat_first_easy.yaml +++ b/stoix/configs/env/popjym/repeat_first_easy.yaml @@ -9,4 +9,4 @@ kwargs: {} # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. -eval_metric: episode_return \ No newline at end of file +eval_metric: episode_return diff --git a/stoix/configs/env/popjym/repeat_first_hard.yaml b/stoix/configs/env/popjym/repeat_first_hard.yaml index 3736f763..fb59973f 100644 --- a/stoix/configs/env/popjym/repeat_first_hard.yaml +++ b/stoix/configs/env/popjym/repeat_first_hard.yaml @@ -9,4 +9,4 @@ kwargs: {} # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. -eval_metric: episode_return \ No newline at end of file +eval_metric: episode_return diff --git a/stoix/configs/env/popjym/repeat_first_medium.yaml b/stoix/configs/env/popjym/repeat_first_medium.yaml index eafd73e3..93c90f2c 100644 --- a/stoix/configs/env/popjym/repeat_first_medium.yaml +++ b/stoix/configs/env/popjym/repeat_first_medium.yaml @@ -9,4 +9,4 @@ kwargs: {} # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. -eval_metric: episode_return \ No newline at end of file +eval_metric: episode_return diff --git a/stoix/configs/env/popjym/stateless_cartpole_easy.yaml b/stoix/configs/env/popjym/stateless_cartpole_easy.yaml index 3d26bc2e..516264bd 100644 --- a/stoix/configs/env/popjym/stateless_cartpole_easy.yaml +++ b/stoix/configs/env/popjym/stateless_cartpole_easy.yaml @@ -9,4 +9,4 @@ kwargs: {} # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. -eval_metric: episode_return \ No newline at end of file +eval_metric: episode_return diff --git a/stoix/configs/network/rnn.yaml b/stoix/configs/network/rnn.yaml index 459cfa64..285b1bbb 100644 --- a/stoix/configs/network/rnn.yaml +++ b/stoix/configs/network/rnn.yaml @@ -5,7 +5,7 @@ actor_network: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] use_layer_norm: False - activation: silu + activation: leaky_relu rnn_layer: _target_: stoix.networks.recurrent.ScannedRNN cell_type: gru @@ -14,7 +14,7 @@ actor_network: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] use_layer_norm: False - activation: silu + activation: leaky_relu action_head: _target_: stoix.networks.heads.CategoricalHead @@ -23,7 +23,7 @@ critic_network: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] use_layer_norm: False - activation: silu + activation: leaky_relu rnn_layer: _target_: stoix.networks.recurrent.ScannedRNN cell_type: gru @@ -32,6 +32,6 @@ critic_network: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] use_layer_norm: False - activation: silu + activation: leaky_relu critic_head: _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/configs/system/rec_ppo.yaml b/stoix/configs/system/rec_ppo.yaml index 1c7c44ff..db70ff80 100644 --- a/stoix/configs/system/rec_ppo.yaml +++ b/stoix/configs/system/rec_ppo.yaml @@ -3,18 +3,18 @@ system_name: rec_ppo # Name of the system. # --- RL hyperparameters --- -actor_lr: 2.5e-4 # Learning rate for actor network -critic_lr: 2.5e-4 # Learning rate for critic network -rollout_length: 128 # Number of environment steps per vectorised environment. -epochs: 4 # Number of ppo epochs per training data batch. -num_minibatches: 2 # Number of minibatches per ppo epoch. +actor_lr: 5e-5 # Learning rate for actor network +critic_lr: 5e-5 # Learning rate for critic network +rollout_length: 64 # Number of environment steps per vectorised environment. +epochs: 10 # Number of ppo epochs per training data batch. +num_minibatches: 32 # Number of minibatches per ppo epoch. gamma: 0.99 # Discounting factor. gae_lambda: 0.95 # Lambda value for GAE computation. clip_eps: 0.2 # Clipping value for PPO updates and value function. -ent_coef: 0.01 # Entropy regularisation term for loss function. -vf_coef: 0.5 # Critic weight in +ent_coef: 0.001 # Entropy regularisation term for loss function. +vf_coef: 1.0 # Critic weight in max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update. -decay_learning_rates: False # Whether learning rates should be linearly decayed during training. +decay_learning_rates: True # Whether learning rates should be linearly decayed during training. standardize_advantages: True # Whether to standardize the advantages. # --- Recurrent hyperparameters --- diff --git a/stoix/networks/memoroid.py b/stoix/networks/memoroid.py index 6a148117..1f5d082a 100644 --- a/stoix/networks/memoroid.py +++ b/stoix/networks/memoroid.py @@ -4,8 +4,8 @@ import chex import flax.linen as nn import jax -import optax import jax.numpy as jnp +import optax # Typing aliases Carry = chex.ArrayTree @@ -179,7 +179,7 @@ def initialize_carry( if batch_size is not None: carry_shape = (carry_shape[0], batch_size, *carry_shape[1:]) t_shape = (*t_shape, batch_size) - + return jnp.zeros(carry_shape, dtype=jnp.complex64), jnp.zeros(t_shape, dtype=jnp.int32) def __call__(self, carry: RecurrentState, incoming): @@ -283,7 +283,7 @@ def test_reset_wrapper(): cell=MemoroidResetWrapper(cell=FFMCell(output_size=2, trace_size=2, context_size=3)) ) - batch_size = 4 + batch_size = 4 time_steps = 100 # Have a batched version with one episode per batch # and collapse it into a single episode with a single batch (but same start/resets) @@ -303,17 +303,23 @@ def test_reset_wrapper(): contig_s = m.initialize_carry(1) params = m.init(jax.random.PRNGKey(0), batched_s, (x_batched, batched_starts)) - - ((batched_out_state, batched_ts), batched_reset), batched_out = m.apply(params, batched_s, (x_batched, batched_starts)) - ((contig_out_state, contig_ts), contig_reset), contig_out = m.apply(params, contig_s, (x_contig, contig_starts)) + ((batched_out_state, batched_ts), batched_reset), batched_out = m.apply( + params, batched_s, (x_batched, batched_starts) + ) + ((contig_out_state, contig_ts), contig_reset), contig_out = m.apply( + params, contig_s, (x_contig, contig_starts) + ) # This should be nearly zero (1e-10 or something) state_error = jnp.linalg.norm(contig_out_state - batched_out_state[-1], axis=-1).sum() print("state error", state_error) - out_error = jnp.linalg.norm(batched_out - jnp.swapaxes(contig_out.reshape(batch_size, time_steps, -1), 1, 0), axis=-1).sum() + out_error = jnp.linalg.norm( + batched_out - jnp.swapaxes(contig_out.reshape(batch_size, time_steps, -1), 1, 0), axis=-1 + ).sum() print("out error", out_error) print(batched_ts, contig_ts) + def test_reset_wrapper_ts(): BatchFFM = ScannedMemoroid @@ -321,26 +327,32 @@ def test_reset_wrapper_ts(): cell=MemoroidResetWrapper(cell=FFMCell(output_size=2, trace_size=2, context_size=3)) ) - batch_size = 2 + batch_size = 2 time_steps = 10 # Have a batched version with one episode per batch # and collapse it into a single episode with a single batch (but same start/resets) # results should be identical - batched_starts = jnp.array([ - [False, False, True, False, False, True, True, False, False, False], - [False, False, True, False, False, True, True, False, False, False], - ]).T - - x_batched = jnp.arange(time_steps * batch_size * 2).reshape((time_steps, batch_size, 2)).astype(jnp.float32) + batched_starts = jnp.array( + [ + [False, False, True, False, False, True, True, False, False, False], + [False, False, True, False, False, True, True, False, False, False], + ] + ).T + + x_batched = ( + jnp.arange(time_steps * batch_size * 2) + .reshape((time_steps, batch_size, 2)) + .astype(jnp.float32) + ) batched_s = m.initialize_carry(batch_size) params = m.init(jax.random.PRNGKey(0), batched_s, (x_batched, batched_starts)) - - ((batched_out_state, batched_ts), batched_reset), batched_out = m.apply(params, batched_s, (x_batched, batched_starts)) + ((batched_out_state, batched_ts), batched_reset), batched_out = m.apply( + params, batched_s, (x_batched, batched_starts) + ) print(batched_ts == 4) - def train_memorize(): BatchFFM = ScannedMemoroid @@ -356,8 +368,8 @@ def train_memorize(): x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space).reshape(-1, 1, 1) y = jnp.repeat(x[::rem_ts], x.shape[0] // x[::rem_ts].shape[0]).reshape(-1, 1) start = jnp.zeros([time_steps, batch_size], dtype=bool).at[::rem_ts].set(True) - #start = jnp.zeros([time_steps, batch_size], dtype=bool) - #start = jnp.ones([time_steps, batch_size], dtype=bool) + # start = jnp.zeros([time_steps, batch_size], dtype=bool) + # start = jnp.ones([time_steps, batch_size], dtype=bool) s = m.initialize_carry(batch_size) params = m.init(jax.random.PRNGKey(0), s, (x, start)) @@ -401,6 +413,6 @@ def error(params, x, start, key): # print(out) # print(debug_shape(out_state)) - #test_reset_wrapper() - #test_reset_wrapper_ts() + # test_reset_wrapper() + # test_reset_wrapper_ts() train_memorize() diff --git a/stoix/networks/utils.py b/stoix/networks/utils.py index 9101b67a..9e14550d 100644 --- a/stoix/networks/utils.py +++ b/stoix/networks/utils.py @@ -8,6 +8,7 @@ def parse_activation_fn(activation_fn_name: str) -> Callable[[chex.Array], chex. """Get the activation function.""" activation_fns: Dict[str, Callable[[chex.Array], chex.Array]] = { "relu": nn.relu, + "leaky_relu": nn.leaky_relu, "tanh": nn.tanh, "silu": nn.silu, "elu": nn.elu, diff --git a/stoix/networks/working_demo.py b/stoix/networks/working_demo.py new file mode 100644 index 00000000..6a8f170f --- /dev/null +++ b/stoix/networks/working_demo.py @@ -0,0 +1,242 @@ +from functools import partial +from typing import Any, Dict, Tuple + +import chex +import jax +import optax +from flax import linen as nn +from jax import numpy as jnp +from jax import vmap + + +def init_deterministic( + memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 +) -> Tuple[jax.Array, jax.Array]: + a_low = 1e-6 + a_high = 0.5 + a = jnp.linspace(a_low, a_high, memory_size) + b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) + return a, b + + +class Gate(nn.Module): + output_size: int + + @nn.compact + def __call__(self, x: chex.Array) -> chex.Array: + return jax.nn.sigmoid(nn.Dense(self.output_size)(x)) + + +class FFM(nn.Module): + trace_size: int + context_size: int + output_size: int + + def setup(self) -> None: + self.a = self.param( + "ffm_a", + lambda key, shape: init_deterministic(self.trace_size, self.context_size)[0], + (), + ) + self.b = self.param( + "ffm_b", + lambda key, shape: init_deterministic(self.trace_size, self.context_size)[1], + (), + ) + + @nn.compact + def __call__( + self, x: jax.Array, state: jax.Array, start: jax.Array + ) -> Tuple[jax.Array, jax.Array]: + + x = nn.Dense(64)(x) + x = nn.relu(x) + x = nn.Dense(self.trace_size * 2)(x) + + gate_in = Gate(self.trace_size)(x) + pre = Gate(self.trace_size)(x) + gated_x = pre * gate_in + scan_input = jnp.repeat(jnp.expand_dims(gated_x, 2), self.context_size, axis=2) + state = self.scan(scan_input, state, start) + z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape( + state.shape[0], -1 + ) + z = nn.Dense(64)(z_in) + gate_out = Gate(64)(x) + skip = nn.Dense(64)(x) + out = nn.LayerNorm(use_scale=False, use_bias=False)(z * gate_out) + skip * (1 - gate_out) + final_state = state[-1:] + + out = nn.Dense(64)(out) + out = nn.relu(out) + out = nn.Dense(self.output_size)(out) + + return out, final_state + + def initial_state(self) -> jax.Array: + return jnp.zeros((1, self.trace_size, self.context_size), dtype=jnp.complex64) + + def log_gamma(self, t: jax.Array) -> jax.Array: + a = self.a + b = self.b + a = -jnp.abs(a).reshape((1, self.trace_size, 1)) + b = b.reshape(1, 1, self.context_size) + ab = jax.lax.complex(a, b) + return ab * t.reshape(t.shape[0], 1, 1) + + def gamma(self, t: jax.Array) -> jax.Array: + return jnp.exp(self.log_gamma(t)) + + def unwrapped_associative_update( + self, + carry: Tuple[jax.Array, jax.Array, jax.Array], + incoming: Tuple[jax.Array, jax.Array, jax.Array], + ) -> Tuple[jax.Array, jax.Array, jax.Array]: + ( + state, + i, + ) = carry + x, j = incoming + state = state * self.gamma(j) + x + return state, j + i + + def wrapped_associative_update(self, carry, incoming): + prev_start, state, i = carry + start, x, j = incoming + # Reset all elements in the carry if we are starting a new episode + state = state * jnp.logical_not(start) + j = j * jnp.logical_not(start) + incoming = x, j + carry = (state, i) + out = self.unwrapped_associative_update(carry, incoming) + start_out = jnp.logical_or(start, prev_start) + return (start_out, *out) + + def scan( + self, + x: jax.Array, + state: jax.Array, + start: jax.Array, + ) -> jax.Array: + """Given an input and recurrent state, this will update the recurrent state. This is equivalent + to the inner-function g in the paper.""" + # x: [T, memory_size] + # memory: [1, memory_size, context_size] + T = x.shape[0] + # timestep = jnp.arange(T + 1, dtype=jnp.int32) + timestep = jnp.ones(T + 1, dtype=jnp.int32).reshape(-1, 1, 1) + # Add context dim + start = start.reshape(T, 1, 1) + + # Now insert previous recurrent state + x = jnp.concatenate([state, x], axis=0) + start = jnp.concatenate([jnp.zeros_like(start[:1]), start], axis=0) + + # This is not executed during inference -- method will just return x if size is 1 + _, new_state, _ = jax.lax.associative_scan( + self.wrapped_associative_update, + (start, x, timestep), + axis=0, + ) + return new_state[1:] + + +def train_memorize(): + + USE_BATCH_VERSION = True + + if USE_BATCH_VERSION: + + m = nn.vmap( + FFM, in_axes=1, out_axes=1, variable_axes={"params": None}, split_rngs={"params": None} + )(output_size=1, trace_size=64, context_size=4) + else: + m = FFM(output_size=1, trace_size=64, context_size=4) + + batch_size = 16 + rem_ts = 10 + time_steps = rem_ts * 10 + obs_space = 8 + rng = jax.random.PRNGKey(0) + if USE_BATCH_VERSION: + x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space) + y = jnp.stack( + [ + jnp.repeat(x[::rem_ts, i], x.shape[0] // x[::rem_ts, i].shape[0]) + for i in range(batch_size) + ], + axis=-1, + ) + x = x.reshape(time_steps, batch_size, 1) + y = y.reshape(time_steps, batch_size, 1) + + else: + x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space).reshape(-1, 1) + y = jnp.repeat(x[::rem_ts], x.shape[0] // x[::rem_ts].shape[0]).reshape(-1, 1) + + start = jnp.zeros([time_steps, batch_size], dtype=bool).at[::rem_ts].set(True) + + s = m.initial_state() + + # FOR BATCH VERSION + if USE_BATCH_VERSION: + s = jnp.expand_dims(s, 1) + s = jnp.repeat(s, batch_size, axis=1) + params = m.init(jax.random.PRNGKey(0), x, s, start) + + def error(params, x, start, key): + s = m.initial_state() + + if USE_BATCH_VERSION: + s = jnp.expand_dims(s, 1) + s = jnp.repeat(s, batch_size, axis=1) + + # For BATCH VERSION + if USE_BATCH_VERSION: + x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space) + y = jnp.stack( + [ + jnp.repeat(x[::rem_ts, i], x.shape[0] // x[::rem_ts, i].shape[0]) + for i in range(batch_size) + ], + axis=-1, + ) + x = x.reshape(time_steps, batch_size, 1) + y = y.reshape(time_steps, batch_size, 1) + else: + x = jax.random.randint(key, (time_steps, batch_size), 0, obs_space).reshape(-1, 1) + y = jnp.repeat(x[::rem_ts], x.shape[0] // x[::rem_ts].shape[0]).reshape(-1, 1) + + y_hat, final_state = m.apply(params, x, s, start) + y_hat = jnp.squeeze(y_hat) + y = jnp.squeeze(y) + accuracy = (jnp.round(y_hat) == y).mean() + loss = jnp.mean(jnp.abs(y - y_hat) ** 2) + return loss, {"accuracy": accuracy, "loss": loss} + + optimizer = optax.adam(learning_rate=0.001) + state = optimizer.init(params) + loss_fn = jax.jit(jax.grad(error, has_aux=True)) + for step in range(10_000): + rng = jax.random.split(rng)[0] + grads, loss_info = loss_fn(params, x, start, rng) + updates, state = jax.jit(optimizer.update)(grads, state) + params = jax.jit(optax.apply_updates)(params, updates) + print(f"Step {step+1}, Loss: {loss_info['loss']}, Accuracy: {loss_info['accuracy']}") + + +if __name__ == "__main__": + # m = FFM( + # output_size=4, + # trace_size=5, + # context_size=6, + # ) + # s = m.initial_state() + # x = jnp.ones((10, 2)) + # start = jnp.zeros(10, dtype=bool) + # params = m.init(jax.random.PRNGKey(0), x, s, start) + # out = m.apply(params, x, s, start) + + # print(out) + + train_memorize() diff --git a/stoix/networks/working_demov2.py b/stoix/networks/working_demov2.py new file mode 100644 index 00000000..25ea254f --- /dev/null +++ b/stoix/networks/working_demov2.py @@ -0,0 +1,223 @@ +from functools import partial +from typing import Any, Dict, Tuple + +import chex +import jax +import optax +from flax import linen as nn +from jax import numpy as jnp +from jax import vmap + + +def init_deterministic( + memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 +) -> Tuple[jax.Array, jax.Array]: + a_low = 1e-6 + a_high = 0.5 + a = jnp.linspace(a_low, a_high, memory_size) + b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) + return a, b + + +class Gate(nn.Module): + output_size: int + + @nn.compact + def __call__(self, x: chex.Array) -> chex.Array: + return jax.nn.sigmoid(nn.Dense(self.output_size)(x)) + + +class FFM(nn.Module): + trace_size: int + context_size: int + output_size: int + + def setup(self) -> None: + self.a = self.param( + "ffm_a", + lambda key, shape: init_deterministic(self.trace_size, self.context_size)[0], + (), + ) + self.b = self.param( + "ffm_b", + lambda key, shape: init_deterministic(self.trace_size, self.context_size)[1], + (), + ) + + @nn.compact + def __call__( + self, x: jax.Array, state: jax.Array, start: jax.Array + ) -> Tuple[jax.Array, jax.Array]: + + x = nn.Dense(self.output_size)(x) + x = nn.relu(x) + x = nn.Dense(self.output_size)(x) + + gate_in = Gate(self.trace_size)(x) + pre = Gate(self.trace_size)(x) + gated_x = pre * gate_in + scan_input = jnp.repeat(jnp.expand_dims(gated_x, 3), self.context_size, axis=3) + state = self.scan(scan_input, state, start) + T = state.shape[0] + B = state.shape[1] + z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape(T, B, -1) + z = nn.Dense(self.output_size)(z_in) + gate_out = Gate(self.output_size)(x) + skip = nn.Dense(self.output_size)(x) + out = nn.LayerNorm(use_scale=False, use_bias=False)(z * gate_out) + skip * (1 - gate_out) + final_state = state[-1:] + + out = nn.Dense(self.output_size)(out) + out = nn.relu(out) + out = nn.Dense(1)(out) + + return out, final_state + + def initial_state(self, batch_size: int) -> jax.Array: + return jnp.zeros((1, batch_size, self.trace_size, self.context_size), dtype=jnp.complex64) + + def log_gamma(self, t: jax.Array) -> jax.Array: + T = t.shape[0] + B = t.shape[1] + a = self.a + b = self.b + a = -jnp.abs(a).reshape((1, 1, self.trace_size, 1)) + b = b.reshape(1, 1, 1, self.context_size) + ab = jax.lax.complex(a, b) + return ab * t.reshape(T, B, 1, 1) + + def gamma(self, t: jax.Array) -> jax.Array: + return jnp.exp(self.log_gamma(t)) + + def unwrapped_associative_update( + self, + carry: Tuple[jax.Array, jax.Array, jax.Array], + incoming: Tuple[jax.Array, jax.Array, jax.Array], + ) -> Tuple[jax.Array, jax.Array, jax.Array]: + ( + state, + i, + ) = carry + x, j = incoming + state = state * self.gamma(j) + x + return state, j + i + + def wrapped_associative_update(self, carry, incoming): + prev_start, state, i = carry + start, x, j = incoming + # Reset all elements in the carry if we are starting a new episode + state = state * jnp.logical_not(start) + j = j * jnp.logical_not(start) + incoming = x, j + carry = (state, i) + out = self.unwrapped_associative_update(carry, incoming) + start_out = jnp.logical_or(start, prev_start) + return (start_out, *out) + + def scan( + self, + x: jax.Array, + state: jax.Array, + start: jax.Array, + ) -> jax.Array: + """Given an input and recurrent state, this will update the recurrent state. This is equivalent + to the inner-function g in the paper.""" + # x: [T, memory_size] + # memory: [1, memory_size, context_size] + T = x.shape[0] + B = x.shape[1] + timestep = jnp.ones((T + 1, B), dtype=jnp.int32).reshape(T + 1, B, 1, 1) + # Add context dim + start = start.reshape(T, B, 1, 1) + + # Now insert previous recurrent state + x = jnp.concatenate([state, x], axis=0) + start = jnp.concatenate([jnp.zeros_like(start[:1]), start], axis=0) + + # This is not executed during inference -- method will just return x if size is 1 + _, new_state, _ = jax.lax.associative_scan( + self.wrapped_associative_update, + (start, x, timestep), + axis=0, + ) + return new_state[1:] + + +def train_memorize(): + + USE_BATCH_VERSION = True # required to be true + + m = FFM(output_size=128, trace_size=64, context_size=4) + + batch_size = 16 + rem_ts = 10 + time_steps = rem_ts * 10 + obs_space = 8 + rng = jax.random.PRNGKey(0) + if USE_BATCH_VERSION: + x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space) + y = jnp.stack( + [ + jnp.repeat(x[::rem_ts, i], x.shape[0] // x[::rem_ts, i].shape[0]) + for i in range(batch_size) + ], + axis=-1, + ) + x = x.reshape(time_steps, batch_size, 1) + y = y.reshape(time_steps, batch_size, 1) + + start = jnp.zeros([time_steps, batch_size], dtype=bool).at[::rem_ts].set(True) + + s = m.initial_state(batch_size) + + params = m.init(jax.random.PRNGKey(0), x, s, start) + + def error(params, x, start, key): + s = m.initial_state(batch_size) + + # For BATCH VERSION + if USE_BATCH_VERSION: + x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space) + y = jnp.stack( + [ + jnp.repeat(x[::rem_ts, i], x.shape[0] // x[::rem_ts, i].shape[0]) + for i in range(batch_size) + ], + axis=-1, + ) + x = x.reshape(time_steps, batch_size, 1) + y = y.reshape(time_steps, batch_size, 1) + + y_hat, final_state = m.apply(params, x, s, start) + y_hat = jnp.squeeze(y_hat) + y = jnp.squeeze(y) + accuracy = (jnp.round(y_hat) == y).mean() + loss = jnp.mean(jnp.abs(y - y_hat) ** 2) + return loss, {"accuracy": accuracy, "loss": loss} + + optimizer = optax.adam(learning_rate=0.001) + state = optimizer.init(params) + loss_fn = jax.jit(jax.grad(error, has_aux=True)) + for step in range(10_000): + rng = jax.random.split(rng)[0] + grads, loss_info = loss_fn(params, x, start, rng) + updates, state = jax.jit(optimizer.update)(grads, state) + params = jax.jit(optax.apply_updates)(params, updates) + print(f"Step {step+1}, Loss: {loss_info['loss']}, Accuracy: {loss_info['accuracy']}") + + +if __name__ == "__main__": + # m = FFM( + # output_size=4, + # trace_size=5, + # context_size=6, + # ) + # s = m.initial_state() + # x = jnp.ones((10, 2)) + # start = jnp.zeros(10, dtype=bool) + # params = m.init(jax.random.PRNGKey(0), x, s, start) + # out = m.apply(params, x, s, start) + + # print(out) + + train_memorize() From 6d93baa20ae1282b53769ee68373f0576b4fb768 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Tue, 25 Jun 2024 22:54:25 +0000 Subject: [PATCH 27/38] chore: edit working demo --- stoix/configs/logger/base_logger.yaml | 4 +- stoix/configs/network/memoroid.yaml | 40 +- stoix/configs/system/rec_ppo.yaml | 4 +- stoix/networks/memoroid.py | 814 +++++++++++++------------- stoix/networks/working_demo.py | 242 -------- stoix/networks/working_demov2.py | 146 +++-- 6 files changed, 524 insertions(+), 726 deletions(-) delete mode 100644 stoix/networks/working_demo.py diff --git a/stoix/configs/logger/base_logger.yaml b/stoix/configs/logger/base_logger.yaml index 1b589707..8ee91dfe 100644 --- a/stoix/configs/logger/base_logger.yaml +++ b/stoix/configs/logger/base_logger.yaml @@ -4,12 +4,12 @@ base_exp_path: results # Base path for logging. use_console: True # Whether to log to stdout. use_tb: False # Whether to use tensorboard logging. use_json: False # Whether to log marl-eval style to json files. -use_neptune: False # Whether to log to neptune.ai. +use_neptune: True # Whether to log to neptune.ai. use_wandb: False # Whether to log to wandb.ai. # --- Other logger kwargs --- kwargs: - project: ~ # Project name in neptune.ai or wandb.ai. + project: e.toledo/Stoix # Project name in neptune.ai or wandb.ai. tags: [stoix] # Tags to add to the experiment. detailed_logging: False # having mean/std/min/max can clutter neptune/wandb so we make it optional json_path: ~ # If set, json files will be logged to a set path so that multiple experiments can diff --git a/stoix/configs/network/memoroid.yaml b/stoix/configs/network/memoroid.yaml index 0871725f..24d9a407 100644 --- a/stoix/configs/network/memoroid.yaml +++ b/stoix/configs/network/memoroid.yaml @@ -3,22 +3,18 @@ actor_network: pre_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [128] - use_layer_norm: True + layer_sizes: [256] + use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.memoroid.ScannedMemoroid - cell: - _target_: stoix.networks.memoroid.MemoroidResetWrapper - cell: - _target_: stoix.networks.memoroid.FFMCell - trace_size: 32 - context_size: 4 - output_size: 128 + _target_: stoix.networks.working_demov2.FFM + trace_size: 16 + context_size: 16 + output_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [128, 128] - use_layer_norm: True + layer_sizes: [256] + use_layer_norm: False activation: leaky_relu action_head: _target_: stoix.networks.heads.CategoricalHead @@ -26,22 +22,18 @@ actor_network: critic_network: pre_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [128] - use_layer_norm: True + layer_sizes: [256] + use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.memoroid.ScannedMemoroid - cell: - _target_: stoix.networks.memoroid.MemoroidResetWrapper - cell: - _target_: stoix.networks.memoroid.FFMCell - trace_size: 32 - context_size: 4 - output_size: 128 + _target_: stoix.networks.working_demov2.FFM + trace_size: 16 + context_size: 16 + output_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso - layer_sizes: [128, 128] - use_layer_norm: True + layer_sizes: [256] + use_layer_norm: False activation: leaky_relu critic_head: _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/configs/system/rec_ppo.yaml b/stoix/configs/system/rec_ppo.yaml index db70ff80..cde26400 100644 --- a/stoix/configs/system/rec_ppo.yaml +++ b/stoix/configs/system/rec_ppo.yaml @@ -3,8 +3,8 @@ system_name: rec_ppo # Name of the system. # --- RL hyperparameters --- -actor_lr: 5e-5 # Learning rate for actor network -critic_lr: 5e-5 # Learning rate for critic network +actor_lr: 1e-4 # Learning rate for actor network +critic_lr: 1e-4 # Learning rate for critic network rollout_length: 64 # Number of environment steps per vectorised environment. epochs: 10 # Number of ppo epochs per training data batch. num_minibatches: 32 # Number of minibatches per ppo epoch. diff --git a/stoix/networks/memoroid.py b/stoix/networks/memoroid.py index 1f5d082a..852c5d00 100644 --- a/stoix/networks/memoroid.py +++ b/stoix/networks/memoroid.py @@ -1,418 +1,420 @@ -from functools import partial -from typing import Optional, Tuple +# CURRENTLY NOT BEING USED -import chex -import flax.linen as nn -import jax -import jax.numpy as jnp -import optax +# from functools import partial +# from typing import Optional, Tuple -# Typing aliases -Carry = chex.ArrayTree +# import chex +# import flax.linen as nn +# import jax +# import jax.numpy as jnp +# import optax -HiddenState = chex.Array -Timestep = chex.Array -Reset = chex.Array +# # Typing aliases +# Carry = chex.ArrayTree -RecurrentState = Tuple[HiddenState, Timestep] +# HiddenState = chex.Array +# Timestep = chex.Array +# Reset = chex.Array -InputEmbedding = chex.Array -Inputs = Tuple[InputEmbedding, Reset] +# RecurrentState = Tuple[HiddenState, Timestep] +# InputEmbedding = chex.Array +# Inputs = Tuple[InputEmbedding, Reset] -def debug_shape(x): - return jax.tree.map(lambda x: x.shape, x) +# def debug_shape(x): +# return jax.tree.map(lambda x: x.shape, x) -class MemoroidCellBase(nn.Module): - """Memoroid cell base class.""" - def map_to_h(self, inputs: Inputs) -> RecurrentState: - """Map from the input space to the recurrent state space""" - raise NotImplementedError +# class MemoroidCellBase(nn.Module): +# """Memoroid cell base class.""" - def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> HiddenState: - """Map from the recurrent space to the Markov space""" - raise NotImplementedError +# def map_to_h(self, inputs: Inputs) -> RecurrentState: +# """Map from the input space to the recurrent state space""" +# raise NotImplementedError - @nn.nowrap - def initialize_carry( - self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None - ) -> RecurrentState: - """Initialize the Memoroid cell carry. - - Args: - batch_size: the batch size of the carry. - rng: random number generator passed to the init_fn. - - Returns: - An initialized carry for the given RNN cell. - """ - raise NotImplementedError - - @property - def num_feature_axes(self) -> int: - """Returns the number of feature axes of the cell.""" - raise NotImplementedError - - -def recurrent_associative_scan( - cell: nn.Module, - state: RecurrentState, - inputs: RecurrentState, - axis: int = 0, -) -> RecurrentState: - """Execute the associative scan to update the recurrent state. - - Note that we do a trick here by concatenating the previous state to the inputs. - This is allowed since the scan is associative. This ensures that the previous - recurrent state feeds information into the scan. Without this method, we need - separate methods for rollouts and training.""" - - # Concatenate the previous state to the inputs and scan over the result - # This ensures the previous recurrent state contributes to the current batch - - scan_inputs = jax.tree.map(lambda s, x: jnp.concatenate([s, x], axis=axis), state, inputs) - new_state = jax.lax.associative_scan( - cell, - scan_inputs, - axis=axis, - ) - - # The zeroth index corresponds to the previous recurrent state - # We just use it to ensure continuity - # We do not actually want to use these values, so slice them away - return jax.tree.map( - lambda x: jax.lax.slice_in_dim(x, start_index=1, limit_index=None, axis=axis), new_state - ) - - -class Gate(nn.Module): - """Sigmoidal gating""" - - output_size: int - - @nn.compact - def __call__(self, x): - x = nn.Dense(self.output_size)(x) - x = nn.sigmoid(x) - return x - - -def init_deterministic( - memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1000 -) -> Tuple[chex.Array, chex.Array]: - """Deterministic initialization of the FFM parameters.""" - a_low = 1e-6 - a_high = 0.5 - a = jnp.linspace(a_low, a_high, memory_size) - b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) - return a, b - - -class FFMCell(MemoroidCellBase): - """The binary associative update function for the FFM.""" - - trace_size: int - context_size: int - output_size: int - - def setup(self): - - # Create the parameters that are explicitly used in the cells core computation - a, b = init_deterministic(self.trace_size, self.context_size) - self.params = (self.param("ffa_a", lambda rng: a), self.param("ffa_b", lambda rng: b)) - - # Create the networks and parameters that are used when - # mapping from input space to recurrent state space - # This is used in the map_to_h method and is used in the - # associative scan outer loop - self.pre = nn.Dense(self.trace_size) - self.gate_in = Gate(self.trace_size) - self.gate_out = Gate(self.output_size) - self.skip = nn.Dense(self.output_size) - self.mix = nn.Dense(self.output_size) - self.ln = nn.LayerNorm(use_scale=False, use_bias=False) - - def map_to_h(self, x: InputEmbedding) -> RecurrentState: - """Map from the input space to the recurrent state space - unlike the call function - this explicitly expects a shape including the sequence dimension. This is used in the - outer network that uses the associative scan.""" - gate_in = self.gate_in(x) - pre = self.pre(x) - gated_x = pre * gate_in - # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous - ts = jnp.ones(x.shape[0:2], dtype=jnp.int32) - z = jnp.repeat(jnp.expand_dims(gated_x, 3), self.context_size, axis=3) - return (z, ts) - - def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> HiddenState: - """Map from the recurrent space to the Markov space""" - (state, _), _ = recurrent_state - z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape( - state.shape[0], state.shape[1], -1 - ) - z = self.mix(z_in) - gate_out = self.gate_out(x) - skip = self.skip(x) - out = self.ln(z * gate_out) + skip * (1 - gate_out) - return out - - def log_gamma(self, t: chex.Array) -> chex.Array: - a, b = self.params - a = -jnp.abs(a).reshape((1, 1, self.trace_size, 1)) - b = b.reshape(1, 1, 1, self.context_size) - ab = jax.lax.complex(a, b) - return ab * t.reshape(t.shape[0], t.shape[1], 1, 1) - - def gamma(self, t: chex.Array) -> chex.Array: - return jnp.exp(self.log_gamma(t)) - - @nn.nowrap - def initialize_carry( - self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None - ) -> RecurrentState: - # inputs should be of shape [*batch, time, feature] - # recurrent states should be of shape [*batch, 1, feature] - carry_shape = (1, self.trace_size, self.context_size) - t_shape = (1,) - if batch_size is not None: - carry_shape = (carry_shape[0], batch_size, *carry_shape[1:]) - t_shape = (*t_shape, batch_size) - - return jnp.zeros(carry_shape, dtype=jnp.complex64), jnp.zeros(t_shape, dtype=jnp.int32) - - def __call__(self, carry: RecurrentState, incoming): - ( - state, - i, - ) = carry - x, j = incoming - state = state * self.gamma(j) + x - return state, j + i - - -class MemoroidResetWrapper(MemoroidCellBase): - """A wrapper around memoroid cells like FFM, LRU, etc that resets - the recurrent state upon a reset signal.""" - - cell: nn.Module - - def __call__(self, carry, incoming, rng=None): - states, prev_carry_reset_flag = carry - xs, start = incoming - - def reset_state(start: Reset, current_state, initial_state): - # Expand to reset all dims of state: [1, B, 1, ...] - assert initial_state.ndim == current_state.ndim - expanded_start = start.reshape(-1, start.shape[1], *([1] * (current_state.ndim - 2))) - out = current_state * jnp.logical_not(expanded_start) + initial_state - return out - - # Add an extra dim, as start will be [Batch] while intialize carry expects [Batch, Feature] - initial_states = self.cell.initialize_carry(rng=rng, batch_size=start.shape[1]) - states = jax.tree.map(partial(reset_state, start), states, initial_states) - out = self.cell(states, xs) - carry_reset_flag = jnp.logical_or(start, prev_carry_reset_flag) - - return out, carry_reset_flag - - def map_to_h(self, x: InputEmbedding) -> RecurrentState: - return self.cell.map_to_h(x) - - def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> HiddenState: - return self.cell.map_from_h(recurrent_state, x) - - @nn.nowrap - def initialize_carry( - self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None - ) -> RecurrentState: - return self.cell.initialize_carry(batch_size, rng), jnp.zeros((1, batch_size), dtype=bool) - - -class ScannedMemoroid(nn.Module): - cell: nn.Module - - @nn.compact - def __call__( - self, recurrent_state: RecurrentState, inputs: Inputs - ) -> Tuple[RecurrentState, HiddenState]: - """Apply the ScannedMemoroid. - This takes in a sequence of batched states and inputs. - The recurrent state that is used requires no sequence dimension but does require a batch dimension.""" - # Recurrent state should be (state, timestep) - # Inputs should be (x, reset) - - # Unsqueeze the recurrent state to add the sequence dimension of size 1 - recurrent_state = jax.tree.map(lambda x: jnp.expand_dims(x, 0), recurrent_state) - - x, resets = inputs - h = self.cell.map_to_h(x) - # TODO: In the original implementation, the recurrent timestep is also one - # recurrent_state = ( - # (recurrent_state[0][0], - # jnp.ones_like(recurrent_state[0][1])), - # recurrent_state[1] - # ) - recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, (h, resets)) - # recurrent_state is (state, timestep) - out = self.cell.map_from_h(recurrent_state, x) - - # TODO: Remove this when we want to return all recurrent states instead of just the last one - final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) - - # Squeeze the sequence dimension of 1 out - final_recurrent_state = jax.tree.map(lambda x: jnp.squeeze(x, 0), final_recurrent_state) - - return final_recurrent_state, out - - @nn.nowrap - def initialize_carry( - self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None - ) -> RecurrentState: - """Initialize the carry for the ScannedMemoroid. This returns the carry in the shape [Batch, ...] i.e. it contains no sequence dimension""" - # We squeeze the sequence dim of 1 out. - return jax.tree.map(lambda x: x.squeeze(0), self.cell.initialize_carry(batch_size, rng)) - - -def test_reset_wrapper(): - """Validate that the reset wrapper works as expected""" - BatchFFM = ScannedMemoroid - - m = BatchFFM( - cell=MemoroidResetWrapper(cell=FFMCell(output_size=2, trace_size=2, context_size=3)) - ) - - batch_size = 4 - time_steps = 100 - # Have a batched version with one episode per batch - # and collapse it into a single episode with a single batch (but same start/resets) - # results should be identical - batched_starts = jnp.ones([batch_size], dtype=bool) - # batched_starts = jnp.concatenate([ - # jnp.zeros([time_steps // 2, batch_size], dtype=bool), - # batched_starts.reshape(1, -1), - # jnp.zeros([time_steps // 2 - 1, batch_size], dtype=bool) - # ], axis=0) - batched_starts = jax.random.uniform(jax.random.PRNGKey(0), (time_steps, batch_size)) < 0.1 - contig_starts = jnp.swapaxes(batched_starts, 1, 0).reshape(-1, 1) - - x_batched = jnp.arange(time_steps * batch_size * 2).reshape((time_steps, batch_size, 2)) - x_contig = jnp.swapaxes(x_batched, 1, 0).reshape(-1, 1, 2) - batched_s = m.initialize_carry(batch_size) - contig_s = m.initialize_carry(1) - params = m.init(jax.random.PRNGKey(0), batched_s, (x_batched, batched_starts)) - - ((batched_out_state, batched_ts), batched_reset), batched_out = m.apply( - params, batched_s, (x_batched, batched_starts) - ) - ((contig_out_state, contig_ts), contig_reset), contig_out = m.apply( - params, contig_s, (x_contig, contig_starts) - ) - - # This should be nearly zero (1e-10 or something) - state_error = jnp.linalg.norm(contig_out_state - batched_out_state[-1], axis=-1).sum() - print("state error", state_error) - out_error = jnp.linalg.norm( - batched_out - jnp.swapaxes(contig_out.reshape(batch_size, time_steps, -1), 1, 0), axis=-1 - ).sum() - print("out error", out_error) - print(batched_ts, contig_ts) - - -def test_reset_wrapper_ts(): - BatchFFM = ScannedMemoroid - - m = BatchFFM( - cell=MemoroidResetWrapper(cell=FFMCell(output_size=2, trace_size=2, context_size=3)) - ) - - batch_size = 2 - time_steps = 10 - # Have a batched version with one episode per batch - # and collapse it into a single episode with a single batch (but same start/resets) - # results should be identical - batched_starts = jnp.array( - [ - [False, False, True, False, False, True, True, False, False, False], - [False, False, True, False, False, True, True, False, False, False], - ] - ).T - - x_batched = ( - jnp.arange(time_steps * batch_size * 2) - .reshape((time_steps, batch_size, 2)) - .astype(jnp.float32) - ) - batched_s = m.initialize_carry(batch_size) - params = m.init(jax.random.PRNGKey(0), batched_s, (x_batched, batched_starts)) - - ((batched_out_state, batched_ts), batched_reset), batched_out = m.apply( - params, batched_s, (x_batched, batched_starts) - ) - print(batched_ts == 4) - - -def train_memorize(): - BatchFFM = ScannedMemoroid - - m = BatchFFM( - cell=MemoroidResetWrapper(cell=FFMCell(output_size=128, trace_size=32, context_size=4)) - ) - - batch_size = 1 - rem_ts = 10 - time_steps = rem_ts * 10 - obs_space = 2 - rng = jax.random.PRNGKey(0) - x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space).reshape(-1, 1, 1) - y = jnp.repeat(x[::rem_ts], x.shape[0] // x[::rem_ts].shape[0]).reshape(-1, 1) - start = jnp.zeros([time_steps, batch_size], dtype=bool).at[::rem_ts].set(True) - # start = jnp.zeros([time_steps, batch_size], dtype=bool) - # start = jnp.ones([time_steps, batch_size], dtype=bool) - - s = m.initialize_carry(batch_size) - params = m.init(jax.random.PRNGKey(0), s, (x, start)) - - def error(params, x, start, key): - s = m.initialize_carry(batch_size) - x = jax.random.randint(key, (time_steps, batch_size), 0, obs_space).reshape(-1, 1, 1) - y = jnp.repeat(x[::rem_ts], x.shape[0] // x[::rem_ts].shape[0]).reshape(-1, 1) - out_state, y_hat = m.apply(params, s, (x, start)) - return jnp.mean((y - y_hat) ** 2) - - optimizer = optax.adam(learning_rate=0.002) - state = optimizer.init(params) - loss_fn = jax.jit(jax.value_and_grad(error)) - for step in range(10_000): - rng = jax.random.split(rng)[0] - loss, grads = loss_fn(params, x, start, rng) - updates, state = optimizer.update(grads, state) - params = optax.apply_updates(params, updates) - print(f"Step {step+1}, Loss: {loss}") - - -if __name__ == "__main__": - # BatchFFM = ScannedMemoroid - - # m = BatchFFM( - # cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)) - # ) - - # batch_size = 8 - # time_steps = 10 - - # y = jnp.ones((time_steps, batch_size, 2)) - # s = m.initialize_carry(batch_size) - # start = jnp.zeros((time_steps, batch_size), dtype=bool) - # params = m.init(jax.random.PRNGKey(0), s, (y, start)) - # out_state, out = m.apply(params, s, (y, start)) - - # out = jnp.swapaxes(out, 0, 1) - - # print(out) - # print(debug_shape(out_state)) - - # test_reset_wrapper() - # test_reset_wrapper_ts() - train_memorize() +# def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> HiddenState: +# """Map from the recurrent space to the Markov space""" +# raise NotImplementedError + +# @nn.nowrap +# def initialize_carry( +# self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None +# ) -> RecurrentState: +# """Initialize the Memoroid cell carry. + +# Args: +# batch_size: the batch size of the carry. +# rng: random number generator passed to the init_fn. + +# Returns: +# An initialized carry for the given RNN cell. +# """ +# raise NotImplementedError + +# @property +# def num_feature_axes(self) -> int: +# """Returns the number of feature axes of the cell.""" +# raise NotImplementedError + + +# def recurrent_associative_scan( +# cell: nn.Module, +# state: RecurrentState, +# inputs: RecurrentState, +# axis: int = 0, +# ) -> RecurrentState: +# """Execute the associative scan to update the recurrent state. + +# Note that we do a trick here by concatenating the previous state to the inputs. +# This is allowed since the scan is associative. This ensures that the previous +# recurrent state feeds information into the scan. Without this method, we need +# separate methods for rollouts and training.""" + +# # Concatenate the previous state to the inputs and scan over the result +# # This ensures the previous recurrent state contributes to the current batch + +# scan_inputs = jax.tree.map(lambda s, x: jnp.concatenate([s, x], axis=axis), state, inputs) +# new_state = jax.lax.associative_scan( +# cell, +# scan_inputs, +# axis=axis, +# ) + +# # The zeroth index corresponds to the previous recurrent state +# # We just use it to ensure continuity +# # We do not actually want to use these values, so slice them away +# return jax.tree.map( +# lambda x: jax.lax.slice_in_dim(x, start_index=1, limit_index=None, axis=axis), new_state +# ) + + +# class Gate(nn.Module): +# """Sigmoidal gating""" + +# output_size: int + +# @nn.compact +# def __call__(self, x): +# x = nn.Dense(self.output_size)(x) +# x = nn.sigmoid(x) +# return x + + +# def init_deterministic( +# memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1000 +# ) -> Tuple[chex.Array, chex.Array]: +# """Deterministic initialization of the FFM parameters.""" +# a_low = 1e-6 +# a_high = 0.5 +# a = jnp.linspace(a_low, a_high, memory_size) +# b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) +# return a, b + + +# class FFMCell(MemoroidCellBase): +# """The binary associative update function for the FFM.""" + +# trace_size: int +# context_size: int +# output_size: int + +# def setup(self): + +# # Create the parameters that are explicitly used in the cells core computation +# a, b = init_deterministic(self.trace_size, self.context_size) +# self.params = (self.param("ffa_a", lambda rng: a), self.param("ffa_b", lambda rng: b)) + +# # Create the networks and parameters that are used when +# # mapping from input space to recurrent state space +# # This is used in the map_to_h method and is used in the +# # associative scan outer loop +# self.pre = nn.Dense(self.trace_size) +# self.gate_in = Gate(self.trace_size) +# self.gate_out = Gate(self.output_size) +# self.skip = nn.Dense(self.output_size) +# self.mix = nn.Dense(self.output_size) +# self.ln = nn.LayerNorm(use_scale=False, use_bias=False) + +# def map_to_h(self, x: InputEmbedding) -> RecurrentState: +# """Map from the input space to the recurrent state space - unlike the call function +# this explicitly expects a shape including the sequence dimension. This is used in the +# outer network that uses the associative scan.""" +# gate_in = self.gate_in(x) +# pre = self.pre(x) +# gated_x = pre * gate_in +# # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous +# ts = jnp.ones(x.shape[0:2], dtype=jnp.int32) +# z = jnp.repeat(jnp.expand_dims(gated_x, 3), self.context_size, axis=3) +# return (z, ts) + +# def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> HiddenState: +# """Map from the recurrent space to the Markov space""" +# (state, _), _ = recurrent_state +# z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape( +# state.shape[0], state.shape[1], -1 +# ) +# z = self.mix(z_in) +# gate_out = self.gate_out(x) +# skip = self.skip(x) +# out = self.ln(z * gate_out) + skip * (1 - gate_out) +# return out + +# def log_gamma(self, t: chex.Array) -> chex.Array: +# a, b = self.params +# a = -jnp.abs(a).reshape((1, 1, self.trace_size, 1)) +# b = b.reshape(1, 1, 1, self.context_size) +# ab = jax.lax.complex(a, b) +# return ab * t.reshape(t.shape[0], t.shape[1], 1, 1) + +# def gamma(self, t: chex.Array) -> chex.Array: +# return jnp.exp(self.log_gamma(t)) + +# @nn.nowrap +# def initialize_carry( +# self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None +# ) -> RecurrentState: +# # inputs should be of shape [*batch, time, feature] +# # recurrent states should be of shape [*batch, 1, feature] +# carry_shape = (1, self.trace_size, self.context_size) +# t_shape = (1,) +# if batch_size is not None: +# carry_shape = (carry_shape[0], batch_size, *carry_shape[1:]) +# t_shape = (*t_shape, batch_size) + +# return jnp.zeros(carry_shape, dtype=jnp.complex64), jnp.zeros(t_shape, dtype=jnp.int32) + +# def __call__(self, carry: RecurrentState, incoming): +# ( +# state, +# i, +# ) = carry +# x, j = incoming +# state = state * self.gamma(j) + x +# return state, j + i + + +# class MemoroidResetWrapper(MemoroidCellBase): +# """A wrapper around memoroid cells like FFM, LRU, etc that resets +# the recurrent state upon a reset signal.""" + +# cell: nn.Module + +# def __call__(self, carry, incoming, rng=None): +# states, prev_carry_reset_flag = carry +# xs, start = incoming + +# def reset_state(start: Reset, current_state, initial_state): +# # Expand to reset all dims of state: [1, B, 1, ...] +# assert initial_state.ndim == current_state.ndim +# expanded_start = start.reshape(-1, start.shape[1], *([1] * (current_state.ndim - 2))) +# out = current_state * jnp.logical_not(expanded_start) + initial_state +# return out + +# # Add an extra dim, as start will be [Batch] while intialize carry expects [Batch, Feature] +# initial_states = self.cell.initialize_carry(rng=rng, batch_size=start.shape[1]) +# states = jax.tree.map(partial(reset_state, start), states, initial_states) +# out = self.cell(states, xs) +# carry_reset_flag = jnp.logical_or(start, prev_carry_reset_flag) + +# return out, carry_reset_flag + +# def map_to_h(self, x: InputEmbedding) -> RecurrentState: +# return self.cell.map_to_h(x) + +# def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> HiddenState: +# return self.cell.map_from_h(recurrent_state, x) + +# @nn.nowrap +# def initialize_carry( +# self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None +# ) -> RecurrentState: +# return self.cell.initialize_carry(batch_size, rng), jnp.zeros((1, batch_size), dtype=bool) + + +# class ScannedMemoroid(nn.Module): +# cell: nn.Module + +# @nn.compact +# def __call__( +# self, recurrent_state: RecurrentState, inputs: Inputs +# ) -> Tuple[RecurrentState, HiddenState]: +# """Apply the ScannedMemoroid. +# This takes in a sequence of batched states and inputs. +# The recurrent state that is used requires no sequence dimension but does require a batch dimension.""" +# # Recurrent state should be (state, timestep) +# # Inputs should be (x, reset) + +# # Unsqueeze the recurrent state to add the sequence dimension of size 1 +# recurrent_state = jax.tree.map(lambda x: jnp.expand_dims(x, 0), recurrent_state) + +# x, resets = inputs +# h = self.cell.map_to_h(x) +# # TODO: In the original implementation, the recurrent timestep is also one +# # recurrent_state = ( +# # (recurrent_state[0][0], +# # jnp.ones_like(recurrent_state[0][1])), +# # recurrent_state[1] +# # ) +# recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, (h, resets)) +# # recurrent_state is (state, timestep) +# out = self.cell.map_from_h(recurrent_state, x) + +# # TODO: Remove this when we want to return all recurrent states instead of just the last one +# final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) + +# # Squeeze the sequence dimension of 1 out +# final_recurrent_state = jax.tree.map(lambda x: jnp.squeeze(x, 0), final_recurrent_state) + +# return final_recurrent_state, out + +# @nn.nowrap +# def initialize_carry( +# self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None +# ) -> RecurrentState: +# """Initialize the carry for the ScannedMemoroid. This returns the carry in the shape [Batch, ...] i.e. it contains no sequence dimension""" +# # We squeeze the sequence dim of 1 out. +# return jax.tree.map(lambda x: x.squeeze(0), self.cell.initialize_carry(batch_size, rng)) + + +# def test_reset_wrapper(): +# """Validate that the reset wrapper works as expected""" +# BatchFFM = ScannedMemoroid + +# m = BatchFFM( +# cell=MemoroidResetWrapper(cell=FFMCell(output_size=2, trace_size=2, context_size=3)) +# ) + +# batch_size = 4 +# time_steps = 100 +# # Have a batched version with one episode per batch +# # and collapse it into a single episode with a single batch (but same start/resets) +# # results should be identical +# batched_starts = jnp.ones([batch_size], dtype=bool) +# # batched_starts = jnp.concatenate([ +# # jnp.zeros([time_steps // 2, batch_size], dtype=bool), +# # batched_starts.reshape(1, -1), +# # jnp.zeros([time_steps // 2 - 1, batch_size], dtype=bool) +# # ], axis=0) +# batched_starts = jax.random.uniform(jax.random.PRNGKey(0), (time_steps, batch_size)) < 0.1 +# contig_starts = jnp.swapaxes(batched_starts, 1, 0).reshape(-1, 1) + +# x_batched = jnp.arange(time_steps * batch_size * 2).reshape((time_steps, batch_size, 2)) +# x_contig = jnp.swapaxes(x_batched, 1, 0).reshape(-1, 1, 2) +# batched_s = m.initialize_carry(batch_size) +# contig_s = m.initialize_carry(1) +# params = m.init(jax.random.PRNGKey(0), batched_s, (x_batched, batched_starts)) + +# ((batched_out_state, batched_ts), batched_reset), batched_out = m.apply( +# params, batched_s, (x_batched, batched_starts) +# ) +# ((contig_out_state, contig_ts), contig_reset), contig_out = m.apply( +# params, contig_s, (x_contig, contig_starts) +# ) + +# # This should be nearly zero (1e-10 or something) +# state_error = jnp.linalg.norm(contig_out_state - batched_out_state[-1], axis=-1).sum() +# print("state error", state_error) +# out_error = jnp.linalg.norm( +# batched_out - jnp.swapaxes(contig_out.reshape(batch_size, time_steps, -1), 1, 0), axis=-1 +# ).sum() +# print("out error", out_error) +# print(batched_ts, contig_ts) + + +# def test_reset_wrapper_ts(): +# BatchFFM = ScannedMemoroid + +# m = BatchFFM( +# cell=MemoroidResetWrapper(cell=FFMCell(output_size=2, trace_size=2, context_size=3)) +# ) + +# batch_size = 2 +# time_steps = 10 +# # Have a batched version with one episode per batch +# # and collapse it into a single episode with a single batch (but same start/resets) +# # results should be identical +# batched_starts = jnp.array( +# [ +# [False, False, True, False, False, True, True, False, False, False], +# [False, False, True, False, False, True, True, False, False, False], +# ] +# ).T + +# x_batched = ( +# jnp.arange(time_steps * batch_size * 2) +# .reshape((time_steps, batch_size, 2)) +# .astype(jnp.float32) +# ) +# batched_s = m.initialize_carry(batch_size) +# params = m.init(jax.random.PRNGKey(0), batched_s, (x_batched, batched_starts)) + +# ((batched_out_state, batched_ts), batched_reset), batched_out = m.apply( +# params, batched_s, (x_batched, batched_starts) +# ) +# print(batched_ts == 4) + + +# def train_memorize(): +# BatchFFM = ScannedMemoroid + +# m = BatchFFM( +# cell=MemoroidResetWrapper(cell=FFMCell(output_size=128, trace_size=32, context_size=4)) +# ) + +# batch_size = 1 +# rem_ts = 10 +# time_steps = rem_ts * 10 +# obs_space = 2 +# rng = jax.random.PRNGKey(0) +# x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space).reshape(-1, 1, 1) +# y = jnp.repeat(x[::rem_ts], x.shape[0] // x[::rem_ts].shape[0]).reshape(-1, 1) +# start = jnp.zeros([time_steps, batch_size], dtype=bool).at[::rem_ts].set(True) +# # start = jnp.zeros([time_steps, batch_size], dtype=bool) +# # start = jnp.ones([time_steps, batch_size], dtype=bool) + +# s = m.initialize_carry(batch_size) +# params = m.init(jax.random.PRNGKey(0), s, (x, start)) + +# def error(params, x, start, key): +# s = m.initialize_carry(batch_size) +# x = jax.random.randint(key, (time_steps, batch_size), 0, obs_space).reshape(-1, 1, 1) +# y = jnp.repeat(x[::rem_ts], x.shape[0] // x[::rem_ts].shape[0]).reshape(-1, 1) +# out_state, y_hat = m.apply(params, s, (x, start)) +# return jnp.mean((y - y_hat) ** 2) + +# optimizer = optax.adam(learning_rate=0.002) +# state = optimizer.init(params) +# loss_fn = jax.jit(jax.value_and_grad(error)) +# for step in range(10_000): +# rng = jax.random.split(rng)[0] +# loss, grads = loss_fn(params, x, start, rng) +# updates, state = optimizer.update(grads, state) +# params = optax.apply_updates(params, updates) +# print(f"Step {step+1}, Loss: {loss}") + + +# if __name__ == "__main__": +# # BatchFFM = ScannedMemoroid + +# # m = BatchFFM( +# # cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)) +# # ) + +# # batch_size = 8 +# # time_steps = 10 + +# # y = jnp.ones((time_steps, batch_size, 2)) +# # s = m.initialize_carry(batch_size) +# # start = jnp.zeros((time_steps, batch_size), dtype=bool) +# # params = m.init(jax.random.PRNGKey(0), s, (y, start)) +# # out_state, out = m.apply(params, s, (y, start)) + +# # out = jnp.swapaxes(out, 0, 1) + +# # print(out) +# # print(debug_shape(out_state)) + +# # test_reset_wrapper() +# # test_reset_wrapper_ts() +# train_memorize() diff --git a/stoix/networks/working_demo.py b/stoix/networks/working_demo.py deleted file mode 100644 index 6a8f170f..00000000 --- a/stoix/networks/working_demo.py +++ /dev/null @@ -1,242 +0,0 @@ -from functools import partial -from typing import Any, Dict, Tuple - -import chex -import jax -import optax -from flax import linen as nn -from jax import numpy as jnp -from jax import vmap - - -def init_deterministic( - memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 -) -> Tuple[jax.Array, jax.Array]: - a_low = 1e-6 - a_high = 0.5 - a = jnp.linspace(a_low, a_high, memory_size) - b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) - return a, b - - -class Gate(nn.Module): - output_size: int - - @nn.compact - def __call__(self, x: chex.Array) -> chex.Array: - return jax.nn.sigmoid(nn.Dense(self.output_size)(x)) - - -class FFM(nn.Module): - trace_size: int - context_size: int - output_size: int - - def setup(self) -> None: - self.a = self.param( - "ffm_a", - lambda key, shape: init_deterministic(self.trace_size, self.context_size)[0], - (), - ) - self.b = self.param( - "ffm_b", - lambda key, shape: init_deterministic(self.trace_size, self.context_size)[1], - (), - ) - - @nn.compact - def __call__( - self, x: jax.Array, state: jax.Array, start: jax.Array - ) -> Tuple[jax.Array, jax.Array]: - - x = nn.Dense(64)(x) - x = nn.relu(x) - x = nn.Dense(self.trace_size * 2)(x) - - gate_in = Gate(self.trace_size)(x) - pre = Gate(self.trace_size)(x) - gated_x = pre * gate_in - scan_input = jnp.repeat(jnp.expand_dims(gated_x, 2), self.context_size, axis=2) - state = self.scan(scan_input, state, start) - z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape( - state.shape[0], -1 - ) - z = nn.Dense(64)(z_in) - gate_out = Gate(64)(x) - skip = nn.Dense(64)(x) - out = nn.LayerNorm(use_scale=False, use_bias=False)(z * gate_out) + skip * (1 - gate_out) - final_state = state[-1:] - - out = nn.Dense(64)(out) - out = nn.relu(out) - out = nn.Dense(self.output_size)(out) - - return out, final_state - - def initial_state(self) -> jax.Array: - return jnp.zeros((1, self.trace_size, self.context_size), dtype=jnp.complex64) - - def log_gamma(self, t: jax.Array) -> jax.Array: - a = self.a - b = self.b - a = -jnp.abs(a).reshape((1, self.trace_size, 1)) - b = b.reshape(1, 1, self.context_size) - ab = jax.lax.complex(a, b) - return ab * t.reshape(t.shape[0], 1, 1) - - def gamma(self, t: jax.Array) -> jax.Array: - return jnp.exp(self.log_gamma(t)) - - def unwrapped_associative_update( - self, - carry: Tuple[jax.Array, jax.Array, jax.Array], - incoming: Tuple[jax.Array, jax.Array, jax.Array], - ) -> Tuple[jax.Array, jax.Array, jax.Array]: - ( - state, - i, - ) = carry - x, j = incoming - state = state * self.gamma(j) + x - return state, j + i - - def wrapped_associative_update(self, carry, incoming): - prev_start, state, i = carry - start, x, j = incoming - # Reset all elements in the carry if we are starting a new episode - state = state * jnp.logical_not(start) - j = j * jnp.logical_not(start) - incoming = x, j - carry = (state, i) - out = self.unwrapped_associative_update(carry, incoming) - start_out = jnp.logical_or(start, prev_start) - return (start_out, *out) - - def scan( - self, - x: jax.Array, - state: jax.Array, - start: jax.Array, - ) -> jax.Array: - """Given an input and recurrent state, this will update the recurrent state. This is equivalent - to the inner-function g in the paper.""" - # x: [T, memory_size] - # memory: [1, memory_size, context_size] - T = x.shape[0] - # timestep = jnp.arange(T + 1, dtype=jnp.int32) - timestep = jnp.ones(T + 1, dtype=jnp.int32).reshape(-1, 1, 1) - # Add context dim - start = start.reshape(T, 1, 1) - - # Now insert previous recurrent state - x = jnp.concatenate([state, x], axis=0) - start = jnp.concatenate([jnp.zeros_like(start[:1]), start], axis=0) - - # This is not executed during inference -- method will just return x if size is 1 - _, new_state, _ = jax.lax.associative_scan( - self.wrapped_associative_update, - (start, x, timestep), - axis=0, - ) - return new_state[1:] - - -def train_memorize(): - - USE_BATCH_VERSION = True - - if USE_BATCH_VERSION: - - m = nn.vmap( - FFM, in_axes=1, out_axes=1, variable_axes={"params": None}, split_rngs={"params": None} - )(output_size=1, trace_size=64, context_size=4) - else: - m = FFM(output_size=1, trace_size=64, context_size=4) - - batch_size = 16 - rem_ts = 10 - time_steps = rem_ts * 10 - obs_space = 8 - rng = jax.random.PRNGKey(0) - if USE_BATCH_VERSION: - x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space) - y = jnp.stack( - [ - jnp.repeat(x[::rem_ts, i], x.shape[0] // x[::rem_ts, i].shape[0]) - for i in range(batch_size) - ], - axis=-1, - ) - x = x.reshape(time_steps, batch_size, 1) - y = y.reshape(time_steps, batch_size, 1) - - else: - x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space).reshape(-1, 1) - y = jnp.repeat(x[::rem_ts], x.shape[0] // x[::rem_ts].shape[0]).reshape(-1, 1) - - start = jnp.zeros([time_steps, batch_size], dtype=bool).at[::rem_ts].set(True) - - s = m.initial_state() - - # FOR BATCH VERSION - if USE_BATCH_VERSION: - s = jnp.expand_dims(s, 1) - s = jnp.repeat(s, batch_size, axis=1) - params = m.init(jax.random.PRNGKey(0), x, s, start) - - def error(params, x, start, key): - s = m.initial_state() - - if USE_BATCH_VERSION: - s = jnp.expand_dims(s, 1) - s = jnp.repeat(s, batch_size, axis=1) - - # For BATCH VERSION - if USE_BATCH_VERSION: - x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space) - y = jnp.stack( - [ - jnp.repeat(x[::rem_ts, i], x.shape[0] // x[::rem_ts, i].shape[0]) - for i in range(batch_size) - ], - axis=-1, - ) - x = x.reshape(time_steps, batch_size, 1) - y = y.reshape(time_steps, batch_size, 1) - else: - x = jax.random.randint(key, (time_steps, batch_size), 0, obs_space).reshape(-1, 1) - y = jnp.repeat(x[::rem_ts], x.shape[0] // x[::rem_ts].shape[0]).reshape(-1, 1) - - y_hat, final_state = m.apply(params, x, s, start) - y_hat = jnp.squeeze(y_hat) - y = jnp.squeeze(y) - accuracy = (jnp.round(y_hat) == y).mean() - loss = jnp.mean(jnp.abs(y - y_hat) ** 2) - return loss, {"accuracy": accuracy, "loss": loss} - - optimizer = optax.adam(learning_rate=0.001) - state = optimizer.init(params) - loss_fn = jax.jit(jax.grad(error, has_aux=True)) - for step in range(10_000): - rng = jax.random.split(rng)[0] - grads, loss_info = loss_fn(params, x, start, rng) - updates, state = jax.jit(optimizer.update)(grads, state) - params = jax.jit(optax.apply_updates)(params, updates) - print(f"Step {step+1}, Loss: {loss_info['loss']}, Accuracy: {loss_info['accuracy']}") - - -if __name__ == "__main__": - # m = FFM( - # output_size=4, - # trace_size=5, - # context_size=6, - # ) - # s = m.initial_state() - # x = jnp.ones((10, 2)) - # start = jnp.zeros(10, dtype=bool) - # params = m.init(jax.random.PRNGKey(0), x, s, start) - # out = m.apply(params, x, s, start) - - # print(out) - - train_memorize() diff --git a/stoix/networks/working_demov2.py b/stoix/networks/working_demov2.py index 25ea254f..c9b9f3f9 100644 --- a/stoix/networks/working_demov2.py +++ b/stoix/networks/working_demov2.py @@ -1,17 +1,22 @@ -from functools import partial -from typing import Any, Dict, Tuple +from typing import Tuple import chex import jax import optax from flax import linen as nn from jax import numpy as jnp -from jax import vmap + +RecurrentState = chex.Array +Reset = chex.Array +Timestep = chex.Array +InputEmbedding = chex.Array +Inputs = Tuple[InputEmbedding, Reset] +ScanInput = chex.Array def init_deterministic( memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 -) -> Tuple[jax.Array, jax.Array]: +) -> Tuple[chex.Array, chex.Array]: a_low = 1e-6 a_high = 0.5 a = jnp.linspace(a_low, a_high, memory_size) @@ -33,50 +38,51 @@ class FFM(nn.Module): output_size: int def setup(self) -> None: + + # Create the FFM parameters + a, b = init_deterministic(self.trace_size, self.context_size) self.a = self.param( "ffm_a", - lambda key, shape: init_deterministic(self.trace_size, self.context_size)[0], + lambda key, shape: a, (), ) self.b = self.param( "ffm_b", - lambda key, shape: init_deterministic(self.trace_size, self.context_size)[1], + lambda key, shape: b, (), ) - @nn.compact - def __call__( - self, x: jax.Array, state: jax.Array, start: jax.Array - ) -> Tuple[jax.Array, jax.Array]: - - x = nn.Dense(self.output_size)(x) - x = nn.relu(x) - x = nn.Dense(self.output_size)(x) - - gate_in = Gate(self.trace_size)(x) - pre = Gate(self.trace_size)(x) + # Create the networks and parameters that are used when + # mapping from input space to recurrent state space + # This is used in the map_to_h method and is used in the + # associative scan outer loop + self.pre = nn.Dense(self.trace_size) + self.gate_in = Gate(self.trace_size) + self.gate_out = Gate(self.output_size) + self.skip = nn.Dense(self.output_size) + self.mix = nn.Dense(self.output_size) + self.ln = nn.LayerNorm(use_scale=False, use_bias=False) + + def map_to_h(self, x: InputEmbedding) -> ScanInput: + """Given an input embedding, this will map it to the format required for the associative scan.""" + gate_in = self.gate_in(x) + pre = self.pre(x) gated_x = pre * gate_in scan_input = jnp.repeat(jnp.expand_dims(gated_x, 3), self.context_size, axis=3) - state = self.scan(scan_input, state, start) + return scan_input + + def map_from_h(self, state: RecurrentState, x: InputEmbedding) -> chex.Array: + """Given the recurrent state and the input embedding, this will map the recurrent state back to the output space.""" T = state.shape[0] B = state.shape[1] z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape(T, B, -1) - z = nn.Dense(self.output_size)(z_in) - gate_out = Gate(self.output_size)(x) - skip = nn.Dense(self.output_size)(x) - out = nn.LayerNorm(use_scale=False, use_bias=False)(z * gate_out) + skip * (1 - gate_out) - final_state = state[-1:] + z = self.mix(z_in) + gate_out = self.gate_out(x) + skip = self.skip(x) + out = self.ln(z * gate_out) + skip * (1 - gate_out) + return out - out = nn.Dense(self.output_size)(out) - out = nn.relu(out) - out = nn.Dense(1)(out) - - return out, final_state - - def initial_state(self, batch_size: int) -> jax.Array: - return jnp.zeros((1, batch_size, self.trace_size, self.context_size), dtype=jnp.complex64) - - def log_gamma(self, t: jax.Array) -> jax.Array: + def log_gamma(self, t: Timestep) -> chex.Array: T = t.shape[0] B = t.shape[1] a = self.a @@ -86,14 +92,14 @@ def log_gamma(self, t: jax.Array) -> jax.Array: ab = jax.lax.complex(a, b) return ab * t.reshape(T, B, 1, 1) - def gamma(self, t: jax.Array) -> jax.Array: + def gamma(self, t: Timestep) -> chex.Array: return jnp.exp(self.log_gamma(t)) def unwrapped_associative_update( self, - carry: Tuple[jax.Array, jax.Array, jax.Array], - incoming: Tuple[jax.Array, jax.Array, jax.Array], - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + carry: Tuple[RecurrentState, Timestep], + incoming: Tuple[InputEmbedding, Timestep], + ) -> Tuple[RecurrentState, Timestep]: ( state, i, @@ -102,7 +108,11 @@ def unwrapped_associative_update( state = state * self.gamma(j) + x return state, j + i - def wrapped_associative_update(self, carry, incoming): + def wrapped_associative_update( + self, + carry: Tuple[Reset, RecurrentState, Timestep], + incoming: Tuple[Reset, InputEmbedding, Timestep], + ) -> Tuple[Reset, RecurrentState, Timestep]: prev_start, state, i = carry start, x, j = incoming # Reset all elements in the carry if we are starting a new episode @@ -116,14 +126,14 @@ def wrapped_associative_update(self, carry, incoming): def scan( self, - x: jax.Array, - state: jax.Array, - start: jax.Array, - ) -> jax.Array: + x: InputEmbedding, + state: RecurrentState, + start: Reset, + ) -> RecurrentState: """Given an input and recurrent state, this will update the recurrent state. This is equivalent to the inner-function g in the paper.""" - # x: [T, memory_size] - # memory: [1, memory_size, context_size] + # x: [T, B, memory_size] + # memory: [1, B, memory_size, context_size] T = x.shape[0] B = x.shape[1] timestep = jnp.ones((T + 1, B), dtype=jnp.int32).reshape(T + 1, B, 1, 1) @@ -142,10 +152,46 @@ def scan( ) return new_state[1:] + @nn.compact + def __call__(self, state: RecurrentState, inputs: Inputs) -> Tuple[RecurrentState, chex.Array]: + + # Add a sequence dimension to the recurrent state. + state = jnp.expand_dims(state, 0) + + # Unpack inputs + x, start = inputs + + # Map the input embedding to the recurrent state space. + # This maps to the format required for the associative scan. + scan_input = self.map_to_h(x) + + # Update the recurrent state + state = self.scan(scan_input, state, start) + + # Map the recurrent state back to the output space + out = self.map_from_h(state, x) + + # Take the final state of the sequence. + final_state = state[-1:] + + # TODO: remove this when not running test + out = nn.Dense(128)(out) + out = nn.relu(out) + out = nn.Dense(1)(out) + + # Remove the sequence dimemnsion from the final state. + final_state = jnp.squeeze(final_state, 0) + + return final_state, out + + @nn.nowrap + def initialize_carry(self, batch_size: int) -> RecurrentState: + return jnp.zeros((batch_size, self.trace_size, self.context_size), dtype=jnp.complex64) + def train_memorize(): - USE_BATCH_VERSION = True # required to be true + USE_BATCH_VERSION = True # Required to be true m = FFM(output_size=128, trace_size=64, context_size=4) @@ -168,12 +214,12 @@ def train_memorize(): start = jnp.zeros([time_steps, batch_size], dtype=bool).at[::rem_ts].set(True) - s = m.initial_state(batch_size) + s = m.initialize_carry(batch_size) - params = m.init(jax.random.PRNGKey(0), x, s, start) + params = m.init(jax.random.PRNGKey(0), s, (x, start)) def error(params, x, start, key): - s = m.initial_state(batch_size) + s = m.initialize_carry(batch_size) # For BATCH VERSION if USE_BATCH_VERSION: @@ -188,7 +234,7 @@ def error(params, x, start, key): x = x.reshape(time_steps, batch_size, 1) y = y.reshape(time_steps, batch_size, 1) - y_hat, final_state = m.apply(params, x, s, start) + final_state, y_hat = m.apply(params, s, (x, start)) y_hat = jnp.squeeze(y_hat) y = jnp.squeeze(y) accuracy = (jnp.round(y_hat) == y).mean() @@ -212,7 +258,7 @@ def error(params, x, start, key): # trace_size=5, # context_size=6, # ) - # s = m.initial_state() + # s = m.initialize_carry() # x = jnp.ones((10, 2)) # start = jnp.zeros(10, dtype=bool) # params = m.init(jax.random.PRNGKey(0), x, s, start) From f79febc2b5228d95b513619d9f5804519dcb6dc4 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Wed, 26 Jun 2024 14:52:29 +0000 Subject: [PATCH 28/38] fix: add required popjym wrapper --- stoix/configs/network/memoroid.yaml | 8 ++--- stoix/utils/make_env.py | 9 ++++- stoix/wrappers/transforms.py | 56 +++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 5 deletions(-) diff --git a/stoix/configs/network/memoroid.yaml b/stoix/configs/network/memoroid.yaml index 24d9a407..ea5a39a1 100644 --- a/stoix/configs/network/memoroid.yaml +++ b/stoix/configs/network/memoroid.yaml @@ -8,8 +8,8 @@ actor_network: activation: leaky_relu rnn_layer: _target_: stoix.networks.working_demov2.FFM - trace_size: 16 - context_size: 16 + trace_size: 64 + context_size: 4 output_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso @@ -27,8 +27,8 @@ critic_network: activation: leaky_relu rnn_layer: _target_: stoix.networks.working_demov2.FFM - trace_size: 16 - context_size: 16 + trace_size: 64 + context_size: 4 output_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso diff --git a/stoix/utils/make_env.py b/stoix/utils/make_env.py index bec61173..7471aca5 100644 --- a/stoix/utils/make_env.py +++ b/stoix/utils/make_env.py @@ -27,7 +27,11 @@ from stoix.wrappers.brax import BraxJumanjiWrapper from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper from stoix.wrappers.pgx import PGXWrapper -from stoix.wrappers.transforms import MultiBoundedToBounded, MultiDiscreteToDiscrete +from stoix.wrappers.transforms import ( + AddStartFlagAndPrevAction, + MultiBoundedToBounded, + MultiDiscreteToDiscrete, +) from stoix.wrappers.xminigrid import XMiniGridWrapper @@ -330,6 +334,9 @@ def make_popjym_env(env_name: str, config: DictConfig) -> Tuple[Environment, Env env = GymnaxWrapper(env, env_params) eval_env = GymnaxWrapper(eval_env, eval_env_params) + env = AddStartFlagAndPrevAction(env) + eval_env = AddStartFlagAndPrevAction(eval_env) + env = AutoResetWrapper(env, next_obs_in_extras=True) env = RecordEpisodeMetrics(env) diff --git a/stoix/wrappers/transforms.py b/stoix/wrappers/transforms.py index 9e959bb9..253c333d 100644 --- a/stoix/wrappers/transforms.py +++ b/stoix/wrappers/transforms.py @@ -1,6 +1,7 @@ from typing import Tuple import chex +import jax import jax.numpy as jnp import numpy as np from jumanji import specs @@ -107,3 +108,58 @@ def action_spec(self) -> specs.Spec: dtype=original_action_spec.dtype, name="action", ) + + +class AddStartFlagAndPrevAction(Wrapper): + """Wrapper that adds a start flag and the previous action to the observation.""" + + def __init__(self, env: Environment): + super().__init__(env) + + # Get the action dimension + if isinstance(self.action_spec(), specs.DiscreteArray): + self.action_dim = self.action_spec().num_values + self.discrete = True + else: + self.action_dim = self.action_spec().shape[0] + self.discrete = False + + # Check if the observation is flat + if not len(self.observation_spec().agent_view.shape) == 1: + raise ValueError("The observation must be flat.") + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: + state, timestep = self._env.reset(key) + start_flag = jnp.array(1.0)[jnp.newaxis] + prev_action = jnp.zeros(self.action_dim) + agent_view = timestep.observation.agent_view + new_agent_view = jnp.concatenate([start_flag, prev_action, agent_view]) + timestep = timestep.replace( + observation=timestep.observation._replace( + agent_view=new_agent_view, + ) + ) + return state, timestep + + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: + state, timestep = self._env.step(state, action) + start_flag = jnp.array(0.0)[jnp.newaxis] + prev_action = action + if self.discrete: + prev_action = jax.nn.one_hot(prev_action, self.action_dim) + agent_view = timestep.observation.agent_view + new_agent_view = jnp.concatenate([start_flag, prev_action, agent_view]) + timestep = timestep.replace( + observation=timestep.observation._replace( + agent_view=new_agent_view, + ) + ) + return state, timestep + + def observation_spec(self) -> Spec: + return self._env.observation_spec().replace( + agent_view=Array( + shape=(1 + self.action_dim + self._env.observation_spec().agent_view.shape[0],), + dtype=jnp.float32, + ) + ) From e90e270a855a7e7abd5205b0472c7f8330fc850f Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Sun, 30 Jun 2024 22:10:19 +0000 Subject: [PATCH 29/38] chore: move all current work into memoroids file to be contained --- stoix/configs/default_rec_ppo.yaml | 4 +- stoix/configs/network/memoroid.yaml | 4 +- stoix/configs/network/s5.yaml | 39 ++ stoix/configs/system/rec_ppo.yaml | 10 +- stoix/networks/{ => memoroids}/ffm.py | 0 stoix/networks/{ => memoroids}/memoroid.py | 0 stoix/networks/memoroids/s5.py | 645 ++++++++++++++++++ .../{ => memoroids}/working_demov2.py | 0 8 files changed, 693 insertions(+), 9 deletions(-) create mode 100644 stoix/configs/network/s5.yaml rename stoix/networks/{ => memoroids}/ffm.py (100%) rename stoix/networks/{ => memoroids}/memoroid.py (100%) create mode 100644 stoix/networks/memoroids/s5.py rename stoix/networks/{ => memoroids}/working_demov2.py (100%) diff --git a/stoix/configs/default_rec_ppo.yaml b/stoix/configs/default_rec_ppo.yaml index 22e49afe..f485273d 100644 --- a/stoix/configs/default_rec_ppo.yaml +++ b/stoix/configs/default_rec_ppo.yaml @@ -2,6 +2,6 @@ defaults: - logger: base_logger - arch: anakin - system: rec_ppo - - network: rnn - - env: gymnax/cartpole + - network: memoroid + - env: popjym/repeat_first_easy - _self_ diff --git a/stoix/configs/network/memoroid.yaml b/stoix/configs/network/memoroid.yaml index ea5a39a1..c070f292 100644 --- a/stoix/configs/network/memoroid.yaml +++ b/stoix/configs/network/memoroid.yaml @@ -7,7 +7,7 @@ actor_network: use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.working_demov2.FFM + _target_: stoix.networks.memoroids.working_demov2.FFM trace_size: 64 context_size: 4 output_size: 256 @@ -26,7 +26,7 @@ critic_network: use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.working_demov2.FFM + _target_: stoix.networks.memoroids.working_demov2.FFM trace_size: 64 context_size: 4 output_size: 256 diff --git a/stoix/configs/network/s5.yaml b/stoix/configs/network/s5.yaml new file mode 100644 index 00000000..72be36e6 --- /dev/null +++ b/stoix/configs/network/s5.yaml @@ -0,0 +1,39 @@ +# ---Recurrent Structure Networks for PPO --- + +actor_network: + pre_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256] + use_layer_norm: False + activation: leaky_relu + rnn_layer: + _target_: stoix.networks.s5.StackedEncoderModel + ssm_size: 256 + d_model: 256 + n_layers: 1 + post_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256] + use_layer_norm: False + activation: leaky_relu + action_head: + _target_: stoix.networks.heads.CategoricalHead + +critic_network: + pre_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256] + use_layer_norm: False + activation: leaky_relu + rnn_layer: + _target_: stoix.networks.s5.StackedEncoderModel + ssm_size: 256 + d_model: 256 + n_layers: 1 + post_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256] + use_layer_norm: False + activation: leaky_relu + critic_head: + _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/configs/system/rec_ppo.yaml b/stoix/configs/system/rec_ppo.yaml index cde26400..15a602e0 100644 --- a/stoix/configs/system/rec_ppo.yaml +++ b/stoix/configs/system/rec_ppo.yaml @@ -3,15 +3,15 @@ system_name: rec_ppo # Name of the system. # --- RL hyperparameters --- -actor_lr: 1e-4 # Learning rate for actor network -critic_lr: 1e-4 # Learning rate for critic network -rollout_length: 64 # Number of environment steps per vectorised environment. -epochs: 10 # Number of ppo epochs per training data batch. +actor_lr: 5e-5 # Learning rate for actor network +critic_lr: 5e-5 # Learning rate for critic network +rollout_length: 128 # Number of environment steps per vectorised environment. +epochs: 15 # Number of ppo epochs per training data batch. num_minibatches: 32 # Number of minibatches per ppo epoch. gamma: 0.99 # Discounting factor. gae_lambda: 0.95 # Lambda value for GAE computation. clip_eps: 0.2 # Clipping value for PPO updates and value function. -ent_coef: 0.001 # Entropy regularisation term for loss function. +ent_coef: 0.0 # Entropy regularisation term for loss function. vf_coef: 1.0 # Critic weight in max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update. decay_learning_rates: True # Whether learning rates should be linearly decayed during training. diff --git a/stoix/networks/ffm.py b/stoix/networks/memoroids/ffm.py similarity index 100% rename from stoix/networks/ffm.py rename to stoix/networks/memoroids/ffm.py diff --git a/stoix/networks/memoroid.py b/stoix/networks/memoroids/memoroid.py similarity index 100% rename from stoix/networks/memoroid.py rename to stoix/networks/memoroids/memoroid.py diff --git a/stoix/networks/memoroids/s5.py b/stoix/networks/memoroids/s5.py new file mode 100644 index 00000000..40e3abad --- /dev/null +++ b/stoix/networks/memoroids/s5.py @@ -0,0 +1,645 @@ +from functools import partial + +import chex +import jax +import jax.numpy as np +import jax.numpy as jnp +from flax import linen as nn +from jax import random +from jax.nn.initializers import lecun_normal, normal +from jax.numpy.linalg import eigh + + +class SequenceLayer(nn.Module): + """Defines a single S5 layer, with S5 SSM, nonlinearity, etc. + Args: + ssm (nn.Module): the SSM to be used (i.e. S5 ssm) + d_model (int32): this is the feature size of the layer inputs and outputs + we usually refer to this size as H + activation (string): Type of activation function to use + prenorm (bool): apply prenorm if true or postnorm if false + step_rescale (float32): allows for uniformly changing the timescale parameter, + e.g. after training on a different resolution for + the speech commands benchmark + """ + + ssm: nn.Module + d_model: int + activation: str = "gelu" + do_norm: bool = True + prenorm: bool = True + do_gtrxl_norm: bool = True + step_rescale: float = 1.0 + + def setup(self): + """Initializes the ssm, layer norm and dense layers""" + self.seq = self.ssm(step_rescale=self.step_rescale) + + if self.activation in ["full_glu"]: + self.out1 = nn.Dense(self.d_model) + self.out2 = nn.Dense(self.d_model) + elif self.activation in ["half_glu1", "half_glu2"]: + self.out2 = nn.Dense(self.d_model) + + self.norm = nn.LayerNorm() + + def __call__(self, hidden, x, d): + """ + Compute the LxH output of S5 layer given an LxH input. + Args: + x (float32): input sequence (L, d_model) + d (bool): reset signal (L,) + Returns: + output sequence (float32): (L, d_model) + """ + skip = x + if self.prenorm and self.do_norm: + x = self.norm(x) + # hidden, x = self.seq(hidden, x, d) + hidden, x = jax.vmap(self.seq, in_axes=1, out_axes=1)(hidden, x, d) + # hidden = jnp.swapaxes(hidden, 1, 0) + if self.do_gtrxl_norm: + x = self.norm(x) + + if self.activation in ["full_glu"]: + x = nn.gelu(x) + x = self.out1(x) * jax.nn.sigmoid(self.out2(x)) + elif self.activation in ["half_glu1"]: + x = nn.gelu(x) + x = x * jax.nn.sigmoid(self.out2(x)) + elif self.activation in ["half_glu2"]: + # Only apply GELU to the gate input + x1 = nn.gelu(x) + x = x * jax.nn.sigmoid(self.out2(x1)) + elif self.activation in ["gelu"]: + x = nn.gelu(x) + else: + raise NotImplementedError("Activation: {} not implemented".format(self.activation)) + + x = skip + x + if not self.prenorm and self.do_norm: + x = self.norm(x) + return hidden, x + + @staticmethod + def initialize_carry(batch_size, hidden_size): + # Use a dummy key since the default state init fn is just zeros. + return jnp.zeros((1, batch_size, hidden_size), dtype=jnp.complex64) + + +def log_step_initializer(dt_min=0.001, dt_max=0.1): + """Initialize the learnable timescale Delta by sampling + uniformly between dt_min and dt_max. + Args: + dt_min (float32): minimum value + dt_max (float32): maximum value + Returns: + init function + """ + + def init(key, shape): + """Init function + Args: + key: jax random key + shape tuple: desired shape + Returns: + sampled log_step (float32) + """ + return random.uniform(key, shape) * (np.log(dt_max) - np.log(dt_min)) + np.log(dt_min) + + return init + + +def init_log_steps(key, input): + """Initialize an array of learnable timescale parameters + Args: + key: jax random key + input: tuple containing the array shape H and + dt_min and dt_max + Returns: + initialized array of timescales (float32): (H,) + """ + H, dt_min, dt_max = input + log_steps = [] + for i in range(H): + key, skey = random.split(key) + log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,)) + log_steps.append(log_step) + + return np.array(log_steps) + + +def init_VinvB(init_fun, rng, shape, Vinv): + """Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. + Note we will parameterize this with two different matrices for complex + numbers. + Args: + init_fun: the initialization function to use, e.g. lecun_normal() + rng: jax random key to be used with init function. + shape (tuple): desired shape (P,H) + Vinv: (complex64) the inverse eigenvectors used for initialization + Returns: + B_tilde (complex64) of shape (P,H,2) + """ + B = init_fun(rng, shape) + VinvB = Vinv @ B + VinvB_real = VinvB.real + VinvB_imag = VinvB.imag + return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) + + +def trunc_standard_normal(key, shape): + """Sample C with a truncated normal distribution with standard deviation 1. + Args: + key: jax random key + shape (tuple): desired shape, of length 3, (H,P,_) + Returns: + sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) + """ + H, P, _ = shape + Cs = [] + for i in range(H): + key, skey = random.split(key) + C = lecun_normal()(skey, shape=(1, P, 2)) + Cs.append(C) + return np.array(Cs)[:, 0] + + +def init_CV(init_fun, rng, shape, V): + """Initialize C_tilde=CV. First sample C. Then compute CV. + Note we will parameterize this with two different matrices for complex + numbers. + Args: + init_fun: the initialization function to use, e.g. lecun_normal() + rng: jax random key to be used with init function. + shape (tuple): desired shape (H,P) + V: (complex64) the eigenvectors used for initialization + Returns: + C_tilde (complex64) of shape (H,P,2) + """ + C_ = init_fun(rng, shape) + C = C_[..., 0] + 1j * C_[..., 1] + CV = C @ V + CV_real = CV.real + CV_imag = CV.imag + return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1) + + +# Discretization functions +def discretize_bilinear(Lambda, B_tilde, Delta): + """Discretize a diagonalized, continuous-time linear SSM + using bilinear transform method. + Args: + Lambda (complex64): diagonal state matrix (P,) + B_tilde (complex64): input matrix (P, H) + Delta (float32): discretization step sizes (P,) + Returns: + discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) + """ + Identity = np.ones(Lambda.shape[0]) + + BL = 1 / (Identity - (Delta / 2.0) * Lambda) + Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda) + B_bar = (BL * Delta)[..., None] * B_tilde + return Lambda_bar, B_bar + + +def discretize_zoh(Lambda, B_tilde, Delta): + """Discretize a diagonalized, continuous-time linear SSM + using zero-order hold method. + Args: + Lambda (complex64): diagonal state matrix (P,) + B_tilde (complex64): input matrix (P, H) + Delta (float32): discretization step sizes (P,) + Returns: + discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) + """ + Identity = np.ones(Lambda.shape[0]) + Lambda_bar = np.exp(Lambda * Delta) + B_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde + return Lambda_bar, B_bar + + +# Parallel scan operations +@jax.vmap +def binary_operator(q_i, q_j): + """Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. + Args: + q_i: tuple containing A_i and Bu_i at position i (P,), (P,) + q_j: tuple containing A_j and Bu_j at position j (P,), (P,) + Returns: + new element ( A_out, Bu_out ) + """ + A_i, b_i = q_i + A_j, b_j = q_j + return A_j * A_i, A_j * b_i + b_j + + +# Parallel scan operations +@jax.vmap +def binary_operator_reset(q_i, q_j): + """Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. + Args: + q_i: tuple containing A_i and Bu_i at position i (P,), (P,) + q_j: tuple containing A_j and Bu_j at position j (P,), (P,) + Returns: + new element ( A_out, Bu_out ) + """ + A_i, b_i, c_i = q_i + A_j, b_j, c_j = q_j + return ( + (A_j * A_i) * (1 - c_j) + A_j * c_j, + (A_j * b_i + b_j) * (1 - c_j) + b_j * c_j, + c_i * (1 - c_j) + c_j, + ) + + +def apply_ssm(Lambda_bar, B_bar, C_tilde, hidden, input_sequence, resets, conj_sym, bidirectional): + """Compute the LxH output of discretized SSM given an LxH input. + Args: + Lambda_bar (complex64): discretized diagonal state matrix (P,) + B_bar (complex64): discretized input matrix (P, H) + C_tilde (complex64): output matrix (H, P) + input_sequence (float32): input sequence of features (L, H) + reset (bool): input sequence of features (L,) + conj_sym (bool): whether conjugate symmetry is enforced + bidirectional (bool): whether bidirectional setup is used, + Note for this case C_tilde will have 2P cols + Returns: + ys (float32): the SSM outputs (S5 layer preactivations) (L, H) + """ + Lambda_elements = Lambda_bar * jnp.ones((input_sequence.shape[0], Lambda_bar.shape[0])) + Bu_elements = jax.vmap(lambda u: B_bar @ u)(input_sequence) + + Lambda_elements = jnp.concatenate( + [ + jnp.ones((1, Lambda_bar.shape[0])), + Lambda_elements, + ] + ) + + Bu_elements = jnp.concatenate( + [ + hidden, + Bu_elements, + ] + ) + + if resets is None: + _, xs = jax.lax.associative_scan(binary_operator, (Lambda_elements, Bu_elements)) + else: + resets = jnp.concatenate( + [ + jnp.zeros(1), + resets, + ] + ) + _, xs, _ = jax.lax.associative_scan( + binary_operator_reset, (Lambda_elements, Bu_elements, resets) + ) + xs = xs[1:] + + if conj_sym: + return xs[np.newaxis, -1], jax.vmap(lambda x: 2 * (C_tilde @ x).real)(xs) + else: + return xs[np.newaxis, -1], jax.vmap(lambda x: (C_tilde @ x).real)(xs) + + +class S5SSM(nn.Module): + Lambda_re_init: chex.Array + Lambda_im_init: chex.Array + V: chex.Array + Vinv: chex.Array + + H: int + P: int + C_init: str + discretization: str + dt_min: float + dt_max: float + conj_sym: bool = True + clip_eigs: bool = False + bidirectional: bool = False + step_rescale: float = 1.0 + + """ The S5 SSM + Args: + Lambda_re_init (complex64): Real part of init diag state matrix (P,) + Lambda_im_init (complex64): Imag part of init diag state matrix (P,) + V (complex64): Eigenvectors used for init (P,P) + Vinv (complex64): Inverse eigenvectors used for init (P,P) + H (int32): Number of features of input seq + P (int32): state size + C_init (string): Specifies How C is initialized + Options: [trunc_standard_normal: sample from truncated standard normal + and then multiply by V, i.e. C_tilde=CV. + lecun_normal: sample from Lecun_normal and then multiply by V. + complex_normal: directly sample a complex valued output matrix + from standard normal, does not multiply by V] + conj_sym (bool): Whether conjugate symmetry is enforced + clip_eigs (bool): Whether to enforce left-half plane condition, i.e. + constrain real part of eigenvalues to be negative. + True recommended for autoregressive task/unbounded sequence lengths + Discussed in https://arxiv.org/pdf/2206.11893.pdf. + bidirectional (bool): Whether model is bidirectional, if True, uses two C matrices + discretization: (string) Specifies discretization method + options: [zoh: zero-order hold method, + bilinear: bilinear transform] + dt_min: (float32): minimum value to draw timescale values from when + initializing log_step + dt_max: (float32): maximum value to draw timescale values from when + initializing log_step + step_rescale: (float32): allows for uniformly changing the timescale parameter, e.g. after training + on a different resolution for the speech commands benchmark + """ + + def setup(self): + """Initializes parameters once and performs discretization each time + the SSM is applied to a sequence + """ + + if self.conj_sym: + # Need to account for case where we actually sample real B and C, and then multiply + # by the half sized Vinv and possibly V + local_P = 2 * self.P + else: + local_P = self.P + + # Initialize diagonal state to state matrix Lambda (eigenvalues) + self.Lambda_re = self.param("Lambda_re", lambda rng, shape: self.Lambda_re_init, (None,)) + self.Lambda_im = self.param("Lambda_im", lambda rng, shape: self.Lambda_im_init, (None,)) + if self.clip_eigs: + self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im + else: + self.Lambda = self.Lambda_re + 1j * self.Lambda_im + + # Initialize input to state (B) matrix + B_init = lecun_normal() + B_shape = (local_P, self.H) + self.B = self.param( + "B", lambda rng, shape: init_VinvB(B_init, rng, shape, self.Vinv), B_shape + ) + B_tilde = self.B[..., 0] + 1j * self.B[..., 1] + + # Initialize state to output (C) matrix + if self.C_init in ["trunc_standard_normal"]: + C_init = trunc_standard_normal + C_shape = (self.H, local_P, 2) + elif self.C_init in ["lecun_normal"]: + C_init = lecun_normal() + C_shape = (self.H, local_P, 2) + elif self.C_init in ["complex_normal"]: + C_init = normal(stddev=0.5**0.5) + else: + raise NotImplementedError("C_init method {} not implemented".format(self.C_init)) + + if self.C_init in ["complex_normal"]: + if self.bidirectional: + C = self.param("C", C_init, (self.H, 2 * self.P, 2)) + self.C_tilde = C[..., 0] + 1j * C[..., 1] + + else: + C = self.param("C", C_init, (self.H, self.P, 2)) + self.C_tilde = C[..., 0] + 1j * C[..., 1] + + else: + if self.bidirectional: + self.C1 = self.param( + "C1", lambda rng, shape: init_CV(C_init, rng, shape, self.V), C_shape + ) + self.C2 = self.param( + "C2", lambda rng, shape: init_CV(C_init, rng, shape, self.V), C_shape + ) + + C1 = self.C1[..., 0] + 1j * self.C1[..., 1] + C2 = self.C2[..., 0] + 1j * self.C2[..., 1] + self.C_tilde = np.concatenate((C1, C2), axis=-1) + + else: + self.C = self.param( + "C", lambda rng, shape: init_CV(C_init, rng, shape, self.V), C_shape + ) + + self.C_tilde = self.C[..., 0] + 1j * self.C[..., 1] + + # Initialize feedthrough (D) matrix + self.D = self.param("D", normal(stddev=1.0), (self.H,)) + + # Initialize learnable discretization timescale value + self.log_step = self.param("log_step", init_log_steps, (self.P, self.dt_min, self.dt_max)) + step = self.step_rescale * np.exp(self.log_step[:, 0]) + + # Discretize + if self.discretization in ["zoh"]: + self.Lambda_bar, self.B_bar = discretize_zoh(self.Lambda, B_tilde, step) + elif self.discretization in ["bilinear"]: + self.Lambda_bar, self.B_bar = discretize_bilinear(self.Lambda, B_tilde, step) + else: + raise NotImplementedError( + "Discretization method {} not implemented".format(self.discretization) + ) + + def __call__(self, hidden, input_sequence, resets): + """ + Compute the LxH output of the S5 SSM given an LxH input sequence + using a parallel scan. + Args: + input_sequence (float32): input sequence (L, H) + resets (bool): input sequence (L,) + Returns: + output sequence (float32): (L, H) + """ + hidden, ys = apply_ssm( + self.Lambda_bar, + self.B_bar, + self.C_tilde, + hidden, + input_sequence, + resets, + self.conj_sym, + self.bidirectional, + ) + # Add feedthrough matrix output Du; + Du = jax.vmap(lambda u: self.D * u)(input_sequence) + return hidden, ys + Du + + +def init_S5SSM( + H, + P, + Lambda_re_init, + Lambda_im_init, + V, + Vinv, + C_init, + discretization, + dt_min, + dt_max, + conj_sym, + clip_eigs, + bidirectional, +): + """Convenience function that will be used to initialize the SSM. + Same arguments as defined in S5SSM above.""" + return partial( + S5SSM, + H=H, + P=P, + Lambda_re_init=Lambda_re_init, + Lambda_im_init=Lambda_im_init, + V=V, + Vinv=Vinv, + C_init=C_init, + discretization=discretization, + dt_min=dt_min, + dt_max=dt_max, + conj_sym=conj_sym, + clip_eigs=clip_eigs, + bidirectional=bidirectional, + ) + + +def make_HiPPO(N): + """Create a HiPPO-LegS matrix. + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Args: + N (int32): state size + Returns: + N x N HiPPO LegS matrix + """ + P = np.sqrt(1 + 2 * np.arange(N)) + A = P[:, np.newaxis] * P[np.newaxis, :] + A = np.tril(A) - np.diag(np.arange(N)) + return -A + + +def make_NPLR_HiPPO(N): + """ + Makes components needed for NPLR representation of HiPPO-LegS + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Args: + N (int32): state size + Returns: + N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B + """ + # Make -HiPPO + hippo = make_HiPPO(N) + + # Add in a rank 1 term. Makes it Normal. + P = np.sqrt(np.arange(N) + 0.5) + + # HiPPO also specifies the B matrix + B = np.sqrt(2 * np.arange(N) + 1.0) + return hippo, P, B + + +def make_DPLR_HiPPO(N): + """ + Makes components needed for DPLR representation of HiPPO-LegS + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Note, we will only use the diagonal part + Args: + N: + Returns: + eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, + eigenvectors V, HiPPO B pre-conjugation + """ + A, P, B = make_NPLR_HiPPO(N) + + S = A + P[:, np.newaxis] * P[np.newaxis, :] + + S_diag = np.diagonal(S) + Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) + + # Diagonalize S to V \Lambda V^* + Lambda_imag, V = eigh(S * -1j) + + P = V.conj().T @ P + B_orig = B + B = V.conj().T @ B + return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig + + +class StackedEncoderModel(nn.Module): + """Defines a stack of S5 layers to be used as an encoder. + Args: + ssm (nn.Module): the SSM to be used (i.e. S5 ssm) + d_model (int32): this is the feature size of the layer inputs and outputs + we usually refer to this size as H + n_layers (int32): the number of S5 layers to stack + activation (string): Type of activation function to use + prenorm (bool): apply prenorm if true or postnorm if false + """ + + ssm_size: int + d_model: int + n_layers: int + activation: str = "gelu" + do_norm: bool = True + prenorm: bool = True + do_gtrxl_norm: bool = True + + def setup(self): + """ + Initializes a linear encoder and the stack of S5 layers. + """ + blocks = 1 + block_size = int(self.ssm_size / blocks) + Lambda, _, _, V, _ = make_DPLR_HiPPO(self.ssm_size) + block_size = block_size // 2 + Lambda = Lambda[:block_size] + V = V[:, :block_size] + Vinv = V.conj().T + # self.encoder = nn.Dense(self.d_model) + self.layers = [ + SequenceLayer( + ssm=init_S5SSM( + H=self.d_model, + P=self.ssm_size // 2, + Lambda_re_init=Lambda.real, + Lambda_im_init=Lambda.imag, + V=V, + Vinv=Vinv, + C_init="lecun_normal", + discretization="zoh", + dt_min=0.001, + dt_max=0.1, + conj_sym=True, + clip_eigs=False, + bidirectional=False, + ), + d_model=self.d_model, + activation=self.activation, + do_norm=self.do_norm, + prenorm=self.prenorm, + do_gtrxl_norm=self.do_gtrxl_norm, + ) + for _ in range(self.n_layers) + ] + + def __call__(self, hidden, inputs): + """ + Compute the LxH output of the stacked encoder given an Lxd_input + input sequence. + Args: + x (float32): input sequence (L, d_input) + Returns: + output sequence (float32): (L, d_model) + """ + x, d = inputs + new_hiddens = [] + hidden = jax.tree.map(lambda x: jnp.expand_dims(x, 0), hidden) + for i, layer in enumerate(self.layers): + new_h, x = layer(hidden[i], x, d) + new_hiddens.append(new_h) + + new_hiddens = jax.tree.map(lambda x: x.squeeze(0), new_hiddens) + return new_hiddens, x + + @nn.nowrap + def initialize_carry(self, batch_size): + # Use a dummy key since the default state init fn is just zeros. + return [ + jnp.zeros((batch_size, self.ssm_size // 2), dtype=jnp.complex64) + for _ in range(self.n_layers) + ] diff --git a/stoix/networks/working_demov2.py b/stoix/networks/memoroids/working_demov2.py similarity index 100% rename from stoix/networks/working_demov2.py rename to stoix/networks/memoroids/working_demov2.py From 5f41a03c667ef53438f77e50021107f460f288dd Mon Sep 17 00:00:00 2001 From: Steven Morad Date: Mon, 1 Jul 2024 00:22:54 +0000 Subject: [PATCH 30/38] small hparam tweaks for memoroid --- stoix/configs/network/memoroid.yaml | 4 ++-- stoix/configs/system/rec_ppo.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/stoix/configs/network/memoroid.yaml b/stoix/configs/network/memoroid.yaml index c070f292..4597fd8e 100644 --- a/stoix/configs/network/memoroid.yaml +++ b/stoix/configs/network/memoroid.yaml @@ -8,7 +8,7 @@ actor_network: activation: leaky_relu rnn_layer: _target_: stoix.networks.memoroids.working_demov2.FFM - trace_size: 64 + trace_size: 32 context_size: 4 output_size: 256 post_torso: @@ -27,7 +27,7 @@ critic_network: activation: leaky_relu rnn_layer: _target_: stoix.networks.memoroids.working_demov2.FFM - trace_size: 64 + trace_size: 32 context_size: 4 output_size: 256 post_torso: diff --git a/stoix/configs/system/rec_ppo.yaml b/stoix/configs/system/rec_ppo.yaml index 15a602e0..1748009e 100644 --- a/stoix/configs/system/rec_ppo.yaml +++ b/stoix/configs/system/rec_ppo.yaml @@ -11,8 +11,8 @@ num_minibatches: 32 # Number of minibatches per ppo epoch. gamma: 0.99 # Discounting factor. gae_lambda: 0.95 # Lambda value for GAE computation. clip_eps: 0.2 # Clipping value for PPO updates and value function. -ent_coef: 0.0 # Entropy regularisation term for loss function. -vf_coef: 1.0 # Critic weight in +ent_coef: 0.01 # Entropy regularisation term for loss function. +vf_coef: 0.5 # Critic weight in max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update. decay_learning_rates: True # Whether learning rates should be linearly decayed during training. standardize_advantages: True # Whether to standardize the advantages. From 9916c791cac38b103c81b6e2f2348ee8e1ef37f2 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Wed, 3 Jul 2024 22:09:29 +0000 Subject: [PATCH 31/38] chore: reorganize code and add start to lru --- stoix/configs/logger/base_logger.yaml | 2 +- .../network/{memoroid.yaml => ffm.yaml} | 20 +- stoix/configs/network/s5.yaml | 4 +- stoix/configs/system/rec_ppo.yaml | 6 +- stoix/networks/memoroids/base.py | 107 +++++ stoix/networks/memoroids/ffm.py | 426 +++++------------- stoix/networks/memoroids/lru.py | 220 +++++++++ stoix/networks/memoroids/old_code1.py | 357 +++++++++++++++ .../memoroids/{memoroid.py => old_code2.py} | 0 stoix/networks/memoroids/{s5.py => old_s5.py} | 0 stoix/networks/memoroids/working_demov2.py | 269 ----------- 11 files changed, 815 insertions(+), 596 deletions(-) rename stoix/configs/network/{memoroid.yaml => ffm.yaml} (67%) create mode 100644 stoix/networks/memoroids/base.py create mode 100644 stoix/networks/memoroids/lru.py create mode 100644 stoix/networks/memoroids/old_code1.py rename stoix/networks/memoroids/{memoroid.py => old_code2.py} (100%) rename stoix/networks/memoroids/{s5.py => old_s5.py} (100%) delete mode 100644 stoix/networks/memoroids/working_demov2.py diff --git a/stoix/configs/logger/base_logger.yaml b/stoix/configs/logger/base_logger.yaml index 8ee91dfe..4d6ecf14 100644 --- a/stoix/configs/logger/base_logger.yaml +++ b/stoix/configs/logger/base_logger.yaml @@ -4,7 +4,7 @@ base_exp_path: results # Base path for logging. use_console: True # Whether to log to stdout. use_tb: False # Whether to use tensorboard logging. use_json: False # Whether to log marl-eval style to json files. -use_neptune: True # Whether to log to neptune.ai. +use_neptune: False # Whether to log to neptune.ai. use_wandb: False # Whether to log to wandb.ai. # --- Other logger kwargs --- diff --git a/stoix/configs/network/memoroid.yaml b/stoix/configs/network/ffm.yaml similarity index 67% rename from stoix/configs/network/memoroid.yaml rename to stoix/configs/network/ffm.yaml index 4597fd8e..30f0afc8 100644 --- a/stoix/configs/network/memoroid.yaml +++ b/stoix/configs/network/ffm.yaml @@ -7,10 +7,12 @@ actor_network: use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.memoroids.working_demov2.FFM - trace_size: 32 - context_size: 4 - output_size: 256 + _target_: stoix.networks.memoroids.base.ScannedMemoroid + cell: + _target_: stoix.networks.memoroids.ffm.FFMCell + trace_size: 32 + context_size: 4 + output_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] @@ -26,10 +28,12 @@ critic_network: use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.memoroids.working_demov2.FFM - trace_size: 32 - context_size: 4 - output_size: 256 + _target_: stoix.networks.memoroids.base.ScannedMemoroid + cell: + _target_: stoix.networks.memoroids.ffm.FFMCell + trace_size: 32 + context_size: 4 + output_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] diff --git a/stoix/configs/network/s5.yaml b/stoix/configs/network/s5.yaml index 72be36e6..8abbb886 100644 --- a/stoix/configs/network/s5.yaml +++ b/stoix/configs/network/s5.yaml @@ -7,7 +7,7 @@ actor_network: use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.s5.StackedEncoderModel + _target_: stoix.networks.old_s5.StackedEncoderModel ssm_size: 256 d_model: 256 n_layers: 1 @@ -26,7 +26,7 @@ critic_network: use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.s5.StackedEncoderModel + _target_: stoix.networks.old_s5.StackedEncoderModel ssm_size: 256 d_model: 256 n_layers: 1 diff --git a/stoix/configs/system/rec_ppo.yaml b/stoix/configs/system/rec_ppo.yaml index 1748009e..75b25f12 100644 --- a/stoix/configs/system/rec_ppo.yaml +++ b/stoix/configs/system/rec_ppo.yaml @@ -3,9 +3,9 @@ system_name: rec_ppo # Name of the system. # --- RL hyperparameters --- -actor_lr: 5e-5 # Learning rate for actor network -critic_lr: 5e-5 # Learning rate for critic network -rollout_length: 128 # Number of environment steps per vectorised environment. +actor_lr: 1e-5 # Learning rate for actor network +critic_lr: 1e-5 # Learning rate for critic network +rollout_length: 256 # Number of environment steps per vectorised environment. epochs: 15 # Number of ppo epochs per training data batch. num_minibatches: 32 # Number of minibatches per ppo epoch. gamma: 0.99 # Discounting factor. diff --git a/stoix/networks/memoroids/base.py b/stoix/networks/memoroids/base.py new file mode 100644 index 00000000..71214ecc --- /dev/null +++ b/stoix/networks/memoroids/base.py @@ -0,0 +1,107 @@ +from typing import List, Optional, Tuple + +import chex +from flax import linen as nn +from jax import numpy as jnp + +RecurrentState = chex.Array +Reset = chex.Array +Timestep = chex.Array +InputEmbedding = chex.Array +Inputs = Tuple[InputEmbedding, Reset] +ScanInput = chex.Array + + +class MemoroidCellBase(nn.Module): + """Memoroid cell base class.""" + + def map_to_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> RecurrentState: + raise NotImplementedError + + def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> RecurrentState: + raise NotImplementedError + + def scan(self, x: InputEmbedding, state: RecurrentState, start: Reset) -> RecurrentState: + raise NotImplementedError + + @nn.nowrap + def initialize_carry( + self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None + ) -> RecurrentState: + """Initialize the Memoroid cell carry. + + Args: + batch_size: the batch size of the carry. + rng: random number generator passed to the init_fn. + + Returns: + An initialized carry for the given Memoroid cell. + """ + raise NotImplementedError + + @property + def num_feature_axes(self) -> int: + """Returns the number of feature axes of the cell.""" + raise NotImplementedError + + +class ScannedMemoroid(nn.Module): + cell: MemoroidCellBase + + @nn.compact + def __call__(self, state: RecurrentState, inputs: Inputs) -> Tuple[RecurrentState, chex.Array]: + + # Add a sequence dimension to the recurrent state. + state = jnp.expand_dims(state, 0) + + # Unpack inputs + x, start = inputs + + # Map the input embedding to the recurrent state space. + # This maps to the format required for the associative scan. + scan_input = self.cell.map_to_h(x) + + # Update the recurrent state + state = self.cell.scan(scan_input, state, start) + + # Map the recurrent state back to the output space + out = self.cell.map_from_h(state, x) + + # Take the final state of the sequence. + final_state = state[-1:] + + # Remove the sequence dimemnsion from the final state. + final_state = jnp.squeeze(final_state, 0) + + return final_state, out + + @nn.nowrap + def initialize_carry(self, batch_size: int) -> RecurrentState: + return self.cell.initialize_carry(batch_size) + + +class StackedMemoroid(nn.Module): + cells: Tuple[ScannedMemoroid] + + @nn.compact + def __call__( + self, all_states: List[RecurrentState], inputs: Inputs + ) -> Tuple[RecurrentState, chex.Array]: + # Ensure all_states is a list + if not isinstance(all_states, list): + all_states = [all_states] + + assert len(all_states) == len( + self.cells + ), f"Expected {len(self.cells)} states, got {len(all_states)}" + + new_states = [] + for cell, mem_state in zip(self.cells, all_states): + new_mem_state, x = cell(mem_state, x) + new_states.append(new_mem_state) + + return new_states, x + + @nn.nowrap + def initialize_carry(self, batch_size: int) -> List[RecurrentState]: + return [cell.initialize_carry(batch_size) for cell in self.cells] diff --git a/stoix/networks/memoroids/ffm.py b/stoix/networks/memoroids/ffm.py index 817d2275..e0352c1c 100644 --- a/stoix/networks/memoroids/ffm.py +++ b/stoix/networks/memoroids/ffm.py @@ -1,56 +1,22 @@ -from functools import partial -from typing import Any, List, Tuple +from typing import Tuple -import flax.linen as nn +import chex import jax -import jax.numpy as jnp +from flax import linen as nn +from jax import numpy as jnp - -def recurrent_associative_scan( - cell: nn.Module, - state: jax.Array, - inputs: jax.Array, - axis: int = 0, -) -> jax.Array: - """Execute the associative scan to update the recurrent state. - - Note that we do a trick here by concatenating the previous state to the inputs. - This is allowed since the scan is associative. This ensures that the previous - recurrent state feeds information into the scan. Without this method, we need - separate methods for rollouts and training.""" - - # Concatenate the prevous state to the inputs and scan over the result - # This ensures the previous recurrent state contributes to the current batch - # state: [start, (x, j)] - # inputs: [start, (x, j)] - scan_inputs = jax.tree.map(lambda x, s: jnp.concatenate([s, x], axis=0), inputs, state) - new_state = jax.lax.associative_scan( - cell, - scan_inputs, - axis=axis, - ) - # The zeroth index corresponds to the previous recurrent state - # We just use it to ensure continuity - # We do not actually want to use these values, so slice them away - return jax.tree.map(lambda x: x[1:], new_state) - - -class Gate(nn.Module): - """Sigmoidal gating""" - - output_size: int - - @nn.compact - def __call__(self, x): - x = nn.Dense(self.output_size)(x) - x = nn.sigmoid(x) - return x +from stoix.networks.memoroids.base import ( + InputEmbedding, + RecurrentState, + Reset, + ScanInput, + Timestep, +) def init_deterministic( memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 -) -> Tuple[jax.Array, jax.Array]: - """Deterministic initialization of the FFM parameters.""" +) -> Tuple[chex.Array, chex.Array]: a_low = 1e-6 a_high = 0.5 a = jnp.linspace(a_low, a_high, memory_size) @@ -58,300 +24,134 @@ def init_deterministic( return a, b -def init_random( - memory_size: int, context_size: int, min_period: int = 1, max_period: int = 10_000, *, key -) -> Tuple[jax.Array, jax.Array]: - _, k1, k2 = jax.random.split(key, 3) - a_low = 1e-6 - a_high = 0.1 - a = jax.random.uniform(k1, (memory_size,), minval=a_low, maxval=a_high) - b = ( - 2 - * jnp.pi - / jnp.exp( - jax.random.uniform( - k2, (context_size,), minval=jnp.log(min_period), maxval=jnp.log(max_period) - ) - ) - ) - return a, b +class Gate(nn.Module): + output_size: int + @nn.compact + def __call__(self, x: chex.Array) -> chex.Array: + return jax.nn.sigmoid(nn.Dense(self.output_size)(x)) -class FFMCell(nn.Module): - """The binary associative update function for the FFM.""" +class FFMCell(nn.Module): trace_size: int context_size: int output_size: int - deterministic_init: bool = True - - def setup(self): - if self.deterministic_init: - a, b = init_deterministic(self.trace_size, self.context_size) - else: - # TODO: Will this result in the same keys for multiple FFMCells? - key = self.make_rng("ffa_params") - a, b = init_random(self.trace_size, self.context_size, key=key) - self.params = (self.param("ffa_a", lambda rng: a), self.param("ffa_b", lambda rng: b)) - - def log_gamma(self, t: jax.Array) -> jax.Array: - a, b = self.params - a = -jnp.abs(a).reshape((1, self.trace_size, 1)) - b = b.reshape(1, 1, self.context_size) - ab = jax.lax.complex(a, b) - return ab * t.reshape(t.shape[0], 1, 1) - - def gamma(self, t: jax.Array) -> jax.Array: - return jnp.exp(self.log_gamma(t)) - - def initialize_carry(self, batch_size: int = None): - if batch_size is None: - return jnp.zeros( - (1, self.trace_size, self.context_size), dtype=jnp.complex64 - ), jnp.ones((1,), dtype=jnp.int32) - - return jnp.zeros( - (1, batch_size, self.trace_size, self.context_size), dtype=jnp.complex64 - ), jnp.ones((1, batch_size), dtype=jnp.int32) - - def __call__(self, carry, incoming): - ( - state, - i, - ) = carry - x, j = incoming - state = state * self.gamma(j) + x - return state, j + i - - -class MemoroidResetWrapper(nn.Module): - """A wrapper around memoroid cells like FFM, LRU, etc that resets - the recurrent state upon a reset signal.""" - - cell: nn.Module - - def __call__(self, carry, incoming): - states, prev_start = carry - xs, start = incoming - - def reset_state(start, current_state, initial_state): - # Expand to reset all dims of state: [B, 1, 1, ...] - expanded_start = start.reshape(-1, *([1] * (current_state.ndim - 1))) - out = current_state * jnp.logical_not(expanded_start) + initial_state - return out - - initial_states = self.cell.initialize_carry() - states = jax.tree.map(partial(reset_state, start), states, initial_states) - out = self.cell(states, xs) - start_carry = jnp.logical_or(start, prev_start) - - return out, start_carry - - def initialize_carry(self, batch_size: int = None): - if batch_size is None: - # TODO: Should this be one or zero? - return self.cell.initialize_carry(batch_size), jnp.zeros((1,), dtype=bool) - - return self.cell.initialize_carry(batch_size), jnp.zeros((batch_size,), dtype=bool) - -class FFM(nn.Module): - """Fast and Forgetful Memory""" + def setup(self) -> None: - trace_size: int - context_size: int - output_size: int - cell: nn.Module + # Create the FFM parameters + a, b = init_deterministic(self.trace_size, self.context_size) + self.a = self.param( + "ffm_a", + lambda key, shape: a, + (), + ) + self.b = self.param( + "ffm_b", + lambda key, shape: b, + (), + ) - def setup(self): + # Create the networks and parameters that are used when + # mapping from input space to recurrent state space + # This is used in the map_to_h method and is used in the + # associative scan outer loop self.pre = nn.Dense(self.trace_size) self.gate_in = Gate(self.trace_size) - self.ffa = FFMCell(self.trace_size, self.context_size, self.output_size) self.gate_out = Gate(self.output_size) self.skip = nn.Dense(self.output_size) self.mix = nn.Dense(self.output_size) self.ln = nn.LayerNorm(use_scale=False, use_bias=False) - def map_to_h(self, inputs): - """Map from the input space to the recurrent state space""" - x, resets = inputs + def map_to_h(self, x: InputEmbedding) -> ScanInput: + """Given an input embedding, this will map it to the format required for the associative scan.""" gate_in = self.gate_in(x) pre = self.pre(x) gated_x = pre * gate_in - # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous - ts = jnp.ones(x.shape[0], dtype=jnp.int32) - z = jnp.repeat(jnp.expand_dims(gated_x, 2), self.context_size, axis=2) - return (z, ts), resets - - def map_from_h(self, recurrent_state, inputs): - """Map from the recurrent space to the Markov space""" - (state, ts), reset = recurrent_state - (x, start) = inputs - z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape( - state.shape[0], -1 - ) + scan_input = jnp.repeat(jnp.expand_dims(gated_x, 3), self.context_size, axis=3) + return scan_input + + def map_from_h(self, state: RecurrentState, x: InputEmbedding) -> chex.Array: + """Given the recurrent state and the input embedding, this will map the recurrent state back to the output space.""" + T = state.shape[0] + B = state.shape[1] + z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape(T, B, -1) z = self.mix(z_in) gate_out = self.gate_out(x) skip = self.skip(x) out = self.ln(z * gate_out) + skip * (1 - gate_out) return out - def __call__(self, recurrent_state, inputs): - # Recurrent state should be ((state, timestep), reset) - # Inputs should be (x, reset) - h = self.map_to_h(inputs) - recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h) - # recurrent_state is ((state, timestep), reset) - out = self.map_from_h(recurrent_state, inputs) - - # TODO: Remove this when we want to return all recurrent states instead of just the last one - final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) - return final_recurrent_state, out - - def initialize_carry(self, batch_size: int = None): - return self.cell.initialize_carry(batch_size) - - -class SFFM(nn.Module): - """Simplified Fast and Forgetful Memory""" - - trace_size: int - context_size: int - hidden_size: int - cell: nn.Module - - def setup(self): - self.W_trace = nn.Dense(self.trace_size) - self.W_context = Gate(self.context_size) - self.ffa = FFMCell( - self.trace_size, self.context_size, self.hidden_size, deterministic_init=False - ) - self.post = nn.Sequential( - [ - # Default init but with smaller weights - nn.Dense( - self.hidden_size, - kernel_init=nn.initializers.variance_scaling( - 0.01, "fan_in", "truncated_normal" - ), - ), - nn.LayerNorm(), - nn.leaky_relu, - nn.Dense(self.hidden_size), - nn.LayerNorm(), - nn.leaky_relu, - ] - ) - - def map_to_h(self, inputs): - x, resets = inputs - pre = jnp.abs(jnp.einsum("bi, bj -> bij", self.W_trace(x), self.W_context(x))) - pre = pre / jnp.sum(pre, axis=(-2, -1), keepdims=True) - # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous - ts = jnp.ones(x.shape[0], dtype=jnp.int32) - return (pre, ts), resets - - def map_from_h(self, recurrent_state, inputs): - x, resets = inputs - (state, ts), reset = recurrent_state - s = state.reshape(state.shape[0], self.context_size * self.trace_size) - eps = s.real + (s.real == 0 + jnp.sign(s.real)) * 0.01 - s = s + eps - scaled = jnp.concatenate( - [ - jnp.log(1 + jnp.abs(s)) * jnp.sin(jnp.angle(s)), - jnp.log(1 + jnp.abs(s)) * jnp.cos(jnp.angle(s)), - ], - axis=-1, - ) - z = self.post(scaled) - return z - - def __call__(self, recurrent_state, inputs): - # Recurrent state should be ((state, timestep), reset) - # Inputs should be (x, reset) - h = self.map_to_h(inputs) - recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h) - # recurrent_state is ((state, timestep), reset) - out = self.map_from_h(recurrent_state, inputs) - - # TODO: Remove this when we want to return all recurrent states instead of just the last one - final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) - return final_recurrent_state, out - - def initialize_carry(self, batch_size: int = None): - return self.cell.initialize_carry(batch_size) - - -class StackedSFFM(nn.Module): - """A multilayer version of SFFM""" - - cells: List[nn.Module] - - def setup(self): - self.project = nn.Dense(cells[0].hidden_size) - - def __call__(self, recurrent_state: jax.Array, inputs: Any) -> Tuple[jax.Array, jax.Array]: - x, start = inputs - x = self.project(x) - inputs = x, start - for i, cell in enumerate(self.cells): - s, y = cell(recurrent_state[i], inputs) - x = x + y - recurrent_state[i] = s - return y, recurrent_state - - def initialize_carry(self, batch_size: int = None): - return [c.initialize_carry(batch_size) for c in self.cells] - - -if __name__ == "__main__": - m = FFM( - output_size=4, - trace_size=5, - context_size=6, - cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)), - ) - s = m.initialize_carry() - x = jnp.ones((10, 2)) - start = jnp.zeros(10, dtype=bool) - params = m.init(jax.random.PRNGKey(0), s, (x, start)) - out_state, out = m.apply(params, s, (x, start)) - - # BatchFFM = nn.vmap( - # FFM, in_axes=1, out_axes=1, variable_axes={"params": None}, split_rngs={"params": False} - # ) - - # m = BatchFFM( - # trace_size=4, - # context_size=5, - # output_size=6, - # cell=MemoroidResetWrapper(cell=FFMCell(4,5,6)) - # ) + def log_gamma(self, t: Timestep) -> chex.Array: + T = t.shape[0] + B = t.shape[1] + a = self.a + b = self.b + a = -jnp.abs(a).reshape((1, 1, self.trace_size, 1)) + b = b.reshape(1, 1, 1, self.context_size) + ab = jax.lax.complex(a, b) + return ab * t.reshape(T, B, 1, 1) - # s = m.initialize_carry(8) - # x = jnp.ones((10, 8, 2)) - # start = jnp.zeros((10, 8), dtype=bool) - # params = m.init(jax.random.PRNGKey(0), s, (x, start)) - # out_state, out = m.apply(params, s, (x, start)) + def gamma(self, t: Timestep) -> chex.Array: + return jnp.exp(self.log_gamma(t)) - # print(out.shape) - # print(out_state.shape) + def unwrapped_associative_update( + self, + carry: Tuple[RecurrentState, Timestep], + incoming: Tuple[InputEmbedding, Timestep], + ) -> Tuple[RecurrentState, Timestep]: + ( + state, + i, + ) = carry + x, j = incoming + state = state * self.gamma(j) + x + return state, j + i - # TODO: Initialize cells with different random streams so the weights are not identical - cells = [ - SFFM( - trace_size=4, - context_size=5, - hidden_size=6, - cell=MemoroidResetWrapper(cell=FFMCell(4, 5, 6)), + def wrapped_associative_update( + self, + carry: Tuple[Reset, RecurrentState, Timestep], + incoming: Tuple[Reset, InputEmbedding, Timestep], + ) -> Tuple[Reset, RecurrentState, Timestep]: + prev_start, state, i = carry + start, x, j = incoming + # Reset all elements in the carry if we are starting a new episode + state = state * jnp.logical_not(start) + j = j * jnp.logical_not(start) + incoming = x, j + carry = (state, i) + out = self.unwrapped_associative_update(carry, incoming) + start_out = jnp.logical_or(start, prev_start) + return (start_out, *out) + + def scan( + self, + x: InputEmbedding, + state: RecurrentState, + start: Reset, + ) -> RecurrentState: + """Given an input and recurrent state, this will update the recurrent state. This is equivalent + to the inner-function g in the paper.""" + # x: [T, B, memory_size] + # memory: [1, B, memory_size, context_size] + T = x.shape[0] + B = x.shape[1] + timestep = jnp.ones((T + 1, B), dtype=jnp.int32).reshape(T + 1, B, 1, 1) + # Add context dim + start = start.reshape(T, B, 1, 1) + + # Now insert previous recurrent state + x = jnp.concatenate([state, x], axis=0) + start = jnp.concatenate([jnp.zeros_like(start[:1]), start], axis=0) + + # This is not executed during inference -- method will just return x if size is 1 + _, new_state, _ = jax.lax.associative_scan( + self.wrapped_associative_update, + (start, x, timestep), + axis=0, ) - for i in range(3) - ] - s2fm = StackedSFFM(cells=cells) + return new_state[1:] - s = s2fm.initialize_carry() - x = jnp.ones((10, 2)) - start = jnp.zeros(10, dtype=bool) - params = s2fm.init(jax.random.PRNGKey(0), s, (x, start)) - out_state, out = s2fm.apply(params, s, (x, start)) + @nn.nowrap + def initialize_carry(self, batch_size: int) -> RecurrentState: + return jnp.zeros((batch_size, self.trace_size, self.context_size), dtype=jnp.complex64) diff --git a/stoix/networks/memoroids/lru.py b/stoix/networks/memoroids/lru.py new file mode 100644 index 00000000..a53ae75d --- /dev/null +++ b/stoix/networks/memoroids/lru.py @@ -0,0 +1,220 @@ +from typing import Tuple, Union + +import chex +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.linen.initializers import Initializer + +from stoix.networks.memoroids.base import ( + InputEmbedding, + Inputs, + MemoroidCellBase, + RecurrentState, + Reset, + ScanInput, +) + +# NOT WORKING YET + +# Parallel scan operations +@jax.vmap +def binary_operator_diag(q_i, q_j): + """Binary operator for parallel scan of linear recurrence""" + A_i, b_i = q_i + A_j, b_j = q_j + return A_j * A_i, A_j * b_i + b_j + + +def wrapped_associative_update(carry: chex.Array, incoming: chex.Array) -> Tuple[chex.Array, ...]: + """The reset-wrapped form of the associative update. + + You might need to override this + if you use variables in associative_update that are not from initial_state. + This is equivalent to the h function in the paper: + b x H -> b x H + """ + prev_start, *carry = carry + start, *incoming = incoming + # Reset all elements in the carry if we are starting a new episode + A, b = carry + + A = jnp.logical_not(start) * A + start * jnp.ones_like(A) + b = jnp.logical_not(start) * b + + out = binary_operator_diag((A, b), incoming) + start_out = jnp.logical_or(start, prev_start) + return (start_out, *out) + + +def matrix_init(normalization: float = 1.0) -> Initializer: + def init( + key: chex.PRNGKey, shape: Tuple[int, ...], dtype: jnp.dtype = jnp.float32 + ) -> jnp.ndarray: + return jax.random.normal(key=key, shape=shape, dtype=dtype) / normalization + + return init + + +def nu_init(r_min: float, r_max: float) -> Initializer: + def init( + key: chex.PRNGKey, shape: Tuple[int, ...], dtype: jnp.dtype = jnp.float32 + ) -> jnp.ndarray: + u = jax.random.uniform(key=key, shape=shape, dtype=dtype) + return jnp.log(-0.5 * jnp.log(u * (r_max**2 - r_min**2) + r_min**2)) + + return init + + +def theta_init(max_phase: float) -> Initializer: + def init( + key: chex.PRNGKey, shape: Tuple[int, ...], dtype: jnp.dtype = jnp.float32 + ) -> jnp.ndarray: + u = jax.random.uniform(key, shape=shape, dtype=dtype) + return jnp.log(max_phase * u) + + return init + + +def gamma_log_init( + lamb: Tuple[Union[float, jnp.ndarray], Union[float, jnp.ndarray]] +) -> Initializer: + def init( + key: chex.PRNGKey, shape: Tuple[int, ...], dtype: jnp.dtype = jnp.float32 + ) -> jnp.ndarray: + nu, theta = lamb + diag_lambda = jnp.exp(-jnp.exp(nu) + 1j * jnp.exp(theta)) + return jnp.log(jnp.sqrt(1 - jnp.abs(diag_lambda) ** 2)) + + return init + + +class LRUCell(MemoroidCellBase): + """ + LRU module in charge of the recurrent processing. + Implementation following the one of Orvieto et al. 2023. + """ + + d_model: int # input and output dimensions + d_hidden: int # hidden state dimension + r_min: float = 0.0 # smallest lambda norm + r_max: float = 1.0 # largest lambda norm + max_phase: float = 6.28 # max phase lambda + + def setup(self): + + self.theta_log = self.param("theta_log", theta_init(self.max_phase), (self.d_hidden,)) + self.nu_log = self.param("nu_log", nu_init(self.r_min, self.r_max), (self.d_hidden,)) + self.gamma_log = self.param( + "gamma_log", gamma_log_init((self.nu_log, self.theta_log)), (self.d_hidden,) + ) + + self.B_re = self.param( + "B_re", + matrix_init(normalization=jnp.sqrt(2 * self.d_model)), + (self.d_hidden, self.d_model), + ) + self.B_im = self.param( + "B_im", + matrix_init(normalization=jnp.sqrt(2 * self.d_model)), + (self.d_hidden, self.d_model), + ) + self.C_re = self.param( + "C_re", + matrix_init(normalization=jnp.sqrt(self.d_hidden)), + (self.d_model, self.d_hidden), + ) + self.C_im = self.param( + "C_im", + matrix_init(normalization=jnp.sqrt(self.d_hidden)), + (self.d_model, self.d_hidden), + ) + self.D = self.param("D", matrix_init(normalization=1), (self.d_model,)) + + self.normalization = nn.LayerNorm() + self.out1 = nn.Dense(self.d_model) + self.out2 = nn.Dense(self.d_model) + + def map_to_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> ScanInput: + x = self.normalization(x) + diag_lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log)) + B_norm = (self.B_re + 1j * self.B_im) * jnp.expand_dims(jnp.exp(self.gamma_log), axis=-1) + + Lambda_elements = jnp.repeat(diag_lambda[None, ...], x.shape[0], axis=0) + Bu_elements = jax.vmap(lambda u: B_norm @ u)(x.astype(jnp.complex64)) + + Lambda_elements = jnp.concatenate( + [ + jnp.ones((1, diag_lambda.shape[0])), + Lambda_elements, + ] + ) + + Bu_elements = jnp.concatenate( + [ + recurrent_state, + Bu_elements, + ] + ) + + return (Lambda_elements, Bu_elements) + + def map_from_h(self, recurrent_states: RecurrentState, x: InputEmbedding) -> chex.Array: + + skip = x + + C = self.C_re + 1j * self.C_im + + # Use them to compute the output of the module + x = jax.vmap(lambda x, u: (C @ x).real + self.D * u)(recurrent_states, x) + + x = jax.nn.gelu(x) + o1 = self.out1(x) + x = o1 * jax.nn.sigmoid(self.out2(x)) # GLU + return skip + x # skip connection + + def scan(self, start, Lambda_elements, Bu_elements) -> RecurrentState: + + # Compute hidden states + _, _, xs = jax.lax.associative_scan( + wrapped_associative_update, (start, Lambda_elements, Bu_elements) + ) + + return xs[1:] + + def __call__(self, recurrent_state: RecurrentState, inputs: Inputs): + """Forward pass of a LRU: h_t+1 = lambda * h_t + B x_t+1, y_t = Re[C h_t + D x_t]""" + + x, start = inputs + + (Lambda_elements, Bu_elements) = self.map_to_h(recurrent_state, x) + + start = start.reshape([-1, 1]) + start = jnp.concatenate([jnp.zeros_like(start[:1]), start], axis=0) + + new_recurrent_states = self.scan(start, Lambda_elements, Bu_elements) + + outputs = self.map_from_h(new_recurrent_states, x) + + return new_recurrent_states[None, -1], outputs + + @nn.nowrap + def initialize_carry(self, batch_size: int) -> RecurrentState: + return jnp.zeros((1, self.d_hidden), dtype=jnp.complex64) + + +if __name__ == "__main__": + LRUModel = LRUCell(d_model=2, d_hidden=4) + + m = LRUModel + + batch_size = 1 + time_steps = 10 + + y = jnp.ones((time_steps, 2)) + s = m.initialize_carry(batch_size) + start = jnp.zeros((time_steps,), dtype=bool) + params = m.init(jax.random.PRNGKey(0), s, (y, start)) + out_state, out = m.apply(params, s, (y, start)) + + print(out) diff --git a/stoix/networks/memoroids/old_code1.py b/stoix/networks/memoroids/old_code1.py new file mode 100644 index 00000000..817d2275 --- /dev/null +++ b/stoix/networks/memoroids/old_code1.py @@ -0,0 +1,357 @@ +from functools import partial +from typing import Any, List, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp + + +def recurrent_associative_scan( + cell: nn.Module, + state: jax.Array, + inputs: jax.Array, + axis: int = 0, +) -> jax.Array: + """Execute the associative scan to update the recurrent state. + + Note that we do a trick here by concatenating the previous state to the inputs. + This is allowed since the scan is associative. This ensures that the previous + recurrent state feeds information into the scan. Without this method, we need + separate methods for rollouts and training.""" + + # Concatenate the prevous state to the inputs and scan over the result + # This ensures the previous recurrent state contributes to the current batch + # state: [start, (x, j)] + # inputs: [start, (x, j)] + scan_inputs = jax.tree.map(lambda x, s: jnp.concatenate([s, x], axis=0), inputs, state) + new_state = jax.lax.associative_scan( + cell, + scan_inputs, + axis=axis, + ) + # The zeroth index corresponds to the previous recurrent state + # We just use it to ensure continuity + # We do not actually want to use these values, so slice them away + return jax.tree.map(lambda x: x[1:], new_state) + + +class Gate(nn.Module): + """Sigmoidal gating""" + + output_size: int + + @nn.compact + def __call__(self, x): + x = nn.Dense(self.output_size)(x) + x = nn.sigmoid(x) + return x + + +def init_deterministic( + memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 +) -> Tuple[jax.Array, jax.Array]: + """Deterministic initialization of the FFM parameters.""" + a_low = 1e-6 + a_high = 0.5 + a = jnp.linspace(a_low, a_high, memory_size) + b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) + return a, b + + +def init_random( + memory_size: int, context_size: int, min_period: int = 1, max_period: int = 10_000, *, key +) -> Tuple[jax.Array, jax.Array]: + _, k1, k2 = jax.random.split(key, 3) + a_low = 1e-6 + a_high = 0.1 + a = jax.random.uniform(k1, (memory_size,), minval=a_low, maxval=a_high) + b = ( + 2 + * jnp.pi + / jnp.exp( + jax.random.uniform( + k2, (context_size,), minval=jnp.log(min_period), maxval=jnp.log(max_period) + ) + ) + ) + return a, b + + +class FFMCell(nn.Module): + """The binary associative update function for the FFM.""" + + trace_size: int + context_size: int + output_size: int + deterministic_init: bool = True + + def setup(self): + if self.deterministic_init: + a, b = init_deterministic(self.trace_size, self.context_size) + else: + # TODO: Will this result in the same keys for multiple FFMCells? + key = self.make_rng("ffa_params") + a, b = init_random(self.trace_size, self.context_size, key=key) + self.params = (self.param("ffa_a", lambda rng: a), self.param("ffa_b", lambda rng: b)) + + def log_gamma(self, t: jax.Array) -> jax.Array: + a, b = self.params + a = -jnp.abs(a).reshape((1, self.trace_size, 1)) + b = b.reshape(1, 1, self.context_size) + ab = jax.lax.complex(a, b) + return ab * t.reshape(t.shape[0], 1, 1) + + def gamma(self, t: jax.Array) -> jax.Array: + return jnp.exp(self.log_gamma(t)) + + def initialize_carry(self, batch_size: int = None): + if batch_size is None: + return jnp.zeros( + (1, self.trace_size, self.context_size), dtype=jnp.complex64 + ), jnp.ones((1,), dtype=jnp.int32) + + return jnp.zeros( + (1, batch_size, self.trace_size, self.context_size), dtype=jnp.complex64 + ), jnp.ones((1, batch_size), dtype=jnp.int32) + + def __call__(self, carry, incoming): + ( + state, + i, + ) = carry + x, j = incoming + state = state * self.gamma(j) + x + return state, j + i + + +class MemoroidResetWrapper(nn.Module): + """A wrapper around memoroid cells like FFM, LRU, etc that resets + the recurrent state upon a reset signal.""" + + cell: nn.Module + + def __call__(self, carry, incoming): + states, prev_start = carry + xs, start = incoming + + def reset_state(start, current_state, initial_state): + # Expand to reset all dims of state: [B, 1, 1, ...] + expanded_start = start.reshape(-1, *([1] * (current_state.ndim - 1))) + out = current_state * jnp.logical_not(expanded_start) + initial_state + return out + + initial_states = self.cell.initialize_carry() + states = jax.tree.map(partial(reset_state, start), states, initial_states) + out = self.cell(states, xs) + start_carry = jnp.logical_or(start, prev_start) + + return out, start_carry + + def initialize_carry(self, batch_size: int = None): + if batch_size is None: + # TODO: Should this be one or zero? + return self.cell.initialize_carry(batch_size), jnp.zeros((1,), dtype=bool) + + return self.cell.initialize_carry(batch_size), jnp.zeros((batch_size,), dtype=bool) + + +class FFM(nn.Module): + """Fast and Forgetful Memory""" + + trace_size: int + context_size: int + output_size: int + cell: nn.Module + + def setup(self): + self.pre = nn.Dense(self.trace_size) + self.gate_in = Gate(self.trace_size) + self.ffa = FFMCell(self.trace_size, self.context_size, self.output_size) + self.gate_out = Gate(self.output_size) + self.skip = nn.Dense(self.output_size) + self.mix = nn.Dense(self.output_size) + self.ln = nn.LayerNorm(use_scale=False, use_bias=False) + + def map_to_h(self, inputs): + """Map from the input space to the recurrent state space""" + x, resets = inputs + gate_in = self.gate_in(x) + pre = self.pre(x) + gated_x = pre * gate_in + # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous + ts = jnp.ones(x.shape[0], dtype=jnp.int32) + z = jnp.repeat(jnp.expand_dims(gated_x, 2), self.context_size, axis=2) + return (z, ts), resets + + def map_from_h(self, recurrent_state, inputs): + """Map from the recurrent space to the Markov space""" + (state, ts), reset = recurrent_state + (x, start) = inputs + z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape( + state.shape[0], -1 + ) + z = self.mix(z_in) + gate_out = self.gate_out(x) + skip = self.skip(x) + out = self.ln(z * gate_out) + skip * (1 - gate_out) + return out + + def __call__(self, recurrent_state, inputs): + # Recurrent state should be ((state, timestep), reset) + # Inputs should be (x, reset) + h = self.map_to_h(inputs) + recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h) + # recurrent_state is ((state, timestep), reset) + out = self.map_from_h(recurrent_state, inputs) + + # TODO: Remove this when we want to return all recurrent states instead of just the last one + final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) + return final_recurrent_state, out + + def initialize_carry(self, batch_size: int = None): + return self.cell.initialize_carry(batch_size) + + +class SFFM(nn.Module): + """Simplified Fast and Forgetful Memory""" + + trace_size: int + context_size: int + hidden_size: int + cell: nn.Module + + def setup(self): + self.W_trace = nn.Dense(self.trace_size) + self.W_context = Gate(self.context_size) + self.ffa = FFMCell( + self.trace_size, self.context_size, self.hidden_size, deterministic_init=False + ) + self.post = nn.Sequential( + [ + # Default init but with smaller weights + nn.Dense( + self.hidden_size, + kernel_init=nn.initializers.variance_scaling( + 0.01, "fan_in", "truncated_normal" + ), + ), + nn.LayerNorm(), + nn.leaky_relu, + nn.Dense(self.hidden_size), + nn.LayerNorm(), + nn.leaky_relu, + ] + ) + + def map_to_h(self, inputs): + x, resets = inputs + pre = jnp.abs(jnp.einsum("bi, bj -> bij", self.W_trace(x), self.W_context(x))) + pre = pre / jnp.sum(pre, axis=(-2, -1), keepdims=True) + # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous + ts = jnp.ones(x.shape[0], dtype=jnp.int32) + return (pre, ts), resets + + def map_from_h(self, recurrent_state, inputs): + x, resets = inputs + (state, ts), reset = recurrent_state + s = state.reshape(state.shape[0], self.context_size * self.trace_size) + eps = s.real + (s.real == 0 + jnp.sign(s.real)) * 0.01 + s = s + eps + scaled = jnp.concatenate( + [ + jnp.log(1 + jnp.abs(s)) * jnp.sin(jnp.angle(s)), + jnp.log(1 + jnp.abs(s)) * jnp.cos(jnp.angle(s)), + ], + axis=-1, + ) + z = self.post(scaled) + return z + + def __call__(self, recurrent_state, inputs): + # Recurrent state should be ((state, timestep), reset) + # Inputs should be (x, reset) + h = self.map_to_h(inputs) + recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h) + # recurrent_state is ((state, timestep), reset) + out = self.map_from_h(recurrent_state, inputs) + + # TODO: Remove this when we want to return all recurrent states instead of just the last one + final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) + return final_recurrent_state, out + + def initialize_carry(self, batch_size: int = None): + return self.cell.initialize_carry(batch_size) + + +class StackedSFFM(nn.Module): + """A multilayer version of SFFM""" + + cells: List[nn.Module] + + def setup(self): + self.project = nn.Dense(cells[0].hidden_size) + + def __call__(self, recurrent_state: jax.Array, inputs: Any) -> Tuple[jax.Array, jax.Array]: + x, start = inputs + x = self.project(x) + inputs = x, start + for i, cell in enumerate(self.cells): + s, y = cell(recurrent_state[i], inputs) + x = x + y + recurrent_state[i] = s + return y, recurrent_state + + def initialize_carry(self, batch_size: int = None): + return [c.initialize_carry(batch_size) for c in self.cells] + + +if __name__ == "__main__": + m = FFM( + output_size=4, + trace_size=5, + context_size=6, + cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)), + ) + s = m.initialize_carry() + x = jnp.ones((10, 2)) + start = jnp.zeros(10, dtype=bool) + params = m.init(jax.random.PRNGKey(0), s, (x, start)) + out_state, out = m.apply(params, s, (x, start)) + + # BatchFFM = nn.vmap( + # FFM, in_axes=1, out_axes=1, variable_axes={"params": None}, split_rngs={"params": False} + # ) + + # m = BatchFFM( + # trace_size=4, + # context_size=5, + # output_size=6, + # cell=MemoroidResetWrapper(cell=FFMCell(4,5,6)) + # ) + + # s = m.initialize_carry(8) + # x = jnp.ones((10, 8, 2)) + # start = jnp.zeros((10, 8), dtype=bool) + # params = m.init(jax.random.PRNGKey(0), s, (x, start)) + # out_state, out = m.apply(params, s, (x, start)) + + # print(out.shape) + # print(out_state.shape) + + # TODO: Initialize cells with different random streams so the weights are not identical + cells = [ + SFFM( + trace_size=4, + context_size=5, + hidden_size=6, + cell=MemoroidResetWrapper(cell=FFMCell(4, 5, 6)), + ) + for i in range(3) + ] + s2fm = StackedSFFM(cells=cells) + + s = s2fm.initialize_carry() + x = jnp.ones((10, 2)) + start = jnp.zeros(10, dtype=bool) + params = s2fm.init(jax.random.PRNGKey(0), s, (x, start)) + out_state, out = s2fm.apply(params, s, (x, start)) diff --git a/stoix/networks/memoroids/memoroid.py b/stoix/networks/memoroids/old_code2.py similarity index 100% rename from stoix/networks/memoroids/memoroid.py rename to stoix/networks/memoroids/old_code2.py diff --git a/stoix/networks/memoroids/s5.py b/stoix/networks/memoroids/old_s5.py similarity index 100% rename from stoix/networks/memoroids/s5.py rename to stoix/networks/memoroids/old_s5.py diff --git a/stoix/networks/memoroids/working_demov2.py b/stoix/networks/memoroids/working_demov2.py deleted file mode 100644 index c9b9f3f9..00000000 --- a/stoix/networks/memoroids/working_demov2.py +++ /dev/null @@ -1,269 +0,0 @@ -from typing import Tuple - -import chex -import jax -import optax -from flax import linen as nn -from jax import numpy as jnp - -RecurrentState = chex.Array -Reset = chex.Array -Timestep = chex.Array -InputEmbedding = chex.Array -Inputs = Tuple[InputEmbedding, Reset] -ScanInput = chex.Array - - -def init_deterministic( - memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 -) -> Tuple[chex.Array, chex.Array]: - a_low = 1e-6 - a_high = 0.5 - a = jnp.linspace(a_low, a_high, memory_size) - b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) - return a, b - - -class Gate(nn.Module): - output_size: int - - @nn.compact - def __call__(self, x: chex.Array) -> chex.Array: - return jax.nn.sigmoid(nn.Dense(self.output_size)(x)) - - -class FFM(nn.Module): - trace_size: int - context_size: int - output_size: int - - def setup(self) -> None: - - # Create the FFM parameters - a, b = init_deterministic(self.trace_size, self.context_size) - self.a = self.param( - "ffm_a", - lambda key, shape: a, - (), - ) - self.b = self.param( - "ffm_b", - lambda key, shape: b, - (), - ) - - # Create the networks and parameters that are used when - # mapping from input space to recurrent state space - # This is used in the map_to_h method and is used in the - # associative scan outer loop - self.pre = nn.Dense(self.trace_size) - self.gate_in = Gate(self.trace_size) - self.gate_out = Gate(self.output_size) - self.skip = nn.Dense(self.output_size) - self.mix = nn.Dense(self.output_size) - self.ln = nn.LayerNorm(use_scale=False, use_bias=False) - - def map_to_h(self, x: InputEmbedding) -> ScanInput: - """Given an input embedding, this will map it to the format required for the associative scan.""" - gate_in = self.gate_in(x) - pre = self.pre(x) - gated_x = pre * gate_in - scan_input = jnp.repeat(jnp.expand_dims(gated_x, 3), self.context_size, axis=3) - return scan_input - - def map_from_h(self, state: RecurrentState, x: InputEmbedding) -> chex.Array: - """Given the recurrent state and the input embedding, this will map the recurrent state back to the output space.""" - T = state.shape[0] - B = state.shape[1] - z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape(T, B, -1) - z = self.mix(z_in) - gate_out = self.gate_out(x) - skip = self.skip(x) - out = self.ln(z * gate_out) + skip * (1 - gate_out) - return out - - def log_gamma(self, t: Timestep) -> chex.Array: - T = t.shape[0] - B = t.shape[1] - a = self.a - b = self.b - a = -jnp.abs(a).reshape((1, 1, self.trace_size, 1)) - b = b.reshape(1, 1, 1, self.context_size) - ab = jax.lax.complex(a, b) - return ab * t.reshape(T, B, 1, 1) - - def gamma(self, t: Timestep) -> chex.Array: - return jnp.exp(self.log_gamma(t)) - - def unwrapped_associative_update( - self, - carry: Tuple[RecurrentState, Timestep], - incoming: Tuple[InputEmbedding, Timestep], - ) -> Tuple[RecurrentState, Timestep]: - ( - state, - i, - ) = carry - x, j = incoming - state = state * self.gamma(j) + x - return state, j + i - - def wrapped_associative_update( - self, - carry: Tuple[Reset, RecurrentState, Timestep], - incoming: Tuple[Reset, InputEmbedding, Timestep], - ) -> Tuple[Reset, RecurrentState, Timestep]: - prev_start, state, i = carry - start, x, j = incoming - # Reset all elements in the carry if we are starting a new episode - state = state * jnp.logical_not(start) - j = j * jnp.logical_not(start) - incoming = x, j - carry = (state, i) - out = self.unwrapped_associative_update(carry, incoming) - start_out = jnp.logical_or(start, prev_start) - return (start_out, *out) - - def scan( - self, - x: InputEmbedding, - state: RecurrentState, - start: Reset, - ) -> RecurrentState: - """Given an input and recurrent state, this will update the recurrent state. This is equivalent - to the inner-function g in the paper.""" - # x: [T, B, memory_size] - # memory: [1, B, memory_size, context_size] - T = x.shape[0] - B = x.shape[1] - timestep = jnp.ones((T + 1, B), dtype=jnp.int32).reshape(T + 1, B, 1, 1) - # Add context dim - start = start.reshape(T, B, 1, 1) - - # Now insert previous recurrent state - x = jnp.concatenate([state, x], axis=0) - start = jnp.concatenate([jnp.zeros_like(start[:1]), start], axis=0) - - # This is not executed during inference -- method will just return x if size is 1 - _, new_state, _ = jax.lax.associative_scan( - self.wrapped_associative_update, - (start, x, timestep), - axis=0, - ) - return new_state[1:] - - @nn.compact - def __call__(self, state: RecurrentState, inputs: Inputs) -> Tuple[RecurrentState, chex.Array]: - - # Add a sequence dimension to the recurrent state. - state = jnp.expand_dims(state, 0) - - # Unpack inputs - x, start = inputs - - # Map the input embedding to the recurrent state space. - # This maps to the format required for the associative scan. - scan_input = self.map_to_h(x) - - # Update the recurrent state - state = self.scan(scan_input, state, start) - - # Map the recurrent state back to the output space - out = self.map_from_h(state, x) - - # Take the final state of the sequence. - final_state = state[-1:] - - # TODO: remove this when not running test - out = nn.Dense(128)(out) - out = nn.relu(out) - out = nn.Dense(1)(out) - - # Remove the sequence dimemnsion from the final state. - final_state = jnp.squeeze(final_state, 0) - - return final_state, out - - @nn.nowrap - def initialize_carry(self, batch_size: int) -> RecurrentState: - return jnp.zeros((batch_size, self.trace_size, self.context_size), dtype=jnp.complex64) - - -def train_memorize(): - - USE_BATCH_VERSION = True # Required to be true - - m = FFM(output_size=128, trace_size=64, context_size=4) - - batch_size = 16 - rem_ts = 10 - time_steps = rem_ts * 10 - obs_space = 8 - rng = jax.random.PRNGKey(0) - if USE_BATCH_VERSION: - x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space) - y = jnp.stack( - [ - jnp.repeat(x[::rem_ts, i], x.shape[0] // x[::rem_ts, i].shape[0]) - for i in range(batch_size) - ], - axis=-1, - ) - x = x.reshape(time_steps, batch_size, 1) - y = y.reshape(time_steps, batch_size, 1) - - start = jnp.zeros([time_steps, batch_size], dtype=bool).at[::rem_ts].set(True) - - s = m.initialize_carry(batch_size) - - params = m.init(jax.random.PRNGKey(0), s, (x, start)) - - def error(params, x, start, key): - s = m.initialize_carry(batch_size) - - # For BATCH VERSION - if USE_BATCH_VERSION: - x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space) - y = jnp.stack( - [ - jnp.repeat(x[::rem_ts, i], x.shape[0] // x[::rem_ts, i].shape[0]) - for i in range(batch_size) - ], - axis=-1, - ) - x = x.reshape(time_steps, batch_size, 1) - y = y.reshape(time_steps, batch_size, 1) - - final_state, y_hat = m.apply(params, s, (x, start)) - y_hat = jnp.squeeze(y_hat) - y = jnp.squeeze(y) - accuracy = (jnp.round(y_hat) == y).mean() - loss = jnp.mean(jnp.abs(y - y_hat) ** 2) - return loss, {"accuracy": accuracy, "loss": loss} - - optimizer = optax.adam(learning_rate=0.001) - state = optimizer.init(params) - loss_fn = jax.jit(jax.grad(error, has_aux=True)) - for step in range(10_000): - rng = jax.random.split(rng)[0] - grads, loss_info = loss_fn(params, x, start, rng) - updates, state = jax.jit(optimizer.update)(grads, state) - params = jax.jit(optax.apply_updates)(params, updates) - print(f"Step {step+1}, Loss: {loss_info['loss']}, Accuracy: {loss_info['accuracy']}") - - -if __name__ == "__main__": - # m = FFM( - # output_size=4, - # trace_size=5, - # context_size=6, - # ) - # s = m.initialize_carry() - # x = jnp.ones((10, 2)) - # start = jnp.zeros(10, dtype=bool) - # params = m.init(jax.random.PRNGKey(0), x, s, start) - # out = m.apply(params, x, s, start) - - # print(out) - - train_memorize() From b21d5e26c4d9e4d42728a196a679f9ab7c9acfae Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Wed, 3 Jul 2024 22:25:00 +0000 Subject: [PATCH 32/38] chore: slight config change --- stoix/configs/arch/anakin.yaml | 2 +- stoix/configs/system/rec_ppo.yaml | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/stoix/configs/arch/anakin.yaml b/stoix/configs/arch/anakin.yaml index f6092512..5b74ce41 100644 --- a/stoix/configs/arch/anakin.yaml +++ b/stoix/configs/arch/anakin.yaml @@ -13,6 +13,6 @@ evaluation_greedy: False # Evaluate the policy greedily. If True the policy will # an action which corresponds to the greatest logit. If false, the policy will sample # from the logits. num_eval_episodes: 128 # Number of episodes to evaluate per evaluation. -num_evaluation: 50 # Number of evenly spaced evaluations to perform during training. +num_evaluation: 19 # Number of evenly spaced evaluations to perform during training. absolute_metric: True # Whether the absolute metric should be computed. For more details # on the absolute metric please see: https://arxiv.org/abs/2209.10485 diff --git a/stoix/configs/system/rec_ppo.yaml b/stoix/configs/system/rec_ppo.yaml index 75b25f12..cb6f5c0c 100644 --- a/stoix/configs/system/rec_ppo.yaml +++ b/stoix/configs/system/rec_ppo.yaml @@ -3,11 +3,11 @@ system_name: rec_ppo # Name of the system. # --- RL hyperparameters --- -actor_lr: 1e-5 # Learning rate for actor network -critic_lr: 1e-5 # Learning rate for critic network +actor_lr: 3e-5 # Learning rate for actor network +critic_lr: 3e-5 # Learning rate for critic network rollout_length: 256 # Number of environment steps per vectorised environment. -epochs: 15 # Number of ppo epochs per training data batch. -num_minibatches: 32 # Number of minibatches per ppo epoch. +epochs: 10 # Number of ppo epochs per training data batch. +num_minibatches: 64 # Number of minibatches per ppo epoch. gamma: 0.99 # Discounting factor. gae_lambda: 0.95 # Lambda value for GAE computation. clip_eps: 0.2 # Clipping value for PPO updates and value function. From 0cfb8320f2fc4c6f8ec43ec2da73fae52371dbc3 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Thu, 4 Jul 2024 14:23:09 +0000 Subject: [PATCH 33/38] chore: more editing and add s5 --- stoix/configs/arch/anakin.yaml | 2 +- stoix/configs/network/lru.yaml | 35 ++ stoix/configs/network/s5.yaml | 14 +- stoix/networks/memoroids/base.py | 4 +- stoix/networks/memoroids/ffm.py | 105 ++-- stoix/networks/memoroids/lru.py | 183 +++---- .../memoroids/{ => old_code}/old_code1.py | 0 .../memoroids/{ => old_code}/old_code2.py | 0 .../memoroids/{ => old_code}/old_s5.py | 0 stoix/networks/memoroids/s5.py | 464 ++++++++++++++++++ stoix/systems/ppo/rec_ppo.py | 17 +- 11 files changed, 655 insertions(+), 169 deletions(-) create mode 100644 stoix/configs/network/lru.yaml rename stoix/networks/memoroids/{ => old_code}/old_code1.py (100%) rename stoix/networks/memoroids/{ => old_code}/old_code2.py (100%) rename stoix/networks/memoroids/{ => old_code}/old_s5.py (100%) create mode 100644 stoix/networks/memoroids/s5.py diff --git a/stoix/configs/arch/anakin.yaml b/stoix/configs/arch/anakin.yaml index 5b74ce41..7f3cea10 100644 --- a/stoix/configs/arch/anakin.yaml +++ b/stoix/configs/arch/anakin.yaml @@ -9,7 +9,7 @@ total_timesteps: 1e7 # Set the total environment steps. num_updates: ~ # Number of updates # --- Evaluation --- -evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select +evaluation_greedy: True # Evaluate the policy greedily. If True the policy will select # an action which corresponds to the greatest logit. If false, the policy will sample # from the logits. num_eval_episodes: 128 # Number of episodes to evaluate per evaluation. diff --git a/stoix/configs/network/lru.yaml b/stoix/configs/network/lru.yaml new file mode 100644 index 00000000..f03b0288 --- /dev/null +++ b/stoix/configs/network/lru.yaml @@ -0,0 +1,35 @@ +# ---Recurrent Structure Networks for PPO --- + +actor_network: + pre_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256] + use_layer_norm: True + activation: leaky_relu + rnn_layer: + _target_: stoix.networks.memoroids.lru.LRUCell + hidden_state_dim: 512 + post_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256] + use_layer_norm: False + activation: leaky_relu + action_head: + _target_: stoix.networks.heads.CategoricalHead + +critic_network: + pre_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256] + use_layer_norm: True + activation: leaky_relu + rnn_layer: + _target_: stoix.networks.memoroids.lru.LRUCell + hidden_state_dim: 512 + post_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256] + use_layer_norm: False + activation: leaky_relu + critic_head: + _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/configs/network/s5.yaml b/stoix/configs/network/s5.yaml index 8abbb886..04426503 100644 --- a/stoix/configs/network/s5.yaml +++ b/stoix/configs/network/s5.yaml @@ -7,10 +7,9 @@ actor_network: use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.old_s5.StackedEncoderModel - ssm_size: 256 - d_model: 256 - n_layers: 1 + _target_: stoix.networks.memoroids.s5.S5Cell + d_model : 256 + state_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] @@ -26,10 +25,9 @@ critic_network: use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.old_s5.StackedEncoderModel - ssm_size: 256 - d_model: 256 - n_layers: 1 + _target_: stoix.networks.memoroids.s5.S5Cell + d_model : 256 + state_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] diff --git a/stoix/networks/memoroids/base.py b/stoix/networks/memoroids/base.py index 71214ecc..6fa9a03a 100644 --- a/stoix/networks/memoroids/base.py +++ b/stoix/networks/memoroids/base.py @@ -94,10 +94,10 @@ def __call__( assert len(all_states) == len( self.cells ), f"Expected {len(self.cells)} states, got {len(all_states)}" - + x, starts = inputs new_states = [] for cell, mem_state in zip(self.cells, all_states): - new_mem_state, x = cell(mem_state, x) + new_mem_state, x = cell(mem_state, (x, starts)) new_states.append(new_mem_state) return new_states, x diff --git a/stoix/networks/memoroids/ffm.py b/stoix/networks/memoroids/ffm.py index e0352c1c..3066d975 100644 --- a/stoix/networks/memoroids/ffm.py +++ b/stoix/networks/memoroids/ffm.py @@ -7,6 +7,7 @@ from stoix.networks.memoroids.base import ( InputEmbedding, + Inputs, RecurrentState, Reset, ScanInput, @@ -14,14 +15,26 @@ ) -def init_deterministic( - memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 +def init_deterministic_a( + memory_size: int, ) -> Tuple[chex.Array, chex.Array]: - a_low = 1e-6 - a_high = 0.5 - a = jnp.linspace(a_low, a_high, memory_size) - b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) - return a, b + def init(key, shape): + a_low = 1e-6 + a_high = 0.5 + a = jnp.linspace(a_low, a_high, memory_size) + return a + + return init + + +def init_deterministic_b( + context_size: int, min_period: int = 1, max_period: int = 1_000 +) -> Tuple[chex.Array, chex.Array]: + def init(key, shape): + b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) + return b + + return init class Gate(nn.Module): @@ -37,36 +50,10 @@ class FFMCell(nn.Module): context_size: int output_size: int - def setup(self) -> None: - - # Create the FFM parameters - a, b = init_deterministic(self.trace_size, self.context_size) - self.a = self.param( - "ffm_a", - lambda key, shape: a, - (), - ) - self.b = self.param( - "ffm_b", - lambda key, shape: b, - (), - ) - - # Create the networks and parameters that are used when - # mapping from input space to recurrent state space - # This is used in the map_to_h method and is used in the - # associative scan outer loop - self.pre = nn.Dense(self.trace_size) - self.gate_in = Gate(self.trace_size) - self.gate_out = Gate(self.output_size) - self.skip = nn.Dense(self.output_size) - self.mix = nn.Dense(self.output_size) - self.ln = nn.LayerNorm(use_scale=False, use_bias=False) - def map_to_h(self, x: InputEmbedding) -> ScanInput: """Given an input embedding, this will map it to the format required for the associative scan.""" - gate_in = self.gate_in(x) - pre = self.pre(x) + gate_in = Gate(self.trace_size)(x) + pre = nn.Dense(self.trace_size)(x) gated_x = pre * gate_in scan_input = jnp.repeat(jnp.expand_dims(gated_x, 3), self.context_size, axis=3) return scan_input @@ -76,17 +63,26 @@ def map_from_h(self, state: RecurrentState, x: InputEmbedding) -> chex.Array: T = state.shape[0] B = state.shape[1] z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape(T, B, -1) - z = self.mix(z_in) - gate_out = self.gate_out(x) - skip = self.skip(x) - out = self.ln(z * gate_out) + skip * (1 - gate_out) + z = nn.Dense(self.output_size)(z_in) + gate_out = Gate(self.output_size)(x) + skip = nn.Dense(self.output_size)(x) + out = nn.LayerNorm(use_scale=False, use_bias=False)(z * gate_out) + skip * (1 - gate_out) return out def log_gamma(self, t: Timestep) -> chex.Array: T = t.shape[0] B = t.shape[1] - a = self.a - b = self.b + + a = self.param( + "ffm_a", + init_deterministic_a(self.trace_size), + (), + ) + b = self.param( + "ffm_b", + init_deterministic_b(self.context_size), + (), + ) a = -jnp.abs(a).reshape((1, 1, self.trace_size, 1)) b = b.reshape(1, 1, 1, self.context_size) ab = jax.lax.complex(a, b) @@ -152,6 +148,33 @@ def scan( ) return new_state[1:] + @nn.compact + def __call__(self, state: RecurrentState, inputs: Inputs) -> Tuple[RecurrentState, chex.Array]: + + # Add a sequence dimension to the recurrent state. + state = jnp.expand_dims(state, 0) + + # Unpack inputs + x, start = inputs + + # Map the input embedding to the recurrent state space. + # This maps to the format required for the associative scan. + scan_input = self.map_to_h(x) + + # Update the recurrent state + state = self.scan(scan_input, state, start) + + # Map the recurrent state back to the output space + out = self.map_from_h(state, x) + + # Take the final state of the sequence. + final_state = state[-1:] + + # Remove the sequence dimemnsion from the final state. + final_state = jnp.squeeze(final_state, 0) + + return final_state, out + @nn.nowrap def initialize_carry(self, batch_size: int) -> RecurrentState: return jnp.zeros((batch_size, self.trace_size, self.context_size), dtype=jnp.complex64) diff --git a/stoix/networks/memoroids/lru.py b/stoix/networks/memoroids/lru.py index a53ae75d..1af3c471 100644 --- a/stoix/networks/memoroids/lru.py +++ b/stoix/networks/memoroids/lru.py @@ -1,10 +1,11 @@ -from typing import Tuple, Union +import functools +from functools import partial +from typing import Tuple import chex import jax import jax.numpy as jnp from flax import linen as nn -from flax.linen.initializers import Initializer from stoix.networks.memoroids.base import ( InputEmbedding, @@ -15,7 +16,6 @@ ScanInput, ) -# NOT WORKING YET # Parallel scan operations @jax.vmap @@ -47,46 +47,24 @@ def wrapped_associative_update(carry: chex.Array, incoming: chex.Array) -> Tuple return (start_out, *out) -def matrix_init(normalization: float = 1.0) -> Initializer: - def init( - key: chex.PRNGKey, shape: Tuple[int, ...], dtype: jnp.dtype = jnp.float32 - ) -> jnp.ndarray: - return jax.random.normal(key=key, shape=shape, dtype=dtype) / normalization +def matrix_init(key, shape, dtype=jnp.float32, normalization=1): + return jax.random.normal(key=key, shape=shape, dtype=dtype) / normalization - return init +def nu_init(key, shape, r_min, r_max, dtype=jnp.float32): + u = jax.random.uniform(key=key, shape=shape, dtype=dtype) + return jnp.log(-0.5 * jnp.log(u * (r_max**2 - r_min**2) + r_min**2)) -def nu_init(r_min: float, r_max: float) -> Initializer: - def init( - key: chex.PRNGKey, shape: Tuple[int, ...], dtype: jnp.dtype = jnp.float32 - ) -> jnp.ndarray: - u = jax.random.uniform(key=key, shape=shape, dtype=dtype) - return jnp.log(-0.5 * jnp.log(u * (r_max**2 - r_min**2) + r_min**2)) - return init +def theta_init(key, shape, max_phase, dtype=jnp.float32): + u = jax.random.uniform(key, shape=shape, dtype=dtype) + return jnp.log(max_phase * u) -def theta_init(max_phase: float) -> Initializer: - def init( - key: chex.PRNGKey, shape: Tuple[int, ...], dtype: jnp.dtype = jnp.float32 - ) -> jnp.ndarray: - u = jax.random.uniform(key, shape=shape, dtype=dtype) - return jnp.log(max_phase * u) - - return init - - -def gamma_log_init( - lamb: Tuple[Union[float, jnp.ndarray], Union[float, jnp.ndarray]] -) -> Initializer: - def init( - key: chex.PRNGKey, shape: Tuple[int, ...], dtype: jnp.dtype = jnp.float32 - ) -> jnp.ndarray: - nu, theta = lamb - diag_lambda = jnp.exp(-jnp.exp(nu) + 1j * jnp.exp(theta)) - return jnp.log(jnp.sqrt(1 - jnp.abs(diag_lambda) ** 2)) - - return init +def gamma_log_init(key, lamb): + nu, theta = lamb + diag_lambda = jnp.exp(-jnp.exp(nu) + 1j * jnp.exp(theta)) + return jnp.log(jnp.sqrt(1 - jnp.abs(diag_lambda) ** 2)) class LRUCell(MemoroidCellBase): @@ -95,53 +73,38 @@ class LRUCell(MemoroidCellBase): Implementation following the one of Orvieto et al. 2023. """ - d_model: int # input and output dimensions - d_hidden: int # hidden state dimension + hidden_state_dim: int # hidden state dimension r_min: float = 0.0 # smallest lambda norm r_max: float = 1.0 # largest lambda norm max_phase: float = 6.28 # max phase lambda - def setup(self): - - self.theta_log = self.param("theta_log", theta_init(self.max_phase), (self.d_hidden,)) - self.nu_log = self.param("nu_log", nu_init(self.r_min, self.r_max), (self.d_hidden,)) - self.gamma_log = self.param( - "gamma_log", gamma_log_init((self.nu_log, self.theta_log)), (self.d_hidden,) + def map_to_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> ScanInput: + d_model = x.shape[-1] + theta_log = self.param( + "theta_log", partial(theta_init, max_phase=self.max_phase), (self.hidden_state_dim,) + ) + nu_log = self.param( + "nu_log", partial(nu_init, r_min=self.r_min, r_max=self.r_max), (self.hidden_state_dim,) ) + gamma_log = self.param("gamma_log", gamma_log_init, (nu_log, theta_log)) - self.B_re = self.param( + B_re = self.param( "B_re", - matrix_init(normalization=jnp.sqrt(2 * self.d_model)), - (self.d_hidden, self.d_model), + partial(matrix_init, normalization=jnp.sqrt(2 * d_model)), + (self.hidden_state_dim, d_model), ) - self.B_im = self.param( + + B_im = self.param( "B_im", - matrix_init(normalization=jnp.sqrt(2 * self.d_model)), - (self.d_hidden, self.d_model), + partial(matrix_init, normalization=jnp.sqrt(2 * d_model)), + (self.hidden_state_dim, d_model), ) - self.C_re = self.param( - "C_re", - matrix_init(normalization=jnp.sqrt(self.d_hidden)), - (self.d_model, self.d_hidden), - ) - self.C_im = self.param( - "C_im", - matrix_init(normalization=jnp.sqrt(self.d_hidden)), - (self.d_model, self.d_hidden), - ) - self.D = self.param("D", matrix_init(normalization=1), (self.d_model,)) - - self.normalization = nn.LayerNorm() - self.out1 = nn.Dense(self.d_model) - self.out2 = nn.Dense(self.d_model) - def map_to_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> ScanInput: - x = self.normalization(x) - diag_lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log)) - B_norm = (self.B_re + 1j * self.B_im) * jnp.expand_dims(jnp.exp(self.gamma_log), axis=-1) + diag_lambda = jnp.exp(-jnp.exp(nu_log) + 1j * jnp.exp(theta_log)) + B_norm = (B_re + 1j * B_im) * jnp.expand_dims(jnp.exp(gamma_log), axis=-1) Lambda_elements = jnp.repeat(diag_lambda[None, ...], x.shape[0], axis=0) - Bu_elements = jax.vmap(lambda u: B_norm @ u)(x.astype(jnp.complex64)) + Bu_elements = jax.vmap(lambda u: B_norm @ u)(x) Lambda_elements = jnp.concatenate( [ @@ -161,60 +124,70 @@ def map_to_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> ScanIn def map_from_h(self, recurrent_states: RecurrentState, x: InputEmbedding) -> chex.Array: - skip = x + d_model = x.shape[-1] + C_re = self.param( + "C_re", + partial(matrix_init, normalization=jnp.sqrt(self.hidden_state_dim)), + (d_model, self.hidden_state_dim), + ) + C_im = self.param( + "C_im", + partial(matrix_init, normalization=jnp.sqrt(self.hidden_state_dim)), + (d_model, self.hidden_state_dim), + ) + D = self.param("D", matrix_init, (d_model,)) - C = self.C_re + 1j * self.C_im + skip = x # Use them to compute the output of the module - x = jax.vmap(lambda x, u: (C @ x).real + self.D * u)(recurrent_states, x) + C = C_re + 1j * C_im + x = jax.vmap(lambda h, x: (C @ h).real + D * x)(recurrent_states, x) - x = jax.nn.gelu(x) - o1 = self.out1(x) - x = o1 * jax.nn.sigmoid(self.out2(x)) # GLU + x = nn.gelu(x) + x = nn.Dense(d_model)(x) * jax.nn.sigmoid(nn.Dense(d_model)(x)) # GLU return skip + x # skip connection - def scan(self, start, Lambda_elements, Bu_elements) -> RecurrentState: - + def scan( + self, start: Reset, Lambda_elements: chex.Array, Bu_elements: chex.Array + ) -> RecurrentState: + start = start.reshape([-1, 1]) + start = jnp.concatenate([jnp.zeros_like(start[:1]), start], axis=0) # Compute hidden states _, _, xs = jax.lax.associative_scan( wrapped_associative_update, (start, Lambda_elements, Bu_elements) ) - return xs[1:] - def __call__(self, recurrent_state: RecurrentState, inputs: Inputs): + @functools.partial( + nn.vmap, + variable_axes={"params": None}, + in_axes=(0, 1), + out_axes=(0, 1), + split_rngs={"params": False}, + ) + @nn.compact + def __call__( + self, recurrent_state: RecurrentState, inputs: Inputs + ) -> Tuple[RecurrentState, chex.Array]: """Forward pass of a LRU: h_t+1 = lambda * h_t + B x_t+1, y_t = Re[C h_t + D x_t]""" - x, start = inputs + # Add a sequence dimension to the recurrent state + recurrent_state = jnp.expand_dims(recurrent_state, 0) + + x, starts = inputs (Lambda_elements, Bu_elements) = self.map_to_h(recurrent_state, x) - start = start.reshape([-1, 1]) - start = jnp.concatenate([jnp.zeros_like(start[:1]), start], axis=0) + # Compute hidden states + hidden_states = self.scan(starts, Lambda_elements, Bu_elements) - new_recurrent_states = self.scan(start, Lambda_elements, Bu_elements) + outputs = self.map_from_h(hidden_states, x) - outputs = self.map_from_h(new_recurrent_states, x) + # Already has sequence dim removed + new_hidden_state = hidden_states[-1] - return new_recurrent_states[None, -1], outputs + return new_hidden_state, outputs @nn.nowrap def initialize_carry(self, batch_size: int) -> RecurrentState: - return jnp.zeros((1, self.d_hidden), dtype=jnp.complex64) - - -if __name__ == "__main__": - LRUModel = LRUCell(d_model=2, d_hidden=4) - - m = LRUModel - - batch_size = 1 - time_steps = 10 - - y = jnp.ones((time_steps, 2)) - s = m.initialize_carry(batch_size) - start = jnp.zeros((time_steps,), dtype=bool) - params = m.init(jax.random.PRNGKey(0), s, (y, start)) - out_state, out = m.apply(params, s, (y, start)) - - print(out) + return jnp.zeros((batch_size, self.hidden_state_dim), dtype=jnp.complex64) diff --git a/stoix/networks/memoroids/old_code1.py b/stoix/networks/memoroids/old_code/old_code1.py similarity index 100% rename from stoix/networks/memoroids/old_code1.py rename to stoix/networks/memoroids/old_code/old_code1.py diff --git a/stoix/networks/memoroids/old_code2.py b/stoix/networks/memoroids/old_code/old_code2.py similarity index 100% rename from stoix/networks/memoroids/old_code2.py rename to stoix/networks/memoroids/old_code/old_code2.py diff --git a/stoix/networks/memoroids/old_s5.py b/stoix/networks/memoroids/old_code/old_s5.py similarity index 100% rename from stoix/networks/memoroids/old_s5.py rename to stoix/networks/memoroids/old_code/old_s5.py diff --git a/stoix/networks/memoroids/s5.py b/stoix/networks/memoroids/s5.py new file mode 100644 index 00000000..67523fc4 --- /dev/null +++ b/stoix/networks/memoroids/s5.py @@ -0,0 +1,464 @@ +import functools +from functools import partial + +import chex +import jax +import jax.numpy as np +import jax.numpy as jnp +import optax +from flax import linen as nn +from jax import random +from jax.nn.initializers import lecun_normal, normal +from jax.numpy.linalg import eigh + +from stoix.networks.memoroids.base import ( + InputEmbedding, + Inputs, + MemoroidCellBase, + RecurrentState, +) + + +def log_step_initializer(dt_min=0.001, dt_max=0.1): + """Initialize the learnable timescale Delta by sampling + uniformly between dt_min and dt_max. + Args: + dt_min (float32): minimum value + dt_max (float32): maximum value + Returns: + init function + """ + + def init(key, shape): + """Init function + Args: + key: jax random key + shape tuple: desired shape + Returns: + sampled log_step (float32) + """ + return random.uniform(key, shape) * (np.log(dt_max) - np.log(dt_min)) + np.log(dt_min) + + return init + + +def init_log_steps(key, input): + """Initialize an array of learnable timescale parameters + Args: + key: jax random key + input: tuple containing the array shape H and + dt_min and dt_max + Returns: + initialized array of timescales (float32): (H,) + """ + H, dt_min, dt_max = input + log_steps = [] + for i in range(H): + key, skey = random.split(key) + log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,)) + log_steps.append(log_step) + + return np.array(log_steps) + + +def init_VinvB(init_fun, rng, shape, Vinv): + """Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. + Note we will parameterize this with two different matrices for complex + numbers. + Args: + init_fun: the initialization function to use, e.g. lecun_normal() + rng: jax random key to be used with init function. + shape (tuple): desired shape (P,H) + Vinv: (complex64) the inverse eigenvectors used for initialization + Returns: + B_tilde (complex64) of shape (P,H,2) + """ + B = init_fun(rng, shape) + VinvB = Vinv @ B + VinvB_real = VinvB.real + VinvB_imag = VinvB.imag + return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) + + +def trunc_standard_normal(key, shape): + """Sample C with a truncated normal distribution with standard deviation 1. + Args: + key: jax random key + shape (tuple): desired shape, of length 3, (H,P,_) + Returns: + sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) + """ + H, P, _ = shape + Cs = [] + for i in range(H): + key, skey = random.split(key) + C = lecun_normal()(skey, shape=(1, P, 2)) + Cs.append(C) + return np.array(Cs)[:, 0] + + +def init_CV(init_fun, rng, shape, V): + """Initialize C_tilde=CV. First sample C. Then compute CV. + Note we will parameterize this with two different matrices for complex + numbers. + Args: + init_fun: the initialization function to use, e.g. lecun_normal() + rng: jax random key to be used with init function. + shape (tuple): desired shape (H,P) + V: (complex64) the eigenvectors used for initialization + Returns: + C_tilde (complex64) of shape (H,P,2) + """ + C_ = init_fun(rng, shape) + C = C_[..., 0] + 1j * C_[..., 1] + CV = C @ V + CV_real = CV.real + CV_imag = CV.imag + return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1) + + +# Discretization functions +def discretize_bilinear(Lambda, B_tilde, Delta): + """Discretize a diagonalized, continuous-time linear SSM + using bilinear transform method. + Args: + Lambda (complex64): diagonal state matrix (P,) + B_tilde (complex64): input matrix (P, H) + Delta (float32): discretization step sizes (P,) + Returns: + discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) + """ + Identity = np.ones(Lambda.shape[0]) + + BL = 1 / (Identity - (Delta / 2.0) * Lambda) + Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda) + B_bar = (BL * Delta)[..., None] * B_tilde + return Lambda_bar, B_bar + + +def discretize_zoh(Lambda, B_tilde, Delta): + """Discretize a diagonalized, continuous-time linear SSM + using zero-order hold method. + Args: + Lambda (complex64): diagonal state matrix (P,) + B_tilde (complex64): input matrix (P, H) + Delta (float32): discretization step sizes (P,) + Returns: + discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) + """ + Identity = np.ones(Lambda.shape[0]) + Lambda_bar = np.exp(Lambda * Delta) + B_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde + return Lambda_bar, B_bar + + +# Parallel scan operations +@jax.vmap +def binary_operator_reset(q_i, q_j): + """Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. + Args: + q_i: tuple containing A_i and Bu_i at position i (P,), (P,) + q_j: tuple containing A_j and Bu_j at position j (P,), (P,) + Returns: + new element ( A_out, Bu_out ) + """ + A_i, b_i, c_i = q_i + A_j, b_j, c_j = q_j + return ( + (A_j * A_i) * (1 - c_j) + A_j * c_j, + (A_j * b_i + b_j) * (1 - c_j) + b_j * c_j, + c_i * (1 - c_j) + c_j, + ) + + +def make_HiPPO(N): + """Create a HiPPO-LegS matrix. + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Args: + N (int32): state size + Returns: + N x N HiPPO LegS matrix + """ + P = np.sqrt(1 + 2 * np.arange(N)) + A = P[:, np.newaxis] * P[np.newaxis, :] + A = np.tril(A) - np.diag(np.arange(N)) + return -A + + +def make_NPLR_HiPPO(N): + """ + Makes components needed for NPLR representation of HiPPO-LegS + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Args: + N (int32): state size + Returns: + N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B + """ + # Make -HiPPO + hippo = make_HiPPO(N) + + # Add in a rank 1 term. Makes it Normal. + P = np.sqrt(np.arange(N) + 0.5) + + # HiPPO also specifies the B matrix + B = np.sqrt(2 * np.arange(N) + 1.0) + return hippo, P, B + + +def make_DPLR_HiPPO(N): + """ + Makes components needed for DPLR representation of HiPPO-LegS + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Note, we will only use the diagonal part + Args: + N: + Returns: + eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, + eigenvectors V, HiPPO B pre-conjugation + """ + A, P, B = make_NPLR_HiPPO(N) + + S = A + P[:, np.newaxis] * P[np.newaxis, :] + + S_diag = np.diagonal(S) + Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) + + # Diagonalize S to V \Lambda V^* + Lambda_imag, V = eigh(S * -1j) + + P = V.conj().T @ P + B_orig = B + B = V.conj().T @ B + return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig + + +class S5Cell(MemoroidCellBase): + d_model: int + state_size: int + blocks: int = 1 + + activation: str = "gelu" + do_norm: bool = True + prenorm: bool = True + do_gtrxl_norm: bool = True + + C_init: str = "lecun_normal" + discretization: str = "zoh" + dt_min: float = 0.001 + dt_max: float = 0.1 + conj_sym: bool = True + clip_eigs: bool = False + bidirectional: bool = False + step_rescale: float = 1.0 + + def setup(self): + """Initializes parameters once and performs discretization each time + the SSM is applied to a sequence + """ + self.ssm_size = self.state_size * 2 + + block_size = int(self.ssm_size / self.blocks) + Lambda, _, _, V, _ = make_DPLR_HiPPO(self.ssm_size) + block_size = block_size // 2 + Lambda = Lambda[:block_size] + V = V[:, :block_size] + Vinv = V.conj().T + + self.H = self.d_model + self.P = self.state_size + self.Lambda_re_init = Lambda.real + self.Lambda_im_init = Lambda.imag + self.V = V + self.Vinv = Vinv + + if self.conj_sym: + # Need to account for case where we actually sample real B and C, and then multiply + # by the half sized Vinv and possibly V + local_P = 2 * self.P + else: + local_P = self.P + + # Initialize diagonal state to state matrix Lambda (eigenvalues) + self.Lambda_re = self.param("Lambda_re", lambda rng, shape: self.Lambda_re_init, (None,)) + self.Lambda_im = self.param("Lambda_im", lambda rng, shape: self.Lambda_im_init, (None,)) + if self.clip_eigs: + self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im + else: + self.Lambda = self.Lambda_re + 1j * self.Lambda_im + + # Initialize input to state (B) matrix + B_init = lecun_normal() + B_shape = (local_P, self.H) + self.B = self.param( + "B", lambda rng, shape: init_VinvB(B_init, rng, shape, self.Vinv), B_shape + ) + B_tilde = self.B[..., 0] + 1j * self.B[..., 1] + + # Initialize state to output (C) matrix + if self.C_init in ["trunc_standard_normal"]: + C_init = trunc_standard_normal + C_shape = (self.H, local_P, 2) + elif self.C_init in ["lecun_normal"]: + C_init = lecun_normal() + C_shape = (self.H, local_P, 2) + elif self.C_init in ["complex_normal"]: + C_init = normal(stddev=0.5**0.5) + else: + raise NotImplementedError("C_init method {} not implemented".format(self.C_init)) + + if self.C_init in ["complex_normal"]: + if self.bidirectional: + C = self.param("C", C_init, (self.H, 2 * self.P, 2)) + self.C_tilde = C[..., 0] + 1j * C[..., 1] + + else: + C = self.param("C", C_init, (self.H, self.P, 2)) + self.C_tilde = C[..., 0] + 1j * C[..., 1] + + else: + if self.bidirectional: + self.C1 = self.param( + "C1", lambda rng, shape: init_CV(C_init, rng, shape, self.V), C_shape + ) + self.C2 = self.param( + "C2", lambda rng, shape: init_CV(C_init, rng, shape, self.V), C_shape + ) + + C1 = self.C1[..., 0] + 1j * self.C1[..., 1] + C2 = self.C2[..., 0] + 1j * self.C2[..., 1] + self.C_tilde = np.concatenate((C1, C2), axis=-1) + + else: + self.C = self.param( + "C", lambda rng, shape: init_CV(C_init, rng, shape, self.V), C_shape + ) + + self.C_tilde = self.C[..., 0] + 1j * self.C[..., 1] + + # Initialize feedthrough (D) matrix + self.D = self.param("D", normal(stddev=1.0), (self.H,)) + + # Initialize learnable discretization timescale value + self.log_step = self.param("log_step", init_log_steps, (self.P, self.dt_min, self.dt_max)) + step = self.step_rescale * np.exp(self.log_step[:, 0]) + + # Discretize + if self.discretization in ["zoh"]: + self.Lambda_bar, self.B_bar = discretize_zoh(self.Lambda, B_tilde, step) + elif self.discretization in ["bilinear"]: + self.Lambda_bar, self.B_bar = discretize_bilinear(self.Lambda, B_tilde, step) + else: + raise NotImplementedError( + "Discretization method {} not implemented".format(self.discretization) + ) + + if self.activation in ["full_glu"]: + self.out1 = nn.Dense(self.d_model) + self.out2 = nn.Dense(self.d_model) + elif self.activation in ["half_glu1", "half_glu2"]: + self.out2 = nn.Dense(self.d_model) + + self.norm = nn.LayerNorm() + + def map_to_h(self, recurrent_state: RecurrentState, x: Inputs): + + if self.prenorm and self.do_norm: + x = self.norm(x) + + Lambda_elements = self.Lambda_bar * jnp.ones((x.shape[0], self.Lambda_bar.shape[0])) + Bu_elements = jax.vmap(lambda u: self.B_bar @ u)(x) + + Lambda_elements = jnp.concatenate( + [ + jnp.ones((1, self.Lambda_bar.shape[0])), + Lambda_elements, + ] + ) + + Bu_elements = jnp.concatenate( + [ + recurrent_state, + Bu_elements, + ] + ) + + return (Lambda_elements, Bu_elements) + + def scan(self, resets, Lambda_elements, Bu_elements): + + resets = jnp.concatenate( + [ + jnp.zeros(1), + resets, + ] + ) + _, xs, _ = jax.lax.associative_scan( + binary_operator_reset, (Lambda_elements, Bu_elements, resets) + ) + xs = xs[1:] + + return xs + + def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding): + skip = x + if self.conj_sym: + x = jax.vmap(lambda x: 2 * (self.C_tilde @ x).real)(recurrent_state) + else: + x = jax.vmap(lambda x: (self.C_tilde @ x).real)(recurrent_state) + + # Add feedthrough matrix output Du; + Du = jax.vmap(lambda u: self.D * u)(x) + x = x + Du + + if self.do_gtrxl_norm: + x = self.norm(x) + + if self.activation in ["full_glu"]: + x = nn.gelu(x) + x = self.out1(x) * jax.nn.sigmoid(self.out2(x)) + elif self.activation in ["half_glu1"]: + x = nn.gelu(x) + x = x * jax.nn.sigmoid(self.out2(x)) + elif self.activation in ["half_glu2"]: + # Only apply GELU to the gate input + x1 = nn.gelu(x) + x = x * jax.nn.sigmoid(self.out2(x1)) + elif self.activation in ["gelu"]: + x = nn.gelu(x) + + x = skip + x + if not self.prenorm and self.do_norm: + x = self.norm(x) + + return x + + @functools.partial( + nn.vmap, + variable_axes={"params": None}, + in_axes=(0, 1), + out_axes=(0, 1), + split_rngs={"params": False}, + ) + @nn.compact + def __call__(self, recurrent_state: RecurrentState, inputs: Inputs): + + # Add a sequence dimension to the recurrent state + recurrent_state = jnp.expand_dims(recurrent_state, 0) + + x, starts = inputs + + (Lambda_elements, Bu_elements) = self.map_to_h(recurrent_state, x) + + # Compute hidden states + hidden_states = self.scan(starts, Lambda_elements, Bu_elements) + + outputs = self.map_from_h(hidden_states, x) + + # Already has sequence dim removed + new_hidden_state = hidden_states[-1] + + return new_hidden_state, outputs + + @nn.nowrap + def initialize_carry(self, batch_size): + return jnp.zeros((batch_size, self.state_size), dtype=jnp.complex64) diff --git a/stoix/systems/ppo/rec_ppo.py b/stoix/systems/ppo/rec_ppo.py index 3e00b41d..b6ddc45a 100644 --- a/stoix/systems/ppo/rec_ppo.py +++ b/stoix/systems/ppo/rec_ppo.py @@ -80,7 +80,7 @@ def _env_step( last_timestep, last_done, last_truncated, - hstates, + last_hstates, ) = learner_state key, policy_key = jax.random.split(key) @@ -96,10 +96,10 @@ def _env_step( # Run the network. policy_hidden_state, actor_policy = actor_apply_fn( - params.actor_params, hstates.policy_hidden_state, ac_in + params.actor_params, last_hstates.policy_hidden_state, ac_in ) critic_hidden_state, value = critic_apply_fn( - params.critic_params, hstates.critic_hidden_state, ac_in + params.critic_params, last_hstates.critic_hidden_state, ac_in ) # Sample action from the policy and squeeze out the batch dimension. @@ -128,7 +128,7 @@ def _env_step( timestep.reward, log_prob, last_timestep.observation, - hstates, + last_hstates, info, ) learner_state = RNNLearnerState( @@ -143,9 +143,6 @@ def _env_step( ) return learner_state, transition - # INITIALISE RNN STATE - initial_hstates = learner_state.hstates - # STEP ENVIRONMENT FOR ROLLOUT LENGTH learner_state, traj_batch = jax.lax.scan( _env_step, learner_state, None, config.system.rollout_length @@ -316,7 +313,6 @@ def _critic_loss_fn( ( params, opt_states, - init_hstates, traj_batch, advantages, targets, @@ -359,7 +355,6 @@ def _critic_loss_fn( update_state = ( params, opt_states, - init_hstates, traj_batch, advantages, targets, @@ -367,11 +362,9 @@ def _critic_loss_fn( ) return update_state, loss_info - init_hstates = jax.tree_util.tree_map(lambda x: x[None, :], initial_hstates) update_state = ( params, opt_states, - init_hstates, traj_batch, advantages, targets, @@ -383,7 +376,7 @@ def _critic_loss_fn( _update_epoch, update_state, None, config.system.epochs ) - params, opt_states, _, traj_batch, advantages, targets, key = update_state + params, opt_states, traj_batch, advantages, targets, key = update_state learner_state = RNNLearnerState( params, opt_states, From 232b7ca13f5e6adfc1975c646154bed80ba15ede Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Thu, 4 Jul 2024 15:05:43 +0000 Subject: [PATCH 34/38] chore: config and network edits --- stoix/configs/logger/base_logger.yaml | 2 +- stoix/configs/network/ffm.yaml | 24 +++----- stoix/configs/network/lru.yaml | 4 +- stoix/configs/network/rnn.yaml | 4 +- stoix/configs/network/s5.yaml | 4 +- stoix/networks/memoroids/base.py | 88 +++++++++++++-------------- stoix/networks/memoroids/ffm.py | 66 +++++++++++--------- 7 files changed, 99 insertions(+), 93 deletions(-) diff --git a/stoix/configs/logger/base_logger.yaml b/stoix/configs/logger/base_logger.yaml index 4d6ecf14..8ee91dfe 100644 --- a/stoix/configs/logger/base_logger.yaml +++ b/stoix/configs/logger/base_logger.yaml @@ -4,7 +4,7 @@ base_exp_path: results # Base path for logging. use_console: True # Whether to log to stdout. use_tb: False # Whether to use tensorboard logging. use_json: False # Whether to log marl-eval style to json files. -use_neptune: False # Whether to log to neptune.ai. +use_neptune: True # Whether to log to neptune.ai. use_wandb: False # Whether to log to wandb.ai. # --- Other logger kwargs --- diff --git a/stoix/configs/network/ffm.yaml b/stoix/configs/network/ffm.yaml index 30f0afc8..eb0b62e4 100644 --- a/stoix/configs/network/ffm.yaml +++ b/stoix/configs/network/ffm.yaml @@ -4,15 +4,13 @@ actor_network: pre_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] - use_layer_norm: False + use_layer_norm: True activation: leaky_relu rnn_layer: - _target_: stoix.networks.memoroids.base.ScannedMemoroid - cell: - _target_: stoix.networks.memoroids.ffm.FFMCell - trace_size: 32 - context_size: 4 - output_size: 256 + _target_: stoix.networks.memoroids.ffm.FFMCell + trace_size: 64 + context_size: 4 + output_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] @@ -25,15 +23,13 @@ critic_network: pre_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] - use_layer_norm: False + use_layer_norm: True activation: leaky_relu rnn_layer: - _target_: stoix.networks.memoroids.base.ScannedMemoroid - cell: - _target_: stoix.networks.memoroids.ffm.FFMCell - trace_size: 32 - context_size: 4 - output_size: 256 + _target_: stoix.networks.memoroids.ffm.FFMCell + trace_size: 64 + context_size: 4 + output_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] diff --git a/stoix/configs/network/lru.yaml b/stoix/configs/network/lru.yaml index f03b0288..247050e5 100644 --- a/stoix/configs/network/lru.yaml +++ b/stoix/configs/network/lru.yaml @@ -8,7 +8,7 @@ actor_network: activation: leaky_relu rnn_layer: _target_: stoix.networks.memoroids.lru.LRUCell - hidden_state_dim: 512 + hidden_state_dim: 256 post_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] @@ -25,7 +25,7 @@ critic_network: activation: leaky_relu rnn_layer: _target_: stoix.networks.memoroids.lru.LRUCell - hidden_state_dim: 512 + hidden_state_dim: 256 post_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] diff --git a/stoix/configs/network/rnn.yaml b/stoix/configs/network/rnn.yaml index 285b1bbb..5f920cdd 100644 --- a/stoix/configs/network/rnn.yaml +++ b/stoix/configs/network/rnn.yaml @@ -4,7 +4,7 @@ actor_network: pre_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] - use_layer_norm: False + use_layer_norm: True activation: leaky_relu rnn_layer: _target_: stoix.networks.recurrent.ScannedRNN @@ -22,7 +22,7 @@ critic_network: pre_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] - use_layer_norm: False + use_layer_norm: True activation: leaky_relu rnn_layer: _target_: stoix.networks.recurrent.ScannedRNN diff --git a/stoix/configs/network/s5.yaml b/stoix/configs/network/s5.yaml index 04426503..1ea232a8 100644 --- a/stoix/configs/network/s5.yaml +++ b/stoix/configs/network/s5.yaml @@ -4,7 +4,7 @@ actor_network: pre_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] - use_layer_norm: False + use_layer_norm: True activation: leaky_relu rnn_layer: _target_: stoix.networks.memoroids.s5.S5Cell @@ -22,7 +22,7 @@ critic_network: pre_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] - use_layer_norm: False + use_layer_norm: True activation: leaky_relu rnn_layer: _target_: stoix.networks.memoroids.s5.S5Cell diff --git a/stoix/networks/memoroids/base.py b/stoix/networks/memoroids/base.py index 6fa9a03a..d88c811b 100644 --- a/stoix/networks/memoroids/base.py +++ b/stoix/networks/memoroids/base.py @@ -45,63 +45,63 @@ def num_feature_axes(self) -> int: raise NotImplementedError -class ScannedMemoroid(nn.Module): - cell: MemoroidCellBase +# class ScannedMemoroid(nn.Module): +# cell: MemoroidCellBase - @nn.compact - def __call__(self, state: RecurrentState, inputs: Inputs) -> Tuple[RecurrentState, chex.Array]: +# @nn.compact +# def __call__(self, state: RecurrentState, inputs: Inputs) -> Tuple[RecurrentState, chex.Array]: - # Add a sequence dimension to the recurrent state. - state = jnp.expand_dims(state, 0) +# # Add a sequence dimension to the recurrent state. +# state = jnp.expand_dims(state, 0) - # Unpack inputs - x, start = inputs +# # Unpack inputs +# x, start = inputs - # Map the input embedding to the recurrent state space. - # This maps to the format required for the associative scan. - scan_input = self.cell.map_to_h(x) +# # Map the input embedding to the recurrent state space. +# # This maps to the format required for the associative scan. +# scan_input = self.cell.map_to_h(x) - # Update the recurrent state - state = self.cell.scan(scan_input, state, start) +# # Update the recurrent state +# state = self.cell.scan(scan_input, state, start) - # Map the recurrent state back to the output space - out = self.cell.map_from_h(state, x) +# # Map the recurrent state back to the output space +# out = self.cell.map_from_h(state, x) - # Take the final state of the sequence. - final_state = state[-1:] +# # Take the final state of the sequence. +# final_state = state[-1:] - # Remove the sequence dimemnsion from the final state. - final_state = jnp.squeeze(final_state, 0) +# # Remove the sequence dimemnsion from the final state. +# final_state = jnp.squeeze(final_state, 0) - return final_state, out +# return final_state, out - @nn.nowrap - def initialize_carry(self, batch_size: int) -> RecurrentState: - return self.cell.initialize_carry(batch_size) +# @nn.nowrap +# def initialize_carry(self, batch_size: int) -> RecurrentState: +# return self.cell.initialize_carry(batch_size) -class StackedMemoroid(nn.Module): - cells: Tuple[ScannedMemoroid] +# class StackedMemoroid(nn.Module): +# cells: Tuple[ScannedMemoroid] - @nn.compact - def __call__( - self, all_states: List[RecurrentState], inputs: Inputs - ) -> Tuple[RecurrentState, chex.Array]: - # Ensure all_states is a list - if not isinstance(all_states, list): - all_states = [all_states] +# @nn.compact +# def __call__( +# self, all_states: List[RecurrentState], inputs: Inputs +# ) -> Tuple[RecurrentState, chex.Array]: +# # Ensure all_states is a list +# if not isinstance(all_states, list): +# all_states = [all_states] - assert len(all_states) == len( - self.cells - ), f"Expected {len(self.cells)} states, got {len(all_states)}" - x, starts = inputs - new_states = [] - for cell, mem_state in zip(self.cells, all_states): - new_mem_state, x = cell(mem_state, (x, starts)) - new_states.append(new_mem_state) +# assert len(all_states) == len( +# self.cells +# ), f"Expected {len(self.cells)} states, got {len(all_states)}" +# x, starts = inputs +# new_states = [] +# for cell, mem_state in zip(self.cells, all_states): +# new_mem_state, x = cell(mem_state, (x, starts)) +# new_states.append(new_mem_state) - return new_states, x +# return new_states, x - @nn.nowrap - def initialize_carry(self, batch_size: int) -> List[RecurrentState]: - return [cell.initialize_carry(batch_size) for cell in self.cells] +# @nn.nowrap +# def initialize_carry(self, batch_size: int) -> List[RecurrentState]: +# return [cell.initialize_carry(batch_size) for cell in self.cells] diff --git a/stoix/networks/memoroids/ffm.py b/stoix/networks/memoroids/ffm.py index 3066d975..15bbc85e 100644 --- a/stoix/networks/memoroids/ffm.py +++ b/stoix/networks/memoroids/ffm.py @@ -23,20 +23,16 @@ def init(key, shape): a_high = 0.5 a = jnp.linspace(a_low, a_high, memory_size) return a - return init - def init_deterministic_b( - context_size: int, min_period: int = 1, max_period: int = 1_000 + context_size: int, min_period: int = 1, max_period: int = 1_000 ) -> Tuple[chex.Array, chex.Array]: def init(key, shape): b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) return b - return init - class Gate(nn.Module): output_size: int @@ -50,10 +46,35 @@ class FFMCell(nn.Module): context_size: int output_size: int - def map_to_h(self, x: InputEmbedding) -> ScanInput: + def setup(self) -> None: + + # Create the FFM parameters + self.a = self.param( + "ffm_a", + init_deterministic_a(self.trace_size), + (), + ) + self.b = self.param( + "ffm_b", + init_deterministic_b(self.context_size), + (), + ) + + # Create the networks and parameters that are used when + # mapping from input space to recurrent state space + # This is used in the map_to_h method and is used in the + # associative scan outer loop + self.pre = nn.Dense(self.trace_size) + self.gate_in = Gate(self.trace_size) + self.gate_out = Gate(self.output_size) + self.skip = nn.Dense(self.output_size) + self.mix = nn.Dense(self.output_size) + self.ln = nn.LayerNorm(use_scale=False, use_bias=False) + + def map_to_h(self, state : RecurrentState, x: InputEmbedding) -> ScanInput: """Given an input embedding, this will map it to the format required for the associative scan.""" - gate_in = Gate(self.trace_size)(x) - pre = nn.Dense(self.trace_size)(x) + gate_in = self.gate_in(x) + pre = self.pre(x) gated_x = pre * gate_in scan_input = jnp.repeat(jnp.expand_dims(gated_x, 3), self.context_size, axis=3) return scan_input @@ -63,28 +84,18 @@ def map_from_h(self, state: RecurrentState, x: InputEmbedding) -> chex.Array: T = state.shape[0] B = state.shape[1] z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape(T, B, -1) - z = nn.Dense(self.output_size)(z_in) - gate_out = Gate(self.output_size)(x) - skip = nn.Dense(self.output_size)(x) - out = nn.LayerNorm(use_scale=False, use_bias=False)(z * gate_out) + skip * (1 - gate_out) + z = self.mix(z_in) + gate_out = self.gate_out(x) + skip = self.skip(x) + out = self.ln(z * gate_out) + skip * (1 - gate_out) return out def log_gamma(self, t: Timestep) -> chex.Array: T = t.shape[0] B = t.shape[1] - - a = self.param( - "ffm_a", - init_deterministic_a(self.trace_size), - (), - ) - b = self.param( - "ffm_b", - init_deterministic_b(self.context_size), - (), - ) - a = -jnp.abs(a).reshape((1, 1, self.trace_size, 1)) - b = b.reshape(1, 1, 1, self.context_size) + + a = -jnp.abs(self.a).reshape((1, 1, self.trace_size, 1)) + b = self.b.reshape(1, 1, 1, self.context_size) ab = jax.lax.complex(a, b) return ab * t.reshape(T, B, 1, 1) @@ -148,7 +159,6 @@ def scan( ) return new_state[1:] - @nn.compact def __call__(self, state: RecurrentState, inputs: Inputs) -> Tuple[RecurrentState, chex.Array]: # Add a sequence dimension to the recurrent state. @@ -159,7 +169,7 @@ def __call__(self, state: RecurrentState, inputs: Inputs) -> Tuple[RecurrentStat # Map the input embedding to the recurrent state space. # This maps to the format required for the associative scan. - scan_input = self.map_to_h(x) + scan_input = self.map_to_h(state, x) # Update the recurrent state state = self.scan(scan_input, state, start) @@ -174,7 +184,7 @@ def __call__(self, state: RecurrentState, inputs: Inputs) -> Tuple[RecurrentStat final_state = jnp.squeeze(final_state, 0) return final_state, out - + @nn.nowrap def initialize_carry(self, batch_size: int) -> RecurrentState: return jnp.zeros((batch_size, self.trace_size, self.context_size), dtype=jnp.complex64) From 6dd59c571e0e7b6c7e20daaa37d48a9917fb315f Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Sun, 7 Jul 2024 13:39:43 +0000 Subject: [PATCH 35/38] feat: add stacked model --- stoix/configs/network/stacked_lrm.yaml | 45 +++++++++++++++++++++ stoix/networks/memoroids/ffm.py | 12 ++++-- stoix/networks/memoroids/layers.py | 56 ++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 4 deletions(-) create mode 100644 stoix/configs/network/stacked_lrm.yaml create mode 100644 stoix/networks/memoroids/layers.py diff --git a/stoix/configs/network/stacked_lrm.yaml b/stoix/configs/network/stacked_lrm.yaml new file mode 100644 index 00000000..c49270b8 --- /dev/null +++ b/stoix/configs/network/stacked_lrm.yaml @@ -0,0 +1,45 @@ +# ---Recurrent Structure Networks for PPO --- + +actor_network: + pre_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256] + use_layer_norm: True + activation: leaky_relu + rnn_layer: + _target_: stoix.networks.memoroids.layers.StackedMemoroid + num_cells: 2 + lrm_cell_type: ffm + cell_kwargs: + trace_size: 64 + context_size: 4 + output_size: 256 + post_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256] + use_layer_norm: False + activation: leaky_relu + action_head: + _target_: stoix.networks.heads.CategoricalHead + +critic_network: + pre_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256] + use_layer_norm: True + activation: leaky_relu + rnn_layer: + _target_: stoix.networks.memoroids.layers.StackedMemoroid + num_cells: 2 + lrm_cell_type: ffm + cell_kwargs: + trace_size: 64 + context_size: 4 + output_size: 256 + post_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256] + use_layer_norm: False + activation: leaky_relu + critic_head: + _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/networks/memoroids/ffm.py b/stoix/networks/memoroids/ffm.py index 15bbc85e..cfb0d40c 100644 --- a/stoix/networks/memoroids/ffm.py +++ b/stoix/networks/memoroids/ffm.py @@ -23,16 +23,20 @@ def init(key, shape): a_high = 0.5 a = jnp.linspace(a_low, a_high, memory_size) return a + return init + def init_deterministic_b( - context_size: int, min_period: int = 1, max_period: int = 1_000 + context_size: int, min_period: int = 1, max_period: int = 1_000 ) -> Tuple[chex.Array, chex.Array]: def init(key, shape): b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) return b + return init + class Gate(nn.Module): output_size: int @@ -71,7 +75,7 @@ def setup(self) -> None: self.mix = nn.Dense(self.output_size) self.ln = nn.LayerNorm(use_scale=False, use_bias=False) - def map_to_h(self, state : RecurrentState, x: InputEmbedding) -> ScanInput: + def map_to_h(self, state: RecurrentState, x: InputEmbedding) -> ScanInput: """Given an input embedding, this will map it to the format required for the associative scan.""" gate_in = self.gate_in(x) pre = self.pre(x) @@ -93,7 +97,7 @@ def map_from_h(self, state: RecurrentState, x: InputEmbedding) -> chex.Array: def log_gamma(self, t: Timestep) -> chex.Array: T = t.shape[0] B = t.shape[1] - + a = -jnp.abs(self.a).reshape((1, 1, self.trace_size, 1)) b = self.b.reshape(1, 1, 1, self.context_size) ab = jax.lax.complex(a, b) @@ -184,7 +188,7 @@ def __call__(self, state: RecurrentState, inputs: Inputs) -> Tuple[RecurrentStat final_state = jnp.squeeze(final_state, 0) return final_state, out - + @nn.nowrap def initialize_carry(self, batch_size: int) -> RecurrentState: return jnp.zeros((batch_size, self.trace_size, self.context_size), dtype=jnp.complex64) diff --git a/stoix/networks/memoroids/layers.py b/stoix/networks/memoroids/layers.py new file mode 100644 index 00000000..a4ffe4d7 --- /dev/null +++ b/stoix/networks/memoroids/layers.py @@ -0,0 +1,56 @@ +from typing import Any, Callable, Dict, List, Tuple + +import chex +from flax import linen as nn + +from stoix.networks.memoroids.base import Inputs, MemoroidCellBase, RecurrentState +from stoix.networks.memoroids.ffm import FFMCell +from stoix.networks.memoroids.lru import LRUCell +from stoix.networks.memoroids.s5 import S5Cell + + +def parse_lrm_cell(lrm_cell_name: str) -> MemoroidCellBase: + """Get the lrm cell.""" + lrm_cells: Dict[str, MemoroidCellBase] = { + "s5": S5Cell, + "ffm": FFMCell, + "lru": LRUCell, + } + return lrm_cells[lrm_cell_name] + + +class StackedMemoroid(nn.Module): + lrm_cell_type: MemoroidCellBase + cell_kwargs: Dict[str, Any] + num_cells: int + + def setup(self) -> None: + """Set up the Memoroid cells for the stacked Memoroid.""" + + cell_cls = parse_lrm_cell(self.lrm_cell_type) + self.cells = [cell_cls(**self.cell_kwargs) for _ in range(self.num_cells)] + + @nn.compact + def __call__( + self, all_states: List[RecurrentState], inputs: Inputs + ) -> Tuple[RecurrentState, chex.Array]: + # Ensure all_states is a list + if not isinstance(all_states, list): + all_states = [all_states] + + assert len(all_states) == len( + self.cells + ), f"Expected {len(self.cells)} states, got {len(all_states)}" + x, starts = inputs + new_states = [] + for cell, mem_state in zip(self.cells, all_states): + new_mem_state, x = cell(mem_state, (x, starts)) + new_states.append(new_mem_state) + + return new_states, x + + @nn.nowrap + def initialize_carry(self, batch_size: int) -> List[RecurrentState]: + cell_cls = parse_lrm_cell(self.lrm_cell_type) + cells = [cell_cls(**self.cell_kwargs) for _ in range(self.num_cells)] + return [cell.initialize_carry(batch_size) for cell in cells] From ab3de3f510eca03a1c999e35e2f1e754edc46456 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Sun, 7 Jul 2024 13:48:22 +0000 Subject: [PATCH 36/38] chore: refactor slightly --- stoix/networks/memoroids/base.py | 107 ----------------------------- stoix/networks/memoroids/ffm.py | 16 +++-- stoix/networks/memoroids/layers.py | 8 +-- stoix/networks/memoroids/lru.py | 5 +- stoix/networks/memoroids/s5.py | 22 +++--- stoix/networks/memoroids/types.py | 12 ++++ 6 files changed, 41 insertions(+), 129 deletions(-) delete mode 100644 stoix/networks/memoroids/base.py create mode 100644 stoix/networks/memoroids/types.py diff --git a/stoix/networks/memoroids/base.py b/stoix/networks/memoroids/base.py deleted file mode 100644 index d88c811b..00000000 --- a/stoix/networks/memoroids/base.py +++ /dev/null @@ -1,107 +0,0 @@ -from typing import List, Optional, Tuple - -import chex -from flax import linen as nn -from jax import numpy as jnp - -RecurrentState = chex.Array -Reset = chex.Array -Timestep = chex.Array -InputEmbedding = chex.Array -Inputs = Tuple[InputEmbedding, Reset] -ScanInput = chex.Array - - -class MemoroidCellBase(nn.Module): - """Memoroid cell base class.""" - - def map_to_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> RecurrentState: - raise NotImplementedError - - def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> RecurrentState: - raise NotImplementedError - - def scan(self, x: InputEmbedding, state: RecurrentState, start: Reset) -> RecurrentState: - raise NotImplementedError - - @nn.nowrap - def initialize_carry( - self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None - ) -> RecurrentState: - """Initialize the Memoroid cell carry. - - Args: - batch_size: the batch size of the carry. - rng: random number generator passed to the init_fn. - - Returns: - An initialized carry for the given Memoroid cell. - """ - raise NotImplementedError - - @property - def num_feature_axes(self) -> int: - """Returns the number of feature axes of the cell.""" - raise NotImplementedError - - -# class ScannedMemoroid(nn.Module): -# cell: MemoroidCellBase - -# @nn.compact -# def __call__(self, state: RecurrentState, inputs: Inputs) -> Tuple[RecurrentState, chex.Array]: - -# # Add a sequence dimension to the recurrent state. -# state = jnp.expand_dims(state, 0) - -# # Unpack inputs -# x, start = inputs - -# # Map the input embedding to the recurrent state space. -# # This maps to the format required for the associative scan. -# scan_input = self.cell.map_to_h(x) - -# # Update the recurrent state -# state = self.cell.scan(scan_input, state, start) - -# # Map the recurrent state back to the output space -# out = self.cell.map_from_h(state, x) - -# # Take the final state of the sequence. -# final_state = state[-1:] - -# # Remove the sequence dimemnsion from the final state. -# final_state = jnp.squeeze(final_state, 0) - -# return final_state, out - -# @nn.nowrap -# def initialize_carry(self, batch_size: int) -> RecurrentState: -# return self.cell.initialize_carry(batch_size) - - -# class StackedMemoroid(nn.Module): -# cells: Tuple[ScannedMemoroid] - -# @nn.compact -# def __call__( -# self, all_states: List[RecurrentState], inputs: Inputs -# ) -> Tuple[RecurrentState, chex.Array]: -# # Ensure all_states is a list -# if not isinstance(all_states, list): -# all_states = [all_states] - -# assert len(all_states) == len( -# self.cells -# ), f"Expected {len(self.cells)} states, got {len(all_states)}" -# x, starts = inputs -# new_states = [] -# for cell, mem_state in zip(self.cells, all_states): -# new_mem_state, x = cell(mem_state, (x, starts)) -# new_states.append(new_mem_state) - -# return new_states, x - -# @nn.nowrap -# def initialize_carry(self, batch_size: int) -> List[RecurrentState]: -# return [cell.initialize_carry(batch_size) for cell in self.cells] diff --git a/stoix/networks/memoroids/ffm.py b/stoix/networks/memoroids/ffm.py index cfb0d40c..05803331 100644 --- a/stoix/networks/memoroids/ffm.py +++ b/stoix/networks/memoroids/ffm.py @@ -5,7 +5,7 @@ from flax import linen as nn from jax import numpy as jnp -from stoix.networks.memoroids.base import ( +from stoix.networks.memoroids.types import ( InputEmbedding, Inputs, RecurrentState, @@ -75,7 +75,7 @@ def setup(self) -> None: self.mix = nn.Dense(self.output_size) self.ln = nn.LayerNorm(use_scale=False, use_bias=False) - def map_to_h(self, state: RecurrentState, x: InputEmbedding) -> ScanInput: + def map_to_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> ScanInput: """Given an input embedding, this will map it to the format required for the associative scan.""" gate_in = self.gate_in(x) pre = self.pre(x) @@ -83,11 +83,13 @@ def map_to_h(self, state: RecurrentState, x: InputEmbedding) -> ScanInput: scan_input = jnp.repeat(jnp.expand_dims(gated_x, 3), self.context_size, axis=3) return scan_input - def map_from_h(self, state: RecurrentState, x: InputEmbedding) -> chex.Array: + def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> chex.Array: """Given the recurrent state and the input embedding, this will map the recurrent state back to the output space.""" - T = state.shape[0] - B = state.shape[1] - z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape(T, B, -1) + T = recurrent_state.shape[0] + B = recurrent_state.shape[1] + z_in = jnp.concatenate( + [jnp.real(recurrent_state), jnp.imag(recurrent_state)], axis=-1 + ).reshape(T, B, -1) z = self.mix(z_in) gate_out = self.gate_out(x) skip = self.skip(x) @@ -137,7 +139,7 @@ def wrapped_associative_update( def scan( self, - x: InputEmbedding, + x: ScanInput, state: RecurrentState, start: Reset, ) -> RecurrentState: diff --git a/stoix/networks/memoroids/layers.py b/stoix/networks/memoroids/layers.py index a4ffe4d7..25e60897 100644 --- a/stoix/networks/memoroids/layers.py +++ b/stoix/networks/memoroids/layers.py @@ -3,15 +3,15 @@ import chex from flax import linen as nn -from stoix.networks.memoroids.base import Inputs, MemoroidCellBase, RecurrentState from stoix.networks.memoroids.ffm import FFMCell from stoix.networks.memoroids.lru import LRUCell from stoix.networks.memoroids.s5 import S5Cell +from stoix.networks.memoroids.types import Inputs, RecurrentState -def parse_lrm_cell(lrm_cell_name: str) -> MemoroidCellBase: +def parse_lrm_cell(lrm_cell_name: str) -> nn.Module: """Get the lrm cell.""" - lrm_cells: Dict[str, MemoroidCellBase] = { + lrm_cells: Dict[str, nn.Module] = { "s5": S5Cell, "ffm": FFMCell, "lru": LRUCell, @@ -20,7 +20,7 @@ def parse_lrm_cell(lrm_cell_name: str) -> MemoroidCellBase: class StackedMemoroid(nn.Module): - lrm_cell_type: MemoroidCellBase + lrm_cell_type: nn.Module cell_kwargs: Dict[str, Any] num_cells: int diff --git a/stoix/networks/memoroids/lru.py b/stoix/networks/memoroids/lru.py index 1af3c471..eb963f55 100644 --- a/stoix/networks/memoroids/lru.py +++ b/stoix/networks/memoroids/lru.py @@ -7,10 +7,9 @@ import jax.numpy as jnp from flax import linen as nn -from stoix.networks.memoroids.base import ( +from stoix.networks.memoroids.types import ( InputEmbedding, Inputs, - MemoroidCellBase, RecurrentState, Reset, ScanInput, @@ -67,7 +66,7 @@ def gamma_log_init(key, lamb): return jnp.log(jnp.sqrt(1 - jnp.abs(diag_lambda) ** 2)) -class LRUCell(MemoroidCellBase): +class LRUCell(nn.Module): """ LRU module in charge of the recurrent processing. Implementation following the one of Orvieto et al. 2023. diff --git a/stoix/networks/memoroids/s5.py b/stoix/networks/memoroids/s5.py index 67523fc4..0f9a072f 100644 --- a/stoix/networks/memoroids/s5.py +++ b/stoix/networks/memoroids/s5.py @@ -1,5 +1,6 @@ import functools from functools import partial +from typing import Tuple import chex import jax @@ -11,11 +12,12 @@ from jax.nn.initializers import lecun_normal, normal from jax.numpy.linalg import eigh -from stoix.networks.memoroids.base import ( +from stoix.networks.memoroids.types import ( InputEmbedding, Inputs, - MemoroidCellBase, RecurrentState, + Reset, + ScanInput, ) @@ -232,7 +234,7 @@ def make_DPLR_HiPPO(N): return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig -class S5Cell(MemoroidCellBase): +class S5Cell(nn.Module): d_model: int state_size: int blocks: int = 1 @@ -360,7 +362,7 @@ def setup(self): self.norm = nn.LayerNorm() - def map_to_h(self, recurrent_state: RecurrentState, x: Inputs): + def map_to_h(self, recurrent_state: RecurrentState, x: Inputs) -> ScanInput: if self.prenorm and self.do_norm: x = self.norm(x) @@ -384,7 +386,9 @@ def map_to_h(self, recurrent_state: RecurrentState, x: Inputs): return (Lambda_elements, Bu_elements) - def scan(self, resets, Lambda_elements, Bu_elements): + def scan( + self, resets: Reset, Lambda_elements: chex.Array, Bu_elements: chex.Array + ) -> RecurrentState: resets = jnp.concatenate( [ @@ -399,7 +403,7 @@ def scan(self, resets, Lambda_elements, Bu_elements): return xs - def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding): + def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> chex.Array: skip = x if self.conj_sym: x = jax.vmap(lambda x: 2 * (self.C_tilde @ x).real)(recurrent_state) @@ -440,7 +444,9 @@ def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding): split_rngs={"params": False}, ) @nn.compact - def __call__(self, recurrent_state: RecurrentState, inputs: Inputs): + def __call__( + self, recurrent_state: RecurrentState, inputs: Inputs + ) -> Tuple[RecurrentState, chex.Array]: # Add a sequence dimension to the recurrent state recurrent_state = jnp.expand_dims(recurrent_state, 0) @@ -460,5 +466,5 @@ def __call__(self, recurrent_state: RecurrentState, inputs: Inputs): return new_hidden_state, outputs @nn.nowrap - def initialize_carry(self, batch_size): + def initialize_carry(self, batch_size: int) -> RecurrentState: return jnp.zeros((batch_size, self.state_size), dtype=jnp.complex64) diff --git a/stoix/networks/memoroids/types.py b/stoix/networks/memoroids/types.py new file mode 100644 index 00000000..56bd8632 --- /dev/null +++ b/stoix/networks/memoroids/types.py @@ -0,0 +1,12 @@ +from typing import List, Optional, Tuple + +import chex +from flax import linen as nn +from jax import numpy as jnp + +RecurrentState = chex.Array +Reset = chex.Array +Timestep = chex.Array +InputEmbedding = chex.Array +Inputs = Tuple[InputEmbedding, Reset] +ScanInput = chex.Array From e278bb5f958b054e886e2dee927c611528509558 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Sun, 7 Jul 2024 20:01:50 +0000 Subject: [PATCH 37/38] chore: clean up code --- stoix/configs/network/ffm.yaml | 39 -- stoix/configs/network/lru.yaml | 35 - stoix/configs/network/s5.yaml | 37 - stoix/configs/network/stacked_lrm.yaml | 22 +- .../{memoroids/layers.py => lrm/base.py} | 39 +- stoix/networks/{memoroids => lrm}/ffm.py | 25 +- stoix/networks/{memoroids => lrm}/lru.py | 32 +- stoix/networks/{memoroids => lrm}/s5.py | 90 +-- stoix/networks/lrm/utils.py | 15 + .../networks/memoroids/old_code/old_code1.py | 357 ---------- .../networks/memoroids/old_code/old_code2.py | 420 ------------ stoix/networks/memoroids/old_code/old_s5.py | 645 ------------------ stoix/networks/memoroids/types.py | 12 - 13 files changed, 138 insertions(+), 1630 deletions(-) delete mode 100644 stoix/configs/network/ffm.yaml delete mode 100644 stoix/configs/network/lru.yaml delete mode 100644 stoix/configs/network/s5.yaml rename stoix/networks/{memoroids/layers.py => lrm/base.py} (60%) rename stoix/networks/{memoroids => lrm}/ffm.py (91%) rename stoix/networks/{memoroids => lrm}/lru.py (87%) rename stoix/networks/{memoroids => lrm}/s5.py (84%) create mode 100644 stoix/networks/lrm/utils.py delete mode 100644 stoix/networks/memoroids/old_code/old_code1.py delete mode 100644 stoix/networks/memoroids/old_code/old_code2.py delete mode 100644 stoix/networks/memoroids/old_code/old_s5.py delete mode 100644 stoix/networks/memoroids/types.py diff --git a/stoix/configs/network/ffm.yaml b/stoix/configs/network/ffm.yaml deleted file mode 100644 index eb0b62e4..00000000 --- a/stoix/configs/network/ffm.yaml +++ /dev/null @@ -1,39 +0,0 @@ -# ---Recurrent Structure Networks for PPO --- - -actor_network: - pre_torso: - _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: True - activation: leaky_relu - rnn_layer: - _target_: stoix.networks.memoroids.ffm.FFMCell - trace_size: 64 - context_size: 4 - output_size: 256 - post_torso: - _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: False - activation: leaky_relu - action_head: - _target_: stoix.networks.heads.CategoricalHead - -critic_network: - pre_torso: - _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: True - activation: leaky_relu - rnn_layer: - _target_: stoix.networks.memoroids.ffm.FFMCell - trace_size: 64 - context_size: 4 - output_size: 256 - post_torso: - _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: False - activation: leaky_relu - critic_head: - _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/configs/network/lru.yaml b/stoix/configs/network/lru.yaml deleted file mode 100644 index 247050e5..00000000 --- a/stoix/configs/network/lru.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# ---Recurrent Structure Networks for PPO --- - -actor_network: - pre_torso: - _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: True - activation: leaky_relu - rnn_layer: - _target_: stoix.networks.memoroids.lru.LRUCell - hidden_state_dim: 256 - post_torso: - _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: False - activation: leaky_relu - action_head: - _target_: stoix.networks.heads.CategoricalHead - -critic_network: - pre_torso: - _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: True - activation: leaky_relu - rnn_layer: - _target_: stoix.networks.memoroids.lru.LRUCell - hidden_state_dim: 256 - post_torso: - _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: False - activation: leaky_relu - critic_head: - _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/configs/network/s5.yaml b/stoix/configs/network/s5.yaml deleted file mode 100644 index 1ea232a8..00000000 --- a/stoix/configs/network/s5.yaml +++ /dev/null @@ -1,37 +0,0 @@ -# ---Recurrent Structure Networks for PPO --- - -actor_network: - pre_torso: - _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: True - activation: leaky_relu - rnn_layer: - _target_: stoix.networks.memoroids.s5.S5Cell - d_model : 256 - state_size: 256 - post_torso: - _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: False - activation: leaky_relu - action_head: - _target_: stoix.networks.heads.CategoricalHead - -critic_network: - pre_torso: - _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: True - activation: leaky_relu - rnn_layer: - _target_: stoix.networks.memoroids.s5.S5Cell - d_model : 256 - state_size: 256 - post_torso: - _target_: stoix.networks.torso.MLPTorso - layer_sizes: [256] - use_layer_norm: False - activation: leaky_relu - critic_head: - _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/configs/network/stacked_lrm.yaml b/stoix/configs/network/stacked_lrm.yaml index c49270b8..778e878f 100644 --- a/stoix/configs/network/stacked_lrm.yaml +++ b/stoix/configs/network/stacked_lrm.yaml @@ -4,16 +4,15 @@ actor_network: pre_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] - use_layer_norm: True + use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.memoroids.layers.StackedMemoroid + _target_: stoix.networks.lrm.layers.StackedLRM num_cells: 2 - lrm_cell_type: ffm + lrm_cell_type: s5 cell_kwargs: - trace_size: 64 - context_size: 4 - output_size: 256 + d_model: 256 + state_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] @@ -26,16 +25,15 @@ critic_network: pre_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] - use_layer_norm: True + use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.memoroids.layers.StackedMemoroid + _target_: stoix.networks.lrm.layers.StackedLRM num_cells: 2 - lrm_cell_type: ffm + lrm_cell_type: s5 cell_kwargs: - trace_size: 64 - context_size: 4 - output_size: 256 + d_model: 256 + state_size: 256 post_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] diff --git a/stoix/networks/memoroids/layers.py b/stoix/networks/lrm/base.py similarity index 60% rename from stoix/networks/memoroids/layers.py rename to stoix/networks/lrm/base.py index 25e60897..255e1d7e 100644 --- a/stoix/networks/memoroids/layers.py +++ b/stoix/networks/lrm/base.py @@ -1,31 +1,36 @@ -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, TypeAlias import chex -from flax import linen as nn +import flax.linen as nn -from stoix.networks.memoroids.ffm import FFMCell -from stoix.networks.memoroids.lru import LRUCell -from stoix.networks.memoroids.s5 import S5Cell -from stoix.networks.memoroids.types import Inputs, RecurrentState +from stoix.networks.lrm.utils import parse_lrm_cell +RecurrentState: TypeAlias = chex.Array +Reset: TypeAlias = chex.Array +Timestep: TypeAlias = chex.Array +InputEmbedding: TypeAlias = chex.Array +Inputs: TypeAlias = Tuple[InputEmbedding, Reset] +ScanInput: TypeAlias = chex.Array -def parse_lrm_cell(lrm_cell_name: str) -> nn.Module: - """Get the lrm cell.""" - lrm_cells: Dict[str, nn.Module] = { - "s5": S5Cell, - "ffm": FFMCell, - "lru": LRUCell, - } - return lrm_cells[lrm_cell_name] + +class LRMCellBase(nn.Module): + def __call__( + self, recurrent_state: RecurrentState, inputs: Inputs + ) -> Tuple[RecurrentState, chex.Array]: + raise NotImplementedError + + @nn.nowrap + def initialize_carry(self, batch_size: int) -> RecurrentState: + raise NotImplementedError -class StackedMemoroid(nn.Module): - lrm_cell_type: nn.Module +class StackedLRM(nn.Module): + lrm_cell_type: LRMCellBase cell_kwargs: Dict[str, Any] num_cells: int def setup(self) -> None: - """Set up the Memoroid cells for the stacked Memoroid.""" + """Set up the LRM cells for the stacked LRM.""" cell_cls = parse_lrm_cell(self.lrm_cell_type) self.cells = [cell_cls(**self.cell_kwargs) for _ in range(self.num_cells)] diff --git a/stoix/networks/memoroids/ffm.py b/stoix/networks/lrm/ffm.py similarity index 91% rename from stoix/networks/memoroids/ffm.py rename to stoix/networks/lrm/ffm.py index 05803331..76a7744c 100644 --- a/stoix/networks/memoroids/ffm.py +++ b/stoix/networks/lrm/ffm.py @@ -1,24 +1,29 @@ -from typing import Tuple +# flake8: noqa +from typing import Sequence, Tuple import chex import jax from flax import linen as nn +from flax.linen.initializers import Initializer from jax import numpy as jnp -from stoix.networks.memoroids.types import ( +from stoix.networks.lrm.base import ( InputEmbedding, Inputs, + LRMCellBase, RecurrentState, Reset, ScanInput, Timestep, ) +# Taken and modified from https://github.com/proroklab/memory-monoids + def init_deterministic_a( memory_size: int, -) -> Tuple[chex.Array, chex.Array]: - def init(key, shape): +) -> Initializer: + def init(key: chex.PRNGKey, shape: Sequence[int]) -> chex.Array: a_low = 1e-6 a_high = 0.5 a = jnp.linspace(a_low, a_high, memory_size) @@ -29,8 +34,8 @@ def init(key, shape): def init_deterministic_b( context_size: int, min_period: int = 1, max_period: int = 1_000 -) -> Tuple[chex.Array, chex.Array]: - def init(key, shape): +) -> Initializer: + def init(key: chex.PRNGKey, shape: Sequence[int]) -> chex.Array: b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) return b @@ -45,7 +50,7 @@ def __call__(self, x: chex.Array) -> chex.Array: return jax.nn.sigmoid(nn.Dense(self.output_size)(x)) -class FFMCell(nn.Module): +class FFMCell(LRMCellBase): trace_size: int context_size: int output_size: int @@ -76,7 +81,8 @@ def setup(self) -> None: self.ln = nn.LayerNorm(use_scale=False, use_bias=False) def map_to_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> ScanInput: - """Given an input embedding, this will map it to the format required for the associative scan.""" + """Given an input embedding, this will map it to the format + required for the associative scan.""" gate_in = self.gate_in(x) pre = self.pre(x) gated_x = pre * gate_in @@ -84,7 +90,8 @@ def map_to_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> ScanIn return scan_input def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> chex.Array: - """Given the recurrent state and the input embedding, this will map the recurrent state back to the output space.""" + """Given the recurrent state and the input embedding, this will map the + recurrent state back to the output space.""" T = recurrent_state.shape[0] B = recurrent_state.shape[1] z_in = jnp.concatenate( diff --git a/stoix/networks/memoroids/lru.py b/stoix/networks/lrm/lru.py similarity index 87% rename from stoix/networks/memoroids/lru.py rename to stoix/networks/lrm/lru.py index eb963f55..2ea56eab 100644 --- a/stoix/networks/memoroids/lru.py +++ b/stoix/networks/lrm/lru.py @@ -1,24 +1,27 @@ +# flake8: noqa import functools from functools import partial -from typing import Tuple +from typing import Sequence, Tuple import chex import jax import jax.numpy as jnp from flax import linen as nn -from stoix.networks.memoroids.types import ( +from stoix.networks.lrm.base import ( InputEmbedding, Inputs, + LRMCellBase, RecurrentState, Reset, ScanInput, ) +# Taken and modified from https://github.com/proroklab/memory-monoids # Parallel scan operations @jax.vmap -def binary_operator_diag(q_i, q_j): +def binary_operator_diag(q_i: chex.Array, q_j: chex.Array) -> Tuple[chex.Array, chex.Array]: """Binary operator for parallel scan of linear recurrence""" A_i, b_i = q_i A_j, b_j = q_j @@ -46,27 +49,40 @@ def wrapped_associative_update(carry: chex.Array, incoming: chex.Array) -> Tuple return (start_out, *out) -def matrix_init(key, shape, dtype=jnp.float32, normalization=1): +def matrix_init( + key: chex.PRNGKey, + shape: Sequence[int], + dtype: jnp.dtype = jnp.float32, + normalization: float = 1, +) -> chex.Array: return jax.random.normal(key=key, shape=shape, dtype=dtype) / normalization -def nu_init(key, shape, r_min, r_max, dtype=jnp.float32): +def nu_init( + key: chex.PRNGKey, + shape: Sequence[int], + r_min: float, + r_max: float, + dtype: jnp.dtype = jnp.float32, +) -> chex.Array: u = jax.random.uniform(key=key, shape=shape, dtype=dtype) return jnp.log(-0.5 * jnp.log(u * (r_max**2 - r_min**2) + r_min**2)) -def theta_init(key, shape, max_phase, dtype=jnp.float32): +def theta_init( + key: chex.PRNGKey, shape: Sequence[int], max_phase: float, dtype: jnp.dtype = jnp.float32 +) -> chex.Array: u = jax.random.uniform(key, shape=shape, dtype=dtype) return jnp.log(max_phase * u) -def gamma_log_init(key, lamb): +def gamma_log_init(key: chex.PRNGKey, lamb: Tuple[chex.Array, chex.Array]) -> chex.Array: nu, theta = lamb diag_lambda = jnp.exp(-jnp.exp(nu) + 1j * jnp.exp(theta)) return jnp.log(jnp.sqrt(1 - jnp.abs(diag_lambda) ** 2)) -class LRUCell(nn.Module): +class LRUCell(LRMCellBase): """ LRU module in charge of the recurrent processing. Implementation following the one of Orvieto et al. 2023. diff --git a/stoix/networks/memoroids/s5.py b/stoix/networks/lrm/s5.py similarity index 84% rename from stoix/networks/memoroids/s5.py rename to stoix/networks/lrm/s5.py index 0f9a072f..fa5ab9ca 100644 --- a/stoix/networks/memoroids/s5.py +++ b/stoix/networks/lrm/s5.py @@ -1,27 +1,29 @@ +# flake8: noqa import functools -from functools import partial -from typing import Tuple +from typing import Sequence, Tuple import chex import jax -import jax.numpy as np import jax.numpy as jnp -import optax from flax import linen as nn +from flax.linen.initializers import Initializer from jax import random from jax.nn.initializers import lecun_normal, normal from jax.numpy.linalg import eigh -from stoix.networks.memoroids.types import ( +from stoix.networks.lrm.base import ( InputEmbedding, Inputs, + LRMCellBase, RecurrentState, Reset, ScanInput, ) +# S5 code taken and modified from https://github.com/luchris429/popjaxrl -def log_step_initializer(dt_min=0.001, dt_max=0.1): + +def log_step_initializer(dt_min: float = 0.001, dt_max: float = 0.1) -> Initializer: """Initialize the learnable timescale Delta by sampling uniformly between dt_min and dt_max. Args: @@ -31,7 +33,7 @@ def log_step_initializer(dt_min=0.001, dt_max=0.1): init function """ - def init(key, shape): + def init(key: chex.PRNGKey, shape: Sequence[int]) -> chex.Array: """Init function Args: key: jax random key @@ -39,12 +41,12 @@ def init(key, shape): Returns: sampled log_step (float32) """ - return random.uniform(key, shape) * (np.log(dt_max) - np.log(dt_min)) + np.log(dt_min) + return random.uniform(key, shape) * (jnp.log(dt_max) - jnp.log(dt_min)) + jnp.log(dt_min) return init -def init_log_steps(key, input): +def init_log_steps(key: chex.PRNGKey, input: Tuple[int, float, float]) -> chex.Array: """Initialize an array of learnable timescale parameters Args: key: jax random key @@ -60,10 +62,12 @@ def init_log_steps(key, input): log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,)) log_steps.append(log_step) - return np.array(log_steps) + return jnp.array(log_steps) -def init_VinvB(init_fun, rng, shape, Vinv): +def init_VinvB( + init_fun: Initializer, rng: chex.PRNGKey, shape: Sequence[int], Vinv: chex.Array +) -> chex.Array: """Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. Note we will parameterize this with two different matrices for complex numbers. @@ -79,10 +83,10 @@ def init_VinvB(init_fun, rng, shape, Vinv): VinvB = Vinv @ B VinvB_real = VinvB.real VinvB_imag = VinvB.imag - return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) + return jnp.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) -def trunc_standard_normal(key, shape): +def trunc_standard_normal(key: chex.PRNGKey, shape: Sequence[int]) -> chex.Array: """Sample C with a truncated normal distribution with standard deviation 1. Args: key: jax random key @@ -96,10 +100,12 @@ def trunc_standard_normal(key, shape): key, skey = random.split(key) C = lecun_normal()(skey, shape=(1, P, 2)) Cs.append(C) - return np.array(Cs)[:, 0] + return jnp.array(Cs)[:, 0] -def init_CV(init_fun, rng, shape, V): +def init_CV( + init_fun: Initializer, rng: chex.PRNGKey, shape: Sequence[int], V: chex.Array +) -> chex.Array: """Initialize C_tilde=CV. First sample C. Then compute CV. Note we will parameterize this with two different matrices for complex numbers. @@ -116,11 +122,13 @@ def init_CV(init_fun, rng, shape, V): CV = C @ V CV_real = CV.real CV_imag = CV.imag - return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1) + return jnp.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1) # Discretization functions -def discretize_bilinear(Lambda, B_tilde, Delta): +def discretize_bilinear( + Lambda: chex.Array, B_tilde: chex.Array, Delta: chex.Array +) -> Tuple[chex.Array, chex.Array]: """Discretize a diagonalized, continuous-time linear SSM using bilinear transform method. Args: @@ -130,7 +138,7 @@ def discretize_bilinear(Lambda, B_tilde, Delta): Returns: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) """ - Identity = np.ones(Lambda.shape[0]) + Identity = jnp.ones(Lambda.shape[0]) BL = 1 / (Identity - (Delta / 2.0) * Lambda) Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda) @@ -138,7 +146,9 @@ def discretize_bilinear(Lambda, B_tilde, Delta): return Lambda_bar, B_bar -def discretize_zoh(Lambda, B_tilde, Delta): +def discretize_zoh( + Lambda: chex.Array, B_tilde: chex.Array, Delta: chex.Array +) -> Tuple[chex.Array, chex.Array]: """Discretize a diagonalized, continuous-time linear SSM using zero-order hold method. Args: @@ -148,15 +158,17 @@ def discretize_zoh(Lambda, B_tilde, Delta): Returns: discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) """ - Identity = np.ones(Lambda.shape[0]) - Lambda_bar = np.exp(Lambda * Delta) + Identity = jnp.ones(Lambda.shape[0]) + Lambda_bar = jnp.exp(Lambda * Delta) B_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde return Lambda_bar, B_bar # Parallel scan operations @jax.vmap -def binary_operator_reset(q_i, q_j): +def binary_operator_reset( + q_i: chex.Array, q_j: chex.Array +) -> Tuple[chex.Array, chex.Array, chex.Array]: """Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. Args: q_i: tuple containing A_i and Bu_i at position i (P,), (P,) @@ -173,7 +185,7 @@ def binary_operator_reset(q_i, q_j): ) -def make_HiPPO(N): +def make_HiPPO(N: int) -> chex.Array: """Create a HiPPO-LegS matrix. From https://github.com/srush/annotated-s4/blob/main/s4/s4.py Args: @@ -181,13 +193,13 @@ def make_HiPPO(N): Returns: N x N HiPPO LegS matrix """ - P = np.sqrt(1 + 2 * np.arange(N)) - A = P[:, np.newaxis] * P[np.newaxis, :] - A = np.tril(A) - np.diag(np.arange(N)) + P = jnp.sqrt(1 + 2 * jnp.arange(N)) + A = P[:, jnp.newaxis] * P[jnp.newaxis, :] + A = jnp.tril(A) - jnp.diag(jnp.arange(N)) return -A -def make_NPLR_HiPPO(N): +def make_NPLR_HiPPO(N: int) -> Tuple[chex.Array, chex.Array, chex.Array]: """ Makes components needed for NPLR representation of HiPPO-LegS From https://github.com/srush/annotated-s4/blob/main/s4/s4.py @@ -200,14 +212,14 @@ def make_NPLR_HiPPO(N): hippo = make_HiPPO(N) # Add in a rank 1 term. Makes it Normal. - P = np.sqrt(np.arange(N) + 0.5) + P = jnp.sqrt(jnp.arange(N) + 0.5) # HiPPO also specifies the B matrix - B = np.sqrt(2 * np.arange(N) + 1.0) + B = jnp.sqrt(2 * jnp.arange(N) + 1.0) return hippo, P, B -def make_DPLR_HiPPO(N): +def make_DPLR_HiPPO(N: int) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array]: """ Makes components needed for DPLR representation of HiPPO-LegS From https://github.com/srush/annotated-s4/blob/main/s4/s4.py @@ -220,10 +232,10 @@ def make_DPLR_HiPPO(N): """ A, P, B = make_NPLR_HiPPO(N) - S = A + P[:, np.newaxis] * P[np.newaxis, :] + S = A + P[:, jnp.newaxis] * P[jnp.newaxis, :] - S_diag = np.diagonal(S) - Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) + S_diag = jnp.diagonal(S) + Lambda_real = jnp.mean(S_diag) * jnp.ones_like(S_diag) # Diagonalize S to V \Lambda V^* Lambda_imag, V = eigh(S * -1j) @@ -234,7 +246,7 @@ def make_DPLR_HiPPO(N): return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig -class S5Cell(nn.Module): +class S5Cell(LRMCellBase): d_model: int state_size: int blocks: int = 1 @@ -253,7 +265,7 @@ class S5Cell(nn.Module): bidirectional: bool = False step_rescale: float = 1.0 - def setup(self): + def setup(self) -> None: """Initializes parameters once and performs discretization each time the SSM is applied to a sequence """ @@ -284,7 +296,7 @@ def setup(self): self.Lambda_re = self.param("Lambda_re", lambda rng, shape: self.Lambda_re_init, (None,)) self.Lambda_im = self.param("Lambda_im", lambda rng, shape: self.Lambda_im_init, (None,)) if self.clip_eigs: - self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im + self.Lambda = jnp.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im else: self.Lambda = self.Lambda_re + 1j * self.Lambda_im @@ -328,7 +340,7 @@ def setup(self): C1 = self.C1[..., 0] + 1j * self.C1[..., 1] C2 = self.C2[..., 0] + 1j * self.C2[..., 1] - self.C_tilde = np.concatenate((C1, C2), axis=-1) + self.C_tilde = jnp.concatenate((C1, C2), axis=-1) else: self.C = self.param( @@ -342,7 +354,7 @@ def setup(self): # Initialize learnable discretization timescale value self.log_step = self.param("log_step", init_log_steps, (self.P, self.dt_min, self.dt_max)) - step = self.step_rescale * np.exp(self.log_step[:, 0]) + step = self.step_rescale * jnp.exp(self.log_step[:, 0]) # Discretize if self.discretization in ["zoh"]: @@ -362,7 +374,7 @@ def setup(self): self.norm = nn.LayerNorm() - def map_to_h(self, recurrent_state: RecurrentState, x: Inputs) -> ScanInput: + def map_to_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> ScanInput: if self.prenorm and self.do_norm: x = self.norm(x) diff --git a/stoix/networks/lrm/utils.py b/stoix/networks/lrm/utils.py new file mode 100644 index 00000000..3c35bde5 --- /dev/null +++ b/stoix/networks/lrm/utils.py @@ -0,0 +1,15 @@ +from typing import Any, Dict + +from stoix.networks.lrm.ffm import FFMCell +from stoix.networks.lrm.lru import LRUCell +from stoix.networks.lrm.s5 import S5Cell + + +def parse_lrm_cell(lrm_cell_name: str) -> Any: + """Parse a linear recurrent model layer.""" + lrm_cells: Dict[str, Any] = { + "s5": S5Cell, + "ffm": FFMCell, + "lru": LRUCell, + } + return lrm_cells[lrm_cell_name] diff --git a/stoix/networks/memoroids/old_code/old_code1.py b/stoix/networks/memoroids/old_code/old_code1.py deleted file mode 100644 index 817d2275..00000000 --- a/stoix/networks/memoroids/old_code/old_code1.py +++ /dev/null @@ -1,357 +0,0 @@ -from functools import partial -from typing import Any, List, Tuple - -import flax.linen as nn -import jax -import jax.numpy as jnp - - -def recurrent_associative_scan( - cell: nn.Module, - state: jax.Array, - inputs: jax.Array, - axis: int = 0, -) -> jax.Array: - """Execute the associative scan to update the recurrent state. - - Note that we do a trick here by concatenating the previous state to the inputs. - This is allowed since the scan is associative. This ensures that the previous - recurrent state feeds information into the scan. Without this method, we need - separate methods for rollouts and training.""" - - # Concatenate the prevous state to the inputs and scan over the result - # This ensures the previous recurrent state contributes to the current batch - # state: [start, (x, j)] - # inputs: [start, (x, j)] - scan_inputs = jax.tree.map(lambda x, s: jnp.concatenate([s, x], axis=0), inputs, state) - new_state = jax.lax.associative_scan( - cell, - scan_inputs, - axis=axis, - ) - # The zeroth index corresponds to the previous recurrent state - # We just use it to ensure continuity - # We do not actually want to use these values, so slice them away - return jax.tree.map(lambda x: x[1:], new_state) - - -class Gate(nn.Module): - """Sigmoidal gating""" - - output_size: int - - @nn.compact - def __call__(self, x): - x = nn.Dense(self.output_size)(x) - x = nn.sigmoid(x) - return x - - -def init_deterministic( - memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1_000 -) -> Tuple[jax.Array, jax.Array]: - """Deterministic initialization of the FFM parameters.""" - a_low = 1e-6 - a_high = 0.5 - a = jnp.linspace(a_low, a_high, memory_size) - b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) - return a, b - - -def init_random( - memory_size: int, context_size: int, min_period: int = 1, max_period: int = 10_000, *, key -) -> Tuple[jax.Array, jax.Array]: - _, k1, k2 = jax.random.split(key, 3) - a_low = 1e-6 - a_high = 0.1 - a = jax.random.uniform(k1, (memory_size,), minval=a_low, maxval=a_high) - b = ( - 2 - * jnp.pi - / jnp.exp( - jax.random.uniform( - k2, (context_size,), minval=jnp.log(min_period), maxval=jnp.log(max_period) - ) - ) - ) - return a, b - - -class FFMCell(nn.Module): - """The binary associative update function for the FFM.""" - - trace_size: int - context_size: int - output_size: int - deterministic_init: bool = True - - def setup(self): - if self.deterministic_init: - a, b = init_deterministic(self.trace_size, self.context_size) - else: - # TODO: Will this result in the same keys for multiple FFMCells? - key = self.make_rng("ffa_params") - a, b = init_random(self.trace_size, self.context_size, key=key) - self.params = (self.param("ffa_a", lambda rng: a), self.param("ffa_b", lambda rng: b)) - - def log_gamma(self, t: jax.Array) -> jax.Array: - a, b = self.params - a = -jnp.abs(a).reshape((1, self.trace_size, 1)) - b = b.reshape(1, 1, self.context_size) - ab = jax.lax.complex(a, b) - return ab * t.reshape(t.shape[0], 1, 1) - - def gamma(self, t: jax.Array) -> jax.Array: - return jnp.exp(self.log_gamma(t)) - - def initialize_carry(self, batch_size: int = None): - if batch_size is None: - return jnp.zeros( - (1, self.trace_size, self.context_size), dtype=jnp.complex64 - ), jnp.ones((1,), dtype=jnp.int32) - - return jnp.zeros( - (1, batch_size, self.trace_size, self.context_size), dtype=jnp.complex64 - ), jnp.ones((1, batch_size), dtype=jnp.int32) - - def __call__(self, carry, incoming): - ( - state, - i, - ) = carry - x, j = incoming - state = state * self.gamma(j) + x - return state, j + i - - -class MemoroidResetWrapper(nn.Module): - """A wrapper around memoroid cells like FFM, LRU, etc that resets - the recurrent state upon a reset signal.""" - - cell: nn.Module - - def __call__(self, carry, incoming): - states, prev_start = carry - xs, start = incoming - - def reset_state(start, current_state, initial_state): - # Expand to reset all dims of state: [B, 1, 1, ...] - expanded_start = start.reshape(-1, *([1] * (current_state.ndim - 1))) - out = current_state * jnp.logical_not(expanded_start) + initial_state - return out - - initial_states = self.cell.initialize_carry() - states = jax.tree.map(partial(reset_state, start), states, initial_states) - out = self.cell(states, xs) - start_carry = jnp.logical_or(start, prev_start) - - return out, start_carry - - def initialize_carry(self, batch_size: int = None): - if batch_size is None: - # TODO: Should this be one or zero? - return self.cell.initialize_carry(batch_size), jnp.zeros((1,), dtype=bool) - - return self.cell.initialize_carry(batch_size), jnp.zeros((batch_size,), dtype=bool) - - -class FFM(nn.Module): - """Fast and Forgetful Memory""" - - trace_size: int - context_size: int - output_size: int - cell: nn.Module - - def setup(self): - self.pre = nn.Dense(self.trace_size) - self.gate_in = Gate(self.trace_size) - self.ffa = FFMCell(self.trace_size, self.context_size, self.output_size) - self.gate_out = Gate(self.output_size) - self.skip = nn.Dense(self.output_size) - self.mix = nn.Dense(self.output_size) - self.ln = nn.LayerNorm(use_scale=False, use_bias=False) - - def map_to_h(self, inputs): - """Map from the input space to the recurrent state space""" - x, resets = inputs - gate_in = self.gate_in(x) - pre = self.pre(x) - gated_x = pre * gate_in - # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous - ts = jnp.ones(x.shape[0], dtype=jnp.int32) - z = jnp.repeat(jnp.expand_dims(gated_x, 2), self.context_size, axis=2) - return (z, ts), resets - - def map_from_h(self, recurrent_state, inputs): - """Map from the recurrent space to the Markov space""" - (state, ts), reset = recurrent_state - (x, start) = inputs - z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape( - state.shape[0], -1 - ) - z = self.mix(z_in) - gate_out = self.gate_out(x) - skip = self.skip(x) - out = self.ln(z * gate_out) + skip * (1 - gate_out) - return out - - def __call__(self, recurrent_state, inputs): - # Recurrent state should be ((state, timestep), reset) - # Inputs should be (x, reset) - h = self.map_to_h(inputs) - recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h) - # recurrent_state is ((state, timestep), reset) - out = self.map_from_h(recurrent_state, inputs) - - # TODO: Remove this when we want to return all recurrent states instead of just the last one - final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) - return final_recurrent_state, out - - def initialize_carry(self, batch_size: int = None): - return self.cell.initialize_carry(batch_size) - - -class SFFM(nn.Module): - """Simplified Fast and Forgetful Memory""" - - trace_size: int - context_size: int - hidden_size: int - cell: nn.Module - - def setup(self): - self.W_trace = nn.Dense(self.trace_size) - self.W_context = Gate(self.context_size) - self.ffa = FFMCell( - self.trace_size, self.context_size, self.hidden_size, deterministic_init=False - ) - self.post = nn.Sequential( - [ - # Default init but with smaller weights - nn.Dense( - self.hidden_size, - kernel_init=nn.initializers.variance_scaling( - 0.01, "fan_in", "truncated_normal" - ), - ), - nn.LayerNorm(), - nn.leaky_relu, - nn.Dense(self.hidden_size), - nn.LayerNorm(), - nn.leaky_relu, - ] - ) - - def map_to_h(self, inputs): - x, resets = inputs - pre = jnp.abs(jnp.einsum("bi, bj -> bij", self.W_trace(x), self.W_context(x))) - pre = pre / jnp.sum(pre, axis=(-2, -1), keepdims=True) - # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous - ts = jnp.ones(x.shape[0], dtype=jnp.int32) - return (pre, ts), resets - - def map_from_h(self, recurrent_state, inputs): - x, resets = inputs - (state, ts), reset = recurrent_state - s = state.reshape(state.shape[0], self.context_size * self.trace_size) - eps = s.real + (s.real == 0 + jnp.sign(s.real)) * 0.01 - s = s + eps - scaled = jnp.concatenate( - [ - jnp.log(1 + jnp.abs(s)) * jnp.sin(jnp.angle(s)), - jnp.log(1 + jnp.abs(s)) * jnp.cos(jnp.angle(s)), - ], - axis=-1, - ) - z = self.post(scaled) - return z - - def __call__(self, recurrent_state, inputs): - # Recurrent state should be ((state, timestep), reset) - # Inputs should be (x, reset) - h = self.map_to_h(inputs) - recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h) - # recurrent_state is ((state, timestep), reset) - out = self.map_from_h(recurrent_state, inputs) - - # TODO: Remove this when we want to return all recurrent states instead of just the last one - final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) - return final_recurrent_state, out - - def initialize_carry(self, batch_size: int = None): - return self.cell.initialize_carry(batch_size) - - -class StackedSFFM(nn.Module): - """A multilayer version of SFFM""" - - cells: List[nn.Module] - - def setup(self): - self.project = nn.Dense(cells[0].hidden_size) - - def __call__(self, recurrent_state: jax.Array, inputs: Any) -> Tuple[jax.Array, jax.Array]: - x, start = inputs - x = self.project(x) - inputs = x, start - for i, cell in enumerate(self.cells): - s, y = cell(recurrent_state[i], inputs) - x = x + y - recurrent_state[i] = s - return y, recurrent_state - - def initialize_carry(self, batch_size: int = None): - return [c.initialize_carry(batch_size) for c in self.cells] - - -if __name__ == "__main__": - m = FFM( - output_size=4, - trace_size=5, - context_size=6, - cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)), - ) - s = m.initialize_carry() - x = jnp.ones((10, 2)) - start = jnp.zeros(10, dtype=bool) - params = m.init(jax.random.PRNGKey(0), s, (x, start)) - out_state, out = m.apply(params, s, (x, start)) - - # BatchFFM = nn.vmap( - # FFM, in_axes=1, out_axes=1, variable_axes={"params": None}, split_rngs={"params": False} - # ) - - # m = BatchFFM( - # trace_size=4, - # context_size=5, - # output_size=6, - # cell=MemoroidResetWrapper(cell=FFMCell(4,5,6)) - # ) - - # s = m.initialize_carry(8) - # x = jnp.ones((10, 8, 2)) - # start = jnp.zeros((10, 8), dtype=bool) - # params = m.init(jax.random.PRNGKey(0), s, (x, start)) - # out_state, out = m.apply(params, s, (x, start)) - - # print(out.shape) - # print(out_state.shape) - - # TODO: Initialize cells with different random streams so the weights are not identical - cells = [ - SFFM( - trace_size=4, - context_size=5, - hidden_size=6, - cell=MemoroidResetWrapper(cell=FFMCell(4, 5, 6)), - ) - for i in range(3) - ] - s2fm = StackedSFFM(cells=cells) - - s = s2fm.initialize_carry() - x = jnp.ones((10, 2)) - start = jnp.zeros(10, dtype=bool) - params = s2fm.init(jax.random.PRNGKey(0), s, (x, start)) - out_state, out = s2fm.apply(params, s, (x, start)) diff --git a/stoix/networks/memoroids/old_code/old_code2.py b/stoix/networks/memoroids/old_code/old_code2.py deleted file mode 100644 index 852c5d00..00000000 --- a/stoix/networks/memoroids/old_code/old_code2.py +++ /dev/null @@ -1,420 +0,0 @@ -# CURRENTLY NOT BEING USED - -# from functools import partial -# from typing import Optional, Tuple - -# import chex -# import flax.linen as nn -# import jax -# import jax.numpy as jnp -# import optax - -# # Typing aliases -# Carry = chex.ArrayTree - -# HiddenState = chex.Array -# Timestep = chex.Array -# Reset = chex.Array - -# RecurrentState = Tuple[HiddenState, Timestep] - -# InputEmbedding = chex.Array -# Inputs = Tuple[InputEmbedding, Reset] - - -# def debug_shape(x): -# return jax.tree.map(lambda x: x.shape, x) - - -# class MemoroidCellBase(nn.Module): -# """Memoroid cell base class.""" - -# def map_to_h(self, inputs: Inputs) -> RecurrentState: -# """Map from the input space to the recurrent state space""" -# raise NotImplementedError - -# def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> HiddenState: -# """Map from the recurrent space to the Markov space""" -# raise NotImplementedError - -# @nn.nowrap -# def initialize_carry( -# self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None -# ) -> RecurrentState: -# """Initialize the Memoroid cell carry. - -# Args: -# batch_size: the batch size of the carry. -# rng: random number generator passed to the init_fn. - -# Returns: -# An initialized carry for the given RNN cell. -# """ -# raise NotImplementedError - -# @property -# def num_feature_axes(self) -> int: -# """Returns the number of feature axes of the cell.""" -# raise NotImplementedError - - -# def recurrent_associative_scan( -# cell: nn.Module, -# state: RecurrentState, -# inputs: RecurrentState, -# axis: int = 0, -# ) -> RecurrentState: -# """Execute the associative scan to update the recurrent state. - -# Note that we do a trick here by concatenating the previous state to the inputs. -# This is allowed since the scan is associative. This ensures that the previous -# recurrent state feeds information into the scan. Without this method, we need -# separate methods for rollouts and training.""" - -# # Concatenate the previous state to the inputs and scan over the result -# # This ensures the previous recurrent state contributes to the current batch - -# scan_inputs = jax.tree.map(lambda s, x: jnp.concatenate([s, x], axis=axis), state, inputs) -# new_state = jax.lax.associative_scan( -# cell, -# scan_inputs, -# axis=axis, -# ) - -# # The zeroth index corresponds to the previous recurrent state -# # We just use it to ensure continuity -# # We do not actually want to use these values, so slice them away -# return jax.tree.map( -# lambda x: jax.lax.slice_in_dim(x, start_index=1, limit_index=None, axis=axis), new_state -# ) - - -# class Gate(nn.Module): -# """Sigmoidal gating""" - -# output_size: int - -# @nn.compact -# def __call__(self, x): -# x = nn.Dense(self.output_size)(x) -# x = nn.sigmoid(x) -# return x - - -# def init_deterministic( -# memory_size: int, context_size: int, min_period: int = 1, max_period: int = 1000 -# ) -> Tuple[chex.Array, chex.Array]: -# """Deterministic initialization of the FFM parameters.""" -# a_low = 1e-6 -# a_high = 0.5 -# a = jnp.linspace(a_low, a_high, memory_size) -# b = 2 * jnp.pi / jnp.linspace(min_period, max_period, context_size) -# return a, b - - -# class FFMCell(MemoroidCellBase): -# """The binary associative update function for the FFM.""" - -# trace_size: int -# context_size: int -# output_size: int - -# def setup(self): - -# # Create the parameters that are explicitly used in the cells core computation -# a, b = init_deterministic(self.trace_size, self.context_size) -# self.params = (self.param("ffa_a", lambda rng: a), self.param("ffa_b", lambda rng: b)) - -# # Create the networks and parameters that are used when -# # mapping from input space to recurrent state space -# # This is used in the map_to_h method and is used in the -# # associative scan outer loop -# self.pre = nn.Dense(self.trace_size) -# self.gate_in = Gate(self.trace_size) -# self.gate_out = Gate(self.output_size) -# self.skip = nn.Dense(self.output_size) -# self.mix = nn.Dense(self.output_size) -# self.ln = nn.LayerNorm(use_scale=False, use_bias=False) - -# def map_to_h(self, x: InputEmbedding) -> RecurrentState: -# """Map from the input space to the recurrent state space - unlike the call function -# this explicitly expects a shape including the sequence dimension. This is used in the -# outer network that uses the associative scan.""" -# gate_in = self.gate_in(x) -# pre = self.pre(x) -# gated_x = pre * gate_in -# # We also need relative timesteps, i.e., each observation is 1 timestep newer than the previous -# ts = jnp.ones(x.shape[0:2], dtype=jnp.int32) -# z = jnp.repeat(jnp.expand_dims(gated_x, 3), self.context_size, axis=3) -# return (z, ts) - -# def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> HiddenState: -# """Map from the recurrent space to the Markov space""" -# (state, _), _ = recurrent_state -# z_in = jnp.concatenate([jnp.real(state), jnp.imag(state)], axis=-1).reshape( -# state.shape[0], state.shape[1], -1 -# ) -# z = self.mix(z_in) -# gate_out = self.gate_out(x) -# skip = self.skip(x) -# out = self.ln(z * gate_out) + skip * (1 - gate_out) -# return out - -# def log_gamma(self, t: chex.Array) -> chex.Array: -# a, b = self.params -# a = -jnp.abs(a).reshape((1, 1, self.trace_size, 1)) -# b = b.reshape(1, 1, 1, self.context_size) -# ab = jax.lax.complex(a, b) -# return ab * t.reshape(t.shape[0], t.shape[1], 1, 1) - -# def gamma(self, t: chex.Array) -> chex.Array: -# return jnp.exp(self.log_gamma(t)) - -# @nn.nowrap -# def initialize_carry( -# self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None -# ) -> RecurrentState: -# # inputs should be of shape [*batch, time, feature] -# # recurrent states should be of shape [*batch, 1, feature] -# carry_shape = (1, self.trace_size, self.context_size) -# t_shape = (1,) -# if batch_size is not None: -# carry_shape = (carry_shape[0], batch_size, *carry_shape[1:]) -# t_shape = (*t_shape, batch_size) - -# return jnp.zeros(carry_shape, dtype=jnp.complex64), jnp.zeros(t_shape, dtype=jnp.int32) - -# def __call__(self, carry: RecurrentState, incoming): -# ( -# state, -# i, -# ) = carry -# x, j = incoming -# state = state * self.gamma(j) + x -# return state, j + i - - -# class MemoroidResetWrapper(MemoroidCellBase): -# """A wrapper around memoroid cells like FFM, LRU, etc that resets -# the recurrent state upon a reset signal.""" - -# cell: nn.Module - -# def __call__(self, carry, incoming, rng=None): -# states, prev_carry_reset_flag = carry -# xs, start = incoming - -# def reset_state(start: Reset, current_state, initial_state): -# # Expand to reset all dims of state: [1, B, 1, ...] -# assert initial_state.ndim == current_state.ndim -# expanded_start = start.reshape(-1, start.shape[1], *([1] * (current_state.ndim - 2))) -# out = current_state * jnp.logical_not(expanded_start) + initial_state -# return out - -# # Add an extra dim, as start will be [Batch] while intialize carry expects [Batch, Feature] -# initial_states = self.cell.initialize_carry(rng=rng, batch_size=start.shape[1]) -# states = jax.tree.map(partial(reset_state, start), states, initial_states) -# out = self.cell(states, xs) -# carry_reset_flag = jnp.logical_or(start, prev_carry_reset_flag) - -# return out, carry_reset_flag - -# def map_to_h(self, x: InputEmbedding) -> RecurrentState: -# return self.cell.map_to_h(x) - -# def map_from_h(self, recurrent_state: RecurrentState, x: InputEmbedding) -> HiddenState: -# return self.cell.map_from_h(recurrent_state, x) - -# @nn.nowrap -# def initialize_carry( -# self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None -# ) -> RecurrentState: -# return self.cell.initialize_carry(batch_size, rng), jnp.zeros((1, batch_size), dtype=bool) - - -# class ScannedMemoroid(nn.Module): -# cell: nn.Module - -# @nn.compact -# def __call__( -# self, recurrent_state: RecurrentState, inputs: Inputs -# ) -> Tuple[RecurrentState, HiddenState]: -# """Apply the ScannedMemoroid. -# This takes in a sequence of batched states and inputs. -# The recurrent state that is used requires no sequence dimension but does require a batch dimension.""" -# # Recurrent state should be (state, timestep) -# # Inputs should be (x, reset) - -# # Unsqueeze the recurrent state to add the sequence dimension of size 1 -# recurrent_state = jax.tree.map(lambda x: jnp.expand_dims(x, 0), recurrent_state) - -# x, resets = inputs -# h = self.cell.map_to_h(x) -# # TODO: In the original implementation, the recurrent timestep is also one -# # recurrent_state = ( -# # (recurrent_state[0][0], -# # jnp.ones_like(recurrent_state[0][1])), -# # recurrent_state[1] -# # ) -# recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, (h, resets)) -# # recurrent_state is (state, timestep) -# out = self.cell.map_from_h(recurrent_state, x) - -# # TODO: Remove this when we want to return all recurrent states instead of just the last one -# final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state) - -# # Squeeze the sequence dimension of 1 out -# final_recurrent_state = jax.tree.map(lambda x: jnp.squeeze(x, 0), final_recurrent_state) - -# return final_recurrent_state, out - -# @nn.nowrap -# def initialize_carry( -# self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None -# ) -> RecurrentState: -# """Initialize the carry for the ScannedMemoroid. This returns the carry in the shape [Batch, ...] i.e. it contains no sequence dimension""" -# # We squeeze the sequence dim of 1 out. -# return jax.tree.map(lambda x: x.squeeze(0), self.cell.initialize_carry(batch_size, rng)) - - -# def test_reset_wrapper(): -# """Validate that the reset wrapper works as expected""" -# BatchFFM = ScannedMemoroid - -# m = BatchFFM( -# cell=MemoroidResetWrapper(cell=FFMCell(output_size=2, trace_size=2, context_size=3)) -# ) - -# batch_size = 4 -# time_steps = 100 -# # Have a batched version with one episode per batch -# # and collapse it into a single episode with a single batch (but same start/resets) -# # results should be identical -# batched_starts = jnp.ones([batch_size], dtype=bool) -# # batched_starts = jnp.concatenate([ -# # jnp.zeros([time_steps // 2, batch_size], dtype=bool), -# # batched_starts.reshape(1, -1), -# # jnp.zeros([time_steps // 2 - 1, batch_size], dtype=bool) -# # ], axis=0) -# batched_starts = jax.random.uniform(jax.random.PRNGKey(0), (time_steps, batch_size)) < 0.1 -# contig_starts = jnp.swapaxes(batched_starts, 1, 0).reshape(-1, 1) - -# x_batched = jnp.arange(time_steps * batch_size * 2).reshape((time_steps, batch_size, 2)) -# x_contig = jnp.swapaxes(x_batched, 1, 0).reshape(-1, 1, 2) -# batched_s = m.initialize_carry(batch_size) -# contig_s = m.initialize_carry(1) -# params = m.init(jax.random.PRNGKey(0), batched_s, (x_batched, batched_starts)) - -# ((batched_out_state, batched_ts), batched_reset), batched_out = m.apply( -# params, batched_s, (x_batched, batched_starts) -# ) -# ((contig_out_state, contig_ts), contig_reset), contig_out = m.apply( -# params, contig_s, (x_contig, contig_starts) -# ) - -# # This should be nearly zero (1e-10 or something) -# state_error = jnp.linalg.norm(contig_out_state - batched_out_state[-1], axis=-1).sum() -# print("state error", state_error) -# out_error = jnp.linalg.norm( -# batched_out - jnp.swapaxes(contig_out.reshape(batch_size, time_steps, -1), 1, 0), axis=-1 -# ).sum() -# print("out error", out_error) -# print(batched_ts, contig_ts) - - -# def test_reset_wrapper_ts(): -# BatchFFM = ScannedMemoroid - -# m = BatchFFM( -# cell=MemoroidResetWrapper(cell=FFMCell(output_size=2, trace_size=2, context_size=3)) -# ) - -# batch_size = 2 -# time_steps = 10 -# # Have a batched version with one episode per batch -# # and collapse it into a single episode with a single batch (but same start/resets) -# # results should be identical -# batched_starts = jnp.array( -# [ -# [False, False, True, False, False, True, True, False, False, False], -# [False, False, True, False, False, True, True, False, False, False], -# ] -# ).T - -# x_batched = ( -# jnp.arange(time_steps * batch_size * 2) -# .reshape((time_steps, batch_size, 2)) -# .astype(jnp.float32) -# ) -# batched_s = m.initialize_carry(batch_size) -# params = m.init(jax.random.PRNGKey(0), batched_s, (x_batched, batched_starts)) - -# ((batched_out_state, batched_ts), batched_reset), batched_out = m.apply( -# params, batched_s, (x_batched, batched_starts) -# ) -# print(batched_ts == 4) - - -# def train_memorize(): -# BatchFFM = ScannedMemoroid - -# m = BatchFFM( -# cell=MemoroidResetWrapper(cell=FFMCell(output_size=128, trace_size=32, context_size=4)) -# ) - -# batch_size = 1 -# rem_ts = 10 -# time_steps = rem_ts * 10 -# obs_space = 2 -# rng = jax.random.PRNGKey(0) -# x = jax.random.randint(rng, (time_steps, batch_size), 0, obs_space).reshape(-1, 1, 1) -# y = jnp.repeat(x[::rem_ts], x.shape[0] // x[::rem_ts].shape[0]).reshape(-1, 1) -# start = jnp.zeros([time_steps, batch_size], dtype=bool).at[::rem_ts].set(True) -# # start = jnp.zeros([time_steps, batch_size], dtype=bool) -# # start = jnp.ones([time_steps, batch_size], dtype=bool) - -# s = m.initialize_carry(batch_size) -# params = m.init(jax.random.PRNGKey(0), s, (x, start)) - -# def error(params, x, start, key): -# s = m.initialize_carry(batch_size) -# x = jax.random.randint(key, (time_steps, batch_size), 0, obs_space).reshape(-1, 1, 1) -# y = jnp.repeat(x[::rem_ts], x.shape[0] // x[::rem_ts].shape[0]).reshape(-1, 1) -# out_state, y_hat = m.apply(params, s, (x, start)) -# return jnp.mean((y - y_hat) ** 2) - -# optimizer = optax.adam(learning_rate=0.002) -# state = optimizer.init(params) -# loss_fn = jax.jit(jax.value_and_grad(error)) -# for step in range(10_000): -# rng = jax.random.split(rng)[0] -# loss, grads = loss_fn(params, x, start, rng) -# updates, state = optimizer.update(grads, state) -# params = optax.apply_updates(params, updates) -# print(f"Step {step+1}, Loss: {loss}") - - -# if __name__ == "__main__": -# # BatchFFM = ScannedMemoroid - -# # m = BatchFFM( -# # cell=MemoroidResetWrapper(cell=FFMCell(output_size=4, trace_size=5, context_size=6)) -# # ) - -# # batch_size = 8 -# # time_steps = 10 - -# # y = jnp.ones((time_steps, batch_size, 2)) -# # s = m.initialize_carry(batch_size) -# # start = jnp.zeros((time_steps, batch_size), dtype=bool) -# # params = m.init(jax.random.PRNGKey(0), s, (y, start)) -# # out_state, out = m.apply(params, s, (y, start)) - -# # out = jnp.swapaxes(out, 0, 1) - -# # print(out) -# # print(debug_shape(out_state)) - -# # test_reset_wrapper() -# # test_reset_wrapper_ts() -# train_memorize() diff --git a/stoix/networks/memoroids/old_code/old_s5.py b/stoix/networks/memoroids/old_code/old_s5.py deleted file mode 100644 index 40e3abad..00000000 --- a/stoix/networks/memoroids/old_code/old_s5.py +++ /dev/null @@ -1,645 +0,0 @@ -from functools import partial - -import chex -import jax -import jax.numpy as np -import jax.numpy as jnp -from flax import linen as nn -from jax import random -from jax.nn.initializers import lecun_normal, normal -from jax.numpy.linalg import eigh - - -class SequenceLayer(nn.Module): - """Defines a single S5 layer, with S5 SSM, nonlinearity, etc. - Args: - ssm (nn.Module): the SSM to be used (i.e. S5 ssm) - d_model (int32): this is the feature size of the layer inputs and outputs - we usually refer to this size as H - activation (string): Type of activation function to use - prenorm (bool): apply prenorm if true or postnorm if false - step_rescale (float32): allows for uniformly changing the timescale parameter, - e.g. after training on a different resolution for - the speech commands benchmark - """ - - ssm: nn.Module - d_model: int - activation: str = "gelu" - do_norm: bool = True - prenorm: bool = True - do_gtrxl_norm: bool = True - step_rescale: float = 1.0 - - def setup(self): - """Initializes the ssm, layer norm and dense layers""" - self.seq = self.ssm(step_rescale=self.step_rescale) - - if self.activation in ["full_glu"]: - self.out1 = nn.Dense(self.d_model) - self.out2 = nn.Dense(self.d_model) - elif self.activation in ["half_glu1", "half_glu2"]: - self.out2 = nn.Dense(self.d_model) - - self.norm = nn.LayerNorm() - - def __call__(self, hidden, x, d): - """ - Compute the LxH output of S5 layer given an LxH input. - Args: - x (float32): input sequence (L, d_model) - d (bool): reset signal (L,) - Returns: - output sequence (float32): (L, d_model) - """ - skip = x - if self.prenorm and self.do_norm: - x = self.norm(x) - # hidden, x = self.seq(hidden, x, d) - hidden, x = jax.vmap(self.seq, in_axes=1, out_axes=1)(hidden, x, d) - # hidden = jnp.swapaxes(hidden, 1, 0) - if self.do_gtrxl_norm: - x = self.norm(x) - - if self.activation in ["full_glu"]: - x = nn.gelu(x) - x = self.out1(x) * jax.nn.sigmoid(self.out2(x)) - elif self.activation in ["half_glu1"]: - x = nn.gelu(x) - x = x * jax.nn.sigmoid(self.out2(x)) - elif self.activation in ["half_glu2"]: - # Only apply GELU to the gate input - x1 = nn.gelu(x) - x = x * jax.nn.sigmoid(self.out2(x1)) - elif self.activation in ["gelu"]: - x = nn.gelu(x) - else: - raise NotImplementedError("Activation: {} not implemented".format(self.activation)) - - x = skip + x - if not self.prenorm and self.do_norm: - x = self.norm(x) - return hidden, x - - @staticmethod - def initialize_carry(batch_size, hidden_size): - # Use a dummy key since the default state init fn is just zeros. - return jnp.zeros((1, batch_size, hidden_size), dtype=jnp.complex64) - - -def log_step_initializer(dt_min=0.001, dt_max=0.1): - """Initialize the learnable timescale Delta by sampling - uniformly between dt_min and dt_max. - Args: - dt_min (float32): minimum value - dt_max (float32): maximum value - Returns: - init function - """ - - def init(key, shape): - """Init function - Args: - key: jax random key - shape tuple: desired shape - Returns: - sampled log_step (float32) - """ - return random.uniform(key, shape) * (np.log(dt_max) - np.log(dt_min)) + np.log(dt_min) - - return init - - -def init_log_steps(key, input): - """Initialize an array of learnable timescale parameters - Args: - key: jax random key - input: tuple containing the array shape H and - dt_min and dt_max - Returns: - initialized array of timescales (float32): (H,) - """ - H, dt_min, dt_max = input - log_steps = [] - for i in range(H): - key, skey = random.split(key) - log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,)) - log_steps.append(log_step) - - return np.array(log_steps) - - -def init_VinvB(init_fun, rng, shape, Vinv): - """Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. - Note we will parameterize this with two different matrices for complex - numbers. - Args: - init_fun: the initialization function to use, e.g. lecun_normal() - rng: jax random key to be used with init function. - shape (tuple): desired shape (P,H) - Vinv: (complex64) the inverse eigenvectors used for initialization - Returns: - B_tilde (complex64) of shape (P,H,2) - """ - B = init_fun(rng, shape) - VinvB = Vinv @ B - VinvB_real = VinvB.real - VinvB_imag = VinvB.imag - return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) - - -def trunc_standard_normal(key, shape): - """Sample C with a truncated normal distribution with standard deviation 1. - Args: - key: jax random key - shape (tuple): desired shape, of length 3, (H,P,_) - Returns: - sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) - """ - H, P, _ = shape - Cs = [] - for i in range(H): - key, skey = random.split(key) - C = lecun_normal()(skey, shape=(1, P, 2)) - Cs.append(C) - return np.array(Cs)[:, 0] - - -def init_CV(init_fun, rng, shape, V): - """Initialize C_tilde=CV. First sample C. Then compute CV. - Note we will parameterize this with two different matrices for complex - numbers. - Args: - init_fun: the initialization function to use, e.g. lecun_normal() - rng: jax random key to be used with init function. - shape (tuple): desired shape (H,P) - V: (complex64) the eigenvectors used for initialization - Returns: - C_tilde (complex64) of shape (H,P,2) - """ - C_ = init_fun(rng, shape) - C = C_[..., 0] + 1j * C_[..., 1] - CV = C @ V - CV_real = CV.real - CV_imag = CV.imag - return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1) - - -# Discretization functions -def discretize_bilinear(Lambda, B_tilde, Delta): - """Discretize a diagonalized, continuous-time linear SSM - using bilinear transform method. - Args: - Lambda (complex64): diagonal state matrix (P,) - B_tilde (complex64): input matrix (P, H) - Delta (float32): discretization step sizes (P,) - Returns: - discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) - """ - Identity = np.ones(Lambda.shape[0]) - - BL = 1 / (Identity - (Delta / 2.0) * Lambda) - Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda) - B_bar = (BL * Delta)[..., None] * B_tilde - return Lambda_bar, B_bar - - -def discretize_zoh(Lambda, B_tilde, Delta): - """Discretize a diagonalized, continuous-time linear SSM - using zero-order hold method. - Args: - Lambda (complex64): diagonal state matrix (P,) - B_tilde (complex64): input matrix (P, H) - Delta (float32): discretization step sizes (P,) - Returns: - discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) - """ - Identity = np.ones(Lambda.shape[0]) - Lambda_bar = np.exp(Lambda * Delta) - B_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde - return Lambda_bar, B_bar - - -# Parallel scan operations -@jax.vmap -def binary_operator(q_i, q_j): - """Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. - Args: - q_i: tuple containing A_i and Bu_i at position i (P,), (P,) - q_j: tuple containing A_j and Bu_j at position j (P,), (P,) - Returns: - new element ( A_out, Bu_out ) - """ - A_i, b_i = q_i - A_j, b_j = q_j - return A_j * A_i, A_j * b_i + b_j - - -# Parallel scan operations -@jax.vmap -def binary_operator_reset(q_i, q_j): - """Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. - Args: - q_i: tuple containing A_i and Bu_i at position i (P,), (P,) - q_j: tuple containing A_j and Bu_j at position j (P,), (P,) - Returns: - new element ( A_out, Bu_out ) - """ - A_i, b_i, c_i = q_i - A_j, b_j, c_j = q_j - return ( - (A_j * A_i) * (1 - c_j) + A_j * c_j, - (A_j * b_i + b_j) * (1 - c_j) + b_j * c_j, - c_i * (1 - c_j) + c_j, - ) - - -def apply_ssm(Lambda_bar, B_bar, C_tilde, hidden, input_sequence, resets, conj_sym, bidirectional): - """Compute the LxH output of discretized SSM given an LxH input. - Args: - Lambda_bar (complex64): discretized diagonal state matrix (P,) - B_bar (complex64): discretized input matrix (P, H) - C_tilde (complex64): output matrix (H, P) - input_sequence (float32): input sequence of features (L, H) - reset (bool): input sequence of features (L,) - conj_sym (bool): whether conjugate symmetry is enforced - bidirectional (bool): whether bidirectional setup is used, - Note for this case C_tilde will have 2P cols - Returns: - ys (float32): the SSM outputs (S5 layer preactivations) (L, H) - """ - Lambda_elements = Lambda_bar * jnp.ones((input_sequence.shape[0], Lambda_bar.shape[0])) - Bu_elements = jax.vmap(lambda u: B_bar @ u)(input_sequence) - - Lambda_elements = jnp.concatenate( - [ - jnp.ones((1, Lambda_bar.shape[0])), - Lambda_elements, - ] - ) - - Bu_elements = jnp.concatenate( - [ - hidden, - Bu_elements, - ] - ) - - if resets is None: - _, xs = jax.lax.associative_scan(binary_operator, (Lambda_elements, Bu_elements)) - else: - resets = jnp.concatenate( - [ - jnp.zeros(1), - resets, - ] - ) - _, xs, _ = jax.lax.associative_scan( - binary_operator_reset, (Lambda_elements, Bu_elements, resets) - ) - xs = xs[1:] - - if conj_sym: - return xs[np.newaxis, -1], jax.vmap(lambda x: 2 * (C_tilde @ x).real)(xs) - else: - return xs[np.newaxis, -1], jax.vmap(lambda x: (C_tilde @ x).real)(xs) - - -class S5SSM(nn.Module): - Lambda_re_init: chex.Array - Lambda_im_init: chex.Array - V: chex.Array - Vinv: chex.Array - - H: int - P: int - C_init: str - discretization: str - dt_min: float - dt_max: float - conj_sym: bool = True - clip_eigs: bool = False - bidirectional: bool = False - step_rescale: float = 1.0 - - """ The S5 SSM - Args: - Lambda_re_init (complex64): Real part of init diag state matrix (P,) - Lambda_im_init (complex64): Imag part of init diag state matrix (P,) - V (complex64): Eigenvectors used for init (P,P) - Vinv (complex64): Inverse eigenvectors used for init (P,P) - H (int32): Number of features of input seq - P (int32): state size - C_init (string): Specifies How C is initialized - Options: [trunc_standard_normal: sample from truncated standard normal - and then multiply by V, i.e. C_tilde=CV. - lecun_normal: sample from Lecun_normal and then multiply by V. - complex_normal: directly sample a complex valued output matrix - from standard normal, does not multiply by V] - conj_sym (bool): Whether conjugate symmetry is enforced - clip_eigs (bool): Whether to enforce left-half plane condition, i.e. - constrain real part of eigenvalues to be negative. - True recommended for autoregressive task/unbounded sequence lengths - Discussed in https://arxiv.org/pdf/2206.11893.pdf. - bidirectional (bool): Whether model is bidirectional, if True, uses two C matrices - discretization: (string) Specifies discretization method - options: [zoh: zero-order hold method, - bilinear: bilinear transform] - dt_min: (float32): minimum value to draw timescale values from when - initializing log_step - dt_max: (float32): maximum value to draw timescale values from when - initializing log_step - step_rescale: (float32): allows for uniformly changing the timescale parameter, e.g. after training - on a different resolution for the speech commands benchmark - """ - - def setup(self): - """Initializes parameters once and performs discretization each time - the SSM is applied to a sequence - """ - - if self.conj_sym: - # Need to account for case where we actually sample real B and C, and then multiply - # by the half sized Vinv and possibly V - local_P = 2 * self.P - else: - local_P = self.P - - # Initialize diagonal state to state matrix Lambda (eigenvalues) - self.Lambda_re = self.param("Lambda_re", lambda rng, shape: self.Lambda_re_init, (None,)) - self.Lambda_im = self.param("Lambda_im", lambda rng, shape: self.Lambda_im_init, (None,)) - if self.clip_eigs: - self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im - else: - self.Lambda = self.Lambda_re + 1j * self.Lambda_im - - # Initialize input to state (B) matrix - B_init = lecun_normal() - B_shape = (local_P, self.H) - self.B = self.param( - "B", lambda rng, shape: init_VinvB(B_init, rng, shape, self.Vinv), B_shape - ) - B_tilde = self.B[..., 0] + 1j * self.B[..., 1] - - # Initialize state to output (C) matrix - if self.C_init in ["trunc_standard_normal"]: - C_init = trunc_standard_normal - C_shape = (self.H, local_P, 2) - elif self.C_init in ["lecun_normal"]: - C_init = lecun_normal() - C_shape = (self.H, local_P, 2) - elif self.C_init in ["complex_normal"]: - C_init = normal(stddev=0.5**0.5) - else: - raise NotImplementedError("C_init method {} not implemented".format(self.C_init)) - - if self.C_init in ["complex_normal"]: - if self.bidirectional: - C = self.param("C", C_init, (self.H, 2 * self.P, 2)) - self.C_tilde = C[..., 0] + 1j * C[..., 1] - - else: - C = self.param("C", C_init, (self.H, self.P, 2)) - self.C_tilde = C[..., 0] + 1j * C[..., 1] - - else: - if self.bidirectional: - self.C1 = self.param( - "C1", lambda rng, shape: init_CV(C_init, rng, shape, self.V), C_shape - ) - self.C2 = self.param( - "C2", lambda rng, shape: init_CV(C_init, rng, shape, self.V), C_shape - ) - - C1 = self.C1[..., 0] + 1j * self.C1[..., 1] - C2 = self.C2[..., 0] + 1j * self.C2[..., 1] - self.C_tilde = np.concatenate((C1, C2), axis=-1) - - else: - self.C = self.param( - "C", lambda rng, shape: init_CV(C_init, rng, shape, self.V), C_shape - ) - - self.C_tilde = self.C[..., 0] + 1j * self.C[..., 1] - - # Initialize feedthrough (D) matrix - self.D = self.param("D", normal(stddev=1.0), (self.H,)) - - # Initialize learnable discretization timescale value - self.log_step = self.param("log_step", init_log_steps, (self.P, self.dt_min, self.dt_max)) - step = self.step_rescale * np.exp(self.log_step[:, 0]) - - # Discretize - if self.discretization in ["zoh"]: - self.Lambda_bar, self.B_bar = discretize_zoh(self.Lambda, B_tilde, step) - elif self.discretization in ["bilinear"]: - self.Lambda_bar, self.B_bar = discretize_bilinear(self.Lambda, B_tilde, step) - else: - raise NotImplementedError( - "Discretization method {} not implemented".format(self.discretization) - ) - - def __call__(self, hidden, input_sequence, resets): - """ - Compute the LxH output of the S5 SSM given an LxH input sequence - using a parallel scan. - Args: - input_sequence (float32): input sequence (L, H) - resets (bool): input sequence (L,) - Returns: - output sequence (float32): (L, H) - """ - hidden, ys = apply_ssm( - self.Lambda_bar, - self.B_bar, - self.C_tilde, - hidden, - input_sequence, - resets, - self.conj_sym, - self.bidirectional, - ) - # Add feedthrough matrix output Du; - Du = jax.vmap(lambda u: self.D * u)(input_sequence) - return hidden, ys + Du - - -def init_S5SSM( - H, - P, - Lambda_re_init, - Lambda_im_init, - V, - Vinv, - C_init, - discretization, - dt_min, - dt_max, - conj_sym, - clip_eigs, - bidirectional, -): - """Convenience function that will be used to initialize the SSM. - Same arguments as defined in S5SSM above.""" - return partial( - S5SSM, - H=H, - P=P, - Lambda_re_init=Lambda_re_init, - Lambda_im_init=Lambda_im_init, - V=V, - Vinv=Vinv, - C_init=C_init, - discretization=discretization, - dt_min=dt_min, - dt_max=dt_max, - conj_sym=conj_sym, - clip_eigs=clip_eigs, - bidirectional=bidirectional, - ) - - -def make_HiPPO(N): - """Create a HiPPO-LegS matrix. - From https://github.com/srush/annotated-s4/blob/main/s4/s4.py - Args: - N (int32): state size - Returns: - N x N HiPPO LegS matrix - """ - P = np.sqrt(1 + 2 * np.arange(N)) - A = P[:, np.newaxis] * P[np.newaxis, :] - A = np.tril(A) - np.diag(np.arange(N)) - return -A - - -def make_NPLR_HiPPO(N): - """ - Makes components needed for NPLR representation of HiPPO-LegS - From https://github.com/srush/annotated-s4/blob/main/s4/s4.py - Args: - N (int32): state size - Returns: - N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B - """ - # Make -HiPPO - hippo = make_HiPPO(N) - - # Add in a rank 1 term. Makes it Normal. - P = np.sqrt(np.arange(N) + 0.5) - - # HiPPO also specifies the B matrix - B = np.sqrt(2 * np.arange(N) + 1.0) - return hippo, P, B - - -def make_DPLR_HiPPO(N): - """ - Makes components needed for DPLR representation of HiPPO-LegS - From https://github.com/srush/annotated-s4/blob/main/s4/s4.py - Note, we will only use the diagonal part - Args: - N: - Returns: - eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, - eigenvectors V, HiPPO B pre-conjugation - """ - A, P, B = make_NPLR_HiPPO(N) - - S = A + P[:, np.newaxis] * P[np.newaxis, :] - - S_diag = np.diagonal(S) - Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) - - # Diagonalize S to V \Lambda V^* - Lambda_imag, V = eigh(S * -1j) - - P = V.conj().T @ P - B_orig = B - B = V.conj().T @ B - return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig - - -class StackedEncoderModel(nn.Module): - """Defines a stack of S5 layers to be used as an encoder. - Args: - ssm (nn.Module): the SSM to be used (i.e. S5 ssm) - d_model (int32): this is the feature size of the layer inputs and outputs - we usually refer to this size as H - n_layers (int32): the number of S5 layers to stack - activation (string): Type of activation function to use - prenorm (bool): apply prenorm if true or postnorm if false - """ - - ssm_size: int - d_model: int - n_layers: int - activation: str = "gelu" - do_norm: bool = True - prenorm: bool = True - do_gtrxl_norm: bool = True - - def setup(self): - """ - Initializes a linear encoder and the stack of S5 layers. - """ - blocks = 1 - block_size = int(self.ssm_size / blocks) - Lambda, _, _, V, _ = make_DPLR_HiPPO(self.ssm_size) - block_size = block_size // 2 - Lambda = Lambda[:block_size] - V = V[:, :block_size] - Vinv = V.conj().T - # self.encoder = nn.Dense(self.d_model) - self.layers = [ - SequenceLayer( - ssm=init_S5SSM( - H=self.d_model, - P=self.ssm_size // 2, - Lambda_re_init=Lambda.real, - Lambda_im_init=Lambda.imag, - V=V, - Vinv=Vinv, - C_init="lecun_normal", - discretization="zoh", - dt_min=0.001, - dt_max=0.1, - conj_sym=True, - clip_eigs=False, - bidirectional=False, - ), - d_model=self.d_model, - activation=self.activation, - do_norm=self.do_norm, - prenorm=self.prenorm, - do_gtrxl_norm=self.do_gtrxl_norm, - ) - for _ in range(self.n_layers) - ] - - def __call__(self, hidden, inputs): - """ - Compute the LxH output of the stacked encoder given an Lxd_input - input sequence. - Args: - x (float32): input sequence (L, d_input) - Returns: - output sequence (float32): (L, d_model) - """ - x, d = inputs - new_hiddens = [] - hidden = jax.tree.map(lambda x: jnp.expand_dims(x, 0), hidden) - for i, layer in enumerate(self.layers): - new_h, x = layer(hidden[i], x, d) - new_hiddens.append(new_h) - - new_hiddens = jax.tree.map(lambda x: x.squeeze(0), new_hiddens) - return new_hiddens, x - - @nn.nowrap - def initialize_carry(self, batch_size): - # Use a dummy key since the default state init fn is just zeros. - return [ - jnp.zeros((batch_size, self.ssm_size // 2), dtype=jnp.complex64) - for _ in range(self.n_layers) - ] diff --git a/stoix/networks/memoroids/types.py b/stoix/networks/memoroids/types.py deleted file mode 100644 index 56bd8632..00000000 --- a/stoix/networks/memoroids/types.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import List, Optional, Tuple - -import chex -from flax import linen as nn -from jax import numpy as jnp - -RecurrentState = chex.Array -Reset = chex.Array -Timestep = chex.Array -InputEmbedding = chex.Array -Inputs = Tuple[InputEmbedding, Reset] -ScanInput = chex.Array From 26d1fc469d3a12f05d6b5450195d30a2d48a9463 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Sun, 7 Jul 2024 20:39:32 +0000 Subject: [PATCH 38/38] chore: change configs --- .../configs/env/popjym/auto_encode_easy.yaml | 12 +++++ .../env/popjym/auto_encode_medium.yaml | 12 +++++ .../configs/env/popjym/count_recall_easy.yaml | 12 +++++ .../env/popjym/count_recall_medium.yaml | 12 +++++ stoix/configs/network/stacked_lrm.yaml | 14 +++--- stoix/configs/system/rec_ppo.yaml | 2 +- stoix/networks/lrm/base.py | 41 +---------------- stoix/networks/lrm/shared.py | 44 +++++++++++++++++++ 8 files changed, 100 insertions(+), 49 deletions(-) create mode 100644 stoix/configs/env/popjym/auto_encode_easy.yaml create mode 100644 stoix/configs/env/popjym/auto_encode_medium.yaml create mode 100644 stoix/configs/env/popjym/count_recall_easy.yaml create mode 100644 stoix/configs/env/popjym/count_recall_medium.yaml create mode 100644 stoix/networks/lrm/shared.py diff --git a/stoix/configs/env/popjym/auto_encode_easy.yaml b/stoix/configs/env/popjym/auto_encode_easy.yaml new file mode 100644 index 00000000..ba762fb7 --- /dev/null +++ b/stoix/configs/env/popjym/auto_encode_easy.yaml @@ -0,0 +1,12 @@ +# ---Environment Configs--- +env_name: popjym + +scenario: + name: AutoencodeEasy + task_name: auto_encode_easy + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return diff --git a/stoix/configs/env/popjym/auto_encode_medium.yaml b/stoix/configs/env/popjym/auto_encode_medium.yaml new file mode 100644 index 00000000..2bc8fe63 --- /dev/null +++ b/stoix/configs/env/popjym/auto_encode_medium.yaml @@ -0,0 +1,12 @@ +# ---Environment Configs--- +env_name: popjym + +scenario: + name: AutoencodeMedium + task_name: auto_encode_medium + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return diff --git a/stoix/configs/env/popjym/count_recall_easy.yaml b/stoix/configs/env/popjym/count_recall_easy.yaml new file mode 100644 index 00000000..2cf4481d --- /dev/null +++ b/stoix/configs/env/popjym/count_recall_easy.yaml @@ -0,0 +1,12 @@ +# ---Environment Configs--- +env_name: popjym + +scenario: + name: CountRecallEasy + task_name: count_recall_easy + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return diff --git a/stoix/configs/env/popjym/count_recall_medium.yaml b/stoix/configs/env/popjym/count_recall_medium.yaml new file mode 100644 index 00000000..867bc8a4 --- /dev/null +++ b/stoix/configs/env/popjym/count_recall_medium.yaml @@ -0,0 +1,12 @@ +# ---Environment Configs--- +env_name: popjym + +scenario: + name: CountRecallMedium + task_name: count_recall_medium + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return diff --git a/stoix/configs/network/stacked_lrm.yaml b/stoix/configs/network/stacked_lrm.yaml index 778e878f..cf3ec062 100644 --- a/stoix/configs/network/stacked_lrm.yaml +++ b/stoix/configs/network/stacked_lrm.yaml @@ -7,12 +7,11 @@ actor_network: use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.lrm.layers.StackedLRM + _target_: stoix.networks.lrm.shared.StackedLRM num_cells: 2 - lrm_cell_type: s5 + lrm_cell_type: lru cell_kwargs: - d_model: 256 - state_size: 256 + hidden_state_dim: 256 post_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] @@ -28,12 +27,11 @@ critic_network: use_layer_norm: False activation: leaky_relu rnn_layer: - _target_: stoix.networks.lrm.layers.StackedLRM + _target_: stoix.networks.lrm.shared.StackedLRM num_cells: 2 - lrm_cell_type: s5 + lrm_cell_type: lru cell_kwargs: - d_model: 256 - state_size: 256 + hidden_state_dim: 256 post_torso: _target_: stoix.networks.torso.MLPTorso layer_sizes: [256] diff --git a/stoix/configs/system/rec_ppo.yaml b/stoix/configs/system/rec_ppo.yaml index cb6f5c0c..e1d1241b 100644 --- a/stoix/configs/system/rec_ppo.yaml +++ b/stoix/configs/system/rec_ppo.yaml @@ -5,7 +5,7 @@ system_name: rec_ppo # Name of the system. # --- RL hyperparameters --- actor_lr: 3e-5 # Learning rate for actor network critic_lr: 3e-5 # Learning rate for critic network -rollout_length: 256 # Number of environment steps per vectorised environment. +rollout_length: 64 # Number of environment steps per vectorised environment. epochs: 10 # Number of ppo epochs per training data batch. num_minibatches: 64 # Number of minibatches per ppo epoch. gamma: 0.99 # Discounting factor. diff --git a/stoix/networks/lrm/base.py b/stoix/networks/lrm/base.py index 255e1d7e..8cde4778 100644 --- a/stoix/networks/lrm/base.py +++ b/stoix/networks/lrm/base.py @@ -1,10 +1,8 @@ -from typing import Any, Dict, List, Tuple, TypeAlias +from typing import Tuple, TypeAlias import chex import flax.linen as nn -from stoix.networks.lrm.utils import parse_lrm_cell - RecurrentState: TypeAlias = chex.Array Reset: TypeAlias = chex.Array Timestep: TypeAlias = chex.Array @@ -22,40 +20,3 @@ def __call__( @nn.nowrap def initialize_carry(self, batch_size: int) -> RecurrentState: raise NotImplementedError - - -class StackedLRM(nn.Module): - lrm_cell_type: LRMCellBase - cell_kwargs: Dict[str, Any] - num_cells: int - - def setup(self) -> None: - """Set up the LRM cells for the stacked LRM.""" - - cell_cls = parse_lrm_cell(self.lrm_cell_type) - self.cells = [cell_cls(**self.cell_kwargs) for _ in range(self.num_cells)] - - @nn.compact - def __call__( - self, all_states: List[RecurrentState], inputs: Inputs - ) -> Tuple[RecurrentState, chex.Array]: - # Ensure all_states is a list - if not isinstance(all_states, list): - all_states = [all_states] - - assert len(all_states) == len( - self.cells - ), f"Expected {len(self.cells)} states, got {len(all_states)}" - x, starts = inputs - new_states = [] - for cell, mem_state in zip(self.cells, all_states): - new_mem_state, x = cell(mem_state, (x, starts)) - new_states.append(new_mem_state) - - return new_states, x - - @nn.nowrap - def initialize_carry(self, batch_size: int) -> List[RecurrentState]: - cell_cls = parse_lrm_cell(self.lrm_cell_type) - cells = [cell_cls(**self.cell_kwargs) for _ in range(self.num_cells)] - return [cell.initialize_carry(batch_size) for cell in cells] diff --git a/stoix/networks/lrm/shared.py b/stoix/networks/lrm/shared.py new file mode 100644 index 00000000..1ba475b8 --- /dev/null +++ b/stoix/networks/lrm/shared.py @@ -0,0 +1,44 @@ +from typing import Any, Dict, List, Tuple + +import chex +import flax.linen as nn + +from stoix.networks.lrm.base import Inputs, LRMCellBase, RecurrentState +from stoix.networks.lrm.utils import parse_lrm_cell + + +class StackedLRM(nn.Module): + lrm_cell_type: LRMCellBase + cell_kwargs: Dict[str, Any] + num_cells: int + + def setup(self) -> None: + """Set up the LRM cells for the stacked LRM.""" + + cell_cls = parse_lrm_cell(self.lrm_cell_type) + self.cells = [cell_cls(**self.cell_kwargs) for _ in range(self.num_cells)] + + @nn.compact + def __call__( + self, all_states: List[RecurrentState], inputs: Inputs + ) -> Tuple[RecurrentState, chex.Array]: + # Ensure all_states is a list + if not isinstance(all_states, list): + all_states = [all_states] + + assert len(all_states) == len( + self.cells + ), f"Expected {len(self.cells)} states, got {len(all_states)}" + x, starts = inputs + new_states = [] + for cell, mem_state in zip(self.cells, all_states): + new_mem_state, x = cell(mem_state, (x, starts)) + new_states.append(new_mem_state) + + return new_states, x + + @nn.nowrap + def initialize_carry(self, batch_size: int) -> List[RecurrentState]: + cell_cls = parse_lrm_cell(self.lrm_cell_type) + cells = [cell_cls(**self.cell_kwargs) for _ in range(self.num_cells)] + return [cell.initialize_carry(batch_size) for cell in cells]