@@ -84,6 +84,8 @@ fn handle_argument_error(pat: &syn::Pat) -> syn::Error {
84
84
pub enum MethodTypeAttribute {
85
85
/// `#[new]`
86
86
New ,
87
+ /// `#[new]` && `#[classmethod]`
88
+ NewClassMethod ,
87
89
/// `#[classmethod]`
88
90
ClassMethod ,
89
91
/// `#[classattr]`
@@ -102,6 +104,7 @@ pub enum FnType {
102
104
Setter ( SelfType ) ,
103
105
Fn ( SelfType ) ,
104
106
FnNew ,
107
+ FnNewClass ,
105
108
FnClass ,
106
109
FnStatic ,
107
110
FnModule ,
@@ -122,7 +125,7 @@ impl FnType {
122
125
FnType :: FnNew | FnType :: FnStatic | FnType :: ClassAttribute => {
123
126
quote ! ( )
124
127
}
125
- FnType :: FnClass => {
128
+ FnType :: FnClass | FnType :: FnNewClass => {
126
129
quote ! {
127
130
let _slf = _pyo3:: types:: PyType :: from_type_ptr( _py, _slf as * mut _pyo3:: ffi:: PyTypeObject ) ;
128
131
}
@@ -368,12 +371,16 @@ impl<'a> FnSpec<'a> {
368
371
let ( fn_type, skip_first_arg, fixed_convention) = match fn_type_attr {
369
372
Some ( MethodTypeAttribute :: StaticMethod ) => ( FnType :: FnStatic , false , None ) ,
370
373
Some ( MethodTypeAttribute :: ClassAttribute ) => ( FnType :: ClassAttribute , false , None ) ,
371
- Some ( MethodTypeAttribute :: New ) => {
374
+ Some ( MethodTypeAttribute :: New ) | Some ( MethodTypeAttribute :: NewClassMethod ) => {
372
375
if let Some ( name) = & python_name {
373
376
bail_spanned ! ( name. span( ) => "`name` not allowed with `#[new]`" ) ;
374
377
}
375
378
* python_name = Some ( syn:: Ident :: new ( "__new__" , Span :: call_site ( ) ) ) ;
376
- ( FnType :: FnNew , false , Some ( CallingConvention :: TpNew ) )
379
+ if matches ! ( fn_type_attr, Some ( MethodTypeAttribute :: New ) ) {
380
+ ( FnType :: FnNew , false , Some ( CallingConvention :: TpNew ) )
381
+ } else {
382
+ ( FnType :: FnNewClass , true , Some ( CallingConvention :: TpNew ) )
383
+ }
377
384
}
378
385
Some ( MethodTypeAttribute :: ClassMethod ) => ( FnType :: FnClass , true , None ) ,
379
386
Some ( MethodTypeAttribute :: Getter ) => {
@@ -496,7 +503,11 @@ impl<'a> FnSpec<'a> {
496
503
}
497
504
CallingConvention :: TpNew => {
498
505
let ( arg_convert, args) = impl_arg_params ( self , cls, & py, false ) ?;
499
- let call = quote ! { #rust_name( #( #args) , * ) } ;
506
+ let call = match & self . tp {
507
+ FnType :: FnNew => quote ! { #rust_name( #( #args) , * ) } ,
508
+ FnType :: FnNewClass => quote ! { #rust_name( PyType :: from_type_ptr( #py, subtype) , #( #args) , * ) } ,
509
+ x => panic ! ( "Only `FnNew` or `FnNewClass` may use the `TpNew` calling convention. Got: {:?}" , x) ,
510
+ } ;
500
511
quote ! {
501
512
unsafe fn #ident(
502
513
#py: _pyo3:: Python <' _>,
@@ -609,7 +620,7 @@ impl<'a> FnSpec<'a> {
609
620
FnType :: Getter ( _) | FnType :: Setter ( _) | FnType :: ClassAttribute => return None ,
610
621
FnType :: Fn ( _) => Some ( "self" ) ,
611
622
FnType :: FnModule => Some ( "module" ) ,
612
- FnType :: FnClass => Some ( "cls" ) ,
623
+ FnType :: FnClass | FnType :: FnNewClass => Some ( "cls" ) ,
613
624
FnType :: FnStatic | FnType :: FnNew => None ,
614
625
} ;
615
626
@@ -637,11 +648,22 @@ fn parse_method_attributes(
637
648
let mut deprecated_args = None ;
638
649
let mut ty: Option < MethodTypeAttribute > = None ;
639
650
651
+ macro_rules! set_compound_ty {
652
+ ( $new_ty: expr, $ident: expr) => {
653
+ ty = match ( ty, $new_ty) {
654
+ ( None , new_ty) => Some ( new_ty) ,
655
+ ( Some ( MethodTypeAttribute :: ClassMethod ) , MethodTypeAttribute :: New ) => Some ( MethodTypeAttribute :: NewClassMethod ) ,
656
+ ( Some ( MethodTypeAttribute :: New ) , MethodTypeAttribute :: ClassMethod ) => Some ( MethodTypeAttribute :: NewClassMethod ) ,
657
+ ( Some ( _) , _) => bail_spanned!( $ident. span( ) => "can only combine `new` and `classmethod`" ) ,
658
+ } ;
659
+ } ;
660
+ }
661
+
640
662
macro_rules! set_ty {
641
663
( $new_ty: expr, $ident: expr) => {
642
664
ensure_spanned!(
643
665
ty. replace( $new_ty) . is_none( ) ,
644
- $ident. span( ) => "cannot specify a second method type "
666
+ $ident. span( ) => "cannot combine these method types "
645
667
) ;
646
668
} ;
647
669
}
@@ -650,13 +672,13 @@ fn parse_method_attributes(
650
672
match attr. parse_meta ( ) {
651
673
Ok ( syn:: Meta :: Path ( name) ) => {
652
674
if name. is_ident ( "new" ) || name. is_ident ( "__new__" ) {
653
- set_ty ! ( MethodTypeAttribute :: New , name) ;
675
+ set_compound_ty ! ( MethodTypeAttribute :: New , name) ;
654
676
} else if name. is_ident ( "init" ) || name. is_ident ( "__init__" ) {
655
677
bail_spanned ! ( name. span( ) => "#[init] is disabled since PyO3 0.9.0" ) ;
656
678
} else if name. is_ident ( "call" ) || name. is_ident ( "__call__" ) {
657
679
bail_spanned ! ( name. span( ) => "use `fn __call__` instead of `#[call]` attribute since PyO3 0.15.0" ) ;
658
680
} else if name. is_ident ( "classmethod" ) {
659
- set_ty ! ( MethodTypeAttribute :: ClassMethod , name) ;
681
+ set_compound_ty ! ( MethodTypeAttribute :: ClassMethod , name) ;
660
682
} else if name. is_ident ( "staticmethod" ) {
661
683
set_ty ! ( MethodTypeAttribute :: StaticMethod , name) ;
662
684
} else if name. is_ident ( "classattr" ) {
0 commit comments