Skip to content

Commit ec6d587

Browse files
authored
support Bound for classmethod and pass_module (#3831)
* support `Bound` for `classmethod` and `pass_module` * `from_ref_to_ptr` -> `ref_from_ptr` * add detailed docs to `ref_from_ptr`
1 parent 05aedc9 commit ec6d587

16 files changed

+205
-51
lines changed

guide/src/class.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ This is the equivalent of the Python decorator `@classmethod`.
691691
#[pymethods]
692692
impl MyClass {
693693
#[classmethod]
694-
fn cls_method(cls: &PyType) -> PyResult<i32> {
694+
fn cls_method(cls: &Bound<'_, PyType>) -> PyResult<i32> {
695695
Ok(10)
696696
}
697697
}
@@ -719,10 +719,10 @@ To create a constructor which takes a positional class argument, you can combine
719719
impl BaseClass {
720720
#[new]
721721
#[classmethod]
722-
fn py_new<'p>(cls: &'p PyType, py: Python<'p>) -> PyResult<Self> {
722+
fn py_new(cls: &Bound<'_, PyType>) -> PyResult<Self> {
723723
// Get an abstract attribute (presumably) declared on a subclass of this class.
724-
let subclass_attr = cls.getattr("a_class_attr")?;
725-
Ok(Self(subclass_attr.to_object(py)))
724+
let subclass_attr: Bound<'_, PyAny> = cls.getattr("a_class_attr")?;
725+
Ok(Self(subclass_attr.unbind()))
726726
}
727727
}
728728
```
@@ -928,7 +928,7 @@ impl MyClass {
928928
// similarly for classmethod arguments, use $cls
929929
#[classmethod]
930930
#[pyo3(text_signature = "($cls, e, f)")]
931-
fn my_class_method(cls: &PyType, e: i32, f: i32) -> i32 {
931+
fn my_class_method(cls: &Bound<'_, PyType>, e: i32, f: i32) -> i32 {
932932
e + f
933933
}
934934
#[staticmethod]

guide/src/function.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,11 @@ The `#[pyo3]` attribute can be used to modify properties of the generated Python
8383

8484
```rust
8585
use pyo3::prelude::*;
86+
use pyo3::types::PyString;
8687

8788
#[pyfunction]
8889
#[pyo3(pass_module)]
89-
fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> {
90+
fn pyfunction_with_module<'py>(module: &Bound<'py, PyModule>) -> PyResult<Bound<'py, PyString>> {
9091
module.name()
9192
}
9293

pyo3-macros-backend/src/method.rs

+11-3
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,21 @@ impl FnType {
127127
let slf: Ident = syn::Ident::new("_slf", Span::call_site());
128128
quote_spanned! { *span =>
129129
#[allow(clippy::useless_conversion)]
130-
::std::convert::Into::into(_pyo3::types::PyType::from_type_ptr(#py, #slf.cast())),
130+
::std::convert::Into::into(
131+
_pyo3::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf.cast())
132+
.downcast_unchecked::<_pyo3::types::PyType>()
133+
),
131134
}
132135
}
133136
FnType::FnModule(span) => {
137+
let py = syn::Ident::new("py", Span::call_site());
138+
let slf: Ident = syn::Ident::new("_slf", Span::call_site());
134139
quote_spanned! { *span =>
135140
#[allow(clippy::useless_conversion)]
136-
::std::convert::Into::into(py.from_borrowed_ptr::<_pyo3::types::PyModule>(_slf)),
141+
::std::convert::Into::into(
142+
_pyo3::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf.cast())
143+
.downcast_unchecked::<_pyo3::types::PyModule>()
144+
),
137145
}
138146
}
139147
}
@@ -409,7 +417,7 @@ impl<'a> FnSpec<'a> {
409417
// will error on incorrect type.
410418
Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
411419
Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
412-
sig.paren_token.span.join() => "Expected `&PyType` or `Py<PyType>` as the first argument to `#[classmethod]`"
420+
sig.paren_token.span.join() => "Expected `&Bound<PyType>` or `Py<PyType>` as the first argument to `#[classmethod]`"
413421
),
414422
};
415423
FnType::FnClass(span)

pytests/src/pyclasses.rs

+20
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,25 @@ struct AssertingBaseClass;
4444

