Skip to content

Commit da7c031

Browse files
authored
Merge pull request #3287 from alex/existing-instance
Allow `#[new]` to return existing instances
2 parents 1a0c9be + 0b78bb8 commit da7c031

File tree

4 files changed

+88
-7
lines changed

4 files changed

+88
-7
lines changed

guide/src/class.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ impl Nonzero {
114114
}
115115
```
116116

117+
If you want to return an existing object (for example, because your `new`
118+
method caches the values it returns), `new` can return `pyo3::Py<Self>`.
119+
117120
As you can see, the Rust method name is not important here; this way you can
118121
still, use `new()` for a Rust-level constructor.
119122

newsfragments/3287.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`#[new]` methods may now return `Py<Self>` in order to return existing instances

src/pyclass_init.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Contains initialization utilities for `#[pyclass]`.
22
use crate::callback::IntoPyCallbackOutput;
33
use crate::impl_::pyclass::{PyClassBaseType, PyClassDict, PyClassThreadChecker, PyClassWeakRef};
4-
use crate::{ffi, PyCell, PyClass, PyErr, PyResult, Python};
4+
use crate::{ffi, IntoPyPointer, Py, PyCell, PyClass, PyErr, PyResult, Python};
55
use crate::{
66
ffi::PyTypeObject,
77
pycell::{
@@ -134,17 +134,22 @@ impl<T: PyTypeInfo> PyObjectInit<T> for PyNativeTypeInitializer<T> {
134134
/// );
135135
/// });
136136
/// ```
137-
pub struct PyClassInitializer<T: PyClass> {
138-
init: T,
139-
super_init: <T::BaseType as PyClassBaseType>::Initializer,
137+
pub struct PyClassInitializer<T: PyClass>(PyClassInitializerImpl<T>);
138+
139+
enum PyClassInitializerImpl<T: PyClass> {
140+
Existing(Py<T>),
141+
New {
142+
init: T,
143+
super_init: <T::BaseType as PyClassBaseType>::Initializer,
144+
},
140145
}
141146

142147
impl<T: PyClass> PyClassInitializer<T> {
143148
/// Constructs a new initializer from value `T` and base class' initializer.
144149
///
145150
/// It is recommended to use `add_subclass` instead of this method for most usage.
146151
pub fn new(init: T, super_init: <T::BaseType as PyClassBaseType>::Initializer) -> Self {
147-
Self { init, super_init }
152+
Self(PyClassInitializerImpl::New { init, super_init })
148153
}
149154

150155
/// Constructs a new initializer from an initializer for the base class.
@@ -242,13 +247,18 @@ impl<T: PyClass> PyObjectInit<T> for PyClassInitializer<T> {
242247
contents: MaybeUninit<PyCellContents<T>>,
243248
}
244249

245-
let obj = self.super_init.into_new_object(py, subtype)?;
250+
let (init, super_init) = match self.0 {
251+
PyClassInitializerImpl::Existing(value) => return Ok(value.into_ptr()),
252+
PyClassInitializerImpl::New { init, super_init } => (init, super_init),
253+
};
254+
255+
let obj = super_init.into_new_object(py, subtype)?;
246256

247257
let cell: *mut PartiallyInitializedPyCell<T> = obj as _;
248258
std::ptr::write(
249259
(*cell).contents.as_mut_ptr(),
250260
PyCellContents {
251-
value: ManuallyDrop::new(UnsafeCell::new(self.init)),
261+
value: ManuallyDrop::new(UnsafeCell::new(init)),
252262
borrow_checker: <T::PyClassMutability as PyClassMutability>::Storage::new(),
253263
thread_checker: T::ThreadChecker::new(),
254264
dict: T::Dict::INIT,
@@ -284,6 +294,13 @@ where
284294
}
285295
}
286296

297+
impl<T: PyClass> From<Py<T>> for PyClassInitializer<T> {
298+
#[inline]
299+
fn from(value: Py<T>) -> PyClassInitializer<T> {
300+
PyClassInitializer(PyClassInitializerImpl::Existing(value))
301+
}
302+
}
303+
287304
// Implementation used by proc macros to allow anything convertible to PyClassInitializer<T> to be
288305
// the return value of pyclass #[new] method (optionally wrapped in `Result<U, E>`).
289306
impl<T, U> IntoPyCallbackOutput<PyClassInitializer<T>> for U

tests/test_class_new.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
use pyo3::exceptions::PyValueError;
44
use pyo3::prelude::*;
5+
use pyo3::sync::GILOnceCell;
56
use pyo3::types::IntoPyDict;
67

78
#[pyclass]
@@ -204,3 +205,62 @@ fn new_with_custom_error() {
204205
assert_eq!(err.to_string(), "ValueError: custom error");
205206
});
206207
}
208+
209+
#[pyclass]
210+
struct NewExisting {
211+
#[pyo3(get)]
212+
num: usize,
213+
}
214+
215+
#[pymethods]
216+
impl NewExisting {
217+
#[new]
218+
fn new(py: pyo3::Python<'_>, val: usize) -> pyo3::Py<NewExisting> {
219+
static PRE_BUILT: GILOnceCell<[pyo3::Py<NewExisting>; 2]> = GILOnceCell::new();
220+
let existing = PRE_BUILT.get_or_init(py, || {
221+
[
222+
pyo3::PyCell::new(py, NewExisting { num: 0 })
223+
.unwrap()
224+
.into(),
225+
pyo3::PyCell::new(py, NewExisting { num: 1 })
226+
.unwrap()
227+
.into(),
228+
]
229+
});
230+
231+
if val < existing.len() {
232+
return existing[val].clone_ref(py);
233+
}
234+
235+
pyo3::PyCell::new(py, NewExisting { num: val })
236+
.unwrap()
237+
.into()
238+
}
239+
}
240+
241+
#[test]
242+
fn test_new_existing() {
243+
Python::with_gil(|py| {
244+
let typeobj = py.get_type::<NewExisting>();
245+
246+
let obj1 = typeobj.call1((0,)).unwrap();
247+
let obj2 = typeobj.call1((0,)).unwrap();
248+
let obj3 = typeobj.call1((1,)).unwrap();
249+
let obj4 = typeobj.call1((1,)).unwrap();
250+
let obj5 = typeobj.call1((2,)).unwrap();
251+
let obj6 = typeobj.call1((2,)).unwrap();
252+
253+
assert!(obj1.getattr("num").unwrap().extract::<u32>().unwrap() == 0);
254+
assert!(obj2.getattr("num").unwrap().extract::<u32>().unwrap() == 0);
255+
assert!(obj3.getattr("num").unwrap().extract::<u32>().unwrap() == 1);
256+
assert!(obj4.getattr("num").unwrap().extract::<u32>().unwrap() == 1);
257+
assert!(obj5.getattr("num").unwrap().extract::<u32>().unwrap() == 2);
258+
assert!(obj6.getattr("num").unwrap().extract::<u32>().unwrap() == 2);
259+
260+
assert!(obj1.is(obj2));
261+
assert!(obj3.is(obj4));
262+
assert!(!obj1.is(obj3));
263+
assert!(!obj1.is(obj5));
264+
assert!(!obj5.is(obj6));
265+
});
266+
}

0 commit comments

Comments
 (0)