Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Noisy Cross Entropy Method (CEM) #62

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ See documentation for the full list of included features.

**RL Algorithms**:
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
- [Noisy Cross Entropy Method (CEM)](http://dx.doi.org/10.1162/neco.2006.18.12.2936)
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/)
Expand Down
5 changes: 5 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Breaking Changes:
- Removed deprecated ``create_eval_env``, ``eval_env``, ``eval_log_path``, ``n_eval_episodes`` and ``eval_freq`` parameters,
please use an ``EvalCallback`` instead
- Removed deprecated ``sde_net_arch`` parameter
- Changed default policy architecture for ARS/CEM to ``[32]`` instead of ``[64, 64]``

New Features:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -62,6 +63,7 @@ Breaking Changes:

New Features:
^^^^^^^^^^^^^
- Added noisy Cross Entropy Method (CEM)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -108,6 +110,9 @@ Bug Fixes:
Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^

Release 1.5.0 (2022-03-25)
-------------------------------

Expand Down
1 change: 1 addition & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from sb3_contrib.ars import ARS
from sb3_contrib.cem import CEM
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
Expand Down
194 changes: 16 additions & 178 deletions sb3_contrib/ars/ars.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,21 @@
import copy
import sys
import time
import warnings
from functools import partial
from typing import Any, Dict, Optional, Type, TypeVar, Union

import gym
import numpy as np
import torch as th
import torch.nn.utils
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_schedule_fn, safe_mean
from stable_baselines3.common.utils import get_schedule_fn

from sb3_contrib.ars.policies import ARSPolicy, LinearPolicy, MlpPolicy
from sb3_contrib.common.policies import ESPolicy
from sb3_contrib.common.population_based_algorithm import PopulationBasedAlgorithm
from sb3_contrib.common.vec_env.async_eval import AsyncEval

ARSSelf = TypeVar("ARSSelf", bound="ARS")


class ARS(BaseAlgorithm):
class ARS(PopulationBasedAlgorithm):
"""
Augmented Random Search: https://arxiv.org/abs/1803.07055

Expand All @@ -47,14 +40,9 @@ class ARS(BaseAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""

policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"LinearPolicy": LinearPolicy,
}

def __init__(
self,
policy: Union[str, Type[ARSPolicy]],
policy: Union[str, Type[ESPolicy]],
env: Union[GymEnv, str],
n_delta: int = 8,
n_top: Optional[int] = None,
Expand All @@ -75,19 +63,19 @@ def __init__(
policy,
env,
learning_rate=learning_rate,
pop_size=2 * n_delta,
alive_bonus_offset=alive_bonus_offset,
n_eval_episodes=n_eval_episodes,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
verbose=verbose,
device=device,
supported_action_spaces=(gym.spaces.Box, gym.spaces.Discrete),
support_multi_env=True,
seed=seed,
)

self.n_delta = n_delta
self.pop_size = 2 * n_delta
self.delta_std_schedule = get_schedule_fn(delta_std)
self.n_eval_episodes = n_eval_episodes

if n_top is None:
n_top = n_delta
Expand All @@ -99,13 +87,8 @@ def __init__(

self.n_top = n_top

self.alive_bonus_offset = alive_bonus_offset
self.zero_policy = zero_policy
self.weights = None # Need to call init model to initialize weight
self.processes = None
# Keep track of how many steps where elapsed before a new rollout
# Important for syncing observation normalization between workers
self.old_count = 0

if _init_setup_model:
self._setup_model()
Expand All @@ -123,138 +106,6 @@ def _setup_model(self) -> None:
self.weights = th.zeros_like(self.weights, requires_grad=False)
self.policy.load_from_vector(self.weights.cpu())

def _mimic_monitor_wrapper(self, episode_rewards: np.ndarray, episode_lengths: np.ndarray) -> None:
"""
Helper to mimic Monitor wrapper and report episode statistics (mean reward, mean episode length).

:param episode_rewards: List containing per-episode rewards
:param episode_lengths: List containing per-episode lengths (in number of steps)
"""
# Mimic Monitor Wrapper
infos = [
{"episode": {"r": episode_reward, "l": episode_length}}
for episode_reward, episode_length in zip(episode_rewards, episode_lengths)
]

self._update_info_buffer(infos)

def _trigger_callback(
self,
_locals: Dict[str, Any],
_globals: Dict[str, Any],
callback: BaseCallback,
n_envs: int,
) -> None:
"""
Callback passed to the ``evaluate_policy()`` helper
in order to increment the number of timesteps
and trigger events in the single process version.

