Skip to content

Commit bbe4393

Browse files
committed
Add more tests in method_with_pyclassarg
1 parent e63e0cb commit bbe4393

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

tests/test_methods.rs

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -426,11 +426,19 @@ impl MethodWithPyClassArg {
426426
fn inplace_add(&self, other: &mut MethodWithPyClassArg) {
427427
other.value += self.value;
428428
}
429+
fn inplace_add_pyref(&self, mut other: PyRefMut<MethodWithPyClassArg>) {
430+
other.value += self.value;
431+
}
429432
fn optional_add(&self, other: Option<&MethodWithPyClassArg>) -> MethodWithPyClassArg {
430433
MethodWithPyClassArg {
431434
value: self.value + other.map(|o| o.value).unwrap_or(10),
432435
}
433436
}
437+
fn optional_inplace_add(&self, other: Option<&mut MethodWithPyClassArg>) {
438+
if let Some(other) = other {
439+
other.value += self.value;
440+
}
441+
}
434442
}
435443

436444
#[test]
@@ -439,26 +447,20 @@ fn method_with_pyclassarg() {
439447
let py = gil.python();
440448
let obj1 = PyCell::new(py, MethodWithPyClassArg { value: 10 }).unwrap();
441449
let obj2 = PyCell::new(py, MethodWithPyClassArg { value: 10 }).unwrap();
442-
py_run!(
443-
py,
444-
obj1 obj2,
445-
"obj = obj1.add(obj2); assert obj.value == 20"
446-
);
447-
py_run!(
448-
py,
449-
obj1 obj2,
450-
"obj = obj1.add_pyref(obj2); assert obj.value == 20"
451-
);
452-
py_run!(
453-
py,
454-
obj1 obj2,
455-
"obj = obj1.optional_add(); assert obj.value == 20"
456-
);
457-
py_run!(
458-
py,
459-
obj1 obj2,
460-
"obj1.inplace_add(obj2); assert obj2.value == 20"
461-
);
450+
let objs = [("obj1", obj1), ("obj2", obj2)].into_py_dict(py);
451+
let run = |code| {
452+
py.run(code, None, Some(objs))
453+
.map_err(|e| e.print(py))
454+
.unwrap()
455+
};
456+
run("obj = obj1.add(obj2); assert obj.value == 20");
457+
run("obj = obj1.add_pyref(obj2); assert obj.value == 20");
458+
run("obj = obj1.optional_add(); assert obj.value == 20");
459+
run("obj = obj1.optional_add(obj2); assert obj.value == 20");
460+
run("obj1.inplace_add(obj2); assert obj.value == 20");
461+
run("obj1.inplace_add_pyref(obj2); assert obj.value == 30");
462+
run("obj1.optional_inplace_add(obj2); assert obj.value == 40");
463+
run("obj1.optional_inplace_add(); assert obj.value == 40");
462464
}
463465

464466
#[pyclass]

0 commit comments

Comments
 (0)