17
17
Callable ,
18
18
Literal ,
19
19
TypedDict ,
20
+ TypeVar ,
20
21
Union ,
21
22
overload ,
22
23
)
96
97
T_MethodOpt = None | Literal ["map-reduce" , "blockwise" , "cohorts" ]
97
98
T_IsBins = Union [bool | Sequence [bool ]]
98
99
100
+ T = TypeVar ("T" )
99
101
100
102
IntermediateDict = dict [Union [str , Callable ], Any ]
101
103
FinalResultsDict = dict [str , Union ["DaskArray" , "CubedArray" , np .ndarray ]]
@@ -140,6 +142,10 @@ def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
140
142
return result
141
143
142
144
145
+ def identity (x : T ) -> T :
146
+ return x
147
+
148
+
143
149
def _issorted (arr : np .ndarray ) -> bool :
144
150
return bool ((arr [:- 1 ] <= arr [1 :]).all ())
145
151
@@ -1438,7 +1444,10 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:
1438
1444
1439
1445
1440
1446
def subset_to_blocks (
1441
- array : DaskArray , flatblocks : Sequence [int ], blkshape : tuple [int ] | None = None
1447
+ array : DaskArray ,
1448
+ flatblocks : Sequence [int ],
1449
+ blkshape : tuple [int ] | None = None ,
1450
+ reindexer = identity ,
1442
1451
) -> DaskArray :
1443
1452
"""
1444
1453
Advanced indexing of .blocks such that we always get a regular array back.
@@ -1464,20 +1473,21 @@ def subset_to_blocks(
1464
1473
index = _normalize_indexes (array , flatblocks , blkshape )
1465
1474
1466
1475
if all (not isinstance (i , np .ndarray ) and i == slice (None ) for i in index ):
1467
- return array
1476
+ return dask . array . map_blocks ( reindexer , array , meta = array . _meta )
1468
1477
1469
1478
# These rest is copied from dask.array.core.py with slight modifications
1470
1479
index = normalize_index (index , array .numblocks )
1471
1480
index = tuple (slice (k , k + 1 ) if isinstance (k , Integral ) else k for k in index )
1472
1481
1473
- name = "blocks -" + tokenize (array , index )
1482
+ name = "groupby-cohort -" + tokenize (array , index )
1474
1483
new_keys = array ._key_array [index ]
1475
1484
1476
1485
squeezed = tuple (np .squeeze (i ) if isinstance (i , np .ndarray ) else i for i in index )
1477
1486
chunks = tuple (tuple (np .array (c )[i ].tolist ()) for c , i in zip (array .chunks , squeezed ))
1478
1487
1479
1488
keys = itertools .product (* (range (len (c )) for c in chunks ))
1480
- layer : Graph = {(name ,) + key : tuple (new_keys [key ].tolist ()) for key in keys }
1489
+ layer : Graph = {(name ,) + key : (reindexer , tuple (new_keys [key ].tolist ())) for key in keys }
1490
+
1481
1491
graph = HighLevelGraph .from_collections (name , layer , dependencies = [array ])
1482
1492
1483
1493
return dask .array .Array (graph , name , chunks , meta = array )
@@ -1651,26 +1661,26 @@ def dask_groupby_agg(
1651
1661
1652
1662
elif method == "cohorts" :
1653
1663
assert chunks_cohorts
1664
+ block_shape = array .blocks .shape [- len (axis ) :]
1665
+
1654
1666
reduced_ = []
1655
1667
groups_ = []
1656
1668
for blks , cohort in chunks_cohorts .items ():
1657
- index = pd .Index (cohort )
1658
- subset = subset_to_blocks (intermediate , blks , array .blocks .shape [- len (axis ) :])
1659
- reindexed = dask .array .map_blocks (
1660
- reindex_intermediates , subset , agg , index , meta = subset ._meta
1661
- )
1669
+ cohort_index = pd .Index (cohort )
1670
+ reindexer = partial (reindex_intermediates , agg = agg , unique_groups = cohort_index )
1671
+ reindexed = subset_to_blocks (intermediate , blks , block_shape , reindexer )
1662
1672
# now that we have reindexed, we can set reindex=True explicitlly
1663
1673
reduced_ .append (
1664
1674
tree_reduce (
1665
1675
reindexed ,
1666
1676
combine = partial (combine , agg = agg , reindex = True ),
1667
- aggregate = partial (aggregate , expected_groups = index , reindex = True ),
1677
+ aggregate = partial (aggregate , expected_groups = cohort_index , reindex = True ),
1668
1678
)
1669
1679
)
1670
1680
# This is done because pandas promotes to 64-bit types when an Index is created
1671
1681
# So we use the index to generate the return value for consistency with "map-reduce"
1672
1682
# This is important on windows
1673
- groups_ .append (index .values )
1683
+ groups_ .append (cohort_index .values )
1674
1684
1675
1685
reduced = dask .array .concatenate (reduced_ , axis = - 1 )
1676
1686
groups = (np .concatenate (groups_ ),)
0 commit comments