Skip to content

Commit

Permalink
[Feature] Add Stack transform
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler committed Nov 14, 2024
1 parent a4c1ee3 commit 6fdc2b7
Show file tree
Hide file tree
Showing 5 changed files with 419 additions and 2 deletions.
157 changes: 155 additions & 2 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
from typing import Dict, List, Optional

import torch
import torch.nn as nn
Expand All @@ -24,7 +24,12 @@
from torchrl.data.utils import consolidate_spec
from torchrl.envs.common import EnvBase
from torchrl.envs.model_based.common import ModelBasedEnvBase
from torchrl.envs.utils import _terminated_or_truncated
from torchrl.envs.utils import (
_terminated_or_truncated,
check_marl_grouping,
MarlGroupMapType,
)


spec_dict = {
"bounded": Bounded,
Expand Down Expand Up @@ -1055,6 +1060,154 @@ def _step(
return tensordict


class MultiAgentCountingEnv(EnvBase):
"""A multi-agent env that is done after a given number of steps.
All agents have identical specs.
The count is incremented by 1 on each step.
"""

def __init__(
self,
n_agents: int,
group_map: MarlGroupMapType
| Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
max_steps: int = 5,
start_val: int = 0,
**kwargs,
):
super().__init__(**kwargs)
self.max_steps = max_steps
self.start_val = start_val
self.n_agents = n_agents
self.agent_names = [f"agent_{idx}" for idx in range(n_agents)]

if isinstance(group_map, MarlGroupMapType):
group_map = group_map.get_group_map(self.agent_names)
check_marl_grouping(group_map, self.agent_names)

self.group_map = group_map

observation_specs = {}
reward_specs = {}
done_specs = {}
action_specs = {}

for group_name, agents in group_map.items():
observation_specs[group_name] = {}
reward_specs[group_name] = {}
done_specs[group_name] = {}
action_specs[group_name] = {}

for agent_name in agents:
observation_specs[group_name][agent_name] = Composite(
observation=Unbounded(
(
*self.batch_size,
3,
4,
),
dtype=torch.float32,
device=self.device,
),
shape=self.batch_size,
device=self.device,
)
reward_specs[group_name][agent_name] = Composite(
reward=Unbounded(
(
*self.batch_size,
1,
),
device=self.device,
),
shape=self.batch_size,
device=self.device,
)
done_specs[group_name][agent_name] = Composite(
done=Categorical(
2,
dtype=torch.bool,
shape=(
*self.batch_size,
1,
),
device=self.device,
),
shape=self.batch_size,
device=self.device,
)
action_specs[group_name][agent_name] = Composite(
action=Binary(n=1, shape=[*self.batch_size, 1], device=self.device),
shape=self.batch_size,
device=self.device,
)

self.observation_spec = Composite(observation_specs)
self.reward_spec = Composite(reward_specs)
self.done_spec = Composite(done_specs)
self.action_spec = Composite(action_specs)
self.register_buffer(
"count",
torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int),
)

def _set_seed(self, seed: Optional[int]):
torch.manual_seed(seed)

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
if tensordict is not None and "_reset" in tensordict.keys():
_reset = tensordict.get("_reset")
self.count[_reset] = self.start_val
else:
self.count[:] = self.start_val

source = {}
for group_name, agents in self.group_map.items():
source[group_name] = {}
for agent_name in agents:
source[group_name][agent_name] = TensorDict(
source={
"observation": torch.rand(
(*self.batch_size, 3, 4), device=self.device
),
"done": self.count > self.max_steps,
"terminated": self.count > self.max_steps,
},
batch_size=self.batch_size,
device=self.device,
)

tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device)
return tensordict

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
self.count += 1
source = {}
for group_name, agents in self.group_map.items():
source[group_name] = {}
for agent_name in agents:
source[group_name][agent_name] = TensorDict(
source={
"observation": torch.rand(
(*self.batch_size, 3, 4), device=self.device
),
"done": self.count > self.max_steps,
"terminated": self.count > self.max_steps,
"reward": torch.zeros_like(self.count, dtype=torch.float),
},
batch_size=self.batch_size,
device=self.device,
)
tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device)
return tensordict


