Skip to content

Commit

Permalink
add option to import modules into the main interpreter
Browse files Browse the repository at this point in the history
- some modules crash the subinterpreters if this doesn't happen
- add pytest as a default (as it is required for it)
- re-enable the pytest mark tests
  • Loading branch information
brownben committed Nov 27, 2024
1 parent 74070f4 commit fe2405d
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 8 deletions.
12 changes: 12 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ pub(crate) struct Settings {
#[clap(long, default_value_t = false)]
pub no_fail_fast: bool,

/// List of (comma separated) modules to try to import in the main interpreter.
///
/// Some modules need imported in the main interpreter, else they crash the subinterpreters.
#[clap(
long,
value_name = "MODULE_NAME",
num_args = 1..,
value_delimiter = ',',
default_value = "pytest",
)]
pub known_imports: Vec<String>,

/// How test results should be reported
#[clap(long, value_enum, default_value_t = OutputFormat::Standard)]
pub output: OutputFormat,
Expand Down
6 changes: 5 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ fn main() -> ExitCode {
interpreter.with_gil(|python| {
// The decimal module crashes Python 3.12 if it is initialised multiple times
// If not initialised in the base interpreter, if a subinterpreter imports it it will crash
_ = python.import_module(c"decimal");
_ = python.import_known_module(c"decimal");

for module in &settings.known_imports {
_ = python.import_module(module);
}
});

// Run tests
Expand Down
1 change: 1 addition & 0 deletions src/python/objects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use pyo3_ffi::{self as ffi};
use std::{ffi::CStr, fmt, marker::PhantomData, ops::Deref, ptr::NonNull};

/// Represents a Python object
#[must_use]
pub struct PyObject(NonNull<ffi::PyObject>);
impl PyObject {
/// Gets the underlying pointer for the object
Expand Down
23 changes: 17 additions & 6 deletions src/python/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ impl ActiveInterpreter {
unsafe { PyObject::from_ptr_unchecked(result) }
}

/// Imports a module
/// Imports a known module
///
/// SAFETY: Assumes that the module exists
pub fn import_module(&self, module: &CStr) -> PyObject {
pub fn import_known_module(&self, module: &CStr) -> PyObject {
let result = unsafe { ffi::PyImport_ImportModule(module.as_ptr().cast()) };

debug_assert!(!result.is_null());
Expand All @@ -40,12 +40,23 @@ impl ActiveInterpreter {
object
}

/// Imports a module which may not exist
#[must_use]
pub fn import_module(&self, module: &str) -> Option<PyObject> {
let module = self.new_string(module);
let result = unsafe { ffi::PyImport_Import(module.as_ptr()) };
if result.is_null() {
PyError::clear();
}
PyObject::from_ptr(result)
}

/// Redirect stdout and stderr from Python into a string
///
/// Captured output can be fetched by [`Self::get_captured_output`]
pub fn capture_output(&self) {
let sys = self.import_module(c"sys");
let io = self.import_module(c"io");
let sys = self.import_known_module(c"sys");
let io = self.import_known_module(c"io");

let string_io = io.get_attr(&self.new_string("StringIO")).unwrap();
let stdout_io = unsafe { string_io.call_unchecked() };
Expand All @@ -57,7 +68,7 @@ impl ActiveInterpreter {

/// Get the captured stdout and stderr
pub fn get_captured_output(&self) -> (Option<String>, Option<String>) {
let sys = self.import_module(c"sys");
let sys = self.import_known_module(c"sys");

let stdout = sys.get_attr(&self.new_string("stdout")).unwrap();
let stderr = sys.get_attr(&self.new_string("stderr")).unwrap();
Expand All @@ -84,7 +95,7 @@ impl ActiveInterpreter {
/// Most commonly used to add the current folder to the module search path.
/// Assumes Python Interpreter is currently active.
pub fn add_to_sys_modules_path(&self, path: &CStr) {
let sys = self.import_module(c"sys");
let sys = self.import_known_module(c"sys");
let path_list = sys.get_attr(&self.new_string("path")).unwrap();

unsafe {
Expand Down
1 change: 1 addition & 0 deletions tests/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ execution_test!(invalid_remove);
execution_test!(isolated);
#[cfg(feature = "ci")] // Takes a long time, so don't want it slowing down developement cycles
execution_test!(long_running);
execution_test!(pytest_marks);
execution_test!(setup_teardown);
execution_test!(skip_tests);
execution_test!(times); // No tests are in this file, just a standard python file
5 changes: 4 additions & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Requirements for ./execution/dependencies.txt
# Requirements for ./execution/dependencies.py
astroid
yarl

# Requirements for ./execution/pytest_marks.py
pytest

0 comments on commit fe2405d

Please sign in to comment.