Skip to content

fix: handle reordered contents in ak.almost_equal #2424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/awkward/_nplikes/array_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

import numpy

from awkward._nplikes.numpylike import ArrayLike, IndexType, NumpyLike, NumpyMetadata
from awkward._nplikes.numpylike import (
ArrayLike,
IndexType,
NumpyLike,
NumpyMetadata,
UniqueAllResult,
)
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward._typing import Final, Literal

Expand Down Expand Up @@ -192,6 +198,31 @@ def unique_values(self, x: ArrayLike) -> ArrayLike:
equal_nan=False,
)

def unique_all(self, x: ArrayLike) -> UniqueAllResult:
values, indices, inverse_indices, counts = self._module.unique(
x, return_counts=True, return_index=True, return_inverse=True
)
# np.unique() flattens inverse indices, but they need to share x's shape
# See https://github.com/numpy/numpy/issues/20638
inverse_indices = inverse_indices.reshape(x.shape)
return UniqueAllResult(values, indices, inverse_indices, counts)

def sort(
self,
x: ArrayLike,
*,
axis: int = -1,
descending: bool = False,
stable: bool = True,
) -> ArrayLike:
# Note: this keyword argument is different, and the default is different.
kind = "stable" if stable else "quicksort"
res = self._module.sort(x, axis=axis, kind=kind)
if descending:
return self._module.flip(res, axis=axis)
else:
return res

