Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 8, 2024
2 parents 09bc3af + f5c13b1 commit 1224f40
Show file tree
Hide file tree
Showing 21 changed files with 1,274 additions and 555 deletions.
41 changes: 29 additions & 12 deletions benchmarks/test_replaybuffer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,23 +173,29 @@ def test_rb_populate(benchmark, rb, storage, sampler, size):
)


class create_tensor_rb:
def __init__(self, rb, storage, sampler, size=1_000_000, iters=100):
class create_compiled_tensor_rb:
def __init__(
self, rb, storage, sampler, storage_size, data_size, iters, compilable=False
):
self.storage = storage
self.rb = rb
self.sampler = sampler
self.size = size
self.storage_size = storage_size
self.data_size = data_size
self.iters = iters
self.compilable = compilable

def __call__(self):
kwargs = {}
if self.sampler is not None:
kwargs["sampler"] = self.sampler()
if self.storage is not None:
kwargs["storage"] = self.storage(10 * self.size)
kwargs["storage"] = self.storage(
self.storage_size, compilable=self.compilable
)

rb = self.rb(batch_size=3, **kwargs)
data = torch.randn(self.size, 1)
rb = self.rb(batch_size=3, compilable=self.compilable, **kwargs)
data = torch.randn(self.data_size, 1)
return ((rb, data, self.iters), {})


Expand All @@ -210,21 +216,32 @@ def fn(td):


@pytest.mark.parametrize(
"rb,storage,sampler,size,iters,compiled",
"rb,storage,sampler,storage_size,data_size,iters,compiled",
[
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, True],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1000, 100, False],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, True],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 10_000, 10_000, 100, False],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, True],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 100_000, 10_000, 100, False],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, True],
[ReplayBuffer, LazyTensorStorage, RandomSampler, 1_000_000, 10_000, 100, False],
],
)
def test_rb_extend_sample(benchmark, rb, storage, sampler, size, iters, compiled):
def test_rb_extend_sample(
benchmark, rb, storage, sampler, storage_size, data_size, iters, compiled
):
if compiled:
torch._dynamo.reset_code_caches()

