Skip to content

Commit a5fb1f4

Browse files
committedJun 15, 2023
Merge remote-tracking branch 'upstream/main'
2 parents c7faff0 + 5d970c1 commit a5fb1f4

13 files changed

+606
-10
lines changed
 

‎.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ MANIFEST
3333
# Tools
3434
.mypy_cache
3535
.coverage
36+
.hypothesis
3637
htmlcov
3738

3839
# General

‎core/Cargo.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ crate-type = ["lib"]
1010

1111
[dependencies]
1212
# tiktoken dependencies
13-
fancy-regex = "0.10.0"
14-
regex = "1.7.0"
13+
fancy-regex = "0.11.0"
14+
regex = "1.8.3"
1515
rustc-hash = "1.1.0"
16-
bstr = "1.0.1"
16+
bstr = "1.5.0"
1717

1818
[features]
1919
default = []

‎pyproject.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,6 @@ macos.archs = ["x86_64", "arm64"]
4141
# Warnings will be silenced with following CIBW_TEST_SKIP
4242
test-skip = "*-macosx_arm64"
4343

44-
before-test = "pip install pytest"
45-
test-command = "pytest {project}/tests"
44+
before-test = "pip install pytest hypothesis"
45+
test-command = "pytest {project}/tests --import-mode=append"
46+

‎python/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ name = "_tiktoken"
99
crate-type = ["cdylib"]
1010

1111
[dependencies]
12-
pyo3 = { version = "0.17.3", features = ["extension-module"] }
12+
pyo3 = { version = "0.19.0", features = ["extension-module"] }
1313
tiktoken_core = { path = "../core", features = ["multithreading"] }
1414
rustc-hash = "1.1.0"

‎tests/__init__.py

Whitespace-only changes.

‎tests/test_encoding.py

