Skip to content

pyfunction: allow wrap_pyfunction to work on imports (even cross-crate) #2091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add buffer magic methods `__getbuffer__` and `__releasebuffer__` to `#[pymethods]`. [#2067](https://github.com/PyO3/pyo3/pull/2067)
- Accept paths in `wrap_pyfunction` and `wrap_pymodule`. [#2081](https://github.com/PyO3/pyo3/pull/2081)
- Add check for correct number of arguments on magic methods. [#2083](https://github.com/PyO3/pyo3/pull/2083)
- `wrap_pyfunction!` can now wrap a `#[pyfunction]` which is implemented in a different Rust module or crate. [#2091](https://github.com/PyO3/pyo3/pull/2091)

### Changed

Expand Down
5 changes: 3 additions & 2 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
if let syn::Stmt::Item(syn::Item::Fn(func)) = &mut stmt {
if let Some(pyfn_args) = get_pyfn_attr(&mut func.attrs)? {
let module_name = pyfn_args.modname;
let (ident, wrapped_function) = impl_wrap_pyfunction(func, pyfn_args.options)?;
let wrapped_function = impl_wrap_pyfunction(func, pyfn_args.options)?;
let name = &func.sig.ident;
let statements: Vec<syn::Stmt> = syn::parse_quote! {
#wrapped_function
#module_name.add_function(#ident(#module_name)?)?;
#module_name.add_function(#name::wrap(#name::DEF, #module_name)?)?;
};
stmts.extend(statements);
}
Expand Down
39 changes: 28 additions & 11 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ use crate::{
method::{self, CallingConvention, FnArg},
pymethod::check_generic,
utils::{self, ensure_not_async_fn, get_pyo3_crate},
wrap::function_wrapper_ident,
};
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use syn::punctuated::Punctuated;
use syn::{ext::IdentExt, spanned::Spanned, Ident, NestedMeta, Path, Result};
use syn::{ext::IdentExt, spanned::Spanned, NestedMeta, Path, Result};
use syn::{
parse::{Parse, ParseBuffer, ParseStream},
token::Comma,
Expand Down Expand Up @@ -364,15 +363,15 @@ pub fn build_py_function(
mut options: PyFunctionOptions,
) -> syn::Result<TokenStream> {
options.add_attributes(take_pyo3_options(&mut ast.attrs)?)?;
Ok(impl_wrap_pyfunction(ast, options)?.1)
impl_wrap_pyfunction(ast, options)
}

/// Generates python wrapper over a function that allows adding it to a python module as a python
/// function
pub fn impl_wrap_pyfunction(
func: &mut syn::ItemFn,
options: PyFunctionOptions,
) -> syn::Result<(Ident, TokenStream)> {
) -> syn::Result<TokenStream> {
check_generic(&func.sig)?;
ensure_not_async_fn(&func.sig)?;

Expand Down Expand Up @@ -412,7 +411,6 @@ pub fn impl_wrap_pyfunction(
.map(|attr| (&python_name, attr)),
);

let function_wrapper_ident = function_wrapper_ident(&func.sig.ident);
let krate = get_pyo3_crate(&options.krate);

let spec = method::FnSpec {
Expand All @@ -434,21 +432,40 @@ pub fn impl_wrap_pyfunction(
unsafety: func.sig.unsafety,
};

let wrapper_ident = format_ident!("__pyo3_raw_{}", spec.name);
let vis = &func.vis;
let name = &func.sig.ident;

let wrapper_ident = format_ident!("__pyfunction_{}", spec.name);
let wrapper = spec.get_wrapper_function(&wrapper_ident, None)?;
let methoddef = spec.get_methoddef(wrapper_ident);

let wrapped_pyfunction = quote! {
#wrapper

pub(crate) fn #function_wrapper_ident<'a>(
args: impl ::std::convert::Into<#krate::derive_utils::PyFunctionArguments<'a>>
) -> #krate::PyResult<&'a #krate::types::PyCFunction> {
// Create a module with the same name as the `#[pyfunction]` - this way `use <the function>`
// will actually bring both the module and the function into scope.
#[doc(hidden)]
#vis mod #name {
use #krate as _pyo3;
_pyo3::types::PyCFunction::internal_new(#methoddef, args.into())
pub(crate) struct PyO3Def;

// Exported for `wrap_pyfunction!`
pub use _pyo3::impl_::pyfunction::wrap_pyfunction as wrap;
pub const DEF: _pyo3::PyMethodDef = <PyO3Def as _pyo3::impl_::pyfunction::PyFunctionDef>::DEF;
}

// Generate the definition inside an anonymous function in the same scope as the original function -
// this avoids complications around the fact that the generated module has a different scope
// (and `super` doesn't always refer to the outer scope, e.g. if the `#[pyfunction] is
// inside a function body)
const _: () = {
use #krate as _pyo3;
impl _pyo3::impl_::pyfunction::PyFunctionDef for #name::PyO3Def {
const DEF: _pyo3::PyMethodDef = #methoddef;
}
};
};
Ok((function_wrapper_ident, wrapped_pyfunction))
Ok(wrapped_pyfunction)
}

fn type_is_pymodule(ty: &syn::Type) -> bool {
Expand Down
25 changes: 6 additions & 19 deletions pyo3-macros-backend/src/wrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,16 @@ impl Parse for WrapPyFunctionArgs {
}
}

pub fn wrap_pyfunction_impl(args: WrapPyFunctionArgs) -> syn::Result<TokenStream> {
pub fn wrap_pyfunction_impl(args: WrapPyFunctionArgs) -> TokenStream {
let WrapPyFunctionArgs {
mut function,
function,
comma_and_arg,
} = args;
let span = function.span();
let last_segment = function
.segments
.last_mut()
.ok_or_else(|| err_spanned!(span => "expected non-empty path"))?;

last_segment.ident = function_wrapper_ident(&last_segment.ident);

let output = if let Some((_, arg)) = comma_and_arg {
quote! { #function(#arg) }
if let Some((_, arg)) = comma_and_arg {
quote! { #function::wrap(#function::DEF, #arg) }
} else {
quote! { &|arg| #function(arg) }
};
Ok(output)
quote! { &|arg| #function::wrap(#function::DEF, arg) }
}
}

pub fn wrap_pymodule_impl(mut module_path: syn::Path) -> syn::Result<TokenStream> {
Expand All @@ -58,10 +49,6 @@ pub fn wrap_pymodule_impl(mut module_path: syn::Path) -> syn::Result<TokenStream
})
}

pub(crate) fn function_wrapper_ident(name: &Ident) -> Ident {
format_ident!("__pyo3_get_function_{}", name)
}

pub(crate) fn module_def_ident(name: &Ident) -> Ident {
format_ident!("__PYO3_PYMODULE_DEF_{}", name.to_string().to_uppercase())
}
10 changes: 8 additions & 2 deletions pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ pub fn pymethods(_: TokenStream, input: TokenStream) -> TokenStream {

/// A proc macro used to expose Rust functions to Python.
///
/// Functions annotated with `#[pyfunction]` can also be annotated with the following `#[pyo3]` options:
/// Functions annotated with `#[pyfunction]` can also be annotated with the following `#[pyo3]`
/// options:
///
/// | Annotation | Description |
/// | :- | :- |
Expand All @@ -176,6 +177,11 @@ pub fn pymethods(_: TokenStream, input: TokenStream) -> TokenStream {
///
/// For more on exposing functions see the [function section of the guide][1].
///
/// Due to technical limitations on how `#[pyfunction]` is implemented, a function marked
/// `#[pyfunction]` cannot have a module with the same name in the same scope. (The
/// `#[pyfunction]` implementation generates a hidden module with the same name containing
/// metadata about the function, which is used by `wrap_pyfunction!`).
///
/// [1]: https://pyo3.rs/latest/function.html
#[proc_macro_attribute]
pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -208,7 +214,7 @@ pub fn derive_from_py_object(item: TokenStream) -> TokenStream {
#[proc_macro]
pub fn wrap_pyfunction(input: TokenStream) -> TokenStream {
let args = parse_macro_input!(input as WrapPyFunctionArgs);
wrap_pyfunction_impl(args).unwrap_or_compile_error().into()
wrap_pyfunction_impl(args).into()
}

/// Returns a function that takes a `Python` instance and returns a Python module.
Expand Down
4 changes: 1 addition & 3 deletions src/derive_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

//! Functionality for the code generated by the derive backend

use crate::err::PyErr;
use crate::types::PyModule;
use crate::{PyCell, PyClass, Python};
use crate::{types::PyModule, PyCell, PyClass, PyErr, Python};

/// Utility trait to enable &PyClass as a pymethod/function argument
#[doc(hidden)]
Expand Down
2 changes: 2 additions & 0 deletions src/impl_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ pub(crate) mod not_send;
#[doc(hidden)]
pub mod pyclass;
#[doc(hidden)]
pub mod pyfunction;
#[doc(hidden)]
pub mod pymethods;
pub mod pymodule;
14 changes: 14 additions & 0 deletions src/impl_/pyfunction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use crate::{
class::methods::PyMethodDef, derive_utils::PyFunctionArguments, types::PyCFunction, PyResult,
};

pub trait PyFunctionDef {
const DEF: crate::PyMethodDef;
}

pub fn wrap_pyfunction<'a>(
method_def: PyMethodDef,
args: impl Into<PyFunctionArguments<'a>>,
) -> PyResult<&'a PyCFunction> {
PyCFunction::internal_new(method_def, args.into())
}
26 changes: 26 additions & 0 deletions tests/test_pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,29 @@ fn test_closure_counter() {
py_assert!(py, counter_py, "counter_py() == 2");
py_assert!(py, counter_py, "counter_py() == 3");
}

#[test]
fn use_pyfunction() {
mod function_in_module {
use pyo3::prelude::*;

#[pyfunction]
pub fn foo(x: i32) -> i32 {
x
}
}

Python::with_gil(|py| {
use function_in_module::foo;

// check imported name can be wrapped
let f = wrap_pyfunction!(foo, py).unwrap();
assert_eq!(f.call1((5,)).unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(f.call1((42,)).unwrap().extract::<i32>().unwrap(), 42);

// check path import can be wrapped
let f2 = wrap_pyfunction!(function_in_module::foo, py).unwrap();
assert_eq!(f2.call1((5,)).unwrap().extract::<i32>().unwrap(), 5);
assert_eq!(f2.call1((42,)).unwrap().extract::<i32>().unwrap(), 42);
})
}