|
| 1 | +# Note that there are more actual tests, they're just not currently public :-) |
| 2 | + |
| 3 | +from typing import Callable |
| 4 | + |
| 5 | +import hypothesis |
| 6 | +import hypothesis.strategies as st |
| 7 | +import pytest |
| 8 | + |
| 9 | +import tiktoken |
| 10 | + |
| 11 | +from .test_helpers import ENCODING_FACTORIES, MAX_EXAMPLES |
| 12 | + |
| 13 | + |
| 14 | +def test_simple(): |
| 15 | + enc = tiktoken.get_encoding("gpt2") |
| 16 | + assert enc.encode("hello world") == [31373, 995] |
| 17 | + assert enc.decode([31373, 995]) == "hello world" |
| 18 | + assert enc.encode("hello <|endoftext|>", allowed_special="all") == [31373, 220, 50256] |
| 19 | + |
| 20 | + enc = tiktoken.get_encoding("cl100k_base") |
| 21 | + assert enc.encode("hello world") == [15339, 1917] |
| 22 | + assert enc.decode([15339, 1917]) == "hello world" |
| 23 | + assert enc.encode("hello <|endoftext|>", allowed_special="all") == [15339, 220, 100257] |
| 24 | + |
| 25 | + for enc_name in tiktoken.list_encoding_names(): |
| 26 | + enc = tiktoken.get_encoding(enc_name) |
| 27 | + for token in range(10_000): |
| 28 | + assert enc.encode_single_token(enc.decode_single_token_bytes(token)) == token |
| 29 | + |
| 30 | + |
| 31 | +def test_simple_repeated(): |
| 32 | + enc = tiktoken.get_encoding("gpt2") |
| 33 | + assert enc.encode("0") == [15] |
| 34 | + assert enc.encode("00") == [405] |
| 35 | + assert enc.encode("000") == [830] |
| 36 | + assert enc.encode("0000") == [2388] |
| 37 | + assert enc.encode("00000") == [20483] |
| 38 | + assert enc.encode("000000") == [10535] |
| 39 | + assert enc.encode("0000000") == [24598] |
| 40 | + assert enc.encode("00000000") == [8269] |
| 41 | + assert enc.encode("000000000") == [10535, 830] |
| 42 | + assert enc.encode("0000000000") == [8269, 405] |
| 43 | + assert enc.encode("00000000000") == [8269, 830] |
| 44 | + assert enc.encode("000000000000") == [8269, 2388] |
| 45 | + assert enc.encode("0000000000000") == [8269, 20483] |
| 46 | + assert enc.encode("00000000000000") == [8269, 10535] |
| 47 | + assert enc.encode("000000000000000") == [8269, 24598] |
| 48 | + assert enc.encode("0000000000000000") == [25645] |
| 49 | + assert enc.encode("00000000000000000") == [8269, 10535, 830] |
| 50 | + |
| 51 | + |
| 52 | +def test_simple_regex(): |
| 53 | + enc = tiktoken.get_encoding("cl100k_base") |
| 54 | + assert enc.encode("rer") == [38149] |
| 55 | + assert enc.encode("'rer") == [2351, 81] |
| 56 | + assert enc.encode("today\n ") == [31213, 198, 220] |
| 57 | + assert enc.encode("today\n \n") == [31213, 27907] |
| 58 | + assert enc.encode("today\n \n") == [31213, 14211] |
| 59 | + |
| 60 | + |
| 61 | +def test_basic_encode(): |
| 62 | + enc = tiktoken.get_encoding("r50k_base") |
| 63 | + assert enc.encode("hello world") == [31373, 995] |
| 64 | + |
| 65 | + enc = tiktoken.get_encoding("p50k_base") |
| 66 | + assert enc.encode("hello world") == [31373, 995] |
| 67 | + |
| 68 | + enc = tiktoken.get_encoding("cl100k_base") |
| 69 | + assert enc.encode("hello world") == [15339, 1917] |
| 70 | + assert enc.encode(" \x850") == [220, 126, 227, 15] |
| 71 | + |
| 72 | + |
| 73 | +def test_encode_empty(): |
| 74 | + enc = tiktoken.get_encoding("r50k_base") |
| 75 | + assert enc.encode("") == [] |
| 76 | + |
| 77 | + |
| 78 | +def test_encode_bytes(): |
| 79 | + enc = tiktoken.get_encoding("cl100k_base") |
| 80 | + assert enc._encode_bytes(b" \xec\x8b\xa4\xed") == [62085] |
| 81 | + |
| 82 | + |
| 83 | +def test_encode_surrogate_pairs(): |
| 84 | + enc = tiktoken.get_encoding("cl100k_base") |
| 85 | + |
| 86 | + assert enc.encode("👍") == [9468, 239, 235] |
| 87 | + # surrogate pair gets converted to codepoint |
| 88 | + assert enc.encode("\ud83d\udc4d") == [9468, 239, 235] |
| 89 | + |
| 90 | + # lone surrogate just gets replaced |
| 91 | + assert enc.encode("\ud83d") == enc.encode("�") |
| 92 | + |
| 93 | + |
| 94 | +# ==================== |
| 95 | +# Roundtrip |
| 96 | +# ==================== |
| 97 | + |
| 98 | + |
| 99 | +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) |
| 100 | +def test_basic_roundtrip(make_enc): |
| 101 | + enc = make_enc() |
| 102 | + for value in ( |
| 103 | + "hello", |
| 104 | + "hello ", |
| 105 | + "hello ", |
| 106 | + " hello", |
| 107 | + " hello ", |
| 108 | + " hello ", |
| 109 | + "hello world", |
| 110 | + "请考试我的软件!12345", |
| 111 | + ): |
| 112 | + assert value == enc.decode(enc.encode(value)) |
| 113 | + assert value == enc.decode(enc.encode_ordinary(value)) |
| 114 | + |
| 115 | + |
| 116 | +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) |
| 117 | +@hypothesis.given(text=st.text()) |
| 118 | +@hypothesis.settings(deadline=None) |
| 119 | +def test_hyp_roundtrip(make_enc: Callable[[], tiktoken.Encoding], text): |
| 120 | + enc = make_enc() |
| 121 | + |
| 122 | + assert text == enc.decode(enc.encode(text)) |
| 123 | + |
| 124 | + |
| 125 | +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) |
| 126 | +def test_single_token_roundtrip(make_enc: Callable[[], tiktoken.Encoding]): |
| 127 | + enc = make_enc() |
| 128 | + |
| 129 | + for token in range(enc.n_vocab): |
| 130 | + try: |
| 131 | + token_bytes = enc.decode_single_token_bytes(token) |
| 132 | + except KeyError: |
| 133 | + continue |
| 134 | + assert enc.encode_single_token(token_bytes) == token |
| 135 | + |
| 136 | + |
| 137 | +# ==================== |
| 138 | +# Special tokens |
| 139 | +# ==================== |
| 140 | + |
| 141 | + |
| 142 | +def test_special_token(): |
| 143 | + enc = tiktoken.get_encoding("cl100k_base") |
| 144 | + |
| 145 | + eot = enc.encode_single_token("<|endoftext|>") |
| 146 | + assert eot == enc.eot_token |
| 147 | + fip = enc.encode_single_token("<|fim_prefix|>") |
| 148 | + fim = enc.encode_single_token("<|fim_middle|>") |
| 149 | + |
| 150 | + text = "<|endoftext|> hello <|fim_prefix|>" |
| 151 | + assert eot not in enc.encode(text, disallowed_special=()) |
| 152 | + with pytest.raises(ValueError): |
| 153 | + enc.encode(text) |
| 154 | + with pytest.raises(ValueError): |
| 155 | + enc.encode(text, disallowed_special="all") |
| 156 | + with pytest.raises(ValueError): |
| 157 | + enc.encode(text, disallowed_special={"<|endoftext|>"}) |
| 158 | + with pytest.raises(ValueError): |
| 159 | + enc.encode(text, disallowed_special={"<|fim_prefix|>"}) |
| 160 | + |
| 161 | + text = "<|endoftext|> hello <|fim_prefix|> there <|fim_middle|>" |
| 162 | + tokens = enc.encode(text, disallowed_special=()) |
| 163 | + assert eot not in tokens |
| 164 | + assert fip not in tokens |
| 165 | + assert fim not in tokens |
| 166 | + |
| 167 | + tokens = enc.encode(text, allowed_special="all", disallowed_special=()) |
| 168 | + assert eot in tokens |
| 169 | + assert fip in tokens |
| 170 | + assert fim in tokens |
| 171 | + |
| 172 | + tokens = enc.encode(text, allowed_special="all", disallowed_special="all") |
| 173 | + assert eot in tokens |
| 174 | + assert fip in tokens |
| 175 | + assert fim in tokens |
| 176 | + |
| 177 | + tokens = enc.encode(text, allowed_special={"<|fim_prefix|>"}, disallowed_special=()) |
| 178 | + assert eot not in tokens |
| 179 | + assert fip in tokens |
| 180 | + assert fim not in tokens |
| 181 | + |
| 182 | + tokens = enc.encode(text, allowed_special={"<|endoftext|>"}, disallowed_special=()) |
| 183 | + assert eot in tokens |
| 184 | + assert fip not in tokens |
| 185 | + assert fim not in tokens |
| 186 | + |
| 187 | + tokens = enc.encode(text, allowed_special={"<|fim_middle|>"}, disallowed_special=()) |
| 188 | + assert eot not in tokens |
| 189 | + assert fip not in tokens |
| 190 | + assert fim in tokens |
| 191 | + |
| 192 | + |
| 193 | +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) |
| 194 | +@hypothesis.given(text=st.text()) |
| 195 | +@hypothesis.settings(deadline=None, max_examples=MAX_EXAMPLES) |
| 196 | +def test_hyp_special_ordinary(make_enc, text: str): |
| 197 | + enc = make_enc() |
| 198 | + assert enc.encode_ordinary(text) == enc.encode(text, disallowed_special=()) |
| 199 | + |
| 200 | + |
| 201 | +# ==================== |
| 202 | +# Batch encoding |
| 203 | +# ==================== |
| 204 | + |
| 205 | + |
| 206 | +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) |
| 207 | +def test_batch_encode(make_enc: Callable[[], tiktoken.Encoding]): |
| 208 | + enc = make_enc() |
| 209 | + text1 = "hello world" |
| 210 | + text2 = "goodbye world" |
| 211 | + |
| 212 | + assert enc.encode_batch([text1]) == [enc.encode(text1)] |
| 213 | + assert enc.encode_batch([text1, text2]) == [enc.encode(text1), enc.encode(text2)] |
| 214 | + |
| 215 | + assert enc.encode_ordinary_batch([text1]) == [enc.encode_ordinary(text1)] |
| 216 | + assert enc.encode_ordinary_batch([text1, text2]) == [ |
| 217 | + enc.encode_ordinary(text1), |
| 218 | + enc.encode_ordinary(text2), |
| 219 | + ] |
| 220 | + |
| 221 | + |
| 222 | +@pytest.mark.parametrize("make_enc", ENCODING_FACTORIES) |
| 223 | +@hypothesis.given(batch=st.lists(st.text())) |
| 224 | +@hypothesis.settings(deadline=None) |
| 225 | +def test_hyp_batch_roundtrip(make_enc: Callable[[], tiktoken.Encoding], batch): |
| 226 | + enc = make_enc() |
| 227 | + |
| 228 | + encoded = enc.encode_batch(batch) |
| 229 | + assert encoded == [enc.encode(t) for t in batch] |
| 230 | + decoded = enc.decode_batch(encoded) |
| 231 | + assert decoded == batch |
0 commit comments