Skip to content

Commit 0ff2ab0

Browse files
authored
Fixing the stream by removing the read_index altogether. (#1716)
* Fixing the stream by removing the read_index altogether. * Moving the test location because.. Windows. * Ok whatever. * Rust 1.84 * Fmt.
1 parent 862d1a3 commit 0ff2ab0

File tree

8 files changed

+89
-132
lines changed

8 files changed

+89
-132
lines changed

bindings/python/Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ style:
1414
# Check the source code is formatted correctly
1515
check-style:
1616
python stub.py --check
17-
ruff check examples py_src/tokenizers tests
18-
ruff format --check examples py_src/tokenizers tests
17+
ruff check $(check_dirs)
18+
ruff format --check $(check_dirs)
1919

2020
TESTS_RESOURCES = $(DATA_DIR)/small.txt $(DATA_DIR)/roberta.json
2121

bindings/python/py_src/tokenizers/tools/visualizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def consecutive_chars_to_html(
241241
# In this case we are looking at a group/single char that is not tokenized.
242242
# e.g. white space
243243
css_classes.append("non-token")
244-
css = f'''class="{' '.join(css_classes)}"'''
244+
css = f'''class="{" ".join(css_classes)}"'''
245245
data = ""
246246
for key, val in data_items.items():
247247
data += f' data-{key}="{val}"'

bindings/python/src/decoders.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -646,11 +646,6 @@ pub struct PyDecodeStream {
646646
/// The index within the ids corresponding to the prefix so we can drain
647647
/// correctly
648648
prefix_index: usize,
649-
/// We need to keep 2 prefixes.
650-
/// Prefix is the second one that was already emitted to discard the part
651-
/// of the text of all the ids
652-
/// read is the prefix kept only for starting side effects of the prefix
653-
read_index: usize,
654649
}
655650

656651
#[pymethods]
@@ -663,7 +658,6 @@ impl PyDecodeStream {
663658
ids: vec![],
664659
prefix: "".to_string(),
665660
prefix_index: 0,
666-
read_index: 0,
667661
}
668662
}
669663

@@ -676,7 +670,6 @@ impl PyDecodeStream {
676670
&mut self.ids,
677671
&mut self.prefix,
678672
&mut self.prefix_index,
679-
&mut self.read_index,
680673
))
681674
.into()
682675
}

tokenizers/Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ TESTS_DIR = tests
44

55
dir_guard=@mkdir -p $(@D)
66

7-
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt $(DATA_DIR)/albert-base-v1-tokenizer.json
7+
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt $(DATA_DIR)/albert-base-v1-tokenizer.json $(DATA_DIR)/llama-3-tokenizer.json
88
BENCHMARK_RESOURCES = $(SHARED_RESOURCES)
99
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt $(DATA_DIR)/roberta.json $(DATA_DIR)/tokenizer-wiki.json $(DATA_DIR)/bert-wiki.json
1010

@@ -79,3 +79,7 @@ $(DATA_DIR)/tokenizer-wiki.json :
7979
$(DATA_DIR)/bert-wiki.json :
8080
$(dir_guard)
8181
wget https://s3.amazonaws.com/models.huggingface.co/bert/anthony/doc-pipeline/tokenizer.json -O $@
82+
83+
$(DATA_DIR)/llama-3-tokenizer.json :
84+
$(dir_guard)
85+
wget https://huggingface.co/hf-internal-testing/llama3-tokenizer/resolve/main/tokenizer.json -O $@

tokenizers/src/models/bpe/word.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ impl Word {
199199

200200
// Make sure we are not processing an expired queue entry
201201
let target_new_pair = (self.symbols[top.pos].c, right.c);
202-
if !merges
202+
if merges
203203
.get(&target_new_pair)
204-
.map_or(false, |(_, new_id)| *new_id == top.new_id)
204+
.is_none_or(|(_, new_id)| *new_id != top.new_id)
205205
{
206206
continue;
207207
}

tokenizers/src/processors/template.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ impl TemplateProcessingBuilder {
441441
let exist = self
442442
.special_tokens
443443
.as_ref()
444-
.map_or(false, |map| map.0.contains_key(sp));
444+
.is_some_and(|map| map.0.contains_key(sp));
445445

446446
match exist {
447447
false => Some(sp),

tokenizers/src/tokenizer/mod.rs

Lines changed: 0 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,11 +1035,6 @@ pub struct DecodeStream<'tok, M, N, PT, PP, D> {
10351035
/// The index within the ids corresponding to the prefix so we can drain
10361036
/// correctly
10371037
prefix_index: usize,
1038-
/// We need to keep 2 prefixes.
1039-
/// Prefix is the second one that was already emitted to discard the part
1040-
/// of the text of all the ids
1041-
/// read is the prefix kept only for starting side effects of the prefix
1042-
read_index: usize,
10431038
}
10441039

10451040
#[derive(thiserror::Error, Debug)]
@@ -1063,7 +1058,6 @@ where
10631058
skip_special_tokens,
10641059
prefix: "".to_string(),
10651060
prefix_index: 0,
1066-
read_index: 0,
10671061
}
10681062
}
10691063

