1
+ import functools
1
2
from typing import Any , Callable , Hashable , Iterable , Optional , Tuple , Union
2
3
3
- import dask .array as da
4
4
import numpy as np
5
5
from xarray import Dataset
6
6
7
+ import sgkit .distarray as da
7
8
from sgkit import variables
8
9
from sgkit .model import get_contigs , num_contigs
9
10
from sgkit .utils import conditional_merge_datasets , create_dataset
@@ -510,8 +511,15 @@ def window_statistic(
510
511
and len (window_stops ) == 1
511
512
and window_stops == np .array ([values .shape [0 ]])
512
513
):
514
+ out = da .map_blocks (
515
+ functools .partial (statistic , ** kwargs ),
516
+ values ,
517
+ dtype = dtype ,
518
+ chunks = values .chunks [1 :],
519
+ drop_axis = 0 ,
520
+ )
513
521
# call expand_dims to add back window dimension (size 1)
514
- return da .expand_dims (statistic ( values , ** kwargs ) , axis = 0 )
522
+ return da .expand_dims (out , axis = 0 )
515
523
516
524
window_lengths = window_stops - window_starts
517
525
depth = np .max (window_lengths ) # type: ignore[no-untyped-call]
@@ -536,10 +544,10 @@ def window_statistic(
536
544
537
545
chunk_offsets = _sizes_to_start_offsets (windows_per_chunk )
538
546
539
- def blockwise_moving_stat (x : ArrayLike , block_info : Any = None ) -> ArrayLike :
540
- if block_info is None or len ( block_info ) == 0 :
547
+ def blockwise_moving_stat (x : ArrayLike , block_id : Any = None ) -> ArrayLike :
548
+ if block_id is None :
541
549
return np .array ([])
542
- chunk_number = block_info [ 0 ][ "chunk-location" ] [0 ]
550
+ chunk_number = block_id [0 ]
543
551
chunk_offset_start = chunk_offsets [chunk_number ]
544
552
chunk_offset_stop = chunk_offsets [chunk_number + 1 ]
545
553
chunk_window_starts = rel_window_starts [chunk_offset_start :chunk_offset_stop ]
@@ -559,8 +567,9 @@ def blockwise_moving_stat(x: ArrayLike, block_info: Any = None) -> ArrayLike:
559
567
depth = {0 : depth }
560
568
# new chunks are same except in first axis
561
569
new_chunks = tuple ([tuple (windows_per_chunk )] + list (desired_chunks [1 :])) # type: ignore
562
- return values .map_overlap (
570
+ return da .map_overlap (
563
571
blockwise_moving_stat ,
572
+ values ,
564
573
dtype = dtype ,
565
574
chunks = new_chunks ,
566
575
depth = depth ,
0 commit comments