Skip to content

Commit

Permalink
Add DDPG and TD3 (#16)
Browse files Browse the repository at this point in the history
* Update to match SB3

* Update min pytorch version

* Remove pytype

* Add base TD3

* Add DDPG

* Remove unused variables
  • Loading branch information
araffin authored Sep 7, 2023
1 parent b8dbac1 commit f662613
Show file tree
Hide file tree
Showing 18 changed files with 614 additions and 40 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand All @@ -32,7 +32,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch==1.13+cpu -f https://download.pytorch.org/whl/torch_stable.html
# # Install Atari Roms
# pip install autorom
Expand All @@ -55,8 +55,6 @@ jobs:
- name: Type check
run: |
make type
# skip mypy type check for python3.7
if: "!(matrix.python-version == '3.7')"
- name: Test with pytest
run: |
make pytest
11 changes: 3 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@ LINT_PATHS=sbx/ tests/ setup.py
pytest:
./scripts/run_tests.sh

pytype:
pytype -j auto

mypy:
mypy ${LINT_PATHS}

type: pytype mypy
type: mypy

lint:
# stop the build if there are Python syntax errors or undefined names
Expand Down Expand Up @@ -44,14 +41,12 @@ commit-checks: format type lint

# PyPi package release
release:
python setup.py sdist
python setup.py bdist_wheel
python -m build
twine upload dist/*

# Test PyPi package release
test-release:
python setup.py sdist
python setup.py bdist_wheel
python -m build
twine upload --repository-url https://test.pypi.org/legacy/ dist/*

.PHONY: clean spelling doc lint format check-codestyle commit-checks
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Implemented algorithms:
- [Dropout Q-Functions for Doubly Efficient Reinforcement Learning (DroQ)](https://openreview.net/forum?id=xCVJMsPv3RT)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347)
- [Deep Q Network (DQN)](https://arxiv.org/abs/1312.5602)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971)


### Install using pip
Expand All @@ -34,7 +36,7 @@ pip install sbx-rl
```python
import gymnasium as gym

from sbx import TQC, DroQ, SAC, PPO, DQN
from sbx import TQC, DroQ, SAC, PPO, DQN, TD3, DDPG

env = gym.make("Pendulum-v1")

Expand Down
8 changes: 2 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[tool.ruff]
# Same as Black.
line-length = 127
# Assume Python 3.7
target-version = "py37"
# Assume Python 3.8
target-version = "py38"
# See https://beta.ruff.rs/docs/rules/
select = ["E", "F", "B", "UP", "C90", "RUF"]
# Ignore explicit stacklevel`
Expand All @@ -21,10 +21,6 @@ profile = "black"
line_length = 127
src_paths = ["sbx"]

[tool.pytype]
inputs = ["sbx"]
disable = []

[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
Expand Down
4 changes: 4 additions & 0 deletions sbx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os

from sbx.ddpg import DDPG
from sbx.dqn import DQN
from sbx.droq import DroQ
from sbx.ppo import PPO
from sbx.sac import SAC
from sbx.td3 import TD3
from sbx.tqc import TQC

# Read version from file
Expand All @@ -12,9 +14,11 @@
__version__ = file_handler.read().strip()

__all__ = [
"DDPG",
"DQN",
"DroQ",
"PPO",
"SAC",
"TD3",
"TQC",
]
3 changes: 3 additions & 0 deletions sbx/ddpg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from sbx.ddpg.ddpg import DDPG

__all__ = ["DDPG"]
72 changes: 72 additions & 0 deletions sbx/ddpg/ddpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.type_aliases import GymEnv, Schedule

from sbx.td3.policies import TD3Policy
from sbx.td3.td3 import TD3


class DDPG(TD3):
policy_aliases: ClassVar[Dict[str, Type[TD3Policy]]] = {
"MlpPolicy": TD3Policy,
}

def __init__(
self,
policy,
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 3e-4,
qf_learning_rate: Optional[float] = 1e-3,
buffer_size: int = 1_000_000, # 1e6
learning_starts: int = 100,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 1,
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: str = "auto",
_init_setup_model: bool = True,
) -> None:
super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
qf_learning_rate=qf_learning_rate,
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
action_noise=action_noise,
# Remove all tricks from TD3 to obtain DDPG:
# we still need to specify target_policy_noise > 0 to avoid errors
policy_delay=1,
target_policy_noise=0.1,
target_noise_clip=0.0,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
seed=seed,
device=device,
_init_setup_model=False,
)

# Use only one critic
if "n_critics" not in self.policy_kwargs:
self.policy_kwargs["n_critics"] = 1

if _init_setup_model:
self._setup_model()
6 changes: 3 additions & 3 deletions sbx/dqn/dqn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union

import gymnasium as gym
import jax
Expand All @@ -15,7 +15,7 @@


class DQN(OffPolicyAlgorithmJax):
policy_aliases: Dict[str, Type[DQNPolicy]] = { # type: ignore[assignment]
policy_aliases: ClassVar[Dict[str, Type[DQNPolicy]]] = { # type: ignore[assignment]
"MlpPolicy": DQNPolicy,
}
# Linear schedule will be defined in `_setup_model()`
Expand Down Expand Up @@ -248,7 +248,7 @@ def predict(
if not deterministic and np.random.rand() < self.exploration_rate:
if self.policy.is_vectorized_observation(observation):
if isinstance(observation, dict):
n_batch = observation[list(observation.keys())[0]].shape[0]
n_batch = observation[next(iter(observation.keys()))].shape[0]
else:
n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)])
Expand Down
4 changes: 2 additions & 2 deletions sbx/droq/droq.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
Expand All @@ -9,7 +9,7 @@


class DroQ(TQC):
policy_aliases: Dict[str, Type[TQCPolicy]] = {
policy_aliases: ClassVar[Dict[str, Type[TQCPolicy]]] = {
"MlpPolicy": TQCPolicy,
}

Expand Down
4 changes: 2 additions & 2 deletions sbx/ppo/ppo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from functools import partial
from typing import Any, Dict, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -68,7 +68,7 @@ class PPO(OnPolicyAlgorithmJax):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""

policy_aliases: Dict[str, Type[PPOPolicy]] = { # type: ignore[assignment]
policy_aliases: ClassVar[Dict[str, Type[PPOPolicy]]] = { # type: ignore[assignment]
"MlpPolicy": PPOPolicy,
# "CnnPolicy": ActorCriticCnnPolicy,
# "MultiInputPolicy": MultiInputActorCriticPolicy,
Expand Down
4 changes: 2 additions & 2 deletions sbx/sac/sac.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union

import flax
import flax.linen as nn
Expand Down Expand Up @@ -39,7 +39,7 @@ def __call__(self) -> float:


class SAC(OffPolicyAlgorithmJax):
policy_aliases: Dict[str, Type[SACPolicy]] = { # type: ignore[assignment]
policy_aliases: ClassVar[Dict[str, Type[SACPolicy]]] = { # type: ignore[assignment]
"MlpPolicy": SACPolicy,
# Minimal dict support using flatten()
"MultiInputPolicy": SACPolicy,
Expand Down
3 changes: 3 additions & 0 deletions sbx/td3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from sbx.td3.td3 import TD3

__all__ = ["TD3"]
Loading

0 comments on commit f662613

Please sign in to comment.