Skip to content

Commit 334e6b2

Browse files
authored
fix: handle reordered contents in ak.almost_equal (#2424)
* initial commit * test: check unions and records in `almost_equal` * fix: unions and records in almost_equal
1 parent 973c530 commit 334e6b2

File tree

5 files changed

+190
-16
lines changed

5 files changed

+190
-16
lines changed

src/awkward/_nplikes/array_module.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55

66
import numpy
77

8-
from awkward._nplikes.numpylike import ArrayLike, IndexType, NumpyLike, NumpyMetadata
8+
from awkward._nplikes.numpylike import (
9+
ArrayLike,
10+
IndexType,
11+
NumpyLike,
12+
NumpyMetadata,
13+
UniqueAllResult,
14+
)
915
from awkward._nplikes.shape import ShapeItem, unknown_length
1016
from awkward._typing import Final, Literal
1117

@@ -192,6 +198,31 @@ def unique_values(self, x: ArrayLike) -> ArrayLike:
192198
equal_nan=False,
193199
)
194200

201+
def unique_all(self, x: ArrayLike) -> UniqueAllResult:
202+
values, indices, inverse_indices, counts = self._module.unique(
203+
x, return_counts=True, return_index=True, return_inverse=True
204+
)
205+
# np.unique() flattens inverse indices, but they need to share x's shape
206+
# See https://github.com/numpy/numpy/issues/20638
207+
inverse_indices = inverse_indices.reshape(x.shape)
208+
return UniqueAllResult(values, indices, inverse_indices, counts)
209+
210+
def sort(
211+
self,
212+
x: ArrayLike,
213+
*,
214+
axis: int = -1,
215+
descending: bool = False,
216+
stable: bool = True,
217+
) -> ArrayLike:
218+
# Note: this keyword argument is different, and the default is different.
219+
kind = "stable" if stable else "quicksort"
220+
res = self._module.sort(x, axis=axis, kind=kind)
221+
if descending:
222+
return self._module.flip(res, axis=axis)
223+
else:
224+
return res
225+
195226
def concat(
196227
self,
197228
arrays: list[ArrayLike] | tuple[ArrayLike, ...],

src/awkward/_nplikes/numpylike.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from awkward._singleton import Singleton
1010
from awkward._typing import (
1111
Literal,
12+
NamedTuple,
1213
Protocol,
1314
Self,
1415
SupportsIndex,
@@ -19,6 +20,13 @@
1920
IndexType: TypeAlias = "int | ArrayLike"
2021

2122

23+
class UniqueAllResult(NamedTuple):
24+
values: ArrayLike
25+
indices: ArrayLike
26+
inverse_indices: ArrayLike
27+
counts: ArrayLike
28+
29+
2230
class ArrayLike(Protocol):
2331
@property
2432
@abstractmethod
@@ -402,6 +410,21 @@ def where(self, condition: ArrayLike, x1: ArrayLike, x2: ArrayLike) -> ArrayLike
402410
def unique_values(self, x: ArrayLike) -> ArrayLike:
403411
...
404412

413+
@abstractmethod
414+
def unique_all(self, x: ArrayLike) -> UniqueAllResult:
415+
...
416+
417+
@abstractmethod
418+
def sort(
419+
self,
420+
x: ArrayLike,
421+
*,
422+
axis: int = -1,
423+
descending: bool = False,
424+
stable: bool = True,
425+
) -> ArrayLike:
426+
...
427+
405428
@abstractmethod
406429
def concat(
407430
self,

src/awkward/_nplikes/typetracer.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88

99
import awkward as ak
1010
from awkward._nplikes.dispatch import register_nplike
11-
from awkward._nplikes.numpylike import ArrayLike, IndexType, NumpyLike, NumpyMetadata
11+
from awkward._nplikes.numpylike import (
12+
ArrayLike,
13+
IndexType,
14+
NumpyLike,
15+
NumpyMetadata,
16+
UniqueAllResult,
17+
)
1218
from awkward._nplikes.shape import ShapeItem, unknown_length
1319
from awkward._regularize import is_integer, is_non_string_like_sequence
1420
from awkward._typing import (
@@ -1081,6 +1087,26 @@ def unique_values(self, x: ArrayLike) -> TypeTracerArray:
10811087
try_touch_data(x)
10821088
return TypeTracerArray._new(x.dtype, shape=(unknown_length,))
10831089

1090+
def unique_all(self, x: ArrayLike) -> UniqueAllResult:
1091+
try_touch_data(x)
1092+
return UniqueAllResult(
1093+
TypeTracerArray._new(x.dtype, shape=(unknown_length,)),
1094+
TypeTracerArray._new(np.int64, shape=(unknown_length,)),
1095+
TypeTracerArray._new(np.int64, shape=x.shape),
1096+
TypeTracerArray._new(np.int64, shape=(unknown_length,)),
1097+
)
1098+
1099+
def sort(
1100+
self,
1101+
x: ArrayLike,
1102+
*,
1103+
axis: int = -1,
1104+
descending: bool = False,
1105+
stable: bool = True,
1106+
) -> ArrayLike:
1107+
try_touch_data(x)
1108+
return TypeTracerArray._new(x.dtype, shape=x.shape)
1109+
10841110
def concat(self, arrays, *, axis: int | None = 0) -> TypeTracerArray:
10851111
if axis is None:
10861112
assert all(x.ndim == 1 for x in arrays)

src/awkward/operations/ak_almost_equal.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -105,35 +105,80 @@ def visitor(left, right) -> bool:
105105
elif left.is_regular:
106106
return (left.size == right.size) and visitor(left.content, right.content)
107107
elif left.is_numpy:
108-
return (
109-
is_approx_dtype(left.dtype, right.dtype)
110-
and backend.nplike.all(
111-
backend.nplike.isclose(
112-
left.data, right.data, rtol=rtol, atol=atol, equal_nan=False
108+
# Timelike types must be exactly compared, including their units
109+
if (
110+
np.issubdtype(left.dtype, np.datetime64)
111+
or np.issubdtype(right.dtype, np.datetime64)
112+
or np.issubdtype(left.dtype, np.timedelta64)
113+
or np.issubdtype(right.dtype, np.timedelta64)
114+
):
115+
return (
116+
(left.dtype == right.dtype)
117+
and backend.nplike.all(left.data == right.data)
118+
and left.shape == right.shape
119+
)
120+
else:
121+
return (
122+
is_approx_dtype(left.dtype, right.dtype)
123+
and backend.nplike.all(
124+
backend.nplike.isclose(
125+
left.data, right.data, rtol=rtol, atol=atol, equal_nan=False
126+
)
113127
)
128+
and left.shape == right.shape
114129
)
115-
and left.shape == right.shape
116-
)
117-
118130
elif left.is_option:
119131
return backend.index_nplike.array_equal(
120132
left.index.data < 0, right.index.data < 0
121133
) and visitor(left.project().to_packed(), right.project().to_packed())
122134
elif left.is_union:
123-
return (len(left.contents) == len(right.contents)) and all(
124-
visitor(left.project(i).to_packed(), right.project(i).to_packed())
125-
for i, _ in enumerate(left.contents)
135+
# For two unions with different content orderings to match, the tags should be equal at each index
136+
# Therefore, we can order the contents by index appearance
137+
def ordered_unique_values(values):
138+
# First, find unique values and their appearance (from smallest to largest)
139+
# unique_index is in ascending order of `unique` value
140+
(
141+
unique,
142+
unique_index,
143+
*_,
144+
) = backend.index_nplike.unique_all(values)
145+
# Now re-order `unique` by order of appearance (`unique_index`)
146+
return values[backend.index_nplike.sort(unique_index)]
147+
148+
# Find order of appearance for each union tags, and assume these are one-to-one maps
149+
left_tag_order = ordered_unique_values(left.tags.data)
150+
right_tag_order = ordered_unique_values(right.tags.data)
151+
152+
# Create map from left tags to right tags
153+
left_tag_to_right_tag = backend.index_nplike.empty(
154+
left_tag_order.size, dtype=np.int64
126155
)
156+
left_tag_to_right_tag[left_tag_order] = right_tag_order
157+
158+
# Map left tags onto right, such that the result should equal right.tags
159+
# if the two tag arrays are equivalent
160+
new_left_tag = left_tag_to_right_tag[left.tags.data]
161+
if not backend.index_nplike.all(new_left_tag == right.tags.data):
162+
return False
163+
164+
# Now project out the contents, and check for equality
165+
for i, j in zip(left_tag_order, right_tag_order):
166+
if not visitor(
167+
left.project(i).to_packed(), right.project(j).to_packed()
168+
):
169+
return False
170+
return True
171+
127172
elif left.is_record:
128173
return (
129174
(
130175
get_record_class(left, left_behavior)
131176
is get_record_class(right, right_behavior)
132177
or not check_parameters
133178
)
134-
and (left.fields == right.fields)
135-
and (left.is_tuple == right.is_tuple)
136-
and all(visitor(x, y) for x, y in zip(left.contents, right.contents))
179+
and left.is_tuple == right.is_tuple
180+
and (left.is_tuple or (len(left.fields) == len(right.fields)))
181+
and all(visitor(left.content(f), right.content(f)) for f in left.fields)
137182
)
138183
elif left.is_unknown:
139184
return True
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
2+
3+
import numpy as np
4+
import pytest # noqa: F401
5+
6+
import awkward as ak
7+
8+
9+
def test_records_almost_equal():
10+
first = ak.contents.RecordArray(
11+
[
12+
ak.contents.NumpyArray(np.array([1, 2, 3], dtype=np.int64)),
13+
ak.contents.NumpyArray(np.array([0], dtype=np.dtype("<M8[s]"))),
14+
],
15+
["x", "y"],
16+
)
17+
18+
second = ak.contents.RecordArray(
19+
[
20+
ak.contents.NumpyArray(np.array([0], dtype=np.dtype("<M8[s]"))),
21+
ak.contents.NumpyArray(np.array([1, 2, 3], dtype=np.int64)),
22+
],
23+
["y", "x"],
24+
)
25+
26+
assert ak.almost_equal(first, second)
27+
28+
29+
def test_unions_almost_equal():
30+
# Check unions agree!
31+
first = ak.contents.UnionArray(
32+
ak.index.Index8([0, 0, 2, 1, 1]),
33+
ak.index.Index64([0, 1, 0, 0, 1]),
34+
[
35+
ak.contents.NumpyArray(np.array([1, 2, 3], dtype=np.int64)),
36+
ak.contents.NumpyArray(np.array([0, 1], dtype=np.dtype("<M8[s]"))),
37+
ak.contents.NumpyArray(np.array([0, 1, 0, 1], dtype=np.bool_)),
38+
],
39+
)
40+
second = ak.contents.UnionArray(
41+
ak.index.Index8([1, 1, 0, 2, 2]),
42+
ak.index.Index64([0, 1, 0, 0, 1]),
43+
[
44+
ak.contents.NumpyArray(np.array([0, 1, 0, 1], dtype=np.bool_)),
45+
ak.contents.NumpyArray(np.array([1, 2, 3], dtype=np.int64)),
46+
ak.contents.NumpyArray(np.array([0, 1], dtype=np.dtype("<M8[s]"))),
47+
],
48+
)
49+
assert ak.almost_equal(first, second)

0 commit comments

Comments
 (0)