Skip to content

Commit

Permalink
Fix off-by-one and improve type annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Feb 8, 2024
1 parent 6504fe9 commit bfbe00b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 17 deletions.
13 changes: 7 additions & 6 deletions sbx/sac/sac.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -212,7 +213,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
gradient_steps,
data,
self.policy_delay,
self._n_updates % self.policy_delay,
(self.n_updates + 1) % self.policy_delay,
self.policy.qf_state,
self.policy.actor_state,
self.ent_coef_state,
Expand Down Expand Up @@ -259,7 +260,7 @@ def update_critic(
# shape is (batch_size, 1)
target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values

def mse_loss(params, dropout_key):
def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array:
# shape is (n_critics, batch_size, 1)
current_q_values = qf_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key})
return 0.5 * ((target_q_values - current_q_values) ** 2).mean(axis=1).sum()
Expand All @@ -284,7 +285,7 @@ def update_actor(
):
key, dropout_key, noise_key = jax.random.split(key, 3)

def actor_loss(params):
def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]:
dist = actor_state.apply_fn(params, observations)
actor_actions = dist.sample(seed=noise_key)
log_prob = dist.log_prob(actor_actions).reshape(-1, 1)
Expand All @@ -308,16 +309,16 @@ def actor_loss(params):

@staticmethod
@jax.jit
def soft_update(tau: float, qf_state: RLTrainState):
def soft_update(tau: float, qf_state: RLTrainState) -> RLTrainState:
qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau))
return qf_state

@staticmethod
@jax.jit
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
def temperature_loss(temp_params):
def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array:
ent_coef_value = ent_coef_state.apply_fn({"params": temp_params})
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean()
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr]
return ent_coef_loss

ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params)
Expand Down
13 changes: 9 additions & 4 deletions sbx/td3/td3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union

import flax
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -151,7 +152,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
gradient_steps,
data,
self.policy_delay,
self._n_updates % self.policy_delay,
(self.n_updates + 1) % self.policy_delay,
self.target_policy_noise,
self.target_noise_clip,
self.policy.qf_state,
Expand Down Expand Up @@ -197,7 +198,7 @@ def update_critic(
# shape is (batch_size, 1)
target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values

def mse_loss(params, dropout_key):
def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array:
# shape is (n_critics, batch_size, 1)
current_q_values = qf_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key})
return 0.5 * ((target_q_values - current_q_values) ** 2).mean(axis=1).sum()
Expand All @@ -221,7 +222,7 @@ def update_actor(
):
key, dropout_key = jax.random.split(key, 2)

def actor_loss(params):
def actor_loss(params: flax.core.FrozenDict) -> jax.Array:
actor_actions = actor_state.apply_fn(params, observations)

qf_pi = qf_state.apply_fn(
Expand All @@ -242,7 +243,7 @@ def actor_loss(params):

@staticmethod
@jax.jit
def soft_update(tau: float, qf_state: RLTrainState, actor_state: RLTrainState):
def soft_update(tau: float, qf_state: RLTrainState, actor_state: RLTrainState) -> Tuple[RLTrainState, RLTrainState]:
qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau))
actor_state = actor_state.replace(
target_params=optax.incremental_update(actor_state.params, actor_state.target_params, tau)
Expand Down Expand Up @@ -279,6 +280,8 @@ def _train(
}

def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:
# Note: this method must be defined inline because
# `fori_loop` expect a signature fn(index, carry) -> carry
actor_state = carry["actor_state"]
qf_state = carry["qf_state"]
key = carry["key"]
Expand Down Expand Up @@ -309,7 +312,9 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:

(actor_state, qf_state, actor_loss_value, key) = jax.lax.cond(
(policy_delay_offset + i) % policy_delay == 0,
# If True:
cls.update_actor,
# If False:
lambda *_: (actor_state, qf_state, info["actor_loss"], key),
actor_state,
qf_state,
Expand Down
19 changes: 12 additions & 7 deletions sbx/tqc/tqc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -215,7 +216,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
self.policy.n_target_quantiles,
data,
self.policy_delay,
self._n_updates % self.policy_delay,
(self.n_updates + 1) % self.policy_delay,
self.policy.qf1_state,
self.policy.qf2_state,
self.policy.actor_state,
Expand Down Expand Up @@ -284,9 +285,9 @@ def update_critic(
# Make target_quantiles broadcastable to (batch_size, n_quantiles, n_target_quantiles).
target_quantiles = jnp.expand_dims(target_quantiles, axis=1)

def huber_quantile_loss(params, noise_key):
def huber_quantile_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array:
# Compute huber quantile loss
current_quantiles = qf1_state.apply_fn(params, observations, actions, True, rngs={"dropout": noise_key})
current_quantiles = qf1_state.apply_fn(params, observations, actions, True, rngs={"dropout": dropout_key})
# convert to shape: (batch_size, n_quantiles, 1) for broadcast
current_quantiles = jnp.expand_dims(current_quantiles, axis=-1)

Expand Down Expand Up @@ -327,7 +328,7 @@ def update_actor(
):
key, dropout_key_1, dropout_key_2, noise_key = jax.random.split(key, 4)

def actor_loss(params):
def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]:
dist = actor_state.apply_fn(params, observations)
actor_actions = dist.sample(seed=noise_key)
log_prob = dist.log_prob(actor_actions).reshape(-1, 1)
Expand Down Expand Up @@ -364,18 +365,18 @@ def actor_loss(params):

@staticmethod
@jax.jit
def soft_update(tau: float, qf1_state: RLTrainState, qf2_state: RLTrainState):
def soft_update(tau: float, qf1_state: RLTrainState, qf2_state: RLTrainState) -> Tuple[RLTrainState, RLTrainState]:
qf1_state = qf1_state.replace(target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, tau))
qf2_state = qf2_state.replace(target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, tau))
return qf1_state, qf2_state

@staticmethod
@jax.jit
def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float):
def temperature_loss(temp_params):
def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array:
ent_coef_value = ent_coef_state.apply_fn({"params": temp_params})
# ent_coef_loss = (jnp.log(ent_coef_value) * (entropy - target_entropy)).mean()
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean()
ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr]
return ent_coef_loss

ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params)
Expand Down Expand Up @@ -444,6 +445,8 @@ def _train(
}

def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:
# Note: this method must be defined inline because
# `fori_loop` expect a signature fn(index, carry) -> carry
actor_state = carry["actor_state"]
qf1_state = carry["qf1_state"]
qf2_state = carry["qf2_state"]
Expand Down Expand Up @@ -477,7 +480,9 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]:

(actor_state, (qf1_state, qf2_state), ent_coef_state, actor_loss_value, ent_coef_loss_value, key) = jax.lax.cond(
(policy_delay_offset + i) % policy_delay == 0,
# If True:
cls.update_actor_and_temperature,
# If False:
lambda *_: (
actor_state,
(qf1_state, qf2_state),
Expand Down

0 comments on commit bfbe00b

Please sign in to comment.