diff --git a/pangeo_forge_recipes/chunk_grid.py b/pangeo_forge_recipes/chunk_grid.py new file mode 100644 index 00000000..34607968 --- /dev/null +++ b/pangeo_forge_recipes/chunk_grid.py @@ -0,0 +1,245 @@ +""" +Abstract representation of ND chunked arrays +""" +from __future__ import annotations + +from itertools import chain, groupby +from typing import Dict, FrozenSet, Set, Tuple + +import numpy as np + +from .utils import calc_subsets + +# Most of this is probably already in Dask and Zarr! +# However, it's useful to write up our own little model that does just what we need. + + +class ChunkGrid: + """A ChunkGrid contains several named ChunkAxis. + The order of the axes does not matter. + + :param chunks: Dictionary mapping dimension names to chunks in each dimension. + """ + + def __init__(self, chunks: Dict[str, Tuple[int, ...]]): + self._chunk_axes = {name: ChunkAxis(axis_chunks) for name, axis_chunks in chunks.items()} + + def __eq__(self, other): + if self.dims != other.dims: + return False + for name in self._chunk_axes: + if self._chunk_axes[name] != other._chunk_axes[name]: + return False + return True + + @classmethod + def from_uniform_grid(cls, chunksize_and_dimsize: Dict[str, Tuple[int, int]]): + """Create a ChunkGrid with uniform chunk sizes (except possibly the last chunk). + + :param chunksize_and_dimsize: Dictionary whose keys are dimension names and + whose values are a tuple of `chunk_size, total_dim_size` + """ + all_chunks = {} + for name, (chunksize, dimsize) in chunksize_and_dimsize.items(): + assert dimsize > 0 + assert ( + chunksize > 0 and chunksize <= dimsize + ), f"invalid chunksize {chunksize} and dimsize {dimsize}" + chunks = (dimsize // chunksize) * (chunksize,) + if dimsize % chunksize > 0: + chunks = chunks + (dimsize % chunksize,) + all_chunks[name] = chunks + return cls(all_chunks) + + @property + def dims(self) -> FrozenSet[str]: + return frozenset(self._chunk_axes) + + @property + def shape(self) -> Dict[str, int]: + return {name: len(ca) for name, ca in self._chunk_axes.items()} + + @property + def nchunks(self) -> Dict[str, int]: + return {name: ca.nchunks for name, ca in self._chunk_axes.items()} + + @property + def ndim(self): + return len(self._chunk_axes) + + def consolidate(self, factors: Dict[str, int]) -> ChunkGrid: + """Return a new ChunkGrid with chunks consolidated by a given factor + along specifed dimensions.""" + + # doesn't seem like the kosher way to do this but /shrug + new = self.__class__({}) + new._chunk_axes = { + name: ca.consolidate(factors[name]) if name in factors else ca + for name, ca in self._chunk_axes.items() + } + return new + + def subset(self, factors: Dict[str, int]) -> ChunkGrid: + """Return a new ChunkGrid with chunks decimated by a given subset factor + along specifed dimensions.""" + + # doesn't seem like the kosher way to do this but /shrug + new = self.__class__({}) + new._chunk_axes = { + name: ca.subset(factors[name]) if name in factors else ca + for name, ca in self._chunk_axes.items() + } + return new + + def chunk_index_to_array_slice(self, chunk_index: Dict[str, int]) -> Dict[str, slice]: + """Convert a single index from chunk space to a slice in array space + for each specified dimension.""" + + return { + name: self._chunk_axes[name].chunk_index_to_array_slice(idx) + for name, idx in chunk_index.items() + } + + def array_index_to_chunk_index(self, array_index: Dict[str, int]) -> Dict[str, int]: + """Figure out which chunk a single array-space index comes from + for each specified dimension.""" + return { + name: self._chunk_axes[name].array_index_to_chunk_index(idx) + for name, idx in array_index.items() + } + + def array_slice_to_chunk_slice(self, array_slices: Dict[str, slice]) -> Dict[str, slice]: + """Find all chunks that intersect with a given array-space slice + for each specified dimension.""" + return { + name: self._chunk_axes[name].array_slice_to_chunk_slice(array_slice) + for name, array_slice in array_slices.items() + } + + def chunk_conflicts(self, chunk_index: Dict[str, int], other: ChunkGrid) -> Dict[str, Set[int]]: + """Figure out which chunks from the other ChunkGrid might potentially + be written by other chunks from this array. + Returns a set of chunk indexes from the _other_ ChunkGrid that would + need to be locked for a safe write. + + :param chunk_index: The index of the chunk we want to write + :param other: The other ChunkAxis + """ + + return { + name: self._chunk_axes[name].chunk_conflicts(idx, other._chunk_axes[name]) + for name, idx in chunk_index.items() + } + + +class ChunkAxis: + """A ChunkAxis has two index spaces. + + Array index space is a regular python index of an array / list. + Chunk index space describes chunk positions. + + A ChunkAxis helps translate between these two spaces. + + :param chunks: The explicit size of each chunk + """ + + def __init__(self, chunks: Tuple[int, ...]): + self.chunks = tuple(chunks) # want this immutable + self._chunk_bounds = np.hstack([0, np.cumsum(self.chunks)]) + + def __eq__(self, other): + return self.chunks == other.chunks + + def __len__(self): + return self._chunk_bounds[-1].item() + + def subset(self, factor: int) -> ChunkAxis: + """Return a copy with chunks decimated by a subset factor.""" + + new_chunks = tuple(chain(*(calc_subsets(c, factor) for c in self.chunks))) + return self.__class__(new_chunks) + + def consolidate(self, factor: int) -> ChunkAxis: + """Return a copy with chunks consolidated by a subset factor.""" + + new_chunks = [] + + def grouper(val): + return val[0] // factor + + for _, gobj in groupby(enumerate(self.chunks), grouper): + new_chunks.append(sum(f[1] for f in gobj)) + return self.__class__(tuple(new_chunks)) + + @property + def nchunks(self): + return len(self.chunks) + + def chunk_index_to_array_slice(self, chunk_index: int) -> slice: + """Convert a single index from chunk space to a slice in array space.""" + + if chunk_index < 0 or chunk_index >= self.nchunks: + raise IndexError("chunk_index out of range") + return slice(self._chunk_bounds[chunk_index], self._chunk_bounds[chunk_index + 1]) + + def array_index_to_chunk_index(self, array_index: int) -> int: + """Figure out which chunk a single array-space index comes from.""" + + if array_index < 0 or array_index >= len(self): + raise IndexError("Index out of range") + return self._chunk_bounds.searchsorted(array_index, side="right") - 1 + + def array_slice_to_chunk_slice(self, sl: slice) -> slice: + """Find all chunks that intersect with a given array-space slice.""" + + if sl.step != 1 and sl.step is not None: + raise IndexError("Only works with step=1 or None") + if sl.start < 0: + raise IndexError("Slice start must be > 0") + if sl.stop <= sl.start: + raise IndexError("Stop must be greater than start") + if sl.stop > len(self): + raise IndexError(f"Stop must be <= than {len(self)}") + first = self._chunk_bounds.searchsorted(sl.start, side="right") - 1 + last = self._chunk_bounds.searchsorted(sl.stop, side="left") + return slice(first, last) + + def chunk_conflicts(self, chunk_index: int, other: ChunkAxis) -> Set[int]: + """Figure out which chunks from the other ChunkAxis might potentially + be written by other chunks from this array. + Returns a set of chunk indexes from the _other_ ChunkAxis that would + need to be locked for a safe write. + If there are no potential conflicts, return an empty set. + There are at most two other-axis chunks with conflicts; + one each edge of this chunk. + + :param chunk_index: The index of the chunk we want to write + :param other: The other ChunkAxis + """ + + if len(other) != len(self): + raise ValueError("Can't compute conflict for ChunkAxes of different size.") + + other_chunk_conflicts = set() + + array_slice = self.chunk_index_to_array_slice(chunk_index) + # The chunks from the other axis that we need to touch: + other_chunk_indexes = other.array_slice_to_chunk_slice(array_slice) + # Which other chunks from this array might also have to touch those chunks? + # To answer this, we need to know if those chunks overlap any of the + # other chunks from this array. + # Since the slice is contiguous, we only have to worry about the first and last chunks. + + other_chunk_left = other_chunk_indexes.start + array_slice_left = other.chunk_index_to_array_slice(other_chunk_left) + chunk_slice_left = self.array_slice_to_chunk_slice(array_slice_left) + if chunk_slice_left.start < chunk_index: + other_chunk_conflicts.add(other_chunk_left) + + other_chunk_right = other_chunk_indexes.stop - 1 + array_slice_right = other.chunk_index_to_array_slice(other_chunk_right) + chunk_slice_right = self.array_slice_to_chunk_slice(array_slice_right) + if chunk_slice_right.stop > chunk_index + 1: + other_chunk_conflicts.add(other_chunk_right) + + return other_chunk_conflicts diff --git a/pangeo_forge_recipes/recipes/xarray_zarr.py b/pangeo_forge_recipes/recipes/xarray_zarr.py index 70cf5cdc..23907936 100644 --- a/pangeo_forge_recipes/recipes/xarray_zarr.py +++ b/pangeo_forge_recipes/recipes/xarray_zarr.py @@ -17,14 +17,10 @@ import xarray as xr import zarr +from ..chunk_grid import ChunkGrid from ..patterns import CombineOp, DimIndex, FilePattern, Index, prune_pattern from ..storage import AbstractTarget, CacheFSSpecTarget, MetadataTarget, file_opener -from ..utils import ( - calc_subsets, - chunk_bounds_and_conflicts, - fix_scalar_attr_encoding, - lock_for_conflicts, -) +from ..utils import calc_subsets, fix_scalar_attr_encoding, lock_for_conflicts from .base import BaseRecipe # use this filename to store global recipe metadata in the metadata_cache @@ -55,6 +51,9 @@ def _input_metadata_fname(input_key): def inputs_for_chunk( chunk_key: ChunkKey, inputs_per_chunk: int, ninputs: int ) -> Sequence[InputKey]: + """For a chunk key, figure out which inputs belong to it. + Returns at least one InputKey.""" + merge_dims = [dim_idx for dim_idx in chunk_key if dim_idx.operation == CombineOp.MERGE] concat_dims = [dim_idx for dim_idx in chunk_key if dim_idx.operation == CombineOp.CONCAT] # Ignore subset dims, we don't need them here @@ -107,11 +106,33 @@ def open_target(target: CacheFSSpecTarget) -> xr.Dataset: return xr.open_zarr(target.get_mapper()) -def input_position(input_key): +def input_position(input_key: InputKey) -> int: + """Return the position of the input within the input sequence.""" for dim_idx in input_key: # assumes there is one and only one concat dim if dim_idx.operation == CombineOp.CONCAT: return dim_idx.index + return -1 # make mypy happy + + +def chunk_position(chunk_key: ChunkKey) -> int: + """Return the position of the input within the input sequence.""" + concat_idx = -1 + for dim_idx in chunk_key: + # assumes there is one and only one concat dim + if dim_idx.operation == CombineOp.CONCAT: + concat_idx = dim_idx.index + concat_dim = dim_idx.name + if concat_idx == -1: + raise ValueError("Couldn't find concat_dim") + subset_idx = 0 + subset_factor = 1 + for dim_idx in chunk_key: + if dim_idx.operation == CombineOp.SUBSET: + if dim_idx.name == concat_dim: + subset_idx = dim_idx.index + subset_factor = dim_idx.sequence_len + return subset_factor * concat_idx + subset_idx def cache_input_metadata( @@ -191,67 +212,42 @@ def region_and_conflicts_for_chunk( concat_dim: str, metadata_cache: Optional[MetadataTarget], subset_inputs: Optional[SubsetSpec], -) -> Tuple[Dict[str, slice], Set[int]]: +) -> Tuple[Dict[str, slice], Dict[str, Set[int]]]: # return a dict suitable to pass to xr.to_zarr(region=...) # specifies where in the overall array to put this chunk's data # also return the conflicts with other chunks - ninputs = file_pattern.dims[file_pattern.concat_dims[0]] - input_keys = inputs_for_chunk(chunk_key, inputs_per_chunk, ninputs) - if nitems_per_input: input_sequence_lens = (nitems_per_input,) * file_pattern.dims[concat_dim] # type: ignore else: assert metadata_cache is not None # for mypy global_metadata = metadata_cache[_GLOBAL_METADATA_KEY] input_sequence_lens = global_metadata["input_sequence_lens"] + total_len = sum(input_sequence_lens) + + # for now this will just have one key since we only allow one concat_dim + # but it could expand to accomodate multiple concat dims + chunk_index = {concat_dim: chunk_position(chunk_key)} + input_chunk_grid = ChunkGrid({concat_dim: input_sequence_lens}) if subset_inputs and concat_dim in subset_inputs: - # scenario I: there is a single input per chunk, possibly with subsetting assert ( inputs_per_chunk == 1 ), "Doesn't make sense to have multiple inputs per chunk plus subsetting" - subset_factor = subset_inputs[concat_dim] - input_sequence_lens = tuple( - chain(*(calc_subsets(input_len, subset_factor) for input_len in input_sequence_lens)) - ) - subset_idx = [ - dim_idx.index - for dim_idx in chunk_key - if dim_idx.operation == CombineOp.SUBSET and dim_idx.name == concat_dim - ][0] + chunk_grid = input_chunk_grid.subset(subset_inputs) + elif inputs_per_chunk > 1: + chunk_grid = input_chunk_grid.consolidate({concat_dim: inputs_per_chunk}) else: - subset_factor = 1 - subset_idx = 0 # unused + chunk_grid = input_chunk_grid + assert chunk_grid.shape[concat_dim] == total_len - assert len(input_sequence_lens) == ninputs * subset_factor - assert concat_dim_chunks is not None + region = chunk_grid.chunk_index_to_array_slice(chunk_index) - chunk_bounds, all_chunk_conflicts = chunk_bounds_and_conflicts( - chunks=input_sequence_lens, zchunks=concat_dim_chunks - ) - input_positions = [input_position(input_key) for input_key in input_keys] - - if subset_factor > 1: - assert len(input_positions) == 1 - start_idx = subset_factor * input_positions[0] + subset_idx - stop_idx = start_idx + 1 - else: - start_idx = min(input_positions) - stop_idx = max(input_positions) + 1 - - start = chunk_bounds[start_idx] - stop = chunk_bounds[stop_idx] - region_slice = slice(start, stop) - - this_chunk_conflicts = set() - for idx in range(start_idx, stop_idx): - conflict = all_chunk_conflicts[idx] - if conflict: - for conflict_index in conflict: - this_chunk_conflicts.add(conflict_index) + assert concat_dim_chunks is not None + target_grid = ChunkGrid.from_uniform_grid({concat_dim: (concat_dim_chunks, total_len)}) + conflicts = chunk_grid.chunk_conflicts(chunk_index, target_grid) - return {concat_dim: region_slice}, this_chunk_conflicts + return region, conflicts @contextmanager @@ -617,7 +613,11 @@ def store_chunk( var.data ) # TODO: can we buffer large data rather than loading it all? zarr_region = tuple(write_region.get(dim, slice(None)) for dim in var.dims) - lock_keys = [f"{vname}-{c}" for c in conflicts] + lock_keys = [ + f"{vname}-{dim}-{c}" + for dim, dim_conflicts in conflicts.items() + for c in dim_conflicts + ] logger.debug(f"Acquiring locks {lock_keys}") with lock_for_conflicts(lock_keys, timeout=lock_timeout): logger.info( diff --git a/tests/test_chunk_grid.py b/tests/test_chunk_grid.py new file mode 100644 index 00000000..f0443bf2 --- /dev/null +++ b/tests/test_chunk_grid.py @@ -0,0 +1,140 @@ +import pytest + +from pangeo_forge_recipes.chunk_grid import ChunkAxis, ChunkGrid + + +def test_chunk_axis(): + ca = ChunkAxis(chunks=(2, 4, 3)) + assert len(ca) == 9 + assert ca.nchunks == 3 + + # yes we could parameterize this but writing it out helps understanding + with pytest.raises(IndexError): + _ = ca.array_index_to_chunk_index(-1) + assert ca.array_index_to_chunk_index(0) == 0 + assert ca.array_index_to_chunk_index(1) == 0 + assert ca.array_index_to_chunk_index(2) == 1 + assert ca.array_index_to_chunk_index(3) == 1 + assert ca.array_index_to_chunk_index(4) == 1 + assert ca.array_index_to_chunk_index(5) == 1 + assert ca.array_index_to_chunk_index(6) == 2 + assert ca.array_index_to_chunk_index(7) == 2 + assert ca.array_index_to_chunk_index(8) == 2 + with pytest.raises(IndexError): + _ = ca.array_index_to_chunk_index(9) + + bad_array_slices = slice(0, 5, 2), slice(-1, 5), slice(5, 4), slice(5, 10) + for sl in bad_array_slices: + with pytest.raises(IndexError): + _ = ca.array_slice_to_chunk_slice(sl) + + assert ca.array_slice_to_chunk_slice(slice(0, 9)) == slice(0, 3) + assert ca.array_slice_to_chunk_slice(slice(1, 9)) == slice(0, 3) + assert ca.array_slice_to_chunk_slice(slice(2, 9)) == slice(1, 3) + assert ca.array_slice_to_chunk_slice(slice(2, 8)) == slice(1, 3) + assert ca.array_slice_to_chunk_slice(slice(2, 6)) == slice(1, 2) + assert ca.array_slice_to_chunk_slice(slice(2, 5)) == slice(1, 2) + assert ca.array_slice_to_chunk_slice(slice(6, 7)) == slice(2, 3) + + with pytest.raises(IndexError): + _ = ca.chunk_index_to_array_slice(-1) + assert ca.chunk_index_to_array_slice(0) == slice(0, 2) + assert ca.chunk_index_to_array_slice(1) == slice(2, 6) + assert ca.chunk_index_to_array_slice(2) == slice(6, 9) + with pytest.raises(IndexError): + _ = ca.chunk_index_to_array_slice(3) + + +def test_chunk_axis_subset(): + ca = ChunkAxis(chunks=(2, 4, 3)) + cas = ca.subset(2) + assert cas.chunks == (1, 1, 2, 2, 1, 2) + + +def test_chunk_axis_consolidate(): + ca = ChunkAxis(chunks=(2, 4, 3, 4, 2)) + cac = ca.consolidate(2) + assert cac.chunks == (6, 7, 2) + cad = ca.consolidate(3) + assert cad.chunks == (9, 6) + + +def test_chunk_grid(): + cg = ChunkGrid({"x": (2, 4, 3), "time": (7, 8)}) + assert cg.dims == {"x", "time"} + assert cg.shape == {"x": 9, "time": 15} + assert cg.nchunks == {"x": 3, "time": 2} + assert cg.ndim == 2 + + assert cg.array_index_to_chunk_index({"x": 2}) == {"x": 1} + assert cg.array_index_to_chunk_index({"time": 10}) == {"time": 1} + assert cg.array_index_to_chunk_index({"x": 7, "time": 10}) == {"x": 2, "time": 1} + + assert cg.array_slice_to_chunk_slice({"x": slice(0, 9)}) == {"x": slice(0, 3)} + assert cg.array_slice_to_chunk_slice({"time": slice(0, 15)}) == {"time": slice(0, 2)} + assert cg.array_slice_to_chunk_slice({"x": slice(0, 9), "time": slice(0, 15)}) == { + "x": slice(0, 3), + "time": slice(0, 2), + } + + assert cg.chunk_index_to_array_slice({"x": 1}) == {"x": slice(2, 6)} + assert cg.chunk_index_to_array_slice({"time": 1}) == {"time": slice(7, 15)} + assert cg.chunk_index_to_array_slice({"x": 1, "time": 1}) == { + "x": slice(2, 6), + "time": slice(7, 15), + } + + +def test_chunk_grid_from_uniform_grid(): + cg1 = ChunkGrid({"x": (2, 2), "y": (3, 3, 3, 1)}) + cg2 = ChunkGrid.from_uniform_grid({"x": (2, 4), "y": (3, 10)}) + assert cg1 == cg2 + + +def test_chunk_grid_consolidate_subset(): + cg = ChunkGrid({"x": (2, 4, 3), "time": (7, 8)}) + + assert cg.consolidate({}) == cg + cgc1 = cg.consolidate({"x": 2}) + assert cg.shape == cgc1.shape + assert cgc1.nchunks == {"x": 2, "time": 2} + cgc2 = cg.consolidate({"x": 2, "time": 2}) + assert cg.shape == cgc2.shape + assert cgc2.nchunks == {"x": 2, "time": 1} + + assert cg.subset({}) == cg + cgs1 = cg.subset({"x": 2}) + assert cg.shape == cgs1.shape + assert cgs1.nchunks == {"x": 6, "time": 2} + cgs2 = cg.subset({"x": 2, "time": 2}) + assert cg.shape == cgs2.shape + assert cgs2.nchunks == {"x": 6, "time": 4} + + +def test_chunk_axis_conflicts(): + ca1 = ChunkAxis(chunks=(2, 4, 3, 4, 2)) # len 15 + ca2 = ChunkAxis(chunks=(5, 4, 6)) + + for n in range(ca1.nchunks): + assert ca1.chunk_conflicts(n, ca1) == set() + + assert ca1.chunk_conflicts(0, ca2) == {0} + assert ca1.chunk_conflicts(1, ca2) == {0, 1} + assert ca1.chunk_conflicts(2, ca2) == {1} + assert ca1.chunk_conflicts(3, ca2) == {2} + assert ca1.chunk_conflicts(4, ca2) == {2} + assert ca2.chunk_conflicts(0, ca1) == {1} + assert ca2.chunk_conflicts(1, ca1) == {1} + assert ca2.chunk_conflicts(2, ca1) == set() + + with pytest.raises(ValueError): + _ = ca1.chunk_conflicts(0, ChunkAxis((14,))) + + +def test_chunk_grid_conflicts(): + cg1 = ChunkGrid({"x": (2, 4, 3, 4, 2), "y": (10, 10, 10)}) + cg2 = ChunkGrid({"x": (5, 4, 6), "y": (11, 9, 10)}) + + assert cg1.chunk_conflicts({"x": 0}, cg2) == {"x": {0}} + assert cg1.chunk_conflicts({"x": 0, "y": 0}, cg2) == {"x": {0}, "y": {0}} + assert cg1.chunk_conflicts({"y": 2}, cg2) == {"y": set()} diff --git a/tests/test_recipes.py b/tests/test_recipes.py index 4efaeecb..53c8f027 100644 --- a/tests/test_recipes.py +++ b/tests/test_recipes.py @@ -124,7 +124,7 @@ def test_process(recipe_fixture, execute_recipe, process_input, process_chunk): ({"lon": 12}, True, does_not_raise()), ({"lon": 12, "time": 1}, True, does_not_raise()), ({"lon": 12, "time": 3}, True, does_not_raise()), - ({"time": 100}, True, does_not_raise()), # only one big chunk + ({"time": 10}, True, does_not_raise()), # only one big chunk ({"lon": 12, "time": 1}, False, does_not_raise()), ({"lon": 12, "time": 3}, False, does_not_raise()), # can't determine target chunks for the next two because 'time' missing from target_chunks @@ -201,7 +201,6 @@ def test_chunks( assert all([item == chunk_len for item in ds_actual.chunks[other_dim][:-1]]) ds_actual.load() - print(ds_actual) xr.testing.assert_identical(ds_actual, ds_expected)