Skip to content

Commit 3b4c7d3

Browse files
bors[bot]stuhood
andauthored
Merge #3157
3157: Add support for `#[new]` which is also a `#[classmethod]` r=davidhewitt a=stuhood Fixes #3077. Co-authored-by: Stu Hood <[email protected]>
2 parents 0f00240 + 20c5618 commit 3b4c7d3

File tree

7 files changed

+90
-10
lines changed

7 files changed

+90
-10
lines changed

guide/src/class.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,27 @@ Declares a class method callable from Python.
607607
* For details on `parameter-list`, see the documentation of `Method arguments` section.
608608
* The return type must be `PyResult<T>` or `T` for some `T` that implements `IntoPy<PyObject>`.
609609

610+
### Constructors which accept a class argument
611+
612+
To create a constructor which takes a positional class argument, you can combine the `#[classmethod]` and `#[new]` modifiers:
613+
```rust
614+
# use pyo3::prelude::*;
615+
# use pyo3::types::PyType;
616+
# #[pyclass]
617+
# struct BaseClass(PyObject);
618+
#
619+
#[pymethods]
620+
impl BaseClass {
621+
#[new]
622+
#[classmethod]
623+
fn py_new<'p>(cls: &'p PyType, py: Python<'p>) -> PyResult<Self> {
624+
// Get an abstract attribute (presumably) declared on a subclass of this class.
625+
let subclass_attr = cls.getattr("a_class_attr")?;
626+
Ok(Self(subclass_attr.to_object(py)))
627+
}
628+
}
629+
```
630+
610631
## Static methods
611632

612633
To create a static method for a custom class, the method needs to be annotated with the

