-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtokenizer.py
349 lines (272 loc) · 10.5 KB
/
tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
from abc import ABC, abstractmethod
import itertools
import os
import regex as re
import string
import sentencepiece as spm
import tiktoken
import torch
from tiktoken.load import load_tiktoken_bpe
from transformers import AutoTokenizer
from pathlib import Path
from typing import (
Dict,
List,
Literal,
TypedDict,
)
default_device = "cuda" if torch.cuda.is_available() else "cpu"
def is_punc_id(text):
# Define a regex pattern that matches any character that is not whitespace or punctuation
pattern = rf"^[\s{re.escape(string.punctuation)}]*$"
return bool(re.match(pattern, text))
class TokenizerInterface(ABC):
def __init__(self, model_path):
self.model_path = model_path
self.vocab = None
@abstractmethod
def encode(self, text):
pass
@abstractmethod
def decode(self, tokens):
pass
@abstractmethod
def bos_id(self):
pass
@abstractmethod
def eos_id(self):
pass
@abstractmethod
def get_terminator_ids(self):
pass
@abstractmethod
def special_ids(self) -> List[List[int]]:
pass
@abstractmethod
def __len__(self):
pass
def punctuation_ids(self):
return [i for i, wp in enumerate(self.vocab) if is_punc_id(wp)]
def get_vocab(self):
assert (
self.vocab is not None
), "Subclasses should set the vocab attribute during initialization."
return self.vocab
class SentencePieceWrapper(TokenizerInterface):
def __init__(self, model_path):
super().__init__(model_path)
self.model_path = model_path
self.processor = spm.SentencePieceProcessor(str(model_path))
self.terminator_ids = [self.processor.eos_id()]
self.vocab = [
self.processor.id_to_piece(id)
for id in range(self.processor.get_piece_size())
]
def addl_special_ids(self):
# If llama-2 in model path, return special tokens for llama-2
if "llama-2" in str(self.model_path).lower():
special_tokens = ["[INST]", "[/INST]"]
else:
raise ValueError(f"Unknown model path: {self.model_path}")
def _encode_special(token):
ids = self.processor.EncodeAsIds(token)
if len(ids) > 1:
print(f"Special token {token} was tokenized into {len(ids)} tokens")
return ids
return list(map(_encode_special, special_tokens))
def special_ids(self) -> List[List[int]]:
# Some of the chat templates aren't given a singular special token so we return a list of lists
return [
[self.processor.bos_id()],
[self.processor.eos_id()],
*self.addl_special_ids(),
]
def encode(self, text):
return self.processor.EncodeAsIds(text)
def decode(self, tokens):
return self.processor.DecodeIds(tokens)
def bos_id(self):
return self.processor.bos_id()
def eos_id(self):
return self.processor.eos_id()
def get_terminator_ids(self):
return self.terminator_ids
def __len__(self):
return self.processor.get_piece_size()
class TiktokenWrapper(TokenizerInterface):
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
num_reserved_special_tokens = 256
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
def __init__(self, model_path):
super().__init__(model_path)
assert os.path.isfile(model_path), str(model_path)
mergeable_ranks = load_tiktoken_bpe(str(model_path))
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [
f"<|reserved_special_token_{i}|>"
for i in range(5, self.num_reserved_special_tokens - 5)
]
self.special_tokens = {
token: num_base_tokens + i for i, token in enumerate(special_tokens)
}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
# BOS / EOS token IDs
self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
self._eos_id: int = self.special_tokens["<|end_of_text|>"]
self.terminator_ids = [self._eos_id, self.special_tokens["<|eot_id|>"]]
self.vocab = [self.model.decode([i]) for i in range(self.model.n_vocab)]
def encode(self, text):
return self.model.encode(text)
def special_ids(self) -> List[List[int]]:
# Some of the chat templates aren't given a singular special token so we return a list of lists
return [[x] for x in list(sorted(self.special_tokens.values()))]
def decode(self, tokens):
return self.model.decode(tokens)
def bos_id(self):
return self._bos_id
def eos_id(self):
return self._eos_id
def get_terminator_ids(self):
return self.terminator_ids
def __len__(self):
return self.model.n_vocab
class TokenizersWrapper(TokenizerInterface):
def __init__(self, model_path):
super().__init__(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.terminator_ids = [self.tokenizer.eos_token_id]
self.vocab = [
self.tokenizer.decode(i) for i in range(self.tokenizer.vocab_size)
]
def special_ids(self) -> List[List[int]]:
if hasattr(self.tokenizer, "special_token_ids"):
return [[x] for x in self.tokenizer.special_token_ids]
# Its likely a tokenizer that has a special_tokens_map attribute
special_tokens_ = list(self.tokenizer.special_tokens_map.values())
special_tokens = []
for t in special_tokens_:
if type(t) == list:
special_tokens.extend(t)
else:
special_tokens.append(t)
special_tokens = list(set(special_tokens))
return [[self.tokenizer.convert_tokens_to_ids(t)] for t in special_tokens]
def encode(self, text):
return self.tokenizer.encode(text, add_special_tokens=False)
def decode(self, tokens):
return self.tokenizer.decode(tokens)
def bos_id(self):
return self.tokenizer.bos_token_id
def eos_id(self):
return self.tokenizer.eos_token_id
def get_terminator_ids(self):
return self.terminator_ids
def __len__(self):
return len(self.tokenizer)
def get_tokenizer(tokenizer_model_path, model_name, is_chat=False):
"""
Factory function to get the appropriate tokenizer based on the model name.
Args:
- tokenizer_model_path (str): The file path to the tokenizer model.
- model_name (str): The name of the model, used to determine the tokenizer type.
Returns:
- TokenizerInterface: An instance of a tokenizer.
"""
if "llama-3" in str(model_name).lower():
return (
Llama3ChatFormat(tokenizer_model_path)
if is_chat
else TiktokenWrapper(tokenizer_model_path)
)
elif "llama-2" in str(model_name).lower():
return (
Llama2ChatFormat(tokenizer_model_path)
if is_chat
else SentencePieceWrapper(tokenizer_model_path)
)
else:
return (
TokenizersChatFormat(tokenizer_model_path)
if is_chat
else TokenizersWrapper(tokenizer_model_path)
)
Role = Literal["system", "user", "assistant"]
class Message(TypedDict):
role: Role
content: str
class Llama3ChatFormat(TiktokenWrapper):
def __init__(self, model_path):
super().__init__(model_path)
def encode_header(self, message: Message) -> List[int]:
return [
self.special_tokens["<|start_header_id|>"],
*self.encode(message["role"]),
self.special_tokens["<|end_header_id|>"],
*self.encode("\n\n"),
]
def encode_prompt(self, prompt: str):
return self.encode_dialog_prompt([{"role": "user", "content": prompt}])
def encode_message(self, message: Message) -> List[int]:
tokens = self.encode_header(message)
tokens.extend(self.encode(message["content"].strip()))
tokens.append(self.special_tokens["<|eot_id|>"])
return tokens
def encode_dialog_prompt(self, dialog: List[Message]) -> List[int]:
return [
self.special_tokens["<|begin_of_text|>"],
*list(itertools.chain(*map(self.encode_message, dialog))),
# Add the start of an assistant message for the model to complete.
*self.encode_header({"role": "assistant", "content": ""}),
]
class Llama2ChatFormat(SentencePieceWrapper):
B_INST = "[INST]"
E_INST = "[/INST]"
def __init__(self, model_path):
super().__init__(model_path)
def encode_prompt(self, prompt: str):
ids = [self.bos_id()]
ids += self.encode(Llama2ChatFormat.B_INST + "\n\n")
ids += self.encode(prompt + " " + Llama2ChatFormat.E_INST)
return ids
class TokenizersChatFormat(TokenizersWrapper):
def __init__(self, model_path):
super().__init__(model_path)
def encode_prompt(self, prompt: str):
messages = [{"role": "user", "content": prompt}]
return self.encode_dialog_prompt(messages)
def encode_dialog_prompt(self, dialog: List[Message]) -> List[int]:
text = self.tokenizer.apply_chat_template(
dialog, tokenize=False, add_generation_prompt=True
)
return self.encode(text)
def encode_tokens(tokenizer, string, bos=True, device=default_device):
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
return torch.tensor(tokens, dtype=torch.int, device=device)
def encode(tokenizer, prompt, device=default_device, bos=True, is_chat=True):
if is_chat:
tokens = tokenizer.encode_prompt(prompt)
encoded = torch.tensor(tokens, dtype=torch.int, device=device)
else:
encoded = encode_tokens(tokenizer, prompt, device=device, bos=bos)
return encoded