Skip to content

Commit

Permalink
[Feature] Disk storage for replay buffers (#155)
Browse files Browse the repository at this point in the history
* add disk storage option for ReplayBuffer

* remove use of tmp folder

* add physical storage to Benchmark class

* docs

* amend

* amend

* amend

---------

Co-authored-by: Matteo Bettini <[email protected]>
  • Loading branch information
JoseLuisC99 and matteobettini authored Jan 26, 2025
1 parent fd41c1e commit 6382cf8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
19 changes: 16 additions & 3 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#

import pathlib

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type
Expand All @@ -13,6 +14,7 @@
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import (
Categorical,
LazyMemmapStorage,
LazyTensorStorage,
OneHot,
ReplayBuffer,
Expand Down Expand Up @@ -162,6 +164,7 @@ def get_replay_buffer(
memory_size = -(-memory_size // sequence_length)
sampling_size = -(-sampling_size // sequence_length)

# Sampler
if self.on_policy:
sampler = SamplerWithoutReplacement()
elif self.experiment_config.off_policy_use_prioritized_replay_buffer:
Expand All @@ -173,11 +176,21 @@ def get_replay_buffer(
else:
sampler = RandomSampler()

return TensorDictReplayBuffer(
storage=LazyTensorStorage(
# Storage
if self.buffer_device == "disk" and not self.on_policy:
storage = LazyMemmapStorage(
memory_size,
device=self.device,
scratch_dir=self.experiment.folder_name / f"buffer_{group}",
)
else:
storage = LazyTensorStorage(
memory_size,
device=self.device if self.on_policy else self.buffer_device,
),
)

return TensorDictReplayBuffer(
storage=storage,
sampler=sampler,
batch_size=sampling_size,
priority_key=(group, "td_error"),
Expand Down
3 changes: 2 additions & 1 deletion benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ defaults:
sampling_device: "cpu"
# The device for training (e.g. cuda)
train_device: "cpu"
# The device for the replay buffer of off-policy algorithms (e.g. cuda)
# The device for the replay buffer of off-policy algorithms (e.g. cuda).
# Use "disk" to store it on disk (in the experiment save_folder)
buffer_device: "cpu"

# Whether to share the parameters of the policy within agent groups
Expand Down
11 changes: 8 additions & 3 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import importlib

import os
import shutil
import time
from collections import deque, OrderedDict
from dataclasses import dataclass, MISSING
Expand Down Expand Up @@ -360,16 +361,16 @@ def on_policy(self) -> bool:
def _setup(self):
self.config.validate(self.on_policy)
seed_everything(self.seed)
self._perfrom_checks()
self._perform_checks()
self._set_action_type()
self._setup_name()
self._setup_task()
self._setup_algorithm()
self._setup_collector()
self._setup_name()
self._setup_logger()
self._on_setup()

def _perfrom_checks(self):
def _perform_checks(self):
for config in (self.model_config, self.critic_model_config):
if isinstance(config, SequenceModelConfig):
for layer_config in config.model_configs[1:]:
Expand Down Expand Up @@ -766,6 +767,10 @@ def close(self):
self.test_env.close()
self.logger.finish()

for buffer in self.replay_buffers.values():
if hasattr(buffer.storage, "scratch_dir"):
shutil.rmtree(buffer.storage.scratch_dir, ignore_errors=False)

def _get_excluded_keys(self, group: str):
excluded_keys = []
for other_group in self.group_map.keys():
Expand Down

0 comments on commit 6382cf8

Please sign in to comment.