Skip to content

Commit 8a03778

Browse files
authored
Merge pull request #2014 from b05902132/default_impl
Support default method implementation
2 parents 7455518 + aac0e56 commit 8a03778

File tree

5 files changed

+122
-1
lines changed

5 files changed

+122
-1
lines changed

pyo3-macros-backend/src/pyclass.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use crate::attributes::{self, take_pyo3_options, NameAttribute, TextSignatureAttribute};
44
use crate::deprecations::Deprecations;
55
use crate::konst::{ConstAttributes, ConstSpec};
6-
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
6+
use crate::pyimpl::{gen_default_slot_impls, gen_py_const, PyClassMethodsType};
77
use crate::pymethod::{impl_py_getter_def, impl_py_setter_def, PropertyType};
88
use crate::utils::{self, unwrap_group, PythonDoc};
99
use proc_macro2::{Span, TokenStream};
@@ -425,6 +425,27 @@ fn impl_enum_class(
425425
.impl_all();
426426
let descriptors = unit_variants_as_descriptors(cls, variants.iter().map(|v| v.ident));
427427

428+
let default_repr_impl = {
429+
let variants_repr = variants.iter().map(|variant| {
430+
let variant_name = variant.ident;
431+
// Assuming all variants are unit variants because they are the only type we support.
432+
let repr = format!("{}.{}", cls, variant_name);
433+
quote! { #cls::#variant_name => #repr, }
434+
});
435+
quote! {
436+
#[doc(hidden)]
437+
#[allow(non_snake_case)]
438+
#[pyo3(name = "__repr__")]
439+
fn __pyo3__repr__(&self) -> &'static str {
440+
match self {
441+
#(#variants_repr)*
442+
_ => unreachable!("Unsupported variant type."),
443+
}
444+
}
445+
}
446+
};
447+
448+
let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl]);
428449
Ok(quote! {
429450

430451
#pytypeinfo
@@ -433,6 +454,8 @@ fn impl_enum_class(
433454

434455
#descriptors
435456

457+
#default_impls
458+
436459
})
437460
}
438461

@@ -758,6 +781,9 @@ impl<'a> PyClassImplsBuilder<'a> {
758781
// Implementation which uses dtolnay specialization to load all slots.
759782
use ::pyo3::class::impl_::*;
760783
let collector = PyClassImplCollector::<Self>::new();
784+
// This depends on Python implementation detail;
785+
// an old slot entry will be overriden by newer ones.
786+
visitor(collector.py_class_default_slots());
761787
visitor(collector.object_protocol_slots());
762788
visitor(collector.number_protocol_slots());
763789
visitor(collector.iter_protocol_slots());

pyo3-macros-backend/src/pyimpl.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,47 @@ pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream {
139139
}
140140
}
141141

142+
pub fn gen_default_slot_impls(cls: &syn::Ident, method_defs: Vec<TokenStream>) -> TokenStream {
143+
// This function uses a lot of `unwrap()`; since method_defs are provided by us, they should
144+
// all succeed.
145+
let ty: syn::Type = syn::parse_quote!(#cls);
146+
147+
let mut method_defs: Vec<_> = method_defs
148+
.into_iter()
149+
.map(|token| syn::parse2::<syn::ImplItemMethod>(token).unwrap())
150+
.collect();
151+
152+
let mut proto_impls = Vec::new();
153+
154+
for meth in &mut method_defs {
155+
let options = PyFunctionOptions::from_attrs(&mut meth.attrs).unwrap();
156+
match pymethod::gen_py_method(&ty, &mut meth.sig, &mut meth.attrs, options).unwrap() {
157+
GeneratedPyMethod::Proto(token_stream) => {
158+
let attrs = get_cfg_attributes(&meth.attrs);
159+
proto_impls.push(quote!(#(#attrs)* #token_stream))
160+
}
161+
GeneratedPyMethod::SlotTraitImpl(..) => {
162+
panic!("SlotFragment methods cannot have default implementation!")
163+
}
164+
GeneratedPyMethod::Method(_) | GeneratedPyMethod::TraitImpl(_) => {
165+
panic!("Only protocol methods can have default implementation!")
166+
}
167+
}
168+
}
169+
170+
quote! {
171+
impl #cls {
172+
#(#method_defs)*
173+
}
174+
impl ::pyo3::class::impl_::PyClassDefaultSlots<#cls>
175+
for ::pyo3::class::impl_::PyClassImplCollector<#cls> {
176+
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
177+
&[#(#proto_impls),*]
178+
}
179+
}
180+
}
181+
}
182+
142183
fn impl_py_methods(ty: &syn::Type, methods: Vec<TokenStream>) -> TokenStream {
143184
quote! {
144185
impl ::pyo3::class::impl_::PyMethods<#ty>

src/class/impl_.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,9 @@ slots_trait!(PyAsyncProtocolSlots, async_protocol_slots);
657657
slots_trait!(PySequenceProtocolSlots, sequence_protocol_slots);
658658
slots_trait!(PyBufferProtocolSlots, buffer_protocol_slots);
659659

660+
// slots that PyO3 implements by default, but can be overidden by the users.
661+
slots_trait!(PyClassDefaultSlots, py_class_default_slots);
662+
660663
// Protocol slots from #[pymethods] if not using inventory.
661664
#[cfg(not(feature = "multiple-pymethods"))]
662665
slots_trait!(PyMethodsProtocolSlots, methods_protocol_slots);

tests/test_default_impls.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
use pyo3::prelude::*;
2+
3+
mod common;
4+
5+
// Test default generated __repr__.
6+
#[pyclass]
7+
enum TestDefaultRepr {
8+
Var,
9+
}
10+
11+
#[test]
12+
fn test_default_slot_exists() {
13+
Python::with_gil(|py| {
14+
let test_object = Py::new(py, TestDefaultRepr::Var).unwrap();
15+
py_assert!(
16+
py,
17+
test_object,
18+
"repr(test_object) == 'TestDefaultRepr.Var'"
19+
);
20+
})
21+
}
22+
23+
#[pyclass]
24+
enum OverrideSlot {
25+
Var,
26+
}
27+
28+
#[pymethods]
29+
impl OverrideSlot {
30+
fn __repr__(&self) -> &str {
31+
"overriden"
32+
}
33+
}
34+
35+
#[test]
36+
fn test_override_slot() {
37+
Python::with_gil(|py| {
38+
let test_object = Py::new(py, OverrideSlot::Var).unwrap();
39+
py_assert!(py, test_object, "repr(test_object) == 'overriden'");
40+
})
41+
}

tests/test_enum.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,13 @@ fn test_enum_arg() {
5151

5252
py_run!(py, f mynum, "f(mynum.Variant)")
5353
}
54+
55+
#[test]
56+
fn test_default_repr_correct() {
57+
Python::with_gil(|py| {
58+
let var1 = Py::new(py, MyEnum::Variant).unwrap();
59+
let var2 = Py::new(py, MyEnum::OtherVariant).unwrap();
60+
py_assert!(py, var1, "repr(var1) == 'MyEnum.Variant'");
61+
py_assert!(py, var2, "repr(var2) == 'MyEnum.OtherVariant'");
62+
})
63+
}

0 commit comments

Comments
 (0)