@@ -2317,9 +2317,11 @@ def rollout(
2317
2317
max_steps : int ,
2318
2318
policy : Optional [Callable [[TensorDictBase ], TensorDictBase ]] = None ,
2319
2319
callback : Optional [Callable [[TensorDictBase , ...], Any ]] = None ,
2320
+ * ,
2320
2321
auto_reset : bool = True ,
2321
2322
auto_cast_to_device : bool = False ,
2322
- break_when_any_done : bool = True ,
2323
+ break_when_any_done : bool | None = None ,
2324
+ break_when_all_done : bool | None = None ,
2323
2325
return_contiguous : bool = True ,
2324
2326
tensordict : Optional [TensorDictBase ] = None ,
2325
2327
set_truncated : bool = False ,
@@ -2342,13 +2344,16 @@ def rollout(
2342
2344
TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user
2343
2345
responsibility to save any result within the callback call if data needs to be carried over beyond
2344
2346
the call to ``rollout``.
2347
+
2348
+ Keyword Args:
2345
2349
auto_reset (bool, optional): if ``True``, resets automatically the environment
2346
2350
if it is in a done state when the rollout is initiated.
2347
2351
Default is ``True``.
2348
2352
auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the
2349
2353
policy device before the policy is used. Default is ``False``.
2350
2354
break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is
2351
2355
called on the sub-envs that are done. Default is True.
2356
+ break_when_all_done (bool): TODO
2352
2357
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
2353
2358
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
2354
2359
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
@@ -2545,6 +2550,19 @@ def rollout(
2545
2550
... )
2546
2551
2547
2552
"""
2553
+ if break_when_any_done is None : # True by default
2554
+ if break_when_all_done : # all overrides
2555
+ break_when_any_done = False
2556
+ else :
2557
+ break_when_any_done = True
2558
+ if break_when_all_done is None :
2559
+ # There is no case where break_when_all_done is True by default
2560
+ break_when_all_done = False
2561
+ if break_when_all_done and break_when_any_done :
2562
+ raise TypeError (
2563
+ "Cannot have both break_when_all_done and break_when_any_done True at the same time."
2564
+ )
2565
+
2548
2566
if policy is not None :
2549
2567
policy = _make_compatible_policy (
2550
2568
policy , self .observation_spec , env = self , fast_wrap = True
@@ -2578,8 +2596,12 @@ def rollout(
2578
2596
"env_device" : env_device ,
2579
2597
"callback" : callback ,
2580
2598
}
2581
- if break_when_any_done :
2582
- tensordicts = self ._rollout_stop_early (** kwargs )
2599
+ if break_when_any_done or break_when_all_done :
2600
+ tensordicts = self ._rollout_stop_early (
2601
+ break_when_all_done = break_when_all_done ,
2602
+ break_when_any_done = break_when_any_done ,
2603
+ ** kwargs ,
2604
+ )
2583
2605
else :
2584
2606
tensordicts = self ._rollout_nonstop (** kwargs )
2585
2607
batch_size = self .batch_size if tensordict is None else tensordict .batch_size
@@ -2639,6 +2661,8 @@ def _step_mdp(self):
2639
2661
def _rollout_stop_early (
2640
2662
self ,
2641
2663
* ,
2664
+ break_when_any_done ,
2665
+ break_when_all_done ,
2642
2666
tensordict ,
2643
2667
auto_cast_to_device ,
2644
2668
max_steps ,
@@ -2651,6 +2675,7 @@ def _rollout_stop_early(
2651
2675
if auto_cast_to_device :
2652
2676
sync_func = _get_sync_func (policy_device , env_device )
2653
2677
tensordicts = []
2678
+ partial_steps = True
2654
2679
for i in range (max_steps ):
2655
2680
if auto_cast_to_device :
2656
2681
if policy_device is not None :
@@ -2668,23 +2693,46 @@ def _rollout_stop_early(
2668
2693
tensordict .clear_device_ ()
2669
2694
tensordict = self .step (tensordict )
2670
2695
td_append = tensordict .copy ()
2696
+ if break_when_all_done :
2697
+ if partial_steps is not True :
2698
+ # At least one partial step has been done
2699
+ del td_append ["_partial_steps" ]
2700
+ td_append = torch .where (
2701
+ partial_steps .view (td_append .shape ), td_append , tensordicts [- 1 ]
2702
+ )
2703
+
2671
2704
tensordicts .append (td_append )
2672
2705
2673
2706
if i == max_steps - 1 :
2674
2707
# we don't truncate as one could potentially continue the run
2675
2708
break
2676
2709
tensordict = self ._step_mdp (tensordict )
2677
2710
2678
- # done and truncated are in done_keys
2679
- # We read if any key is done.
2680
- any_done = _terminated_or_truncated (
2681
- tensordict ,
2682
- full_done_spec = self .output_spec ["full_done_spec" ],
2683
- key = None ,
2684
- )
2685
-
2686
- if any_done :
2687
- break
2711
+ if break_when_any_done :
2712
+ # done and truncated are in done_keys
2713
+ # We read if any key is done.
2714
+ any_done = _terminated_or_truncated (
2715
+ tensordict ,
2716
+ full_done_spec = self .output_spec ["full_done_spec" ],
2717
+ key = None ,
2718
+ )
2719
+ if any_done :
2720
+ break
2721
+ else :
2722
+ _terminated_or_truncated (
2723
+ tensordict ,
2724
+ full_done_spec = self .output_spec ["full_done_spec" ],
2725
+ key = "_partial_steps" ,
2726
+ write_full_false = False ,
2727
+ )
2728
+ partial_step_curr = tensordict .get ("_partial_steps" , None )
2729
+ if partial_step_curr is not None :
2730
+ partial_step_curr = ~ partial_step_curr
2731
+ partial_steps = partial_steps & partial_step_curr
2732
+ if partial_steps is not True :
2733
+ if not partial_steps .any ():
2734
+ break
2735
+ tensordict .set ("_partial_steps" , partial_steps )
2688
2736
2689
2737
if callback is not None :
2690
2738
callback (self , tensordict )
0 commit comments