Skip to content

Commit e34294a

Browse files
committed
Set the module of #[pyfunction]s.
Previously neither the module nor the name of the module of pyfunctions were registered. This commit passes the module and its name when creating a new pyfunction. PyModule::add_function and PyModule::add_module have been added and are set to replace `add_wrapped` in a future release. `add_wrapped` is kept for compatibility reasons during the transition. Depending on whether a `PyModule` or `Python` is the argument for the Python function-wrapper, the module will be registered with the function.
1 parent 21ad52a commit e34294a

File tree

11 files changed

+113
-37
lines changed

11 files changed

+113
-37
lines changed

examples/rustapi_module/src/datetime.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -215,29 +215,29 @@ impl TzClass {
215215

216216
#[pymodule]
217217
fn datetime(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
218-
m.add_wrapped(wrap_pyfunction!(make_date))?;
219-
m.add_wrapped(wrap_pyfunction!(get_date_tuple))?;
220-
m.add_wrapped(wrap_pyfunction!(date_from_timestamp))?;
221-
m.add_wrapped(wrap_pyfunction!(make_time))?;
222-
m.add_wrapped(wrap_pyfunction!(get_time_tuple))?;
223-
m.add_wrapped(wrap_pyfunction!(make_delta))?;
224-
m.add_wrapped(wrap_pyfunction!(get_delta_tuple))?;
225-
m.add_wrapped(wrap_pyfunction!(make_datetime))?;
226-
m.add_wrapped(wrap_pyfunction!(get_datetime_tuple))?;
227-
m.add_wrapped(wrap_pyfunction!(datetime_from_timestamp))?;
218+
m.add_function(wrap_pyfunction!(make_date))?;
219+
m.add_function(wrap_pyfunction!(get_date_tuple))?;
220+
m.add_function(wrap_pyfunction!(date_from_timestamp))?;
221+
m.add_function(wrap_pyfunction!(make_time))?;
222+
m.add_function(wrap_pyfunction!(get_time_tuple))?;
223+
m.add_function(wrap_pyfunction!(make_delta))?;
224+
m.add_function(wrap_pyfunction!(get_delta_tuple))?;
225+
m.add_function(wrap_pyfunction!(make_datetime))?;
226+
m.add_function(wrap_pyfunction!(get_datetime_tuple))?;
227+
m.add_function(wrap_pyfunction!(datetime_from_timestamp))?;
228228

229229
// Python 3.6+ functions
230230
#[cfg(Py_3_6)]
231231
{
232-
m.add_wrapped(wrap_pyfunction!(time_with_fold))?;
232+
m.add_function(wrap_pyfunction!(time_with_fold))?;
233233
#[cfg(not(PyPy))]
234234
{
235-
m.add_wrapped(wrap_pyfunction!(get_time_tuple_fold))?;
236-
m.add_wrapped(wrap_pyfunction!(get_datetime_tuple_fold))?;
235+
m.add_function(wrap_pyfunction!(get_time_tuple_fold))?;
236+
m.add_function(wrap_pyfunction!(get_datetime_tuple_fold))?;
237237
}
238238
}
239239

240-
m.add_wrapped(wrap_pyfunction!(issue_219))?;
240+
m.add_function(wrap_pyfunction!(issue_219))?;
241241
m.add_class::<TzClass>()?;
242242

243243
Ok(())

examples/rustapi_module/src/othermod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ fn double(x: i32) -> i32 {
3131

3232
#[pymodule]
3333
fn othermod(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
34-
m.add_wrapped(wrap_pyfunction!(double))?;
34+
m.add_function(wrap_pyfunction!(double))?;
3535

3636
m.add_class::<ModClass>()?;
3737

examples/word-count/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ fn count_line(line: &str, needle: &str) -> usize {
5555

5656
#[pymodule]
5757
fn word_count(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
58-
m.add_wrapped(wrap_pyfunction!(search))?;
59-
m.add_wrapped(wrap_pyfunction!(search_sequential))?;
60-
m.add_wrapped(wrap_pyfunction!(search_sequential_allow_threads))?;
58+
m.add_function(wrap_pyfunction!(search))?;
59+
m.add_function(wrap_pyfunction!(search_sequential))?;
60+
m.add_function(wrap_pyfunction!(search_sequential_allow_threads))?;
6161

6262
Ok(())
6363
}

guide/src/function.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ fn double(x: usize) -> usize {
3636

3737
#[pymodule]
3838
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
39-
m.add_wrapped(wrap_pyfunction!(double)).unwrap();
39+
m.add_function(wrap_pyfunction!(double)).unwrap();
4040

4141
Ok(())
4242
}
@@ -65,7 +65,7 @@ fn num_kwds(kwds: Option<&PyDict>) -> usize {
6565

6666
#[pymodule]
6767
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
68-
m.add_wrapped(wrap_pyfunction!(num_kwds)).unwrap();
68+
m.add_function(wrap_pyfunction!(num_kwds)).unwrap();
6969
Ok(())
7070
}
7171

guide/src/module.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@ fn subfunction() -> String {
6767

6868
#[pymodule]
6969
fn submodule(_py: Python, module: &PyModule) -> PyResult<()> {
70-
module.add_wrapped(wrap_pyfunction!(subfunction))?;
70+
module.add_function(wrap_pyfunction!(subfunction))?;
7171
Ok(())
7272
}
7373

7474
#[pymodule]
7575
fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> {
76-
module.add_wrapped(wrap_pymodule!(submodule))?;
76+
module.add_module(wrap_pymodule!(submodule))?;
7777
Ok(())
7878
}
7979

pyo3-derive-backend/src/module.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
4545
let item: syn::ItemFn = syn::parse_quote! {
4646
fn block_wrapper() {
4747
#function_to_python
48-
#module_name.add_wrapped(&#function_wrapper_ident)?;
48+
#module_name.add_function(&#function_wrapper_ident)?;
4949
}
5050
};
5151
stmts.extend(item.block.stmts.into_iter());
@@ -193,7 +193,10 @@ pub fn add_fn_to_module(
193193
let wrapper = function_c_wrapper(&func.sig.ident, &spec);
194194

195195
Ok(quote! {
196-
fn #function_wrapper_ident(py: pyo3::Python) -> pyo3::PyObject {
196+
fn #function_wrapper_ident<'a, 'b>(
197+
args: impl pyo3::derive_utils::WrapPyFunctionArguments<'a, 'b>
198+
) -> pyo3::PyObject {
199+
let (py, maybe_module) = args.arguments();
197200
#wrapper
198201

199202
let _def = pyo3::class::PyMethodDef {
@@ -203,12 +206,25 @@ pub fn add_fn_to_module(
203206
ml_doc: #doc,
204207
};
205208

209+
let (mod_ptr, name) = if let Some(m) = maybe_module {
210+
let mod_ptr = <pyo3::types::PyModule as ::pyo3::conversion::AsPyPointer>::as_ptr(m);
211+
let name = unsafe { pyo3::ffi::PyModule_GetNameObject(mod_ptr) };
212+
if name.is_null() {
213+
let err = PyErr::fetch(py);
214+
return <PyErr as pyo3::conversion::IntoPy<PyObject>>::into_py(err, py);
215+
}
216+
(mod_ptr, name)
217+
} else {
218+
(std::ptr::null_mut(), std::ptr::null_mut())
219+
};
220+
206221
let function = unsafe {
207222
pyo3::PyObject::from_owned_ptr(
208223
py,
209-
pyo3::ffi::PyCFunction_New(
224+
pyo3::ffi::PyCFunction_NewEx(
210225
Box::into_raw(Box::new(_def.as_method_def())),
211-
::std::ptr::null_mut()
226+
mod_ptr,
227+
name
212228
)
213229
)
214230
};

src/derive_utils.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,24 @@ where
207207
<R as std::convert::TryFrom<&'a PyCell<T>>>::try_from(cell)
208208
}
209209
}
210+
211+
/// Trait to abstract over the arguments of Python function wrappers.
212+
#[doc(hidden)]
213+
pub trait WrapPyFunctionArguments<'a, 'b> {
214+
fn arguments(self) -> (Python<'b>, Option<&'a PyModule>);
215+
}
216+
217+
impl<'a, 'b> WrapPyFunctionArguments<'a, 'b> for Python<'b> {
218+
fn arguments(self) -> (Python<'b>, Option<&'a PyModule>) {
219+
(self, None)
220+
}
221+
}
222+
223+
impl<'a, 'b> WrapPyFunctionArguments<'a, 'b> for &'a PyModule
224+
where
225+
'a: 'b,
226+
{
227+
fn arguments(self) -> (Python<'b>, Option<&'a PyModule>) {
228+
(self.py(), Some(self))
229+
}
230+
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
//! #[pymodule]
7272
//! /// A Python module implemented in Rust.
7373
//! fn string_sum(py: Python, m: &PyModule) -> PyResult<()> {
74-
//! m.add_wrapped(wrap_pyfunction!(sum_as_string))?;
74+
//! m.add_function(wrap_pyfunction!(sum_as_string))?;
7575
//!
7676
//! Ok(())
7777
//! }

src/python.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ impl<'p> Python<'p> {
134134
/// let gil = Python::acquire_gil();
135135
/// let py = gil.python();
136136
/// let m = PyModule::new(py, "pcount").unwrap();
137-
/// m.add_wrapped(wrap_pyfunction!(parallel_count)).unwrap();
137+
/// m.add_function(wrap_pyfunction!(parallel_count)).unwrap();
138138
/// let locals = [("pcount", m)].into_py_dict(py);
139139
/// py.run(r#"
140140
/// s = ["Flow", "my", "tears", "the", "Policeman", "Said"]

src/types/module.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,50 @@ impl PyModule {
194194
/// ```rust,ignore
195195
/// m.add("also_double", wrap_pyfunction!(double)(py));
196196
/// ```
197+
///
198+
/// **This function will be deprecated in the next release. Please use the specific
199+
/// [add_function] and [add_module] functions instead.**
197200
pub fn add_wrapped(&self, wrapper: &impl Fn(Python) -> PyObject) -> PyResult<()> {
198201
let function = wrapper(self.py());
199202
let name = function
200203
.getattr(self.py(), "__name__")
201204
.expect("A function or module must have a __name__");
202205
self.add(name.extract(self.py()).unwrap(), function)
203206
}
207+
208+
/// Adds a (sub)module to a module.
209+
///
210+
/// Use this together with `#[pymodule]` and [wrap_pymodule!].
211+
///
212+
/// ```rust,ignore
213+
/// m.add_module(wrap_pymodule!(utils));
214+
/// ```
215+
pub fn add_module(&self, wrapper: &impl Fn(Python) -> PyObject) -> PyResult<()> {
216+
let function = wrapper(self.py());
217+
let name = function
218+
.getattr(self.py(), "__name__")
219+
.expect("A module must have a __name__");
220+
self.add(name.extract(self.py()).unwrap(), function)
221+
}
222+
223+
/// Adds a function to a module, using the functions __name__ as name.
224+
///
225+
/// Use this together with the`#[pyfunction]` and [wrap_pyfunction!].
226+
///
227+
/// ```rust,ignore
228+
/// m.add_function(wrap_pyfunction!(double));
229+
/// ```
230+
///
231+
/// You can also add a function with a custom name using [add](PyModule::add):
232+
///
233+
/// ```rust,ignore
234+
/// m.add("also_double", wrap_pyfunction!(double)(py, m));
235+
/// ```
236+
pub fn add_function<'a>(&'a self, wrapper: &impl Fn(&'a Self) -> PyObject) -> PyResult<()> {
237+
let function = wrapper(self);
238+
let name = function
239+
.getattr(self.py(), "__name__")
240+
.expect("A function or module must have a __name__");
241+
self.add(name.extract(self.py()).unwrap(), function)
242+
}
204243
}

tests/test_module.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ fn double(x: usize) -> usize {
3535

3636
/// This module is implemented in Rust.
3737
#[pymodule]
38-
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
38+
fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> {
3939
use pyo3::wrap_pyfunction;
4040

4141
#[pyfn(m, "sum_as_string")]
@@ -60,8 +60,8 @@ fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
6060

6161
m.add("foo", "bar").unwrap();
6262

63-
m.add_wrapped(wrap_pyfunction!(double)).unwrap();
64-
m.add("also_double", wrap_pyfunction!(double)(py)).unwrap();
63+
m.add_function(wrap_pyfunction!(double)).unwrap();
64+
m.add("also_double", wrap_pyfunction!(double)(m)).unwrap();
6565

6666
Ok(())
6767
}
@@ -157,7 +157,7 @@ fn r#move() -> usize {
157157
fn raw_ident_module(_py: Python, module: &PyModule) -> PyResult<()> {
158158
use pyo3::wrap_pyfunction;
159159

160-
module.add_wrapped(wrap_pyfunction!(r#move))
160+
module.add_function(wrap_pyfunction!(r#move))
161161
}
162162

163163
#[test]
@@ -182,7 +182,7 @@ fn custom_named_fn() -> usize {
182182
fn foobar_module(_py: Python, m: &PyModule) -> PyResult<()> {
183183
use pyo3::wrap_pyfunction;
184184

185-
m.add_wrapped(wrap_pyfunction!(custom_named_fn))?;
185+
m.add_function(wrap_pyfunction!(custom_named_fn))?;
186186
m.dict().set_item("yay", "me")?;
187187
Ok(())
188188
}
@@ -216,7 +216,7 @@ fn subfunction() -> String {
216216
fn submodule(_py: Python, module: &PyModule) -> PyResult<()> {
217217
use pyo3::wrap_pyfunction;
218218

219-
module.add_wrapped(wrap_pyfunction!(subfunction))?;
219+
module.add_function(wrap_pyfunction!(subfunction))?;
220220
Ok(())
221221
}
222222

@@ -229,8 +229,8 @@ fn superfunction() -> String {
229229
fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> {
230230
use pyo3::{wrap_pyfunction, wrap_pymodule};
231231

232-
module.add_wrapped(wrap_pyfunction!(superfunction))?;
233-
module.add_wrapped(wrap_pymodule!(submodule))?;
232+
module.add_function(wrap_pyfunction!(superfunction))?;
233+
module.add_module(wrap_pymodule!(submodule))?;
234234
Ok(())
235235
}
236236

@@ -268,7 +268,7 @@ fn vararg_module(_py: Python, m: &PyModule) -> PyResult<()> {
268268
ext_vararg_fn(py, a, vararg)
269269
}
270270

271-
m.add_wrapped(pyo3::wrap_pyfunction!(ext_vararg_fn))
271+
m.add_function(pyo3::wrap_pyfunction!(ext_vararg_fn))
272272
.unwrap();
273273
Ok(())
274274
}

0 commit comments

Comments
 (0)