Skip to content

Commit b57fa0c

Browse files
authored
Limited implementation of map_overlap (#462)
* Limited implementation of map_overlap * Change to Array API function name `concat`
1 parent 6b99959 commit b57fa0c

File tree

5 files changed

+259
-0
lines changed

5 files changed

+259
-0
lines changed

cubed/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .core.gufunc import apply_gufunc
1919
from .core.ops import from_array, from_zarr, map_blocks, store, to_zarr
2020
from .nan_functions import nanmean, nansum
21+
from .overlap import map_overlap
2122
from .runtime.types import Callback, TaskEndEvent
2223
from .spec import Spec
2324

@@ -33,6 +34,7 @@
3334
"from_array",
3435
"from_zarr",
3536
"map_blocks",
37+
"map_overlap",
3638
"measure_reserved_mem",
3739
"nanmean",
3840
"nansum",

cubed/overlap.py

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from typing import Tuple
2+
3+
from cubed.backend_array_api import namespace as nxp
4+
from cubed.core.ops import map_direct
5+
from cubed.types import T_RectangularChunks
6+
from cubed.utils import _cumsum
7+
from cubed.vendor.dask.array.core import normalize_chunks
8+
from cubed.vendor.dask.array.overlap import coerce_boundary, coerce_depth
9+
from cubed.vendor.dask.utils import has_keyword
10+
11+
12+
def map_overlap(
13+
func,
14+
*args,
15+
dtype=None,
16+
chunks=None,
17+
depth=None,
18+
boundary=None,
19+
trim=False,
20+
**kwargs,
21+
):
22+
"""Apply a function to corresponding blocks from multiple input arrays with some overlap.
23+
24+
Parameters
25+
----------
26+
func : callable
27+
Function to apply to every block (with overlap) to produce the output array.
28+
args : arrays
29+
The Cubed arrays to map over. Note that currently only one array may be specified.
30+
dtype : np.dtype
31+
The ``dtype`` of the output array.
32+
chunks : tuple
33+
Chunk shape of blocks in the output array.
34+
depth : int, tuple, dict or list
35+
The number of elements that each block should share with its neighbors.
36+
boundary : value type, tuple, dict or list
37+
How to handle the boundaries. Note that this currently only supports constant values.
38+
trim : bool
39+
Whether or not to trim ``depth`` elements from each block after calling the map function.
40+
Currently only ``False`` is supported.
41+
**kwargs : dict
42+
Extra keyword arguments to pass to function.
43+
"""
44+
if trim:
45+
raise ValueError("trim is not supported")
46+
47+
chunks = normalize_chunks(chunks, dtype=dtype)
48+
shape = tuple(map(sum, chunks))
49+
50+
# Coerce depth and boundary arguments to lists of individual
51+
# specifications for each array argument
52+
def coerce(xs, arg, fn):
53+
if not isinstance(arg, list):
54+
arg = [arg] * len(xs)
55+
return [fn(x.ndim, a) for x, a in zip(xs, arg)]
56+
57+
depth = coerce(args, depth, coerce_depth)
58+
boundary = coerce(args, boundary, coerce_boundary)
59+
60+
# memory allocated by reading one chunk from input array
61+
# note that although the output chunk will overlap multiple input chunks, zarr will
62+
# read the chunks in series, reusing the buffer
63+
extra_projected_mem = args[0].chunkmem # TODO: support multiple
64+
65+
has_block_id_kw = has_keyword(func, "block_id")
66+
67+
return map_direct(
68+
_overlap,
69+
*args,
70+
shape=shape,
71+
dtype=dtype,
72+
chunks=chunks,
73+
extra_projected_mem=extra_projected_mem,
74+
overlap_func=func,
75+
depth=depth,
76+
boundary=boundary,
77+
has_block_id_kw=has_block_id_kw,
78+
**kwargs,
79+
)
80+
81+
82+
def _overlap(
83+
x,
84+
*arrays,
85+
overlap_func=None,
86+
depth=None,
87+
boundary=None,
88+
has_block_id_kw=False,
89+
block_id=None,
90+
**kwargs,
91+
):
92+
a = arrays[0] # TODO: support multiple
93+
depth = depth[0]
94+
boundary = boundary[0]
95+
96+
# First read the chunk with overlaps determined by depth, then pad boundaries second.
97+
# Do it this way round so we can do everything with one blockwise. The alternative,
98+
# which pads the entire array first (via concatenate), would result in at least one extra copy.
99+
out = a.zarray[get_item_with_depth(a.chunks, block_id, depth)]
100+
out = _pad_boundaries(out, depth, boundary, a.numblocks, block_id)
101+
if has_block_id_kw:
102+
return overlap_func(out, block_id=block_id, **kwargs)
103+
else:
104+
return overlap_func(out, **kwargs)
105+
106+
107+
def _clamp(minimum: int, x: int, maximum: int) -> int:
108+
return max(minimum, min(x, maximum))
109+
110+
111+
def get_item_with_depth(
112+
chunks: T_RectangularChunks, idx: Tuple[int, ...], depth
113+
) -> Tuple[slice, ...]:
114+
"""Convert a chunk index to a tuple of slices with depth offsets."""
115+
starts = tuple(_cumsum(c, initial_zero=True) for c in chunks)
116+
loc = tuple(
117+
(
118+
_clamp(0, start[i] - depth[ax], start[-1]),
119+
_clamp(0, start[i + 1] + depth[ax], start[-1]),
120+
)
121+
for ax, (i, start) in enumerate(zip(idx, starts))
122+
)
123+
return tuple(slice(*s, None) for s in loc)
124+
125+
126+
def _pad_boundaries(x, depth, boundary, numblocks, block_id):
127+
for i in range(x.ndim):
128+
d = depth.get(i, 0)
129+
if d == 0 or block_id[i] not in (0, numblocks[i] - 1):
130+
continue
131+
pad_shape = list(x.shape)
132+
pad_shape[i] = d
133+
pad_shape = tuple(pad_shape)
134+
p = nxp.full_like(x, fill_value=boundary[i], shape=pad_shape)
135+
if block_id[i] == 0: # first block on axis i
136+
x = nxp.concat([p, x], axis=i)
137+
elif block_id[i] == numblocks[i] - 1: # last block on axis i
138+
x = nxp.concat([x, p], axis=i)
139+
return x

cubed/tests/test_overlap.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import numpy as np
2+
from numpy.testing import assert_array_equal
3+
4+
import cubed
5+
import cubed.array_api as xp
6+
7+
8+
def test_map_overlap_1d():
9+
x = np.arange(6)
10+
a = xp.asarray(x, chunks=(3,))
11+
12+
b = cubed.map_overlap(
13+
lambda x: x,
14+
a,
15+
dtype=a.dtype,
16+
chunks=((5, 5),),
17+
depth=1,
18+
boundary=0,
19+
trim=False,
20+
)
21+
22+
assert_array_equal(b.compute(), np.array([0, 0, 1, 2, 3, 2, 3, 4, 5, 0]))
23+
24+
25+
def test_map_overlap_2d():
26+
x = np.arange(36).reshape((6, 6))
27+
a = xp.asarray(x, chunks=(3, 3))
28+
29+
b = cubed.map_overlap(
30+
lambda x: x,
31+
a,
32+
dtype=a.dtype,
33+
chunks=((7, 7), (5, 5)),
34+
depth={0: 2, 1: 1},
35+
boundary={0: 100, 1: 200},
36+
trim=False,
37+
)
38+
39+
expected = np.array(
40+
[
41+
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200],
42+
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200],
43+
[200, 0, 1, 2, 3, 2, 3, 4, 5, 200],
44+
[200, 6, 7, 8, 9, 8, 9, 10, 11, 200],
45+
[200, 12, 13, 14, 15, 14, 15, 16, 17, 200],
46+
[200, 18, 19, 20, 21, 20, 21, 22, 23, 200],
47+
[200, 24, 25, 26, 27, 26, 27, 28, 29, 200],
48+
[200, 6, 7, 8, 9, 8, 9, 10, 11, 200],
49+
[200, 12, 13, 14, 15, 14, 15, 16, 17, 200],
50+
[200, 18, 19, 20, 21, 20, 21, 22, 23, 200],
51+
[200, 24, 25, 26, 27, 26, 27, 28, 29, 200],
52+
[200, 30, 31, 32, 33, 32, 33, 34, 35, 200],
53+
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200],
54+
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200],
55+
]
56+
)
57+
58+
assert_array_equal(b.compute(), expected)
59+
60+
61+
def test_map_overlap_trim():
62+
x = np.array([1, 1, 2, 3, 5, 8, 13, 21])
63+
a = xp.asarray(x, chunks=5)
64+
65+
def derivative(x):
66+
out = x - np.roll(x, 1)
67+
return out[1:-1] # manual trim
68+
69+
b = cubed.map_overlap(
70+
derivative,
71+
a,
72+
dtype=a.dtype,
73+
chunks=a.chunks,
74+
depth=1,
75+
boundary=0,
76+
trim=False,
77+
)
78+
79+
assert_array_equal(b.compute(), np.array([1, 0, 1, 1, 2, 3, 5, 8]))

