-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update to match SB3 * Update min pytorch version * Remove pytype * Add base TD3 * Add DDPG * Remove unused variables
- Loading branch information
Showing
18 changed files
with
614 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from sbx.ddpg.ddpg import DDPG | ||
|
||
__all__ = ["DDPG"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union | ||
|
||
from stable_baselines3.common.buffers import ReplayBuffer | ||
from stable_baselines3.common.noise import ActionNoise | ||
from stable_baselines3.common.type_aliases import GymEnv, Schedule | ||
|
||
from sbx.td3.policies import TD3Policy | ||
from sbx.td3.td3 import TD3 | ||
|
||
|
||
class DDPG(TD3): | ||
policy_aliases: ClassVar[Dict[str, Type[TD3Policy]]] = { | ||
"MlpPolicy": TD3Policy, | ||
} | ||
|
||
def __init__( | ||
self, | ||
policy, | ||
env: Union[GymEnv, str], | ||
learning_rate: Union[float, Schedule] = 3e-4, | ||
qf_learning_rate: Optional[float] = 1e-3, | ||
buffer_size: int = 1_000_000, # 1e6 | ||
learning_starts: int = 100, | ||
batch_size: int = 256, | ||
tau: float = 0.005, | ||
gamma: float = 0.99, | ||
train_freq: Union[int, Tuple[int, str]] = 1, | ||
gradient_steps: int = 1, | ||
action_noise: Optional[ActionNoise] = None, | ||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None, | ||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None, | ||
tensorboard_log: Optional[str] = None, | ||
policy_kwargs: Optional[Dict[str, Any]] = None, | ||
verbose: int = 0, | ||
seed: Optional[int] = None, | ||
device: str = "auto", | ||
_init_setup_model: bool = True, | ||
) -> None: | ||
super().__init__( | ||
policy=policy, | ||
env=env, | ||
learning_rate=learning_rate, | ||
qf_learning_rate=qf_learning_rate, | ||
buffer_size=buffer_size, | ||
learning_starts=learning_starts, | ||
batch_size=batch_size, | ||
tau=tau, | ||
gamma=gamma, | ||
train_freq=train_freq, | ||
gradient_steps=gradient_steps, | ||
action_noise=action_noise, | ||
# Remove all tricks from TD3 to obtain DDPG: | ||
# we still need to specify target_policy_noise > 0 to avoid errors | ||
policy_delay=1, | ||
target_policy_noise=0.1, | ||
target_noise_clip=0.0, | ||
replay_buffer_class=replay_buffer_class, | ||
replay_buffer_kwargs=replay_buffer_kwargs, | ||
policy_kwargs=policy_kwargs, | ||
tensorboard_log=tensorboard_log, | ||
verbose=verbose, | ||
seed=seed, | ||
device=device, | ||
_init_setup_model=False, | ||
) | ||
|
||
# Use only one critic | ||
if "n_critics" not in self.policy_kwargs: | ||
self.policy_kwargs["n_critics"] = 1 | ||
|
||
if _init_setup_model: | ||
self._setup_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from sbx.td3.td3 import TD3 | ||
|
||
__all__ = ["TD3"] |
Oops, something went wrong.