Skip to content

Commit c45aebd

Browse files
🚨 Support updating template processors (#1652)
* current updates * simplify * set_item works, but `tokenizer._tokenizer.post_processor[1].single = ["$0", "</s>"]` does not ! * fix: `normalizers` deserialization and other refactoring * fix: `pre_tokenizer` deserialization * feat: add `__len__` implementation for `normalizer::PySequence` * feat: add `__setitem__` impl for `normalizers::PySequence` * feat: add `__setitem__` impl to `pre_tokenizer::PySequence` * feat: add `__setitem__` impl to `post_processor::PySequence` * test: add normalizer sequence setter check * refactor: allow unused `processors::setter` macro * test: add `__setitem__` test for processors & pretok * refactor: `unwrap` -> `PyException::new_err()?` * refactor: fmt * refactor: remove unnecessary `pub` * feat(bindings): add missing getters & setters for pretoks * feat(bindings): add missing getters & setters for processors * refactor(bindings): rewrite RwLock poison error msg * refactor: remove debug print * feat(bindings): add description as to why custom deser is needed * feat: make post proc sequence elements mutable * fix(binding): serialization --------- Co-authored-by: Luc Georges <[email protected]>
1 parent e7ed39d commit c45aebd

22 files changed

+1011
-179
lines changed

.github/workflows/python.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ jobs:
5555
steps:
5656
- name: Checkout repository
5757
uses: actions/checkout@v4
58-
5958

6059
- name: Install Rust
6160
uses: actions-rs/toolchain@v1

bindings/python/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ serde = { version = "1.0", features = ["rc", "derive"] }
1414
serde_json = "1.0"
1515
libc = "0.2"
1616
env_logger = "0.11"
17-
pyo3 = { version = "0.23", features = ["abi3", "abi3-py39"] }
17+
pyo3 = { version = "0.23", features = ["abi3", "abi3-py39", "py-clone"] }
1818
numpy = "0.23"
1919
ndarray = "0.16"
2020
itertools = "0.12"

bindings/python/pyproject.toml

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
name = 'tokenizers'
33
requires-python = '>=3.9'
44
authors = [
5-
{name = 'Nicolas Patry', email = '[email protected]'},
6-
{name = 'Anthony Moi', email = '[email protected]'}
5+
{ name = 'Nicolas Patry', email = '[email protected]' },
6+
{ name = 'Anthony Moi', email = '[email protected]' },
77
]
88
classifiers = [
99
"Development Status :: 5 - Production/Stable",
@@ -21,12 +21,7 @@ classifiers = [
2121
"Topic :: Scientific/Engineering :: Artificial Intelligence",
2222
]
2323
keywords = ["NLP", "tokenizer", "BPE", "transformer", "deep learning"]
24-
dynamic = [
25-
'description',
26-
'license',
27-
'readme',
28-
'version',
29-
]
24+
dynamic = ['description', 'license', 'readme', 'version']
3025
dependencies = ["huggingface_hub>=0.16.4,<1.0"]
3126

3227
[project.urls]
@@ -58,16 +53,16 @@ target-version = ['py35']
5853
line-length = 119
5954
target-version = "py311"
6055
lint.ignore = [
61-
# a == None in tests vs is None.
62-
"E711",
63-
# a == False in tests vs is False.
64-
"E712",
65-
# try.. import except.. pattern without using the lib.
66-
"F401",
67-
# Raw type equality is required in asserts
68-
"E721",
69-
# Import order
70-
"E402",
71-
# Fixtures unused import
72-
"F811",
56+
# a == None in tests vs is None.
57+
"E711",
58+
# a == False in tests vs is False.
59+
"E712",
60+
# try.. import except.. pattern without using the lib.
61+
"F401",
62+
# Raw type equality is required in asserts
63+
"E721",
64+
# Import order
65+
"E402",
66+
# Fixtures unused import
67+
"F811",
7368
]

bindings/python/src/normalizers.rs

Lines changed: 116 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use pyo3::exceptions::PyException;
12
use pyo3::types::*;
23
use pyo3::{exceptions, prelude::*};
34
use std::sync::{Arc, RwLock};
@@ -41,7 +42,7 @@ impl PyNormalizedStringMut<'_> {
4142
/// This class is not supposed to be instantiated directly. Instead, any implementation of a
4243
/// Normalizer will return an instance of this class when instantiated.
4344
#[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)]
44-
#[derive(Clone, Serialize, Deserialize)]
45+
#[derive(Clone, Debug, Serialize, Deserialize)]
4546
#[serde(transparent)]
4647
pub struct PyNormalizer {
4748
pub(crate) normalizer: PyNormalizerTypeWrapper,
@@ -58,7 +59,11 @@ impl PyNormalizer {
5859
.into_pyobject(py)?
5960
.into_any()
6061
.into(),
61-
PyNormalizerTypeWrapper::Single(ref inner) => match &*inner.as_ref().read().unwrap() {
62+
PyNormalizerTypeWrapper::Single(ref inner) => match &*inner
63+
.as_ref()
64+
.read()
65+
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))?
66+
{
6267
PyNormalizerWrapper::Custom(_) => {
6368
Py::new(py, base)?.into_pyobject(py)?.into_any().into()
6469
}
@@ -218,7 +223,9 @@ macro_rules! getter {
218223
($self: ident, $variant: ident, $name: ident) => {{
219224
let super_ = $self.as_ref();
220225
if let PyNormalizerTypeWrapper::Single(ref norm) = super_.normalizer {
221-
let wrapper = norm.read().unwrap();
226+
let wrapper = norm.read().expect(
227+
"RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer",
228+
);
222229
if let PyNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (&*wrapper) {
223230
o.$name.clone()
224231
} else {
@@ -234,7 +241,9 @@ macro_rules! setter {
234241
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
235242
let super_ = $self.as_ref();
236243
if let PyNormalizerTypeWrapper::Single(ref norm) = super_.normalizer {
237-
let mut wrapper = norm.write().unwrap();
244+
let mut wrapper = norm.write().expect(
245+
"RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer",
246+
);
238247
if let PyNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(ref mut o)) = *wrapper {
239248
o.$name = $value;
240249
}
@@ -410,25 +419,55 @@ impl PySequence {
410419
PyTuple::new(py, [PyList::empty(py)])
411420
}
412421

413-
fn __len__(&self) -> usize {
414-
0
422+
fn __len__(self_: PyRef<'_, Self>) -> usize {
423+
match &self_.as_ref().normalizer {
424+
PyNormalizerTypeWrapper::Sequence(inner) => inner.len(),
425+
PyNormalizerTypeWrapper::Single(_) => 1,
426+
}
415427
}
416428

417429
fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult<Py<PyAny>> {
418430
match &self_.as_ref().normalizer {
419431
PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) {
420-
Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(item)))
432+
Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(item.clone()))
421433
.get_as_subtype(py),
422434
_ => Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
423435
"Index not found",
424436
)),
425437
},
426438
PyNormalizerTypeWrapper::Single(inner) => {
427-
PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(inner)))
428-
.get_as_subtype(py)
439+
PyNormalizer::new(PyNormalizerTypeWrapper::Single(inner.clone())).get_as_subtype(py)
429440
}
430441
}
431442
}
443+
444+
fn __setitem__(self_: PyRef<'_, Self>, index: usize, value: Bound<'_, PyAny>) -> PyResult<()> {
445+
let norm: PyNormalizer = value.extract()?;
446+
let PyNormalizerTypeWrapper::Single(norm) = norm.normalizer else {
447+
return Err(PyException::new_err("normalizer should not be a sequence"));
448+
};
449+
match &self_.as_ref().normalizer {
450+
PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) {
451+
Some(item) => {
452+
*item
453+
.write()
454+
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))? = norm
455+
.read()
456+
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))?
457+
.clone();
458+
}
459+
_ => {
460+
return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
461+
"Index not found",
462+
))
463+
}
464+
},
465+
PyNormalizerTypeWrapper::Single(_) => {
466+
return Err(PyException::new_err("normalizer is not a sequence"))
467+
}
468+
};
469+
Ok(())
470+
}
432471
}
433472

