Skip to content

Commit

Permalink
[Feature] single_<attr>_spec
Browse files Browse the repository at this point in the history
ghstack-source-id: 27e247ea1775e455999a114dd6d95fac748376c4
Pull Request resolved: #2549
  • Loading branch information
vmoens committed Nov 11, 2024
1 parent 19dbeeb commit 58c3847
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
16 changes: 16 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3510,6 +3510,22 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
assert (td[3].get("next") != 0).any()


def test_single_env_spec():
env = NestedCountingEnv(batch_size=[3, 1, 7])
assert not env.single_full_action_spec.shape
assert not env.single_full_done_spec.shape
assert not env.single_input_spec.shape
assert not env.single_full_observation_spec.shape
assert not env.single_output_spec.shape
assert not env.single_full_reward_spec.shape

assert env.single_action_spec.shape
assert env.single_reward_spec.shape

assert env.output_spec.is_in(env.single_output_spec.zeros(env.shape))
assert env.input_spec.is_in(env.single_input_spec.zeros(env.shape))


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
71 changes: 71 additions & 0 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,77 @@ def full_state_spec(self) -> Composite:
def full_state_spec(self, spec: Composite) -> None:
self.state_spec = spec

# Single-env specs can be used to remove the batch size from the spec
@property
def batch_dims(self):
return len(self.batch_size)

def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec:
if not self.batch_dims:
return spec
idx = tuple(0 for _ in range(self.batch_dims))
return spec[idx]

@property
def single_full_action_spec(self) -> Composite:
"""Returns the action spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.full_action_spec)

@property
def single_action_spec(self) -> TensorSpec:
"""Returns the action spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.action_spec)

@property
def single_full_observation_spec(self) -> Composite:
"""Returns the observation spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.full_action_spec)

@property
def single_observation_spec(self) -> Composite:
"""Returns the observation spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.observation_spec)

@property
def single_full_reward_spec(self) -> Composite:
"""Returns the reward spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.full_action_spec)

@property
def single_reward_spec(self) -> TensorSpec:
"""Returns the reward spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.reward_spec)

@property
def single_full_done_spec(self) -> Composite:
"""Returns the done spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.full_action_spec)

@property
def single_done_spec(self) -> TensorSpec:
"""Returns the done spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.done_spec)

@property
def single_output_spec(self) -> Composite:
"""Returns the output spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.output_spec)

@property
def single_input_spec(self) -> Composite:
"""Returns the input spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.input_spec)

@property
def single_full_state_spec(self) -> Composite:
"""Returns the state spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.full_state_spec)

@property
def single_state_spec(self) -> TensorSpec:
"""Returns the state spec of the env as if it had no batch dimensions."""
return self._make_single_env_spec(self.state_spec)

def step(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Makes a step in the environment.
Expand Down

1 comment on commit 58c3847

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 58c3847 Previous: 19dbeeb Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] 508.2675237221923 iter/sec (stddev: 0.03932827345715761) 1974.0783457474377 iter/sec (stddev: 0.00008492815988513649) 3.88

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.