Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 17, 2025
1 parent dc4f6e5 commit 42ce732
Showing 1 changed file with 70 additions and 8 deletions.
78 changes: 70 additions & 8 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4426,10 +4426,12 @@ class UnaryTransform(Transform):
Args:
in_keys (sequence of NestedKey): the keys of inputs to the unary operation.
out_keys (sequence of NestedKey): the keys of the outputs of the unary operation.
fn (Callable): the function to use as the unary operation. If it accepts
a non-tensor input, it must also accept ``None``.
in_keys_inv (sequence of NestedKey): the keys of inputs to the unary operation during inverse call.
out_keys_inv (sequence of NestedKey): the keys of the outputs of the unary operation durin inverse call.
Keyword Args:
fn (Callable): the function to use as the unary operation. If it accepts
a non-tensor input, it must also accept ``None``.
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
Expand Down Expand Up @@ -4500,11 +4502,18 @@ def __init__(
self,
in_keys: Sequence[NestedKey],
out_keys: Sequence[NestedKey],
fn: Callable,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
*,
fn: Callable,
use_raw_nontensor: bool = False,
):
super().__init__(in_keys=in_keys, out_keys=out_keys)
super().__init__(
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
)
self._fn = fn
self._use_raw_nontensor = use_raw_nontensor

Expand All @@ -4519,13 +4528,50 @@ def _apply_transform(self, value):
value = value.tolist()
return self._fn(value)

def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor:
if not self._use_raw_nontensor:
if isinstance(state, NonTensorData):
if state.dim() == 0:
state = state.get("data")
else:
state = state.tolist()
elif isinstance(state, NonTensorStack):
state = state.tolist()
return self._fn(state)

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
with _set_missing_tolerance(self, True):
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset

def transform_input_spec(self, input_spec: Composite) -> Composite:
input_spec = input_spec.clone()

# Make a generic input from the spec, call the transform with that
# input, and then generate the output spec from the output.
zero_input_ = input_spec.zero()
test_input = zero_input_["full_action_spec"].update(
zero_input_["full_state_spec"]
)
test_output = self.inv(test_input)
test_input_spec = make_composite_from_td(
test_output, unsqueeze_null_shapes=False
)

input_spec["full_action_spec"] = self.transform_action_spec(
input_spec["full_action_spec"],
test_input_spec,
)
if "full_state_spec" in input_spec.keys():
input_spec["full_state_spec"] = self.transform_state_spec(
input_spec["full_state_spec"],
test_input_spec,
)
print(input_spec)
return input_spec

def transform_output_spec(self, output_spec: Composite) -> Composite:
output_spec = output_spec.clone()

Expand Down Expand Up @@ -4586,19 +4632,31 @@ def transform_done_spec(
) -> TensorSpec:
return self._transform_spec(done_spec, test_output_spec)

def transform_action_spec(
self, action_spec: TensorSpec, test_input_spec: TensorSpec
) -> TensorSpec:
return self._transform_spec(action_spec, test_input_spec)

def transform_state_spec(
self, state_spec: TensorSpec, test_input_spec: TensorSpec
) -> TensorSpec:
return self._transform_spec(state_spec, test_input_spec)


class Hash(UnaryTransform):
r"""Adds a hash value to a tensordict.
Args:
in_keys (sequence of NestedKey): the keys of the values to hash.
out_keys (sequence of NestedKey): the keys of the resulting hashes.
in_keys_inv (sequence of NestedKey): the keys of the values to hash during inv call.
out_keys_inv (sequence of NestedKey): the keys of the resulting hashes during inv call.
Keyword Args:
hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
the hash function must accept it as its second argument. Default is
``Hash.reproducible_hash``.
seed (optional): seed to use for the hash function, if it requires one.
Keyword Args:
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
Expand Down Expand Up @@ -4684,9 +4742,11 @@ def __init__(
self,
in_keys: Sequence[NestedKey],
out_keys: Sequence[NestedKey],
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
*,
hash_fn: Callable = None,
seed: Any | None = None,
*,
use_raw_nontensor: bool = False,
):
if hash_fn is None:
Expand All @@ -4697,6 +4757,8 @@ def __init__(
super().__init__(
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
fn=self.call_hash_fn,
use_raw_nontensor=use_raw_nontensor,
)
Expand Down Expand Up @@ -4725,7 +4787,7 @@ def reproducible_hash(cls, string, seed=None):
if seed is not None:
seeded_string = seed + string
else:
seeded_string = string
seeded_string = str(string)

# Create a new SHA-256 hash object
hash_object = hashlib.sha256()
Expand Down

0 comments on commit 42ce732

Please sign in to comment.