diff --git a/test/test_transforms.py b/test/test_transforms.py index ec413b2b34c..6a57d4faa1e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -689,6 +689,85 @@ def test_single_trans_env_check(self, out_keys): ) check_env_specs(env) + @pytest.mark.parametrize("cat_dim", [-1, -2, -3]) + @pytest.mark.parametrize("cat_N", [3, 10]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_with_permute_no_env(self, cat_dim, cat_N, device): + torch.manual_seed(cat_dim * cat_N) + pixels = torch.randn(8, 5, 3, 10, 4, device=device) + + a = TensorDict( + { + "pixels": pixels, + }, + [ + pixels.shape[0], + ], + device=device, + ) + + t0 = Compose( + CatFrames(N=cat_N, dim=cat_dim), + ) + + def get_rand_perm(ndim): + cat_dim_perm = cat_dim + # Ensure that the permutation moves the cat_dim + while cat_dim_perm == cat_dim: + perm_pos = torch.randperm(ndim) + perm = perm_pos - ndim + cat_dim_perm = (perm == cat_dim).nonzero().item() - ndim + perm_inv = perm_pos.argsort() - ndim + return perm.tolist(), perm_inv.tolist(), cat_dim_perm + + perm, perm_inv, cat_dim_perm = get_rand_perm(pixels.dim() - 1) + + t1 = Compose( + PermuteTransform(perm, in_keys=["pixels"]), + CatFrames(N=cat_N, dim=cat_dim_perm), + PermuteTransform(perm_inv, in_keys=["pixels"]), + ) + + b = t0._call(a.clone()) + c = t1._call(a.clone()) + assert (b == c).all() + + @pytest.mark.skipif(not _has_gym, reason="Test executed on gym") + @pytest.mark.parametrize("cat_dim", [-1, -2]) + def test_with_permute_env(self, cat_dim): + env0 = TransformedEnv( + GymEnv("Pendulum-v1"), + Compose( + UnsqueezeTransform(-1, in_keys=["observation"]), + CatFrames(N=4, dim=cat_dim, in_keys=["observation"]), + ), + ) + + env1 = TransformedEnv( + GymEnv("Pendulum-v1"), + Compose( + UnsqueezeTransform(-1, in_keys=["observation"]), + PermuteTransform((-1, -2), in_keys=["observation"]), + CatFrames(N=4, dim=-3 - cat_dim, in_keys=["observation"]), + PermuteTransform((-1, -2), in_keys=["observation"]), + ), + ) + + torch.manual_seed(0) + env0.set_seed(0) + td0 = env0.reset() + + torch.manual_seed(0) + env1.set_seed(0) + td1 = env1.reset() + + assert (td0 == td1).all() + + td0 = env0.step(td0.update(env0.full_action_spec.rand())) + td1 = env0.step(td0.update(env1.full_action_spec.rand())) + + assert (td0 == td1).all() + def test_serial_trans_env_check(self): env = SerialEnv( 2,