Skip to content

Commit

Permalink
Closes #1452 - Support for argmin and argmax on bool values. (#…
Browse files Browse the repository at this point in the history
…1469)

* Adding argmax/argmin processing for boolean arrays. Adding testing.

Updating server messaging to support argmax/argmin for segment reductions. Updated Groupby to remove exception when bool pdarray passed for values. Added testing.

* REbase cleanup

* Correcting test case

* Removing test cases that are no longer valid.
  • Loading branch information
Ethan-DeBandi99 committed Jun 6, 2022
1 parent 8db89d0 commit 877747b
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 10 deletions.
6 changes: 2 additions & 4 deletions arkouda/groupbyclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,8 +645,7 @@ def argmin(self, values: pdarray) -> Tuple[groupable, pdarray]:
>>> g.argmin(b)
(array([2, 3, 4]), array([5, 4, 2]))
"""
if values.dtype == bool:
raise TypeError("argmin is only supported for pdarrays of dtype float64, uint64, and int64")

return self.aggregate(values, "argmin")

def argmax(self, values: pdarray) -> Tuple[groupable, pdarray]:
Expand Down Expand Up @@ -695,8 +694,7 @@ def argmax(self, values: pdarray) -> Tuple[groupable, pdarray]:
>>> g.argmax(b)
(array([2, 3, 4]), array([9, 3, 2]))
"""
if values.dtype == bool:
raise TypeError("argmax is only supported for pdarrays of dtype float64, uint64, and int64")

return self.aggregate(values, "argmax")

def nunique(self, values: groupable) -> Tuple[groupable, pdarray]:
Expand Down
16 changes: 16 additions & 0 deletions src/ReductionMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,14 @@ module ReductionMsg
if (| reduce e.a) { val = "True"; } else { val = "False"; }
repMsg = "bool %s".format(val);
}
when "argmax" {
var (maxVal, maxLoc) = maxloc reduce zip(e.a,e.aD);
repMsg = "int64 %i".format(maxLoc);
}
when "argmin" {
var (minVal, minLoc) = minloc reduce zip(e.a,e.aD);
repMsg = "int64 %i".format(minLoc);
}
otherwise {
var errorMsg = notImplementedError(pn,reductionop,gEnt.dtype);
rmLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down Expand Up @@ -485,6 +493,14 @@ module ReductionMsg
var res = segMean(values.a, segments.a);
st.addEntry(rname, new shared SymEntry(res));
}
when "argmin" {
var (vals, locs) = segArgmin(values.a, segments.a);
st.addEntry(rname, new shared SymEntry(locs));
}
when "argmax" {
var (vals, locs) = segArgmax(values.a, segments.a);
st.addEntry(rname, new shared SymEntry(locs));
}
otherwise {
var errorMsg = notImplementedError(pn,op,gVal.dtype);
rmLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
Expand Down
12 changes: 12 additions & 0 deletions tests/extrema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,15 @@ def test_error_handling(self):

with self.assertRaises(ValueError):
ak.argmaxk(ak.array([]), 1)

class ArgMinTest(ArkoudaTest):
def test_argmin(self):
np_arr = np.array([False, False, True, True, False])
ak_arr = ak.array(np_arr)
self.assertTrue(np_arr.argmin() == ak_arr.argmin())

class ArgMaxTest(ArkoudaTest):
def test_argmax(self):
np_arr = np.array([False, False, True, True, False])
ak_arr = ak.array(np_arr)
self.assertTrue(np_arr.argmax() == ak_arr.argmax())
18 changes: 12 additions & 6 deletions tests/groupby_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,18 @@ def test_groupby_on_two_levels(self):
"""
self.assertEqual(0, run_test(2, verbose))

def test_argmax_argmin(self):
b = ak.array([True, False, True, True, False, True])
x = ak.array([True, True, False, True, False, False])
g = ak.GroupBy(x)
keys, locs = g.argmin(b)
self.assertListEqual(keys.to_ndarray().tolist(), [False, True])
self.assertListEqual(locs.to_ndarray().tolist(), [4, 1])

keys, locs = g.argmax(b)
self.assertListEqual(keys.to_ndarray().tolist(), [False, True])
self.assertListEqual(locs.to_ndarray().tolist(), [2, 0])

def test_boolean_arrays(self):
a = ak.array([True, False, True, True, False])
true_ct = a.sum()
Expand Down Expand Up @@ -332,12 +344,6 @@ def test_error_handling(self):
with self.assertRaises(TypeError):
self.igb.max(ak.randint(0, 1, 10, dtype=bool))

with self.assertRaises(TypeError):
self.igb.argmin(ak.randint(0, 1, 10, dtype=bool))

with self.assertRaises(TypeError):
self.igb.argmax(ak.randint(0, 1, 10, dtype=bool))

def test_aggregate_strings(self):
s = ak.array(["a", "b", "a", "b", "c"])
i = ak.arange(s.size)
Expand Down

0 comments on commit 877747b

Please sign in to comment.