Skip to content

Commit 50a35f6

Browse files
[Doc] Minor fixes to the docs and type hints (#2548)
1 parent 58c3847 commit 50a35f6

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

sota-implementations/ppo/ppo_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
"""
77
This script reproduces the Proximal Policy Optimization (PPO) Algorithm
8-
results from Schulman et al. 2017 for the on Atari Environments.
8+
results from Schulman et al. 2017 for the Atari Environments.
99
"""
1010
import hydra
1111
from torchrl._utils import logger as torchrl_logger

torchrl/envs/common.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def append_transform(
516516
self,
517517
transform: "Transform" # noqa: F821
518518
| Callable[[TensorDictBase], TensorDictBase],
519-
) -> None:
519+
) -> EnvBase:
520520
"""Returns a transformed environment where the callable/transform passed is applied.
521521
522522
Args:
@@ -1482,7 +1482,8 @@ def full_state_spec(self, spec: Composite) -> None:
14821482

14831483
# Single-env specs can be used to remove the batch size from the spec
14841484
@property
1485-
def batch_dims(self):
1485+
def batch_dims(self) -> int:
1486+
"""Number of batch dimensions of the env."""
14861487
return len(self.batch_size)
14871488

14881489
def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec:
@@ -2444,11 +2445,11 @@ def rollout(
24442445
set_truncated: bool = False,
24452446
out=None,
24462447
trust_policy: bool = False,
2447-
):
2448+
) -> TensorDictBase:
24482449
"""Executes a rollout in the environment.
24492450
2450-
The function will stop as soon as one of the contained environments
2451-
returns done=True.
2451+
The function will return as soon as any of the contained environments
2452+
reaches any of the done states.
24522453
24532454
Args:
24542455
max_steps (int): maximum number of steps to be executed. The actual number of steps can be smaller if
@@ -2464,14 +2465,16 @@ def rollout(
24642465
the call to ``rollout``.
24652466
24662467
Keyword Args:
2467-
auto_reset (bool, optional): if ``True``, resets automatically the environment
2468-
if it is in a done state when the rollout is initiated.
2469-
Default is ``True``.
2468+
auto_reset (bool, optional): if ``True``, the contained environments will be reset before starting the
2469+
rollout. If ``False``, then the rollout will continue from a previous state, which requires the
2470+
``tensordict`` argument to be passed with the previous rollout. Default is ``True``.
24702471
auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the
24712472
policy device before the policy is used. Default is ``False``.
2472-
break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is
2473-
called on the sub-envs that are done. Default is True.
2474-
break_when_all_done (bool): TODO
2473+
break_when_any_done (bool): if ``True``, break when any of the contained environments reaches any of the
2474+
done states. If ``False``, then the done environments are reset automatically. Default is ``True``.
2475+
break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any
2476+
of the done states. If ``False``, break if at least one environment reaches any of the done states.
2477+
Default is ``False``.
24752478
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
24762479
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
24772480
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the

0 commit comments

Comments
 (0)