Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 20, 2024
1 parent 133d709 commit 72441e6
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 63 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False):
)

# copy action from the input tensordict to the output
transformed_env.append_transform(TensorDictPrimer(action=base_env.action_spec))
transformed_env.append_transform(TensorDictPrimer(base_env.full_action_spec))

transformed_env.append_transform(DoubleToFloat())
obsnorm = ObservationNorm(
Expand Down
105 changes: 80 additions & 25 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import functools
import os

import pytest
Expand All @@ -12,6 +13,7 @@
import torchrl.modules
from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list
from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
from tensordict.utils import assert_close
from torch import nn
from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
from torchrl.envs import (
Expand Down Expand Up @@ -938,10 +940,12 @@ def test_multi_consecutive(self, shape, python_based):
@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
@pytest.mark.parametrize("within", [False, True])
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous, within):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)
num_envs = 3
device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
lstm_module = LSTMModule(
Expand All @@ -958,25 +962,36 @@ def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
else:
cls = SerialEnv

def create_transformed_env():
primer = lstm_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
if within:

def create_transformed_env():
primer = lstm_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

else:
create_transformed_env = functools.partial(
DiscreteActionVecMockEnv,
categorical_action_encoding=True,
device=device,
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env) for _ in range(num_envs)
]
env = cls(
create_env_fn=create_transformed_env,
num_workers=2,
num_workers=num_envs,
)
if not within:
env = env.append_transform(InitTracker())
env.append_transform(lstm_module.make_tensordict_primer())

mlp = TensorDictModule(
MLP(
Expand All @@ -1002,6 +1017,19 @@ def create_transformed_env():
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
assert (data.get(("next", "recurrent_state_c")) != 0.0).all()
assert (data.get("recurrent_state_c") != 0.0).any()
return data

@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_lstm_parallel_within(self, python_based, parallel, heterogeneous):
out_within = self.test_lstm_parallel_env(
python_based, parallel, heterogeneous, within=True
)
out_not_within = self.test_lstm_parallel_env(
python_based, parallel, heterogeneous, within=False
)
assert_close(out_within, out_not_within)

@pytest.mark.skipif(
not _has_functorch, reason="vmap can only be used with functorch"
Expand Down Expand Up @@ -1330,10 +1358,12 @@ def test_multi_consecutive(self, shape, python_based):
@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
@pytest.mark.parametrize("within", [False, True])
def test_gru_parallel_env(self, python_based, parallel, heterogeneous, within):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)
num_workers = 3

device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
Expand All @@ -1347,30 +1377,42 @@ def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
python_based=python_based,
)

def create_transformed_env():
primer = gru_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
if within:

def create_transformed_env():
primer = gru_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

else:
create_transformed_env = functools.partial(
DiscreteActionVecMockEnv,
categorical_action_encoding=True,
device=device,
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

if parallel:
cls = ParallelEnv
else:
cls = SerialEnv
if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env) for _ in range(num_workers)
]

env = cls(
env: ParallelEnv | SerialEnv = cls(
create_env_fn=create_transformed_env,
num_workers=2,
num_workers=num_workers,
)
if not within:
primer = gru_module.make_tensordict_primer()
env = env.append_transform(InitTracker())
env.append_transform(primer)

mlp = TensorDictModule(
MLP(
Expand All @@ -1396,6 +1438,19 @@ def create_transformed_env():
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
assert (data.get("recurrent_state") != 0.0).any()
assert (data.get(("next", "recurrent_state")) != 0.0).all()
return data

@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_gru_parallel_within(self, python_based, parallel, heterogeneous):
out_within = self.test_gru_parallel_env(
python_based, parallel, heterogeneous, within=True
)
out_not_within = self.test_gru_parallel_env(
python_based, parallel, heterogeneous, within=False
)
assert_close(out_within, out_not_within)

@pytest.mark.skipif(
not _has_functorch, reason="vmap can only be used with functorch"
Expand Down
45 changes: 36 additions & 9 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7408,7 +7408,7 @@ def make_env():
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
env = TransformedEnv(
maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=Unbounded([2, 4])),
TensorDictPrimer(mykey=Unbounded([4])),
)
try:
check_env_specs(env)
Expand All @@ -7423,11 +7423,39 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
pass

@pytest.mark.parametrize("spec_shape", [[4], [2, 4]])
def test_trans_serial_env_check(self, spec_shape):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=Unbounded(spec_shape)),
)
@pytest.mark.parametrize("expand_specs", [True, False, None])
def test_trans_serial_env_check(self, spec_shape, expand_specs):
if expand_specs is None:
with pytest.warns(FutureWarning, match=""):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(
mykey=Unbounded(spec_shape), expand_specs=expand_specs
),
)
env.observation_spec
elif expand_specs is True:
shape = spec_shape[:-1]
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(
Composite(mykey=Unbounded(spec_shape), shape=shape),
expand_specs=expand_specs,
),
)
else:
# If we don't expand, we can't use [4]
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(
mykey=Unbounded(spec_shape), expand_specs=expand_specs
),
)
if spec_shape == [4]:
with pytest.raises(ValueError):
env.observation_spec
return

check_env_specs(env)
assert "mykey" in env.reset().keys()
r = env.rollout(3)
Expand Down Expand Up @@ -10310,9 +10338,8 @@ def _make_transform_env(self, out_key, base_env):
transform = KLRewardTransform(actor, out_keys=out_key)
return Compose(
TensorDictPrimer(
primers={
"sample_log_prob": Unbounded(shape=base_env.action_spec.shape[:-1])
}
sample_log_prob=Unbounded(shape=base_env.action_spec.shape[:-1]),
shape=base_env.shape,
),
transform,
)
Expand Down
Loading

0 comments on commit 72441e6

Please sign in to comment.