Skip to content

Commit 4dcdede

Browse files
committed
derive(FromPyObject): adds default option
Takes an optional expression to set a custom value that is not the one from the Default trait
1 parent 93823d2 commit 4dcdede

File tree

5 files changed

+81
-13
lines changed

5 files changed

+81
-13
lines changed

newsfragments/4829.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
`derive(FromPyObject)` allow a `default` attribute to set a default value for extracted fields. The default value is either provided explicitly or fetched via `Default::default()`.

pyo3-macros-backend/src/attributes.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ pub mod kw {
4545
syn::custom_keyword!(unsendable);
4646
syn::custom_keyword!(weakref);
4747
syn::custom_keyword!(gil_used);
48+
syn::custom_keyword!(default);
4849
}
4950

5051
fn take_int(read: &mut &str, tracker: &mut usize) -> String {
@@ -351,6 +352,8 @@ impl<K: ToTokens, V: ToTokens> ToTokens for OptionalKeywordAttribute<K, V> {
351352

352353
pub type FromPyWithAttribute = KeywordAttribute<kw::from_py_with, LitStrValue<ExprPath>>;
353354

355+
pub type DefaultAttribute = OptionalKeywordAttribute<kw::default, Expr>;
356+
354357
/// For specifying the path to the pyo3 crate.
355358
pub type CrateAttribute = KeywordAttribute<Token![crate], LitStrValue<Path>>;
356359

pyo3-macros-backend/src/frompyobject.rs

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
use crate::attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute};
1+
use crate::attributes::{
2+
self, get_pyo3_options, CrateAttribute, DefaultAttribute, FromPyWithAttribute,
3+
};
24
use crate::utils::Ctx;
35
use proc_macro2::TokenStream;
4-
use quote::{format_ident, quote};
6+
use quote::{format_ident, quote, ToTokens};
57
use syn::{
68
ext::IdentExt,
79
parenthesized,
@@ -90,6 +92,7 @@ struct NamedStructField<'a> {
9092
ident: &'a syn::Ident,
9193
getter: Option<FieldGetter>,
9294
from_py_with: Option<FromPyWithAttribute>,
95+
default: Option<DefaultAttribute>,
9396
}
9497

9598
struct TupleStructField {
@@ -193,6 +196,7 @@ impl<'a> Container<'a> {
193196
ident,
194197
getter: attrs.getter,
195198
from_py_with: attrs.from_py_with,
199+
default: attrs.default,
196200
})
197201
})
198202
.collect::<Result<Vec<_>>>()?;
@@ -346,18 +350,33 @@ impl<'a> Container<'a> {
346350
quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #field_name)))
347351
}
348352
};
349-
let extractor = match &field.from_py_with {
350-
None => {
351-
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?)
352-
}
353-
Some(FromPyWithAttribute {
354-
value: expr_path, ..
355-
}) => {
356-
quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &#getter?, #struct_name, #field_name)?)
357-
}
353+
let extractor = if let Some(FromPyWithAttribute {
354+
value: expr_path, ..
355+
}) = &field.from_py_with
356+
{
357+
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &value, #struct_name, #field_name)?)
358+
} else {
359+
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?)
360+
};
361+
let extracted = if let Some(default) = &field.default {
362+
let default_expr = if let Some(default_expr) = &default.value {
363+
default_expr.to_token_stream()
364+
} else {
365+
quote!(Default::default())
366+
};
367+
quote!(if let Ok(value) = #getter {
368+
#extractor
369+
} else {
370+
#default_expr
371+
})
372+
} else {
373+
quote!({
374+
let value = #getter?;
375+
#extractor
376+
})
358377
};
359378

360-
fields.push(quote!(#ident: #extractor));
379+
fields.push(quote!(#ident: #extracted));
361380
}
362381

363382
quote!(::std::result::Result::Ok(#self_ty{#fields}))
@@ -458,6 +477,7 @@ impl ContainerOptions {
458477
struct FieldPyO3Attributes {
459478
getter: Option<FieldGetter>,
460479
from_py_with: Option<FromPyWithAttribute>,
480+
default: Option<DefaultAttribute>,
461481
}
462482

463483
#[derive(Clone, Debug)]
@@ -469,6 +489,7 @@ enum FieldGetter {
469489
enum FieldPyO3Attribute {
470490
Getter(FieldGetter),
471491
FromPyWith(FromPyWithAttribute),
492+
Default(DefaultAttribute),
472493
}
473494

474495
impl Parse for FieldPyO3Attribute {
@@ -512,6 +533,8 @@ impl Parse for FieldPyO3Attribute {
512533
}
513534
} else if lookahead.peek(attributes::kw::from_py_with) {
514535
input.parse().map(FieldPyO3Attribute::FromPyWith)
536+
} else if lookahead.peek(attributes::kw::default) {
537+
input.parse().map(FieldPyO3Attribute::Default)
515538
} else {
516539
Err(lookahead.error())
517540
}
@@ -523,6 +546,7 @@ impl FieldPyO3Attributes {
523546
fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
524547
let mut getter = None;
525548
let mut from_py_with = None;
549+
let mut default = None;
526550

527551
for attr in attrs {
528552
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
@@ -542,6 +566,13 @@ impl FieldPyO3Attributes {
542566
);
543567
from_py_with = Some(from_py_with_attr);
544568
}
569+
FieldPyO3Attribute::Default(default_attr) => {
570+
ensure_spanned!(
571+
default.is_none(),
572+
attr.span() => "`default` may only be provided once"
573+
);
574+
default = Some(default_attr);
575+
}
545576
}
546577
}
547578
}
@@ -550,6 +581,7 @@ impl FieldPyO3Attributes {
550581
Ok(FieldPyO3Attributes {
551582
getter,
552583
from_py_with,
584+
default,
553585
})
554586
}
555587
}

tests/test_frompyobject.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,3 +686,35 @@ fn test_with_keyword_item() {
686686
assert_eq!(result, expected);
687687
});
688688
}
689+
690+
#[derive(Debug, FromPyObject, PartialEq, Eq)]
691+
pub struct WithDefaultItem {
692+
#[pyo3(item, default)]
693+
value: Option<usize>,
694+
}
695+
696+
#[test]
697+
fn test_with_default_item() {
698+
Python::with_gil(|py| {
699+
let dict = PyDict::new(py);
700+
let result = dict.extract::<WithDefaultItem>().unwrap();
701+
let expected = WithDefaultItem { value: None };
702+
assert_eq!(result, expected);
703+
});
704+
}
705+
706+
#[derive(Debug, FromPyObject, PartialEq, Eq)]
707+
pub struct WithExplicitDefaultItem {
708+
#[pyo3(item, default = 1)]
709+
value: usize,
710+
}
711+
712+
#[test]
713+
fn test_with_explicit_default_item() {
714+
Python::with_gil(|py| {
715+
let dict = PyDict::new(py);
716+
let result = dict.extract::<WithExplicitDefaultItem>().unwrap();
717+
let expected = WithExplicitDefaultItem { value: 1 };
718+
assert_eq!(result, expected);
719+
});
720+
}

tests/ui/invalid_frompy_derive.stderr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ error: transparent structs and variants can only have 1 field
8484
70 | | },
8585
| |_____^
8686

87-
error: expected one of: `attribute`, `item`, `from_py_with`
87+
error: expected one of: `attribute`, `item`, `from_py_with`, `default`
8888
--> tests/ui/invalid_frompy_derive.rs:76:12
8989
|
9090
76 | #[pyo3(attr)]

0 commit comments

Comments
 (0)