Skip to content

Commit 31156e5

Browse files
committed
Implement __int__ and __richcmp__.
1 parent dea57e0 commit 31156e5

File tree

2 files changed

+124
-15
lines changed

2 files changed

+124
-15
lines changed

pyo3-macros-backend/src/pyclass.rs

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,6 @@ struct PyClassEnum<'a> {
410410
ident: &'a syn::Ident,
411411
// The underlying #[repr] of the enum, used to implement __int__ and __richcmp__.
412412
// This matters when the underlying representation may not fit in `isize`.
413-
#[allow(unused, dead_code)]
414413
repr_type: syn::Ident,
415414
variants: Vec<PyClassEnumVariant<'a>>,
416415
}
@@ -522,7 +521,68 @@ fn impl_enum_class(
522521
}
523522
};
524523

525-
let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl]);
524+
let repr_type = &enum_.repr_type;
525+
526+
let default_int = {
527+
// This implementation allows us to convert &T to #repr_type without implementing `Copy`
528+
let variants_to_int = variants.iter().map(|variant| {
529+
let variant_name = variant.ident;
530+
quote! { #cls::#variant_name => #cls::#variant_name as #repr_type, }
531+
});
532+
quote! {
533+
#[doc(hidden)]
534+
#[allow(non_snake_case)]
535+
#[pyo3(name = "__int__")]
536+
fn __pyo3__int__(&self) -> #repr_type {
537+
match self {
538+
#(#variants_to_int)*
539+
}
540+
}
541+
}
542+
};
543+
544+
let default_richcmp = {
545+
let variants_eq = variants.iter().map(|variant| {
546+
let variant_name = variant.ident;
547+
quote! {
548+
(#cls::#variant_name, #cls::#variant_name) =>
549+
Ok(true.to_object(py)),
550+
}
551+
});
552+
quote! {
553+
#[doc(hidden)]
554+
#[allow(non_snake_case)]
555+
#[pyo3(name = "__richcmp__")]
556+
fn __pyo3__richcmp__(
557+
&self,
558+
py: _pyo3::Python,
559+
other: &_pyo3::PyAny,
560+
op: _pyo3::basic::CompareOp
561+
) -> _pyo3::PyResult<_pyo3::PyObject> {
562+
use _pyo3::conversion::ToPyObject;
563+
use ::core::result::Result::*;
564+
match op {
565+
_pyo3::basic::CompareOp::Eq => {
566+
if let Ok(i) = other.extract::<#repr_type>() {
567+
let self_val = self.__pyo3__int__();
568+
return Ok((self_val == i).to_object(py));
569+
}
570+
let other = other.extract::<_pyo3::PyRef<Self>>()?;
571+
let other = &*other;
572+
match (self, other) {
573+
#(#variants_eq)*
574+
_ => Ok(false.to_object(py)),
575+
}
576+
}
577+
_ => Ok(py.NotImplemented()),
578+
}
579+
}
580+
}
581+
};
582+
583+
let default_impls =
584+
gen_default_slot_impls(cls, vec![default_repr_impl, default_richcmp, default_int]);
585+
526586
Ok(quote! {
527587
const _: () = {
528588
use #krate as _pyo3;

tests/test_enum.rs

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@ pub enum MyEnum {
1414

1515
#[test]
1616
fn test_enum_class_attr() {
17-
let gil = Python::acquire_gil();
18-
let py = gil.python();
19-
let my_enum = py.get_type::<MyEnum>();
20-
py_assert!(py, my_enum, "getattr(my_enum, 'Variant', None) is not None");
21-
py_assert!(py, my_enum, "getattr(my_enum, 'foobar', None) is None");
22-
py_run!(py, my_enum, "my_enum.Variant = None");
17+
Python::with_gil(|py| {
18+
let my_enum = py.get_type::<MyEnum>();
19+
let var = Py::new(py, MyEnum::Variant).unwrap();
20+
py_assert!(py, my_enum var, "my_enum.Variant == var");
21+
})
2322
}
2423

2524
#[pyfunction]
@@ -28,7 +27,6 @@ fn return_enum() -> MyEnum {
2827
}
2928

3029
#[test]
31-
#[ignore] // need to implement __eq__
3230
fn test_return_enum() {
3331
let gil = Python::acquire_gil();
3432
let py = gil.python();
@@ -44,14 +42,24 @@ fn enum_arg(e: MyEnum) {
4442
}
4543

4644
#[test]
47-
#[ignore] // need to implement __eq__
4845
fn test_enum_arg() {
49-
let gil = Python::acquire_gil();
50-
let py = gil.python();
51-
let f = wrap_pyfunction!(enum_arg)(py).unwrap();
52-
let mynum = py.get_type::<MyEnum>();
46+
Python::with_gil(|py| {
47+
let f = wrap_pyfunction!(enum_arg)(py).unwrap();
48+
let mynum = py.get_type::<MyEnum>();
49+
50+
py_run!(py, f mynum, "f(mynum.OtherVariant)")
51+
})
52+
}
5353

54-
py_run!(py, f mynum, "f(mynum.Variant)")
54+
#[test]
55+
fn test_enum_eq() {
56+
Python::with_gil(|py| {
57+
let var1 = Py::new(py, MyEnum::Variant).unwrap();
58+
let var2 = Py::new(py, MyEnum::Variant).unwrap();
59+
let other_var = Py::new(py, MyEnum::OtherVariant).unwrap();
60+
py_assert!(py, var1 var2, "var1 == var2");
61+
py_assert!(py, var1 other_var, "var1 != other_var");
62+
})
5563
}
5664

5765
#[test]
@@ -85,6 +93,47 @@ fn test_custom_discriminant() {
8593
})
8694
}
8795

96+
#[test]
97+
fn test_enum_to_int() {
98+
Python::with_gil(|py| {
99+
let one = Py::new(py, CustomDiscriminant::One).unwrap();
100+
py_assert!(py, one, "int(one) == 1");
101+
let v = Py::new(py, MyEnum::Variant).unwrap();
102+
let v_value = MyEnum::Variant as isize;
103+
py_run!(py, v v_value, "int(v) == v_value");
104+
})
105+
}
106+
107+
#[test]
108+
fn test_enum_compare_int() {
109+
Python::with_gil(|py| {
110+
let one = Py::new(py, CustomDiscriminant::One).unwrap();
111+
py_run!(
112+
py,
113+
one,
114+
r#"
115+
assert one == 1
116+
assert 1 == one
117+
assert one != 2
118+
"#
119+
)
120+
})
121+
}
122+
123+
#[pyclass]
124+
#[repr(u8)]
125+
enum SmallEnum {
126+
V = 1,
127+
}
128+
129+
#[test]
130+
fn test_enum_compare_int_no_throw_when_overflow() {
131+
Python::with_gil(|py| {
132+
let v = Py::new(py, SmallEnum::V).unwrap();
133+
py_assert!(py, v, "v != 1<<30")
134+
})
135+
}
136+
88137
#[pyclass]
89138
#[repr(usize)]
90139
#[allow(clippy::enum_clike_unportable_variant)]

0 commit comments

Comments
 (0)