+231
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# Note that there are more actual tests, they're just not currently public :-)
2+
3+
from typing import Callable
4+
5+
import hypothesis
6+
import hypothesis.strategies as st
7+
import pytest
8+
9+
import tiktoken
10+
11+
from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES
12+
13+
14+
def test_simple():
15+
enc = tiktoken.get_encoding("gpt2")
16+
assert enc.encode("hello world") == [31373, 995]
17+
assert enc.decode([31373, 995]) == "hello world"
18+
assert enc.encode("hello <|endoftext|>", allowed_special="all") == [31373, 220, 50256]
19+
20+
enc = tiktoken.get_encoding("cl100k_base")
21+
assert enc.encode("hello world") == [15339, 1917]
22+
assert enc.decode([15339, 1917]) == "hello world"
23+
assert enc.encode("hello <|endoftext|>", allowed_special="all") == [15339, 220, 100257]
24+
25+
for enc_name in tiktoken.list_encoding_names():
26+
enc = tiktoken.get_encoding(enc_name)
27+
for token in range(10_000):
28+
assert enc.encode_single_token(enc.decode_single_token_bytes(token)) == token
29+
30+
31+
def test_simple_repeated():
32+
enc = tiktoken.get_encoding("gpt2")
33+
assert enc.encode("0") == [15]
34+
assert enc.encode("00") == [405]
35+
assert enc.encode("000") == [830]
36+
assert enc.encode("0000") == [2388]
37+
assert enc.encode("00000") == [20483]
38+
assert enc.encode("000000") == [10535]
39+
assert enc.encode("0000000") == [24598]
40+
assert enc.encode("00000000") == [8269]
41+
assert enc.encode("000000000") == [10535, 830]
42+
assert enc.encode("0000000000") == [8269, 405]
43+
assert enc.encode("00000000000") == [8269, 830]
44+
assert enc.encode("000000000000") == [8269, 2388]
45+
assert enc.encode("0000000000000") == [8269, 20483]
46+
assert enc.encode("00000000000000") == [8269, 10535]
47+
assert enc.encode("000000000000000") == [8269, 24598]
48+
assert enc.encode("0000000000000000") == [25645]
49+
assert enc.encode("00000000000000000") == [8269, 10535, 830]
50+
51+
52+
def test_simple_regex():
53+
enc = tiktoken.get_encoding("cl100k_base")
54+
assert enc.encode("rer") == [38149]
55+
assert enc.encode("'rer") == [2351, 81]
56+
assert enc.encode("today\n ") == [31213, 198, 220]
57+
assert enc.encode("today\n \n") == [31213, 27907]
58+
assert enc.encode("today\n \n") == [31213, 14211]
59+
60+
61+
def test_basic_encode():
62+
enc = tiktoken.get_encoding("r50k_base")
63+
assert enc.encode("hello world") == [31373, 995]
64+
65+
enc = tiktoken.get_encoding("p50k_base")
66+
assert enc.encode("hello world") == [31373, 995]
67+
68+
enc = tiktoken.get_encoding("cl100k_base")
69+
assert enc.encode("hello world") == [15339, 1917]
70+
assert enc.encode(" \x850") == [220, 126, 227, 15]
71+
72+
73+
def test_encode_empty():
74+
enc = tiktoken.get_encoding("r50k_base")
75+
assert enc.encode("") == []
76+
77+
78+
def test_encode_bytes():
79+
enc = tiktoken.get_encoding("cl100k_base")
80+
assert enc._encode_bytes(b" \xec\x8b\xa4\xed") == [62085]
81+
82+
83+
def test_encode_surrogate_pairs():
84+
enc = tiktoken.get_encoding("cl100k_base")
85+
86+
assert enc.encode("👍") == [9468, 239, 235]
87+
# surrogate pair gets converted to codepoint
88+
assert enc.encode("\ud83d\udc4d") == [9468, 239, 235]
89+
90+
# lone surrogate just gets replaced
91+
assert enc.encode("\ud83d") == enc.encode("�")
92+
93+
94+
# ====================
95+
# Roundtrip
96+
# ====================
97+
98+
99+
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
100+
def test_basic_roundtrip(make_enc):
101+
enc = make_enc()
102+
for value in (
103+
"hello",
104+
"hello ",
105+
"hello ",
106+
" hello",
107+
" hello ",
108+
" hello ",
109+
"hello world",
110+
"请考试我的软件!12345",
111+
):
112+
assert value == enc.decode(enc.encode(value))
113+
assert value == enc.decode(enc.encode_ordinary(value))
114+
115+
116+
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
117+
@hypothesis.given(text=st.text())
118+
@hypothesis.settings(deadline=None)
119+
def test_hyp_roundtrip(make_enc: Callable[[], tiktoken.Encoding], text):
120+
enc = make_enc()
121+
122+
assert text == enc.decode(enc.encode(text))
123+
124+
125+
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
126+
def test_single_token_roundtrip(make_enc: Callable[[], tiktoken.Encoding]):
127+
enc = make_enc()
128+
129+
for token in range(enc.n_vocab):
130+
try:
131+
token_bytes = enc.decode_single_token_bytes(token)
132+
except KeyError:
133+
continue
134+
assert enc.encode_single_token(token_bytes) == token
135+
136+
137+
# ====================
138+
# Special tokens
139+
# ====================
140+
141+
142+
def test_special_token():
143+
enc = tiktoken.get_encoding("cl100k_base")
144+
145+
eot = enc.encode_single_token("<|endoftext|>")
146+
assert eot == enc.eot_token
147+
fip = enc.encode_single_token("<|fim_prefix|>")
148+
fim = enc.encode_single_token("<|fim_middle|>")
149+
150+
text = "<|endoftext|> hello <|fim_prefix|>"
151+
assert eot not in enc.encode(text, disallowed_special=())
152+
with pytest.raises(ValueError):
153+
enc.encode(text)
154+
with pytest.raises(ValueError):
155+
enc.encode(text, disallowed_special="all")
156+
with pytest.raises(ValueError):
157+
enc.encode(text, disallowed_special={"<|endoftext|>"})
158+
with pytest.raises(ValueError):
159+
enc.encode(text, disallowed_special={"<|fim_prefix|>"})
160+
161+
text = "<|endoftext|> hello <|fim_prefix|> there <|fim_middle|>"
162+
tokens = enc.encode(text, disallowed_special=())
163+
assert eot not in tokens
164+
assert fip not in tokens
165+
assert fim not in tokens
166+
167+
tokens = enc.encode(text, allowed_special="all", disallowed_special=())
168+
assert eot in tokens
169+
assert fip in tokens
170+
assert fim in tokens
171+
172+
tokens = enc.encode(text, allowed_special="all", disallowed_special="all")
173+
assert eot in tokens
174+
assert fip in tokens
175+
assert fim in tokens
176+
177+
tokens = enc.encode(text, allowed_special={"<|fim_prefix|>"}, disallowed_special=())
178+
assert eot not in tokens
179+
assert fip in tokens
180+
assert fim not in tokens
181+
182+
tokens = enc.encode(text, allowed_special={"<|endoftext|>"}, disallowed_special=())
183+
assert eot in tokens
184+
assert fip not in tokens
185+
assert fim not in tokens
186+
187+
tokens = enc.encode(text, allowed_special={"<|fim_middle|>"}, disallowed_special=())
188+
assert eot not in tokens
189+
assert fip not in tokens
190+
assert fim in tokens
191+
192+
193+
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
194+
@hypothesis.given(text=st.text())
195+
@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES)
196+
def test_hyp_special_ordinary(make_enc, text: str):
197+
enc = make_enc()
198+
assert enc.encode_ordinary(text) == enc.encode(text, disallowed_special=())
199+
200+
201+
# ====================
202+
# Batch encoding
203+
# ====================
204+
205+
206+
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
207+
def test_batch_encode(make_enc: Callable[[], tiktoken.Encoding]):
208+
enc = make_enc()
209+
text1 = "hello world"
210+
text2 = "goodbye world"
211+
212+
assert enc.encode_batch([text1]) == [enc.encode(text1)]
213+
assert enc.encode_batch([text1, text2]) == [enc.encode(text1), enc.encode(text2)]
214+
215+
assert enc.encode_ordinary_batch([text1]) == [enc.encode_ordinary(text1)]
216+
assert enc.encode_ordinary_batch([text1, text2]) == [
217+
enc.encode_ordinary(text1),
218+
enc.encode_ordinary(text2),
219+
]
220+
221+
222+
@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES)
223+
@hypothesis.given(batch=st.lists(st.text()))
224+
@hypothesis.settings(deadline=None)
225+
def test_hyp_batch_roundtrip(make_enc: Callable[[], tiktoken.Encoding], batch):
226+
enc = make_enc()
227+
228+
encoded = enc.encode_batch(batch)
229+
assert encoded == [enc.encode(t) for t in batch]
230+
decoded = enc.decode_batch(encoded)
231+
assert decoded == batch

