Skip to content

Commit

Permalink
Start testing simba
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Nov 1, 2024
1 parent 1c79684 commit 9589326
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 7 deletions.
17 changes: 17 additions & 0 deletions sbx/common/jax_layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen.module import Module, compact, merge_param
Expand Down Expand Up @@ -204,3 +205,19 @@ def __call__(self, x, use_running_average: Optional[bool] = None):
self.bias_init,
self.scale_init,
)


# Adapted from simba: https://github.com/SonyResearch/simba
class SimbaResidualBlock(nn.Module):
hidden_dim: int
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
scale_factor: int = 4

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
residual = x
x = nn.LayerNorm()(x)
x = nn.Dense(self.hidden_dim * self.scale_factor, kernel_init=nn.initializers.he_normal())(x)
x = self.activation_fn(x)
x = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.he_normal())(x)
return residual + x
53 changes: 53 additions & 0 deletions sbx/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from stable_baselines3.common.preprocessing import is_image_space, maybe_transpose
from stable_baselines3.common.utils import is_vectorized_observation

from sbx.common.jax_layers import SimbaResidualBlock


class Flatten(nn.Module):
"""
Expand Down Expand Up @@ -143,6 +145,29 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
return x


class SimbaContinuousCritic(nn.Module):
net_arch: Sequence[int]
dropout_rate: Optional[float] = None
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
scale_factor: int = 4

@nn.compact
def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, action], -1)
# Note: simba was using kernel_init=orthogonal_init(1)
x = nn.Dense(self.net_arch[0])(x)
for n_units in self.net_arch:
x = SimbaResidualBlock(n_units, self.activation_fn, self.scale_factor)(x)
# TODO: double check where to put the dropout
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False)
x = nn.LayerNorm()(x)

x = nn.Dense(1)(x)
return x


class VectorCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
Expand All @@ -169,3 +194,31 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray):
activation_fn=self.activation_fn,
)(obs, action)
return q_values


class SimbaVectorCritic(nn.Module):
net_arch: Sequence[int]
# Note: we have use_layer_norm for consistency but it is not used (always on)
use_layer_norm: bool = True
dropout_rate: Optional[float] = None
n_critics: int = 1 # only one critic per default
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, obs: jnp.ndarray, action: jnp.ndarray):
# Idea taken from https://github.com/perrin-isir/xpag
# Similar to https://github.com/tinkoff-ai/CORL for PyTorch
vmap_critic = nn.vmap(
SimbaContinuousCritic,
variable_axes={"params": 0}, # parameters not shared between the critics
split_rngs={"params": True, "dropout": True}, # different initializations
in_axes=None,
out_axes=0,
axis_size=self.n_critics,
)
q_values = vmap_critic(
dropout_rate=self.dropout_rate,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
)(obs, action)
return q_values
96 changes: 92 additions & 4 deletions sbx/sac/policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union

import flax.linen as nn
import jax
Expand All @@ -11,7 +11,8 @@
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.distributions import TanhTransformedDistribution
from sbx.common.policies import BaseJaxPolicy, Flatten, VectorCritic
from sbx.common.jax_layers import SimbaResidualBlock
from sbx.common.policies import BaseJaxPolicy, Flatten, SimbaVectorCritic, VectorCritic
from sbx.common.type_aliases import RLTrainState

tfd = tfp.distributions
Expand Down Expand Up @@ -43,6 +44,40 @@ def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-def
return dist


class SimbaActor(nn.Module):
# Note: each element in net_arch correpond to a residual block
# not just a single layer
net_arch: Sequence[int]
action_dim: int
# num_blocks: int = 2
log_std_min: float = -20
log_std_max: float = 2
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
scale_factor: int = 4

def get_std(self):
# Make it work with gSDE
return jnp.array(0.0)

@nn.compact
def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined]
x = Flatten()(x)

# Note: simba was using kernel_init=orthogonal_init(1)
x = nn.Dense(self.net_arch[0])(x)
for n_units in self.net_arch:
x = SimbaResidualBlock(n_units, self.activation_fn, self.scale_factor)(x)
x = nn.LayerNorm()(x)

mean = nn.Dense(self.action_dim)(x)
log_std = nn.Dense(self.action_dim)(x)
log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
dist = TanhTransformedDistribution(
tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)),
)
return dist


class SACPolicy(BaseJaxPolicy):
action_space: spaces.Box # type: ignore[assignment]

Expand All @@ -68,6 +103,8 @@ def __init__(
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
actor_class: Type[nn.Module] = Actor,
vector_critic_class: Type[nn.Module] = VectorCritic,
):
super().__init__(
observation_space,
Expand All @@ -91,6 +128,8 @@ def __init__(
self.n_critics = n_critics
self.use_sde = use_sde
self.activation_fn = activation_fn
self.actor_class = actor_class
self.vector_critic_class = vector_critic_class

self.key = self.noise_key = jax.random.PRNGKey(0)

Expand All @@ -107,7 +146,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
obs = jnp.array([self.observation_space.sample()])
action = jnp.array([self.action_space.sample()])

self.actor = Actor(
self.actor = self.actor_class(
action_dim=int(np.prod(self.action_space.shape)),
net_arch=self.net_arch_pi,
activation_fn=self.activation_fn,
Expand All @@ -124,7 +163,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
),
)

self.qf = VectorCritic(
self.qf = self.vector_critic_class(
dropout_rate=self.dropout_rate,
use_layer_norm=self.layer_norm,
net_arch=self.net_arch_qf,
Expand Down Expand Up @@ -174,3 +213,52 @@ def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.n
if not self.use_sde:
self.reset_noise()
return BaseJaxPolicy.sample_action(self.actor_state, observation, self.noise_key)


class SimbaSACPolicy(SACPolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
dropout_rate: float = 0,
layer_norm: bool = False,
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2,
features_extractor_class=None,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
# AdamW for simba
optimizer_class: Callable[..., optax.GradientTransformation] = optax.adamw,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
actor_class: Type[nn.Module] = SimbaActor,
vector_critic_class: Type[nn.Module] = SimbaVectorCritic,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
dropout_rate,
layer_norm,
activation_fn,
use_sde,
log_std_init,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
actor_class,
vector_critic_class,
)
14 changes: 13 additions & 1 deletion sbx/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax
from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState
from sbx.sac.policies import SACPolicy
from sbx.sac.policies import SACPolicy, SimbaSACPolicy


class EntropyCoef(nn.Module):
Expand All @@ -42,6 +42,18 @@ def __call__(self) -> float:
class SAC(OffPolicyAlgorithmJax):
policy_aliases: ClassVar[Dict[str, Type[SACPolicy]]] = { # type: ignore[assignment]
"MlpPolicy": SACPolicy,
# Residual net, from https://github.com/SonyResearch/simba
# hypeparams:
# https://github.com/SonyResearch/simba/blob/master/configs/agent/sac_simba.yaml#L16
# NOTE: simba codebase is using several tricks:
# - special initialization
# - residual block with scale factor (x4)
# - AdamW with weight_decay=1e-2
# - heuristic for gamma from TD-MPC2
# - only one critic except for humanoid
# - it is not using policy delay with using larger UTD (replay ratio)
# see https://github.com/SonyResearch/simba/blob/master/configs/base.yaml#L15
"SimbaPolicy": SimbaSACPolicy,
# Minimal dict support using flatten()
"MultiInputPolicy": SACPolicy,
}
Expand Down
8 changes: 6 additions & 2 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,14 @@ def test_tqc(tmp_path) -> None:
check_save_load(model, TQC, tmp_path)


@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, CrossQ])
@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, CrossQ, "SimbaSAC"])
def test_sac_td3(tmp_path, model_class) -> None:
policy = "MlpPolicy"
if model_class == "SimbaSAC":
model_class = SAC
policy = "SimbaPolicy"
model = model_class(
"MlpPolicy",
policy,
"Pendulum-v1",
verbose=1,
gradient_steps=1,
Expand Down

0 comments on commit 9589326

Please sign in to comment.