Skip to content

Commit 1224f40

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 09bc3af + f5c13b1 commit 1224f40

File tree

21 files changed

+1274
-555
lines changed

21 files changed

+1274
-555
lines changed

benchmarks/test_replaybuffer_benchmark.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -173,23 +173,29 @@ def test_rb_populate(benchmark, rb, storage, sampler, size):
173173
)
174174

175175

176-
class create_tensor_rb:
177-
def __init__(self, rb, storage, sampler, size=1_000_000, iters=100):
176+
class create_compiled_tensor_rb:
177+
def __init__(
178+
self, rb, storage, sampler, storage_size, data_size, iters, compilable=False
179+
):
178180
self.storage = storage
179181
self.rb = rb
180182
self.sampler = sampler
181-
self.size = size
183+
self.storage_size = storage_size
184+
self.data_size = data_size
182185
self.iters = iters
186+
self.compilable = compilable
183187

184188
def __call__(self):
185189
kwargs = {}
186190
if self.sampler is not None:
187191
kwargs["sampler"] = self.sampler()
188192
if self.storage is not None:
189-
kwargs["storage"] = self.storage(10 * self.size)
193+
kwargs["storage"] = self.storage(
194+
self.storage_size, compilable=self.compilable
195+
)
190196

191-
rb = self.rb(batch_size=3, **kwargs)
192-
data = torch.randn(self.size, 1)
197+
rb = self.rb(batch_size=3, compilable=self.compilable, **kwargs)
198+
data = torch.randn(self.data_size, 1)
193199
return ((rb, data, self.iters), {})
194200

195201

@@ -210,21 +216,32 @@ def fn(td):
210216

211217

