Skip to content

Commit ccfb358

Browse files
committed
Set the module of #[pyfunction]s.
Previously neither the module nor the name of the module of pyfunctions were registered. This commit passes the module and its name when creating a new pyfunction. PyModule::add_function and PyModule::add_module have been added and are set to replace `add_wrapped` in a future release. `add_wrapped` is kept for compatibility reasons during the transition. Depending on whether a `PyModule` or `Python` is the argument for the Python function-wrapper, the module will be registered with the function.
1 parent 21ad52a commit ccfb358

File tree

12 files changed

+121
-38
lines changed

12 files changed

+121
-38
lines changed

examples/rustapi_module/src/datetime.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -215,29 +215,29 @@ impl TzClass {
215215

216216
#[pymodule]
217217
fn datetime(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
218-
m.add_wrapped(wrap_pyfunction!(make_date))?;
219-
m.add_wrapped(wrap_pyfunction!(get_date_tuple))?;
220-
m.add_wrapped(wrap_pyfunction!(date_from_timestamp))?;
221-
m.add_wrapped(wrap_pyfunction!(make_time))?;
222-
m.add_wrapped(wrap_pyfunction!(get_time_tuple))?;
223-
m.add_wrapped(wrap_pyfunction!(make_delta))?;
224-
m.add_wrapped(wrap_pyfunction!(get_delta_tuple))?;
225-
m.add_wrapped(wrap_pyfunction!(make_datetime))?;
226-
m.add_wrapped(wrap_pyfunction!(get_datetime_tuple))?;
227-
m.add_wrapped(wrap_pyfunction!(datetime_from_timestamp))?;
218+
m.add_function(wrap_pyfunction!(make_date))?;
219+
m.add_function(wrap_pyfunction!(get_date_tuple))?;
220+
m.add_function(wrap_pyfunction!(date_from_timestamp))?;
221+
m.add_function(wrap_pyfunction!(make_time))?;
222+
m.add_function(wrap_pyfunction!(get_time_tuple))?;
223+
m.add_function(wrap_pyfunction!(make_delta))?;
224+
m.add_function(wrap_pyfunction!(get_delta_tuple))?;
225+
m.add_function(wrap_pyfunction!(make_datetime))?;
226+
m.add_function(wrap_pyfunction!(get_datetime_tuple))?;
227+
m.add_function(wrap_pyfunction!(datetime_from_timestamp))?;
228228

229229
// Python 3.6+ functions
230230
#[cfg(Py_3_6)]
231231
{
232-
m.add_wrapped(wrap_pyfunction!(time_with_fold))?;
232+
m.add_function(wrap_pyfunction!(time_with_fold))?;
233233
#[cfg(not(PyPy))]
234234
{
235-
m.add_wrapped(wrap_pyfunction!(get_time_tuple_fold))?;
236-
m.add_wrapped(wrap_pyfunction!(get_datetime_tuple_fold))?;
235+
m.add_function(wrap_pyfunction!(get_time_tuple_fold))?;
236+
m.add_function(wrap_pyfunction!(get_datetime_tuple_fold))?;
237237
}
238238
}
239239

240-
m.add_wrapped(wrap_pyfunction!(issue_219))?;
240+
m.add_function(wrap_pyfunction!(issue_219))?;
241241
m.add_class::<TzClass>()?;
242242

243243
Ok(())

examples/rustapi_module/src/othermod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ fn double(x: i32) -> i32 {
3131

3232
#[pymodule]
3333
fn othermod(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
34-
m.add_wrapped(wrap_pyfunction!(double))?;
34+
m.add_function(wrap_pyfunction!(double))?;
3535

3636
m.add_class::<ModClass>()?;
3737

examples/word-count/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ fn count_line(line: &str, needle: &str) -> usize {
5555

5656
#[pymodule]
5757
fn word_count(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
58-
m.add_wrapped(wrap_pyfunction!(search))?;
59-
m.add_wrapped(wrap_pyfunction!(search_sequential))?;
60-
m.add_wrapped(wrap_pyfunction!(search_sequential_allow_threads))?;
58+
m.add_function(wrap_pyfunction!(search))?;
59+
m.add_function(wrap_pyfunction!(search_sequential))?;
60+
m.add_function(wrap_pyfunction!(search_sequential_allow_threads))?;
6161

6262
Ok(())
6363
}

guide/src/function.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ fn double(x: usize) -> usize {
3636

3737
#[pymodule]
3838
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
39-
m.add_wrapped(wrap_pyfunction!(double)).unwrap();
39+
m.add_function(wrap_pyfunction!(double)).unwrap();
4040

4141
Ok(())
4242
}
@@ -65,7 +65,7 @@ fn num_kwds(kwds: Option<&PyDict>) -> usize {
6565

6666
#[pymodule]
6767
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
68-
m.add_wrapped(wrap_pyfunction!(num_kwds)).unwrap();
68+
m.add_function(wrap_pyfunction!(num_kwds)).unwrap();
6969
Ok(())
7070
}
7171

guide/src/module.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ fn subfunction() -> String {
6767

6868
#[pymodule]
6969
fn submodule(_py: Python, module: &PyModule) -> PyResult<()> {
70-
module.add_wrapped(wrap_pyfunction!(subfunction))?;
70+
module.add_function(wrap_pyfunction!(subfunction))?;
7171
Ok(())
7272
}
7373

7474
#[pymodule]
7575
fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> {
76-
module.add_wrapped(wrap_pymodule!(submodule))?;
76+
module.add_module(wrap_pymodule!(submodule))?;
7777
Ok(())
7878
}
7979

pyo3-derive-backend/src/module.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
4545
let item: syn::ItemFn = syn::parse_quote! {
4646
fn block_wrapper() {
4747
#function_to_python
48-
#module_name.add_wrapped(&#function_wrapper_ident)?;
48+
#module_name.add_function(&#function_wrapper_ident)?;
4949
}
5050
};
5151
stmts.extend(item.block.stmts.into_iter());
@@ -193,7 +193,17 @@ pub fn add_fn_to_module(
193193
let wrapper = function_c_wrapper(&func.sig.ident, &spec);
194194

195195
Ok(quote! {
196-
fn #function_wrapper_ident(py: pyo3::Python) -> pyo3::PyObject {
196+
fn #function_wrapper_ident<'a>(
197+
args: impl Into<pyo3::derive_utils::WrapPyFunctionArguments<'a>>
198+
) -> pyo3::PyObject {
199+
let arg = args.into();
200+
let (py, maybe_module) = match arg {
201+
pyo3::derive_utils::WrapPyFunctionArguments::Python(py) => (py, None),
202+
pyo3::derive_utils::WrapPyFunctionArguments::PyModule(module) => {
203+
let py = <pyo3::types::PyModule as pyo3::PyNativeType>::py(module);
204+
(py, Some(module))
205+
}
206+
};
197207
#wrapper
198208

199209
let _def = pyo3::class::PyMethodDef {
@@ -203,12 +213,26 @@ pub fn add_fn_to_module(
203213
ml_doc: #doc,
204214
};
205215

216+
let (mod_ptr, name) = if let Some(m) = maybe_module {
217+
let mod_ptr = <pyo3::types::PyModule as ::pyo3::conversion::AsPyPointer>::as_ptr(m);
218+
let name = match m.name() {
219+
Ok(name) => <&str as pyo3::conversion::IntoPy<PyObject>>::into_py(name, py),
220+
Err(err) => {
221+
return <PyErr as pyo3::conversion::IntoPy<PyObject>>::into_py(err, py);
222+
}
223+
};
224+
(mod_ptr, <PyObject as pyo3::AsPyPointer>::as_ptr(&name))
225+
} else {
226+
(std::ptr::null_mut(), std::ptr::null_mut())
227+
};
228+
206229
let function = unsafe {
207230
pyo3::PyObject::from_owned_ptr(
208231
py,
209-
pyo3::ffi::PyCFunction_New(
232+
pyo3::ffi::PyCFunction_NewEx(
210233
Box::into_raw(Box::new(_def.as_method_def())),
211-
::std::ptr::null_mut()
234+
mod_ptr,
235+
name
212236
)
213237
)
214238
};

src/derive_utils.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,22 @@ where
207207
<R as std::convert::TryFrom<&'a PyCell<T>>>::try_from(cell)
208208
}
209209
}
210+
211+
/// Enum to abstract over the arguments of Python function wrappers.
212+
#[doc(hidden)]
213+
pub enum WrapPyFunctionArguments<'a> {
214+
Python(Python<'a>),
215+
PyModule(&'a PyModule),
216+
}
217+
218+
impl<'a> From<Python<'a>> for WrapPyFunctionArguments<'a> {
219+
fn from(py: Python<'a>) -> WrapPyFunctionArguments<'a> {
220+
WrapPyFunctionArguments::Python(py)
221+
}
222+
}
223+
224+
impl<'a> From<&'a PyModule> for WrapPyFunctionArguments<'a> {
225+
fn from(module: &'a PyModule) -> WrapPyFunctionArguments<'a> {
226+
WrapPyFunctionArguments::PyModule(module)
227+
}
228+
}

src/ffi/moduleobject.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ extern "C" {
2525
pub fn PyModule_New(name: *const c_char) -> *mut PyObject;
2626
#[cfg_attr(PyPy, link_name = "PyPyModule_GetDict")]
2727
pub fn PyModule_GetDict(arg1: *mut PyObject) -> *mut PyObject;
28+
#[cfg_attr(PyPy, link_name = "PyPyModule_GetNameObject")]
2829
pub fn PyModule_GetNameObject(arg1: *mut PyObject) -> *mut PyObject;
2930
#[cfg_attr(PyPy, link_name = "PyPyModule_GetName")]
3031
pub fn PyModule_GetName(arg1: *mut PyObject) -> *const c_char;

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
//! #[pymodule]
7272
//! /// A Python module implemented in Rust.
7373
//! fn string_sum(py: Python, m: &PyModule) -> PyResult<()> {
74-
//! m.add_wrapped(wrap_pyfunction!(sum_as_string))?;
74+
//! m.add_function(wrap_pyfunction!(sum_as_string))?;
7575
//!
7676
//! Ok(())
7777
//! }

src/python.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ impl<'p> Python<'p> {
134134
/// let gil = Python::acquire_gil();
135135
/// let py = gil.python();
136136
/// let m = PyModule::new(py, "pcount").unwrap();
137-
/// m.add_wrapped(wrap_pyfunction!(parallel_count)).unwrap();
137+
/// m.add_function(wrap_pyfunction!(parallel_count)).unwrap();
138138
/// let locals = [("pcount", m)].into_py_dict(py);
139139
/// py.run(r#"
140140
/// s = ["Flow", "my", "tears", "the", "Policeman", "Said"]

src/types/module.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,50 @@ impl PyModule {
194194
/// ```rust,ignore
195195
/// m.add("also_double", wrap_pyfunction!(double)(py));
196196
/// ```
197-
pub fn add_wrapped(&self, wrapper: &impl Fn(Python) -> PyObject) -> PyResult<()> {
197+
///
198+
/// **This function will be deprecated in the next release. Please use the specific
199+
/// [add_function] and [add_module] functions instead.**
200+
pub fn add_wrapped<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> {
198201
let function = wrapper(self.py());
199202
let name = function
200203
.getattr(self.py(), "__name__")
201204
.expect("A function or module must have a __name__");
202205
self.add(name.extract(self.py()).unwrap(), function)
203206
}
207+
208+
/// Adds a (sub)module to a module.
209+
///
210+
/// Use this together with `#[pymodule]` and [wrap_pymodule!].
211+
///
212+
/// ```rust,ignore
213+
/// m.add_module(wrap_pymodule!(utils));
214+
/// ```
215+
pub fn add_module<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> {
216+
let function = wrapper(self.py());
217+
let name = function
218+
.getattr(self.py(), "__name__")
219+
.expect("A module must have a __name__");
220+
self.add(name.extract(self.py()).unwrap(), function)
221+
}
222+
223+
/// Adds a function to a module, using the functions __name__ as name.
224+
///
225+
/// Use this together with the`#[pyfunction]` and [wrap_pyfunction!].
226+
///
227+
/// ```rust,ignore
228+
/// m.add_function(wrap_pyfunction!(double));
229+
/// ```
230+
///
231+
/// You can also add a function with a custom name using [add](PyModule::add):
232+
///
233+
/// ```rust,ignore
234+
/// m.add("also_double", wrap_pyfunction!(double)(py, m));
235+
/// ```
236+
pub fn add_function<'a>(&'a self, wrapper: &impl Fn(&'a Self) -> PyObject) -> PyResult<()> {
237+
let function = wrapper(self);
238+
let name = function
239+
.getattr(self.py(), "__name__")
240+
.expect("A function or module must have a __name__");
241+
self.add(name.extract(self.py()).unwrap(), function)
242+
}
204243
}

tests/test_module.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ fn double(x: usize) -> usize {
3535

3636
/// This module is implemented in Rust.
3737
#[pymodule]
38-
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
38+
fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> {
3939
use pyo3::wrap_pyfunction;
4040

4141
#[pyfn(m, "sum_as_string")]
@@ -60,8 +60,8 @@ fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
6060

6161
m.add("foo", "bar").unwrap();
6262

63-
m.add_wrapped(wrap_pyfunction!(double)).unwrap();
64-
m.add("also_double", wrap_pyfunction!(double)(py)).unwrap();
63+
m.add_function(wrap_pyfunction!(double)).unwrap();
64+
m.add("also_double", wrap_pyfunction!(double)(m)).unwrap();
6565

6666
Ok(())
6767
}
@@ -157,7 +157,7 @@ fn r#move() -> usize {
157157
fn raw_ident_module(_py: Python, module: &PyModule) -> PyResult<()> {
158158
use pyo3::wrap_pyfunction;
159159

160-
module.add_wrapped(wrap_pyfunction!(r#move))
160+
module.add_function(wrap_pyfunction!(r#move))
161161
}
162162

163163
#[test]
@@ -182,7 +182,7 @@ fn custom_named_fn() -> usize {
182182
fn foobar_module(_py: Python, m: &PyModule) -> PyResult<()> {
183183
use pyo3::wrap_pyfunction;
184184

185-
m.add_wrapped(wrap_pyfunction!(custom_named_fn))?;
185+
m.add_function(wrap_pyfunction!(custom_named_fn))?;
186186
m.dict().set_item("yay", "me")?;
187187
Ok(())
188188
}
@@ -216,7 +216,7 @@ fn subfunction() -> String {
216216
fn submodule(_py: Python, module: &PyModule) -> PyResult<()> {
217217
use pyo3::wrap_pyfunction;
218218

219-
module.add_wrapped(wrap_pyfunction!(subfunction))?;
219+
module.add_function(wrap_pyfunction!(subfunction))?;
220220
Ok(())
221221
}
222222

@@ -229,8 +229,8 @@ fn superfunction() -> String {
229229
fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> {
230230
use pyo3::{wrap_pyfunction, wrap_pymodule};
231231

232-
module.add_wrapped(wrap_pyfunction!(superfunction))?;
233-
module.add_wrapped(wrap_pymodule!(submodule))?;
232+
module.add_function(wrap_pyfunction!(superfunction))?;
233+
module.add_module(wrap_pymodule!(submodule))?;
234234
Ok(())
235235
}
236236

@@ -268,7 +268,7 @@ fn vararg_module(_py: Python, m: &PyModule) -> PyResult<()> {
268268
ext_vararg_fn(py, a, vararg)
269269
}
270270

271-
m.add_wrapped(pyo3::wrap_pyfunction!(ext_vararg_fn))
271+
m.add_function(pyo3::wrap_pyfunction!(ext_vararg_fn))
272272
.unwrap();
273273
Ok(())
274274
}

0 commit comments

Comments
 (0)