Skip to content

Commit 3cb56bb

Browse files
committed
Implement default slot methods and default __repr__ for enums.
1 parent 245617a commit 3cb56bb

File tree

4 files changed

+96
-100
lines changed

4 files changed

+96
-100
lines changed

pyo3-macros-backend/src/pyclass.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use crate::attributes::{self, take_pyo3_options, NameAttribute, TextSignatureAttribute};
44
use crate::deprecations::Deprecations;
55
use crate::konst::{ConstAttributes, ConstSpec};
6-
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
6+
use crate::pyimpl::{gen_default_slot_impls, gen_py_const, PyClassMethodsType};
77
use crate::pymethod::{impl_py_getter_def, impl_py_setter_def, PropertyType};
88
use crate::utils::{self, unwrap_group, PythonDoc};
99
use proc_macro2::{Span, TokenStream};
@@ -422,6 +422,23 @@ fn impl_enum_class(
422422
.impl_all();
423423
let descriptors = unit_variants_as_descriptors(cls, variants.iter().map(|v| v.ident));
424424

425+
let variants_repr = variants.iter().map(|variant| {
426+
let variant_name = variant.ident;
427+
// Assuming all variants are unit variants because they are the only type we support.
428+
let repr = format!("{}.{}", cls, variant_name);
429+
quote! { #cls::#variant_name => #repr }
430+
});
431+
432+
let default_repr_impl = quote! {
433+
fn __repr__(&self) -> &'static str {
434+
match self {
435+
#(#variants_repr),*
436+
, // needs an additional comma for the last item
437+
_ => unreachable!("Unsupported variant type."),
438+
}
439+
}
440+
};
441+
let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl]);
425442
Ok(quote! {
426443

427444
#pytypeinfo
@@ -430,6 +447,8 @@ fn impl_enum_class(
430447

431448
#descriptors
432449

450+
#default_impls
451+
433452
})
434453
}
435454

pyo3-macros-backend/src/pyimpl.rs

+54-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::{
88
pymethod::{self, is_proto_method},
99
};
1010
use proc_macro2::TokenStream;
11-
use pymethod::GeneratedPyMethod;
11+
use pymethod::{GeneratedPyMethod, PyMethodKind, PyMethodProtoKind};
1212
use quote::quote;
1313
use syn::spanned::Spanned;
1414

@@ -139,6 +139,59 @@ pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream {
139139
}
140140
}
141141