‎tests/test_helpers.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import bisect
2+
import functools
3+
import os
4+
5+
import pytest
6+
7+
import tiktoken
8+
9+
MAX_EXAMPLES: int = int(os.environ.get("TIKTOKEN_MAX_EXAMPLES", "100"))
10+
11+
ENCODINGS = ["r50k_base", "cl100k_base"]
12+
SOME_ENCODINGS = ["cl100k_base"]
13+
14+
15+
ENCODING_FACTORIES = [
16+
pytest.param(functools.partial(tiktoken.get_encoding, name), id=name) for name in ENCODINGS
17+
]
18+
SOME_ENCODING_FACTORIES = [
19+
pytest.param(functools.partial(tiktoken.get_encoding, name), id=name) for name in SOME_ENCODINGS
20+
]
21+
22+

‎tests/test_misc.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import subprocess
2+
import sys
3+
4+
import tiktoken
5+
6+
7+
def test_encoding_for_model():
8+
enc = tiktoken.encoding_for_model("gpt2")
9+
assert enc.name == "gpt2"
10+
enc = tiktoken.encoding_for_model("text-davinci-003")
11+
assert enc.name == "p50k_base"
12+
enc = tiktoken.encoding_for_model("text-davinci-edit-001")
13+
assert enc.name == "p50k_edit"
14+
enc = tiktoken.encoding_for_model("gpt-3.5-turbo-0301")
15+
assert enc.name == "cl100k_base"
16+
17+
18+
def test_optional_blobfile_dependency():
19+
prog = """
20+
import tiktoken
21+
import sys
22+
assert "blobfile" not in sys.modules
23+
"""
24+
subprocess.check_call([sys.executable, "-c", prog])

