Skip to content

Commit

Permalink
[Bugfix] enhance offset_mapping to fit more tokenizers (PaddlePaddle#…
Browse files Browse the repository at this point in the history
…3858)

* enable offset_mapping

* fix bert-like offset-mapping
  • Loading branch information
wj-Mcat authored Nov 24, 2022
1 parent 0d931ad commit 0d0eb1e
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 25 deletions.
26 changes: 26 additions & 0 deletions paddlenlp/transformers/bart/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,29 @@ def convert_tokens_to_string(self, tokens):
text = bytearray([self.byte_decoder[c]
for c in text]).decode('utf-8', errors=self.errors)
return text

def build_offset_mapping_with_special_tokens(self,
offset_mapping_0,
offset_mapping_1=None):
"""
Build offset map from a pair of offset map by concatenating and adding offsets of special tokens.
A BERT offset_mapping has the following format:
- single sequence: ``(0,0) X (0,0)``
- pair of sequences: ``(0,0) A (0,0) B (0,0)``
Args:
offset_mapping_ids_0 (List[tuple]):
List of wordpiece offsets to which the special tokens will be added.
offset_mapping_ids_1 (List[tuple], optional):
Optional second list of wordpiece offsets for offset mapping pairs. Defaults to None.
Returns:
List[tuple]: A list of wordpiece offsets with the appropriate offsets of special tokens.
"""
if offset_mapping_1 is None:
return [(0, 0)] + offset_mapping_0 + [(0, 0)]

return [(0, 0)] + offset_mapping_0 + [(0, 0), (0, 0)
] + offset_mapping_1 + [(0, 0)]
87 changes: 70 additions & 17 deletions paddlenlp/transformers/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,11 +1318,10 @@ def _batch_prepare_for_model(
batch_outputs_list[i][k] = v[i]
return batch_outputs_list

def get_offset_mapping(self, text):
def _get_bert_like_offset_mapping(self, text: str):
"""
Returns the map of tokens and the start and end index of their start and end character.
Modified from https://github.com/bojone/bert4keras/blob/master/bert4keras/tokenizers.py#L372
Args:
text (str):
Input text.
Expand All @@ -1332,21 +1331,7 @@ def get_offset_mapping(self, text):
"""
if text is None:
return None
split_tokens = []
if hasattr(self, "basic_tokenizer"):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
else:
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(
sub_token if sub_token != self.unk_token else token)
else:
for sub_token in self.wordpiece_tokenizer.tokenize(text):
split_tokens.append(
sub_token if sub_token != self.unk_token else text)
split_tokens = self.tokenize(text)

normalized_text, char_mapping = '', []

Expand Down Expand Up @@ -1393,6 +1378,74 @@ def get_offset_mapping(self, text):

return token_mapping

def get_offset_mapping(self,
text: str,
split_tokens: Optional[List[str]] = None):
"""
Returns the map of tokens and the start and end index of their start and end character.
Modified from https://github.com/bojone/bert4keras/blob/master/bert4keras/tokenizers.py#L372
Args:
text (str):
Input text.
split_tokens (Optional[List[str]]):
the tokens which has been split which can accelerate the operation.
Returns:
list: The offset map of input text.
"""
if text is None:
return None

# bert-like tokenizer use the old-school code block
if hasattr(self, "basic_tokenizer") or hasattr(self,
"wordpiece_tokenizer"):
return self._get_bert_like_offset_mapping(text)

if not split_tokens:
split_tokens = self.tokenize(text)

normalized_text, char_mapping = '', []

for i, ch in enumerate(text):
normalized_text += normalize_chars(ch)
char_mapping.extend([i] * len(ch))

text, token_mapping, offset = normalized_text, [], 0
do_lower_case = getattr(self, 'do_lower_case', False)

# lower the text if the token is lower-cased
# keep align with token
if do_lower_case:
text = text.lower()

for token in split_tokens:

# convert tokens into original string
token: str = self.convert_tokens_to_string(token).strip()

if token in self.all_special_tokens:
if do_lower_case:
token = token.lower()

# The greek letter "sigma" has 2 forms of lowercase, σ and ς respectively.
# When used as a final letter of a word, the final form (ς) is used. Otherwise, the form (σ) is used.
# https://latin.stackexchange.com/questions/6168/how-and-when-did-we-get-two-forms-of-sigma
if "σ" in token or "ς" in token:
start = text[offset:].replace("ς", "σ").index(
token.replace("ς", "σ")) + offset
else:
start = text[offset:].index(token) + offset

end = start + len(token)

token_mapping.append(
(char_mapping[start], char_mapping[end - 1] + 1))
offset = end

return token_mapping

def _decode(self,
token_ids: List[int],
skip_special_tokens: bool = False,
Expand Down
2 changes: 0 additions & 2 deletions tests/transformers/albert/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from paddlenlp.transformers.albert.tokenizer import (
AlbertEnglishTokenizer,
AlbertChineseTokenizer,
AlbertTokenizer,
)
from paddlenlp.transformers.bert.tokenizer import (
BasicTokenizer,
Expand All @@ -37,7 +36,6 @@ class AlbertEnglishTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
from_pretrained_vocab_key = "sentencepiece_model_file"
test_sentencepiece = True
test_sentencepiece_ignore_case = True
test_offsets = False

def setUp(self):
super().setUp()
Expand Down
6 changes: 1 addition & 5 deletions tests/transformers/bart/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import json
import os
import unittest
import tempfile
import shutil

from paddlenlp.transformers import BartTokenizer

Expand All @@ -31,6 +29,7 @@
class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = BartTokenizer
test_rust_tokenizer = False

from_pretrained_filter = filter_roberta_detectors

# from_pretrained_kwargs = {'add_prefix_space': True}
Expand Down Expand Up @@ -176,6 +175,3 @@ def test_special_tokens(self):

def test_pretokenized_inputs(self):
pass

def test_offsets_mapping(self):
pass
1 change: 0 additions & 1 deletion tests/transformers/gpt/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class GPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = GPTTokenizer
from_pretrained_kwargs = {"add_prefix_space": True}
test_seq2seq = False
test_offsets = False

def setUp(self):
super().setUp()
Expand Down

0 comments on commit 0d0eb1e

Please sign in to comment.