Skip to content

Commit 77726a5

Browse files
committed
Draft of default_impl
1 parent b7419b5 commit 77726a5

File tree

3 files changed

+167
-0
lines changed

3 files changed

+167
-0
lines changed

pyo3-macros-backend/src/pyclass.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ impl<'a> PyClassImplsBuilder<'a> {
706706
}
707707
}
708708
}
709+
709710
fn impl_pyclassimpl(&self) -> TokenStream {
710711
let cls = self.cls;
711712
let doc = self.doc.as_ref().map_or(quote! {"\0"}, |doc| quote! {#doc});
@@ -765,6 +766,9 @@ impl<'a> PyClassImplsBuilder<'a> {
765766
visitor(collector.descr_protocol_methods());
766767
visitor(collector.mapping_protocol_methods());
767768
visitor(collector.number_protocol_methods());
769+
770+
// It's collected last so Python ignore them if it found methods with same names.
771+
visitor(collector.py_class_default_impls());
768772
}
769773
fn get_new() -> ::std::option::Option<::pyo3::ffi::newfunc> {
770774
use ::pyo3::class::impl_::*;
@@ -786,6 +790,7 @@ impl<'a> PyClassImplsBuilder<'a> {
786790
// Implementation which uses dtolnay specialization to load all slots.
787791
use ::pyo3::class::impl_::*;
788792
let collector = PyClassImplCollector::<Self>::new();
793+
visitor(collector.py_class_default_slots());
789794
visitor(collector.object_protocol_slots());
790795
visitor(collector.number_protocol_slots());
791796
visitor(collector.iter_protocol_slots());

src/class/impl_.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,9 @@ pub trait HasMethodsInventory {
629629
// Methods from #[pyo3(get, set)] on struct fields.
630630
methods_trait!(PyClassDescriptors, py_class_descriptors);
631631

632+
// Methods that PyO3 implemented by default, but can be overridden by the user.
633+
methods_trait!(PyClassDefaultImpls, py_class_default_impls);
634+
632635
// Methods from #[pymethods] if not using inventory.
633636
#[cfg(not(feature = "multiple-pymethods"))]
634637
methods_trait!(PyMethods, py_methods);
@@ -664,6 +667,9 @@ slots_trait!(PyAsyncProtocolSlots, async_protocol_slots);
664667
slots_trait!(PySequenceProtocolSlots, sequence_protocol_slots);
665668
slots_trait!(PyBufferProtocolSlots, buffer_protocol_slots);
666669

670+
// slots that PyO3 implements by default, but can be overidden by the users.
671+
slots_trait!(PyClassDefaultSlots, py_class_default_slots);
672+
667673
// Protocol slots from #[pymethods] if not using inventory.
668674
#[cfg(not(feature = "multiple-pymethods"))]
669675
slots_trait!(PyMethodsProtocolSlots, methods_protocol_slots);

tests/test_default_impls.rs

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#![allow(non_snake_case)]
2+
use pyo3::class::PyMethodDefType;
3+
use pyo3::prelude::*;
4+
use pyo3::py_run;
5+
6+
mod common;
7+
8+
// Tests for PyClassDefaultImpls
9+
#[pyclass]
10+
struct TestDefaultImpl;
11+
12+
// generated using `Cargo expand`
13+
// equivalent to
14+
// ```
15+
// impl TestDefaultImpl {
16+
// #[staticmethod]
17+
// fn default() {}
18+
// #[staticmethod]
19+
// fn overriden() -> bool { false }
20+
// }
21+
// ```
22+
impl TestDefaultImpl {
23+
fn __pyo3__default() {}
24+
fn __pyo3__overriden() -> bool {
25+
false
26+
}
27+
}
28+
29+
impl pyo3::class::impl_::PyClassDefaultImpls<TestDefaultImpl>
30+
for pyo3::class::impl_::PyClassImplCollector<TestDefaultImpl>
31+
{
32+
fn py_class_default_impls(self) -> &'static [PyMethodDefType] {
33+
static METHODS: &[::pyo3::class::methods::PyMethodDefType] = &[
34+
::pyo3::class::PyMethodDefType::Static(
35+
::pyo3::class::methods::PyMethodDef::noargs(
36+
"default\u{0}",
37+
::pyo3::class::methods::PyCFunction({
38+
unsafe extern "C" fn __wrap(
39+
_slf: *mut ::pyo3::ffi::PyObject,
40+
_args: *mut ::pyo3::ffi::PyObject,
41+
) -> *mut ::pyo3::ffi::PyObject {
42+
::pyo3::callback::handle_panic(|_py| {
43+
::pyo3::callback::convert(_py, TestDefaultImpl::__pyo3__default())
44+
})
45+
}
46+
__wrap
47+
}),
48+
"\u{0}",
49+
)
50+
.flags(::pyo3::ffi::METH_STATIC),
51+
),
52+
::pyo3::class::PyMethodDefType::Static(
53+
::pyo3::class::methods::PyMethodDef::noargs(
54+
"overriden\u{0}",
55+
::pyo3::class::methods::PyCFunction({
56+
unsafe extern "C" fn __wrap(
57+
_slf: *mut ::pyo3::ffi::PyObject,
58+
_args: *mut ::pyo3::ffi::PyObject,
59+
) -> *mut ::pyo3::ffi::PyObject {
60+
::pyo3::callback::handle_panic(|_py| {
61+
::pyo3::callback::convert(_py, TestDefaultImpl::__pyo3__overriden())
62+
})
63+
}
64+
__wrap
65+
}),
66+
"\u{0}",
67+
)
68+
.flags(::pyo3::ffi::METH_STATIC),
69+
),
70+
];
71+
METHODS
72+
}
73+
}
74+
75+
#[pymethods]
76+
impl TestDefaultImpl {
77+
#[staticmethod]
78+
fn overriden() -> bool {
79+
true
80+
}
81+
}
82+
83+
#[test]
84+
fn test_default_impl_exists() {
85+
Python::with_gil(|py| {
86+
let test_object = Py::new(py, TestDefaultImpl).unwrap();
87+
py_run!(py, test_object, "test_object.default()");
88+
})
89+
}
90+
91+
#[test]
92+
fn test_default_impl_is_overriden() {
93+
Python::with_gil(|py| {
94+
let test_object = Py::new(py, TestDefaultImpl).unwrap();
95+
py_assert!(py, test_object, "test_object.overriden() == True");
96+
})
97+
}
98+
99+
#[pyclass]
100+
struct OverrideMagicMethod;
101+
102+
// generated using `Cargo expand`
103+
// equivalent to
104+
// ```
105+
// impl OverrideMagicMethod {
106+
// fn __str__(&self) -> &'static str {
107+
// "default"
108+
// }
109+
// }
110+
// ```
111+
impl OverrideMagicMethod {
112+
fn __pyo3__str__(&self) -> &'static str {
113+
"default"
114+
}
115+
}
116+
117+
impl ::pyo3::class::impl_::PyClassDefaultSlots<OverrideMagicMethod>
118+
for ::pyo3::class::impl_::PyClassImplCollector<OverrideMagicMethod>
119+
{
120+
fn py_class_default_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
121+
&[{
122+
unsafe extern "C" fn __wrap(
123+
_raw_slf: *mut ::pyo3::ffi::PyObject,
124+
) -> *mut ::pyo3::ffi::PyObject {
125+
let _slf = _raw_slf;
126+
::pyo3::callback::handle_panic(|_py| {
127+
let _cell = _py
128+
.from_borrowed_ptr::<::pyo3::PyAny>(_slf)
129+
.downcast::<::pyo3::PyCell<OverrideMagicMethod>>()?;
130+
let _ref = _cell.try_borrow()?;
131+
let _slf = &_ref;
132+
::pyo3::callback::convert(_py, OverrideMagicMethod::__pyo3__str__(_slf))
133+
})
134+
}
135+
::pyo3::ffi::PyType_Slot {
136+
slot: ::pyo3::ffi::Py_tp_str,
137+
pfunc: __wrap as ::pyo3::ffi::reprfunc as _,
138+
}
139+
}]
140+
}
141+
}
142+
143+
#[pymethods]
144+
impl OverrideMagicMethod {
145+
fn __str__(&self) -> &str {
146+
"overriden"
147+
}
148+
}
149+
150+
#[test]
151+
fn test_override_magic_method() {
152+
Python::with_gil(|py| {
153+
let test_object = Py::new(py, OverrideMagicMethod).unwrap();
154+
py_assert!(py, test_object, "str(test_object) == 'overriden'");
155+
})
156+
}

0 commit comments

Comments
 (0)