1
+ use pyo3:: exceptions:: PyException ;
1
2
use pyo3:: types:: * ;
2
3
use pyo3:: { exceptions, prelude:: * } ;
3
4
use std:: sync:: { Arc , RwLock } ;
@@ -41,7 +42,7 @@ impl PyNormalizedStringMut<'_> {
41
42
/// This class is not supposed to be instantiated directly. Instead, any implementation of a
42
43
/// Normalizer will return an instance of this class when instantiated.
43
44
#[ pyclass( dict, module = "tokenizers.normalizers" , name = "Normalizer" , subclass) ]
44
- #[ derive( Clone , Serialize , Deserialize ) ]
45
+ #[ derive( Clone , Debug , Serialize , Deserialize ) ]
45
46
#[ serde( transparent) ]
46
47
pub struct PyNormalizer {
47
48
pub ( crate ) normalizer : PyNormalizerTypeWrapper ,
@@ -58,7 +59,11 @@ impl PyNormalizer {
58
59
. into_pyobject ( py) ?
59
60
. into_any ( )
60
61
. into ( ) ,
61
- PyNormalizerTypeWrapper :: Single ( ref inner) => match & * inner. as_ref ( ) . read ( ) . unwrap ( ) {
62
+ PyNormalizerTypeWrapper :: Single ( ref inner) => match & * inner
63
+ . as_ref ( )
64
+ . read ( )
65
+ . map_err ( |_| PyException :: new_err ( "RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer" ) ) ?
66
+ {
62
67
PyNormalizerWrapper :: Custom ( _) => {
63
68
Py :: new ( py, base) ?. into_pyobject ( py) ?. into_any ( ) . into ( )
64
69
}
@@ -218,7 +223,9 @@ macro_rules! getter {
218
223
( $self: ident, $variant: ident, $name: ident) => { {
219
224
let super_ = $self. as_ref( ) ;
220
225
if let PyNormalizerTypeWrapper :: Single ( ref norm) = super_. normalizer {
221
- let wrapper = norm. read( ) . unwrap( ) ;
226
+ let wrapper = norm. read( ) . expect(
227
+ "RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer" ,
228
+ ) ;
222
229
if let PyNormalizerWrapper :: Wrapped ( NormalizerWrapper :: $variant( o) ) = ( & * wrapper) {
223
230
o. $name. clone( )
224
231
} else {
@@ -234,7 +241,9 @@ macro_rules! setter {
234
241
( $self: ident, $variant: ident, $name: ident, $value: expr) => { {
235
242
let super_ = $self. as_ref( ) ;
236
243
if let PyNormalizerTypeWrapper :: Single ( ref norm) = super_. normalizer {
237
- let mut wrapper = norm. write( ) . unwrap( ) ;
244
+ let mut wrapper = norm. write( ) . expect(
245
+ "RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer" ,
246
+ ) ;
238
247
if let PyNormalizerWrapper :: Wrapped ( NormalizerWrapper :: $variant( ref mut o) ) = * wrapper {
239
248
o. $name = $value;
240
249
}
@@ -410,25 +419,55 @@ impl PySequence {
410
419
PyTuple :: new ( py, [ PyList :: empty ( py) ] )
411
420
}
412
421
413
- fn __len__ ( & self ) -> usize {
414
- 0
422
+ fn __len__ ( self_ : PyRef < ' _ , Self > ) -> usize {
423
+ match & self_. as_ref ( ) . normalizer {
424
+ PyNormalizerTypeWrapper :: Sequence ( inner) => inner. len ( ) ,
425
+ PyNormalizerTypeWrapper :: Single ( _) => 1 ,
426
+ }
415
427
}
416
428
417
429
fn __getitem__ ( self_ : PyRef < ' _ , Self > , py : Python < ' _ > , index : usize ) -> PyResult < Py < PyAny > > {
418
430
match & self_. as_ref ( ) . normalizer {
419
431
PyNormalizerTypeWrapper :: Sequence ( inner) => match inner. get ( index) {
420
- Some ( item) => PyNormalizer :: new ( PyNormalizerTypeWrapper :: Single ( Arc :: clone ( item ) ) )
432
+ Some ( item) => PyNormalizer :: new ( PyNormalizerTypeWrapper :: Single ( item . clone ( ) ) )
421
433
. get_as_subtype ( py) ,
422
434
_ => Err ( PyErr :: new :: < pyo3:: exceptions:: PyIndexError , _ > (
423
435
"Index not found" ,
424
436
) ) ,
425
437
} ,
426
438
PyNormalizerTypeWrapper :: Single ( inner) => {
427
- PyNormalizer :: new ( PyNormalizerTypeWrapper :: Single ( Arc :: clone ( inner) ) )
428
- . get_as_subtype ( py)
439
+ PyNormalizer :: new ( PyNormalizerTypeWrapper :: Single ( inner. clone ( ) ) ) . get_as_subtype ( py)
429
440
}
430
441
}
431
442
}
443
+
444
+ fn __setitem__ ( self_ : PyRef < ' _ , Self > , index : usize , value : Bound < ' _ , PyAny > ) -> PyResult < ( ) > {
445
+ let norm: PyNormalizer = value. extract ( ) ?;
446
+ let PyNormalizerTypeWrapper :: Single ( norm) = norm. normalizer else {
447
+ return Err ( PyException :: new_err ( "normalizer should not be a sequence" ) ) ;
448
+ } ;
449
+ match & self_. as_ref ( ) . normalizer {
450
+ PyNormalizerTypeWrapper :: Sequence ( inner) => match inner. get ( index) {
451
+ Some ( item) => {
452
+ * item
453
+ . write ( )
454
+ . map_err ( |_| PyException :: new_err ( "RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer" ) ) ? = norm
455
+ . read ( )
456
+ . map_err ( |_| PyException :: new_err ( "RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer" ) ) ?
457
+ . clone ( ) ;
458
+ }
459
+ _ => {
460
+ return Err ( PyErr :: new :: < pyo3:: exceptions:: PyIndexError , _ > (
461
+ "Index not found" ,
462
+ ) )
463
+ }
464
+ } ,
465
+ PyNormalizerTypeWrapper :: Single ( _) => {
466
+ return Err ( PyException :: new_err ( "normalizer is not a sequence" ) )
467
+ }
468
+ } ;
469
+ Ok ( ( ) )
470
+ }
432
471
}
433
472
434
473
/// Lowercase Normalizer
@@ -570,9 +609,31 @@ impl PyReplace {
570
609
ToPyResult ( Replace :: new ( pattern, content) ) . into_py ( ) ?. into ( ) ,
571
610
) )
572
611
}
612
+
613
+ #[ getter]
614
+ fn get_pattern ( _self : PyRef < Self > ) -> PyResult < ( ) > {
615
+ Err ( PyException :: new_err ( "Cannot get pattern" ) )
616
+ }
617
+
618
+ #[ setter]
619
+ fn set_pattern ( _self : PyRef < Self > , _pattern : PyPattern ) -> PyResult < ( ) > {
620
+ Err ( PyException :: new_err (
621
+ "Cannot set pattern, please instantiate a new replace pattern instead" ,
622
+ ) )
623
+ }
624
+
625
+ #[ getter]
626
+ fn get_content ( self_ : PyRef < Self > ) -> String {
627
+ getter ! ( self_, Replace , content)
628
+ }
629
+
630
+ #[ setter]
631
+ fn set_content ( self_ : PyRef < Self > , content : String ) {
632
+ setter ! ( self_, Replace , content, content)
633
+ }
573
634
}
574
635
575
- #[ derive( Debug ) ]
636
+ #[ derive( Clone , Debug ) ]
576
637
pub ( crate ) struct CustomNormalizer {
577
638
inner : PyObject ,
578
639
}
@@ -615,7 +676,7 @@ impl<'de> Deserialize<'de> for CustomNormalizer {
615
676
}
616
677
}
617
678
618
- #[ derive( Debug , Deserialize ) ]
679
+ #[ derive( Clone , Debug , Deserialize ) ]
619
680
#[ serde( untagged) ]
620
681
pub ( crate ) enum PyNormalizerWrapper {
621
682
Custom ( CustomNormalizer ) ,
@@ -634,13 +695,27 @@ impl Serialize for PyNormalizerWrapper {
634
695
}
635
696
}
636
697
637
- #[ derive( Debug , Clone , Deserialize ) ]
638
- #[ serde( untagged) ]
698
+ #[ derive( Debug , Clone ) ]
639
699
pub ( crate ) enum PyNormalizerTypeWrapper {
640
700
Sequence ( Vec < Arc < RwLock < PyNormalizerWrapper > > > ) ,
641
701
Single ( Arc < RwLock < PyNormalizerWrapper > > ) ,
642
702
}
643
703
704
+ /// XXX: we need to manually implement deserialize here because of the structure of the
705
+ /// PyNormalizerTypeWrapper enum. Given the underlying PyNormalizerWrapper can contain a Sequence,
706
+ /// default deserialization will give us a PyNormalizerTypeWrapper::Single(Sequence) when we'd like
707
+ /// it to be PyNormalizerTypeWrapper::Sequence(// ...).
708
+ impl < ' de > Deserialize < ' de > for PyNormalizerTypeWrapper {
709
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
710
+ where
711
+ D : Deserializer < ' de > ,
712
+ {
713
+ let wrapper = NormalizerWrapper :: deserialize ( deserializer) ?;
714
+ let py_wrapper: PyNormalizerWrapper = wrapper. into ( ) ;
715
+ Ok ( py_wrapper. into ( ) )
716
+ }
717
+ }
718
+
644
719
impl Serialize for PyNormalizerTypeWrapper {
645
720
fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
646
721
where
@@ -672,7 +747,17 @@ where
672
747
I : Into < PyNormalizerWrapper > ,
673
748
{
674
749
fn from ( norm : I ) -> Self {
675
- PyNormalizerTypeWrapper :: Single ( Arc :: new ( RwLock :: new ( norm. into ( ) ) ) )
750
+ let norm = norm. into ( ) ;
751
+ match norm {
752
+ PyNormalizerWrapper :: Wrapped ( NormalizerWrapper :: Sequence ( seq) ) => {
753
+ PyNormalizerTypeWrapper :: Sequence (
754
+ seq. into_iter ( )
755
+ . map ( |e| Arc :: new ( RwLock :: new ( PyNormalizerWrapper :: Wrapped ( e. clone ( ) ) ) ) )
756
+ . collect ( ) ,
757
+ )
758
+ }
759
+ _ => PyNormalizerTypeWrapper :: Single ( Arc :: new ( RwLock :: new ( norm) ) ) ,
760
+ }
676
761
}
677
762
}
678
763
@@ -690,10 +775,15 @@ where
690
775
impl Normalizer for PyNormalizerTypeWrapper {
691
776
fn normalize ( & self , normalized : & mut NormalizedString ) -> tk:: Result < ( ) > {
692
777
match self {
693
- PyNormalizerTypeWrapper :: Single ( inner) => inner. read ( ) . unwrap ( ) . normalize ( normalized) ,
694
- PyNormalizerTypeWrapper :: Sequence ( inner) => inner
695
- . iter ( )
696
- . try_for_each ( |n| n. read ( ) . unwrap ( ) . normalize ( normalized) ) ,
778
+ PyNormalizerTypeWrapper :: Single ( inner) => inner
779
+ . read ( )
780
+ . map_err ( |_| PyException :: new_err ( "RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer" ) ) ?
781
+ . normalize ( normalized) ,
782
+ PyNormalizerTypeWrapper :: Sequence ( inner) => inner. iter ( ) . try_for_each ( |n| {
783
+ n. read ( )
784
+ . map_err ( |_| PyException :: new_err ( "RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer" ) ) ?
785
+ . normalize ( normalized)
786
+ } ) ,
697
787
}
698
788
}
699
789
}
@@ -793,18 +883,14 @@ mod test {
793
883
let normalizer: PyNormalizer = serde_json:: from_str ( & sequence_string) . unwrap ( ) ;
794
884
795
885
match normalizer. normalizer {
796
- PyNormalizerTypeWrapper :: Single ( inner) => match & * inner. as_ref ( ) . read ( ) . unwrap ( ) {
797
- PyNormalizerWrapper :: Wrapped ( NormalizerWrapper :: Sequence ( sequence) ) => {
798
- let normalizers = sequence. get_normalizers ( ) ;
799
- assert_eq ! ( normalizers. len( ) , 1 ) ;
800
- match normalizers[ 0 ] {
801
- NormalizerWrapper :: NFKC ( _) => { }
802
- _ => panic ! ( "Expected NFKC" ) ,
803
- }
804
- }
805
- _ => panic ! ( "Expected sequence" ) ,
806
- } ,
807
- _ => panic ! ( "Expected single" ) ,
886
+ PyNormalizerTypeWrapper :: Sequence ( inner) => {
887
+ assert_eq ! ( inner. len( ) , 1 ) ;
888
+ match * inner[ 0 ] . as_ref ( ) . read ( ) . unwrap ( ) {
889
+ PyNormalizerWrapper :: Wrapped ( NormalizerWrapper :: NFKC ( _) ) => { }
890
+ _ => panic ! ( "Expected NFKC" ) ,
891
+ } ;
892
+ }
893
+ _ => panic ! ( "Expected sequence" ) ,
808
894
} ;
809
895
}
810
896
}
0 commit comments