4545
#[pymethods]
4646
impl AssertingBaseClass {
47+
#[new]
48+
#[classmethod]
49+
fn new(cls: &Bound<'_, PyType>, expected_type: Bound<'_, PyType>) -> PyResult<Self> {
50+
if !cls.is(&expected_type) {
51+
return Err(PyValueError::new_err(format!(
52+
"{:?} != {:?}",
53+
cls, expected_type
54+
)));
55+
}
56+
Ok(Self)
57+
}
58+
}
59+
60+
#[pyclass(subclass)]
61+
#[derive(Clone, Debug)]
62+
struct AssertingBaseClassGilRef;
63+
64+
#[pymethods]
65+
impl AssertingBaseClassGilRef {
4766
#[new]
4867
#[classmethod]
4968
fn new(cls: &PyType, expected_type: &PyType) -> PyResult<Self> {
@@ -65,6 +84,7 @@ pub fn pyclasses(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
6584
m.add_class::<EmptyClass>()?;
6685
m.add_class::<PyClassIter>()?;
6786
m.add_class::<AssertingBaseClass>()?;
87+
m.add_class::<AssertingBaseClassGilRef>()?;
6888
m.add_class::<ClassWithoutConstructor>()?;
6989
Ok(())
7090
}

pytests/tests/test_pyclasses.py

+11
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ def test_new_classmethod():
4141
_ = AssertingSubClass(expected_type=str)
4242

4343

44+
def test_new_classmethod_gil_ref():
45+
class AssertingSubClass(pyclasses.AssertingBaseClassGilRef):
46+
pass
47+
48+
# The `AssertingBaseClass` constructor errors if it is not passed the
49+
# relevant subclass.
50+
_ = AssertingSubClass(expected_type=AssertingSubClass)
51+
with pytest.raises(ValueError):
52+
_ = AssertingSubClass(expected_type=str)
53+
54+
4455
class ClassWithoutConstructorPy:
4556
def __new__(cls):
4657
raise TypeError("No constructor defined")

src/impl_/pymethods.rs

+52-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ use crate::exceptions::PyStopAsyncIteration;
33
use crate::gil::LockGIL;
44
use crate::impl_::panic::PanicTrap;
55
use crate::internal_tricks::extract_c_string;
6+
use crate::types::{any::PyAnyMethods, PyModule, PyType};
67
use crate::{
7-
ffi, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit, Python,
8+
ffi, Bound, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit,
9+
Python,
810
};
911
use std::borrow::Cow;
1012
use std::ffi::CStr;
@@ -466,3 +468,52 @@ pub trait AsyncIterResultOptionKind {
466468
}
467469

468470
impl<Value, Error> AsyncIterResultOptionKind for Result<Option<Value>, Error> {}
471+
472+
/// Used in `#[classmethod]` to pass the class object to the method
473+
/// and also in `#[pyfunction(pass_module)]`.
474+
///
475+
/// This is a wrapper to avoid implementing `From<Bound>` for GIL Refs.
476+
///
477+
/// Once the GIL Ref API is fully removed, it should be possible to simplify
478+
/// this to just `&'a Bound<'py, T>` and `From` implementations.
479+
pub struct BoundRef<'a, 'py, T>(pub &'a Bound<'py, T>);
480+
481+
impl<'a, 'py> BoundRef<'a, 'py, PyAny> {
482+
pub unsafe fn ref_from_ptr(py: Python<'py>, ptr: &'a *mut ffi::PyObject) -> Self {
483+
BoundRef(Bound::ref_from_ptr(py, ptr))
484+
}
485+
486+
pub unsafe fn downcast_unchecked<T>(self) -> BoundRef<'a, 'py, T> {
487+
BoundRef(self.0.downcast_unchecked::<T>())
488+
}
489+
}
490+
491+
// GIL Ref implementations for &'a T ran into trouble with orphan rules,
492+
// so explicit implementations are used instead for the two relevant types.
493+
impl<'a> From<BoundRef<'a, 'a, PyType>> for &'a PyType {
494+
#[inline]
495+
fn from(bound: BoundRef<'a, 'a, PyType>) -> Self {
496+
bound.0.as_gil_ref()
497+
}
498+
}
499+
500+
impl<'a> From<BoundRef<'a, 'a, PyModule>> for &'a PyModule {
501+
#[inline]
502+
fn from(bound: BoundRef<'a, 'a, PyModule>) -> Self {
503+
bound.0.as_gil_ref()
504+
}
505+
}
506+
507+
impl<'a, 'py, T> From<BoundRef<'a, 'py, T>> for &'a Bound<'py, T> {
508+
#[inline]
509+
fn from(bound: BoundRef<'a, 'py, T>) -> Self {
510+
bound.0
511+
}
512+
}
513+
514+
impl<T> From<BoundRef<'_, '_, T>> for Py<T> {
515+
#[inline]
516+
fn from(bound: BoundRef<'_, '_, T>) -> Self {
517+
bound.0.clone().unbind()
518+
}
519+
}

src/instance.rs

+18
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,24 @@ impl<'py> Bound<'py, PyAny> {
138138
) -> PyResult<Self> {
139139
Py::from_owned_ptr_or_err(py, ptr).map(|obj| Self(py, ManuallyDrop::new(obj)))
140140
}
141+
142+
/// This slightly strange method is used to obtain `&Bound<PyAny>` from a pointer in macro code
143+
/// where we need to constrain the lifetime `'a` safely.
144+
///
145+
/// Note that `'py` is required to outlive `'a` implicitly by the nature of the fact that
146+
/// `&'a Bound<'py>` means that `Bound<'py>` exists for at least the lifetime `'a`.
147+
///
148+
/// # Safety
149+
/// - `ptr` must be a valid pointer to a Python object for the lifetime `'a`. The `ptr` can
150+
/// be either a borrowed reference or an owned reference, it does not matter, as this is
151+
/// just `&Bound` there will never be any ownership transfer.
152+
#[inline]
153+
pub(crate) unsafe fn ref_from_ptr<'a>(
154+
_py: Python<'py>,
155+
ptr: &'a *mut ffi::PyObject,
156+
) -> &'a Self {
157+
&*(ptr as *const *mut ffi::PyObject).cast::<Bound<'py, PyAny>>()
158+
}
141159
}
142160

