diff --git a/navix/_version.py b/navix/_version.py index 6006f8f..9026dd1 100644 --- a/navix/_version.py +++ b/navix/_version.py @@ -18,5 +18,5 @@ # under the License. -__version__ = "0.6.0" +__version__ = "0.6.1" __version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit()) diff --git a/navix/agents/models.py b/navix/agents/models.py index f743d77..b9e3b36 100644 --- a/navix/agents/models.py +++ b/navix/agents/models.py @@ -1,3 +1,4 @@ +from typing import Tuple from jax import Array import jax.numpy as jnp import distrax @@ -5,6 +6,18 @@ from flax.linen.initializers import constant, orthogonal +RNNState = tuple + + +class DenseRNN(nn.Dense): + """A linear module that returns an empty RNN state, + which makes it behave like an RNN layer.""" + + @nn.compact + def __call__(self, carry: RNNState, x: Array) -> Tuple[RNNState, Array]: + return (), super().__call__(x) + + class MLPEncoder(nn.Module): hidden_size: int = 64 @@ -66,11 +79,66 @@ def setup(self): ] ) - def __call__(self, x): + def __call__(self, x: Array) -> Tuple[distrax.Distribution, Array]: return distrax.Categorical(self.actor(x)), jnp.squeeze(self.critic(x), -1) def policy(self, x: Array) -> distrax.Distribution: return distrax.Categorical(logits=self.actor(x)) - + def value(self, x: Array) -> Array: return jnp.squeeze(self.critic(x), -1) + + +class ActorCriticRnn(nn.Module): + action_dim: int + actor_encoder: nn.Module = MLPEncoder() + critic_encoder: nn.Module = MLPEncoder() + hidden_size: int = 64 + recurrent: bool = False + + def setup(self): + self.actor = nn.Sequential( + [ + self.actor_encoder, + nn.Dense( + self.action_dim, + kernel_init=orthogonal(0.01), + bias_init=constant(0.0), + ), + # lambda x: distrax.Categorical(logits=x), + ] + ) + + self.critic = nn.Sequential( + [ + self.critic_encoder, + nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0)), + # lambda x: jnp.squeeze(x, axis=-1), + ] + ) + + if self.recurrent: + self.core_actor = nn.LSTMCell(self.hidden_size) + self.core_critic = nn.LSTMCell(self.hidden_size) + else: + self.core_actor = DenseRNN(self.hidden_size) + self.core_critic = DenseRNN(self.hidden_size) + + def __call__( + self, x: Array, carry: RNNState + ) -> Tuple[RNNState, Tuple[distrax.Distribution, Array]]: + pi = distrax.Categorical(logits=self.actor(x)) + v = jnp.squeeze(self.critic(x), -1) + return carry, (pi, v) + + def policy( + self, carry: RNNState, x: Array + ) -> Tuple[RNNState, distrax.Distribution]: + actor_embed = self.actor_encoder(x) + carry, actor_embed = self.core_actor(carry, actor_embed) + return carry, distrax.Categorical(logits=self.actor(x)) + + def value(self, x: Array, carry: RNNState) -> Tuple[RNNState, Array]: + critic_embed = self.critic_encoder(x) + carry, critic_embed = self.core_critic(carry, critic_embed) + return carry, jnp.squeeze(self.critic(x), axis=-1)