212218
@pytest.mark.parametrize(
213-
"rb,storage,sampler,size,iters,compiled",
219+
"rb,storage,sampler,storage_size,data_size,iters,compiled",
214220
[
215-
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, True],
216-
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, False],
221+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, True],
222+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, False],
223+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, True],
224+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, False],
225+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, True],
226+
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, False],
217227
],
218228
)
219-
def test_rb_extend_sample(benchmark, rb, storage, sampler, size, iters, compiled):
229+
def test_rb_extend_sample(
230+
benchmark, rb, storage, sampler, storage_size, data_size, iters, compiled
231+
):
232+
if compiled:
233+
torch._dynamo.reset_code_caches()
234+
220235
benchmark.pedantic(
221236
extend_and_sample_compiled if compiled else extend_and_sample,
222-
setup=create_tensor_rb(
237+
setup=create_compiled_tensor_rb(
223238
rb=rb,
224239
storage=storage,
225240
sampler=sampler,
226-
size=size,
241+
storage_size=storage_size,
242+
data_size=data_size,
227243
iters=iters,
244+
compilable=compiled,
228245
),
229246
iterations=1,
230247
warmup_rounds=10,

docs/source/reference/data.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,11 +985,13 @@ TorchRL offers a set of classes and functions that can be used to represent tree
985985

986986
BinaryToDecimal
987987
HashToInt
988+
MCTSForeset
988989
QueryModule
989990
RandomProjectionHash
990991
SipHash
991992
TensorDictMap
992993
TensorMap
994+
Tree
993995

994996

995997
Reinforcement Learning From Human Feedback (RLHF)

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,6 @@ def _main(argv):
275275
extras_require=extra_requires,
276276
zip_safe=False,
277277
classifiers=[
278-
"Programming Language :: Python :: 3.8",
279278
"Programming Language :: Python :: 3.9",
280279
"Programming Language :: Python :: 3.10",
281280
"Programming Language :: Python :: 3.11",

test/test_cost.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7650,6 +7650,7 @@ def _create_mock_actor(
76507650
observation_key="observation",
76517651
sample_log_prob_key="sample_log_prob",
76527652
composite_action_dist=False,
7653+
aggregate_probabilities=True,
76537654
):
76547655
# Actor
76557656
action_spec = Bounded(
@@ -7668,7 +7669,7 @@ def _create_mock_actor(
76687669
"action1": (action_key, "action1"),
76697670
},
76707671
log_prob_key=sample_log_prob_key,
7671-
aggregate_probabilities=True,
7672+
aggregate_probabilities=aggregate_probabilities,
76727673
)
76737674
module_out_keys = [
76747675
("params", "action1", "loc"),
@@ -8038,6 +8039,96 @@ def test_ppo(
80388039
assert counter == 2
80398040
actor.zero_grad()
80408041

8042+
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
8043+
@pytest.mark.parametrize("gradient_mode", (True, False))
8044+
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
8045+
@pytest.mark.parametrize("device", get_default_devices())
8046+
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
8047+
@pytest.mark.parametrize("functional", [True, False])
8048+
def test_ppo_composite_no_aggregate(
8049+
self, loss_class, device, gradient_mode, advantage, td_est, functional
8050+
):
8051+
torch.manual_seed(self.seed)
8052+
td = self._create_seq_mock_data_ppo(device=device, composite_action_dist=True)
8053+
8054+
actor = self._create_mock_actor(
8055+
device=device,
8056+
composite_action_dist=True,
8057+
aggregate_probabilities=False,
8058+
)
8059+
value = self._create_mock_value(device=device)
8060+
if advantage == "gae":
8061+
advantage = GAE(
8062+
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
8063+
)
8064+
elif advantage == "vtrace":
8065+
advantage = VTrace(
8066+
gamma=0.9,
8067+
value_network=value,
8068+
actor_network=actor,
8069+
differentiable=gradient_mode,
8070+
)
8071+
elif advantage == "td":
8072+
advantage = TD1Estimator(
8073+
gamma=0.9, value_network=value, differentiable=gradient_mode
8074+
)
8075+
elif advantage == "td_lambda":
8076+
advantage = TDLambdaEstimator(
8077+
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
8078+
)
8079+
elif advantage is None:
8080+
pass
8081+
else:
8082+
raise NotImplementedError
8083+
8084+
loss_fn = loss_class(
8085+
actor,
8086+
value,
8087+
loss_critic_type="l2",
8088+
functional=functional,
8089+
)
8090+
if advantage is not None:
8091+
advantage(td)
8092+
else:
8093+
if td_est is not None:
8094+
loss_fn.make_value_estimator(td_est)
8095+
8096+
loss = loss_fn(td)
8097+
if isinstance(loss_fn, KLPENPPOLoss):
8098+
kl = loss.pop("kl_approx")
8099+
assert (kl != 0).any()
8100+
8101+
loss_critic = loss["loss_critic"]
8102+
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
8103+
loss_critic.backward(retain_graph=True)
8104+
# check that grads are independent and non null
8105+
named_parameters = loss_fn.named_parameters()
8106+
counter = 0
8107+
for name, p in named_parameters:
8108+
if p.grad is not None and p.grad.norm() > 0.0:
8109+
counter += 1
8110+
assert "actor" not in name
8111+
assert "critic" in name
8112+
if p.grad is None:
8113+
assert ("actor" in name) or ("target_" in name)
8114+
assert ("critic" not in name) or ("target_" in name)
8115+
assert counter == 2
8116+
8117+
value.zero_grad()
8118+
loss_objective.backward()
8119+
counter = 0
8120+
named_parameters = loss_fn.named_parameters()
8121+
for name, p in named_parameters:
8122+
if p.grad is not None and p.grad.norm() > 0.0:
8123+
counter += 1
8124+
assert "actor" in name
8125+
assert "critic" not in name
8126+
if p.grad is None:
8127+
assert ("actor" not in name) or ("target_" in name)
8128+
assert ("critic" in name) or ("target_" in name)
8129+
assert counter == 2
8130+
actor.zero_grad()
8131+
80418132
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
80428133
@pytest.mark.parametrize("gradient_mode", (True,))
80438134
@pytest.mark.parametrize("device", get_default_devices())

test/test_rb.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -178,18 +178,24 @@
178178
)
179179
@pytest.mark.parametrize("size", [3, 5, 100])
180180
class TestComposableBuffers:
181-
def _get_rb(self, rb_type, size, sampler, writer, storage):
181+
def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False):
182182

183183
if storage is not None:
184-
storage = storage(size)
184+
storage = storage(size, compilable=compilable)
185185

186186
sampler_args = {}
187187
if sampler is samplers.PrioritizedSampler:
188188
sampler_args = {"max_capacity": size, "alpha": 0.8, "beta": 0.9}
189189

190190
sampler = sampler(**sampler_args)
191-
writer = writer()
192-
rb = rb_type(storage=storage, sampler=sampler, writer=writer, batch_size=3)
191+
writer = writer(compilable=compilable)
192+
rb = rb_type(
193+
storage=storage,
194+
sampler=sampler,
195+
writer=writer,
196+
batch_size=3,
197+
compilable=compilable,
198+
)
193199
return rb
194200

195201
def _get_datum(self, datatype):
@@ -421,8 +427,9 @@ def data_iter():
421427
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
422428
# Our Windows CI jobs do not have "cl", so skip this test.
423429
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
430+
@pytest.mark.parametrize("avoid_max_size", [False, True])
424431
def test_extend_sample_recompile(
425-
self, rb_type, sampler, writer, storage, size, datatype
432+
self, rb_type, sampler, writer, storage, size, datatype, avoid_max_size
426433
):
427434
if rb_type is not ReplayBuffer:
428435
pytest.skip(
@@ -443,28 +450,36 @@ def test_extend_sample_recompile(
443450

444451
torch._dynamo.reset_code_caches()
445452

446-
storage_size = 10 * size
453+
# Number of times to extend the replay buffer
454+
num_extend = 10
455+
data_size = size
456+
457+
# These two cases are separated because when the max storage size is
458+
# reached, the code execution path changes, causing necessary
459+
# recompiles.
460+
if avoid_max_size:
461+
storage_size = (num_extend + 1) * data_size
462+
else:
463+
storage_size = 2 * data_size
464+
447465
rb = self._get_rb(
448466
rb_type=rb_type,
449467
sampler=sampler,
450468
writer=writer,
451469
storage=storage,
452470
size=storage_size,
471+
compilable=True,
453472
)
454-
data_size = size
455473
data = self._get_data(datatype, size=data_size)
456474

457475
@torch.compile
458476
def extend_and_sample(data):
459477
rb.extend(data)
460478
return rb.sample()
461479

462-
# Number of times to extend the replay buffer
463-
num_extend = 30
464-
465-
# NOTE: The first two calls to 'extend' and 'sample' currently cause
466-
# recompilations, so avoid capturing those for now.
467-
num_extend_before_capture = 2
480+
# NOTE: The first three calls to 'extend' and 'sample' can currently
481+
# cause recompilations, so avoid capturing those.
482+
num_extend_before_capture = 3
468483

469484
for _ in range(num_extend_before_capture):
470485
extend_and_sample(data)
@@ -477,12 +492,12 @@ def extend_and_sample(data):
477492
for _ in range(num_extend - num_extend_before_capture):
478493
extend_and_sample(data)
479494

480-
assert len(rb) == storage_size
481-
assert len(records) == 0
482-
483495
finally:
484496
torch._logging.set_logs()
485497

498+
assert len(rb) == min((num_extend * data_size), storage_size)
499+
assert len(records) == 0
500+
486501
def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
487502
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
488503
pytest.skip(
@@ -806,6 +821,52 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend):
806821
s = new_replay_buffer.sample()
807822
assert (s.exclude("index") == 1).all()
808823

824+
@pytest.mark.skipif(
825+
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
826+
)
827+
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
828+
# This test checks if the `torch._dynamo.disable` wrapper around
829+
# `TensorStorage._rand_given_ndim` is still necessary.
830+
def test__rand_given_ndim_recompile(self):
831+
torch._dynamo.reset_code_caches()
832+
833+
# Number of times to extend the replay buffer
834+
num_extend = 10
835+
data_size = 100
836+
storage_size = (num_extend + 1) * data_size
837+
sample_size = 3
838+
839+
storage = LazyTensorStorage(storage_size, compilable=True)
840+
sampler = RandomSampler()
841+
842+
# Override to avoid the `torch._dynamo.disable` wrapper
843+
storage._rand_given_ndim = storage._rand_given_ndim_impl
844+
845+
@torch.compile
846+
def extend_and_sample(data):
847+
storage.set(torch.arange(data_size) + len(storage), data)
848+
return sampler.sample(storage, sample_size)
849+
850+
data = torch.randint(100, (data_size, 1))
851+
852+
try:
853+
torch._logging.set_logs(recompiles=True)
854+
records = []
855+
capture_log_records(records, "torch._dynamo", "recompiles")
856+
857+
for _ in range(num_extend):
858+
extend_and_sample(data)
859+
860+
finally:
861+
torch._logging.set_logs()
862+
863+
assert len(storage) == num_extend * data_size
864+
assert len(records) == 8, (
865+
"If this ever decreases, that's probably good news and the "
866+
"`torch._dynamo.disable` wrapper around "
867+
"`TensorStorage._rand_given_ndim` can be removed."
868+
)
869+
809870
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
810871
def test_extend_lazystack(self, storage_type):
811872

0 commit comments

Comments
 (0)