Skip to content

Commit acf7271

Browse files
authored
Merge pull request #1494 from PyO3/enhance-py-run
Extend py_run! to take locals dict and refactor tests using it
2 parents 6137e3a + 9b88a45 commit acf7271

10 files changed

+279
-210
lines changed

src/lib.rs

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -289,10 +289,10 @@ macro_rules! wrap_pymodule {
289289
/// # Example
290290
/// ```
291291
/// use pyo3::{prelude::*, py_run, types::PyList};
292-
/// let gil = Python::acquire_gil();
293-
/// let py = gil.python();
294-
/// let list = PyList::new(py, &[1, 2, 3]);
295-
/// py_run!(py, list, "assert list == [1, 2, 3]");
292+
/// Python::with_gil(|py| {
293+
/// let list = PyList::new(py, &[1, 2, 3]);
294+
/// py_run!(py, list, "assert list == [1, 2, 3]");
295+
/// });
296296
/// ```
297297
///
298298
/// You can use this macro to test pyfunctions or pyclasses quickly.
@@ -320,15 +320,33 @@ macro_rules! wrap_pymodule {
320320
/// (self.hour, self.minute, self.second)
321321
/// }
322322
/// }
323-
/// let gil = Python::acquire_gil();
324-
/// let py = gil.python();
325-
/// let time = PyCell::new(py, Time {hour: 8, minute: 43, second: 16}).unwrap();
326-
/// let time_as_tuple = (8, 43, 16);
327-
/// py_run!(py, time time_as_tuple, r#"
328-
/// assert time.hour == 8
329-
/// assert time.repl_japanese() == "8時43分16秒"
330-
/// assert time.as_tuple() == time_as_tuple
331-
/// "#);
323+
/// Python::with_gil(|py| {
324+
/// let time = PyCell::new(py, Time {hour: 8, minute: 43, second: 16}).unwrap();
325+
/// let time_as_tuple = (8, 43, 16);
326+
/// py_run!(py, time time_as_tuple, r#"
327+
/// assert time.hour == 8
328+
/// assert time.repl_japanese() == "8時43分16秒"
329+
/// assert time.as_tuple() == time_as_tuple
330+
/// "#);
331+
/// });
332+
/// ```
333+
///
334+
/// If you need to prepare the `locals` dict by yourself, you can pass it as `*locals`.
335+
///
336+
/// ```
337+
/// use pyo3::prelude::*;
338+
/// use pyo3::types::IntoPyDict;
339+
/// #[pyclass]
340+
/// struct MyClass {}
341+
/// #[pymethods]
342+
/// impl MyClass {
343+
/// #[new]
344+
/// fn new() -> Self { MyClass {} }
345+
/// }
346+
/// Python::with_gil(|py| {
347+
/// let locals = [("C", py.get_type::<MyClass>())].into_py_dict(py);
348+
/// pyo3::py_run!(py, *locals, "c = C()");
349+
/// });
332350
/// ```
333351
///
334352
/// **Note**
@@ -345,6 +363,12 @@ macro_rules! py_run {
345363
($py:expr, $($val:ident)+, $code:expr) => {{
346364
$crate::py_run_impl!($py, $($val)+, &$crate::unindent::unindent($code))
347365
}};
366+
($py:expr, *$dict:expr, $code:literal) => {{
367+
$crate::py_run_impl!($py, *$dict, $crate::indoc::indoc!($code))
368+
}};
369+
($py:expr, *$dict:expr, $code:expr) => {{
370+
$crate::py_run_impl!($py, *$dict, &$crate::unindent::unindent($code))
371+
}};
348372
}
349373

350374
#[macro_export]
@@ -355,8 +379,10 @@ macro_rules! py_run_impl {
355379
use $crate::types::IntoPyDict;
356380
use $crate::ToPyObject;
357381
let d = [$((stringify!($val), $val.to_object($py)),)+].into_py_dict($py);
358-
359-
if let Err(e) = $py.run($code, None, Some(d)) {
382+
$crate::py_run_impl!($py, *d, $code)
383+
}};
384+
($py:expr, *$dict:expr, $code:expr) => {{
385+
if let Err(e) = $py.run($code, None, Some($dict)) {
360386
e.print($py);
361387
// So when this c api function the last line called printed the error to stderr,
362388
// the output is only written into a buffer which is never flushed because we

tests/common.rs

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,43 @@
1-
//! Useful tips for writing tests:
2-
//! - Tests are run in parallel; There's still a race condition in test_owned with some other test
3-
//! - You need to use flush=True to get any output from print
1+
//! Some common macros for tests
42
53
#[macro_export]
64
macro_rules! py_assert {
7-
($py:expr, $val:ident, $assertion:expr) => {
8-
pyo3::py_run!($py, $val, concat!("assert ", $assertion))
5+
($py:expr, $($val:ident)+, $assertion:literal) => {
6+
pyo3::py_run!($py, $($val)+, concat!("assert ", $assertion))
7+
};
8+
($py:expr, *$dict:expr, $assertion:literal) => {
9+
pyo3::py_run!($py, *$dict, concat!("assert ", $assertion))
910
};
1011
}
1112

1213
#[macro_export]
1314
macro_rules! py_expect_exception {
14-
($py:expr, $val:ident, $code:expr, $err:ident) => {{
15+
// Case1: idents & no err_msg
16+
($py:expr, $($val:ident)+, $code:expr, $err:ident) => {{
1517
use pyo3::types::IntoPyDict;
16-
let d = [(stringify!($val), &$val)].into_py_dict($py);
17-
18-
let res = $py.run($code, None, Some(d));
18+
let d = [$((stringify!($val), $val.to_object($py)),)+].into_py_dict($py);
19+
py_expect_exception!($py, *d, $code, $err)
20+
}};
21+
// Case2: dict & no err_msg
22+
($py:expr, *$dict:expr, $code:expr, $err:ident) => {{
23+
let res = $py.run($code, None, Some($dict));
1924
let err = res.expect_err(&format!("Did not raise {}", stringify!($err)));
2025
if !err.matches($py, $py.get_type::<pyo3::exceptions::$err>()) {
2126
panic!("Expected {} but got {:?}", stringify!($err), err)
2227
}
2328
err
2429
}};
25-
($py:expr, $val:ident, $code:expr, $err:ident, $err_msg:expr) => {{
26-
let err = py_expect_exception!($py, $val, $code, $err);
27-
assert_eq!(
28-
err.instance($py)
29-
.str()
30-
.expect("error str() failed")
31-
.to_str()
32-
.expect("message was not valid utf8"),
33-
$err_msg
34-
);
30+
// Case3: idents & err_msg
31+
($py:expr, $($val:ident)+, $code:expr, $err:ident, $err_msg:literal) => {{
32+
let err = py_expect_exception!($py, $($val)+, $code, $err);
33+
// Suppose that the error message looks like 'TypeError: ~'
34+
assert_eq!(format!("Py{}", err), concat!(stringify!($err), ": ", $err_msg));
35+
err
36+
}};
37+
// Case4: dict & err_msg
38+
($py:expr, *$dict:expr, $code:expr, $err:ident, $err_msg:literal) => {{
39+
let err = py_expect_exception!($py, *$dict, $code, $err);
40+
assert_eq!(format!("Py{}", err), concat!(stringify!($err), ": ", $err_msg));
3541
err
3642
}};
3743
}

tests/test_buffer_protocol.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ use std::ptr;
1313
use std::sync::atomic::{AtomicBool, Ordering};
1414
use std::sync::Arc;
1515

16+
mod common;
17+
1618
#[pyclass]
1719
struct TestBufferClass {
1820
vec: Vec<u8>,
@@ -93,8 +95,7 @@ fn test_buffer() {
9395
)
9496
.unwrap();
9597
let env = [("ob", instance)].into_py_dict(py);
96-
py.run("assert bytes(ob) == b' 23'", None, Some(env))
97-
.unwrap();
98+
py_assert!(py, *env, "bytes(ob) == b' 23'");
9899
}
99100

100101
assert!(drop_called.load(Ordering::Relaxed));

tests/test_dunder.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use pyo3::class::{
44
};
55
use pyo3::exceptions::{PyIndexError, PyValueError};
66
use pyo3::prelude::*;
7-
use pyo3::types::{IntoPyDict, PySlice, PyType};
7+
use pyo3::types::{PySlice, PyType};
88
use pyo3::{ffi, py_run, AsPyPointer, PyCell};
99
use std::convert::TryFrom;
1010
use std::{isize, iter};
@@ -450,11 +450,9 @@ fn test_cls_impl() {
450450
let py = gil.python();
451451

452452
let ob = Py::new(py, Test {}).unwrap();
453-
let d = [("ob", ob)].into_py_dict(py);
454453

455-
py.run("assert ob[1] == 'int'", None, Some(d)).unwrap();
456-
py.run("assert ob[100:200:1] == 'slice'", None, Some(d))
457-
.unwrap();
454+
py_assert!(py, ob, "ob[1] == 'int'");
455+
py_assert!(py, ob, "ob[100:200:1] == 'slice'");
458456
}
459457

460458
#[pyclass(dict, subclass)]

tests/test_getter_setter.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,7 @@ fn class_with_properties() {
6060
py_run!(py, inst, "assert inst.data_list == [42]");
6161

6262
let d = [("C", py.get_type::<ClassWithProperties>())].into_py_dict(py);
63-
py.run(
64-
"assert C.DATA.__doc__ == 'a getter for data'",
65-
None,
66-
Some(d),
67-
)
68-
.unwrap();
63+
py_assert!(py, *d, "C.DATA.__doc__ == 'a getter for data'");
6964
}
7065

7166
#[pyclass]

tests/test_mapping.rs

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ use std::collections::HashMap;
22

33
use pyo3::exceptions::PyKeyError;
44
use pyo3::prelude::*;
5+
use pyo3::py_run;
56
use pyo3::types::IntoPyDict;
67
use pyo3::types::PyList;
78
use pyo3::PyMappingProtocol;
89

10+
mod common;
11+
912
#[pyclass]
1013
struct Mapping {
1114
index: HashMap<String, usize>,
@@ -66,61 +69,58 @@ impl PyMappingProtocol for Mapping {
6669
}
6770
}
6871

72+
/// Return a dict with `m = Mapping(['1', '2', '3'])`.
73+
fn map_dict(py: Python) -> &pyo3::types::PyDict {
74+
let d = [("Mapping", py.get_type::<Mapping>())].into_py_dict(py);
75+
py_run!(py, *d, "m = Mapping(['1', '2', '3'])");
76+
d
77+
}
78+
6979
#[test]
7080
fn test_getitem() {
7181
let gil = Python::acquire_gil();
7282
let py = gil.python();
73-
let d = [("Mapping", py.get_type::<Mapping>())].into_py_dict(py);
74-
75-
let run = |code| py.run(code, None, Some(d)).unwrap();
76-
let err = |code| py.run(code, None, Some(d)).unwrap_err();
83+
let d = map_dict(py);
7784

78-
run("m = Mapping(['1', '2', '3']); assert m['1'] == 0");
79-
run("m = Mapping(['1', '2', '3']); assert m['2'] == 1");
80-
run("m = Mapping(['1', '2', '3']); assert m['3'] == 2");
81-
err("m = Mapping(['1', '2', '3']); print(m['4'])");
85+
py_assert!(py, *d, "m['1'] == 0");
86+
py_assert!(py, *d, "m['2'] == 1");
87+
py_assert!(py, *d, "m['3'] == 2");
88+
py_expect_exception!(py, *d, "print(m['4'])", PyKeyError);
8289
}
8390

8491
#[test]
8592
fn test_setitem() {
8693
let gil = Python::acquire_gil();
8794
let py = gil.python();
88-
let d = [("Mapping", py.get_type::<Mapping>())].into_py_dict(py);
89-
90-
let run = |code| py.run(code, None, Some(d)).unwrap();
91-
let err = |code| py.run(code, None, Some(d)).unwrap_err();
95+
let d = map_dict(py);
9296

93-
run("m = Mapping(['1', '2', '3']); m['1'] = 4; assert m['1'] == 4");
94-
run("m = Mapping(['1', '2', '3']); m['0'] = 0; assert m['0'] == 0");
95-
run("m = Mapping(['1', '2', '3']); len(m) == 4");
96-
err("m = Mapping(['1', '2', '3']); m[0] = 'hello'");
97-
err("m = Mapping(['1', '2', '3']); m[0] = -1");
97+
py_run!(py, *d, "m['1'] = 4; assert m['1'] == 4");
98+
py_run!(py, *d, "m['0'] = 0; assert m['0'] == 0");
99+
py_assert!(py, *d, "len(m) == 4");
100+
py_expect_exception!(py, *d, "m[0] = 'hello'", PyTypeError);
101+
py_expect_exception!(py, *d, "m[0] = -1", PyTypeError);
98102
}
99103

100104
#[test]
101105
fn test_delitem() {
102106
let gil = Python::acquire_gil();
103107
let py = gil.python();
104108

105-
let d = [("Mapping", py.get_type::<Mapping>())].into_py_dict(py);
106-
let run = |code| py.run(code, None, Some(d)).unwrap();
107-
let err = |code| py.run(code, None, Some(d)).unwrap_err();
108-
109-
run(
110-
"m = Mapping(['1', '2', '3']); del m['1']; assert len(m) == 2; \
111-
assert m['2'] == 1; assert m['3'] == 2",
109+
let d = map_dict(py);
110+
py_run!(
111+
py,
112+
*d,
113+
"del m['1']; assert len(m) == 2 and m['2'] == 1 and m['3'] == 2"
112114
);
113-
err("m = Mapping(['1', '2', '3']); del m[-1]");
114-
err("m = Mapping(['1', '2', '3']); del m['4']");
115+
py_expect_exception!(py, *d, "del m[-1]", PyTypeError);
116+
py_expect_exception!(py, *d, "del m['4']", PyKeyError);
115117
}
116118

117119
#[test]
118120
fn test_reversed() {
119121
let gil = Python::acquire_gil();
120122
let py = gil.python();
121123

122-
let d = [("Mapping", py.get_type::<Mapping>())].into_py_dict(py);
123-
let run = |code| py.run(code, None, Some(d)).unwrap();
124-
125-
run("m = Mapping(['1', '2']); assert set(reversed(m)) == {'1', '2'}");
124+
let d = map_dict(py);
125+
py_assert!(py, *d, "set(reversed(m)) == {'1', '2', '3'}");
126126
}

0 commit comments

Comments
 (0)