434473
/// Lowercase Normalizer
@@ -570,9 +609,31 @@ impl PyReplace {
570609
ToPyResult(Replace::new(pattern, content)).into_py()?.into(),
571610
))
572611
}
612+
613+
#[getter]
614+
fn get_pattern(_self: PyRef<Self>) -> PyResult<()> {
615+
Err(PyException::new_err("Cannot get pattern"))
616+
}
617+
618+
#[setter]
619+
fn set_pattern(_self: PyRef<Self>, _pattern: PyPattern) -> PyResult<()> {
620+
Err(PyException::new_err(
621+
"Cannot set pattern, please instantiate a new replace pattern instead",
622+
))
623+
}
624+
625+
#[getter]
626+
fn get_content(self_: PyRef<Self>) -> String {
627+
getter!(self_, Replace, content)
628+
}
629+
630+
#[setter]
631+
fn set_content(self_: PyRef<Self>, content: String) {
632+
setter!(self_, Replace, content, content)
633+
}
573634
}
574635

575-
#[derive(Debug)]
636+
#[derive(Clone, Debug)]
576637
pub(crate) struct CustomNormalizer {
577638
inner: PyObject,
578639
}
@@ -615,7 +676,7 @@ impl<'de> Deserialize<'de> for CustomNormalizer {
615676
}
616677
}
617678

