Skip to content

Commit 1afcfc2

Browse files
committed
WIP declarative module
1 parent ff6cce9 commit 1afcfc2

File tree

7 files changed

+205
-52
lines changed

7 files changed

+205
-52
lines changed

pyo3-macros-backend/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ mod pyproto;
2929
mod wrap;
3030

3131
pub use frompyobject::build_derive_from_pyobject;
32-
pub use module::{process_functions_in_module, pymodule_impl, PyModuleOptions};
32+
pub use module::{
33+
process_functions_in_module, pymodule_function_impl, pymodule_module_impl, PyModuleOptions,
34+
};
3335
pub use pyclass::{build_py_class, build_py_enum, PyClassArgs};
3436
pub use pyfunction::{build_py_function, PyFunctionOptions};
3537
pub use pyimpl::{build_py_methods, PyClassMethodsType};

pyo3-macros-backend/src/module.rs

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::{
55
attributes::{
66
self, is_attribute_ident, take_attributes, take_pyo3_options, CrateAttribute, NameAttribute,
77
},
8+
get_doc,
89
pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
910
utils::{get_pyo3_crate, PythonDoc},
1011
};
@@ -59,9 +60,104 @@ impl PyModuleOptions {
5960
}
6061
}
6162

63+
pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
64+
let syn::ItemMod {
65+
attrs,
66+
vis,
67+
ident,
68+
mod_token,
69+
content,
70+
semi: _,
71+
} = &mut module;
72+
let items = match content {
73+
Some((_, items)) => items,
74+
None => bail_spanned!(module.span() => "`#[pymodule]` can only be used on inline modules"),
75+
};
76+
let options = PyModuleOptions::from_attrs(attrs)?;
77+
let doc = get_doc(attrs, None);
78+
79+
let name = options.name.unwrap_or_else(|| ident.unraw());
80+
let krate = get_pyo3_crate(&options.krate);
81+
let pyinit_symbol = format!("PyInit_{}", name);
82+
83+
let mut module_items = Vec::new();
84+
85+
fn extract_use_items(source: &syn::UseTree, target: &mut Vec<Ident>) {
86+
match source {
87+
syn::UseTree::Name(name) => target.push(name.ident.clone()),
88+
syn::UseTree::Path(path) => extract_use_items(&path.tree, target),
89+
syn::UseTree::Group(group) => {
90+
for tree in &group.items {
91+
extract_use_items(tree, target)
92+
}
93+
}
94+
x @ _ => {
95+
dbg!(x);
96+
todo!()
97+
}
98+
}
99+
}
100+
101+
let mut pymodule_init = None;
102+
103+
for item in items.iter_mut() {
104+
match item {
105+
syn::Item::Use(item_use) => {
106+
let mut is_pyo3 = false;
107+
item_use.attrs.retain(|attr| {
108+
let found = attr.path.is_ident("pyo3");
109+
is_pyo3 |= found;
110+
!found
111+
});
112+
if is_pyo3 {
113+
extract_use_items(&item_use.tree, &mut module_items);
114+
}
115+
}
116+
syn::Item::Fn(item_fn) => {
117+
let mut is_module_init = false;
118+
item_fn.attrs.retain(|attr| {
119+
let found = attr.path.is_ident("pymodule_init");
120+
is_module_init |= found;
121+
!found
122+
});
123+
if is_module_init {
124+
ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one pymodule_init may be specified");
125+
let ident = &item_fn.sig.ident;
126+
pymodule_init = Some(quote! { #ident(module)?; });
127+
}
128+
}
129+
_ => {}
130+
}
131+
}
132+
133+
Ok(quote! {
134+
#vis #mod_token #ident {
135+
#(#items)*
136+
137+
pub static DEF: #krate::impl_::pymodule::ModuleDef = unsafe {
138+
use #krate::impl_::pymodule as impl_;
139+
impl_::ModuleDef::new(concat!(stringify!(#name), "\0"), #doc, impl_::ModuleInitializer(__pyo3_pymodule))
140+
};
141+
142+
pub fn __pyo3_pymodule(_py: #krate::Python, module: &#krate::types::PyModule) -> #krate::PyResult<()> {
143+
#(#module_items::DEF.add_to_module(module)?;)*
144+
#pymodule_init
145+
Ok(())
146+
}
147+
148+
/// This autogenerated function is called by the python interpreter when importing
149+
/// the module.
150+
#[export_name = #pyinit_symbol]
151+
pub unsafe extern "C" fn __pyo3_init() -> *mut #krate::ffi::PyObject {
152+
DEF.module_init()
153+
}
154+
}
155+
})
156+
}
157+
62158
/// Generates the function that is called by the python interpreter to initialize the native
63159
/// module
64-
pub fn pymodule_impl(
160+
pub fn pymodule_function_impl(
65161
fnname: &Ident,
66162
options: PyModuleOptions,
67163
doc: PythonDoc,

pyo3-macros/src/lib.rs

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ use proc_macro::TokenStream;
99
use proc_macro2::TokenStream as TokenStream2;
1010
use pyo3_macros_backend::{
1111
build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods,
12-
get_doc, process_functions_in_module, pymodule_impl, wrap_pyfunction_impl, wrap_pymodule_impl,
13-
PyClassArgs, PyClassMethodsType, PyFunctionOptions, PyModuleOptions, WrapPyFunctionArgs,
12+
get_doc, process_functions_in_module, pymodule_function_impl, pymodule_module_impl,
13+
wrap_pyfunction_impl, wrap_pymodule_impl, PyClassArgs, PyClassMethodsType, PyFunctionOptions,
14+
PyModuleOptions, WrapPyFunctionArgs,
1415
};
1516
use quote::quote;
1617
use syn::{parse::Nothing, parse_macro_input};
@@ -35,25 +36,30 @@ use syn::{parse::Nothing, parse_macro_input};
3536
pub fn pymodule(args: TokenStream, input: TokenStream) -> TokenStream {
3637
parse_macro_input!(args as Nothing);
3738

38-
let mut ast = parse_macro_input!(input as syn::ItemFn);
39-
let options = match PyModuleOptions::from_attrs(&mut ast.attrs) {
40-
Ok(options) => options,
41-
Err(e) => return e.into_compile_error().into(),
42-
};
43-
44-
if let Err(err) = process_functions_in_module(&mut ast) {
45-
return err.into_compile_error().into();
46-
}
39+
if let Ok(module) = syn::parse(input.clone()) {
40+
pymodule_module_impl(module)
41+
.unwrap_or_compile_error()
42+
.into()
43+
} else {
44+
let mut ast = parse_macro_input!(input as syn::ItemFn);
45+
let options = match PyModuleOptions::from_attrs(&mut ast.attrs) {
46+
Ok(options) => options,
47+
Err(e) => return e.into_compile_error().into(),
48+
};
4749

48-
let doc = get_doc(&ast.attrs, None);
50+
if let Err(err) = process_functions_in_module(&mut ast) {
51+
return err.into_compile_error().into();
52+
}
4953

50-
let expanded = pymodule_impl(&ast.sig.ident, options, doc, &ast.vis);
54+
let doc = get_doc(&ast.attrs, None);
5155

52-
quote!(
53-
#ast
54-
#expanded
55-
)
56-
.into()
56+
let expanded = pymodule_function_impl(&ast.sig.ident, options, doc, &ast.vis);
57+
quote!(
58+
#ast
59+
#expanded
60+
)
61+
.into()
62+
}
5763
}
5864

5965
/// A proc macro used to implement Python's [dunder methods][1].

pytests/src/lib.rs

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
use pyo3::prelude::*;
2-
use pyo3::types::PyDict;
3-
use pyo3::wrap_pymodule;
42

53
pub mod buf_and_str;
64
pub mod datetime;
@@ -14,35 +12,40 @@ pub mod pyfunctions;
1412
pub mod subclassing;
1513

1614
#[pymodule]
17-
fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {
18-
#[cfg(not(Py_LIMITED_API))]
19-
m.add_wrapped(wrap_pymodule!(buf_and_str::buf_and_str))?;
20-
#[cfg(not(Py_LIMITED_API))]
21-
m.add_wrapped(wrap_pymodule!(datetime::datetime))?;
22-
m.add_wrapped(wrap_pymodule!(dict_iter::dict_iter))?;
23-
m.add_wrapped(wrap_pymodule!(misc::misc))?;
24-
m.add_wrapped(wrap_pymodule!(objstore::objstore))?;
25-
m.add_wrapped(wrap_pymodule!(othermod::othermod))?;
26-
m.add_wrapped(wrap_pymodule!(path::path))?;
27-
m.add_wrapped(wrap_pymodule!(pyclasses::pyclasses))?;
28-
m.add_wrapped(wrap_pymodule!(pyfunctions::pyfunctions))?;
29-
m.add_wrapped(wrap_pymodule!(subclassing::subclassing))?;
15+
mod pyo3_pytests {
16+
use pyo3::types::{PyDict, PyModule};
17+
use pyo3::PyResult;
3018

31-
// Inserting to sys.modules allows importing submodules nicely from Python
32-
// e.g. import pyo3_pytests.buf_and_str as bas
19+
#[pyo3]
20+
use {
21+
// #[cfg(not(Py_LIMITED_API))]
22+
crate::buf_and_str::buf_and_str,
23+
// #[cfg(not(Py_LIMITED_API))]
24+
crate::datetime::datetime,
25+
crate::dict_iter::dict_iter,
26+
crate::misc::misc,
27+
crate::objstore::objstore,
28+
crate::othermod::othermod,
29+
crate::path::path,
30+
crate::pyclasses::pyclasses,
31+
crate::pyfunctions::pyfunctions,
32+
crate::subclassing::subclassing,
33+
};
3334

34-
let sys = PyModule::import(py, "sys")?;
35-
let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?;
36-
sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?;
37-
sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?;
38-
sys_modules.set_item("pyo3_pytests.dict_iter", m.getattr("dict_iter")?)?;
39-
sys_modules.set_item("pyo3_pytests.misc", m.getattr("misc")?)?;
40-
sys_modules.set_item("pyo3_pytests.objstore", m.getattr("objstore")?)?;
41-
sys_modules.set_item("pyo3_pytests.othermod", m.getattr("othermod")?)?;
42-
sys_modules.set_item("pyo3_pytests.path", m.getattr("path")?)?;
43-
sys_modules.set_item("pyo3_pytests.pyclasses", m.getattr("pyclasses")?)?;
44-
sys_modules.set_item("pyo3_pytests.pyfunctions", m.getattr("pyfunctions")?)?;
45-
sys_modules.set_item("pyo3_pytests.subclassing", m.getattr("subclassing")?)?;
46-
47-
Ok(())
35+
#[pymodule_init]
36+
fn init(m: &PyModule) -> PyResult<()> {
37+
let sys = PyModule::import(m.py(), "sys")?;
38+
let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?;
39+
sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?;
40+
sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?;
41+
sys_modules.set_item("pyo3_pytests.dict_iter", m.getattr("dict_iter")?)?;
42+
sys_modules.set_item("pyo3_pytests.misc", m.getattr("misc")?)?;
43+
sys_modules.set_item("pyo3_pytests.objstore", m.getattr("objstore")?)?;
44+
sys_modules.set_item("pyo3_pytests.othermod", m.getattr("othermod")?)?;
45+
sys_modules.set_item("pyo3_pytests.path", m.getattr("path")?)?;
46+
sys_modules.set_item("pyo3_pytests.pyclasses", m.getattr("pyclasses")?)?;
47+
sys_modules.set_item("pyo3_pytests.pyfunctions", m.getattr("pyfunctions")?)?;
48+
sys_modules.set_item("pyo3_pytests.subclassing", m.getattr("subclassing")?)?;
49+
Ok(())
50+
}
4851
}

src/impl_/pymodule.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ impl ModuleDef {
9090
}),
9191
)
9292
}
93+
94+
pub fn add_to_module(&'static self, module: &PyModule) -> PyResult<()> {
95+
module.add_object(self.make_module(module.py())?)
96+
}
9397
}
9498

