Skip to content

Commit 54390bc

Browse files
authored
Merge pull request #3712 from alex/binops
add PyAnyMethods for binary operators
2 parents e1fcb4e + 339660c commit 54390bc

File tree

4 files changed

+132
-0
lines changed

4 files changed

+132
-0
lines changed

newsfragments/3712.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added methods to `PyAnyMethods` for binary operators (`add`, `sub`, etc.)

src/tests/common.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ mod inner {
2323
};
2424
}
2525

26+
#[macro_export]
27+
macro_rules! assert_py_eq {
28+
($val:expr, $expected:expr) => {
29+
assert!($val.eq($expected).unwrap());
30+
};
31+
}
32+
2633
#[macro_export]
2734
macro_rules! py_expect_exception {
2835
// Case1: idents & no err_msg

src/types/any.rs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,58 @@ pub trait PyAnyMethods<'py> {
12081208
where
12091209
O: ToPyObject;
12101210

1211+
/// Computes `self + other`.
1212+
fn add<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
1213+
where
1214+
O: ToPyObject;
1215+
1216+
/// Computes `self - other`.
1217+
fn sub<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
1218+
where
1219+
O: ToPyObject;
1220+
1221+
/// Computes `self * other`.
1222+
fn mul<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
1223+
where
1224+
O: ToPyObject;
1225+
1226+
/// Computes `self / other`.
1227+
fn div<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
1228+
where
1229+
O: ToPyObject;
1230+
1231+
/// Computes `self << other`.
1232+
fn lshift<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
1233+
where
1234+
O: ToPyObject;
1235+
1236+
/// Computes `self >> other`.
1237+
fn rshift<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
1238+
where
1239+
O: ToPyObject;
1240+
1241+
/// Computes `self ** other % modulus` (`pow(self, other, modulus)`).
1242+
/// `py.None()` may be passed for the `modulus`.
1243+
fn pow<O1, O2>(&self, other: O1, modulus: O2) -> PyResult<Bound<'py, PyAny>>
1244+
where
1245+
O1: ToPyObject,
1246+
O2: ToPyObject;
1247+
1248+
/// Computes `self & other`.
1249+
fn bitand<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
1250+
where
1251+
O: ToPyObject;
1252+
1253+
/// Computes `self | other`.
1254+
fn bitor<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
1255+
where
1256+
O: ToPyObject;
1257+
1258+
/// Computes `self ^ other`.
1259+
fn bitxor<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
1260+
where
1261+
O: ToPyObject;
1262+
12111263
/// Determines whether this object appears callable.
12121264
///
12131265
/// This is equivalent to Python's [`callable()`][1] function.
@@ -1680,6 +1732,26 @@ pub trait PyAnyMethods<'py> {
16801732
fn py_super(&self) -> PyResult<Bound<'py, PySuper>>;
16811733
}
16821734

1735+
macro_rules! implement_binop {
1736+
($name:ident, $c_api:ident, $op:expr) => {
1737+
#[doc = concat!("Computes `self ", $op, " other`.")]
1738+
fn $name<O>(&self, other: O) -> PyResult<Bound<'py, PyAny>>
1739+
where
1740+
O: ToPyObject,
1741+
{
1742+
fn inner<'py>(
1743+
any: &Bound<'py, PyAny>,
1744+
other: Bound<'_, PyAny>,
1745+
) -> PyResult<Bound<'py, PyAny>> {
1746+
unsafe { ffi::$c_api(any.as_ptr(), other.as_ptr()).assume_owned_or_err(any.py()) }
1747+
}
1748+
1749+
let py = self.py();
1750+
inner(self, other.to_object(py).into_bound(py))
1751+
}
1752+
};
1753+
}
1754+
16831755
impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> {
16841756
#[inline]
16851757
fn is<T: AsPyPointer>(&self, other: &T) -> bool {
@@ -1855,6 +1927,42 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> {
18551927
.and_then(|any| any.is_truthy())
18561928
}
18571929

1930+
implement_binop!(add, PyNumber_Add, "+");
1931+
implement_binop!(sub, PyNumber_Subtract, "-");
1932+
implement_binop!(mul, PyNumber_Multiply, "*");
1933+
implement_binop!(div, PyNumber_TrueDivide, "/");
1934+
implement_binop!(lshift, PyNumber_Lshift, "<<");
1935+
implement_binop!(rshift, PyNumber_Rshift, ">>");
1936+
implement_binop!(bitand, PyNumber_And, "&");
1937+
implement_binop!(bitor, PyNumber_Or, "|");
1938+
implement_binop!(bitxor, PyNumber_Xor, "^");
1939+
1940+
/// Computes `self ** other % modulus` (`pow(self, other, modulus)`).
1941+
/// `py.None()` may be passed for the `modulus`.
1942+
fn pow<O1, O2>(&self, other: O1, modulus: O2) -> PyResult<Bound<'py, PyAny>>
1943+
where
1944+
O1: ToPyObject,
1945+
O2: ToPyObject,
1946+
{
1947+
fn inner<'py>(
1948+
any: &Bound<'py, PyAny>,
1949+
other: Bound<'_, PyAny>,
1950+
modulus: Bound<'_, PyAny>,
1951+
) -> PyResult<Bound<'py, PyAny>> {
1952+
unsafe {
1953+
ffi::PyNumber_Power(any.as_ptr(), other.as_ptr(), modulus.as_ptr())
1954+
.assume_owned_or_err(any.py())
1955+
}
1956+
}
1957+
1958+
let py = self.py();
1959+
inner(
1960+
self,
1961+
other.to_object(py).into_bound(py),
1962+
modulus.to_object(py).into_bound(py),
1963+
)
1964+
}
1965+
18581966
fn is_callable(&self) -> bool {
18591967
unsafe { ffi::PyCallable_Check(self.as_ptr()) != 0 }
18601968
}

tests/test_arithmetics.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ impl BinaryArithmetic {
178178
format!("BA * {:?}", rhs)
179179
}
180180

181+
fn __truediv__(&self, rhs: &PyAny) -> String {
182+
format!("BA / {:?}", rhs)
183+
}
184+
181185
fn __lshift__(&self, rhs: &PyAny) -> String {
182186
format!("BA << {:?}", rhs)
183187
}
@@ -233,6 +237,18 @@ fn binary_arithmetic() {
233237
py_expect_exception!(py, c, "1 ** c", PyTypeError);
234238

235239
py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'");
240+
241+
let c: Bound<'_, PyAny> = c.extract().unwrap();
242+
assert_py_eq!(c.add(&c).unwrap(), "BA + BA");
243+
assert_py_eq!(c.sub(&c).unwrap(), "BA - BA");
244+
assert_py_eq!(c.mul(&c).unwrap(), "BA * BA");
245+
assert_py_eq!(c.div(&c).unwrap(), "BA / BA");
246+
assert_py_eq!(c.lshift(&c).unwrap(), "BA << BA");
247+
assert_py_eq!(c.rshift(&c).unwrap(), "BA >> BA");
248+
assert_py_eq!(c.bitand(&c).unwrap(), "BA & BA");
249+
assert_py_eq!(c.bitor(&c).unwrap(), "BA | BA");
250+
assert_py_eq!(c.bitxor(&c).unwrap(), "BA ^ BA");
251+
assert_py_eq!(c.pow(&c, py.None()).unwrap(), "BA ** BA (mod: None)");
236252
});
237253
}
238254

0 commit comments

Comments
 (0)