Skip to content

Commit 554cffd

Browse files
add #[pyo3(from_py_with="...")] attribute (#1411)
* allow from_py_with inside #[derive(FromPyObject)] * split up FnSpec::parse
1 parent a9f064d commit 554cffd

17 files changed

+485
-145
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
88
## [Unreleased]
99
### Added
1010
- Add conversions between `OsStr`/`OsString`/`Path`/`PathBuf` and Python strings. [#1379](https://github.com/PyO3/pyo3/pull/1379)
11+
- Add #[pyo3(from_py_with = "...")]` attribute for function arguments and struct fields to override the default from-Python conversion. [#1411](https://github.com/PyO3/pyo3/pull/1411)
1112
- Add FFI definition `PyCFunction_CheckExact` for Python 3.9 and later. [#1425](https://github.com/PyO3/pyo3/pull/1425)
1213
- Add FFI definition `Py_IS_TYPE`. [#1429](https://github.com/PyO3/pyo3/pull/1429)
1314

pyo3-macros-backend/src/attrs.rs

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
use syn::spanned::Spanned;
2+
use syn::{ExprPath, Lit, Meta, MetaNameValue, Result};
3+
4+
#[derive(Clone, Debug, PartialEq)]
5+
pub struct FromPyWithAttribute(pub ExprPath);
6+
7+
impl FromPyWithAttribute {
8+
pub fn from_meta(meta: Meta) -> Result<Self> {
9+
let string_literal = match meta {
10+
Meta::NameValue(MetaNameValue {
11+
lit: Lit::Str(string_literal),
12+
..
13+
}) => string_literal,
14+
meta => {
15+
bail_spanned!(meta.span() => "expected a name-value: `pyo3(from_py_with = \"func\")`")
16+
}
17+
};
18+
19+
let expr_path = string_literal.parse::<ExprPath>()?;
20+
Ok(FromPyWithAttribute(expr_path))
21+
}
22+
}

pyo3-macros-backend/src/from_pyobject.rs

+70-40
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::attrs::FromPyWithAttribute;
12
use proc_macro2::TokenStream;
23
use quote::quote;
34
use syn::punctuated::Punctuated;
@@ -85,7 +86,7 @@ enum ContainerType<'a> {
8586
/// Struct Container, e.g. `struct Foo { a: String }`
8687
///
8788
/// Variant contains the list of field identifiers and the corresponding extraction call.
88-
Struct(Vec<(&'a Ident, FieldAttribute)>),
89+
Struct(Vec<(&'a Ident, FieldAttributes)>),
8990
/// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }`
9091
///
9192
/// The field specified by the identifier is extracted directly from the object.
@@ -156,9 +157,8 @@ impl<'a> Container<'a> {
156157
.ident
157158
.as_ref()
158159
.expect("Named fields should have identifiers");
159-
let attr = FieldAttribute::parse_attrs(&field.attrs)?
160-
.unwrap_or(FieldAttribute::GetAttr(None));
161-
fields.push((ident, attr))
160+
let attrs = FieldAttributes::parse_attrs(&field.attrs)?;
161+
fields.push((ident, attrs))
162162
}
163163
ContainerType::Struct(fields)
164164
}
@@ -235,17 +235,24 @@ impl<'a> Container<'a> {
235235
)
236236
}
237237

238-
fn build_struct(&self, tups: &[(&Ident, FieldAttribute)]) -> TokenStream {
238+
fn build_struct(&self, tups: &[(&Ident, FieldAttributes)]) -> TokenStream {
239239
let self_ty = &self.path;
240240
let mut fields: Punctuated<TokenStream, syn::Token![,]> = Punctuated::new();
241-
for (ident, attr) in tups {
242-
let ext_fn = match attr {
243-
FieldAttribute::GetAttr(Some(name)) => quote!(getattr(#name)),
244-
FieldAttribute::GetAttr(None) => quote!(getattr(stringify!(#ident))),
245-
FieldAttribute::GetItem(Some(key)) => quote!(get_item(#key)),
246-
FieldAttribute::GetItem(None) => quote!(get_item(stringify!(#ident))),
241+
for (ident, attrs) in tups {
242+
let getter = match &attrs.getter {
243+
FieldGetter::GetAttr(Some(name)) => quote!(getattr(#name)),
244+
FieldGetter::GetAttr(None) => quote!(getattr(stringify!(#ident))),
245+
FieldGetter::GetItem(Some(key)) => quote!(get_item(#key)),
246+
FieldGetter::GetItem(None) => quote!(get_item(stringify!(#ident))),
247247
};
248-
fields.push(quote!(#ident: obj.#ext_fn?.extract()?));
248+
249+
let get_field = quote!(obj.#getter?);
250+
let extractor = match &attrs.from_py_with {
251+
None => quote!(#get_field.extract()?),
252+
Some(FromPyWithAttribute(expr_path)) => quote! (#expr_path(#get_field)?),
253+
};
254+
255+
fields.push(quote!(#ident: #extractor));
249256
}
250257
quote!(Ok(#self_ty{#fields}))
251258
}
@@ -309,40 +316,59 @@ impl ContainerAttribute {
309316

310317
/// Attributes for deriving FromPyObject scoped on fields.
311318
#[derive(Clone, Debug)]
312-
enum FieldAttribute {
319+
struct FieldAttributes {
320+
getter: FieldGetter,
321+
from_py_with: Option<FromPyWithAttribute>,
322+
}
323+
324+
#[derive(Clone, Debug)]
325+
enum FieldGetter {
313326
GetItem(Option<syn::Lit>),
314327
GetAttr(Option<syn::LitStr>),
315328
}
316329

317-
impl FieldAttribute {
318-
/// Extract the field attribute.
330+
impl FieldAttributes {
331+
/// Extract the field attributes.
319332
///
320-
/// Currently fails if more than 1 attribute is passed in `pyo3`
321-
fn parse_attrs(attrs: &[Attribute]) -> Result<Option<Self>> {
333+
fn parse_attrs(attrs: &[Attribute]) -> Result<Self> {
334+
let mut getter = None;
335+
let mut from_py_with = None;
336+
322337
let list = get_pyo3_meta_list(attrs)?;
323-
let metaitem = match list.nested.len() {
324-
0 => return Ok(None),
325-
1 => list.nested.into_iter().next().unwrap(),
326-
_ => bail_spanned!(
327-
list.nested.span() =>
328-
"only one of `attribute` or `item` can be provided"
329-
),
330-
};
331-
let meta = match metaitem {
332-
syn::NestedMeta::Meta(meta) => meta,
333-
syn::NestedMeta::Lit(lit) => bail_spanned!(
334-
lit.span() =>
335-
"expected `attribute` or `item`, got a literal"
336-
),
337-
};
338-
let path = meta.path();
339-
if path.is_ident("attribute") {
340-
Ok(Some(FieldAttribute::GetAttr(Self::attribute_arg(meta)?)))
341-
} else if path.is_ident("item") {
342-
Ok(Some(FieldAttribute::GetItem(Self::item_arg(meta)?)))
343-
} else {
344-
bail_spanned!(meta.span() => "expected `attribute` or `item`");
338+
339+
for meta_item in list.nested {
340+
let meta = match meta_item {
341+
syn::NestedMeta::Meta(meta) => meta,
342+
syn::NestedMeta::Lit(lit) => bail_spanned!(
343+
lit.span() =>
344+
"expected `attribute`, `item` or `from_py_with`, got a literal"
345+
),
346+
};
347+
let path = meta.path();
348+
349+
if path.is_ident("attribute") {
350+
ensure_spanned!(
351+
getter.is_none(),
352+
meta.span() => "only one of `attribute` or `item` can be provided"
353+
);
354+
getter = Some(FieldGetter::GetAttr(Self::attribute_arg(meta)?))
355+
} else if path.is_ident("item") {
356+
ensure_spanned!(
357+
getter.is_none(),
358+
meta.span() => "only one of `attribute` or `item` can be provided"
359+
);
360+
getter = Some(FieldGetter::GetItem(Self::item_arg(meta)?))
361+
} else if path.is_ident("from_py_with") {
362+
from_py_with = Some(Self::from_py_with_arg(meta)?)
363+
} else {
364+
bail_spanned!(meta.span() => "expected `attribute`, `item` or `from_py_with`")
365+
};
345366
}
367+
368+
Ok(FieldAttributes {
369+
getter: getter.unwrap_or(FieldGetter::GetAttr(None)),
370+
from_py_with,
371+
})
346372
}
347373

348374
fn attribute_arg(meta: Meta) -> syn::Result<Option<syn::LitStr>> {
@@ -389,6 +415,10 @@ impl FieldAttribute {
389415

390416
bail_spanned!(arg_list.span() => "expected a single literal argument");
391417
}
418+
419+
fn from_py_with_arg(meta: Meta) -> syn::Result<FromPyWithAttribute> {
420+
FromPyWithAttribute::from_meta(meta)
421+
}
392422
}
393423

394424
/// Extract pyo3 metalist, flattens multiple lists into a single one.
@@ -426,7 +456,7 @@ fn verify_and_get_lifetime(generics: &syn::Generics) -> Result<Option<&syn::Life
426456
/// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier
427457
/// * At least one field, in case of `#[transparent]`, exactly one field
428458
/// * At least one variant for enums.
429-
/// * Fields of input structs and enums must implement `FromPyObject`
459+
/// * Fields of input structs and enums must implement `FromPyObject` or be annotated with `from_py_with`
430460
/// * Derivation for structs with generic fields like `struct<T> Foo(T)`
431461
/// adds `T: FromPyObject` on the derived implementation.
432462
pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result<TokenStream> {

pyo3-macros-backend/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#[macro_use]
88
mod utils;
99

10+
mod attrs;
1011
mod defs;
1112
mod from_pyobject;
1213
mod konst;

0 commit comments

Comments
 (0)