benchmark.pedantic(
extend_and_sample_compiled if compiled else extend_and_sample,
setup=create_tensor_rb(
setup=create_compiled_tensor_rb(
rb=rb,
storage=storage,
sampler=sampler,
size=size,
storage_size=storage_size,
data_size=data_size,
iters=iters,
compilable=compiled,
),
iterations=1,
warmup_rounds=10,
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -985,11 +985,13 @@ TorchRL offers a set of classes and functions that can be used to represent tree

BinaryToDecimal
HashToInt
MCTSForeset
QueryModule
RandomProjectionHash
SipHash
TensorDictMap
TensorMap
Tree


Reinforcement Learning From Human Feedback (RLHF)
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ def _main(argv):
extras_require=extra_requires,
zip_safe=False,
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down
93 changes: 92 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7650,6 +7650,7 @@ def _create_mock_actor(
observation_key="observation",
sample_log_prob_key="sample_log_prob",
composite_action_dist=False,
aggregate_probabilities=True,
):
# Actor
action_spec = Bounded(
Expand All @@ -7668,7 +7669,7 @@ def _create_mock_actor(
"action1": (action_key, "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
aggregate_probabilities=aggregate_probabilities,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -8038,6 +8039,96 @@ def test_ppo(
assert counter == 2
actor.zero_grad()

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("gradient_mode", (True, False))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("functional", [True, False])
def test_ppo_composite_no_aggregate(
self, loss_class, device, gradient_mode, advantage, td_est, functional
):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_ppo(device=device, composite_action_dist=True)

actor = self._create_mock_actor(
device=device,
composite_action_dist=True,
aggregate_probabilities=False,
)
value = self._create_mock_value(device=device)
if advantage == "gae":
advantage = GAE(
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
)
elif advantage == "vtrace":
advantage = VTrace(
gamma=0.9,
value_network=value,
actor_network=actor,
differentiable=gradient_mode,
)
elif advantage == "td":
advantage = TD1Estimator(
gamma=0.9, value_network=value, differentiable=gradient_mode
)
elif advantage == "td_lambda":
advantage = TDLambdaEstimator(
gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode
)
elif advantage is None:
pass
else:
raise NotImplementedError

loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
functional=functional,
)
if advantage is not None:
advantage(td)
else:
if td_est is not None:
loss_fn.make_value_estimator(td_est)

loss = loss_fn(td)
if isinstance(loss_fn, KLPENPPOLoss):
kl = loss.pop("kl_approx")
assert (kl != 0).any()

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
# check that grads are independent and non null
named_parameters = loss_fn.named_parameters()
counter = 0
for name, p in named_parameters:
if p.grad is not None and p.grad.norm() > 0.0:
counter += 1
assert "actor" not in name
assert "critic" in name
if p.grad is None:
assert ("actor" in name) or ("target_" in name)
assert ("critic" not in name) or ("target_" in name)
assert counter == 2

value.zero_grad()
loss_objective.backward()
counter = 0
named_parameters = loss_fn.named_parameters()
for name, p in named_parameters:
if p.grad is not None and p.grad.norm() > 0.0:
counter += 1
assert "actor" in name
assert "critic" not in name
if p.grad is None:
assert ("actor" not in name) or ("target_" in name)
assert ("critic" in name) or ("target_" in name)
assert counter == 2
actor.zero_grad()

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("gradient_mode", (True,))
@pytest.mark.parametrize("device", get_default_devices())
Expand Down
93 changes: 77 additions & 16 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,24 @@
)
@pytest.mark.parametrize("size", [3, 5, 100])
class TestComposableBuffers:
def _get_rb(self, rb_type, size, sampler, writer, storage):
def _get_rb(self, rb_type, size, sampler, writer, storage, compilable=False):

if storage is not None:
storage = storage(size)
storage = storage(size, compilable=compilable)

sampler_args = {}
if sampler is samplers.PrioritizedSampler:
sampler_args = {"max_capacity": size, "alpha": 0.8, "beta": 0.9}

sampler = sampler(**sampler_args)
writer = writer()
rb = rb_type(storage=storage, sampler=sampler, writer=writer, batch_size=3)
writer = writer(compilable=compilable)
rb = rb_type(
storage=storage,
sampler=sampler,
writer=writer,
batch_size=3,
compilable=compilable,
)
return rb

def _get_datum(self, datatype):
Expand Down Expand Up @@ -421,8 +427,9 @@ def data_iter():
# <https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/_inductor/cpp_builder.py#L143>
# Our Windows CI jobs do not have "cl", so skip this test.
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
@pytest.mark.parametrize("avoid_max_size", [False, True])
def test_extend_sample_recompile(
self, rb_type, sampler, writer, storage, size, datatype
self, rb_type, sampler, writer, storage, size, datatype, avoid_max_size
):
if rb_type is not ReplayBuffer:
pytest.skip(
Expand All @@ -443,28 +450,36 @@ def test_extend_sample_recompile(

torch._dynamo.reset_code_caches()

storage_size = 10 * size
# Number of times to extend the replay buffer
num_extend = 10
data_size = size

# These two cases are separated because when the max storage size is
# reached, the code execution path changes, causing necessary
# recompiles.
if avoid_max_size:
storage_size = (num_extend + 1) * data_size
else:
storage_size = 2 * data_size

rb = self._get_rb(
rb_type=rb_type,
sampler=sampler,
writer=writer,
storage=storage,
size=storage_size,
compilable=True,
)
data_size = size
data = self._get_data(datatype, size=data_size)

@torch.compile
def extend_and_sample(data):
rb.extend(data)
return rb.sample()

# Number of times to extend the replay buffer
num_extend = 30

# NOTE: The first two calls to 'extend' and 'sample' currently cause
# recompilations, so avoid capturing those for now.
num_extend_before_capture = 2
# NOTE: The first three calls to 'extend' and 'sample' can currently
# cause recompilations, so avoid capturing those.
num_extend_before_capture = 3

for _ in range(num_extend_before_capture):
extend_and_sample(data)
Expand All @@ -477,12 +492,12 @@ def extend_and_sample(data):
for _ in range(num_extend - num_extend_before_capture):
extend_and_sample(data)

assert len(rb) == storage_size
assert len(records) == 0

finally:
torch._logging.set_logs()

assert len(rb) == min((num_extend * data_size), storage_size)
assert len(records) == 0

def test_sample(self, rb_type, sampler, writer, storage, size, datatype):
if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows:
pytest.skip(
Expand Down Expand Up @@ -806,6 +821,52 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend):
s = new_replay_buffer.sample()
assert (s.exclude("index") == 1).all()

@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0"
)
@pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile")
# This test checks if the `torch._dynamo.disable` wrapper around
# `TensorStorage._rand_given_ndim` is still necessary.
def test__rand_given_ndim_recompile(self):
torch._dynamo.reset_code_caches()

# Number of times to extend the replay buffer
num_extend = 10
data_size = 100
storage_size = (num_extend + 1) * data_size
sample_size = 3

storage = LazyTensorStorage(storage_size, compilable=True)
sampler = RandomSampler()

# Override to avoid the `torch._dynamo.disable` wrapper
storage._rand_given_ndim = storage._rand_given_ndim_impl

@torch.compile
def extend_and_sample(data):
storage.set(torch.arange(data_size) + len(storage), data)
return sampler.sample(storage, sample_size)

data = torch.randint(100, (data_size, 1))

try:
torch._logging.set_logs(recompiles=True)
records = []
capture_log_records(records, "torch._dynamo", "recompiles")

for _ in range(num_extend):
extend_and_sample(data)

finally:
torch._logging.set_logs()

assert len(storage) == num_extend * data_size
assert len(records) == 8, (
"If this ever decreases, that's probably good news and the "
"`torch._dynamo.disable` wrapper around "
"`TensorStorage._rand_given_ndim` can be removed."
)

@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
def test_extend_lazystack(self, storage_type):

Expand Down
Loading

0 comments on commit 1224f40

Please sign in to comment.