Skip to content

Commit b65b28c

Browse files
committed
Allow rechunk to take a dict for chunks
1 parent 19ec275 commit b65b28c

File tree

3 files changed

+9
-19
lines changed

3 files changed

+9
-19
lines changed

cubed/core/ops.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -590,16 +590,19 @@ def wrap(*a, block_id=None, **kw):
590590

591591

592592
def rechunk(x, chunks, target_store=None):
593-
if x.chunks == normalize_chunks(chunks, x.shape, dtype=x.dtype):
593+
normalized_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype)
594+
if x.chunks == normalized_chunks:
594595
return x
596+
# normalizing takes care of dict args for chunks
597+
target_chunks = to_chunksize(normalized_chunks)
595598
name = gensym()
596599
spec = x.spec
597600
if target_store is None:
598601
target_store = new_temp_path(name=name, spec=spec)
599602
temp_store = new_temp_path(name=f"{name}-intermediate", spec=spec)
600603
pipeline = primitive_rechunk(
601604
x.zarray_maybe_lazy,
602-
target_chunks=chunks,
605+
target_chunks=target_chunks,
603606
allowed_mem=spec.allowed_mem,
604607
reserved_mem=spec.reserved_mem,
605608
target_store=target_store,

cubed/primitive/rechunk.py

+1-15
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@
77
from cubed.runtime.pipeline import spec_to_pipeline
88
from cubed.storage.zarr import lazy_empty
99
from cubed.vendor.rechunker.algorithm import rechunking_plan
10-
from cubed.vendor.rechunker.api import (
11-
_get_dims_from_zarr_array,
12-
_shape_dict_to_tuple,
13-
_validate_options,
14-
)
10+
from cubed.vendor.rechunker.api import _validate_options
1511

1612

1713
def rechunk(
@@ -119,16 +115,6 @@ def _setup_array_rechunk(
119115
# this is just a pass-through copy
120116
target_chunks = source_chunks
121117

122-
if isinstance(target_chunks, dict):
123-
array_dims = _get_dims_from_zarr_array(source_array)
124-
try:
125-
target_chunks = _shape_dict_to_tuple(array_dims, target_chunks)
126-
except KeyError:
127-
raise KeyError(
128-
"You must explicitly specify each dimension size in target_chunks. "
129-
f"Got array_dims {array_dims}, target_chunks {target_chunks}."
130-
)
131-
132118
# TODO: rewrite to avoid the hard dependency on dask
133119
max_mem = cubed.vendor.dask.utils.parse_bytes(max_mem)
134120

cubed/tests/test_core.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,10 @@ def test_multiple_ops(spec, executor):
197197
)
198198

199199

200-
def test_rechunk(spec, executor):
200+
@pytest.mark.parametrize("new_chunks", [(1, 2), {0: 1, 1: 2}])
201+
def test_rechunk(spec, executor, new_chunks):
201202
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 1), spec=spec)
202-
b = a.rechunk((1, 2))
203+
b = a.rechunk(new_chunks)
203204
assert_array_equal(
204205
b.compute(executor=executor),
205206
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),

0 commit comments

Comments
 (0)