From 877747bd43add591d6caa2075e5c26259f30bbff Mon Sep 17 00:00:00 2001 From: Ethan-DeBandi99 <16845933+Ethan-DeBandi99@users.noreply.github.com> Date: Mon, 6 Jun 2022 13:38:55 -0400 Subject: [PATCH] Closes #1452 - Support for `argmin` and `argmax` on `bool` values. (#1469) * Adding argmax/argmin processing for boolean arrays. Adding testing. Updating server messaging to support argmax/argmin for segment reductions. Updated Groupby to remove exception when bool pdarray passed for values. Added testing. * REbase cleanup * Correcting test case * Removing test cases that are no longer valid. --- arkouda/groupbyclass.py | 6 ++---- src/ReductionMsg.chpl | 16 ++++++++++++++++ tests/extrema_test.py | 12 ++++++++++++ tests/groupby_test.py | 18 ++++++++++++------ 4 files changed, 42 insertions(+), 10 deletions(-) diff --git a/arkouda/groupbyclass.py b/arkouda/groupbyclass.py index a599a74124..1c71f1267e 100644 --- a/arkouda/groupbyclass.py +++ b/arkouda/groupbyclass.py @@ -645,8 +645,7 @@ def argmin(self, values: pdarray) -> Tuple[groupable, pdarray]: >>> g.argmin(b) (array([2, 3, 4]), array([5, 4, 2])) """ - if values.dtype == bool: - raise TypeError("argmin is only supported for pdarrays of dtype float64, uint64, and int64") + return self.aggregate(values, "argmin") def argmax(self, values: pdarray) -> Tuple[groupable, pdarray]: @@ -695,8 +694,7 @@ def argmax(self, values: pdarray) -> Tuple[groupable, pdarray]: >>> g.argmax(b) (array([2, 3, 4]), array([9, 3, 2])) """ - if values.dtype == bool: - raise TypeError("argmax is only supported for pdarrays of dtype float64, uint64, and int64") + return self.aggregate(values, "argmax") def nunique(self, values: groupable) -> Tuple[groupable, pdarray]: diff --git a/src/ReductionMsg.chpl b/src/ReductionMsg.chpl index ccb4ed8d77..71ee20a400 100644 --- a/src/ReductionMsg.chpl +++ b/src/ReductionMsg.chpl @@ -241,6 +241,14 @@ module ReductionMsg if (| reduce e.a) { val = "True"; } else { val = "False"; } repMsg = "bool %s".format(val); } + when "argmax" { + var (maxVal, maxLoc) = maxloc reduce zip(e.a,e.aD); + repMsg = "int64 %i".format(maxLoc); + } + when "argmin" { + var (minVal, minLoc) = minloc reduce zip(e.a,e.aD); + repMsg = "int64 %i".format(minLoc); + } otherwise { var errorMsg = notImplementedError(pn,reductionop,gEnt.dtype); rmLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); @@ -485,6 +493,14 @@ module ReductionMsg var res = segMean(values.a, segments.a); st.addEntry(rname, new shared SymEntry(res)); } + when "argmin" { + var (vals, locs) = segArgmin(values.a, segments.a); + st.addEntry(rname, new shared SymEntry(locs)); + } + when "argmax" { + var (vals, locs) = segArgmax(values.a, segments.a); + st.addEntry(rname, new shared SymEntry(locs)); + } otherwise { var errorMsg = notImplementedError(pn,op,gVal.dtype); rmLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); diff --git a/tests/extrema_test.py b/tests/extrema_test.py index d5b8f98062..b48717a521 100644 --- a/tests/extrema_test.py +++ b/tests/extrema_test.py @@ -172,3 +172,15 @@ def test_error_handling(self): with self.assertRaises(ValueError): ak.argmaxk(ak.array([]), 1) + +class ArgMinTest(ArkoudaTest): + def test_argmin(self): + np_arr = np.array([False, False, True, True, False]) + ak_arr = ak.array(np_arr) + self.assertTrue(np_arr.argmin() == ak_arr.argmin()) + +class ArgMaxTest(ArkoudaTest): + def test_argmax(self): + np_arr = np.array([False, False, True, True, False]) + ak_arr = ak.array(np_arr) + self.assertTrue(np_arr.argmax() == ak_arr.argmax()) diff --git a/tests/groupby_test.py b/tests/groupby_test.py index 7c5e12615b..777e283e0b 100755 --- a/tests/groupby_test.py +++ b/tests/groupby_test.py @@ -166,6 +166,18 @@ def test_groupby_on_two_levels(self): """ self.assertEqual(0, run_test(2, verbose)) + def test_argmax_argmin(self): + b = ak.array([True, False, True, True, False, True]) + x = ak.array([True, True, False, True, False, False]) + g = ak.GroupBy(x) + keys, locs = g.argmin(b) + self.assertListEqual(keys.to_ndarray().tolist(), [False, True]) + self.assertListEqual(locs.to_ndarray().tolist(), [4, 1]) + + keys, locs = g.argmax(b) + self.assertListEqual(keys.to_ndarray().tolist(), [False, True]) + self.assertListEqual(locs.to_ndarray().tolist(), [2, 0]) + def test_boolean_arrays(self): a = ak.array([True, False, True, True, False]) true_ct = a.sum() @@ -332,12 +344,6 @@ def test_error_handling(self): with self.assertRaises(TypeError): self.igb.max(ak.randint(0, 1, 10, dtype=bool)) - with self.assertRaises(TypeError): - self.igb.argmin(ak.randint(0, 1, 10, dtype=bool)) - - with self.assertRaises(TypeError): - self.igb.argmax(ak.randint(0, 1, 10, dtype=bool)) - def test_aggregate_strings(self): s = ak.array(["a", "b", "a", "b", "c"]) i = ak.arange(s.size)