newsfragments/3157.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Allow combining `#[new]` and `#[classmethod]` to create a constructor which receives a (subtype's) class/`PyType` as its first argument.

pyo3-macros-backend/src/method.rs

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ fn handle_argument_error(pat: &syn::Pat) -> syn::Error {
8484
pub enum MethodTypeAttribute {
8585
/// `#[new]`
8686
New,
87+
/// `#[new]` && `#[classmethod]`
88+
NewClassMethod,
8789
/// `#[classmethod]`
8890
ClassMethod,
8991
/// `#[classattr]`
@@ -102,6 +104,7 @@ pub enum FnType {
102104
Setter(SelfType),
103105
Fn(SelfType),
104106
FnNew,
107+
FnNewClass,
105108
FnClass,
106109
FnStatic,
107110
FnModule,
@@ -122,7 +125,7 @@ impl FnType {
122125
FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => {
123126
quote!()
124127
}
125-
FnType::FnClass => {
128+
FnType::FnClass | FnType::FnNewClass => {
126129
quote! {
127130
let _slf = _pyo3::types::PyType::from_type_ptr(_py, _slf as *mut _pyo3::ffi::PyTypeObject);
128131
}
@@ -368,12 +371,16 @@ impl<'a> FnSpec<'a> {
368371
let (fn_type, skip_first_arg, fixed_convention) = match fn_type_attr {
369372
Some(MethodTypeAttribute::StaticMethod) => (FnType::FnStatic, false, None),
370373
Some(MethodTypeAttribute::ClassAttribute) => (FnType::ClassAttribute, false, None),
371-
Some(MethodTypeAttribute::New) => {
374+
Some(MethodTypeAttribute::New) | Some(MethodTypeAttribute::NewClassMethod) => {
372375
if let Some(name) = &python_name {
373376
bail_spanned!(name.span() => "`name` not allowed with `#[new]`");
374377
}
375378
*python_name = Some(syn::Ident::new("__new__", Span::call_site()));
376-
(FnType::FnNew, false, Some(CallingConvention::TpNew))
379+
if matches!(fn_type_attr, Some(MethodTypeAttribute::New)) {
380+
(FnType::FnNew, false, Some(CallingConvention::TpNew))
381+
} else {
382+
(FnType::FnNewClass, true, Some(CallingConvention::TpNew))
383+
}
377384
}
378385
Some(MethodTypeAttribute::ClassMethod) => (FnType::FnClass, true, None),
379386
Some(MethodTypeAttribute::Getter) => {
@@ -496,7 +503,11 @@ impl<'a> FnSpec<'a> {
496503
}
497504
CallingConvention::TpNew => {
498505
let (arg_convert, args) = impl_arg_params(self, cls, &py, false)?;
499-
let call = quote! { #rust_name(#(#args),*) };
506+
let call = match &self.tp {
507+
FnType::FnNew => quote! { #rust_name(#(#args),*) },
508+
FnType::FnNewClass => quote! { #rust_name(PyType::from_type_ptr(#py, subtype), #(#args),*) },
509+
x => panic!("Only `FnNew` or `FnNewClass` may use the `TpNew` calling convention. Got: {:?}", x),
510+
};
500511
quote! {
501512
unsafe fn #ident(
502513
#py: _pyo3::Python<'_>,
@@ -609,7 +620,7 @@ impl<'a> FnSpec<'a> {
609620
FnType::Getter(_) | FnType::Setter(_) | FnType::ClassAttribute => return None,
610621
FnType::Fn(_) => Some("self"),
611622
FnType::FnModule => Some("module"),
612-
FnType::FnClass => Some("cls"),
623+
FnType::FnClass | FnType::FnNewClass => Some("cls"),
613624
FnType::FnStatic | FnType::FnNew => None,
614625
};
615626

@@ -637,11 +648,22 @@ fn parse_method_attributes(
637648
let mut deprecated_args = None;
638649
let mut ty: Option<MethodTypeAttribute> = None;
639650

651+
macro_rules! set_compound_ty {
652+
($new_ty:expr, $ident:expr) => {
653+
ty = match (ty, $new_ty) {
654+
(None, new_ty) => Some(new_ty),
655+
(Some(MethodTypeAttribute::ClassMethod), MethodTypeAttribute::New) => Some(MethodTypeAttribute::NewClassMethod),
656+
(Some(MethodTypeAttribute::New), MethodTypeAttribute::ClassMethod) => Some(MethodTypeAttribute::NewClassMethod),
657+
(Some(_), _) => bail_spanned!($ident.span() => "can only combine `new` and `classmethod`"),
658+
};
659+
};
660+
}
661+
640662
macro_rules! set_ty {
641663
($new_ty:expr, $ident:expr) => {
642664
ensure_spanned!(
643665
ty.replace($new_ty).is_none(),
644-
$ident.span() => "cannot specify a second method type"
666+
$ident.span() => "cannot combine these method types"
645667
);
646668
};
647669
}
@@ -650,13 +672,13 @@ fn parse_method_attributes(
650672
match attr.parse_meta() {
651673
Ok(syn::Meta::Path(name)) => {
652674
if name.is_ident("new") || name.is_ident("__new__") {
653-
set_ty!(MethodTypeAttribute::New, name);
675+
set_compound_ty!(MethodTypeAttribute::New, name);
654676
} else if name.is_ident("init") || name.is_ident("__init__") {
655677
bail_spanned!(name.span() => "#[init] is disabled since PyO3 0.9.0");
656678
} else if name.is_ident("call") || name.is_ident("__call__") {
657679
bail_spanned!(name.span() => "use `fn __call__` instead of `#[call]` attribute since PyO3 0.15.0");
658680
} else if name.is_ident("classmethod") {
659-
set_ty!(MethodTypeAttribute::ClassMethod, name);
681+
set_compound_ty!(MethodTypeAttribute::ClassMethod, name);
660682
} else if name.is_ident("staticmethod") {
661683
set_ty!(MethodTypeAttribute::StaticMethod, name);
662684
} else if name.is_ident("classattr") {

pyo3-macros-backend/src/pymethod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ pub fn gen_py_method(
234234
Some(quote!(_pyo3::ffi::METH_STATIC)),
235235
)?),
236236
// special prototypes
237-
(_, FnType::FnNew) => GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?),
237+
(_, FnType::FnNew) | (_, FnType::FnNewClass) => {
238+
GeneratedPyMethod::Proto(impl_py_method_def_new(cls, spec)?)
239+
}
238240

239241
(_, FnType::Getter(self_type)) => GeneratedPyMethod::Method(impl_py_getter_def(
240242
cls,

pytests/src/pyclasses.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
use pyo3::exceptions::PyValueError;
12
use pyo3::iter::IterNextOutput;
23
use pyo3::prelude::*;
4+
use pyo3::types::PyType;
35

46
#[pyclass]
57
struct EmptyClass {}
@@ -35,9 +37,30 @@ impl PyClassIter {
3537
}
3638
}
3739

40+
/// Demonstrates a base class which can operate on the relevant subclass in its constructor.
41+
#[pyclass(subclass)]
42+
#[derive(Clone, Debug)]
43+
struct AssertingBaseClass;
44+
45+
#[pymethods]
46+
impl AssertingBaseClass {
47+
#[new]
48+
#[classmethod]
49+
fn new(cls: &PyType, expected_type: &PyType) -> PyResult<Self> {
50+
if !cls.is(expected_type) {
51+
return Err(PyValueError::new_err(format!(
52+
"{:?} != {:?}",
53+
cls, expected_type
54+
)));
55+
}
56+
Ok(Self)
57+
}
58+
}
59+
3860
#[pymodule]
3961
pub fn pyclasses(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
4062
m.add_class::<EmptyClass>()?;
4163
m.add_class::<PyClassIter>()?;
64+
m.add_class::<AssertingBaseClass>()?;
4265
Ok(())
4366
}

pytests/tests/test_pyclasses.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,14 @@ def test_iter():
2525
with pytest.raises(StopIteration) as excinfo:
2626
next(i)
2727
assert excinfo.value.value == "Ended"
28+
29+
30+
class AssertingSubClass(pyclasses.AssertingBaseClass):
31+
pass
32+
33+
34+
def test_new_classmethod():
35+
# The `AssertingBaseClass` constructor errors if it is not passed the relevant subclass.
36+
_ = AssertingSubClass(expected_type=AssertingSubClass)
37+
with pytest.raises(ValueError):
38+
_ = AssertingSubClass(expected_type=str)

tests/ui/invalid_pymethods.stderr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ error: `signature` not allowed with `classattr`
8888
105 | #[pyo3(signature = ())]
8989
| ^^^^^^^^^
9090

91-
error: cannot specify a second method type
91+
error: cannot combine these method types
9292
--> tests/ui/invalid_pymethods.rs:112:7
9393
|
9494
112 | #[staticmethod]

0 commit comments

Comments
 (0)