Skip to content

Commit f347d74

Browse files
authored
Use general_blockwise in case of chunk-aligned selections in index (#586)
This allows the optimizer to perform fusion, which is not the case with the existing `map_direct` implementation.
1 parent d0abfba commit f347d74

File tree

3 files changed

+151
-15
lines changed

3 files changed

+151
-15
lines changed

cubed/core/ops.py

Lines changed: 135 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import builtins
22
import math
33
import numbers
4+
from dataclasses import dataclass
45
from functools import partial
56
from itertools import product
67
from numbers import Integral, Number
@@ -25,9 +26,11 @@
2526
from cubed.primitive.rechunk import rechunk as primitive_rechunk
2627
from cubed.spec import spec_from_config
2728
from cubed.storage.backend import open_backend_array
29+
from cubed.types import T_RegularChunks, T_Shape
2830
from cubed.utils import (
2931
_concatenate2,
3032
array_memory,
33+
array_size,
3134
get_item,
3235
offset_to_block_id,
3336
to_chunksize,
@@ -484,6 +487,11 @@ def merged_chunk_len_for_indexer(ia, c):
484487
if shape == x.shape:
485488
# no op case (except possibly newaxis applied below)
486489
out = x
490+
elif array_size(shape) == 0:
491+
# empty output case
492+
from cubed.array_api.creation_functions import empty
493+
494+
out = empty(shape, dtype=x.dtype, spec=x.spec)
487495
else:
488496
dtype = x.dtype
489497
chunks = tuple(
@@ -494,21 +502,68 @@ def merged_chunk_len_for_indexer(ia, c):
494502

495503
target_chunks = normalize_chunks(chunks, shape, dtype=dtype)
496504

497-
# memory allocated by reading one chunk from input array
498-
# note that although the output chunk will overlap multiple input chunks, zarr will
499-
# read the chunks in series, reusing the buffer
500-
extra_projected_mem = x.chunkmem
505+
if _is_chunk_aligned_selection(idx):
506+
# use general_blockwise, which allows more opportunities for optimization than map_direct
501507

502-
out = map_direct(
503-
_read_index_chunk,
504-
x,
505-
shape=shape,
506-
dtype=dtype,
507-
chunks=target_chunks,
508-
extra_projected_mem=extra_projected_mem,
509-
target_chunks=target_chunks,
510-
selection=selection,
511-
)
508+
from cubed.array_api.creation_functions import offsets_virtual_array
509+
510+
# general_blockwise doesn't support block_id, so emulate it ourselves
511+
numblocks = tuple(map(len, target_chunks))
512+
offsets = offsets_virtual_array(numblocks, x.spec)
513+
514+
def key_function(out_key):
515+
out_coords = out_key[1:]
516+
517+
# compute the selection on x required to get the relevant chunk for out_coords
518+
in_sel = _target_chunk_selection(target_chunks, out_coords, selection)
519+
520+
# use a Zarr BasicIndexer to convert this to input coordinates
521+
indexer = create_basic_indexer(
522+
in_sel, x.zarray_maybe_lazy.shape, x.zarray_maybe_lazy.chunks
523+
)
524+
525+
offset_in_key = ((offsets.name,) + out_coords,)
526+
return (
527+
tuple((x.name,) + chunk_coords for (chunk_coords, _, _) in indexer)
528+
+ offset_in_key
529+
)
530+
531+
# since selection is chunk-aligned, we know that we only read one block of x
532+
num_input_blocks = (1, 1) # x, offsets
533+
534+
out = general_blockwise(
535+
_assemble_index_chunk,
536+
key_function,
537+
x,
538+
offsets,
539+
shapes=[shape],
540+
dtypes=[x.dtype],
541+
chunkss=[target_chunks],
542+
num_input_blocks=num_input_blocks,
543+
target_chunks=target_chunks,
544+
selection=selection,
545+
in_shape=x.shape,
546+
in_chunksize=x.chunksize,
547+
)
548+
else:
549+
# use map_direct, which can't be fused
550+
# (note that it should be possible to re-write as general_blockwise with more work)
551+
552+
# memory allocated by reading one chunk from input array
553+
# note that although the output chunk will overlap multiple input chunks, zarr will
554+
# read the chunks in series, reusing the buffer
555+
extra_projected_mem = x.chunkmem
556+
557+
out = map_direct(
558+
_read_index_chunk,
559+
x,
560+
shape=shape,
561+
dtype=dtype,
562+
chunks=target_chunks,
563+
extra_projected_mem=extra_projected_mem,
564+
target_chunks=target_chunks,
565+
selection=selection,
566+
)
512567

513568
# merge chunks for any dims with step > 1 so they are
514569
# the same size as the input (or slightly smaller due to rounding)
@@ -528,6 +583,72 @@ def merged_chunk_len_for_indexer(ia, c):
528583
return out
529584

530585

586+
def _is_chunk_aligned_selection(idx: ndindex.Tuple):
587+
return all(
588+
isinstance(ia, ndindex.Integer)
589+
or (
590+
isinstance(ia, ndindex.Slice)
591+
and ia.start == 0
592+
and (ia.step is None or ia.step == 1)
593+
)
594+
for ia in idx.args
595+
)
596+
597+
598+
def create_basic_indexer(selection, shape, chunks):
599+
if zarr.__version__[0] == "3":
600+
from zarr.core.chunk_grids import RegularChunkGrid
601+
from zarr.core.indexing import BasicIndexer
602+
603+
return BasicIndexer(selection, shape, RegularChunkGrid(chunk_shape=chunks))
604+
else:
605+
from zarr.indexing import BasicIndexer
606+
607+
return BasicIndexer(selection, ZarrArrayIndexingAdaptor(shape, chunks))
608+
609+
610+
@dataclass
611+
class ZarrArrayIndexingAdaptor:
612+
_shape: T_Shape
613+
_chunks: T_RegularChunks
614+
615+
@classmethod
616+
def from_zarr_array(cls, zarray):
617+
return cls(zarray.shape, zarray.chunks)
618+
619+
620+
def _assemble_index_chunk(
621+
*arrs,
622+
target_chunks=None,
623+
selection=None,
624+
in_shape=None,
625+
in_chunksize=None,
626+
):
627+
# last array contains the offset for the block_id
628+
offset = int(arrs[-1]) # convert from 0-d array
629+
numblocks = tuple(map(len, target_chunks))
630+
block_id = offset_to_block_id(offset, numblocks)
631+
632+
arrs = arrs[:-1] # drop offset array
633+
634+
# compute the selection on x required to get the relevant chunk for out_coords
635+
out_coords = block_id
636+
in_sel = _target_chunk_selection(target_chunks, out_coords, selection)
637+
638+
# use a Zarr BasicIndexer to convert this to input coordinates
639+
indexer = create_basic_indexer(in_sel, in_shape, in_chunksize)
640+
641+
shape = indexer.shape
642+
out = np.empty_like(arrs[0], shape=shape)
643+
644+
if array_size(shape) > 0:
645+
_, lchunk_selection, lout_selection = zip(*indexer)
646+
for ai, chunk_select, out_select in zip(arrs, lchunk_selection, lout_selection):
647+
out[out_select] = ai[chunk_select]
648+
649+
return out
650+
651+
531652
def _read_index_chunk(
532653
x,
533654
*arrays,

cubed/tests/test_mem_utilization.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ def test_index(tmp_path, spec, executor):
8484
run_operation(tmp_path, executor, "index", b)
8585

8686

87+
@pytest.mark.slow
88+
def test_index_chunk_aligned(tmp_path, spec, executor):
89+
a = cubed.random.random(
90+
(10000, 10000), chunks=(5000, 5000), spec=spec
91+
) # 200MB chunks
92+
b = a[0:5000, :]
93+
run_operation(tmp_path, executor, "index_chunk_aligned", b)
94+
95+
8796
@pytest.mark.slow
8897
def test_index_step(tmp_path, spec, executor):
8998
a = cubed.random.random(

cubed/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@
1010
from functools import partial
1111
from itertools import islice
1212
from math import prod
13-
from operator import add
13+
from operator import add, mul
1414
from pathlib import Path
1515
from posixpath import join
1616
from typing import Dict, Tuple, Union, cast
1717
from urllib.parse import quote, unquote, urlsplit, urlunsplit
1818

1919
import numpy as np
2020
import tlz as toolz
21+
from toolz import reduce
2122

2223
from cubed.backend_array_api import namespace as nxp
2324
from cubed.types import T_DType, T_RectangularChunks, T_RegularChunks, T_Shape
@@ -41,6 +42,11 @@ def chunk_memory(arr) -> int:
4142
)
4243

4344

45+
def array_size(shape: T_Shape) -> int:
46+
"""Number of elements in an array."""
47+
return reduce(mul, shape, 1)
48+
49+
4450
def offset_to_block_id(offset: int, numblocks: Tuple[int, ...]) -> Tuple[int, ...]:
4551
"""Convert an index offset to a block ID (chunk coordinates)."""
4652
return tuple(int(i) for i in np.unravel_index(offset, numblocks))

0 commit comments

Comments
 (0)