class IncrementingEnv(CountingEnv):
# Same as CountingEnv but always increments the count by 1 regardless of the action.
def _step(
Expand Down
118 changes: 118 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
IncrementingEnv,
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
MultiAgentCountingEnv,
MultiKeyCountingEnv,
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
Expand All @@ -69,6 +70,7 @@
IncrementingEnv,
MockBatchedLockedEnv,
MockBatchedUnLockedEnv,
MultiAgentCountingEnv,
MultiKeyCountingEnv,
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
Expand Down Expand Up @@ -132,6 +134,7 @@
SerialEnv,
SignTransform,
SqueezeTransform,
Stack,
StepCounter,
TargetReturn,
TensorDictPrimer,
Expand Down Expand Up @@ -2136,6 +2139,121 @@ def test_transform_no_env(self, device, batch):
pytest.skip("TrajCounter cannot be called without env")


class TestStack(TransformBase):
def test_stack_tensors(self):
td_orig = TensorDict(
{
"key1": torch.rand(1, 3),
"key2": torch.rand(1, 3),
"key3": torch.rand(1, 3),
},
[1],
)
td = td_orig.clone()
t = Stack(
in_keys=[("key1",), ("key2",)],
out_key=("stacked",),
dim=1,
)
td = t(td)

assert ("key1",) not in td.keys()
assert ("key2",) not in td.keys()
assert ("key3",) in td.keys()
assert ("stacked",) in td.keys()

assert td["stacked"].shape == torch.Size([1, 2, 3])
assert (td["stacked"][:, 0] == td_orig["key1"]).all()
assert (td["stacked"][:, 1] == td_orig["key2"]).all()

td = t.inv(td)
assert (td == td_orig).all()

def test_stack_tensordicts(self):
def get_value():
return TensorDict(
{
"a": torch.rand(3),
"b": torch.rand(2, 4),
}
)

td_orig = TensorDict(
{
"key1": get_value(),
"key2": get_value(),
"key3": get_value(),
},
[],
)
td = td_orig.clone()
t = Stack(
in_keys=[("key1",), ("key2",)],
out_key=("stacked",),
dim=0,
)
td = t(td)

assert ("key1",) not in td.keys()
assert ("key2",) not in td.keys()
assert ("stacked", "a") in td.keys(include_nested=True)
assert ("stacked", "b") in td.keys(include_nested=True)
assert ("key3",) in td.keys()

assert td["stacked", "a"].shape == torch.Size([2, 3])
assert td["stacked", "b"].shape == torch.Size([2, 2, 4])
assert (td["stacked"][0] == td_orig["key1"]).all()
assert (td["stacked"][1] == td_orig["key2"]).all()
assert (td["key3"] == td_orig["key3"]).all()

td = t.inv(td)
assert (td == td_orig).all()

def test_stack_env(self):
base_env = MultiAgentCountingEnv(
n_agents=5,
)
check_env_specs(base_env)

t = Stack(
in_keys=[
("agents", "agent_0"),
("agents", "agent_2"),
("agents", "agent_3"),
],
out_key="stacked_agents",
)
env = TransformedEnv(base_env, t)
check_env_specs(env)

base_env.set_seed(123)
td_orig = base_env.reset()
env.set_seed(123)
td = env.reset()

td_keys = td.keys(include_nested=True)

assert ("agents", "agent_0") not in td_keys
assert ("agents", "agent_2") not in td_keys
assert ("agents", "agent_3") not in td_keys
assert ("agents", "agent_1") in td_keys
assert ("agents", "agent_4") in td_keys
assert ("stacked_agents",) in td_keys

assert (td["stacked_agents"][0] == td_orig["agents", "agent_0"]).all()
assert (td["stacked_agents"][1] == td_orig["agents", "agent_2"]).all()
assert (td["stacked_agents"][2] == td_orig["agents", "agent_3"]).all()
assert (td["agents", "agent_1"] == td_orig["agents", "agent_1"]).all()
assert (td["agents", "agent_4"] == td_orig["agents", "agent_4"]).all()

td = env.step(env.full_action_spec.rand())
td = env.rollout(6)

assert td["next", "stacked_agents", "done"].shape == torch.Size([6, 3, 1])
assert not (td["next", "stacked_agents", "done"][:-1]).any()
assert (td["next", "stacked_agents", "done"][-1]).all()


class TestCatTensors(TransformBase):
@pytest.mark.parametrize("append", [True, False])
def test_cattensors_empty(self, append):
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
SelectTransform,
SignTransform,
SqueezeTransform,
Stack,
StepCounter,
TargetReturn,
TensorDictPrimer,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
SelectTransform,
SignTransform,
SqueezeTransform,
Stack,
StepCounter,
TargetReturn,
TensorDictPrimer,
Expand Down
Loading

0 comments on commit 6fdc2b7

Please sign in to comment.