Skip to content

Commit

Permalink
Consolidate and standardize grouping API (#1212)
Browse files Browse the repository at this point in the history
* formalized grouping API

* reorganize

* fix typing
  • Loading branch information
reuster986 committed Mar 15, 2022
1 parent 960c6d2 commit e479ed0
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 58 deletions.
9 changes: 9 additions & 0 deletions arkouda/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,15 @@ def group(self) -> pdarray:
else:
return self.permutation

def _get_grouping_keys(self):
'''
Private method for generating grouping keys used by GroupBy.
API: this method must be defined by all groupable arrays, and it
must return a list of arrays that can be (co)argsorted.
'''
return [self.codes]

def argsort(self):
#__doc__ = argsort.__doc__
idxperm = argsort(self.categories)
Expand Down
111 changes: 53 additions & 58 deletions arkouda/groupbyclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class GroupBy:
Parameters
----------
keys : (list of) pdarray, int64, Strings, or Categorical
keys : (list of) pdarray, Strings, or Categorical
The array to group by value, or if list, the column arrays to group by row
assume_sorted : bool
If True, assume keys is already sorted (Default: False)
Expand Down Expand Up @@ -88,7 +88,17 @@ class GroupBy:
Notes
-----
Only accepts (list of) pdarrays of int64 dtype, Strings, or Categorical.
Integral pdarrays, Strings, and Categoricals are natively supported, but
float64 and bool arrays are not.
For a user-defined class to be groupable, it must inherit from pdarray
and define or overload the grouping API:
1) a ._get_grouping_keys() method that returns a list of pdarrays
that can be (co)argsorted.
2) (Optional) a .group() method that returns the permutation that
groups the array
If the input is a single array with a .group() method defined, method 2
will be used; otherwise, method 1 will be used.
"""
Reductions = GROUPBY_REDUCTION_TYPES
Expand All @@ -100,38 +110,45 @@ def __init__(self, keys: groupable,
self.assume_sorted = assume_sorted
self.hash_strings = hash_strings
self.keys : groupable
self.permutation : pdarray

if isinstance(keys, pdarray):
if keys.dtype != int64 and keys.dtype != uint64:
raise TypeError('GroupBy only supports pdarrays with a dtype int64 or uint64')
self.keys = cast(pdarray, keys)
# Get all grouping keys, even if not required for finding permutation
# They will be required later for finding segment boundaries
if hasattr(keys, "_get_grouping_keys"):
# Single groupable array
self.nkeys = 1
self.size = cast(int, keys.size)
if assume_sorted:
self.permutation = cast(pdarray, arange(self.size))
else:
self.permutation = cast(pdarray, argsort(keys))
elif hasattr(keys, "group"): # for Strings or Categorical
self.nkeys = 1
self.keys = cast(Union[Strings,Categorical],keys)
self.size = cast(int, self.keys.size) # type: ignore
if assume_sorted:
self.permutation = cast(pdarray,arange(self.size))
else:
self.permutation = cast(Union[Strings, Categorical],keys).group()
self.keys = cast(groupable_element_type, keys)
self.size = cast(int, self.keys.size)
self._grouping_keys = self.keys._get_grouping_keys()
else:
self.keys = cast(Sequence[groupable_element_type],keys)
self.nkeys = len(keys)
self.size = cast(int,keys[0].size) # type: ignore
for k in keys:
# Sequence of groupable arrays
# Because of type checking, this is the only other possibility
self.keys = cast(Sequence[groupable_element_type], keys)
self.nkeys = len(self.keys)
self.size = cast(int, self.keys[0].size)
self._grouping_keys = []
for k in self.keys:
if k.size != self.size:
raise ValueError("Key arrays must all be same size")
if assume_sorted:
self.permutation = cast(pdarray, arange(self.size))
else:
self.permutation = cast(pdarray, coargsort(cast(Sequence[pdarray],keys)))

# self.permuted_keys = self.keys[self.permutation]
if not hasattr(k, "_get_grouping_keys"):
# Type checks should ensure we never get here
raise TypeError("{} does not support grouping".format(type(k)))
self._grouping_keys.extend(cast(list, k._get_grouping_keys()))
# Get permutation
if assume_sorted:
# Permutation is identity
self.permutation = cast(pdarray, arange(self.size))
elif hasattr(self.keys, "group"):
# If an object wants to group itself (e.g. Categoricals),
# let it set the permutation
perm = self.keys.group() # type: ignore
self.permutation = cast(pdarray, perm)
elif len(self._grouping_keys) == 1:
self.permutation = cast(pdarray, argsort(self._grouping_keys[0]))
else:
self.permutation = cast(pdarray, coargsort(self._grouping_keys))

# Finally, get segment offsets and unique keys
self.find_segments()

def find_segments(self) -> None:
Expand All @@ -140,41 +157,17 @@ def find_segments(self) -> None:

if self.nkeys == 1:
# for Categorical
# Most categoricals already store segments and unique keys
if hasattr(self.keys, 'segments') and cast(Categorical,
self.keys).segments is not None:
self.unique_keys = cast(Categorical, self.keys).categories
self.segments = cast(pdarray, cast(Categorical, self.keys).segments)
self.ngroups = self.unique_keys.size
return
else:
mykeys = [self.keys]
else:
mykeys = cast(List[pdarray], self.keys) # type: ignore
keyobjs : List[groupable_element_type] = [] # needed to maintain obj refs esp for h1 and h2 in the strings case
keynames = []
keytypes = []
effectiveKeys = self.nkeys
for k in mykeys:
if isinstance(k, Strings):
if self.hash_strings:
h1, h2 = k.hash()
keyobjs.extend([h1,h2])
keynames.extend([h1.name, h2.name])
keytypes.extend([h1.objtype, h2.objtype])
effectiveKeys += 1
else:
keyobjs.append(k)
keynames.append(k.entry.name)
keytypes.append(k.objtype)
# for Categorical
elif hasattr(k, 'codes'):
keyobjs.append(cast(Categorical, k))
keynames.append(cast(Categorical,k).codes.name)
keytypes.append(cast(Categorical,k).codes.objtype)
elif isinstance(k, pdarray):
keyobjs.append(k)
keynames.append(k.name)
keytypes.append(k.objtype)

keynames = [k.name for k in self._grouping_keys]
keytypes = [k.objtype for k in self._grouping_keys]
effectiveKeys = len(self._grouping_keys)
args = "{} {:n} {} {}".format(self.permutation.name,
effectiveKeys,
' '.join(keynames),
Expand All @@ -192,6 +185,8 @@ def find_segments(self) -> None:
self.unique_keys = cast(groupable,
[k[unique_key_indices] for k in self.keys])
self.ngroups = self.unique_keys[0].size
# Free up memory, because _grouping_keys are not user-facing and no longer needed
del self._grouping_keys


def count(self) -> Tuple[groupable,pdarray]:
Expand Down
12 changes: 12 additions & 0 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,18 @@ class method to return a pdarray attached to the registered name in the arkouda
"""
return attach_pdarray(user_defined_name)