:param _locals:
:param _globals:
:param callback: Callback that will be called at every step
:param n_envs: Number of environments
"""
self.num_timesteps += n_envs
callback.on_step()

def evaluate_candidates(
self, candidate_weights: th.Tensor, callback: BaseCallback, async_eval: Optional[AsyncEval]
) -> th.Tensor:
"""
Evaluate each candidate.

:param candidate_weights: The candidate weights to be evaluated.
:param callback: Callback that will be called at each step
(or after evaluation in the multiprocess version)
:param async_eval: The object for asynchronous evaluation of candidates.
:return: The episodic return for each candidate.
"""

batch_steps = 0
# returns == sum of rewards
candidate_returns = th.zeros(self.pop_size, device=self.device)
train_policy = copy.deepcopy(self.policy)
# Empty buffer to show only mean over one iteration (one set of candidates) in the logs
self.ep_info_buffer = []
callback.on_rollout_start()

if async_eval is not None:
# Multiprocess asynchronous version
async_eval.send_jobs(candidate_weights, self.pop_size)
results = async_eval.get_results()

for weights_idx, (episode_rewards, episode_lengths) in results:

# Update reward to cancel out alive bonus if needed
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
batch_steps += np.sum(episode_lengths)
self._mimic_monitor_wrapper(episode_rewards, episode_lengths)

# Combine the filter stats of each process for normalization
for worker_obs_rms in async_eval.get_obs_rms():
if self._vec_normalize_env is not None:
# worker_obs_rms.count -= self.old_count
self._vec_normalize_env.obs_rms.combine(worker_obs_rms)
# Hack: don't count timesteps twice (between the two are synced)
# otherwise it will lead to overflow,
# in practice we would need two RunningMeanStats
self._vec_normalize_env.obs_rms.count -= self.old_count

# Synchronise VecNormalize if needed
if self._vec_normalize_env is not None:
async_eval.sync_obs_rms(self._vec_normalize_env.obs_rms.copy())
self.old_count = self._vec_normalize_env.obs_rms.count

# Hack to have Callback events
for _ in range(batch_steps // len(async_eval.remotes)):
self.num_timesteps += len(async_eval.remotes)
callback.on_step()
else:
# Single process, synchronous version
for weights_idx in range(self.pop_size):

# Load current candidate weights
train_policy.load_from_vector(candidate_weights[weights_idx].cpu())
# Evaluate the candidate
episode_rewards, episode_lengths = evaluate_policy(
train_policy,
self.env,
n_eval_episodes=self.n_eval_episodes,
return_episode_rewards=True,
# Increment num_timesteps too (slight mismatch with multi envs)
callback=partial(self._trigger_callback, callback=callback, n_envs=self.env.num_envs),
warn=False,
)
# Update reward to cancel out alive bonus if needed
candidate_returns[weights_idx] = sum(episode_rewards) + self.alive_bonus_offset * sum(episode_lengths)
batch_steps += sum(episode_lengths)
self._mimic_monitor_wrapper(episode_rewards, episode_lengths)

# Note: we increment the num_timesteps inside the evaluate_policy()
# however when using multiple environments, there will be a slight
# mismatch between the number of timesteps used and the number
# of calls to the step() method (cf. implementation of evaluate_policy())
# self.num_timesteps += batch_steps

callback.on_rollout_end()

return candidate_returns

def _log_and_dump(self) -> None:
"""
Dump information to the logger.
"""
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
self.logger.record("time/fps", fps)
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)

def _do_one_update(self, callback: BaseCallback, async_eval: Optional[AsyncEval]) -> None:
"""
Sample new candidates, evaluate them and then update current policy.
Expand Down Expand Up @@ -326,25 +177,12 @@ def learn(
:return: the trained model
"""

total_steps, callback = self._setup_learn(
total_timesteps,
callback,
reset_num_timesteps,
tb_log_name,
progress_bar,
return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
async_eval=async_eval,
progress_bar=progress_bar,
)

callback.on_training_start(locals(), globals())

while self.num_timesteps < total_steps:
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
self._do_one_update(callback, async_eval)
if log_interval is not None and self._n_updates % log_interval == 0:
self._log_and_dump()

if async_eval is not None:
async_eval.close()

callback.on_training_end()

return self
Loading