@@ -1076,7 +1070,6 @@ where
10761070
&mut self.ids,
10771071
&mut self.prefix,
10781072
&mut self.prefix_index,
1079-
&mut self.read_index,
10801073
)
10811074
}
10821075
}
@@ -1089,7 +1082,6 @@ pub fn step_decode_stream<M, N, PT, PP, D>(
10891082
ids: &mut Vec<u32>,
10901083
prefix: &mut String,
10911084
prefix_index: &mut usize,
1092-
read_index: &mut usize,
10931085
) -> Result<Option<String>>
10941086
where
10951087
M: Model,
@@ -1108,7 +1100,6 @@ where
11081100
let new_prefix_index = ids.len() - *prefix_index;
11091101
*ids = ids.drain(*prefix_index..).collect();
11101102
*prefix = tokenizer.decode(ids, skip_special_tokens)?;
1111-
*read_index = *prefix_index;
11121103
*prefix_index = new_prefix_index;
11131104
Ok(Some(new_text.to_string()))
11141105
} else {
@@ -1563,112 +1554,3 @@ where
15631554
Ok(())
15641555
}
15651556
}
1566-
1567-
#[cfg(test)]
1568-
mod test {
1569-
#[cfg(feature = "http")]
1570-
#[test]
1571-
fn test_decoding_with_added_bpe() {
1572-
use crate::{
1573-
normalizers,
1574-
pre_tokenizers::split::{Split, SplitPattern},
1575-
AddedToken, NormalizerWrapper, PreTokenizerWrapper, SplitDelimiterBehavior, Tokenizer,
1576-
};
1577-
1578-
let mut tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap();
1579-
tokenizer.normalizer = Some(NormalizerWrapper::from(normalizers::ByteLevel::new()));
1580-
tokenizer.pre_tokenizer = Some(PreTokenizerWrapper::Split(
1581-
Split::new(
1582-
SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()),
1583-
SplitDelimiterBehavior::Isolated,
1584-
false,
1585-
)
1586-
.unwrap(),
1587-
));
1588-
tokenizer.add_tokens(&[AddedToken::from("嗎", false).normalized(false)]);
1589-
let encoded = tokenizer
1590-
.encode("Hey! how is this token: 嗎", false)
1591-
.unwrap();
1592-
assert_eq!(
1593-
encoded.get_ids(),
1594-
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128256]
1595-
);
1596-
assert_eq!(
1597-
encoded.get_tokens(),
1598-
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "嗎"]
1599-
);
1600-
1601-
let decoded = tokenizer.decode(encoded.get_ids(), false);
1602-
assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎");
1603-
1604-
tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]);
1605-
let encoded = tokenizer
1606-
.encode("Hey! how is this token: д", false)
1607-
.unwrap();
1608-
assert_eq!(
1609-
encoded.get_ids(),
1610-
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128257]
1611-
);
1612-
assert_eq!(
1613-
encoded.get_tokens(),
1614-
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"]
1615-
);
1616-
let decoded = tokenizer.decode(encoded.get_ids(), false);
1617-
assert_eq!(decoded.unwrap(), "Hey! how is this token: д")
1618-
}
1619-
1620-
#[cfg(feature = "http")]
1621-
#[test]
1622-
fn test_decode_stream_step_no_panic() {
1623-
use std::panic;
1624-
1625-
use crate::Tokenizer;
1626-
1627-
let tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap();
1628-
1629-
// "A B C D E F G H I J"
1630-
let mut decode_stream = tokenizer.decode_stream(false);
1631-
let output_tokens = vec![32, 426, 356, 423, 469, 435, 480, 473, 358, 622];
1632-
let expected_outputs = vec![
1633-
Some("A".to_string()),
1634-
Some(" B".to_string()),
1635-
Some(" C".to_string()),
1636-
Some(" D".to_string()),
1637-
Some(" E".to_string()),
1638-
Some(" F".to_string()),
1639-
Some(" G".to_string()),
1640-
Some(" H".to_string()),
1641-
Some(" I".to_string()),
1642-
Some(" J".to_string()),
1643-
];
1644-
for (i, &token) in output_tokens.iter().enumerate() {
1645-
let maybe_panic =
1646-
panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token)));
1647-
assert!(maybe_panic.is_ok());
1648-
let result = maybe_panic.unwrap();
1649-
assert!(result.is_ok());
1650-
assert_eq!(result.unwrap(), expected_outputs[i]);
1651-
}
1652-
1653-
// "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113])
1654-
let mut decode_stream = tokenizer.decode_stream(false);
1655-
let output_tokens = vec![80690, 98, 167, 121, 243, 102457, 113];
1656-
let expected_outputs = vec![
1657-
None,
1658-
Some("삥".to_string()),
1659-
None,
1660-
None,
1661-
Some("뽕".to_string()),
1662-
None,
1663-
Some("빵".to_string()),
1664-
];
1665-
for (i, &token) in output_tokens.iter().enumerate() {
1666-
let maybe_panic =
1667-
panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token)));
1668-
assert!(maybe_panic.is_ok());
1669-
let result = maybe_panic.unwrap();
1670-
assert!(result.is_ok());
1671-
assert_eq!(result.unwrap(), expected_outputs[i]);
1672-
}
1673-
}
1674-
}