def _get_grouping_keys(self) -> List[pdarray]:
'''
Private method for generating grouping keys used by GroupBy.
API: this method must be defined by all groupable arrays, and it
must return a list of arrays that can be (co)argsorted.
'''
if self.dtype not in (akint64, akuint64):
raise TypeError("Grouping numeric data is only supported on integral types.")
# Integral pdarrays are their own grouping keys
return [self]

#end pdarray class def

# creates pdarray object
Expand Down
10 changes: 10 additions & 0 deletions arkouda/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,15 @@ def group(self) -> pdarray:
args = "{} {}".format(self.objtype, self.entry.name)
return create_pdarray(generic_msg(cmd=cmd,args=args))

def _get_grouping_keys(self) -> List[pdarray]:
'''
Private method for generating grouping keys used by GroupBy.
API: this method must be defined by all groupable arrays, and it
must return a list of arrays that can be (co)argsorted.
'''
return list(self.hash())

def to_ndarray(self) -> np.ndarray:
"""
Convert the array to a np.ndarray, transferring array data from the
Expand Down Expand Up @@ -1559,3 +1568,4 @@ def unregister_strings_by_name(user_defined_name : str) -> None:
register, unregister, attach, is_registered
"""
unregister_pdarray_by_name(user_defined_name)

0 comments on commit e479ed0

Please sign in to comment.