Skip to content

Commit 44d05dd

Browse files
committed
Implement default slot methods and default __repr__ for enums.
1 parent 245617a commit 44d05dd

File tree

3 files changed

+76
-92
lines changed

3 files changed

+76
-92
lines changed

pyo3-macros-backend/src/pyclass.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use crate::attributes::{self, take_pyo3_options, NameAttribute, TextSignatureAttribute};
44
use crate::deprecations::Deprecations;
55
use crate::konst::{ConstAttributes, ConstSpec};
6-
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
6+
use crate::pyimpl::{gen_default_slot_impls, gen_py_const, PyClassMethodsType};
77
use crate::pymethod::{impl_py_getter_def, impl_py_setter_def, PropertyType};
88
use crate::utils::{self, unwrap_group, PythonDoc};
99
use proc_macro2::{Span, TokenStream};
@@ -422,6 +422,24 @@ fn impl_enum_class(
422422
.impl_all();
423423
let descriptors = unit_variants_as_descriptors(cls, variants.iter().map(|v| v.ident));
424424

425+
let variants_repr = variants.iter().map(|variant| {
426+
let variant_name = variant.ident;
427+
// Assuming all variants are unit variants because they are the only type we support.
428+
let repr = format!("{}.{}", cls, variant_name);
429+
quote! { #cls::#variant_name => #repr, }
430+
});
431+
432+
let default_repr_impl = quote! {
433+
#[allow(non_snake_case)]
434+
#[pyo3(name = "__repr__")]
435+
fn __pyo3__repr__(&self) -> &'static str {
436+
match self {
437+
#(#variants_repr)*
438+
_ => unreachable!("Unsupported variant type."),
439+
}
440+
}
441+
};
442+
let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl]);
425443
Ok(quote! {
426444

427445
#pytypeinfo
@@ -430,6 +448,8 @@ fn impl_enum_class(
430448

431449
#descriptors
432450

451+
#default_impls
452+
433453
})
434454
}
435455

pyo3-macros-backend/src/pyimpl.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,47 @@ pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream {
139139
}
140140
}
141141