‎tests/test_offsets.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from typing import Callable
2+
3+
import hypothesis
4+
import pytest
5+
from hypothesis import strategies as st
6+
7+
import tiktoken
8+
9+
from .test_helpers import MAX_EXAMPLES, SOME_ENCODING_FACTORIES
10+
11+
12+
def _common_prefix_len(a, b):
13+
i = 0
14+
while i < len(a) and i < len(b) and a[i] == b[i]:
15+
i += 1
16+
return i
17+
18+
19+
def _token_offsets_reference(enc, tokens):
20+
text = enc.decode(tokens, errors="strict")
21+
res = []
22+
for i in range(len(tokens)):
23+
prefix = enc.decode(tokens[:i], errors="ignore")
24+
res.append(_common_prefix_len(text, prefix))
25+
return res
26+
27+
28+
@pytest.mark.parametrize("make_enc", SOME_ENCODING_FACTORIES)
29+
@hypothesis.given(data=st.data())
30+
@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES)
31+
def test_hyp_offsets(make_enc: Callable[[], tiktoken.Encoding], data):
32+
enc = make_enc()
33+
34+
tokens_st = st.lists(
35+
st.integers(0, enc.n_vocab - 1).filter(
36+
lambda x: x in enc._special_tokens.values() or x in enc._mergeable_ranks.values()
37+
),
38+
min_size=1,
39+
max_size=20,
40+
)
41+
tokens = data.draw(tokens_st)
42+
43+
# This is a dumb hack to make sure that our tokens are a valid UTF-8 string
44+
# We could potentially drop this, see the TODO in decode_with_offsets
45+
tokens = enc.encode(enc.decode(tokens, errors="ignore"), allowed_special="all")
46+
assert enc.decode_with_offsets(tokens)[1] == _token_offsets_reference(enc, tokens)
47+
48+
49+
def test_basic_offsets():
50+
enc = tiktoken.get_encoding("cl100k_base")
51+
52+
prompt = "hello world"
53+
p, o = enc.decode_with_offsets(enc.encode(prompt))
54+
assert p == prompt
55+
assert o == [0, 5]
56+
57+
prompt = "hello world<|endoftext|> green cow"
58+
p, o = enc.decode_with_offsets(enc.encode(prompt, allowed_special="all"))
59+
assert p == prompt
60+
assert o == [0, 5, 11, 24, 30]
61+
62+
prompt = "我非常渴望与人工智能一起工作"
63+
p, o = enc.decode_with_offsets(enc.encode(prompt))
64+
assert p == prompt
65+
assert o == [0, 1, 2, 3, 3, 4, 4, 5, 6, 7, 8, 8, 9, 10, 11, 12, 13]
66+
67+
# contains the interesting tokens b'\xe0\xae\xbf\xe0\xae' and b'\xe0\xaf\x8d\xe0\xae'
68+
# in which \xe0 is the start of a 3-byte UTF-8 character
69+
prompt = "நடிகர் சூர்யா"
70+
p, o = enc.decode_with_offsets(enc.encode(prompt))
71+
assert p == prompt
72+
assert o == [0, 0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 8, 8, 9, 9, 10, 11, 12, 12]
73+
74+
# contains the interesting token b'\xa0\xe9\x99\xa4'
75+
# in which \xe9 is the start of a 3-byte UTF-8 character and \xa0 is a continuation byte
76+
prompt = " Ġ除"
77+
p, o = enc.decode_with_offsets(enc.encode(prompt))
78+
assert p == prompt
79+
assert o == [0, 1]

