Skip to content

derive(FromPyObject): adds default option #4829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions guide/src/conversions/traits.md
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,48 @@ If the input is neither a string nor an integer, the error message will be:
- apply a custom function to convert the field from Python the desired Rust type.
- the argument must be the name of the function as a string.
- the function signature must be `fn(&Bound<PyAny>) -> PyResult<T>` where `T` is the Rust type of the argument.
- `pyo3(default)`, `pyo3(default = ...)`
- if the argument is set, uses the given default value.
- in this case, the argument must be a Rust expression returning a value of the desired Rust type.
- if the argument is not set, [`Default::default`](https://doc.rust-lang.org/std/default/trait.Default.html#tymethod.default) is used.
- note that the default value is only used if the field is not set.
If the field is set and the conversion function from Python to Rust fails, an exception is raised and the default value is not used.
- this attribute is only supported on named fields.

For example, the code below applies the given conversion function on the `"value"` dict item to compute its length or fall back to the type default value (0):

```rust
use pyo3::prelude::*;

#[derive(FromPyObject)]
struct RustyStruct {
#[pyo3(item("value"), default, from_py_with = "Bound::<'_, PyAny>::len")]
len: usize,
#[pyo3(item)]
other: usize,
}
#
# use pyo3::types::PyDict;
# fn main() -> PyResult<()> {
# Python::with_gil(|py| -> PyResult<()> {
# // Filled case
# let dict = PyDict::new(py);
# dict.set_item("value", (1,)).unwrap();
# dict.set_item("other", 1).unwrap();
# let result = dict.extract::<RustyStruct>()?;
# assert_eq!(result.len, 1);
# assert_eq!(result.other, 1);
#
# // Empty case
# let dict = PyDict::new(py);
# dict.set_item("other", 1).unwrap();
# let result = dict.extract::<RustyStruct>()?;
# assert_eq!(result.len, 0);
# assert_eq!(result.other, 1);
# Ok(())
# })
# }
```

### `IntoPyObject`
The ['IntoPyObject'] trait defines the to-python conversion for a Rust type. All types in PyO3 implement this trait,
Expand Down
1 change: 1 addition & 0 deletions newsfragments/4829.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`derive(FromPyObject)` allow a `default` attribute to set a default value for extracted fields of named structs. The default value is either provided explicitly or fetched via `Default::default()`.
2 changes: 2 additions & 0 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ impl<K: ToTokens, V: ToTokens> ToTokens for OptionalKeywordAttribute<K, V> {

pub type FromPyWithAttribute = KeywordAttribute<kw::from_py_with, LitStrValue<ExprPath>>;

pub type DefaultAttribute = OptionalKeywordAttribute<Token![default], Expr>;

/// For specifying the path to the pyo3 crate.
pub type CrateAttribute = KeywordAttribute<Token![crate], LitStrValue<Path>>;

Expand Down
66 changes: 53 additions & 13 deletions pyo3-macros-backend/src/frompyobject.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::attributes::{self, get_pyo3_options, CrateAttribute, FromPyWithAttribute};
use crate::attributes::{
self, get_pyo3_options, CrateAttribute, DefaultAttribute, FromPyWithAttribute,
};
use crate::utils::Ctx;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use quote::{format_ident, quote, ToTokens};
use syn::{
ext::IdentExt,
parenthesized,
Expand Down Expand Up @@ -90,6 +92,7 @@ struct NamedStructField<'a> {
ident: &'a syn::Ident,
getter: Option<FieldGetter>,
from_py_with: Option<FromPyWithAttribute>,
default: Option<DefaultAttribute>,
}

struct TupleStructField {
Expand Down Expand Up @@ -144,6 +147,10 @@ impl<'a> Container<'a> {
attrs.getter.is_none(),
field.span() => "`getter` is not permitted on tuple struct elements."
);
ensure_spanned!(
attrs.default.is_none(),
field.span() => "`default` is not permitted on tuple struct elements."
);
Ok(TupleStructField {
from_py_with: attrs.from_py_with,
})
Expand Down Expand Up @@ -193,10 +200,15 @@ impl<'a> Container<'a> {
ident,
getter: attrs.getter,
from_py_with: attrs.from_py_with,
default: attrs.default,
})
})
.collect::<Result<Vec<_>>>()?;
if options.transparent {
if struct_fields.iter().all(|field| field.default.is_some()) {
bail_spanned!(
fields.span() => "cannot derive FromPyObject for structs and variants with only default values"
)
} else if options.transparent {
ensure_spanned!(
struct_fields.len() == 1,
fields.span() => "transparent structs and variants can only have 1 field"
Expand Down Expand Up @@ -346,18 +358,33 @@ impl<'a> Container<'a> {
quote!(#pyo3_path::types::PyAnyMethods::get_item(obj, #pyo3_path::intern!(obj.py(), #field_name)))
}
};
let extractor = match &field.from_py_with {
None => {
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&#getter?, #struct_name, #field_name)?)
}
Some(FromPyWithAttribute {
value: expr_path, ..
}) => {
quote! (#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &#getter?, #struct_name, #field_name)?)
}
let extractor = if let Some(FromPyWithAttribute {
value: expr_path, ..
}) = &field.from_py_with
{
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field_with(#expr_path as fn(_) -> _, &value, #struct_name, #field_name)?)
} else {
quote!(#pyo3_path::impl_::frompyobject::extract_struct_field(&value, #struct_name, #field_name)?)
};
let extracted = if let Some(default) = &field.default {
let default_expr = if let Some(default_expr) = &default.value {
default_expr.to_token_stream()
} else {
quote!(::std::default::Default::default())
};
quote!(if let ::std::result::Result::Ok(value) = #getter {
#extractor
} else {
#default_expr
})
} else {
quote!({
let value = #getter?;
#extractor
})
};

fields.push(quote!(#ident: #extractor));
fields.push(quote!(#ident: #extracted));
}

quote!(::std::result::Result::Ok(#self_ty{#fields}))
Expand Down Expand Up @@ -458,6 +485,7 @@ impl ContainerOptions {
struct FieldPyO3Attributes {
getter: Option<FieldGetter>,
from_py_with: Option<FromPyWithAttribute>,
default: Option<DefaultAttribute>,
}

#[derive(Clone, Debug)]
Expand All @@ -469,6 +497,7 @@ enum FieldGetter {
enum FieldPyO3Attribute {
Getter(FieldGetter),
FromPyWith(FromPyWithAttribute),
Default(DefaultAttribute),
}

impl Parse for FieldPyO3Attribute {
Expand Down Expand Up @@ -512,6 +541,8 @@ impl Parse for FieldPyO3Attribute {
}
} else if lookahead.peek(attributes::kw::from_py_with) {
input.parse().map(FieldPyO3Attribute::FromPyWith)
} else if lookahead.peek(Token![default]) {
input.parse().map(FieldPyO3Attribute::Default)
} else {
Err(lookahead.error())
}
Expand All @@ -523,6 +554,7 @@ impl FieldPyO3Attributes {
fn from_attrs(attrs: &[Attribute]) -> Result<Self> {
let mut getter = None;
let mut from_py_with = None;
let mut default = None;

for attr in attrs {
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
Expand All @@ -542,6 +574,13 @@ impl FieldPyO3Attributes {
);
from_py_with = Some(from_py_with_attr);
}
FieldPyO3Attribute::Default(default_attr) => {
ensure_spanned!(
default.is_none(),
attr.span() => "`default` may only be provided once"
);
default = Some(default_attr);
}
}
}
}
Expand All @@ -550,6 +589,7 @@ impl FieldPyO3Attributes {
Ok(FieldPyO3Attributes {
getter,
from_py_with,
default,
})
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/tests/hygiene/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct Derive3 {
f: i32,
#[pyo3(item(42))]
g: i32,
#[pyo3(default)]
h: i32,
} // struct case

#[derive(crate::FromPyObject)]
Expand Down
114 changes: 114 additions & 0 deletions tests/test_frompyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,3 +686,117 @@ fn test_with_keyword_item() {
assert_eq!(result, expected);
});
}

#[derive(Debug, FromPyObject, PartialEq, Eq)]
pub struct WithDefaultItem {
#[pyo3(item, default)]
opt: Option<usize>,
#[pyo3(item)]
value: usize,
}

#[test]
fn test_with_default_item() {
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item("value", 3).unwrap();
let result = dict.extract::<WithDefaultItem>().unwrap();
let expected = WithDefaultItem {
value: 3,
opt: None,
};
assert_eq!(result, expected);
});
}

#[derive(Debug, FromPyObject, PartialEq, Eq)]
pub struct WithExplicitDefaultItem {
#[pyo3(item, default = 1)]
opt: usize,
#[pyo3(item)]
value: usize,
}

#[test]
fn test_with_explicit_default_item() {
Python::with_gil(|py| {
let dict = PyDict::new(py);
dict.set_item("value", 3).unwrap();
let result = dict.extract::<WithExplicitDefaultItem>().unwrap();
let expected = WithExplicitDefaultItem { value: 3, opt: 1 };
assert_eq!(result, expected);
});
}

#[derive(Debug, FromPyObject, PartialEq, Eq)]
pub struct WithDefaultItemAndConversionFunction {
#[pyo3(item, default, from_py_with = "Bound::<'_, PyAny>::len")]
opt: usize,
#[pyo3(item)]
value: usize,
}

#[test]
fn test_with_default_item_and_conversion_function() {
Python::with_gil(|py| {
// Filled case
let dict = PyDict::new(py);
dict.set_item("opt", (1,)).unwrap();
dict.set_item("value", 3).unwrap();
let result = dict
.extract::<WithDefaultItemAndConversionFunction>()
.unwrap();
let expected = WithDefaultItemAndConversionFunction { opt: 1, value: 3 };
assert_eq!(result, expected);

// Empty case
let dict = PyDict::new(py);
dict.set_item("value", 3).unwrap();
let result = dict
.extract::<WithDefaultItemAndConversionFunction>()
.unwrap();
let expected = WithDefaultItemAndConversionFunction { opt: 0, value: 3 };
assert_eq!(result, expected);

// Error case
let dict = PyDict::new(py);
dict.set_item("value", 3).unwrap();
dict.set_item("opt", 1).unwrap();
assert!(dict
.extract::<WithDefaultItemAndConversionFunction>()
.is_err());
});
}

#[derive(Debug, FromPyObject, PartialEq, Eq)]
pub enum WithDefaultItemEnum {
#[pyo3(from_item_all)]
Foo {
a: usize,
#[pyo3(default)]
b: usize,
},
NeverUsedA {
a: usize,
},
}

#[test]
fn test_with_default_item_enum() {
Python::with_gil(|py| {
// A and B filled
let dict = PyDict::new(py);
dict.set_item("a", 1).unwrap();
dict.set_item("b", 2).unwrap();
let result = dict.extract::<WithDefaultItemEnum>().unwrap();
let expected = WithDefaultItemEnum::Foo { a: 1, b: 2 };
assert_eq!(result, expected);

// A filled
let dict = PyDict::new(py);
dict.set_item("a", 1).unwrap();
let result = dict.extract::<WithDefaultItemEnum>().unwrap();
let expected = WithDefaultItemEnum::Foo { a: 1, b: 0 };
assert_eq!(result, expected);
});
}
17 changes: 17 additions & 0 deletions tests/ui/invalid_frompy_derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,21 @@ struct FromItemAllConflictAttrWithArgs {
field: String,
}

#[derive(FromPyObject)]
struct StructWithOnlyDefaultValues {
#[pyo3(default)]
field: String,
}

#[derive(FromPyObject)]
enum EnumVariantWithOnlyDefaultValues {
Foo {
#[pyo3(default)]
field: String,
},
}

#[derive(FromPyObject)]
struct NamedTuplesWithDefaultValues(#[pyo3(default)] String);

fn main() {}
28 changes: 27 additions & 1 deletion tests/ui/invalid_frompy_derive.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ error: transparent structs and variants can only have 1 field
70 | | },
| |_____^

error: expected one of: `attribute`, `item`, `from_py_with`
error: expected one of: `attribute`, `item`, `from_py_with`, `default`
--> tests/ui/invalid_frompy_derive.rs:76:12
|
76 | #[pyo3(attr)]
Expand Down Expand Up @@ -223,3 +223,29 @@ error: The struct is already annotated with `from_item_all`, `attribute` is not
|
210 | #[pyo3(from_item_all)]
| ^^^^^^^^^^^^^

error: cannot derive FromPyObject for structs and variants with only default values
--> tests/ui/invalid_frompy_derive.rs:217:36
|
217 | struct StructWithOnlyDefaultValues {
| ____________________________________^
218 | | #[pyo3(default)]
219 | | field: String,
220 | | }
| |_^

error: cannot derive FromPyObject for structs and variants with only default values
--> tests/ui/invalid_frompy_derive.rs:224:9
|
224 | Foo {
| _________^
225 | | #[pyo3(default)]
226 | | field: String,
227 | | },
| |_____^

error: `default` is not permitted on tuple struct elements.
--> tests/ui/invalid_frompy_derive.rs:231:37
|
231 | struct NamedTuplesWithDefaultValues(#[pyo3(default)] String);
| ^
Loading