Skip to content

Commit

Permalink
Add CNN support for DQN (#49)
Browse files Browse the repository at this point in the history
* Add CNN support for DQN

* Update version and deps

* Fix CNN, channel last, padding and reshape
  • Loading branch information
araffin authored Jul 11, 2024
1 parent 27de67c commit 19c85a1
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 5 deletions.
4 changes: 3 additions & 1 deletion sbx/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@

from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax
from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState
from sbx.dqn.policies import DQNPolicy
from sbx.dqn.policies import CNNPolicy, DQNPolicy


class DQN(OffPolicyAlgorithmJax):
policy_aliases: ClassVar[Dict[str, Type[DQNPolicy]]] = { # type: ignore[assignment]
"MlpPolicy": DQNPolicy,
"CnnPolicy": CNNPolicy,
}
# Linear schedule will be defined in `_setup_model()`
exploration_schedule: Schedule
Expand All @@ -36,6 +37,7 @@ def __init__(
exploration_fraction: float = 0.1,
exploration_initial_eps: float = 1.0,
exploration_final_eps: float = 0.05,
optimize_memory_usage: bool = False, # Note: unused but to match SB3 API
# max_grad_norm: float = 10,
train_freq: Union[int, Tuple[int, str]] = 4,
gradient_steps: int = 1,
Expand Down
54 changes: 53 additions & 1 deletion sbx/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,32 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return x


# Add CNN policy from DQN paper
class NatureCNN(nn.Module):
n_actions: int
n_units: int = 512
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
# Convert from channel-first (PyTorch) to channel-last (Jax)
x = jnp.transpose(x, (0, 2, 3, 1))
# Convert to float and normalize the image
x = x.astype(jnp.float32) / 255.0
x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x)
x = self.activation_fn(x)
x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x)
x = self.activation_fn(x)
x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x)
x = self.activation_fn(x)
# Flatten
x = x.reshape((x.shape[0], -1))
x = nn.Dense(self.n_units)(x)
x = self.activation_fn(x)
x = nn.Dense(self.n_actions)(x)
return x


class DQNPolicy(BaseJaxPolicy):
action_space: spaces.Discrete # type: ignore[assignment]

Expand Down Expand Up @@ -65,7 +91,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:

obs = jnp.array([self.observation_space.sample()])

self.qf = QNetwork(
self.qf: nn.Module = QNetwork(
n_actions=int(self.action_space.n),
n_units=self.n_units,
activation_fn=self.activation_fn,
Expand Down Expand Up @@ -97,3 +123,29 @@ def select_action(qf_state, observations):

def _predict(self, observation: np.ndarray, deterministic: bool = True) -> np.ndarray: # type: ignore[override]
return DQNPolicy.select_action(self.qf_state, observation)


class CNNPolicy(DQNPolicy):
def build(self, key: jax.Array, lr_schedule: Schedule) -> jax.Array:
key, qf_key = jax.random.split(key, 2)

obs = jnp.array([self.observation_space.sample()])

self.qf = NatureCNN(
n_actions=int(self.action_space.n),
n_units=self.n_units,
activation_fn=self.activation_fn,
)

self.qf_state = RLTrainState.create(
apply_fn=self.qf.apply,
params=self.qf.init({"params": qf_key}, obs),
target_params=self.qf.init({"params": qf_key}, obs),
tx=self.optimizer_class(
learning_rate=lr_schedule(1), # type: ignore[call-arg]
**self.optimizer_kwargs,
),
)
self.qf.apply = jax.jit(self.qf.apply) # type: ignore[method-assign]

return key
2 changes: 1 addition & 1 deletion sbx/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.16.0
0.17.0
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
## Example
```python
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ, CrossQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
model = TQC("MlpPolicy", "Pendulum-v1", verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)
Expand All @@ -40,7 +40,7 @@
packages=[package for package in find_packages() if package.startswith("sbx")],
package_data={"sbx": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.3.0",
"stable_baselines3>=2.4.0a4,<3.0",
"jax",
"jaxlib",
"flax",
Expand Down
44 changes: 44 additions & 0 deletions tests/test_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import pytest
from stable_baselines3.common.envs import FakeImageEnv

from sbx import DQN


@pytest.mark.parametrize("model_class", [DQN])
def test_cnn(tmp_path, model_class):
SAVE_NAME = "cnn_model.zip"
# Fake grayscale with frameskip
# Atari after preprocessing: 84x84x1, here we are using lower resolution
# to check that the network handle it automatically
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1)
model = model_class(
"CnnPolicy",
env,
buffer_size=250,
policy_kwargs=dict(net_arch=[64]),
learning_starts=100,
verbose=1,
)
model.learn(total_timesteps=250)

obs, _ = env.reset()

# Test stochastic predict with channel last input
if model_class == DQN:
model.exploration_rate = 0.9

for _ in range(10):
model.predict(obs, deterministic=False)

action, _ = model.predict(obs, deterministic=True)

model.save(tmp_path / SAVE_NAME)
del model

model = model_class.load(tmp_path / SAVE_NAME)

# Check that the prediction is the same
assert np.allclose(action, model.predict(obs, deterministic=True)[0])

(tmp_path / SAVE_NAME).unlink()

0 comments on commit 19c85a1

Please sign in to comment.