‎tiktoken/_educational.py

+212
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""This is an educational implementation of the byte pair encoding algorithm."""
2+
from __future__ import annotations
3+
4+
import collections
5+
import itertools
6+
from typing import Optional
7+
8+
import regex
9+
10+
import tiktoken
11+
12+
13+
class SimpleBytePairEncoding:
14+
def __init__(self, *, pat_str: str, mergeable_ranks: dict[bytes, int]) -> None:
15+
"""Creates an Encoding object."""
16+
# A regex pattern string that is used to split the input text
17+
self.pat_str = pat_str
18+
# A dictionary mapping token bytes to their ranks. The ranks correspond to merge priority
19+
self.mergeable_ranks = mergeable_ranks
20+
21+
self._decoder = {token: token_bytes for token_bytes, token in mergeable_ranks.items()}
22+
self._pat = regex.compile(pat_str)
23+
24+
def encode(self, text: str, visualise: Optional[str] = "colour") -> list[int]:
25+
"""Encodes a string into tokens.
26+
27+
>>> enc.encode("hello world")
28+
[388, 372]
29+
"""
30+
# Use the regex to split the text into (approximately) words
31+
words = self._pat.findall(text)
32+
tokens = []
33+
for word in words:
34+
# Turn each word into tokens, using the byte pair encoding algorithm
35+
word_bytes = word.encode("utf-8")
36+
word_tokens = bpe_encode(self.mergeable_ranks, word_bytes, visualise=visualise)
37+
tokens.extend(word_tokens)
38+
return tokens
39+
40+
def decode_bytes(self, tokens: list[int]) -> bytes:
41+
"""Decodes a list of tokens into bytes.
42+
43+
>>> enc.decode_bytes([388, 372])
44+
b'hello world'
45+
"""
46+
return b"".join(self._decoder[token] for token in tokens)
47+
48+
def decode(self, tokens: list[int]) -> str:
49+
"""Decodes a list of tokens into a string.
50+
51+
Decoded bytes are not guaranteed to be valid UTF-8. In that case, we replace
52+
the invalid bytes with the replacement character "�".
53+
54+
>>> enc.decode([388, 372])
55+
'hello world'
56+
"""
57+
return self.decode_bytes(tokens).decode("utf-8", errors="replace")
58+
59+
def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]:
60+
"""Decodes a list of tokens into a list of bytes.
61+
62+
Useful for visualising how a string is tokenised.
63+
64+
>>> enc.decode_tokens_bytes([388, 372])
65+
[b'hello', b' world']
66+
"""
67+
return [self._decoder[token] for token in tokens]
68+
69+
@staticmethod
70+
def train(training_data: str, vocab_size: int, pat_str: str):
71+
"""Train a BPE tokeniser on some data!"""
72+
mergeable_ranks = bpe_train(data=training_data, vocab_size=vocab_size, pat_str=pat_str)
73+
return SimpleBytePairEncoding(pat_str=pat_str, mergeable_ranks=mergeable_ranks)
74+
75+
@staticmethod
76+
def from_tiktoken(encoding):
77+
if isinstance(encoding, str):
78+
encoding = tiktoken.get_encoding(encoding)
79+
return SimpleBytePairEncoding(
80+
pat_str=encoding._pat_str, mergeable_ranks=encoding._mergeable_ranks
81+
)
82+
83+
84+
def bpe_encode(
85+
mergeable_ranks: dict[bytes, int], input: bytes, visualise: Optional[str] = "colour"
86+
) -> list[int]:
87+
parts = [bytes([b]) for b in input]
88+
while True:
89+
# See the intermediate merges play out!
90+
if visualise:
91+
if visualise in ["colour", "color"]:
92+
visualise_tokens(parts)
93+
elif visualise == "simple":
94+
print(parts)
95+
96+
# Iterate over all pairs and find the pair we want to merge the most
97+
min_idx = None
98+
min_rank = None
99+
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
100+
rank = mergeable_ranks.get(pair[0] + pair[1])
101+
if rank is not None and (min_rank is None or rank < min_rank):
102+
min_idx = i
103+
min_rank = rank
104+
105+
# If there were no pairs we could merge, we're done!
106+
if min_rank is None:
107+
break
108+
assert min_idx is not None
109+
110+
# Otherwise, merge that pair and leave the rest unchanged. Then repeat.
111+
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]
112+
113+
if visualise:
114+
print()
115+
116+
tokens = [mergeable_ranks[part] for part in parts]
117+
return tokens
118+
119+
120+
def bpe_train(
121+
data: str, vocab_size: int, pat_str: str, visualise: Optional[str] = "colour"
122+
) -> dict[bytes, int]:
123+
# First, add tokens for each individual byte value
124+
if vocab_size < 2**8:
125+
raise ValueError("vocab_size must be at least 256, so we can encode all bytes")
126+
ranks = {}
127+
for i in range(2**8):
128+
ranks[bytes([i])] = i
129+
130+
# Splinter up our data into lists of bytes
131+
# data = "Hello world"
132+
# words = [
133+
# [b'H', b'e', b'l', b'l', b'o'],
134+
# [b' ', b'w', b'o', b'r', b'l', b'd']
135+
# ]
136+
words: list[list[bytes]] = [
137+
[bytes([b]) for b in word.encode("utf-8")] for word in regex.findall(pat_str, data)
138+
]
139+
140+
# Now, use our data to figure out which merges we should make
141+
while len(ranks) < vocab_size:
142+
# Find the most common pair. This will become our next token
143+
stats = collections.Counter()
144+
for piece in words:
145+
for pair in zip(piece[:-1], piece[1:]):
146+
stats[pair] += 1
147+
148+
most_common_pair = max(stats, key=lambda x: stats[x])
149+
token_bytes = most_common_pair[0] + most_common_pair[1]
150+
token = len(ranks)
151+
# Add the new token!
152+
ranks[token_bytes] = token
153+
154+
# Now merge that most common pair in all the words. That is, update our training data
155+
# to reflect our decision to make that pair into a new token.
156+
new_words = []
157+
for word in words:
158+
new_word = []
159+
i = 0
160+
while i < len(word) - 1:
161+
if (word[i], word[i + 1]) == most_common_pair:
162+
# We found our pair! Merge it
163+
new_word.append(token_bytes)
164+
i += 2
165+
else:
166+
new_word.append(word[i])
167+
i += 1
168+
if i == len(word) - 1:
169+
new_word.append(word[i])
170+
new_words.append(new_word)
171+
words = new_words
172+
173+
# See the intermediate merges play out!
174+
if visualise:
175+
print(f"The current most common pair is {most_common_pair[0]} + {most_common_pair[1]}")
176+
print(f"So we made {token_bytes} our {len(ranks)}th token")
177+
if visualise in ["colour", "color"]:
178+
print("Now the first fifty words in our training data look like:")
179+
visualise_tokens([token for word in words[:50] for token in word])
180+
elif visualise == "simple":
181+
print("Now the first twenty words in our training data look like:")
182+
for word in words[:20]:
183+
print(word)
184+
print("\n")
185+
186+
return ranks
187+
188+
189+
def visualise_tokens(token_values: list[bytes]) -> None:
190+
backgrounds = itertools.cycle(
191+
[f"\u001b[48;5;{i}m".encode() for i in [167, 179, 185, 77, 80, 68, 134]]
192+
)
193+
interleaved = itertools.chain.from_iterable(zip(backgrounds, token_values))
194+
print((b"".join(interleaved) + "\u001b[0m".encode()).decode("utf-8"))
195+
196+
197+
def train_simple_encoding():
198+
gpt2_pattern = (
199+
r"""'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
200+
)
201+
with open(__file__, "r") as f:
202+
data = f.read()
203+
204+
enc = SimpleBytePairEncoding.train(data, vocab_size=600, pat_str=gpt2_pattern)
205+
206+
print("This is the sequence of merges performed in order to encode 'hello world':")
207+
tokens = enc.encode("hello world")
208+
assert enc.decode(tokens) == "hello world"
209+
assert enc.decode_bytes(tokens) == b"hello world"
210+
assert enc.decode_tokens_bytes(tokens) == [b"hello", b" world"]
211+
212+
return enc

