Skip to content

Commit

Permalink
feat: Add ak.array_equal (#3215)
Browse files Browse the repository at this point in the history
* Preparatory refactor: create ak_almost_equal._impl

* Adding ak.array_equal

Includes a very minimal test case. Needs more.

* Fixing bug in nplike, array_equal

The NaN values were not compared correctly.
Also adding several tests of ak.array_equal.
Also firing a minor issue in ak.array_equal.

* Fixing array_equal docstring

* Update src/awkward/operations/ak_almost_equal.py

Use is operator to compare classes

Co-authored-by: Jim Pivarski <[email protected]>

* Remove defaults from args to ak_almost_equal._impl

* Fixing more bugs

* Possible fix for old numpy

* style: pre-commit fixes

---------

Co-authored-by: Jim Pivarski <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 14, 2024
1 parent 6eb8627 commit e15518f
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/awkward/_nplikes/array_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def array_equal(
assert not isinstance(x1, PlaceholderArray)
assert not isinstance(x2, PlaceholderArray)
if equal_nan:
both_nan = self._module.logical_and(x1 == np.nan, x2 == np.nan)
# Only newer numpy.array_equal supports the equal_nan parameter.
both_nan = self._module.logical_and(
self._module.isnan(x1), self._module.isnan(x2)
)
both_equal = x1 == x2
return self._module.all(self._module.logical_or(both_equal, both_nan))
else:
Expand Down
1 change: 1 addition & 0 deletions src/awkward/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from awkward.operations.ak_argmax import *
from awkward.operations.ak_argmin import *
from awkward.operations.ak_argsort import *
from awkward.operations.ak_array_equal import *
from awkward.operations.ak_backend import *
from awkward.operations.ak_broadcast_arrays import *
from awkward.operations.ak_broadcast_fields import *
Expand Down
46 changes: 45 additions & 1 deletion src/awkward/operations/ak_almost_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,32 @@ def almost_equal(
# Dispatch
yield left, right

return _impl(
left,
right,
rtol=rtol,
atol=atol,
dtype_exact=dtype_exact,
check_parameters=check_parameters,
check_regular=check_regular,
exact_eq=False,
same_content_types=False,
equal_nan=False,
)


def _impl(
left,
right,
rtol: float,
atol: float,
dtype_exact: bool,
check_parameters: bool,
check_regular: bool,
exact_eq: bool,
same_content_types: bool,
equal_nan: bool,
):
# Implementation
left_behavior = behavior_of(left)
right_behavior = behavior_of(right)
Expand Down Expand Up @@ -82,6 +108,10 @@ def packed_list_content(layout):
return layout.content[layout.offsets[0] : layout.offsets[-1]]

def visitor(left, right) -> bool:
# Most firstly, check same_content_types before any transformations
if same_content_types and left.__class__ is not right.__class__:
return False

# First, erase indexed types!
if left.is_indexed and not left.is_option:
left = left.project()
Expand Down Expand Up @@ -152,12 +182,26 @@ def visitor(left, right) -> bool:
and backend.nplike.all(left.data == right.data)
and left.shape == right.shape
)
elif exact_eq:
return (
is_approx_dtype(left.dtype, right.dtype)
and backend.nplike.array_equal(
left.data,
right.data,
equal_nan=equal_nan,
)
and left.shape == right.shape
)
else:
return (
is_approx_dtype(left.dtype, right.dtype)
and backend.nplike.all(
backend.nplike.isclose(
left.data, right.data, rtol=rtol, atol=atol, equal_nan=False
left.data,
right.data,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
)
)
and left.shape == right.shape
Expand Down
54 changes: 54 additions & 0 deletions src/awkward/operations/ak_array_equal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import awkward as ak
from awkward._dispatch import high_level_function

__all__ = ("array_equal",)


@high_level_function()
def array_equal(
a1,
a2,
equal_nan: bool = False,
dtype_exact: bool = True,
same_content_types: bool = True,
check_parameters: bool = True,
check_regular: bool = True,
):
"""
True if two arrays have the same shape and elements, False otherwise.
Args:
a1: Array-like data (anything #ak.to_layout recognizes).
a2: Array-like data (anything #ak.to_layout recognizes).
equal_nan: bool (default=False)
Whether to count NaN values as equal to each other.
dtype_exact: bool (default=True) whether the dtypes must be exactly the same, or just the
same family.
same_content_types: bool (default=True)
Whether to require all content classes to match
check_parameters: bool (default=True) whether to compare parameters.
check_regular: bool (default=True) whether to consider ragged and regular dimensions as
unequal.
TypeTracer arrays are not supported, as there is very little information to
be compared.
"""
# Dispatch
yield a1, a2

return ak.operations.ak_almost_equal._impl(
a1,
a2,
rtol=0.0,
atol=0.0,
dtype_exact=dtype_exact,
check_parameters=check_parameters,
check_regular=check_regular,
exact_eq=True,
same_content_types=same_content_types and check_regular,
equal_nan=equal_nan,
)
90 changes: 90 additions & 0 deletions tests/test_1105_ak_aray_equal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy as np

import awkward as ak
from awkward.contents import NumpyArray
from awkward.index import Index64


def test_array_equal_simple():
assert ak.array_equal(
ak.Array([[1, 2], [], [3, 4, 5]]),
ak.Array([[1, 2], [], [3, 4, 5]]),
)


def test_array_equal_mixed_dtype():
assert not ak.array_equal(
ak.Array(np.array([1.5, 2.0, 3.25], dtype=np.float32)),
ak.Array(np.array([1.5, 2.0, 3.25], dtype=np.float64)),
)
assert ak.array_equal(
ak.Array(np.array([1.5, 2.0, 3.25], dtype=np.float32)),
ak.Array(np.array([1.5, 2.0, 3.25], dtype=np.float64)),
dtype_exact=False,
)


def test_array_equal_on_listoffsets():
a1 = ak.contents.ListOffsetArray(
Index64(np.array([0, 3, 3, 5])), NumpyArray(np.arange(5) * 1.5)
)
a2 = ak.contents.ListOffsetArray(
Index64(np.array([0, 3, 3, 5])),
NumpyArray(np.arange(10) * 1.5), # Longer array content than a1
)
assert ak.array_equal(a1, a2)


def test_array_equal_mixed_content_type():
a1 = ak.Array([[1, 2, 3], [4, 5, 6], [7, 8]])
a1r = ak.to_regular(a1[:2])
assert not ak.array_equal(a1[:2], a1r)
assert ak.array_equal(a1[:2], a1r, check_regular=False)
assert not ak.array_equal(a1, a1r, check_regular=False)

assert ak.array_equal(
a1, a1.layout
) # high-level automatically converted to layout in pre-check

a2_np = ak.contents.NumpyArray(np.arange(3))
a2_ia = ak.contents.IndexedArray(
Index64(np.array([0, 1, 2])), NumpyArray(np.arange(3))
)
assert ak.array_equal(a2_np, a2_ia, same_content_types=False)


def test_array_equal_opion_types():
a1 = ak.Array([1, 2, None, 4])
a2 = ak.concatenate([ak.Array([1, 2]), ak.Array([None, 4])])
assert ak.array_equal(a1, a2)

a3 = a1.layout.to_ByteMaskedArray(valid_when=True)
assert not ak.array_equal(a1, a3, same_content_types=True)
assert ak.array_equal(a1, a3, same_content_types=False)
assert not ak.array_equal(
a1, ak.Array([1, 2, 3, 4]), same_content_types=False, dtype_exact=False
)


def test_array_equal_nan():
a = ak.Array([1.0, 2.5, np.nan])
nplike = a.layout.backend.nplike
assert not nplike.array_equal(a.layout.data, a.layout.data)
assert nplike.array_equal(a.layout.data, a.layout.data, equal_nan=True)
assert not ak.array_equal(a, a)
assert ak.array_equal(a, a, equal_nan=True)


def test_array_equal_with_params():
a1 = NumpyArray(
np.array([1, 2, 3], dtype=np.uint32), parameters={"foo": {"bar": "baz"}}
)
a2 = NumpyArray(
np.array([1, 2, 3], dtype=np.uint32), parameters={"foo": {"bar": "Not so fast"}}
)
assert not ak.array_equal(a1, a2)
assert ak.array_equal(a1, a2, check_parameters=False)

0 comments on commit e15518f

Please sign in to comment.