Skip to content

Commit 905754b

Browse files
author
Vincent Moens
committed
[Feature] break_when_all_done in rollout
ghstack-source-id: dd464da Pull Request resolved: #2381
1 parent 5f9614c commit 905754b

File tree

2 files changed

+72
-19
lines changed

2 files changed

+72
-19
lines changed

torchrl/envs/batched_envs.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,8 +1071,9 @@ def _step(
10711071
if partial_steps is not None and partial_steps.all():
10721072
partial_steps = None
10731073
if partial_steps is not None:
1074+
partial_steps = partial_steps.view(tensordict.shape)
10741075
tensordict = tensordict[partial_steps]
1075-
workers_range = partial_steps.nonzero().squeeze().tolist()
1076+
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
10761077
tensordict_in = tensordict
10771078
# if self._use_buffers:
10781079
# shared_tensordict_parent = (
@@ -1466,8 +1467,9 @@ def _step_and_maybe_reset_no_buffers(
14661467
if partial_steps is not None and partial_steps.all():
14671468
partial_steps = None
14681469
if partial_steps is not None:
1470+
partial_steps = partial_steps.view(tensordict.shape)
14691471
tensordict = tensordict[partial_steps]
1470-
workers_range = partial_steps.nonzero().squeeze().tolist()
1472+
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
14711473
else:
14721474
workers_range = range(self.num_workers)
14731475

@@ -1516,6 +1518,7 @@ def step_and_maybe_reset(
15161518
if partial_steps is not None and partial_steps.all():
15171519
partial_steps = None
15181520
if partial_steps is not None:
1521+
partial_steps = partial_steps.view(tensordict.shape)
15191522
shared_tensordict_parent = (
15201523
self.shared_tensordict_parent._get_sub_tensordict(partial_steps)
15211524
)
@@ -1526,7 +1529,7 @@ def step_and_maybe_reset(
15261529
partial_steps
15271530
)
15281531
tensordict = tensordict[partial_steps]
1529-
workers_range = partial_steps.nonzero().squeeze().tolist()
1532+
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
15301533
else:
15311534
workers_range = range(self.num_workers)
15321535
shared_tensordict_parent = self.shared_tensordict_parent
@@ -1633,8 +1636,9 @@ def _step_no_buffers(
16331636
if partial_steps is not None and partial_steps.all():
16341637
partial_steps = None
16351638
if partial_steps is not None:
1639+
partial_steps = partial_steps.view(tensordict.shape)
16361640
tensordict = tensordict[partial_steps]
1637-
workers_range = partial_steps.nonzero().squeeze().tolist()
1641+
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
16381642
else:
16391643
workers_range = range(self.num_workers)
16401644

@@ -1674,11 +1678,12 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
16741678
if partial_steps is not None and partial_steps.all():
16751679
partial_steps = None
16761680
if partial_steps is not None:
1681+
partial_steps = partial_steps.view(tensordict.shape)
16771682
shared_tensordict_parent = (
16781683
self.shared_tensordict_parent._get_sub_tensordict(partial_steps)
16791684
)
16801685
tensordict = tensordict[partial_steps]
1681-
workers_range = partial_steps.nonzero().squeeze().tolist()
1686+
workers_range = partial_steps.nonzero(as_tuple=True)[0].tolist()
16821687
else:
16831688
workers_range = range(self.num_workers)
16841689
shared_tensordict_parent = self.shared_tensordict_parent
@@ -1693,7 +1698,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
16931698
# if we have input "next" data (eg, RNNs which pass the next state)
16941699
# the sub-envs will need to process them through step_and_maybe_reset.
16951700
# We keep track of which keys are present to let the worker know what
1696-
# should be passd to the env (we don't want to pass done states for instance)
1701+
# should be passed to the env (we don't want to pass done states for instance)
16971702
next_td_keys = list(next_td_passthrough.keys(True, True))
16981703
data = [
16991704
{"next_td_passthrough_keys": next_td_keys}

torchrl/envs/common.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2317,9 +2317,11 @@ def rollout(
23172317
max_steps: int,
23182318
policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
23192319
callback: Optional[Callable[[TensorDictBase, ...], Any]] = None,
2320+
*,
23202321
auto_reset: bool = True,
23212322
auto_cast_to_device: bool = False,
2322-
break_when_any_done: bool = True,
2323+
break_when_any_done: bool | None = None,
2324+
break_when_all_done: bool | None = None,
23232325
return_contiguous: bool = True,
23242326
tensordict: Optional[TensorDictBase] = None,
23252327
set_truncated: bool = False,
@@ -2342,13 +2344,16 @@ def rollout(
23422344
TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user
23432345
responsibility to save any result within the callback call if data needs to be carried over beyond
23442346
the call to ``rollout``.
2347+
2348+
Keyword Args:
23452349
auto_reset (bool, optional): if ``True``, resets automatically the environment
23462350
if it is in a done state when the rollout is initiated.
23472351
Default is ``True``.
23482352
auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the
23492353
policy device before the policy is used. Default is ``False``.
23502354
break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is
23512355
called on the sub-envs that are done. Default is True.
2356+
break_when_all_done (bool): TODO
23522357
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
23532358
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
23542359
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
@@ -2545,6 +2550,19 @@ def rollout(
25452550
... )
25462551
25472552
"""
2553+
if break_when_any_done is None: # True by default
2554+
if break_when_all_done: # all overrides
2555+
break_when_any_done = False
2556+
else:
2557+
break_when_any_done = True
2558+
if break_when_all_done is None:
2559+
# There is no case where break_when_all_done is True by default
2560+
break_when_all_done = False
2561+
if break_when_all_done and break_when_any_done:
2562+
raise TypeError(
2563+
"Cannot have both break_when_all_done and break_when_any_done True at the same time."
2564+
)
2565+
25482566
if policy is not None:
25492567
policy = _make_compatible_policy(
25502568
policy, self.observation_spec, env=self, fast_wrap=True
@@ -2578,8 +2596,12 @@ def rollout(
25782596
"env_device": env_device,
25792597
"callback": callback,
25802598
}
2581-
if break_when_any_done:
2582-
tensordicts = self._rollout_stop_early(**kwargs)
2599+
if break_when_any_done or break_when_all_done:
2600+
tensordicts = self._rollout_stop_early(
2601+
break_when_all_done=break_when_all_done,
2602+
break_when_any_done=break_when_any_done,
2603+
**kwargs,
2604+
)
25832605
else:
25842606
tensordicts = self._rollout_nonstop(**kwargs)
25852607
batch_size = self.batch_size if tensordict is None else tensordict.batch_size
@@ -2639,6 +2661,8 @@ def _step_mdp(self):
26392661
def _rollout_stop_early(
26402662
self,
26412663
*,
2664+
break_when_any_done,
2665+
break_when_all_done,
26422666
tensordict,
26432667
auto_cast_to_device,
26442668
max_steps,
@@ -2651,6 +2675,7 @@ def _rollout_stop_early(
26512675
if auto_cast_to_device:
26522676
sync_func = _get_sync_func(policy_device, env_device)
26532677
tensordicts = []
2678+
partial_steps = True
26542679
for i in range(max_steps):
26552680
if auto_cast_to_device:
26562681
if policy_device is not None:
@@ -2668,23 +2693,46 @@ def _rollout_stop_early(
26682693
tensordict.clear_device_()
26692694
tensordict = self.step(tensordict)
26702695
td_append = tensordict.copy()
2696+
if break_when_all_done:
2697+
if partial_steps is not True:
2698+
# At least one partial step has been done
2699+
del td_append["_partial_steps"]
2700+
td_append = torch.where(
2701+
partial_steps.view(td_append.shape), td_append, tensordicts[-1]
2702+
)
2703+
26712704
tensordicts.append(td_append)
26722705

26732706
if i == max_steps - 1:
26742707
# we don't truncate as one could potentially continue the run
26752708
break
26762709
tensordict = self._step_mdp(tensordict)
26772710

2678-
# done and truncated are in done_keys
2679-
# We read if any key is done.
2680-
any_done = _terminated_or_truncated(
2681-
tensordict,
2682-
full_done_spec=self.output_spec["full_done_spec"],
2683-
key=None,
2684-
)
2685-
2686-
if any_done:
2687-
break
2711+
if break_when_any_done:
2712+
# done and truncated are in done_keys
2713+
# We read if any key is done.
2714+
any_done = _terminated_or_truncated(
2715+
tensordict,
2716+
full_done_spec=self.output_spec["full_done_spec"],
2717+
key=None,
2718+
)
2719+
if any_done:
2720+
break
2721+
else:
2722+
_terminated_or_truncated(
2723+
tensordict,
2724+
full_done_spec=self.output_spec["full_done_spec"],
2725+
key="_partial_steps",
2726+
write_full_false=False,
2727+
)
2728+
partial_step_curr = tensordict.get("_partial_steps", None)
2729+
if partial_step_curr is not None:
2730+
partial_step_curr = ~partial_step_curr
2731+
partial_steps = partial_steps & partial_step_curr
2732+
if partial_steps is not True:
2733+
if not partial_steps.any():
2734+
break
2735+
tensordict.set("_partial_steps", partial_steps)
26882736

26892737
if callback is not None:
26902738
callback(self, tensordict)

0 commit comments

Comments
 (0)