cubed/vendor/dask/array/overlap.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from __future__ import annotations
2+
3+
from numbers import Integral
4+
5+
6+
def coerce_depth(ndim, depth):
7+
default = 0
8+
if depth is None:
9+
depth = default
10+
if isinstance(depth, Integral):
11+
depth = (depth,) * ndim
12+
if isinstance(depth, tuple):
13+
depth = dict(zip(range(ndim), depth))
14+
if isinstance(depth, dict):
15+
depth = {ax: depth.get(ax, default) for ax in range(ndim)}
16+
return coerce_depth_type(ndim, depth)
17+
18+
19+
def coerce_depth_type(ndim, depth):
20+
for i in range(ndim):
21+
if isinstance(depth[i], tuple):
22+
depth[i] = tuple(int(d) for d in depth[i])
23+
else:
24+
depth[i] = int(depth[i])
25+
return depth
26+
27+
28+
def coerce_boundary(ndim, boundary):
29+
default = "none"
30+
if boundary is None:
31+
boundary = default
32+
if not isinstance(boundary, (tuple, dict)):
33+
boundary = (boundary,) * ndim
34+
if isinstance(boundary, tuple):
35+
boundary = dict(zip(range(ndim), boundary))
36+
if isinstance(boundary, dict):
37+
boundary = {ax: boundary.get(ax, default) for ax in range(ndim)}
38+
return boundary

docs/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ Chunk-specific functions
4545

4646
apply_gufunc
4747
map_blocks
48+
map_overlap
4849

4950
Non-standardised functions
5051
==========================

0 commit comments

Comments
 (0)