Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 21, 2025
1 parent f70bc1b commit 344e7ae
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4943,15 +4943,15 @@ def call_tokenizer_fn(self, value: str | List[str]):
if isinstance(value, str):
out = self.tokenizer.encode(value, return_tensors="pt", **kwargs)[0]
# TODO: incorporate attention mask
attention_mask = torch.ones_like(out, dtype=torch.bool)
# attention_mask = torch.ones_like(out, dtype=torch.bool)
else:
kwargs["padding"] = (
self.padding if self.max_length is None else "max_length"
)
# kwargs["return_attention_mask"] = False
# kwargs["return_token_type_ids"] = False
out = self.tokenizer.batch_encode_plus(value, return_tensors="pt", **kwargs)
attention_mask = out["attention_mask"]
# attention_mask = out["attention_mask"]
out = out["input_ids"]

if device is not None and out.device != device:
Expand Down

0 comments on commit 344e7ae

Please sign in to comment.