Skip to content

Commit 1ef9cb0

Browse files
committed
Implement __int__ and __eq__ betweeen enum and int
1 parent dc549cc commit 1ef9cb0

File tree

2 files changed

+89
-5
lines changed

2 files changed

+89
-5
lines changed

pyo3-macros-backend/src/pyclass.rs

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -501,30 +501,58 @@ fn impl_enum_class(
501501
}
502502
};
503503

504+
let repr_type = &enum_.repr;
505+
506+
let default_int = {
507+
// This implementation allows us to convert &T to #repr_type without implementing `Copy`
508+
let variants_to_int = variants.iter().map(|variant| {
509+
let variant_name = variant.ident;
510+
quote! { #cls::#variant_name => #cls::#variant_name as #repr_type, }
511+
});
512+
quote! {
513+
#[doc(hidden)]
514+
#[allow(non_snake_case)]
515+
#[pyo3(name = "__int__")]
516+
fn __pyo3__int__(&self) -> #repr_type {
517+
match self {
518+
#(#variants_to_int)*
519+
}
520+
}
521+
}
522+
};
523+
504524
let default_richcmp = {
505525
let variants_eq = variants.iter().map(|variant| {
506526
let variant_name = variant.ident;
507-
quote! {(#cls::#variant_name, #cls::#variant_name) => true.to_object(py),}
527+
quote! {(#cls::#variant_name, #cls::#variant_name) => Ok(true.to_object(py)),}
508528
});
509529
quote! {
510530
#[doc(hidden)]
511531
#[allow(non_snake_case)]
512532
#[pyo3(name = "__richcmp__")]
513-
fn __pyo3__richcmp__(&self, py: ::pyo3::Python, other: &Self, op: ::pyo3::basic::CompareOp) -> PyObject {
533+
fn __pyo3__richcmp__(&self, py: ::pyo3::Python, other: &PyAny, op: ::pyo3::basic::CompareOp) -> PyResult<PyObject> {
514534
match op {
515535
::pyo3::basic::CompareOp::Eq => {
536+
if let Ok(i) = other.extract::<#repr_type>() {
537+
let self_val = self.__pyo3__int__();
538+
return Ok((self_val == i).to_object(py));
539+
}
540+
let other = other.extract::<PyRef<Self>>()?;
541+
let other = &*other;
516542
match (self, other) {
517543
#(#variants_eq)*
518-
_ => false.to_object(py),
544+
_ => Ok(false.to_object(py)),
519545
}
520546
}
521-
_ => py.NotImplemented(),
547+
_ => Ok(py.NotImplemented()),
522548
}
523549
}
524550
}
525551
};
526552

527-
let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl, default_richcmp]);
553+
let default_impls =
554+
gen_default_slot_impls(cls, vec![default_repr_impl, default_richcmp, default_int]);
555+
528556
Ok(quote! {
529557

530558
#pytypeinfo

tests/test_enum.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,59 @@ fn test_custom_discriminant() {
9090
"#);
9191
})
9292
}
93+
94+
#[test]
95+
fn test_enum_to_int() {
96+
Python::with_gil(|py| {
97+
let one = Py::new(py, CustomDiscriminant::One).unwrap();
98+
py_assert!(py, one, "int(one) == 1");
99+
let test_noexception = Py::new(py, MyEnum::Variant).unwrap();
100+
py_run!(py, test_noexception, "int(test_noexception)");
101+
})
102+
}
103+
104+
#[test]
105+
fn test_enum_compare_int() {
106+
Python::with_gil(|py| {
107+
let one = Py::new(py, CustomDiscriminant::One).unwrap();
108+
py_run!(
109+
py,
110+
one,
111+
r#"
112+
assert one == 1
113+
assert 1 == one
114+
assert one != 2
115+
"#
116+
)
117+
})
118+
}
119+
120+
#[pyclass]
121+
#[repr(u8)]
122+
enum SmallEnum {
123+
V = 1,
124+
}
125+
126+
#[test]
127+
fn test_enum_compare_int_no_throw_when_overflow() {
128+
Python::with_gil(|py| {
129+
let v = Py::new(py, SmallEnum::V).unwrap();
130+
py_assert!(py, v, "v != 1<<30")
131+
})
132+
}
133+
134+
#[pyclass]
135+
#[repr(usize)]
136+
enum BigEnum {
137+
V = usize::MAX,
138+
}
139+
140+
#[test]
141+
fn test_big_enum_no_overflow() {
142+
Python::with_gil(|py| {
143+
let usize_max = usize::MAX;
144+
let v = Py::new(py, BigEnum::V).unwrap();
145+
py_assert!(py, usize_max v, "v == usize_max");
146+
py_assert!(py, usize_max v, "int(v) == usize_max");
147+
})
148+
}

0 commit comments

Comments
 (0)