142+
pub fn gen_default_slot_impls(cls: &syn::Ident, method_defs: Vec<TokenStream>) -> TokenStream {
143+
// This function uses a lot of `unwrap()`; since method_defs are provided by us, they should
144+
// all succeed.
145+
let ty: syn::Type = syn::parse_quote!(#cls);
146+
147+
let mut method_defs: Vec<_> = method_defs
148+
.into_iter()
149+
.map(|token| syn::parse2::<syn::ImplItemMethod>(token).unwrap())
150+
.collect();
151+
152+
let mut proto_impls = Vec::new();
153+
154+
for meth in &mut method_defs {
155+
let options = PyFunctionOptions::from_attrs(&mut meth.attrs).unwrap();
156+
match pymethod::gen_py_method(&ty, &mut meth.sig, &mut meth.attrs, options).unwrap() {
157+
GeneratedPyMethod::Proto(token_stream) => {
158+
let attrs = get_cfg_attributes(&meth.attrs);
159+
proto_impls.push(quote!(#(#attrs)* #token_stream))
160+
}
161+
GeneratedPyMethod::SlotTraitImpl(..) => {
162+
todo!()
163+
}
164+
GeneratedPyMethod::Method(_) | GeneratedPyMethod::TraitImpl(_) => {
165+
panic!("Only protocol methods can have default implementation!")
166+
}
167+
}
168+
}
169+
170+
quote! {
171+
impl #cls {
172+
#(#method_defs)*
173+
}
174+
impl ::pyo3::class::impl_::PyClassDefaultSlots<#cls>
175+
for ::pyo3::class::impl_::PyClassImplCollector<#cls> {
176+
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
177+
&[#(#proto_impls),*]
178+
}
179+
}
180+
}
181+
}
182+
142183
fn impl_py_methods(ty: &syn::Type, methods: Vec<TokenStream>) -> TokenStream {
143184
quote! {
144185
impl ::pyo3::class::impl_::PyMethods<#ty>

tests/test_default_impls.rs

Lines changed: 14 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,41 @@
1-
#![allow(non_snake_case)]
2-
use pyo3::class::PyMethodDefType;
31
use pyo3::prelude::*;
4-
use pyo3::py_run;
52

63
mod common;
74

8-
// Tests for PyClassDefaultSlots
5+
// Test default generated __repr__.
96
#[pyclass]
10-
struct TestDefaultSlots;
11-
12-
// generated using `Cargo expand`
13-
// equivalent to
14-
// ```
15-
// impl TestDefaultSlots {{
16-
// fn __str__(&self) -> &'static str {
17-
// "default"
18-
// }
19-
// }
20-
// ```
21-
impl TestDefaultSlots {
22-
fn __pyo3__str__(&self) -> &'static str {
23-
"default"
24-
}
25-
}
26-
27-
impl ::pyo3::class::impl_::PyClassDefaultSlots<TestDefaultSlots>
28-
for ::pyo3::class::impl_::PyClassImplCollector<TestDefaultSlots>
29-
{
30-
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
31-
&[{
32-
unsafe extern "C" fn __wrap(
33-
_raw_slf: *mut ::pyo3::ffi::PyObject,
34-
) -> *mut ::pyo3::ffi::PyObject {
35-
let _slf = _raw_slf;
36-
::pyo3::callback::handle_panic(|_py| {
37-
let _cell = _py
38-
.from_borrowed_ptr::<::pyo3::PyAny>(_slf)
39-
.downcast::<::pyo3::PyCell<TestDefaultSlots>>()?;
40-
let _ref = _cell.try_borrow()?;
41-
let _slf = &_ref;
42-
::pyo3::callback::convert(_py, TestDefaultSlots::__pyo3__str__(_slf))
43-
})
44-
}
45-
::pyo3::ffi::PyType_Slot {
46-
slot: ::pyo3::ffi::Py_tp_str,
47-
pfunc: __wrap as ::pyo3::ffi::reprfunc as _,
48-
}
49-
}]
50-
}
7+
enum TestDefaultRepr {
8+
Var,
519
}
5210

5311
#[test]
5412
fn test_default_slot_exists() {
5513
Python::with_gil(|py| {
56-
let test_object = Py::new(py, TestDefaultSlots).unwrap();
57-
py_assert!(py, test_object, "str(test_object) == 'default'");
14+
let test_object = Py::new(py, TestDefaultRepr::Var).unwrap();
15+
py_assert!(
16+
py,
17+
test_object,
18+
"repr(test_object) == 'TestDefaultRepr.Var'"
19+
);
5820
})
5921
}
6022

6123
#[pyclass]
62-
struct OverrideSlot;
63-
64-
// generated using `Cargo expand`
65-
// equivalent to
66-
// ```
67-
// impl OverrideMagicMethod {
68-
// fn __str__(&self) -> &'static str {
69-
// "default"
70-
// }
71-
// }
72-
// ```
73-
impl OverrideSlot {
74-
fn __pyo3__str__(&self) -> &'static str {
75-
"default"
76-
}
77-
}
78-
79-
impl ::pyo3::class::impl_::PyClassDefaultSlots<OverrideSlot>
80-
for ::pyo3::class::impl_::PyClassImplCollector<OverrideSlot>
81-
{
82-
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
83-
&[{
84-
unsafe extern "C" fn __wrap(
85-
_raw_slf: *mut ::pyo3::ffi::PyObject,
86-
) -> *mut ::pyo3::ffi::PyObject {
87-
let _slf = _raw_slf;
88-
::pyo3::callback::handle_panic(|_py| {
89-
let _cell = _py
90-
.from_borrowed_ptr::<::pyo3::PyAny>(_slf)
91-
.downcast::<::pyo3::PyCell<OverrideSlot>>()?;
92-
let _ref = _cell.try_borrow()?;
93-
let _slf = &_ref;
94-
::pyo3::callback::convert(_py, OverrideSlot::__pyo3__str__(_slf))
95-
})
96-
}
97-
::pyo3::ffi::PyType_Slot {
98-
slot: ::pyo3::ffi::Py_tp_str,
99-
pfunc: __wrap as ::pyo3::ffi::reprfunc as _,
100-
}
101-
}]
102-
}
24+
enum OverrideSlot {
25+
Var,
10326
}
10427

10528
#[pymethods]
10629
impl OverrideSlot {
107-
fn __str__(&self) -> &str {
30+
fn __repr__(&self) -> &str {
10831
"overriden"
10932
}
11033
}
11134

11235
#[test]
11336
fn test_override_slot() {
11437
Python::with_gil(|py| {
115-
let test_object = Py::new(py, OverrideSlot).unwrap();
116-
py_assert!(py, test_object, "str(test_object) == 'overriden'");
38+
let test_object = Py::new(py, OverrideSlot::Var).unwrap();
39+
py_assert!(py, test_object, "repr(test_object) == 'overriden'");
11740
})
11841
}

0 commit comments

Comments
 (0)