Skip to content

Commit b03c4cb

Browse files
authored
Merge pull request #3506 from davidhewitt/default-ne
Fix bug in default implementation of `__ne__`
2 parents b73c069 + e1d4173 commit b03c4cb

File tree

4 files changed

+65
-22
lines changed

4 files changed

+65
-22
lines changed

guide/src/class/object.md

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ impl Number {
7373

7474
In the `__repr__`, we used a hard-coded class name. This is sometimes not ideal,
7575
because if the class is subclassed in Python, we would like the repr to reflect
76-
the subclass name. This is typically done in Python code by accessing
76+
the subclass name. This is typically done in Python code by accessing
7777
`self.__class__.__name__`. In order to be able to access the Python type information
7878
*and* the Rust struct, we need to use a `PyCell` as the `self` argument.
7979

@@ -149,8 +149,8 @@ impl Number {
149149
150150
### Comparisons
151151
152-
Unlike in Python, PyO3 does not provide the magic comparison methods you might expect like `__eq__`,
153-
`__lt__` and so on. Instead you have to implement all six operations at once with `__richcmp__`.
152+
PyO3 supports the usual magic comparison methods available in Python such as `__eq__`, `__lt__`
153+
and so on. It is also possible to support all six operations at once with `__richcmp__`.
154154
This method will be called with a value of `CompareOp` depending on the operation.
155155
156156
```rust
@@ -198,28 +198,31 @@ impl Number {
198198
It checks that the `std::cmp::Ordering` obtained from Rust's `Ord` matches
199199
the given `CompareOp`.
200200

201-
Alternatively, if you want to leave some operations unimplemented, you can
202-
return `py.NotImplemented()` for some of the operations:
201+
Alternatively, you can implement just equality using `__eq__`:
203202

204203

205204
```rust
206-
use pyo3::class::basic::CompareOp;
207-
208205
# use pyo3::prelude::*;
209206
#
210207
# #[pyclass]
211208
# struct Number(i32);
212209
#
213210
#[pymethods]
214211
impl Number {
215-
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
216-
match op {
217-
CompareOp::Eq => (self.0 == other.0).into_py(py),
218-
CompareOp::Ne => (self.0 != other.0).into_py(py),
219-
_ => py.NotImplemented(),
220-
}
212+
fn __eq__(&self, other: &Self) -> bool {
213+
self.0 == other.0
221214
}
222215
}
216+
217+
# fn main() -> PyResult<()> {
218+
# Python::with_gil(|py| {
219+
# let x = PyCell::new(py, Number(4))?;
220+
# let y = PyCell::new(py, Number(4))?;
221+
# assert!(x.eq(y)?);
222+
# assert!(!x.ne(y)?);
223+
# Ok(())
224+
# })
225+
# }
223226
```
224227

225228
### Truthyness

guide/src/class/protocols.md

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,41 @@ given signatures should be interpreted as follows:
7676
- `__richcmp__(<self>, object, pyo3::basic::CompareOp) -> object`
7777

7878
Implements Python comparison operations (`==`, `!=`, `<`, `<=`, `>`, and `>=`) in a single method.
79-
The `CompareOp` argument indicates the comparison operation being performed.
79+
The `CompareOp` argument indicates the comparison operation being performed. You can use
80+
[`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result to the requested comparison.
8081

8182
_This method cannot be implemented in combination with any of `__lt__`, `__le__`, `__eq__`, `__ne__`, `__gt__`, or `__ge__`._
8283

8384
_Note that implementing `__richcmp__` will cause Python not to generate a default `__hash__` implementation, so consider implementing `__hash__` when implementing `__richcmp__`._
8485
<details>
8586
<summary>Return type</summary>
8687
The return type will normally be `PyResult<bool>`, but any Python object can be returned.
88+
89+
If you want to leave some operations unimplemented, you can return `py.NotImplemented()`
90+
for some of the operations:
91+
92+
```rust
93+
use pyo3::class::basic::CompareOp;
94+
95+
# use pyo3::prelude::*;
96+
#
97+
# #[pyclass]
98+
# struct Number(i32);
99+
#
100+
#[pymethods]
101+
impl Number {
102+
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
103+
match op {
104+
CompareOp::Eq => (self.0 == other.0).into_py(py),
105+
CompareOp::Ne => (self.0 != other.0).into_py(py),
106+
_ => py.NotImplemented(),
107+
}
108+
}
109+
}
110+
```
111+
87112
If the second argument `object` is not of the type specified in the
88113
signature, the generated code will automatically `return NotImplemented`.
89-
90-
You can use [`CompareOp::matches`] to adapt a Rust `std::cmp::Ordering` result
91-
to the requested comparison.
92114
</details>
93115

94116
- `__getattr__(<self>, object) -> object`

pytests/tests/test_comparisons.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@ def test_eq(ty: Type[Union[Eq, PyEq]]):
2323
c = ty(1)
2424

2525
assert a == b
26+
assert not (a != b)
2627
assert a != c
28+
assert not (a == c)
2729

2830
assert b == a
31+
assert not (a != b)
2932
assert b != c
33+
assert not (b == c)
3034

3135
with pytest.raises(TypeError):
3236
assert a <= b
@@ -49,17 +53,21 @@ def __eq__(self, other: Self) -> bool:
4953
return self.x == other.x
5054

5155

52-
@pytest.mark.parametrize("ty", (Eq, PyEq), ids=("rust", "python"))
56+
@pytest.mark.parametrize("ty", (EqDefaultNe, PyEqDefaultNe), ids=("rust", "python"))
5357
def test_eq_default_ne(ty: Type[Union[EqDefaultNe, PyEqDefaultNe]]):
5458
a = ty(0)
5559
b = ty(0)
5660
c = ty(1)
5761

5862
assert a == b
63+
assert not (a != b)
5964
assert a != c
65+
assert not (a == c)
6066

6167
assert b == a
68+
assert not (a != b)
6269
assert b != c
70+
assert not (b == c)
6371

6472
with pytest.raises(TypeError):
6573
assert a <= b
@@ -152,19 +160,25 @@ def test_ordered_default_ne(ty: Type[Union[OrderedDefaultNe, PyOrderedDefaultNe]
152160
c = ty(1)
153161

154162
assert a == b
163+
assert not (a != b)
155164
assert a <= b
156165
assert a >= b
157166
assert a != c
167+
assert not (a == c)
158168
assert a <= c
159169

160170
assert b == a
171+
assert not (b != a)
161172
assert b <= a
162173
assert b >= a
163174
assert b != c
175+
assert not (b == c)
164176
assert b <= c
165177

166178
assert c != a
179+
assert not (c == a)
167180
assert c != b
181+
assert not (c == b)
168182
assert c > a
169183
assert c >= a
170184
assert c > b

src/impl_/pyclass.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::{
66
internal_tricks::extract_c_string,
77
pycell::PyCellLayout,
88
pyclass_init::PyObjectInit,
9+
types::PyBool,
910
Py, PyAny, PyCell, PyClass, PyErr, PyMethodDefType, PyNativeType, PyResult, PyTypeInfo, Python,
1011
};
1112
use std::{
@@ -805,11 +806,14 @@ slot_fragment_trait! {
805806
#[inline]
806807
unsafe fn __ne__(
807808
self,
808-
_py: Python<'_>,
809-
_slf: *mut ffi::PyObject,
810-
_other: *mut ffi::PyObject,
809+
py: Python<'_>,
810+
slf: *mut ffi::PyObject,
811+
other: *mut ffi::PyObject,
811812
) -> PyResult<*mut ffi::PyObject> {
812-
Ok(ffi::_Py_NewRef(ffi::Py_NotImplemented()))
813+
// By default `__ne__` will try `__eq__` and invert the result
814+
let slf: &PyAny = py.from_borrowed_ptr(slf);
815+
let other: &PyAny = py.from_borrowed_ptr(other);
816+
slf.eq(other).map(|is_eq| PyBool::new(py, !is_eq).into_ptr())
813817
}
814818
}
815819

0 commit comments

Comments
 (0)