143161
impl<'py, T> Bound<'py, T>

src/tests/hygiene/pymethods.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ impl Dummy {
375375
#[staticmethod]
376376
fn staticmethod() {}
377377
#[classmethod]
378-
fn clsmethod(_: &crate::types::PyType) {}
378+
fn clsmethod(_: &crate::Bound<'_, crate::types::PyType>) {}
379379
#[pyo3(signature = (*_args, **_kwds))]
380380
fn __call__(
381381
&self,
@@ -770,7 +770,7 @@ impl Dummy {
770770
#[staticmethod]
771771
fn staticmethod() {}
772772
#[classmethod]
773-
fn clsmethod(_: &crate::types::PyType) {}
773+
fn clsmethod(_: &crate::Bound<'_, crate::types::PyType>) {}
774774
#[pyo3(signature = (*_args, **_kwds))]
775775
fn __call__(
776776
&self,

tests/test_class_basics.rs

+14-2
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ fn panic_unsendable_child() {
284284
test_unsendable::<UnsendableChild>().unwrap();
285285
}
286286

287-
fn get_length(obj: &PyAny) -> PyResult<usize> {
287+
fn get_length(obj: &Bound<'_, PyAny>) -> PyResult<usize> {
288288
let length = obj.len()?;
289289

290290
Ok(length)
@@ -299,7 +299,18 @@ impl ClassWithFromPyWithMethods {
299299
argument
300300
}
301301
#[classmethod]
302-
fn classmethod(_cls: &PyType, #[pyo3(from_py_with = "PyAny::len")] argument: usize) -> usize {
302+
fn classmethod(
303+
_cls: &Bound<'_, PyType>,
304+
#[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] argument: usize,
305+
) -> usize {
306+
argument
307+
}
308+
309+
#[classmethod]
310+
fn classmethod_gil_ref(
311+
_cls: &PyType,
312+
#[pyo3(from_py_with = "PyAny::len")] argument: usize,
313+
) -> usize {
303314
argument
304315
}
305316

@@ -322,6 +333,7 @@ fn test_pymethods_from_py_with() {
322333
323334
assert instance.instance_method(arg) == 2
324335
assert instance.classmethod(arg) == 2
336+
assert instance.classmethod_gil_ref(arg) == 2
325337
assert instance.staticmethod(arg) == 2
326338
"#
327339
);

tests/test_methods.rs

+15-5
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,13 @@ impl ClassMethod {
7373

7474
#[classmethod]
7575
/// Test class method.
76-
fn method(cls: &PyType) -> PyResult<String> {
76+
fn method(cls: &Bound<'_, PyType>) -> PyResult<String> {
77+
Ok(format!("{}.method()!", cls.as_gil_ref().qualname()?))
78+
}
79+
80+
#[classmethod]
81+
/// Test class method.
82+
fn method_gil_ref(cls: &PyType) -> PyResult<String> {
7783
Ok(format!("{}.method()!", cls.qualname()?))
7884
}
7985

@@ -108,8 +114,12 @@ struct ClassMethodWithArgs {}
108114
#[pymethods]
109115
impl ClassMethodWithArgs {
110116
#[classmethod]
111-
fn method(cls: &PyType, input: &PyString) -> PyResult<String> {
112-
Ok(format!("{}.method({})", cls.qualname()?, input))
117+
fn method(cls: &Bound<'_, PyType>, input: &PyString) -> PyResult<String> {
118+
Ok(format!(
119+
"{}.method({})",
120+
cls.as_gil_ref().qualname()?,
121+
input
122+
))
113123
}
114124
}
115125

@@ -915,7 +925,7 @@ impl r#RawIdents {
915925
}
916926

917927
#[classmethod]
918-
pub fn r#class_method(_: &PyType, r#type: PyObject) -> PyObject {
928+
pub fn r#class_method(_: &Bound<'_, PyType>, r#type: PyObject) -> PyObject {
919929
r#type
920930
}
921931

@@ -1082,7 +1092,7 @@ issue_1506!(
10821092

10831093
#[classmethod]
10841094
fn issue_1506_class(
1085-
_cls: &PyType,
1095+
_cls: &Bound<'_, PyType>,
10861096
_py: Python<'_>,
10871097
_arg: &PyAny,
10881098
_args: &PyTuple,

0 commit comments

Comments
 (0)