From e479ed03f605438031281c88c70ea094e94c89e9 Mon Sep 17 00:00:00 2001 From: reuster986 Date: Tue, 15 Mar 2022 16:07:30 -0400 Subject: [PATCH] Consolidate and standardize grouping API (#1212) * formalized grouping API * reorganize * fix typing --- arkouda/categorical.py | 9 ++++ arkouda/groupbyclass.py | 111 +++++++++++++++++++--------------------- arkouda/pdarrayclass.py | 12 +++++ arkouda/strings.py | 10 ++++ 4 files changed, 84 insertions(+), 58 deletions(-) diff --git a/arkouda/categorical.py b/arkouda/categorical.py index 2efacd1914..b66256b405 100644 --- a/arkouda/categorical.py +++ b/arkouda/categorical.py @@ -495,6 +495,15 @@ def group(self) -> pdarray: else: return self.permutation + def _get_grouping_keys(self): + ''' + Private method for generating grouping keys used by GroupBy. + + API: this method must be defined by all groupable arrays, and it + must return a list of arrays that can be (co)argsorted. + ''' + return [self.codes] + def argsort(self): #__doc__ = argsort.__doc__ idxperm = argsort(self.categories) diff --git a/arkouda/groupbyclass.py b/arkouda/groupbyclass.py index 4ab47d79db..ec78377feb 100644 --- a/arkouda/groupbyclass.py +++ b/arkouda/groupbyclass.py @@ -59,7 +59,7 @@ class GroupBy: Parameters ---------- - keys : (list of) pdarray, int64, Strings, or Categorical + keys : (list of) pdarray, Strings, or Categorical The array to group by value, or if list, the column arrays to group by row assume_sorted : bool If True, assume keys is already sorted (Default: False) @@ -88,7 +88,17 @@ class GroupBy: Notes ----- - Only accepts (list of) pdarrays of int64 dtype, Strings, or Categorical. + Integral pdarrays, Strings, and Categoricals are natively supported, but + float64 and bool arrays are not. + + For a user-defined class to be groupable, it must inherit from pdarray + and define or overload the grouping API: + 1) a ._get_grouping_keys() method that returns a list of pdarrays + that can be (co)argsorted. + 2) (Optional) a .group() method that returns the permutation that + groups the array + If the input is a single array with a .group() method defined, method 2 + will be used; otherwise, method 1 will be used. """ Reductions = GROUPBY_REDUCTION_TYPES @@ -100,38 +110,45 @@ def __init__(self, keys: groupable, self.assume_sorted = assume_sorted self.hash_strings = hash_strings self.keys : groupable + self.permutation : pdarray - if isinstance(keys, pdarray): - if keys.dtype != int64 and keys.dtype != uint64: - raise TypeError('GroupBy only supports pdarrays with a dtype int64 or uint64') - self.keys = cast(pdarray, keys) + # Get all grouping keys, even if not required for finding permutation + # They will be required later for finding segment boundaries + if hasattr(keys, "_get_grouping_keys"): + # Single groupable array self.nkeys = 1 - self.size = cast(int, keys.size) - if assume_sorted: - self.permutation = cast(pdarray, arange(self.size)) - else: - self.permutation = cast(pdarray, argsort(keys)) - elif hasattr(keys, "group"): # for Strings or Categorical - self.nkeys = 1 - self.keys = cast(Union[Strings,Categorical],keys) - self.size = cast(int, self.keys.size) # type: ignore - if assume_sorted: - self.permutation = cast(pdarray,arange(self.size)) - else: - self.permutation = cast(Union[Strings, Categorical],keys).group() + self.keys = cast(groupable_element_type, keys) + self.size = cast(int, self.keys.size) + self._grouping_keys = self.keys._get_grouping_keys() else: - self.keys = cast(Sequence[groupable_element_type],keys) - self.nkeys = len(keys) - self.size = cast(int,keys[0].size) # type: ignore - for k in keys: + # Sequence of groupable arrays + # Because of type checking, this is the only other possibility + self.keys = cast(Sequence[groupable_element_type], keys) + self.nkeys = len(self.keys) + self.size = cast(int, self.keys[0].size) + self._grouping_keys = [] + for k in self.keys: if k.size != self.size: raise ValueError("Key arrays must all be same size") - if assume_sorted: - self.permutation = cast(pdarray, arange(self.size)) - else: - self.permutation = cast(pdarray, coargsort(cast(Sequence[pdarray],keys))) - - # self.permuted_keys = self.keys[self.permutation] + if not hasattr(k, "_get_grouping_keys"): + # Type checks should ensure we never get here + raise TypeError("{} does not support grouping".format(type(k))) + self._grouping_keys.extend(cast(list, k._get_grouping_keys())) + # Get permutation + if assume_sorted: + # Permutation is identity + self.permutation = cast(pdarray, arange(self.size)) + elif hasattr(self.keys, "group"): + # If an object wants to group itself (e.g. Categoricals), + # let it set the permutation + perm = self.keys.group() # type: ignore + self.permutation = cast(pdarray, perm) + elif len(self._grouping_keys) == 1: + self.permutation = cast(pdarray, argsort(self._grouping_keys[0])) + else: + self.permutation = cast(pdarray, coargsort(self._grouping_keys)) + + # Finally, get segment offsets and unique keys self.find_segments() def find_segments(self) -> None: @@ -140,41 +157,17 @@ def find_segments(self) -> None: if self.nkeys == 1: # for Categorical + # Most categoricals already store segments and unique keys if hasattr(self.keys, 'segments') and cast(Categorical, self.keys).segments is not None: self.unique_keys = cast(Categorical, self.keys).categories self.segments = cast(pdarray, cast(Categorical, self.keys).segments) self.ngroups = self.unique_keys.size return - else: - mykeys = [self.keys] - else: - mykeys = cast(List[pdarray], self.keys) # type: ignore - keyobjs : List[groupable_element_type] = [] # needed to maintain obj refs esp for h1 and h2 in the strings case - keynames = [] - keytypes = [] - effectiveKeys = self.nkeys - for k in mykeys: - if isinstance(k, Strings): - if self.hash_strings: - h1, h2 = k.hash() - keyobjs.extend([h1,h2]) - keynames.extend([h1.name, h2.name]) - keytypes.extend([h1.objtype, h2.objtype]) - effectiveKeys += 1 - else: - keyobjs.append(k) - keynames.append(k.entry.name) - keytypes.append(k.objtype) - # for Categorical - elif hasattr(k, 'codes'): - keyobjs.append(cast(Categorical, k)) - keynames.append(cast(Categorical,k).codes.name) - keytypes.append(cast(Categorical,k).codes.objtype) - elif isinstance(k, pdarray): - keyobjs.append(k) - keynames.append(k.name) - keytypes.append(k.objtype) + + keynames = [k.name for k in self._grouping_keys] + keytypes = [k.objtype for k in self._grouping_keys] + effectiveKeys = len(self._grouping_keys) args = "{} {:n} {} {}".format(self.permutation.name, effectiveKeys, ' '.join(keynames), @@ -192,6 +185,8 @@ def find_segments(self) -> None: self.unique_keys = cast(groupable, [k[unique_key_indices] for k in self.keys]) self.ngroups = self.unique_keys[0].size + # Free up memory, because _grouping_keys are not user-facing and no longer needed + del self._grouping_keys def count(self) -> Tuple[groupable,pdarray]: diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index 003f806a04..580bc57f20 100755 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -1298,6 +1298,18 @@ class method to return a pdarray attached to the registered name in the arkouda """ return attach_pdarray(user_defined_name) + def _get_grouping_keys(self) -> List[pdarray]: + ''' + Private method for generating grouping keys used by GroupBy. + + API: this method must be defined by all groupable arrays, and it + must return a list of arrays that can be (co)argsorted. + ''' + if self.dtype not in (akint64, akuint64): + raise TypeError("Grouping numeric data is only supported on integral types.") + # Integral pdarrays are their own grouping keys + return [self] + #end pdarray class def # creates pdarray object diff --git a/arkouda/strings.py b/arkouda/strings.py index eae060cf48..a9438ce060 100755 --- a/arkouda/strings.py +++ b/arkouda/strings.py @@ -1183,6 +1183,15 @@ def group(self) -> pdarray: args = "{} {}".format(self.objtype, self.entry.name) return create_pdarray(generic_msg(cmd=cmd,args=args)) + def _get_grouping_keys(self) -> List[pdarray]: + ''' + Private method for generating grouping keys used by GroupBy. + + API: this method must be defined by all groupable arrays, and it + must return a list of arrays that can be (co)argsorted. + ''' + return list(self.hash()) + def to_ndarray(self) -> np.ndarray: """ Convert the array to a np.ndarray, transferring array data from the @@ -1559,3 +1568,4 @@ def unregister_strings_by_name(user_defined_name : str) -> None: register, unregister, attach, is_registered """ unregister_pdarray_by_name(user_defined_name) +