‎tiktoken/core.py

+25
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,31 @@ def decode_tokens_bytes(self, tokens: list[int]) -> list[bytes]:
276276
"""
277277
return [self.decode_single_token_bytes(token) for token in tokens]
278278

279+
def decode_with_offsets(self, tokens: list[int]) -> tuple[str, list[int]]:
280+
"""Decodes a list of tokens into a string and a list of offsets.
281+
282+
Each offset is the index into text corresponding to the start of each token.
283+
If UTF-8 character boundaries do not line up with token boundaries, the offset is the index
284+
of the first character that contains bytes from the token.
285+
286+
This will currently raise if given tokens that decode to invalid UTF-8; this behaviour may
287+
change in the future to be more permissive.
288+
289+
>>> enc.decode_with_offsets([31373, 995])
290+
('hello world', [0, 5])
291+
"""
292+
token_bytes = self.decode_tokens_bytes(tokens)
293+
294+
text_len = 0
295+
offsets = []
296+
for token in token_bytes:
297+
offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0)))
298+
text_len += sum(1 for c in token if not 0x80 <= c < 0xC0)
299+
300+
# TODO: assess correctness for errors="ignore" and errors="replace"
301+
text = b"".join(token_bytes).decode("utf-8", errors="strict")
302+
return text, offsets
303+
279304
def decode_batch(
280305
self, batch: list[list[int]], *, errors: str = "replace", num_threads: int = 8
281306
) -> list[str]:

‎tiktoken/model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# chat
1616
"gpt-4-": "cl100k_base", # e.g., gpt-4-0314, etc., plus gpt-4-32k
1717
"gpt-3.5-turbo-": "cl100k_base", # e.g, gpt-3.5-turbo-0301, -0401, etc.
18+
"gpt-35-turbo": "cl100k_base", # Azure deployment name
1819
}
1920

2021
MODEL_TO_ENCODING: dict[str, str] = json.loads(pkg_resources.read_text("tiktoken", "model_to_encoding.json"))
@@ -36,7 +37,7 @@ def encoding_for_model(model_name: str) -> Encoding:
3637
if encoding_name is None:
3738
raise KeyError(
3839
f"Could not automatically map {model_name} to a tokeniser. "
39-
"Please use `tiktok.get_encoding` to explicitly get the tokeniser you expect."
40+
"Please use `tiktoken.get_encoding` to explicitly get the tokeniser you expect."
4041
) from None
4142

4243
return get_encoding(encoding_name)

‎wasm/Cargo.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ crate-type = ["cdylib"]
1111
[dependencies]
1212
tiktoken_core = { path = "../core", features = [] }
1313
# tiktoken dependencies
14-
fancy-regex = "0.10.0"
15-
regex = "1.7.0"
14+
fancy-regex = "0.11.0"
15+
regex = "1.8.3"
1616
rustc-hash = "1.1.0"
17-
bstr = "1.0.1"
17+
bstr = "1.5.0"
1818
wasm-bindgen = "0.2.83"
1919
anyhow = "1.0.69"
2020
base64 = "0.21.0"

0 commit comments

Comments
 (0)
Please sign in to comment.