Skip to content

Commit

Permalink
Add optional eos_token parameter to LLTokenizer initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Feb 4, 2025
1 parent 6e1ec11 commit b80ccb8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
1 change: 1 addition & 0 deletions python/llguidance/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class LLTokenizer:
cls,
tokenizer: Union[str, TokenizerWrapper],
n_vocab: Optional[int] = None,
eos_token: Optional[TokenId] = None,
slices: Optional[List[str]] = None,
) -> "LLTokenizer":
"""
Expand Down
23 changes: 14 additions & 9 deletions python_ext/src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ impl LLTokenizer {
fn py_new(
tokenizer: Bound<'_, PyAny>,
n_vocab: Option<usize>,
eos_token: Option<u32>,
slices: Option<Vec<String>>,
) -> PyResult<Self> {
let tok_env: TokEnv = if let Some(tokenizer_str) = tokenizer.extract::<String>().ok() {
Expand All @@ -224,15 +225,19 @@ impl LLTokenizer {
"</s>",
"<|endoftext|>",
];
let eos_token = candidates
.iter()
.filter_map(|s| trie.get_special_token(s))
.next()
.ok_or_else(|| {
PyValueError::new_err(format!(
"Expecting a tokenizer with an EOS token, but none was found"
))
})?;
let eos_token = if let Some(eos_token) = eos_token {
eos_token
} else {
candidates
.iter()
.filter_map(|s| trie.get_special_token(s))
.next()
.ok_or_else(|| {
PyValueError::new_err(format!(
"Expecting a tokenizer with an EOS token, but none was found"
))
})?
};
let trie = trie.with_eos_token(eos_token);
Arc::new(ApproximateTokEnv::new(trie))
} else {
Expand Down

0 comments on commit b80ccb8

Please sign in to comment.