618-
#[derive(Debug, Deserialize)]
679+
#[derive(Clone, Debug, Deserialize)]
619680
#[serde(untagged)]
620681
pub(crate) enum PyNormalizerWrapper {
621682
Custom(CustomNormalizer),
@@ -634,13 +695,27 @@ impl Serialize for PyNormalizerWrapper {
634695
}
635696
}
636697

637-
#[derive(Debug, Clone, Deserialize)]
638-
#[serde(untagged)]
698+
#[derive(Debug, Clone)]
639699
pub(crate) enum PyNormalizerTypeWrapper {
640700
Sequence(Vec<Arc<RwLock<PyNormalizerWrapper>>>),
641701
Single(Arc<RwLock<PyNormalizerWrapper>>),
642702
}
643703

704+
/// XXX: we need to manually implement deserialize here because of the structure of the
705+
/// PyNormalizerTypeWrapper enum. Given the underlying PyNormalizerWrapper can contain a Sequence,
706+
/// default deserialization will give us a PyNormalizerTypeWrapper::Single(Sequence) when we'd like
707+
/// it to be PyNormalizerTypeWrapper::Sequence(// ...).
708+
impl<'de> Deserialize<'de> for PyNormalizerTypeWrapper {
709+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
710+
where
711+
D: Deserializer<'de>,
712+
{
713+
let wrapper = NormalizerWrapper::deserialize(deserializer)?;
714+
let py_wrapper: PyNormalizerWrapper = wrapper.into();
715+
Ok(py_wrapper.into())
716+
}
717+
}
718+
644719
impl Serialize for PyNormalizerTypeWrapper {
645720
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
646721
where
@@ -672,7 +747,17 @@ where
672747
I: Into<PyNormalizerWrapper>,
673748
{
674749
fn from(norm: I) -> Self {
675-
PyNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm.into())))
750+
let norm = norm.into();
751+
match norm {
752+
PyNormalizerWrapper::Wrapped(NormalizerWrapper::Sequence(seq)) => {
753+
PyNormalizerTypeWrapper::Sequence(
754+
seq.into_iter()
755+
.map(|e| Arc::new(RwLock::new(PyNormalizerWrapper::Wrapped(e.clone()))))
756+
.collect(),
757+
)
758+
}
759+
_ => PyNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm))),
760+
}
676761
}
677762
}
678763

@@ -690,10 +775,15 @@ where
690775
impl Normalizer for PyNormalizerTypeWrapper {
691776
fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
692777
match self {
693-
PyNormalizerTypeWrapper::Single(inner) => inner.read().unwrap().normalize(normalized),
694-
PyNormalizerTypeWrapper::Sequence(inner) => inner
695-
.iter()
696-
.try_for_each(|n| n.read().unwrap().normalize(normalized)),
778+
PyNormalizerTypeWrapper::Single(inner) => inner
779+
.read()
780+
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))?
781+
.normalize(normalized),
782+
PyNormalizerTypeWrapper::Sequence(inner) => inner.iter().try_for_each(|n| {
783+
n.read()
784+
.map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))?
785+
.normalize(normalized)
786+
}),
697787
}
698788
}
699789
}
@@ -793,18 +883,14 @@ mod test {
793883
let normalizer: PyNormalizer = serde_json::from_str(&sequence_string).unwrap();
794884

795885
match normalizer.normalizer {
796-
PyNormalizerTypeWrapper::Single(inner) => match &*inner.as_ref().read().unwrap() {
797-
PyNormalizerWrapper::Wrapped(NormalizerWrapper::Sequence(sequence)) => {
798-
let normalizers = sequence.get_normalizers();
799-
assert_eq!(normalizers.len(), 1);
800-
match normalizers[0] {
801-
NormalizerWrapper::NFKC(_) => {}
802-
_ => panic!("Expected NFKC"),
803-
}
804-
}
805-
_ => panic!("Expected sequence"),
806-
},
807-
_ => panic!("Expected single"),
886+
PyNormalizerTypeWrapper::Sequence(inner) => {
887+
assert_eq!(inner.len(), 1);
888+
match *inner[0].as_ref().read().unwrap() {
889+
PyNormalizerWrapper::Wrapped(NormalizerWrapper::NFKC(_)) => {}
890+
_ => panic!("Expected NFKC"),
891+
};
892+
}
893+
_ => panic!("Expected sequence"),
808894
};
809895
}
810896
}

0 commit comments

Comments
 (0)