Skip to content

Commit

Permalink
Fixes #2703: Sort bug with nans (#2755)
Browse files Browse the repository at this point in the history
* Fixes #2703: Sort bug with `nan`s

This PR (fixes #2703)

When a `nan` is present in `a` the value of `min reduce a` will equal `nan`. So `signbit(min reduce a)` will be false even if there are negatives present. This was causing the sort to mishandle `0.0`

I updated the code to do the same thing it used to if `min reduce a` is not a `nan`, and when it is to find the signbits of all values see if any are true (i.e. `| reduce signbit(a)`

I feel like calling `signbit` on every value of `a` then reducing shouldn't be too much more expensive than reducing first and doing only one `signbit` call. But I know the sort code is super optimized, so if ronawho doesn't mind looking this over and making sure I'm not doing something dumb that will kill the performance

* upated in response to PR feedback

---------

Co-authored-by: Pierce Hayes <[email protected]>
  • Loading branch information
stress-tess and Pierce Hayes authored Sep 6, 2023
1 parent 2c8061b commit e574712
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 5 deletions.
52 changes: 52 additions & 0 deletions PROTO_tests/tests/extrema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,29 @@ def make_np_arrays(size, dtype):
return None


def make_np_edge_cases(dtype):
if dtype == "int64":
return np.array([np.iinfo(np.int64).min, -1, 0, 3, np.iinfo(np.int64).max])
elif dtype == "uint64":
return np.array([17, 2**64-1, 0, 3, 2**63 + 10], dtype=np.uint64)
elif dtype == "float64":
return np.array(
[
np.nan,
-np.inf,
np.finfo(np.float64).min,
-3.14,
0.0,
3.14,
8,
np.finfo(np.float64).max,
np.inf,
np.nan,
]
)
return None


class TestExtrema:
@pytest.mark.parametrize("prob_size", pytest.prob_size)
@pytest.mark.parametrize("dtype", ["int64", "uint64", "float64"])
Expand All @@ -34,6 +57,35 @@ def test_extrema(self, prob_size, dtype):
assert (ak.maxk(pda, K) == ak_sorted[-K:]).all()
assert (pda[ak.argmaxk(pda, K)] == ak_sorted[-K:]).all()

@pytest.mark.parametrize("dtype", ["int64", "uint64", "float64"])
def test_extrema_edge_cases(self, dtype):
edge_cases = make_np_edge_cases(dtype)
size = edge_cases.size // 2
# Due to #2754, we need to have replacement off to avoid all values = min/max(dtype)
npa = np.random.choice(edge_cases, edge_cases.size // 2, replace=False)
pda = ak.array(npa)
K = size // 2
np_sorted = np.sort(npa)

# extremas ignore nans
non_nan_sorted = np_sorted[~np.isnan(np_sorted)]

if non_nan_sorted.size >= K:
# compare minimums against first K elements from sorted array
assert np.allclose(ak.mink(pda, K).to_ndarray(), non_nan_sorted[:K], equal_nan=True)
# check for -1s to avoid oob due to #2754
arg_min_k = ak.argmink(pda, K)
if (arg_min_k != -1).all():
assert np.allclose(pda[arg_min_k].to_ndarray(), non_nan_sorted[:K], equal_nan=True)

# compare maximums against last K elements from sorted array
assert np.allclose(ak.maxk(pda, K).to_ndarray(), non_nan_sorted[-K:], equal_nan=True)
# check for -1s to avoid oob due to #2754
arg_max_k = ak.argmaxk(pda, K)
if (arg_max_k != -1).all():
assert np.allclose(pda[arg_max_k].to_ndarray(), non_nan_sorted[-K:], equal_nan=True)


@pytest.mark.parametrize("dtype", NUMERIC_TYPES)
def test_argmin_and_argmax(self, dtype):
np_arr = make_np_arrays(1000, dtype)
Expand Down
8 changes: 8 additions & 0 deletions PROTO_tests/tests/sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,11 @@ def test_error_handling(self, algo):
# Test attempt to sort Strings object, which is unsupported
with pytest.raises(TypeError):
ak.sort(ak.array([f"string {i}" for i in range(10)]), algo)

@pytest.mark.parametrize("algo", SortingAlgorithm)
def test_nan_sort(self, algo):
# Reproducer from #2703
neg_arr = np.array([-3.14, np.inf, np.nan, -np.inf, 3.14, 0.0, 3.14, -8])
pos_arr = np.array([3.14, np.inf, np.nan, np.inf, 7.7, 0.0, 3.14, 8])
for npa in neg_arr, pos_arr:
assert np.allclose(np.sort(npa), ak.sort(ak.array(npa), algo).to_ndarray(), equal_nan=True)
2 changes: 1 addition & 1 deletion src/AryUtil.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ module AryUtil

inline proc getBitWidth(a: [?aD] real): (int, bool) {
const bitWidth = numBits(real);
const negs = signbit(min reduce a);
const negs = | reduce signbit(a);
return (bitWidth, negs);
}

Expand Down
16 changes: 12 additions & 4 deletions tests/sort_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
from base_test import ArkoudaTest
from context import arkouda as ak

from arkouda.sorting import SortingAlgorithm

"""
Expand All @@ -24,10 +26,9 @@ def testSort(self):
for algo in SortingAlgorithm:
sorted_pda = ak.sort(pda, algo)
sorted_bi = ak.sort(shift_up, algo)
self.assertListEqual((sorted_bi - 2 ** 200).to_list(), sorted_pda.to_list())
self.assertListEqual((sorted_bi - 2**200).to_list(), sorted_pda.to_list())

def testBitBoundaryHardcode(self):

# test hardcoded 16/17-bit boundaries with and without negative values
a = ak.array([1, -1, 32767]) # 16 bit
b = ak.array([1, 0, 32768]) # 16 bit
Expand All @@ -47,7 +48,6 @@ def testBitBoundaryHardcode(self):
assert ak.is_sorted(ak.sort(f, algo))

def testBitBoundary(self):

# test 17-bit sort
L = -(2**15)
U = 2**16
Expand All @@ -56,7 +56,6 @@ def testBitBoundary(self):
assert ak.is_sorted(ak.sort(a, algo))

def testErrorHandling(self):

# Test RuntimeError from bool NotImplementedError
akbools = ak.randint(0, 1, 1000, dtype=ak.bool)
bools = ak.randint(0, 1, 1000, dtype=bool)
Expand All @@ -75,3 +74,12 @@ def testErrorHandling(self):
# Test attempt to sort Strings object, which is unsupported
with self.assertRaises(TypeError):
ak.sort(ak.array(["String {}".format(i) for i in range(0, 10)]), algo)

def test_nan_sort(self):
# Reproducer from #2703
neg_arr = np.array([-3.14, np.inf, np.nan, -np.inf, 3.14, 0.0, 3.14, -8])
pos_arr = np.array([3.14, np.inf, np.nan, np.inf, 7.7, 0.0, 3.14, 8])
for npa in neg_arr, pos_arr:
self.assertTrue(
np.allclose(np.sort(npa), ak.sort(ak.array(npa)).to_ndarray(), equal_nan=True)
)

0 comments on commit e574712

Please sign in to comment.