9599
#[cfg(test)]

src/types/module.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
// based on Daniel Grunwald's https://github.com/dgrunwald/rust-cpython
44

55
use crate::callback::IntoPyCallbackOutput;
6-
use crate::err::{PyErr, PyResult};
6+
use crate::err::{self, PyErr, PyResult};
77
use crate::exceptions;
88
use crate::ffi;
99
use crate::pyclass::PyClass;
@@ -252,6 +252,16 @@ impl PyModule {
252252
self.setattr(name, value.into_py(self.py()))
253253
}
254254

255+
pub(crate) fn add_object(&self, value: PyObject) -> PyResult<()> {
256+
let py = self.py();
257+
let attr_name = value.getattr(py, "__name__")?;
258+
259+
unsafe {
260+
let ret = ffi::PyObject_SetAttr(self.as_ptr(), attr_name.as_ptr(), value.as_ptr());
261+
err::error_on_minusone(py, ret)
262+
}
263+
}
264+
255265
/// Adds a new class to the module.
256266
///
257267
/// Notice that this method does not take an argument.

tests/test_module.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,3 +443,35 @@ fn test_module_doc_hidden() {
443443
py_assert!(py, m, "m.__doc__ == ''");
444444
})
445445
}
446+
447+
/// A module written using declarative syntax.
448+
#[pymodule]
449+
mod declarative_module {
450+
451+
#[pyo3]
452+
use super::module_with_functions;
453+
}
454+
455+
#[test]
456+
fn test_declarative_module() {
457+
Python::with_gil(|py| {
458+
let m = pyo3::wrap_pymodule!(declarative_module)(py).into_ref(py);
459+
py_assert!(
460+
py,
461+
m,
462+
"m.__doc__ == 'A module written using declarative syntax.'"
463+
);
464+
465+
let submodule = m.getattr("module_with_functions").unwrap();
466+
assert_eq!(
467+
submodule
468+
.getattr("no_parameters")
469+
.unwrap()
470+
.call0()
471+
.unwrap()
472+
.extract::<i32>()
473+
.unwrap(),
474+
42
475+
);
476+
})
477+
}

0 commit comments

Comments
 (0)