Skip to content

Commit

Permalink
start implementing rnn ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
epignatelli committed Jun 6, 2024
1 parent a42bbb3 commit dff090b
Show file tree
Hide file tree
Showing 3 changed files with 415 additions and 35 deletions.
95 changes: 62 additions & 33 deletions navix/agents/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import functools
from typing import Tuple
from jax import Array
import jax
import jax.numpy as jnp
import distrax
import flax.linen as nn
Expand All @@ -9,14 +11,24 @@
RNNState = tuple


class DenseRNN(nn.Dense):
class DenseRNN(nn.Dense, nn.RNNCellBase):
"""A linear module that returns an empty RNN state,
which makes it behave like an RNN layer."""

@nn.compact
def __call__(self, carry: RNNState, x: Array) -> Tuple[RNNState, Array]:
return (), super().__call__(x)

@nn.nowrap
def initialize_carry(
self, rng: Array, input_shape: Tuple[int, ...]
) -> Tuple[Array, Array]:
return (jnp.asarray(()), jnp.asarray(()))

@property
def num_feature_axes(self) -> int:
return 1


class MLPEncoder(nn.Module):
hidden_size: int = 64
Expand Down Expand Up @@ -89,56 +101,73 @@ def value(self, x: Array) -> Array:
return jnp.squeeze(self.critic(x), -1)


class ActorCriticRnn(nn.Module):
class ActorCriticRNN(nn.Module):
action_dim: int
actor_encoder: nn.Module = MLPEncoder()
critic_encoder: nn.Module = MLPEncoder()
hidden_size: int = 64
recurrent: bool = False

def setup(self):
self.actor = nn.Sequential(
[
self.actor_encoder,
nn.Dense(
self.action_dim,
kernel_init=orthogonal(0.01),
bias_init=constant(0.0),
),
# lambda x: distrax.Categorical(logits=x),
]
)

self.critic = nn.Sequential(
[
self.critic_encoder,
nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0)),
# lambda x: jnp.squeeze(x, axis=-1),
]
)

if self.recurrent:
self.core_actor = nn.LSTMCell(self.hidden_size)
self.core_critic = nn.LSTMCell(self.hidden_size)
else:
self.core_actor = DenseRNN(self.hidden_size)
self.core_critic = DenseRNN(self.hidden_size)

self.actor_head = nn.Dense(
self.action_dim,
kernel_init=orthogonal(0.01),
bias_init=constant(0.0),
)
self.critic_head = nn.Dense(
1, kernel_init=orthogonal(1.0), bias_init=constant(0.0)
)

@functools.partial(
nn.scan,
variable_broadcast="params",
in_axes=0,
out_axes=0,
split_rngs={"params": False},
)
def __call__(
self, x: Array, carry: RNNState
self, x: Array, carry: Tuple[RNNState, RNNState], done=None
) -> Tuple[RNNState, Tuple[distrax.Distribution, Array]]:
pi = distrax.Categorical(logits=self.actor(x))
v = jnp.squeeze(self.critic(x), -1)
return carry, (pi, v)
if done is None:
done = jnp.zeros(x.shape[0], dtype=jnp.bool_)

# TODO(epignatelli): Implement reset

carry_actor, pi = self.policy(carry, x)
carry_critic, v = self.value(carry, x)
return (carry_actor, carry_critic), (pi, v)

@nn.nowrap
def initialize_carry(
self, rng: Array, input_shape: Tuple[int, ...]
) -> Tuple[RNNState, RNNState]:
carry_actor = self.core_actor.initialize_carry(rng, input_shape)
carry_critic = self.core_critic.initialize_carry(rng, input_shape)
return (carry_actor, carry_critic)

def policy(
self, carry: RNNState, x: Array
self, carry: Tuple[RNNState, RNNState], x: Array
) -> Tuple[RNNState, distrax.Distribution]:
carry_actor, carry_critic = carry
actor_embed = self.actor_encoder(x)
carry, actor_embed = self.core_actor(carry, actor_embed)
return carry, distrax.Categorical(logits=self.actor(x))

def value(self, x: Array, carry: RNNState) -> Tuple[RNNState, Array]:
carry_actor, actor_embed = self.core_actor(carry_actor, actor_embed)
logits = self.actor_head(actor_embed)
carry = (carry_actor, carry_critic)
return carry, distrax.Categorical(logits=logits)

def value(
self, carry: Tuple[RNNState, RNNState], x: Array
) -> Tuple[RNNState, Array]:
carry_actor, carry_critic = carry
critic_embed = self.critic_encoder(x)
carry, critic_embed = self.core_critic(carry, critic_embed)
return carry, jnp.squeeze(self.critic(x), axis=-1)
carry_critic, critic_embed = self.core_critic(carry_critic, critic_embed)
value = self.critic_head(critic_embed)
carry = (carry_actor, carry_critic)
return carry, jnp.squeeze(value, axis=-1)
4 changes: 2 additions & 2 deletions navix/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from navix.environments.environment import Timestep
from navix.states import State

from .models import ActorCritic
from .models import ActorCriticRNN


@dataclass
Expand Down Expand Up @@ -87,7 +87,7 @@ class TrainingState(TrainState):

class PPO(Agent):
hparams: PPOHparams = struct.field(pytree_node=False)
network: ActorCritic = struct.field(pytree_node=False)
network: ActorCriticRNN = struct.field(pytree_node=False)
env: Environment

def collect_experience(
Expand Down
Loading

0 comments on commit dff090b

Please sign in to comment.