Skip to content

Commit dea57e0

Browse files
committed
Parse #[repr(..)] for #[pyclass] enums. Allow custom discriminants.
1 parent 309065d commit dea57e0

File tree

2 files changed

+113
-22
lines changed

2 files changed

+113
-22
lines changed

pyo3-macros-backend/src/pyclass.rs

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,55 @@ struct PyClassEnumVariant<'a> {
406406
/* currently have no more options */
407407
}
408408

409+
struct PyClassEnum<'a> {
410+
ident: &'a syn::Ident,
411+
// The underlying #[repr] of the enum, used to implement __int__ and __richcmp__.
412+
// This matters when the underlying representation may not fit in `isize`.
413+
#[allow(unused, dead_code)]
414+
repr_type: syn::Ident,
415+
variants: Vec<PyClassEnumVariant<'a>>,
416+
}
417+
418+
impl<'a> PyClassEnum<'a> {
419+
fn new(enum_: &'a syn::ItemEnum) -> syn::Result<Self> {
420+
fn is_numeric_type(t: &syn::Ident) -> bool {
421+
[
422+
"u8", "i8", "u16", "i16", "u32", "i32", "u64", "i64", "u128", "i128", "usize",
423+
"isize",
424+
]
425+
.iter()
426+
.any(|&s| t == s)
427+
}
428+
let ident = &enum_.ident;
429+
// According to the [reference](https://doc.rust-lang.org/reference/items/enumerations.html),
430+
// "Under the default representation, the specified discriminant is interpreted as an isize
431+
// value", so `isize` should be enough by default.
432+
let mut repr_type = syn::Ident::new("isize", proc_macro2::Span::call_site());
433+
if let Some(attr) = enum_.attrs.iter().find(|attr| attr.path.is_ident("repr")) {
434+
let args =
435+
attr.parse_args_with(Punctuated::<TokenStream, Token![!]>::parse_terminated)?;
436+
if let Some(ident) = args
437+
.into_iter()
438+
.filter_map(|ts| syn::parse2::<syn::Ident>(ts).ok())
439+
.find(is_numeric_type)
440+
{
441+
repr_type = ident;
442+
}
443+
}
444+
445+
let variants = enum_
446+
.variants
447+
.iter()
448+
.map(extract_variant_data)
449+
.collect::<syn::Result<_>>()?;
450+
Ok(Self {
451+
ident,
452+
repr_type,
453+
variants,
454+
})
455+
}
456+
}
457+
409458
pub fn build_py_enum(
410459
enum_: &mut syn::ItemEnum,
411460
args: &PyClassArgs,
@@ -416,41 +465,37 @@ pub fn build_py_enum(
416465
if enum_.variants.is_empty() {
417466
bail_spanned!(enum_.brace_token.span => "Empty enums can't be #[pyclass].");
418467
}
419-
let variants: Vec<PyClassEnumVariant> = enum_
420-
.variants
421-
.iter()
422-
.map(extract_variant_data)
423-
.collect::<syn::Result<_>>()?;
424-
impl_enum(enum_, args, variants, method_type, options)
425-
}
426-
427-
fn impl_enum(
428-
enum_: &syn::ItemEnum,
429-
args: &PyClassArgs,
430-
variants: Vec<PyClassEnumVariant>,
431-
methods_type: PyClassMethodsType,
432-
options: PyClassPyO3Options,
433-
) -> syn::Result<TokenStream> {
434-
let enum_name = &enum_.ident;
435468
let doc = utils::get_doc(
436469
&enum_.attrs,
437470
options
438471
.text_signature
439472
.as_ref()
440473
.map(|attr| (get_class_python_name(&enum_.ident, args), attr)),
441474
);
475+
let enum_ = PyClassEnum::new(enum_)?;
476+
impl_enum(enum_, args, doc, method_type, options)
477+
}
478+
479+
fn impl_enum(
480+
enum_: PyClassEnum,
481+
args: &PyClassArgs,
482+
doc: PythonDoc,
483+
methods_type: PyClassMethodsType,
484+
options: PyClassPyO3Options,
485+
) -> syn::Result<TokenStream> {
442486
let krate = get_pyo3_crate(&options.krate);
443-
impl_enum_class(enum_name, args, variants, doc, methods_type, krate)
487+
impl_enum_class(enum_, args, doc, methods_type, krate)
444488
}
445489

446490
fn impl_enum_class(
447-
cls: &syn::Ident,
491+
enum_: PyClassEnum,
448492
args: &PyClassArgs,
449-
variants: Vec<PyClassEnumVariant>,
450493
doc: PythonDoc,
451494
methods_type: PyClassMethodsType,
452495
krate: syn::Path,
453496
) -> syn::Result<TokenStream> {
497+
let cls = enum_.ident;
498+
let variants = enum_.variants;
454499
let pytypeinfo = impl_pytypeinfo(cls, args, None);
455500
let pyclass_impls = PyClassImplsBuilder::new(cls, args, methods_type)
456501
.doc(doc)
@@ -528,9 +573,6 @@ fn extract_variant_data(variant: &syn::Variant) -> syn::Result<PyClassEnumVarian
528573
Fields::Unit => &variant.ident,
529574
_ => bail_spanned!(variant.span() => "Currently only support unit variants."),
530575
};
531-
if let Some(discriminant) = variant.discriminant.as_ref() {
532-
bail_spanned!(discriminant.0.span() => "Currently does not support discriminats.")
533-
};
534576
Ok(PyClassEnumVariant { ident })
535577
}
536578

tests/test_enum.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,52 @@ fn test_default_repr_correct() {
6363
py_assert!(py, var2, "repr(var2) == 'MyEnum.OtherVariant'");
6464
})
6565
}
66+
67+
#[pyclass]
68+
enum CustomDiscriminant {
69+
One = 1,
70+
Two = 2,
71+
}
72+
73+
#[test]
74+
fn test_custom_discriminant() {
75+
Python::with_gil(|py| {
76+
#[allow(non_snake_case)]
77+
let CustomDiscriminant = py.get_type::<CustomDiscriminant>();
78+
let one = Py::new(py, CustomDiscriminant::One).unwrap();
79+
let two = Py::new(py, CustomDiscriminant::Two).unwrap();
80+
py_run!(py, CustomDiscriminant one two, r#"
81+
assert CustomDiscriminant.One == one
82+
assert CustomDiscriminant.Two == two
83+
assert one != two
84+
"#);
85+
})
86+
}
87+
88+
#[pyclass]
89+
#[repr(usize)]
90+
#[allow(clippy::enum_clike_unportable_variant)]
91+
enum BigEnum {
92+
V = usize::MAX,
93+
}
94+
95+
#[test]
96+
fn test_big_enum_no_overflow() {
97+
Python::with_gil(|py| {
98+
let usize_max = usize::MAX;
99+
let v = Py::new(py, BigEnum::V).unwrap();
100+
py_assert!(py, usize_max v, "v == usize_max");
101+
py_assert!(py, usize_max v, "int(v) == usize_max");
102+
})
103+
}
104+
105+
#[pyclass]
106+
#[repr(u16, align(8))]
107+
enum TestReprParse {
108+
V,
109+
}
110+
111+
#[test]
112+
fn test_repr_parse() {
113+
assert_eq!(std::mem::align_of::<TestReprParse>(), 8);
114+
}

0 commit comments

Comments
 (0)