Skip to content

Commit 449e917

Browse files
davidhewittbirkenfeld
authored andcommitted
WIP declarative module
1 parent 849b699 commit 449e917

File tree

7 files changed

+219
-56
lines changed

7 files changed

+219
-56
lines changed

pyo3-macros-backend/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ mod pyimpl;
2121
mod pymethod;
2222

2323
pub use frompyobject::build_derive_from_pyobject;
24-
pub use module::{process_functions_in_module, pymodule_impl, PyModuleOptions};
24+
pub use module::{
25+
process_functions_in_module, pymodule_function_impl, pymodule_module_impl, PyModuleOptions,
26+
};
2527
pub use pyclass::{build_py_class, build_py_enum, PyClassArgs};
2628
pub use pyfunction::{build_py_function, PyFunctionOptions};
2729
pub use pyimpl::{build_py_methods, PyClassMethodsType};

pyo3-macros-backend/src/module.rs

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use crate::{
44
attributes::{self, take_attributes, take_pyo3_options, CrateAttribute, NameAttribute},
5+
get_doc,
56
pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
67
utils::{get_pyo3_crate, PythonDoc},
78
};
@@ -56,9 +57,118 @@ impl PyModuleOptions {
5657
}
5758
}
5859

60+
pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
61+
let syn::ItemMod {
62+
attrs,
63+
vis,
64+
ident,
65+
mod_token,
66+
content,
67+
unsafety: _,
68+
semi: _,
69+
} = &mut module;
70+
let items = match content {
71+
Some((_, items)) => items,
72+
None => bail_spanned!(module.span() => "`#[pymodule]` can only be used on inline modules"),
73+
};
74+
let options = PyModuleOptions::from_attrs(attrs)?;
75+
let doc = get_doc(attrs, None);
76+
77+
let name = options.name.unwrap_or_else(|| ident.unraw());
78+
let krate = get_pyo3_crate(&options.krate);
79+
let pyinit_symbol = format!("PyInit_{}", name);
80+
81+
let mut module_items = Vec::new();
82+
let mut module_attrs = Vec::new();
83+
84+
fn extract_use_items(source: &syn::UseTree, cfg_attrs: &Vec<syn::Attribute>,
85+
names: &mut Vec<Ident>, attrs: &mut Vec<Vec<syn::Attribute>>) -> Result<()> {
86+
match source {
87+
syn::UseTree::Name(name) => {
88+
names.push(name.ident.clone());
89+
attrs.push(cfg_attrs.clone());
90+
}
91+
syn::UseTree::Path(path) => extract_use_items(&path.tree, cfg_attrs, names, attrs)?,
92+
syn::UseTree::Group(group) => {
93+
for tree in &group.items {
94+
extract_use_items(tree, cfg_attrs, names, attrs)?
95+
}
96+
}
97+
syn::UseTree::Glob(glob) => {
98+
bail_spanned!(glob.span() => "#[pyo3] cannot import glob statements")
99+
}
100+
syn::UseTree::Rename(rename) => {
101+
names.push(rename.ident.clone());
102+
attrs.push(cfg_attrs.clone());
103+
}
104+
}
105+
Ok(())
106+
}
107+
108+
let mut pymodule_init = None;
109+
110+
for item in items.iter_mut() {
111+
match item {
112+
syn::Item::Use(item_use) => {
113+
let mut is_pyo3 = false;
114+
item_use.attrs.retain(|attr| {
115+
let found = attr.path().is_ident("pyo3");
116+
is_pyo3 |= found;
117+
!found
118+
});
119+
if is_pyo3 {
120+
let cfg_attrs: Vec<_> = item_use.attrs.iter().filter(|attr| attr.path().is_ident("cfg")).map(Clone::clone).collect();
121+
extract_use_items(&item_use.tree, &cfg_attrs, &mut module_items, &mut module_attrs)?;
122+
}
123+
}
124+
syn::Item::Fn(item_fn) => {
125+
let mut is_module_init = false;
126+
item_fn.attrs.retain(|attr| {
127+
let found = attr.path().is_ident("pymodule_init");
128+
is_module_init |= found;
129+
!found
130+
});
131+
if is_module_init {
132+
ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one pymodule_init may be specified");
133+
let ident = &item_fn.sig.ident;
134+
pymodule_init = Some(quote! { #ident(module)?; });
135+
}
136+
}
137+
_ => {}
138+
}
139+
}
140+
141+
Ok(quote! {
142+
#vis #mod_token #ident {
143+
#(#items)*
144+
145+
pub static DEF: #krate::impl_::pymodule::ModuleDef = unsafe {
146+
use #krate::impl_::pymodule as impl_;
147+
impl_::ModuleDef::new(concat!(stringify!(#name), "\0"), #doc, impl_::ModuleInitializer(__pyo3_pymodule))
148+
};
149+
150+
pub fn __pyo3_pymodule(_py: #krate::Python, module: &#krate::types::PyModule) -> #krate::PyResult<()> {
151+
#(
152+
#(#module_attrs)*
153+
#module_items::DEF.add_to_module(module)?;
154+
)*
155+
#pymodule_init
156+
Ok(())
157+
}
158+
159+
/// This autogenerated function is called by the python interpreter when importing
160+
/// the module.
161+
#[export_name = #pyinit_symbol]
162+
pub unsafe extern "C" fn __pyo3_init() -> *mut #krate::ffi::PyObject {
163+
#krate::impl_::trampoline::module_init(|py| DEF.make_module(py))
164+
}
165+
}
166+
})
167+
}
168+
59169
/// Generates the function that is called by the python interpreter to initialize the native
60170
/// module
61-
pub fn pymodule_impl(
171+
pub fn pymodule_function_impl(
62172
fnname: &Ident,
63173
options: PyModuleOptions,
64174
doc: PythonDoc,

pyo3-macros/src/lib.rs

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ use proc_macro::TokenStream;
88
use proc_macro2::TokenStream as TokenStream2;
99
use pyo3_macros_backend::{
1010
build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods,
11-
get_doc, process_functions_in_module, pymodule_impl, PyClassArgs, PyClassMethodsType,
12-
PyFunctionOptions, PyModuleOptions,
11+
get_doc, process_functions_in_module, pymodule_function_impl, pymodule_module_impl,
12+
PyClassArgs, PyClassMethodsType, PyFunctionOptions, PyModuleOptions,
1313
};
1414
use quote::quote;
1515
use syn::{parse::Nothing, parse_macro_input};
@@ -39,25 +39,30 @@ use syn::{parse::Nothing, parse_macro_input};
3939
pub fn pymodule(args: TokenStream, input: TokenStream) -> TokenStream {
4040
parse_macro_input!(args as Nothing);
4141

42-
let mut ast = parse_macro_input!(input as syn::ItemFn);
43-
let options = match PyModuleOptions::from_attrs(&mut ast.attrs) {
44-
Ok(options) => options,
45-
Err(e) => return e.into_compile_error().into(),
46-
};
47-
48-
if let Err(err) = process_functions_in_module(&options, &mut ast) {
49-
return err.into_compile_error().into();
50-
}
51-
52-
let doc = get_doc(&ast.attrs, None);
42+
if let Ok(module) = syn::parse(input.clone()) {
43+
pymodule_module_impl(module)
44+
.unwrap_or_compile_error()
45+
.into()
46+
} else {
47+
let mut ast = parse_macro_input!(input as syn::ItemFn);
48+
let options = match PyModuleOptions::from_attrs(&mut ast.attrs) {
49+
Ok(options) => options,
50+
Err(e) => return e.into_compile_error().into(),
51+
};
52+
53+
if let Err(err) = process_functions_in_module(&options, &mut ast) {
54+
return err.into_compile_error().into();
55+
}
5356

54-
let expanded = pymodule_impl(&ast.sig.ident, options, doc, &ast.vis);
57+
let doc = get_doc(&ast.attrs, None);
5558

56-
quote!(
57-
#ast
58-
#expanded
59-
)
60-
.into()
59+
let expanded = pymodule_function_impl(&ast.sig.ident, options, doc, &ast.vis);
60+
quote!(
61+
#ast
62+
#expanded
63+
)
64+
.into()
65+
}
6166
}
6267

6368
#[proc_macro_attribute]

pytests/src/lib.rs

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
use pyo3::prelude::*;
2-
use pyo3::types::PyDict;
3-
use pyo3::wrap_pymodule;
42

53
pub mod buf_and_str;
64
pub mod comparisons;
@@ -16,39 +14,41 @@ pub mod sequence;
1614
pub mod subclassing;
1715

1816
#[pymodule]
19-
fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {
20-
#[cfg(not(Py_LIMITED_API))]
21-
m.add_wrapped(wrap_pymodule!(buf_and_str::buf_and_str))?;
22-
m.add_wrapped(wrap_pymodule!(comparisons::comparisons))?;
23-
#[cfg(not(Py_LIMITED_API))]
24-
m.add_wrapped(wrap_pymodule!(datetime::datetime))?;
25-
m.add_wrapped(wrap_pymodule!(dict_iter::dict_iter))?;
26-
m.add_wrapped(wrap_pymodule!(misc::misc))?;
27-
m.add_wrapped(wrap_pymodule!(objstore::objstore))?;
28-
m.add_wrapped(wrap_pymodule!(othermod::othermod))?;
29-
m.add_wrapped(wrap_pymodule!(path::path))?;
30-
m.add_wrapped(wrap_pymodule!(pyclasses::pyclasses))?;
31-
m.add_wrapped(wrap_pymodule!(pyfunctions::pyfunctions))?;
32-
m.add_wrapped(wrap_pymodule!(sequence::sequence))?;
33-
m.add_wrapped(wrap_pymodule!(subclassing::subclassing))?;
17+
mod pyo3_pytests {
18+
use pyo3::types::{PyDict, PyModule};
19+
use pyo3::PyResult;
3420

35-
// Inserting to sys.modules allows importing submodules nicely from Python
36-
// e.g. import pyo3_pytests.buf_and_str as bas
21+
#[pyo3]
22+
use {
23+
crate::comparisons::comparisons, crate::dict_iter::dict_iter, crate::misc::misc,
24+
crate::objstore::objstore, crate::othermod::othermod, crate::path::path,
25+
crate::pyclasses::pyclasses, crate::pyfunctions::pyfunctions, crate::sequence::sequence,
26+
crate::subclassing::subclassing,
27+
};
28+
29+
#[pyo3]
30+
#[cfg(not(Py_LIMITED_API))]
31+
use {crate::buf_and_str::buf_and_str, crate::datetime::datetime};
3732

38-
let sys = PyModule::import(py, "sys")?;
39-
let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?;
40-
sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?;
41-
sys_modules.set_item("pyo3_pytests.comparisons", m.getattr("comparisons")?)?;
42-
sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?;
43-
sys_modules.set_item("pyo3_pytests.dict_iter", m.getattr("dict_iter")?)?;
44-
sys_modules.set_item("pyo3_pytests.misc", m.getattr("misc")?)?;
45-
sys_modules.set_item("pyo3_pytests.objstore", m.getattr("objstore")?)?;
46-
sys_modules.set_item("pyo3_pytests.othermod", m.getattr("othermod")?)?;
47-
sys_modules.set_item("pyo3_pytests.path", m.getattr("path")?)?;
48-
sys_modules.set_item("pyo3_pytests.pyclasses", m.getattr("pyclasses")?)?;
49-
sys_modules.set_item("pyo3_pytests.pyfunctions", m.getattr("pyfunctions")?)?;
50-
sys_modules.set_item("pyo3_pytests.sequence", m.getattr("sequence")?)?;
51-
sys_modules.set_item("pyo3_pytests.subclassing", m.getattr("subclassing")?)?;
33+
#[pymodule_init]
34+
fn init(m: &PyModule) -> PyResult<()> {
35+
let sys = PyModule::import(m.py(), "sys")?;
36+
let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?;
37+
#[cfg(not(Py_LIMITED_API))]
38+
sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?;
39+
sys_modules.set_item("pyo3_pytests.comparisons", m.getattr("comparisons")?)?;
40+
#[cfg(not(Py_LIMITED_API))]
41+
sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?;
42+
sys_modules.set_item("pyo3_pytests.dict_iter", m.getattr("dict_iter")?)?;
43+
sys_modules.set_item("pyo3_pytests.misc", m.getattr("misc")?)?;
44+
sys_modules.set_item("pyo3_pytests.objstore", m.getattr("objstore")?)?;
45+
sys_modules.set_item("pyo3_pytests.othermod", m.getattr("othermod")?)?;
46+
sys_modules.set_item("pyo3_pytests.path", m.getattr("path")?)?;
47+
sys_modules.set_item("pyo3_pytests.pyclasses", m.getattr("pyclasses")?)?;
48+
sys_modules.set_item("pyo3_pytests.pyfunctions", m.getattr("pyfunctions")?)?;
49+
sys_modules.set_item("pyo3_pytests.sequence", m.getattr("sequence")?)?;
50+
sys_modules.set_item("pyo3_pytests.subclassing", m.getattr("subclassing")?)?;
5251

53-
Ok(())
52+
Ok(())
53+
}
5454
}

src/impl_/pymodule.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ impl ModuleDef {
8282
(self.initializer.0)(py, module.as_ref(py))?;
8383
Ok(module)
8484
}
85+
86+
pub fn add_to_module(&'static self, module: &PyModule) -> PyResult<()> {
87+
module.add_object(self.make_module(module.py())?.into())
88+
}
8589
}
8690

8791
#[cfg(test)]

src/types/module.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::callback::IntoPyCallbackOutput;
2-
use crate::err::{PyErr, PyResult};
2+
use crate::err::{self, PyErr, PyResult};
33
use crate::exceptions;
44
use crate::ffi;
55
use crate::pyclass::PyClass;
@@ -248,6 +248,16 @@ impl PyModule {
248248
self.setattr(name, value.into_py(self.py()))
249249
}
250250

251+
pub(crate) fn add_object(&self, value: PyObject) -> PyResult<()> {
252+
let py = self.py();
253+
let attr_name = value.getattr(py, "__name__")?;
254+
255+
unsafe {
256+
let ret = ffi::PyObject_SetAttr(self.as_ptr(), attr_name.as_ptr(), value.as_ptr());
257+
err::error_on_minusone(py, ret)
258+
}
259+
}
260+
251261
/// Adds a new class to the module.
252262
///
253263
/// Notice that this method does not take an argument.

tests/test_module.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,3 +450,35 @@ fn test_module_doc_hidden() {
450450
py_assert!(py, m, "m.__doc__ == ''");
451451
})
452452
}
453+
454+
/// A module written using declarative syntax.
455+
#[pymodule]
456+
mod declarative_module {
457+
458+
#[pyo3]
459+
use super::module_with_functions;
460+
}
461+
462+
#[test]
463+
fn test_declarative_module() {
464+
Python::with_gil(|py| {
465+
let m = pyo3::wrap_pymodule!(declarative_module)(py).into_ref(py);
466+
py_assert!(
467+
py,
468+
m,
469+
"m.__doc__ == 'A module written using declarative syntax.'"
470+
);
471+
472+
let submodule = m.getattr("module_with_functions").unwrap();
473+
assert_eq!(
474+
submodule
475+
.getattr("no_parameters")
476+
.unwrap()
477+
.call0()
478+
.unwrap()
479+
.extract::<i32>()
480+
.unwrap(),
481+
42
482+
);
483+
})
484+
}

0 commit comments

Comments
 (0)