1
1
import builtins
2
2
import math
3
3
import numbers
4
+ from dataclasses import dataclass
4
5
from functools import partial
5
6
from itertools import product
6
7
from numbers import Integral , Number
25
26
from cubed .primitive .rechunk import rechunk as primitive_rechunk
26
27
from cubed .spec import spec_from_config
27
28
from cubed .storage .backend import open_backend_array
29
+ from cubed .types import T_RegularChunks , T_Shape
28
30
from cubed .utils import (
29
31
_concatenate2 ,
30
32
array_memory ,
33
+ array_size ,
31
34
get_item ,
32
35
offset_to_block_id ,
33
36
to_chunksize ,
@@ -484,6 +487,11 @@ def merged_chunk_len_for_indexer(ia, c):
484
487
if shape == x .shape :
485
488
# no op case (except possibly newaxis applied below)
486
489
out = x
490
+ elif array_size (shape ) == 0 :
491
+ # empty output case
492
+ from cubed .array_api .creation_functions import empty
493
+
494
+ out = empty (shape , dtype = x .dtype , spec = x .spec )
487
495
else :
488
496
dtype = x .dtype
489
497
chunks = tuple (
@@ -494,21 +502,68 @@ def merged_chunk_len_for_indexer(ia, c):
494
502
495
503
target_chunks = normalize_chunks (chunks , shape , dtype = dtype )
496
504
497
- # memory allocated by reading one chunk from input array
498
- # note that although the output chunk will overlap multiple input chunks, zarr will
499
- # read the chunks in series, reusing the buffer
500
- extra_projected_mem = x .chunkmem
505
+ if _is_chunk_aligned_selection (idx ):
506
+ # use general_blockwise, which allows more opportunities for optimization than map_direct
501
507
502
- out = map_direct (
503
- _read_index_chunk ,
504
- x ,
505
- shape = shape ,
506
- dtype = dtype ,
507
- chunks = target_chunks ,
508
- extra_projected_mem = extra_projected_mem ,
509
- target_chunks = target_chunks ,
510
- selection = selection ,
511
- )
508
+ from cubed .array_api .creation_functions import offsets_virtual_array
509
+
510
+ # general_blockwise doesn't support block_id, so emulate it ourselves
511
+ numblocks = tuple (map (len , target_chunks ))
512
+ offsets = offsets_virtual_array (numblocks , x .spec )
513
+
514
+ def key_function (out_key ):
515
+ out_coords = out_key [1 :]
516
+
517
+ # compute the selection on x required to get the relevant chunk for out_coords
518
+ in_sel = _target_chunk_selection (target_chunks , out_coords , selection )
519
+
520
+ # use a Zarr BasicIndexer to convert this to input coordinates
521
+ indexer = create_basic_indexer (
522
+ in_sel , x .zarray_maybe_lazy .shape , x .zarray_maybe_lazy .chunks
523
+ )
524
+
525
+ offset_in_key = ((offsets .name ,) + out_coords ,)
526
+ return (
527
+ tuple ((x .name ,) + chunk_coords for (chunk_coords , _ , _ ) in indexer )
528
+ + offset_in_key
529
+ )
530
+
531
+ # since selection is chunk-aligned, we know that we only read one block of x
532
+ num_input_blocks = (1 , 1 ) # x, offsets
533
+
534
+ out = general_blockwise (
535
+ _assemble_index_chunk ,
536
+ key_function ,
537
+ x ,
538
+ offsets ,
539
+ shapes = [shape ],
540
+ dtypes = [x .dtype ],
541
+ chunkss = [target_chunks ],
542
+ num_input_blocks = num_input_blocks ,
543
+ target_chunks = target_chunks ,
544
+ selection = selection ,
545
+ in_shape = x .shape ,
546
+ in_chunksize = x .chunksize ,
547
+ )
548
+ else :
549
+ # use map_direct, which can't be fused
550
+ # (note that it should be possible to re-write as general_blockwise with more work)
551
+
552
+ # memory allocated by reading one chunk from input array
553
+ # note that although the output chunk will overlap multiple input chunks, zarr will
554
+ # read the chunks in series, reusing the buffer
555
+ extra_projected_mem = x .chunkmem
556
+
557
+ out = map_direct (
558
+ _read_index_chunk ,
559
+ x ,
560
+ shape = shape ,
561
+ dtype = dtype ,
562
+ chunks = target_chunks ,
563
+ extra_projected_mem = extra_projected_mem ,
564
+ target_chunks = target_chunks ,
565
+ selection = selection ,
566
+ )
512
567
513
568
# merge chunks for any dims with step > 1 so they are
514
569
# the same size as the input (or slightly smaller due to rounding)
@@ -528,6 +583,72 @@ def merged_chunk_len_for_indexer(ia, c):
528
583
return out
529
584
530
585
586
+ def _is_chunk_aligned_selection (idx : ndindex .Tuple ):
587
+ return all (
588
+ isinstance (ia , ndindex .Integer )
589
+ or (
590
+ isinstance (ia , ndindex .Slice )
591
+ and ia .start == 0
592
+ and (ia .step is None or ia .step == 1 )
593
+ )
594
+ for ia in idx .args
595
+ )
596
+
597
+
598
+ def create_basic_indexer (selection , shape , chunks ):
599
+ if zarr .__version__ [0 ] == "3" :
600
+ from zarr .core .chunk_grids import RegularChunkGrid
601
+ from zarr .core .indexing import BasicIndexer
602
+
603
+ return BasicIndexer (selection , shape , RegularChunkGrid (chunk_shape = chunks ))
604
+ else :
605
+ from zarr .indexing import BasicIndexer
606
+
607
+ return BasicIndexer (selection , ZarrArrayIndexingAdaptor (shape , chunks ))
608
+
609
+
610
+ @dataclass
611
+ class ZarrArrayIndexingAdaptor :
612
+ _shape : T_Shape
613
+ _chunks : T_RegularChunks
614
+
615
+ @classmethod
616
+ def from_zarr_array (cls , zarray ):
617
+ return cls (zarray .shape , zarray .chunks )
618
+
619
+
620
+ def _assemble_index_chunk (
621
+ * arrs ,
622
+ target_chunks = None ,
623
+ selection = None ,
624
+ in_shape = None ,
625
+ in_chunksize = None ,
626
+ ):
627
+ # last array contains the offset for the block_id
628
+ offset = int (arrs [- 1 ]) # convert from 0-d array
629
+ numblocks = tuple (map (len , target_chunks ))
630
+ block_id = offset_to_block_id (offset , numblocks )
631
+
632
+ arrs = arrs [:- 1 ] # drop offset array
633
+
634
+ # compute the selection on x required to get the relevant chunk for out_coords
635
+ out_coords = block_id
636
+ in_sel = _target_chunk_selection (target_chunks , out_coords , selection )
637
+
638
+ # use a Zarr BasicIndexer to convert this to input coordinates
639
+ indexer = create_basic_indexer (in_sel , in_shape , in_chunksize )
640
+
641
+ shape = indexer .shape
642
+ out = np .empty_like (arrs [0 ], shape = shape )
643
+
644
+ if array_size (shape ) > 0 :
645
+ _ , lchunk_selection , lout_selection = zip (* indexer )
646
+ for ai , chunk_select , out_select in zip (arrs , lchunk_selection , lout_selection ):
647
+ out [out_select ] = ai [chunk_select ]
648
+
649
+ return out
650
+
651
+
531
652
def _read_index_chunk (
532
653
x ,
533
654
* arrays ,
0 commit comments