@@ -406,6 +406,55 @@ struct PyClassEnumVariant<'a> {
406
406
/* currently have no more options */
407
407
}
408
408
409
+ struct PyClassEnum < ' a > {
410
+ ident : & ' a syn:: Ident ,
411
+ // The underlying #[repr] of the enum, used to implement __int__ and __richcmp__.
412
+ // This matters when the underlying representation may not fit in `isize`.
413
+ #[ allow( unused, dead_code) ]
414
+ repr_type : syn:: Ident ,
415
+ variants : Vec < PyClassEnumVariant < ' a > > ,
416
+ }
417
+
418
+ impl < ' a > PyClassEnum < ' a > {
419
+ fn new ( enum_ : & ' a syn:: ItemEnum ) -> syn:: Result < Self > {
420
+ fn is_numeric_type ( t : & syn:: Ident ) -> bool {
421
+ [
422
+ "u8" , "i8" , "u16" , "i16" , "u32" , "i32" , "u64" , "i64" , "u128" , "i128" , "usize" ,
423
+ "isize" ,
424
+ ]
425
+ . iter ( )
426
+ . any ( |& s| t == s)
427
+ }
428
+ let ident = & enum_. ident ;
429
+ // According to the [reference](https://doc.rust-lang.org/reference/items/enumerations.html),
430
+ // "Under the default representation, the specified discriminant is interpreted as an isize
431
+ // value", so `isize` should be enough by default.
432
+ let mut repr_type = syn:: Ident :: new ( "isize" , proc_macro2:: Span :: call_site ( ) ) ;
433
+ if let Some ( attr) = enum_. attrs . iter ( ) . find ( |attr| attr. path . is_ident ( "repr" ) ) {
434
+ let args =
435
+ attr. parse_args_with ( Punctuated :: < TokenStream , Token ! [ !] > :: parse_terminated) ?;
436
+ if let Some ( ident) = args
437
+ . into_iter ( )
438
+ . filter_map ( |ts| syn:: parse2 :: < syn:: Ident > ( ts) . ok ( ) )
439
+ . find ( is_numeric_type)
440
+ {
441
+ repr_type = ident;
442
+ }
443
+ }
444
+
445
+ let variants = enum_
446
+ . variants
447
+ . iter ( )
448
+ . map ( extract_variant_data)
449
+ . collect :: < syn:: Result < _ > > ( ) ?;
450
+ Ok ( Self {
451
+ ident,
452
+ repr_type,
453
+ variants,
454
+ } )
455
+ }
456
+ }
457
+
409
458
pub fn build_py_enum (
410
459
enum_ : & mut syn:: ItemEnum ,
411
460
args : & PyClassArgs ,
@@ -416,41 +465,37 @@ pub fn build_py_enum(
416
465
if enum_. variants . is_empty ( ) {
417
466
bail_spanned ! ( enum_. brace_token. span => "Empty enums can't be #[pyclass]." ) ;
418
467
}
419
- let variants: Vec < PyClassEnumVariant > = enum_
420
- . variants
421
- . iter ( )
422
- . map ( extract_variant_data)
423
- . collect :: < syn:: Result < _ > > ( ) ?;
424
- impl_enum ( enum_, args, variants, method_type, options)
425
- }
426
-
427
- fn impl_enum (
428
- enum_ : & syn:: ItemEnum ,
429
- args : & PyClassArgs ,
430
- variants : Vec < PyClassEnumVariant > ,
431
- methods_type : PyClassMethodsType ,
432
- options : PyClassPyO3Options ,
433
- ) -> syn:: Result < TokenStream > {
434
- let enum_name = & enum_. ident ;
435
468
let doc = utils:: get_doc (
436
469
& enum_. attrs ,
437
470
options
438
471
. text_signature
439
472
. as_ref ( )
440
473
. map ( |attr| ( get_class_python_name ( & enum_. ident , args) , attr) ) ,
441
474
) ;
475
+ let enum_ = PyClassEnum :: new ( enum_) ?;
476
+ impl_enum ( enum_, args, doc, method_type, options)
477
+ }
478
+
479
+ fn impl_enum (
480
+ enum_ : PyClassEnum ,
481
+ args : & PyClassArgs ,
482
+ doc : PythonDoc ,
483
+ methods_type : PyClassMethodsType ,
484
+ options : PyClassPyO3Options ,
485
+ ) -> syn:: Result < TokenStream > {
442
486
let krate = get_pyo3_crate ( & options. krate ) ;
443
- impl_enum_class ( enum_name , args, variants , doc, methods_type, krate)
487
+ impl_enum_class ( enum_ , args, doc, methods_type, krate)
444
488
}
445
489
446
490
fn impl_enum_class (
447
- cls : & syn :: Ident ,
491
+ enum_ : PyClassEnum ,
448
492
args : & PyClassArgs ,
449
- variants : Vec < PyClassEnumVariant > ,
450
493
doc : PythonDoc ,
451
494
methods_type : PyClassMethodsType ,
452
495
krate : syn:: Path ,
453
496
) -> syn:: Result < TokenStream > {
497
+ let cls = enum_. ident ;
498
+ let variants = enum_. variants ;
454
499
let pytypeinfo = impl_pytypeinfo ( cls, args, None ) ;
455
500
let pyclass_impls = PyClassImplsBuilder :: new ( cls, args, methods_type)
456
501
. doc ( doc)
@@ -528,9 +573,6 @@ fn extract_variant_data(variant: &syn::Variant) -> syn::Result<PyClassEnumVarian
528
573
Fields :: Unit => & variant. ident ,
529
574
_ => bail_spanned ! ( variant. span( ) => "Currently only support unit variants." ) ,
530
575
} ;
531
- if let Some ( discriminant) = variant. discriminant . as_ref ( ) {
532
- bail_spanned ! ( discriminant. 0 . span( ) => "Currently does not support discriminats." )
533
- } ;
534
576
Ok ( PyClassEnumVariant { ident } )
535
577
}
536
578
0 commit comments