@@ -2823,9 +2823,6 @@ def groupby_scan(
2823
2823
# nothing to do, no NaNs!
2824
2824
return array
2825
2825
2826
- is_bool_array = np .issubdtype (array .dtype , bool )
2827
- array = array .astype (np .int_ ) if is_bool_array else array
2828
-
2829
2826
if expected_groups is not None :
2830
2827
raise NotImplementedError ("Setting `expected_groups` and binning is not supported yet." )
2831
2828
expected_groups = _validate_expected_groups (nby , expected_groups )
@@ -2855,6 +2852,11 @@ def groupby_scan(
2855
2852
if array .dtype .kind in "Mm" :
2856
2853
cast_to = array .dtype
2857
2854
array = array .view (np .int64 )
2855
+ elif array .dtype .kind == "b" :
2856
+ array = array .view (np .int8 )
2857
+ cast_to = None
2858
+ if agg .preserves_dtype :
2859
+ cast_to = bool
2858
2860
else :
2859
2861
cast_to = None
2860
2862
@@ -2869,6 +2871,7 @@ def groupby_scan(
2869
2871
agg .dtype = np .result_type (array .dtype , np .uint )
2870
2872
else :
2871
2873
agg .dtype = array .dtype if dtype is None else dtype
2874
+ agg .identity = xrdtypes ._get_fill_value (agg .dtype , agg .identity )
2872
2875
2873
2876
(single_axis ,) = axis_ # type: ignore[misc]
2874
2877
# avoid some roundoff error when we can.
@@ -2887,7 +2890,7 @@ def groupby_scan(
2887
2890
2888
2891
if not has_dask :
2889
2892
final_state = chunk_scan (inp , axis = single_axis , agg = agg , dtype = agg .dtype )
2890
- result = _finalize_scan (final_state )
2893
+ result = _finalize_scan (final_state , dtype = agg . dtype )
2891
2894
else :
2892
2895
result = dask_groupby_scan (inp .array , inp .group_idx , axes = axis_ , agg = agg )
2893
2896
@@ -2940,9 +2943,9 @@ def _zip(group_idx: np.ndarray, array: np.ndarray) -> AlignedArrays:
2940
2943
return AlignedArrays (group_idx = group_idx , array = array )
2941
2944
2942
2945
2943
- def _finalize_scan (block : ScanState ) -> np .ndarray :
2946
+ def _finalize_scan (block : ScanState , dtype ) -> np .ndarray :
2944
2947
assert block .result is not None
2945
- return block .result .array
2948
+ return block .result .array . astype ( dtype , copy = False )
2946
2949
2947
2950
2948
2951
def dask_groupby_scan (array , by , axes : T_Axes , agg : Scan ) -> DaskArray :
@@ -2985,7 +2988,7 @@ def dask_groupby_scan(array, by, axes: T_Axes, agg: Scan) -> DaskArray:
2985
2988
)
2986
2989
2987
2990
# 3. Unzip and extract the final result array, discard groups
2988
- result = map_blocks (_finalize_scan , accumulated , dtype = agg .dtype )
2991
+ result = map_blocks (partial ( _finalize_scan , dtype = agg . dtype ) , accumulated , dtype = agg .dtype )
2989
2992
2990
2993
assert result .chunks == array .chunks
2991
2994
0 commit comments