Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prioritized experience replay #1622

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c607585
Started PER
AlexPasqua Jul 22, 2023
57f1192
Added "add" method + other improvements
AlexPasqua Jul 23, 2023
2b9df33
Docstrings, type hints, doc
AlexPasqua Jul 23, 2023
31dc46c
Merge branch 'master' into prioritized-experience-replay
AlexPasqua Jul 23, 2023
1a32377
Merge branch 'master' into prioritized-experience-replay
araffin Jul 24, 2023
ccf7dc3
Merge branch 'master' into prioritized-experience-replay
AlexPasqua Aug 6, 2023
aee1d30
FIxed for pytype checks (partially)
AlexPasqua Aug 6, 2023
c51b173
make format
AlexPasqua Aug 6, 2023
18c9d28
Made pytype ignore type on PER's sample method
AlexPasqua Aug 6, 2023
840dde2
Merge branch 'master' into prioritized-experience-replay
araffin Aug 30, 2023
dcfbf88
Merge branch 'master' into prioritized-experience-replay
araffin Sep 29, 2023
fb33732
Switch to numpy for the backend
araffin Sep 29, 2023
f984e5c
Move to common and add tests
araffin Sep 29, 2023
5edf8bf
Updated DQN docs
AlexPasqua Sep 30, 2023
2f76038
Update doc
araffin Oct 2, 2023
42f2f4a
Rename things to be consistent with buffers.py
araffin Oct 2, 2023
007105f
Rename variables and add priority update
araffin Oct 2, 2023
cc37cba
Ignore mypy
araffin Oct 2, 2023
b60ef03
Add beta schedule
araffin Oct 3, 2023
ec272b9
Merge branch 'master' into prioritized-experience-replay
araffin Nov 8, 2023
a043cfd
Merge branch 'master' into prioritized-experience-replay
araffin Nov 22, 2023
f6accf9
Merge branch 'master' into prioritized-experience-replay
araffin Jan 30, 2024
b21ef33
Merge branch 'master' into prioritized-experience-replay
araffin May 6, 2024
f57444a
Merge branch 'master' into prioritized-experience-replay
araffin May 24, 2024
be00231
Minor fix in PER
AlexPasqua May 24, 2024
4390ec7
Merge branch 'master' into prioritized-experience-replay
araffin Jun 7, 2024
bee9cbe
Merge branch 'master' into prioritized-experience-replay
araffin Jul 7, 2024
fb1a9f7
Only convert to numpy if needed
araffin Jul 12, 2024
150b09a
Increase min priority to avoid division by zero
araffin Jul 12, 2024
5c0c79d
Merge branch 'master' into prioritized-experience-replay
araffin Jul 17, 2024
148e4aa
Merge branch 'master' into prioritized-experience-replay
araffin Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ New Features:
- Improved error message when mixing Gym API with VecEnv API (see GH#1694)
- Add support for setting ``options`` at reset with VecEnv via the ``set_options()`` method. Same as seeds logic, options are reset at the end of an episode (@ReHoss)
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to on-policy algorithms (A2C and PPO)
- Added Prioritized Experience Replay for DQN (@AlexPasqua)


Bug Fixes:
Expand Down
12 changes: 11 additions & 1 deletion docs/modules/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ Notes
- Further reference: https://www.nature.com/articles/nature14236

.. note::
This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay.

This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN or Dueling-DQN.
To Prioritized Experience Replay, you need to pass it via the ``replay_buffer_class`` argument

Can I use?
----------
Expand All @@ -48,6 +49,15 @@ MultiBinary ❌ ✔️
Dict ❌ ✔️️
============= ====== ===========

- Rainbow DQN extensions:

- Double Q-Learning: ❌
- Prioritized Experience Replay: ✔️ (``from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer``)
- Dueling Networks: ❌
- Multi-step Learning: ❌
- Distributional RL: ✔️ (``QR-DQN`` is implemented in the SB3 contrib repo)
- Noisy Nets: ❌


Example
-------
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB
:param batch_size: Number of element to sample
:param env: associated gym VecEnv
to normalize the observations/rewards when sampling
:return:
:return: a batch of sampled experiences from the buffer.
"""
if not self.optimize_memory_usage:
return super().sample(batch_size=batch_size, env=env)
Expand Down Expand Up @@ -321,7 +321,7 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
(self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
)
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
return ReplayBufferSamples(*tuple(map(self.to_torch, data))) # type: ignore[arg-type]

@staticmethod
def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike:
Expand Down
259 changes: 259 additions & 0 deletions stable_baselines3/common/prioritized_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch as th
from gymnasium import spaces

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.type_aliases import ReplayBufferSamples
from stable_baselines3.common.utils import get_linear_fn
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize


class SumTree:
"""
SumTree data structure for Prioritized Replay Buffer.
This code is inspired by: https://github.com/Howuhh/prioritized_experience_replay

:param buffer_size: Max number of element in the buffer.
"""

def __init__(self, buffer_size: int) -> None:
self.nodes = np.zeros(2 * buffer_size - 1)
# The data array stores transition indices
self.data = np.zeros(buffer_size)
self.buffer_size = buffer_size
self.pos = 0
self.full = False

@property
def total_sum(self) -> float:
"""
Returns the root node value, which represents the total sum of all priorities in the tree.

:return: Total sum of all priorities in the tree.
"""
return self.nodes[0].item()

def update(self, leaf_node_idx: int, value: float) -> None:
"""
Update the priority of a leaf node.

:param leaf_node_idx: Index of the leaf node to update.
:param value: New priority value.
"""
idx = leaf_node_idx + self.buffer_size - 1 # child index in tree array
change = value - self.nodes[idx]
self.nodes[idx] = value
parent = (idx - 1) // 2
while parent >= 0:
self.nodes[parent] += change
parent = (parent - 1) // 2

def add(self, value: float, data: int) -> None:
"""
Add a new transition with priority value,
it adds a new leaf node and update cumulative sum.

:param value: Priority value.
:param data: Data for the new leaf node, storing transition index
in the case of the prioritized replay buffer.
"""
# Note: transition_indices should be constant
# as the replay buffer already updates a pointer
self.data[self.pos] = data
self.update(self.pos, value)
self.pos = (self.pos + 1) % self.buffer_size

def get(self, cumulative_sum: float) -> Tuple[int, float, th.Tensor]:
"""
Get a leaf node index, its priority value and transition index by cumulative_sum value.

:param cumulative_sum: Cumulative sum value.
:return: Leaf node index, its priority value and transition index.
"""
assert cumulative_sum <= self.total_sum

idx = 0
while 2 * idx + 1 < len(self.nodes):
left, right = 2 * idx + 1, 2 * idx + 2
if cumulative_sum <= self.nodes[left]:
idx = left
else:
idx = right
cumulative_sum = cumulative_sum - self.nodes[left]

leaf_node_idx = idx - self.buffer_size + 1
return leaf_node_idx, self.nodes[idx].item(), self.data[leaf_node_idx]

def __repr__(self) -> str:
return f"SumTree(nodes={self.nodes!r}, data={self.data!r})"


class PrioritizedReplayBuffer(ReplayBuffer):
"""
Prioritized Replay Buffer (proportional priorities version).
Paper: https://arxiv.org/abs/1511.05952
This code is inspired by: https://github.com/Howuhh/prioritized_experience_replay

:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param alpha: How much prioritization is used (0 - no prioritization aka uniform case, 1 - full prioritization)
:param beta: To what degree to use importance weights (0 - no corrections, 1 - full correction)
:param final_beta: Value of beta at the end of training.
Linear annealing is used to interpolate between initial value of beta and final beta.
:param min_priority: Minimum priority, prevents zero probabilities, so that all samples
always have a non-zero probability to be sampled.
"""

def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
n_envs: int = 1,
alpha: float = 0.5,
beta: float = 0.4,
final_beta: float = 1.0,
optimize_memory_usage: bool = False,
min_priority: float = 1e-8,
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs)

assert optimize_memory_usage is False, "PrioritizedReplayBuffer doesn't support optimize_memory_usage=True"

self.min_priority = 1e-8
AlexPasqua marked this conversation as resolved.
Show resolved Hide resolved
self.alpha = alpha
self.max_priority = self.min_priority # priority for new samples, init as eps
# Track the training progress remaining (from 1 to 0)
# this is used to update beta
self._current_progress_remaining = 1.0
self.inital_beta = beta
self.final_beta = final_beta
self.beta_schedule = get_linear_fn(
self.inital_beta,
self.final_beta,
end_fraction=1.0,
)
# SumTree: data structure to store priorities
self.tree = SumTree(buffer_size=buffer_size)

@property
def beta(self) -> float:
# Linear schedule
return self.beta_schedule(self._current_progress_remaining)

def add(
self,
obs: np.ndarray,
next_obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
done: np.ndarray,
infos: List[Dict[str, Any]],
) -> None:
"""
Add a new transition to the buffer.

:param obs: Starting observation of the transition to be stored.
:param next_obs: Destination observation of the transition to be stored.
:param action: Action performed in the transition to be stored.
:param reward: Reward received in the transition to be stored.
:param done: Whether the episode was finished after the transition to be stored.
:param infos: Eventual information given by the environment.
"""
# store transition index with maximum priority in sum tree
self.tree.add(self.max_priority, self.pos)

# store transition in the buffer
super().add(obs, next_obs, action, reward, done, infos)

def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
"""
Sample elements from the prioritized replay buffer.

:param batch_size: Number of element to sample
:param env:associated gym VecEnv
to normalize the observations/rewards when sampling
:return: a batch of sampled experiences from the buffer.
"""
assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires."

leaf_nodes_indices = np.zeros(batch_size, dtype=np.uint32)
priorities = np.zeros((batch_size, 1))
sample_indices = np.zeros(batch_size, dtype=np.uint32)

# To sample a minibatch of size k, the range [0, total_sum] is divided equally into k ranges.
# Next, a value is uniformly sampled from each range. Finally the transitions that correspond
# to each of these sampled values are retrieved from the tree.
segment_size = self.tree.total_sum / batch_size
for batch_idx in range(batch_size):
# extremes of the current segment
start, end = segment_size * batch_idx, segment_size * (batch_idx + 1)

# uniformely sample a value from the current segment
cumulative_sum = np.random.uniform(start, end)

# leaf_node_idx is a index of a sample in the tree, needed further to update priorities
# sample_idx is a sample index in buffer, needed further to sample actual transitions
leaf_node_idx, priority, sample_idx = self.tree.get(cumulative_sum)

leaf_nodes_indices[batch_idx] = leaf_node_idx
priorities[batch_idx] = priority
sample_indices[batch_idx] = sample_idx

# probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha
# where p_i > 0 is the priority of transition i.
probs = priorities / self.tree.total_sum

# Importance sampling weights.
# All weights w_i were scaled so that max_i w_i = 1.
weights = (self.size() * probs) ** -self.beta
weights = weights / weights.max()

# TODO: add proper support for multi env

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could we add proper support for multiple envs? Is there any idea? Does the random line below could work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure yet, the random line below might work but we need to check if it won't affect performance first.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

araffin/sbx@b5ce091 should be better, see araffin/sbx#50

# env_indices = np.random.randint(0, high=self.n_envs, size=(batch_size,))
env_indices = np.zeros(batch_size, dtype=np.uint32)

if self.optimize_memory_usage:
next_obs = self._normalize_obs(self.observations[(sample_indices + 1) % self.buffer_size, env_indices, :], env)
else:
next_obs = self._normalize_obs(self.next_observations[sample_indices, env_indices, :], env)

batch = (
self._normalize_obs(self.observations[sample_indices, env_indices, :], env),
self.actions[sample_indices, env_indices, :],
next_obs,
self.dones[sample_indices],
self.rewards[sample_indices],
weights,
)
return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type,call-arg]

def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: th.Tensor, progress_remaining: float) -> None:
"""
Update transition priorities.

:param leaf_nodes_indices: Indices for the leaf nodes to update
(correponding to the transitions)
:param td_errors: New priorities, td error in the case of
proportional prioritized replay buffer.
:param progress_remaining: Current progress remaining (starts from 1 and ends to 0)
to linearly anneal beta from its start value to 1.0 at the end of training
"""
# Update beta schedule
self._current_progress_remaining = progress_remaining
td_errors = td_errors.detach().cpu().numpy().flatten()

for leaf_node_idx, td_error in zip(leaf_nodes_indices, td_errors):
# Proportional prioritization priority = (abs(td_error) + eps) ^ alpha
# where eps is a small positive constant that prevents the edge-case of transitions not being
# revisited once their error is zero. (Section 3.3)
priority = (abs(td_error) + self.min_priority) ** self.alpha
self.tree.update(leaf_node_idx, priority)
# Update max priority for new samples
self.max_priority = max(self.max_priority, priority)
4 changes: 4 additions & 0 deletions stable_baselines3/common/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class ReplayBufferSamples(NamedTuple):
next_observations: th.Tensor
dones: th.Tensor
rewards: th.Tensor
weights: Union[th.Tensor, float] = 1.0
leaf_nodes_indices: Optional[np.ndarray] = None


class DictReplayBufferSamples(NamedTuple):
Expand All @@ -60,6 +62,8 @@ class DictReplayBufferSamples(NamedTuple):
next_observations: TensorDict
dones: th.Tensor
rewards: th.Tensor
weights: Union[th.Tensor, float] = 1.0
leaf_nodes_indices: Optional[np.ndarray] = None


class RolloutReturn(NamedTuple):
Expand Down
17 changes: 15 additions & 2 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork
Expand Down Expand Up @@ -208,8 +209,20 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Retrieve the q-values for the actions from the replay buffer
current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())

# Compute Huber loss (less sensitive to outliers)
loss = F.smooth_l1_loss(current_q_values, target_q_values)
# Special case when using PrioritizedReplayBuffer (PER)
if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
# TD error in absolute value
td_error = th.abs(current_q_values - target_q_values)
# Weighted Huber loss using importance sampling weights
loss = (replay_data.weights * th.where(td_error < 1.0, 0.5 * td_error**2, td_error - 0.5)).mean()
# Update priorities, they will be proportional to the td error
assert replay_data.leaf_nodes_indices is not None, "Node leaf node indices provided"
self.replay_buffer.update_priorities(
replay_data.leaf_nodes_indices, td_error, self._current_progress_remaining
)
else:
# Compute Huber loss (less sensitive to outliers)
loss = F.smooth_l1_loss(current_q_values, target_q_values)
Comment on lines +212 to +225
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexPasqua Ideally, we'd like to be able to associate it with all off-policy algo's without adaptation, but I don't see a simple way of doing it at this stage.
Also related, we had discussed not modifying DQN: Stable-Baselines-Team/stable-baselines3-contrib#127 (comment)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm interested in this PR. Since every algo-specific train method includes a replay_buffer.sample line, couldn't we just additionally add a replay_buffer.update line? The update function could take in the current and target q values whenever a value function is present or maybe even all the local variables. It would do nothing for the vanilla replay buffer. Would this be an acceptable modification?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment!
How do you handle the loss in your proposal?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want this to work for general off-policy algorithms, we could update the ReplayBufferSample-like classes to additionally include an importance_sampling_weight attribute which would be updated from the replay_buffer.update method.

Then I see two ways to handle the loss under this interface:

  1. Estimate TD error from the loss as such:
losses = loss_fn(current_q_values, target_q_values, reduction='none')

# e.g. If loss is L2, then it's basically th.sqrt(loss). If loss is L1, td_error = loss
td_error = importance_sampling_weight * function_to_approx_td_error(losses)  

loss = losses.mean()

Obviously the downside of this is that it requires hand engineering for the different types of loss functions or priority metrics.

  1. Make any value-based train methods "td-error" centric in the sense that we always compute td_error = importance_sampling_weight * th.abs(current_q_values - target_q_values) first, then the loss loss = loss_fn(td_error). The downsides of this approach is that we cant use the pytorch api for computing the loss, and would have to write functions for those.

Either approach requires computing a td_error variable which unfortunately requires somewhat intrusive code changes. What do you think?

Copy link
Member

@araffin araffin May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe to make things clearer: my plan is not to have PER for all algorithms, mainly for two reasons:

  1. Keep the code concise (in fact, I would like to have RAINBOW and keep vanilla DQN, see [Feature Request] RAINBOW #622)
  2. I don't think it works for entropy-RL algorithms (SAC and derivates), so it would be limited to DQN/QR-DQN and TD3

If the users really want PER in other algo, they would take inspiration from a reference implementation in SB3 and integrate it (the same way we don't provide maskable + recurrent PPO at the same time).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"just" yes, I would be happy to receive such PR =)
the main thing is to benchmark the implementation and reproduce the published results.
This PR is also still open because I was not satisfied by the result of DQN + PER (I couldn't see significant different with respect to DQN).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I had in mind was to implement CNN for SBX (https://github.com/araffin/sbx) in order to iterate faster and check the PER, but I had no time to do so until now...

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we implement the toy environment from figure 1 of https://arxiv.org/pdf/1511.05952 as the PER benchmark? It would be a simpler initial check for correctness than the Atari environments

Copy link
Member

@araffin araffin May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The toy environment can be a start for fast iteration and debugging, but what we learned in the past is that subtle bugs only show up when doing more complex task (see #48 and #47 where we found bugs like PyTorch and TF RMSProp are not the same)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, will definitely work towards it!

losses.append(loss.item())

# Optimize the policy
Expand Down
8 changes: 6 additions & 2 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer
from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize
Expand Down Expand Up @@ -109,7 +110,9 @@ def test_replay_buffer_normalization(replay_buffer_cls):
assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1)


@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer])
@pytest.mark.parametrize(
"replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer, PrioritizedReplayBuffer]
)
@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
def test_device_buffer(replay_buffer_cls, device):
if device == "cuda" and not th.cuda.is_available():
Expand All @@ -120,6 +123,7 @@ def test_device_buffer(replay_buffer_cls, device):
DictRolloutBuffer: DummyDictEnv,
ReplayBuffer: DummyEnv,
DictReplayBuffer: DummyDictEnv,
PrioritizedReplayBuffer: DummyEnv,
}[replay_buffer_cls]
env = make_vec_env(env)

Expand All @@ -140,7 +144,7 @@ def test_device_buffer(replay_buffer_cls, device):
# Get data from the buffer
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
data = buffer.get(50)
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]:
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer, PrioritizedReplayBuffer]:
data = buffer.sample(50)

# Check that all data are on the desired device
Expand Down
Loading
Loading