Skip to content

Commit

Permalink
[Test] Add tests for CatFrames with PermuteTransform
Browse files Browse the repository at this point in the history
ghstack-source-id: e554d1cda8d7e4458c9397f1f93345c855e68e5c
Pull Request resolved: #2715
  • Loading branch information
kurtamohler committed Jan 24, 2025
1 parent 80690d2 commit d4e4019
Showing 1 changed file with 79 additions and 0 deletions.
79 changes: 79 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,85 @@ def test_single_trans_env_check(self, out_keys):
)
check_env_specs(env)

@pytest.mark.parametrize("cat_dim", [-1, -2, -3])
@pytest.mark.parametrize("cat_N", [3, 10])
@pytest.mark.parametrize("device", get_default_devices())
def test_with_permute_no_env(self, cat_dim, cat_N, device):
torch.manual_seed(cat_dim * cat_N)
pixels = torch.randn(8, 5, 3, 10, 4, device=device)

a = TensorDict(
{
"pixels": pixels,
},
[
pixels.shape[0],
],
device=device,
)

t0 = Compose(
CatFrames(N=cat_N, dim=cat_dim),
)

def get_rand_perm(ndim):
cat_dim_perm = cat_dim
# Ensure that the permutation moves the cat_dim
while cat_dim_perm == cat_dim:
perm_pos = torch.randperm(ndim)
perm = perm_pos - ndim
cat_dim_perm = (perm == cat_dim).nonzero().item() - ndim
perm_inv = perm_pos.argsort() - ndim
return perm.tolist(), perm_inv.tolist(), cat_dim_perm

perm, perm_inv, cat_dim_perm = get_rand_perm(pixels.dim() - 1)

t1 = Compose(
PermuteTransform(perm, in_keys=["pixels"]),
CatFrames(N=cat_N, dim=cat_dim_perm),
PermuteTransform(perm_inv, in_keys=["pixels"]),
)

b = t0._call(a.clone())
c = t1._call(a.clone())
assert (b == c).all()

@pytest.mark.skipif(not _has_gym, reason="Test executed on gym")
@pytest.mark.parametrize("cat_dim", [-1, -2])
def test_with_permute_env(self, cat_dim):
env0 = TransformedEnv(
GymEnv("Pendulum-v1"),
Compose(
UnsqueezeTransform(-1, in_keys=["observation"]),
CatFrames(N=4, dim=cat_dim, in_keys=["observation"]),
),
)

env1 = TransformedEnv(
GymEnv("Pendulum-v1"),
Compose(
UnsqueezeTransform(-1, in_keys=["observation"]),
PermuteTransform((-1, -2), in_keys=["observation"]),
CatFrames(N=4, dim=-3 - cat_dim, in_keys=["observation"]),
PermuteTransform((-1, -2), in_keys=["observation"]),
),
)

torch.manual_seed(0)
env0.set_seed(0)
td0 = env0.reset()

torch.manual_seed(0)
env1.set_seed(0)
td1 = env1.reset()

assert (td0 == td1).all()

td0 = env0.step(td0.update(env0.full_action_spec.rand()))
td1 = env0.step(td0.update(env1.full_action_spec.rand()))

assert (td0 == td1).all()

def test_serial_trans_env_check(self):
env = SerialEnv(
2,
Expand Down

0 comments on commit d4e4019

Please sign in to comment.