Skip to content

Commit 62c7fd0

Browse files
committed
allow #[pyo3(signature = ...)] on complex enum variants to specify constructor signature
1 parent ef13bc6 commit 62c7fd0

File tree

5 files changed

+91
-15
lines changed

5 files changed

+91
-15
lines changed

pyo3-macros-backend/src/pyclass.rs

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::attributes::{
88
use crate::deprecations::Deprecations;
99
use crate::konst::{ConstAttributes, ConstSpec};
1010
use crate::method::{FnArg, FnSpec, PyArg, RegularArg};
11+
use crate::pyfunction::SignatureAttribute;
1112
use crate::pyimpl::{gen_py_const, PyClassMethodsType};
1213
use crate::pymethod::{
1314
impl_py_getter_def, impl_py_setter_def, MethodAndMethodDef, MethodAndSlotDef, PropertyType,
@@ -622,17 +623,21 @@ struct PyClassEnumVariantNamedField<'a> {
622623
/// `#[pyo3()]` options for pyclass enum variants
623624
struct EnumVariantPyO3Options {
624625
name: Option<NameAttribute>,
626+
signature: Option<SignatureAttribute>,
625627
}
626628

627629
enum EnumVariantPyO3Option {
628630
Name(NameAttribute),
631+
Signature(SignatureAttribute),
629632
}
630633

631634
impl Parse for EnumVariantPyO3Option {
632635
fn parse(input: ParseStream<'_>) -> Result<Self> {
633636
let lookahead = input.lookahead1();
634637
if lookahead.peek(attributes::kw::name) {
635638
input.parse().map(EnumVariantPyO3Option::Name)
639+
} else if lookahead.peek(attributes::kw::signature) {
640+
input.parse().map(EnumVariantPyO3Option::Signature)
636641
} else {
637642
Err(lookahead.error())
638643
}
@@ -641,7 +646,10 @@ impl Parse for EnumVariantPyO3Option {
641646

642647
impl EnumVariantPyO3Options {
643648
fn take_pyo3_options(attrs: &mut Vec<syn::Attribute>) -> Result<Self> {
644-
let mut options = EnumVariantPyO3Options { name: None };
649+
let mut options = EnumVariantPyO3Options {
650+
name: None,
651+
signature: None,
652+
};
645653

646654
for option in take_pyo3_options(attrs)? {
647655
match option {
@@ -652,6 +660,13 @@ impl EnumVariantPyO3Options {
652660
);
653661
options.name = Some(name);
654662
}
663+
EnumVariantPyO3Option::Signature(signature) => {
664+
ensure_spanned!(
665+
options.signature.is_none(),
666+
signature.span() => "`signature` may only be specified once"
667+
);
668+
options.signature = Some(signature);
669+
}
655670
}
656671
}
657672

@@ -691,19 +706,20 @@ fn impl_simple_enum(
691706

692707
let (default_repr, default_repr_slot) = {
693708
let variants_repr = variants.iter().map(|variant| {
709+
ensure_spanned!(variant.options.signature.is_none(), variant.options.signature.span() => "`signature` can't be used on a simple enum variant");
694710
let variant_name = variant.ident;
695711
// Assuming all variants are unit variants because they are the only type we support.
696712
let repr = format!(
697713
"{}.{}",
698714
get_class_python_name(cls, args),
699715
variant.get_python_name(args),
700716
);
701-
quote! { #cls::#variant_name => #repr, }
702-
});
717+
Ok(quote! { #cls::#variant_name => #repr, })
718+
}).collect::<Result<TokenStream>>()?;
703719
let mut repr_impl: syn::ImplItemFn = syn::parse_quote! {
704720
fn __pyo3__repr__(&self) -> &'static str {
705721
match self {
706-
#(#variants_repr)*
722+
#variants_repr
707723
}
708724
}
709725
};
@@ -889,7 +905,7 @@ fn impl_complex_enum(
889905
let mut variant_cls_pytypeinfos = vec![];
890906
let mut variant_cls_pyclass_impls = vec![];
891907
let mut variant_cls_impls = vec![];
892-
for variant in &variants {
908+
for variant in variants {
893909
let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident());
894910

895911
let variant_cls_zst = quote! {
@@ -908,11 +924,11 @@ fn impl_complex_enum(
908924
let variant_cls_pytypeinfo = impl_pytypeinfo(&variant_cls, &variant_args, None, ctx);
909925
variant_cls_pytypeinfos.push(variant_cls_pytypeinfo);
910926

911-
let variant_new = complex_enum_variant_new(cls, variant, ctx)?;
912-
913-
let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, variant, ctx)?;
927+
let (variant_cls_impl, field_getters) = impl_complex_enum_variant_cls(cls, &variant, ctx)?;
914928
variant_cls_impls.push(variant_cls_impl);
915929

930+
let variant_new = complex_enum_variant_new(cls, variant, ctx)?;
931+
916932
let pyclass_impl = PyClassImplsBuilder::new(
917933
&variant_cls,
918934
&variant_args,
@@ -1120,7 +1136,7 @@ pub fn gen_complex_enum_variant_attr(
11201136

11211137
fn complex_enum_variant_new<'a>(
11221138
cls: &'a syn::Ident,
1123-
variant: &'a PyClassEnumVariant<'a>,
1139+
variant: PyClassEnumVariant<'a>,
11241140
ctx: &Ctx,
11251141
) -> Result<MethodAndSlotDef> {
11261142
match variant {
@@ -1132,7 +1148,7 @@ fn complex_enum_variant_new<'a>(
11321148

11331149
fn complex_enum_struct_variant_new<'a>(
11341150
cls: &'a syn::Ident,
1135-
variant: &'a PyClassEnumStructVariant<'a>,
1151+
variant: PyClassEnumStructVariant<'a>,
11361152
ctx: &Ctx,
11371153
) -> Result<MethodAndSlotDef> {
11381154
let Ctx { pyo3_path } = ctx;
@@ -1162,7 +1178,12 @@ fn complex_enum_struct_variant_new<'a>(
11621178
}
11631179
args
11641180
};
1165-
let signature = crate::pyfunction::FunctionSignature::from_arguments(args)?;
1181+
1182+
let signature = if let Some(signature) = variant.options.signature {
1183+
crate::pyfunction::FunctionSignature::from_arguments_and_attribute(args, signature)?
1184+
} else {
1185+
crate::pyfunction::FunctionSignature::from_arguments(args)?
1186+
};
11661187

11671188
let spec = FnSpec {
11681189
tp: crate::method::FnType::FnNew,

pytests/src/enums.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,26 @@ pub fn do_simple_stuff(thing: &SimpleEnum) -> SimpleEnum {
3939

4040
#[pyclass]
4141
pub enum ComplexEnum {
42-
Int { i: i32 },
43-
Float { f: f64 },
44-
Str { s: String },
42+
Int {
43+
i: i32,
44+
},
45+
Float {
46+
f: f64,
47+
},
48+
Str {
49+
s: String,
50+
},
4551
EmptyStruct {},
46-
MultiFieldStruct { a: i32, b: f64, c: bool },
52+
MultiFieldStruct {
53+
a: i32,
54+
b: f64,
55+
c: bool,
56+
},
57+
#[pyo3(signature = (a = 42, b = None))]
58+
VariantWithDefault {
59+
a: i32,
60+
b: Option<String>,
61+
},
4762
}
4863

4964
#[pyfunction]
@@ -58,5 +73,9 @@ pub fn do_complex_stuff(thing: &ComplexEnum) -> ComplexEnum {
5873
b: *b,
5974
c: *c,
6075
},
76+
ComplexEnum::VariantWithDefault { a, b } => ComplexEnum::VariantWithDefault {
77+
a: 2 * a,
78+
b: b.as_ref().map(|s| s.to_uppercase()),
79+
},
6180
}
6281
}

pytests/tests/test_enums.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ def test_complex_enum_variant_constructors():
1818
multi_field_struct_variant = enums.ComplexEnum.MultiFieldStruct(42, 3.14, True)
1919
assert isinstance(multi_field_struct_variant, enums.ComplexEnum.MultiFieldStruct)
2020

21+
variant_with_default_1 = enums.ComplexEnum.VariantWithDefault()
22+
assert isinstance(variant_with_default_1, enums.ComplexEnum.VariantWithDefault)
23+
24+
variant_with_default_2 = enums.ComplexEnum.VariantWithDefault(25, "Hello")
25+
assert isinstance(variant_with_default_2, enums.ComplexEnum.VariantWithDefault)
26+
2127

2228
@pytest.mark.parametrize(
2329
"variant",
@@ -27,6 +33,7 @@ def test_complex_enum_variant_constructors():
2733
enums.ComplexEnum.Str("hello"),
2834
enums.ComplexEnum.EmptyStruct(),
2935
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
36+
enums.ComplexEnum.VariantWithDefault(),
3037
],
3138
)
3239
def test_complex_enum_variant_subclasses(variant: enums.ComplexEnum):
@@ -48,6 +55,10 @@ def test_complex_enum_field_getters():
4855
assert multi_field_struct_variant.b == 3.14
4956
assert multi_field_struct_variant.c is True
5057

58+
variant_with_default = enums.ComplexEnum.VariantWithDefault()
59+
assert variant_with_default.a == 42
60+
assert variant_with_default.b is None
61+
5162

5263
@pytest.mark.parametrize(
5364
"variant",
@@ -57,6 +68,7 @@ def test_complex_enum_field_getters():
5768
enums.ComplexEnum.Str("hello"),
5869
enums.ComplexEnum.EmptyStruct(),
5970
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
71+
enums.ComplexEnum.VariantWithDefault(),
6072
],
6173
)
6274
def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
@@ -78,6 +90,11 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
7890
assert x == 42
7991
assert y == 3.14
8092
assert z is True
93+
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
94+
x = variant.a
95+
y = variant.b
96+
assert x == 42
97+
assert y is None
8198
else:
8299
assert False
83100

@@ -90,6 +107,7 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
90107
enums.ComplexEnum.Str("hello"),
91108
enums.ComplexEnum.EmptyStruct(),
92109
enums.ComplexEnum.MultiFieldStruct(42, 3.14, True),
110+
enums.ComplexEnum.VariantWithDefault(b="hello"),
93111
],
94112
)
95113
def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEnum):
@@ -112,5 +130,10 @@ def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEn
112130
assert x == 42
113131
assert y == 3.14
114132
assert z is True
133+
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
134+
x = variant.a
135+
y = variant.b
136+
assert x == 84
137+
assert y == "HELLO"
115138
else:
116139
assert False

tests/ui/invalid_pyclass_enum.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,11 @@ enum NoTupleVariants {
2727
TupleVariant(i32),
2828
}
2929

30+
#[pyclass]
31+
enum SimpleNoSignature {
32+
#[pyo3(signature = (a, b))]
33+
A,
34+
B,
35+
}
36+
3037
fn main() {}

tests/ui/invalid_pyclass_enum.stderr

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,9 @@ error: Tuple variant `TupleVariant` is not yet supported in a complex enum
3131
|
3232
27 | TupleVariant(i32),
3333
| ^^^^^^^^^^^^
34+
35+
error: `signature` can't be used on a simple enum variant
36+
--> tests/ui/invalid_pyclass_enum.rs:32:12
37+
|
38+
32 | #[pyo3(signature = (a, b))]
39+
| ^^^^^^^^^

0 commit comments

Comments
 (0)