Skip to content

Commit b2e25b1

Browse files
committed
pyfunction: allow wrap_pyfunction to work on imports (even cross-crate)
1 parent 2cee7fe commit b2e25b1

File tree

9 files changed

+82
-36
lines changed

9 files changed

+82
-36
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2727
- Add buffer magic methods `__getbuffer__` and `__releasebuffer__` to `#[pymethods]`. [#2067](https://github.com/PyO3/pyo3/pull/2067)
2828
- Accept paths in `wrap_pyfunction` and `wrap_pymodule`. [#2081](https://github.com/PyO3/pyo3/pull/2081)
2929
- Add check for correct number of arguments on magic methods. [#2083](https://github.com/PyO3/pyo3/pull/2083)
30+
- `wrap_pyfunction!` can now wrap a `#[pyfunction]` which is implemented in a different Rust module or crate. [#2091](https://github.com/PyO3/pyo3/pull/2091)
3031

3132
### Changed
3233

pyo3-macros-backend/src/module.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,11 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
9999
if let syn::Stmt::Item(syn::Item::Fn(func)) = &mut stmt {
100100
if let Some(pyfn_args) = get_pyfn_attr(&mut func.attrs)? {
101101
let module_name = pyfn_args.modname;
102-
let (ident, wrapped_function) = impl_wrap_pyfunction(func, pyfn_args.options)?;
102+
let wrapped_function = impl_wrap_pyfunction(func, pyfn_args.options)?;
103+
let name = &func.sig.ident;
103104
let statements: Vec<syn::Stmt> = syn::parse_quote! {
104105
#wrapped_function
105-
#module_name.add_function(#ident(#module_name)?)?;
106+
#module_name.add_function(#name::wrap(#name::DEF, #module_name)?)?;
106107
};
107108
stmts.extend(statements);
108109
}

pyo3-macros-backend/src/pyfunction.rs

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@ use crate::{
99
method::{self, CallingConvention, FnArg},
1010
pymethod::check_generic,
1111
utils::{self, ensure_not_async_fn, get_pyo3_crate},
12-
wrap::function_wrapper_ident,
1312
};
1413
use proc_macro2::{Span, TokenStream};
1514
use quote::{format_ident, quote};
1615
use syn::punctuated::Punctuated;
17-
use syn::{ext::IdentExt, spanned::Spanned, Ident, NestedMeta, Path, Result};
16+
use syn::{ext::IdentExt, spanned::Spanned, NestedMeta, Path, Result};
1817
use syn::{
1918
parse::{Parse, ParseBuffer, ParseStream},
2019
token::Comma,
@@ -364,15 +363,15 @@ pub fn build_py_function(
364363
mut options: PyFunctionOptions,
365364
) -> syn::Result<TokenStream> {
366365
options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
367-
Ok(impl_wrap_pyfunction(ast, options)?.1)
366+
impl_wrap_pyfunction(ast, options)
368367
}
369368

370369
/// Generates python wrapper over a function that allows adding it to a python module as a python
371370
/// function
372371
pub fn impl_wrap_pyfunction(
373372
func: &mut syn::ItemFn,
374373
options: PyFunctionOptions,
375-
) -> syn::Result<(Ident, TokenStream)> {
374+
) -> syn::Result<TokenStream> {
376375
check_generic(&func.sig)?;
377376
ensure_not_async_fn(&func.sig)?;
378377

@@ -412,7 +411,6 @@ pub fn impl_wrap_pyfunction(
412411
.map(|attr| (&python_name, attr)),
413412
);
414413

415-
let function_wrapper_ident = function_wrapper_ident(&func.sig.ident);
416414
let krate = get_pyo3_crate(&options.krate);
417415

418416
let spec = method::FnSpec {
@@ -434,21 +432,40 @@ pub fn impl_wrap_pyfunction(
434432
unsafety: func.sig.unsafety,
435433
};
436434

437-
let wrapper_ident = format_ident!("__pyo3_raw_{}", spec.name);
435+
let vis = &func.vis;
436+
let name = &func.sig.ident;
437+
438+
let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
438439
let wrapper = spec.get_wrapper_function(&wrapper_ident, None)?;
439440
let methoddef = spec.get_methoddef(wrapper_ident);
440441

441442
let wrapped_pyfunction = quote! {
442443
#wrapper
443444

444-
pub(crate) fn #function_wrapper_ident<'a>(
445-
args: impl ::std::convert::Into<#krate::derive_utils::PyFunctionArguments<'a>>
446-
) -> #krate::PyResult<&'a #krate::types::PyCFunction> {
445+
// Create a module with the same name as the `#[pyfunction]` - this way `use <the function>`
446+
// will actually bring both the module and the function into scope.
447+
#[doc(hidden)]
448+
#vis mod #name {
447449
use #krate as _pyo3;
448-
_pyo3::types::PyCFunction::internal_new(#methoddef, args.into())
450+
pub(crate) struct PyO3Def;
451+
452+
// Exported for `wrap_pyfunction!`
453+
pub use _pyo3::impl_::pyfunction::wrap_pyfunction as wrap;
454+
pub const DEF: _pyo3::PyMethodDef = <PyO3Def as _pyo3::impl_::pyfunction::PyFunctionDef>::DEF;
449455
}
456+
457+
// Generate the definition inside an anonymous function in the same scope as the original function -
458+
// this avoids complications around the fact that the generated module has a different scope
459+
// (and `super` doesn't always refer to the outer scope, e.g. if the `#[pyfunction] is
460+
// inside a function body)
461+
const _: () = {
462+
use #krate as _pyo3;
463+
impl _pyo3::impl_::pyfunction::PyFunctionDef for #name::PyO3Def {
464+
const DEF: _pyo3::PyMethodDef = #methoddef;
465+
}
466+
};
450467
};
451-
Ok((function_wrapper_ident, wrapped_pyfunction))
468+
Ok(wrapped_pyfunction)
452469
}
453470

454471
fn type_is_pymodule(ty: &syn::Type) -> bool {

pyo3-macros-backend/src/wrap.rs

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,16 @@ impl Parse for WrapPyFunctionArgs {
2222
}
2323
}
2424

25-
pub fn wrap_pyfunction_impl(args: WrapPyFunctionArgs) -> syn::Result<TokenStream> {
25+
pub fn wrap_pyfunction_impl(args: WrapPyFunctionArgs) -> TokenStream {
2626
let WrapPyFunctionArgs {
27-
mut function,
27+
function,
2828
comma_and_arg,
2929
} = args;
30-
let span = function.span();
31-
let last_segment = function
32-
.segments
33-
.last_mut()
34-
.ok_or_else(|| err_spanned!(span => "expected non-empty path"))?;
35-
36-
last_segment.ident = function_wrapper_ident(&last_segment.ident);
37-
38-
let output = if let Some((_, arg)) = comma_and_arg {
39-
quote! { #function(#arg) }
30+
if let Some((_, arg)) = comma_and_arg {
31+
quote! { #function::wrap(#function::DEF, #arg) }
4032
} else {
41-
quote! { &|arg| #function(arg) }
42-
};
43-
Ok(output)
33+
quote! { &|arg| #function::wrap(#function::DEF, arg) }
34+
}
4435
}
4536

4637
pub fn wrap_pymodule_impl(mut module_path: syn::Path) -> syn::Result<TokenStream> {
@@ -58,10 +49,6 @@ pub fn wrap_pymodule_impl(mut module_path: syn::Path) -> syn::Result<TokenStream
5849
})
5950
}
6051

61-
pub(crate) fn function_wrapper_ident(name: &Ident) -> Ident {
62-
format_ident!("__pyo3_get_function_{}", name)
63-
}
64-
6552
pub(crate) fn module_def_ident(name: &Ident) -> Ident {
6653
format_ident!("__PYO3_PYMODULE_DEF_{}", name.to_string().to_uppercase())
6754
}

pyo3-macros/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ pub fn derive_from_py_object(item: TokenStream) -> TokenStream {
208208
#[proc_macro]
209209
pub fn wrap_pyfunction(input: TokenStream) -> TokenStream {
210210
let args = parse_macro_input!(input as WrapPyFunctionArgs);
211-
wrap_pyfunction_impl(args).unwrap_or_compile_error().into()
211+
wrap_pyfunction_impl(args).into()
212212
}
213213

214214
/// Returns a function that takes a `Python` instance and returns a Python module.

src/derive_utils.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
//! Functionality for the code generated by the derive backend
66
7-
use crate::err::PyErr;
8-
use crate::types::PyModule;
9-
use crate::{PyCell, PyClass, Python};
7+
use crate::{types::PyModule, PyCell, PyClass, PyErr, Python};
108

119
/// Utility trait to enable &PyClass as a pymethod/function argument
1210
#[doc(hidden)]

src/impl_.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,7 @@ pub(crate) mod not_send;
1313
#[doc(hidden)]
1414
pub mod pyclass;
1515
#[doc(hidden)]
16+
pub mod pyfunction;
17+
#[doc(hidden)]
1618
pub mod pymethods;
1719
pub mod pymodule;

src/impl_/pyfunction.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
use crate::{
2+
class::methods::PyMethodDef, derive_utils::PyFunctionArguments, types::PyCFunction, PyResult,
3+
};
4+
5+
pub trait PyFunctionDef {
6+
const DEF: crate::PyMethodDef;
7+
}
8+
9+
pub fn wrap_pyfunction<'a>(
10+
method_def: PyMethodDef,
11+
args: impl Into<PyFunctionArguments<'a>>,
12+
) -> PyResult<&'a PyCFunction> {
13+
PyCFunction::internal_new(method_def, args.into())
14+
}

tests/test_pyfunction.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,29 @@ fn test_closure_counter() {
269269
py_assert!(py, counter_py, "counter_py() == 2");
270270
py_assert!(py, counter_py, "counter_py() == 3");
271271
}
272+
273+
#[test]
274+
fn use_pyfunction() {
275+
mod function_in_module {
276+
use pyo3::prelude::*;
277+
278+
#[pyfunction]
279+
pub fn foo(x: i32) -> i32 {
280+
x
281+
}
282+
}
283+
284+
Python::with_gil(|py| {
285+
use function_in_module::foo;
286+
287+
// check imported name can be wrapped
288+
let f = wrap_pyfunction!(foo, py).unwrap();
289+
assert_eq!(f.call1((5,)).unwrap().extract::<i32>().unwrap(), 5);
290+
assert_eq!(f.call1((42,)).unwrap().extract::<i32>().unwrap(), 42);
291+
292+
// check path import can be wrapped
293+
let f2 = wrap_pyfunction!(function_in_module::foo, py).unwrap();
294+
assert_eq!(f2.call1((5,)).unwrap().extract::<i32>().unwrap(), 5);
295+
assert_eq!(f2.call1((42,)).unwrap().extract::<i32>().unwrap(), 42);
296+
})
297+
}

0 commit comments

Comments
 (0)