Skip to content

Commit

Permalink
[Doc] Minor fixes to the docs and type hints (#2548)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasbbrunner authored Nov 11, 2024
1 parent 58c3847 commit 50a35f6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

"""
This script reproduces the Proximal Policy Optimization (PPO) Algorithm
results from Schulman et al. 2017 for the on Atari Environments.
results from Schulman et al. 2017 for the Atari Environments.
"""
import hydra
from torchrl._utils import logger as torchrl_logger
Expand Down
25 changes: 14 additions & 11 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def append_transform(
self,
transform: "Transform" # noqa: F821
| Callable[[TensorDictBase], TensorDictBase],
) -> None:
) -> EnvBase:
"""Returns a transformed environment where the callable/transform passed is applied.
Args:
Expand Down Expand Up @@ -1482,7 +1482,8 @@ def full_state_spec(self, spec: Composite) -> None:

# Single-env specs can be used to remove the batch size from the spec
@property
def batch_dims(self):
def batch_dims(self) -> int:
"""Number of batch dimensions of the env."""
return len(self.batch_size)

def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec:
Expand Down Expand Up @@ -2444,11 +2445,11 @@ def rollout(
set_truncated: bool = False,
out=None,
trust_policy: bool = False,
):
) -> TensorDictBase:
"""Executes a rollout in the environment.
The function will stop as soon as one of the contained environments
returns done=True.
The function will return as soon as any of the contained environments
reaches any of the done states.
Args:
max_steps (int): maximum number of steps to be executed. The actual number of steps can be smaller if
Expand All @@ -2464,14 +2465,16 @@ def rollout(
the call to ``rollout``.
Keyword Args:
auto_reset (bool, optional): if ``True``, resets automatically the environment
if it is in a done state when the rollout is initiated.
Default is ``True``.
auto_reset (bool, optional): if ``True``, the contained environments will be reset before starting the
rollout. If ``False``, then the rollout will continue from a previous state, which requires the
``tensordict`` argument to be passed with the previous rollout. Default is ``True``.
auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the
policy device before the policy is used. Default is ``False``.
break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is
called on the sub-envs that are done. Default is True.
break_when_all_done (bool): TODO
break_when_any_done (bool): if ``True``, break when any of the contained environments reaches any of the
done states. If ``False``, then the done environments are reset automatically. Default is ``True``.
break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any
of the done states. If ``False``, break if at least one environment reaches any of the done states.
Default is ``False``.
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
Expand Down

1 comment on commit 50a35f6

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 50a35f6 Previous: 58c3847 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 35.743673089958456 iter/sec (stddev: 0.1671078810074134) 230.47089010079605 iter/sec (stddev: 0.0008080878444520098) 6.45

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.