tokenizers/tests/stream.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
use tokenizers::{
2+
normalizers,
3+
pre_tokenizers::split::{Split, SplitPattern},
4+
AddedToken, NormalizerWrapper, PreTokenizerWrapper, SplitDelimiterBehavior, Tokenizer,
5+
};
6+
7+
#[test]
8+
fn test_decoding_with_added_bpe() {
9+
let mut tokenizer = Tokenizer::from_file("data/llama-3-tokenizer.json").unwrap();
10+
tokenizer.with_normalizer(Some(NormalizerWrapper::from(normalizers::ByteLevel::new())));
11+
tokenizer.with_pre_tokenizer(Some(PreTokenizerWrapper::Split(
12+
Split::new(
13+
SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()),
14+
SplitDelimiterBehavior::Isolated,
15+
false,
16+
)
17+
.unwrap(),
18+
)));
19+
tokenizer.add_tokens(&[AddedToken::from("嗎", false).normalized(false)]);
20+
let encoded = tokenizer
21+
.encode("Hey! how is this token: 嗎", false)
22+
.unwrap();
23+
assert_eq!(
24+
encoded.get_ids(),
25+
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128256]
26+
);
27+
assert_eq!(
28+
encoded.get_tokens(),
29+
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "嗎"]
30+
);
31+
32+
let decoded = tokenizer.decode(encoded.get_ids(), false);
33+
assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎");
34+
35+
tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]);
36+
let encoded = tokenizer
37+
.encode("Hey! how is this token: д", false)
38+
.unwrap();
39+
assert_eq!(
40+
encoded.get_ids(),
41+
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128257]
42+
);
43+
assert_eq!(
44+
encoded.get_tokens(),
45+
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"]
46+
);
47+
let decoded = tokenizer.decode(encoded.get_ids(), false);
48+
assert_eq!(decoded.unwrap(), "Hey! how is this token: д")
49+
}
50+
51+
#[test]
52+
fn test_decode_stream_step_no_panic() {
53+
let tokenizer = Tokenizer::from_file("data/llama-3-tokenizer.json").unwrap();
54+
55+
// "A B C D E F G H I J"
56+
let mut decode_stream = tokenizer.decode_stream(false);
57+
assert_eq!(decode_stream.step(32).unwrap(), Some("A".to_string()));
58+
assert_eq!(decode_stream.step(426).unwrap(), Some(" B".to_string()));
59+
assert_eq!(decode_stream.step(356).unwrap(), Some(" C".to_string()));
60+
assert_eq!(decode_stream.step(423).unwrap(), Some(" D".to_string()));
61+
assert_eq!(decode_stream.step(469).unwrap(), Some(" E".to_string()));
62+
assert_eq!(decode_stream.step(435).unwrap(), Some(" F".to_string()));
63+
assert_eq!(decode_stream.step(480).unwrap(), Some(" G".to_string()));
64+
assert_eq!(decode_stream.step(473).unwrap(), Some(" H".to_string()));
65+
assert_eq!(decode_stream.step(358).unwrap(), Some(" I".to_string()));
66+
assert_eq!(decode_stream.step(622).unwrap(), Some(" J".to_string()));
67+
// for (i, &token) in output_tokens.iter().enumerate() {}
68+
69+
// "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113])
70+
let mut decode_stream = tokenizer.decode_stream(false);
71+
assert_eq!(decode_stream.step(80690).unwrap(), None);
72+
assert_eq!(decode_stream.step(98).unwrap(), Some("삥".to_string()));
73+
assert_eq!(decode_stream.step(167).unwrap(), None);
74+
assert_eq!(decode_stream.step(121).unwrap(), None);
75+
assert_eq!(decode_stream.step(243).unwrap(), Some("뽕".to_string()));
76+
assert_eq!(decode_stream.step(102457).unwrap(), None);
77+
assert_eq!(decode_stream.step(113).unwrap(), Some("빵".to_string()));
78+
}

0 commit comments

Comments
 (0)