142+
pub fn gen_default_slot_impls(cls: &syn::Ident, method_defs: Vec<TokenStream>) -> TokenStream {
143+
// This function uses a lot of `unwrap()`; since method_defs are provided by us, they should
144+
// all succeed.
145+
let ty: syn::Type = syn::parse_quote!(#cls);
146+
147+
let mut method_defs: Vec<syn::ImplItemMethod> = method_defs
148+
.into_iter()
149+
.map(|tokens| syn::parse2(tokens).unwrap())
150+
.collect();
151+
152+
// Generate the python method definition, and change the rust method name
153+
// to avoid name collision.
154+
let py_methods: Vec<_> = method_defs
155+
.iter_mut()
156+
.map(|method_def| -> TokenStream {
157+
// The new name of internal method.
158+
let new_name: syn::Ident =
159+
syn::parse_str(&format!("__pyo3__{}", method_def.sig.ident.clone())).unwrap();
160+
let mut method = pymethod::PyMethod::parse(
161+
&mut method_def.sig,
162+
&mut method_def.attrs,
163+
Default::default(),
164+
)
165+
.unwrap();
166+
// Have to change the name after parsing so the slot is parsed correctly.
167+
method.spec.name = &new_name;
168+
let tok = match method.kind {
169+
PyMethodKind::Proto(proto_kind) => match proto_kind {
170+
PyMethodProtoKind::Slot(slot_def) => {
171+
slot_def.generate_type_slot(&ty, &method.spec).unwrap()
172+
}
173+
_ => todo!("Unsupported prototype kind"),
174+
},
175+
_ => unreachable!("Not a prototype method"),
176+
};
177+
// Change the name here because `method` mutably borrows from `method_def`
178+
method_def.sig.ident = new_name;
179+
tok
180+
})
181+
.collect(); // So `py_methods` don't mutably borrow from `method_defs`
182+
quote! {
183+
impl #cls {
184+
#(#method_defs)*
185+
}
186+
impl ::pyo3::class::impl_::PyClassDefaultSlots<#ty>
187+
for ::pyo3::class::impl_::PyClassImplCollector<#ty> {
188+
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
189+
&[#(#py_methods),*]
190+
}
191+
}
192+
}
193+
}
194+
142195
fn impl_py_methods(ty: &syn::Type, methods: Vec<TokenStream>) -> TokenStream {
143196
quote! {
144197
impl ::pyo3::class::impl_::PyMethods<#ty>

pyo3-macros-backend/src/pymethod.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ pub enum GeneratedPyMethod {
2525
}
2626

2727
pub struct PyMethod<'a> {
28-
kind: PyMethodKind,
28+
pub kind: PyMethodKind,
2929
method_name: String,
30-
spec: FnSpec<'a>,
30+
pub spec: FnSpec<'a>,
3131
}
3232

33-
enum PyMethodKind {
33+
pub enum PyMethodKind {
3434
Fn,
3535
Proto(PyMethodProtoKind),
3636
}
@@ -49,14 +49,14 @@ impl PyMethodKind {
4949
}
5050
}
5151

52-
enum PyMethodProtoKind {
52+
pub enum PyMethodProtoKind {
5353
Slot(&'static SlotDef),
5454
Call,
5555
SlotFragment(&'static SlotFragmentDef),
5656
}
5757

5858
impl<'a> PyMethod<'a> {
59-
fn parse(
59+
pub fn parse(
6060
sig: &'a mut syn::Signature,
6161
meth_attrs: &mut Vec<syn::Attribute>,
6262
options: PyFunctionOptions,
@@ -744,7 +744,7 @@ impl ReturnMode {
744744
}
745745
}
746746

747-
struct SlotDef {
747+
pub struct SlotDef {
748748
slot: StaticIdent,
749749
func_ty: StaticIdent,
750750
arguments: &'static [Ty],
@@ -799,7 +799,7 @@ impl SlotDef {
799799
self
800800
}
801801

802-
fn generate_type_slot(&self, cls: &syn::Type, spec: &FnSpec) -> Result<TokenStream> {
802+
pub fn generate_type_slot(&self, cls: &syn::Type, spec: &FnSpec) -> Result<TokenStream> {
803803
let SlotDef {
804804
slot,
805805
func_ty,
@@ -871,7 +871,7 @@ fn generate_method_body(
871871
})
872872
}
873873

874-
struct SlotFragmentDef {
874+
pub struct SlotFragmentDef {
875875
fragment: &'static str,
876876
arguments: &'static [Ty],
877877
extract_error_mode: ExtractErrorMode,

tests/test_default_impls.rs

+14-90
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,42 @@
11
#![allow(non_snake_case)]
2-
use pyo3::class::PyMethodDefType;
32
use pyo3::prelude::*;
4-
use pyo3::py_run;
53

64
mod common;
75

8-
// Tests for PyClassDefaultSlots
6+
// Test default generated __repr__.
97
#[pyclass]
10-
struct TestDefaultSlots;
11-
12-
// generated using `Cargo expand`
13-
// equivalent to
14-
// ```
15-
// impl TestDefaultSlots {{
16-
// fn __str__(&self) -> &'static str {
17-
// "default"
18-
// }
19-
// }
20-
// ```
21-
impl TestDefaultSlots {
22-
fn __pyo3__str__(&self) -> &'static str {
23-
"default"
24-
}
25-
}
26-
27-
impl ::pyo3::class::impl_::PyClassDefaultSlots<TestDefaultSlots>
28-
for ::pyo3::class::impl_::PyClassImplCollector<TestDefaultSlots>
29-
{
30-
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
31-
&[{
32-
unsafe extern "C" fn __wrap(
33-
_raw_slf: *mut ::pyo3::ffi::PyObject,
34-
) -> *mut ::pyo3::ffi::PyObject {
35-
let _slf = _raw_slf;
36-
::pyo3::callback::handle_panic(|_py| {
37-
let _cell = _py
38-
.from_borrowed_ptr::<::pyo3::PyAny>(_slf)
39-
.downcast::<::pyo3::PyCell<TestDefaultSlots>>()?;
40-
let _ref = _cell.try_borrow()?;
41-
let _slf = &_ref;
42-
::pyo3::callback::convert(_py, TestDefaultSlots::__pyo3__str__(_slf))
43-
})
44-
}
45-
::pyo3::ffi::PyType_Slot {
46-
slot: ::pyo3::ffi::Py_tp_str,
47-
pfunc: __wrap as ::pyo3::ffi::reprfunc as _,
48-
}
49-
}]
50-
}
8+
enum TestDefaultRepr {
9+
Var,
5110
}
5211

5312
#[test]
5413
fn test_default_slot_exists() {
5514
Python::with_gil(|py| {
56-
let test_object = Py::new(py, TestDefaultSlots).unwrap();
57-
py_assert!(py, test_object, "str(test_object) == 'default'");
15+
let test_object = Py::new(py, TestDefaultRepr::Var).unwrap();
16+
py_assert!(
17+
py,
18+
test_object,
19+
"repr(test_object) == 'TestDefaultRepr.Var'"
20+
);
5821
})
5922
}
6023

6124
#[pyclass]
62-
struct OverrideSlot;
63-
64-
// generated using `Cargo expand`
65-
// equivalent to
66-
// ```
67-
// impl OverrideMagicMethod {
68-
// fn __str__(&self) -> &'static str {
69-
// "default"
70-
// }
71-
// }
72-
// ```
73-
impl OverrideSlot {
74-
fn __pyo3__str__(&self) -> &'static str {
75-
"default"
76-
}
77-
}
78-
79-
impl ::pyo3::class::impl_::PyClassDefaultSlots<OverrideSlot>
80-
for ::pyo3::class::impl_::PyClassImplCollector<OverrideSlot>
81-
{
82-
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
83-
&[{
84-
unsafe extern "C" fn __wrap(
85-
_raw_slf: *mut ::pyo3::ffi::PyObject,
86-
) -> *mut ::pyo3::ffi::PyObject {
87-
let _slf = _raw_slf;
88-
::pyo3::callback::handle_panic(|_py| {
89-
let _cell = _py
90-
.from_borrowed_ptr::<::pyo3::PyAny>(_slf)
91-
.downcast::<::pyo3::PyCell<OverrideSlot>>()?;
92-
let _ref = _cell.try_borrow()?;
93-
let _slf = &_ref;
94-
::pyo3::callback::convert(_py, OverrideSlot::__pyo3__str__(_slf))
95-
})
96-
}
97-
::pyo3::ffi::PyType_Slot {
98-
slot: ::pyo3::ffi::Py_tp_str,
99-
pfunc: __wrap as ::pyo3::ffi::reprfunc as _,
100-
}
101-
}]
102-
}
25+
enum OverrideSlot {
26+
Var,
10327
}
10428

10529
#[pymethods]
10630
impl OverrideSlot {
107-
fn __str__(&self) -> &str {
31+
fn __repr__(&self) -> &str {
10832
"overriden"
10933
}
11034
}
11135

11236
#[test]
11337
fn test_override_slot() {
11438
Python::with_gil(|py| {
115-
let test_object = Py::new(py, OverrideSlot).unwrap();
116-
py_assert!(py, test_object, "str(test_object) == 'overriden'");
39+
let test_object = Py::new(py, OverrideSlot::Var).unwrap();
40+
py_assert!(py, test_object, "repr(test_object) == 'overriden'");
11741
})
11842
}

0 commit comments

Comments
 (0)