def concat(
self,
arrays: list[ArrayLike] | tuple[ArrayLike, ...],
Expand Down
23 changes: 23 additions & 0 deletions src/awkward/_nplikes/numpylike.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from awkward._singleton import Singleton
from awkward._typing import (
Literal,
NamedTuple,
Protocol,
Self,
SupportsIndex,
Expand All @@ -19,6 +20,13 @@
IndexType: TypeAlias = "int | ArrayLike"


class UniqueAllResult(NamedTuple):
values: ArrayLike
indices: ArrayLike
inverse_indices: ArrayLike
counts: ArrayLike


class ArrayLike(Protocol):
@property
@abstractmethod
Expand Down Expand Up @@ -402,6 +410,21 @@ def where(self, condition: ArrayLike, x1: ArrayLike, x2: ArrayLike) -> ArrayLike
def unique_values(self, x: ArrayLike) -> ArrayLike:
...

@abstractmethod
def unique_all(self, x: ArrayLike) -> UniqueAllResult:
...

@abstractmethod
def sort(
self,
x: ArrayLike,
*,
axis: int = -1,
descending: bool = False,
stable: bool = True,
) -> ArrayLike:
...

@abstractmethod
def concat(
self,
Expand Down
28 changes: 27 additions & 1 deletion src/awkward/_nplikes/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@

import awkward as ak
from awkward._nplikes.dispatch import register_nplike
from awkward._nplikes.numpylike import ArrayLike, IndexType, NumpyLike, NumpyMetadata
from awkward._nplikes.numpylike import (
ArrayLike,
IndexType,
NumpyLike,
NumpyMetadata,
UniqueAllResult,
)
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward._regularize import is_integer, is_non_string_like_sequence
from awkward._typing import (
Expand Down Expand Up @@ -1081,6 +1087,26 @@ def unique_values(self, x: ArrayLike) -> TypeTracerArray:
try_touch_data(x)
return TypeTracerArray._new(x.dtype, shape=(unknown_length,))

def unique_all(self, x: ArrayLike) -> UniqueAllResult:
try_touch_data(x)
return UniqueAllResult(
TypeTracerArray._new(x.dtype, shape=(unknown_length,)),
TypeTracerArray._new(np.int64, shape=(unknown_length,)),
TypeTracerArray._new(np.int64, shape=x.shape),
TypeTracerArray._new(np.int64, shape=(unknown_length,)),
)

def sort(
self,
x: ArrayLike,
*,
axis: int = -1,
descending: bool = False,
stable: bool = True,
) -> ArrayLike:
try_touch_data(x)
return TypeTracerArray._new(x.dtype, shape=x.shape)

def concat(self, arrays, *, axis: int | None = 0) -> TypeTracerArray:
if axis is None:
assert all(x.ndim == 1 for x in arrays)
Expand Down
73 changes: 59 additions & 14 deletions src/awkward/operations/ak_almost_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,35 +105,80 @@ def visitor(left, right) -> bool:
elif left.is_regular:
return (left.size == right.size) and visitor(left.content, right.content)
elif left.is_numpy:
return (
is_approx_dtype(left.dtype, right.dtype)
and backend.nplike.all(
backend.nplike.isclose(
left.data, right.data, rtol=rtol, atol=atol, equal_nan=False
# Timelike types must be exactly compared, including their units
if (
np.issubdtype(left.dtype, np.datetime64)
or np.issubdtype(right.dtype, np.datetime64)
or np.issubdtype(left.dtype, np.timedelta64)
or np.issubdtype(right.dtype, np.timedelta64)
):
return (
(left.dtype == right.dtype)
and backend.nplike.all(left.data == right.data)
and left.shape == right.shape
)
else:
return (
is_approx_dtype(left.dtype, right.dtype)
and backend.nplike.all(
backend.nplike.isclose(
left.data, right.data, rtol=rtol, atol=atol, equal_nan=False
)
)
and left.shape == right.shape
)
and left.shape == right.shape
)

elif left.is_option:
return backend.index_nplike.array_equal(
left.index.data < 0, right.index.data < 0
) and visitor(left.project().to_packed(), right.project().to_packed())
elif left.is_union:
return (len(left.contents) == len(right.contents)) and all(
visitor(left.project(i).to_packed(), right.project(i).to_packed())
for i, _ in enumerate(left.contents)
# For two unions with different content orderings to match, the tags should be equal at each index
# Therefore, we can order the contents by index appearance
def ordered_unique_values(values):
# First, find unique values and their appearance (from smallest to largest)
# unique_index is in ascending order of `unique` value
(
unique,
unique_index,
*_,
) = backend.index_nplike.unique_all(values)
# Now re-order `unique` by order of appearance (`unique_index`)
return values[backend.index_nplike.sort(unique_index)]

# Find order of appearance for each union tags, and assume these are one-to-one maps
left_tag_order = ordered_unique_values(left.tags.data)
right_tag_order = ordered_unique_values(right.tags.data)

# Create map from left tags to right tags
left_tag_to_right_tag = backend.index_nplike.empty(
left_tag_order.size, dtype=np.int64
)
left_tag_to_right_tag[left_tag_order] = right_tag_order

# Map left tags onto right, such that the result should equal right.tags
# if the two tag arrays are equivalent
new_left_tag = left_tag_to_right_tag[left.tags.data]
if not backend.index_nplike.all(new_left_tag == right.tags.data):
return False

# Now project out the contents, and check for equality
for i, j in zip(left_tag_order, right_tag_order):
if not visitor(
left.project(i).to_packed(), right.project(j).to_packed()
):
return False
return True

elif left.is_record:
return (
(
get_record_class(left, left_behavior)
is get_record_class(right, right_behavior)
or not check_parameters
)
and (left.fields == right.fields)
and (left.is_tuple == right.is_tuple)
and all(visitor(x, y) for x, y in zip(left.contents, right.contents))
and left.is_tuple == right.is_tuple
and (left.is_tuple or (len(left.fields) == len(right.fields)))
and all(visitor(left.content(f), right.content(f)) for f in left.fields)
)
elif left.is_unknown:
return True
Expand Down
49 changes: 49 additions & 0 deletions tests/test_2424_almost_equal_union_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import numpy as np
import pytest # noqa: F401

import awkward as ak


def test_records_almost_equal():
first = ak.contents.RecordArray(
[
ak.contents.NumpyArray(np.array([1, 2, 3], dtype=np.int64)),
ak.contents.NumpyArray(np.array([0], dtype=np.dtype("<M8[s]"))),
],
["x", "y"],
)

second = ak.contents.RecordArray(
[
ak.contents.NumpyArray(np.array([0], dtype=np.dtype("<M8[s]"))),
ak.contents.NumpyArray(np.array([1, 2, 3], dtype=np.int64)),
],
["y", "x"],
)

assert ak.almost_equal(first, second)


def test_unions_almost_equal():
# Check unions agree!
first = ak.contents.UnionArray(
ak.index.Index8([0, 0, 2, 1, 1]),
ak.index.Index64([0, 1, 0, 0, 1]),
[
ak.contents.NumpyArray(np.array([1, 2, 3], dtype=np.int64)),
ak.contents.NumpyArray(np.array([0, 1], dtype=np.dtype("<M8[s]"))),
ak.contents.NumpyArray(np.array([0, 1, 0, 1], dtype=np.bool_)),
],
)
second = ak.contents.UnionArray(
ak.index.Index8([1, 1, 0, 2, 2]),
ak.index.Index64([0, 1, 0, 0, 1]),
[
ak.contents.NumpyArray(np.array([0, 1, 0, 1], dtype=np.bool_)),
ak.contents.NumpyArray(np.array([1, 2, 3], dtype=np.int64)),
ak.contents.NumpyArray(np.array([0, 1], dtype=np.dtype("<M8[s]"))),
],
)
assert ak.almost_equal(first, second)