Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Tokenizer transform #2701

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
TargetReturn,
TensorDictPrimer,
TimeMaxPool,
Tokenizer,
ToTensorImage,
TrajCounter,
Transform,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
TargetReturn,
TensorDictPrimer,
TimeMaxPool,
Tokenizer,
ToTensorImage,
TrajCounter,
Transform,
Expand Down
80 changes: 75 additions & 5 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4426,8 +4426,8 @@ class UnaryTransform(Transform):
Args:
in_keys (sequence of NestedKey): the keys of inputs to the unary operation.
out_keys (sequence of NestedKey): the keys of the outputs of the unary operation.
in_keys_inv (sequence of NestedKey): the keys of inputs to the unary operation during inverse call.
out_keys_inv (sequence of NestedKey): the keys of the outputs of the unary operation durin inverse call.
in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the unary operation during inverse call.
out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the unary operation durin inverse call.

Keyword Args:
fn (Callable): the function to use as the unary operation. If it accepts
Expand Down Expand Up @@ -4569,7 +4569,6 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
input_spec["full_state_spec"],
test_input_spec,
)
print(input_spec)
return input_spec

def transform_output_spec(self, output_spec: Composite) -> Composite:
Expand Down Expand Up @@ -4649,8 +4648,8 @@ class Hash(UnaryTransform):
Args:
in_keys (sequence of NestedKey): the keys of the values to hash.
out_keys (sequence of NestedKey): the keys of the resulting hashes.
in_keys_inv (sequence of NestedKey): the keys of the values to hash during inv call.
out_keys_inv (sequence of NestedKey): the keys of the resulting hashes during inv call.
in_keys_inv (sequence of NestedKey, optional): the keys of the values to hash during inv call.
out_keys_inv (sequence of NestedKey, optional): the keys of the resulting hashes during inv call.

Keyword Args:
hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
Expand Down Expand Up @@ -4801,6 +4800,77 @@ def reproducible_hash(cls, string, seed=None):
return torch.frombuffer(hash_bytes, dtype=torch.uint8)


class Tokenizer(UnaryTransform):
r"""Applies a tokenization operation on the specified inputs.

Args:
in_keys (sequence of NestedKey): the keys of inputs to the tokenization operation.
out_keys (sequence of NestedKey): the keys of the outputs of the tokenization operation.
in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the tokenization operation during inverse call.
out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the tokenization operation during inverse call.

Keyword Args:
tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``,
"bert-base-uncased" will be used by default. If a string is provided, it should be the name of a
pre-trained tokenizer.
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before the tokenization
function is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
inputs are given directly to the tokenization function, which must support those inputs. Default is ``False``.
additional_tokens (List[str], optional): list of additional tokens to add to the tokenizer's vocabulary.
"""

def __init__(
self,
in_keys: Sequence[NestedKey],
out_keys: Sequence[NestedKey],
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
*,
tokenizer: "transformers.PretrainedTokenizerBase" = None, # noqa: F821
use_raw_nontensor: bool = False,
additional_tokens: List[str] | None = None,
):
if tokenizer is None:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
elif isinstance(tokenizer, str):
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(tokenizer)

self.tokenizer = tokenizer
if additional_tokens:
self.tokenizer.add_tokens(additional_tokens)
super().__init__(
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
fn=self.call_tokenizer_fn,
use_raw_nontensor=use_raw_nontensor,
)

@property
def device(self):
if "_device" in self.__dict__:
return self._device
parent = self.parent
if parent is None:
return None
device = parent.device
self._device = device
return device

def call_tokenizer_fn(self, value: str | List[str]):
device = self.device
out = self.tokenizer.encode(value, return_tensors="pt")
if device is not None and out.device != device:
out = out.to(device)
return out


class Stack(Transform):
"""Stacks tensors and tensordicts.

Expand Down
Loading