Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chunk grid logic #141

Merged
merged 16 commits into from
Aug 26, 2021
245 changes: 245 additions & 0 deletions pangeo_forge_recipes/chunk_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
"""
Abstract representation of ND chunked arrays
"""
from __future__ import annotations

from itertools import chain, groupby
from typing import Dict, FrozenSet, Set, Tuple

import numpy as np

from .utils import calc_subsets

# Most of this is probably already in Dask and Zarr!
# However, it's useful to write up our own little model that does just what we need.


class ChunkGrid:
"""A ChunkGrid contains several named ChunkAxis.
The order of the axes does not matter.

:param chunks: Dictionary mapping dimension names to chunks in each dimension.
"""

def __init__(self, chunks: Dict[str, Tuple[int, ...]]):
self._chunk_axes = {name: ChunkAxis(axis_chunks) for name, axis_chunks in chunks.items()}

def __eq__(self, other):
if self.dims != other.dims:
return False
for name in self._chunk_axes:
if self._chunk_axes[name] != other._chunk_axes[name]:
return False
return True

@classmethod
def from_uniform_grid(cls, chunksize_and_dimsize: Dict[str, Tuple[int, int]]):
"""Create a ChunkGrid with uniform chunk sizes (except possibly the last chunk).

:param chunksize_and_dimsize: Dictionary whose keys are dimension names and
whose values are a tuple of `chunk_size, total_dim_size`
"""
all_chunks = {}
for name, (chunksize, dimsize) in chunksize_and_dimsize.items():
assert dimsize > 0
assert (
chunksize > 0 and chunksize <= dimsize
), f"invalid chunksize {chunksize} and dimsize {dimsize}"
chunks = (dimsize // chunksize) * (chunksize,)
if dimsize % chunksize > 0:
chunks = chunks + (dimsize % chunksize,)
all_chunks[name] = chunks
return cls(all_chunks)

@property
def dims(self) -> FrozenSet[str]:
return frozenset(self._chunk_axes)

@property
def shape(self) -> Dict[str, int]:
return {name: len(ca) for name, ca in self._chunk_axes.items()}

@property
def nchunks(self) -> Dict[str, int]:
return {name: ca.nchunks for name, ca in self._chunk_axes.items()}

@property
def ndim(self):
return len(self._chunk_axes)

def consolidate(self, factors: Dict[str, int]) -> ChunkGrid:
"""Return a new ChunkGrid with chunks consolidated by a given factor
along specifed dimensions."""

# doesn't seem like the kosher way to do this but /shrug
new = self.__class__({})
new._chunk_axes = {
name: ca.consolidate(factors[name]) if name in factors else ca
for name, ca in self._chunk_axes.items()
}
return new

def subset(self, factors: Dict[str, int]) -> ChunkGrid:
"""Return a new ChunkGrid with chunks decimated by a given subset factor
along specifed dimensions."""

# doesn't seem like the kosher way to do this but /shrug
new = self.__class__({})
new._chunk_axes = {
name: ca.subset(factors[name]) if name in factors else ca
for name, ca in self._chunk_axes.items()
}
return new

def chunk_index_to_array_slice(self, chunk_index: Dict[str, int]) -> Dict[str, slice]:
"""Convert a single index from chunk space to a slice in array space
for each specified dimension."""

return {
name: self._chunk_axes[name].chunk_index_to_array_slice(idx)
for name, idx in chunk_index.items()
}

def array_index_to_chunk_index(self, array_index: Dict[str, int]) -> Dict[str, int]:
"""Figure out which chunk a single array-space index comes from
for each specified dimension."""
return {
name: self._chunk_axes[name].array_index_to_chunk_index(idx)
for name, idx in array_index.items()
}

def array_slice_to_chunk_slice(self, array_slices: Dict[str, slice]) -> Dict[str, slice]:
"""Find all chunks that intersect with a given array-space slice
for each specified dimension."""
return {
name: self._chunk_axes[name].array_slice_to_chunk_slice(array_slice)
for name, array_slice in array_slices.items()
}

def chunk_conflicts(self, chunk_index: Dict[str, int], other: ChunkGrid) -> Dict[str, Set[int]]:
"""Figure out which chunks from the other ChunkGrid might potentially
be written by other chunks from this array.
Returns a set of chunk indexes from the _other_ ChunkGrid that would
need to be locked for a safe write.

:param chunk_index: The index of the chunk we want to write
:param other: The other ChunkAxis
"""

return {
name: self._chunk_axes[name].chunk_conflicts(idx, other._chunk_axes[name])
for name, idx in chunk_index.items()
}


class ChunkAxis:
"""A ChunkAxis has two index spaces.

Array index space is a regular python index of an array / list.
Chunk index space describes chunk positions.

A ChunkAxis helps translate between these two spaces.

:param chunks: The explicit size of each chunk
"""

def __init__(self, chunks: Tuple[int, ...]):
self.chunks = tuple(chunks) # want this immutable
self._chunk_bounds = np.hstack([0, np.cumsum(self.chunks)])

def __eq__(self, other):
return self.chunks == other.chunks

def __len__(self):
return self._chunk_bounds[-1].item()

def subset(self, factor: int) -> ChunkAxis:
"""Return a copy with chunks decimated by a subset factor."""

new_chunks = tuple(chain(*(calc_subsets(c, factor) for c in self.chunks)))
return self.__class__(new_chunks)

def consolidate(self, factor: int) -> ChunkAxis:
"""Return a copy with chunks consolidated by a subset factor."""

new_chunks = []

def grouper(val):
return val[0] // factor

for _, gobj in groupby(enumerate(self.chunks), grouper):
new_chunks.append(sum(f[1] for f in gobj))
return self.__class__(tuple(new_chunks))

@property
def nchunks(self):
return len(self.chunks)

def chunk_index_to_array_slice(self, chunk_index: int) -> slice:
"""Convert a single index from chunk space to a slice in array space."""

if chunk_index < 0 or chunk_index >= self.nchunks:
raise IndexError("chunk_index out of range")
return slice(self._chunk_bounds[chunk_index], self._chunk_bounds[chunk_index + 1])

def array_index_to_chunk_index(self, array_index: int) -> int:
"""Figure out which chunk a single array-space index comes from."""

if array_index < 0 or array_index >= len(self):
raise IndexError("Index out of range")
return self._chunk_bounds.searchsorted(array_index, side="right") - 1

def array_slice_to_chunk_slice(self, sl: slice) -> slice:
"""Find all chunks that intersect with a given array-space slice."""

if sl.step != 1 and sl.step is not None:
raise IndexError("Only works with step=1 or None")
if sl.start < 0:
raise IndexError("Slice start must be > 0")
if sl.stop <= sl.start:
raise IndexError("Stop must be greater than start")
if sl.stop > len(self):
raise IndexError(f"Stop must be <= than {len(self)}")
first = self._chunk_bounds.searchsorted(sl.start, side="right") - 1
last = self._chunk_bounds.searchsorted(sl.stop, side="left")
return slice(first, last)

def chunk_conflicts(self, chunk_index: int, other: ChunkAxis) -> Set[int]:
"""Figure out which chunks from the other ChunkAxis might potentially
be written by other chunks from this array.
Returns a set of chunk indexes from the _other_ ChunkAxis that would
need to be locked for a safe write.
If there are no potential conflicts, return an empty set.
There are at most two other-axis chunks with conflicts;
one each edge of this chunk.

:param chunk_index: The index of the chunk we want to write
:param other: The other ChunkAxis
"""

if len(other) != len(self):
raise ValueError("Can't compute conflict for ChunkAxes of different size.")

other_chunk_conflicts = set()

array_slice = self.chunk_index_to_array_slice(chunk_index)
# The chunks from the other axis that we need to touch:
other_chunk_indexes = other.array_slice_to_chunk_slice(array_slice)
# Which other chunks from this array might also have to touch those chunks?
# To answer this, we need to know if those chunks overlap any of the
# other chunks from this array.
# Since the slice is contiguous, we only have to worry about the first and last chunks.

other_chunk_left = other_chunk_indexes.start
array_slice_left = other.chunk_index_to_array_slice(other_chunk_left)
chunk_slice_left = self.array_slice_to_chunk_slice(array_slice_left)
if chunk_slice_left.start < chunk_index:
other_chunk_conflicts.add(other_chunk_left)

other_chunk_right = other_chunk_indexes.stop - 1
array_slice_right = other.chunk_index_to_array_slice(other_chunk_right)
chunk_slice_right = self.array_slice_to_chunk_slice(array_slice_right)
if chunk_slice_right.stop > chunk_index + 1:
other_chunk_conflicts.add(other_chunk_right)

return other_chunk_conflicts
83 changes: 33 additions & 50 deletions pangeo_forge_recipes/recipes/xarray_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,10 @@
import xarray as xr
import zarr

from ..chunk_grid import ChunkGrid
from ..patterns import CombineOp, DimIndex, FilePattern, Index, prune_pattern
from ..storage import AbstractTarget, CacheFSSpecTarget, MetadataTarget, file_opener
from ..utils import (
calc_subsets,
chunk_bounds_and_conflicts,
fix_scalar_attr_encoding,
lock_for_conflicts,
)
from ..utils import calc_subsets, fix_scalar_attr_encoding, lock_for_conflicts
from .base import BaseRecipe

# use this filename to store global recipe metadata in the metadata_cache
Expand Down Expand Up @@ -55,6 +51,9 @@ def _input_metadata_fname(input_key):
def inputs_for_chunk(
chunk_key: ChunkKey, inputs_per_chunk: int, ninputs: int
) -> Sequence[InputKey]:
"""For a chunk key, figure out which inputs belong to it.
Returns at least one InputKey."""

merge_dims = [dim_idx for dim_idx in chunk_key if dim_idx.operation == CombineOp.MERGE]
concat_dims = [dim_idx for dim_idx in chunk_key if dim_idx.operation == CombineOp.CONCAT]
# Ignore subset dims, we don't need them here
Expand Down Expand Up @@ -107,11 +106,13 @@ def open_target(target: CacheFSSpecTarget) -> xr.Dataset:
return xr.open_zarr(target.get_mapper())


def input_position(input_key):
def input_position(input_key: InputKey) -> int:
"""Return the position of the input within the input sequence."""
for dim_idx in input_key:
# assumes there is one and only one concat dim
if dim_idx.operation == CombineOp.CONCAT:
return dim_idx.index
return -1 # make mypy happy


def cache_input_metadata(
Expand Down Expand Up @@ -191,67 +192,45 @@ def region_and_conflicts_for_chunk(
concat_dim: str,
metadata_cache: Optional[MetadataTarget],
subset_inputs: Optional[SubsetSpec],
) -> Tuple[Dict[str, slice], Set[int]]:
) -> Tuple[Dict[str, slice], Dict[str, Set[int]]]:
# return a dict suitable to pass to xr.to_zarr(region=...)
# specifies where in the overall array to put this chunk's data
# also return the conflicts with other chunks

ninputs = file_pattern.dims[file_pattern.concat_dims[0]]
input_keys = inputs_for_chunk(chunk_key, inputs_per_chunk, ninputs)

if nitems_per_input:
input_sequence_lens = (nitems_per_input,) * file_pattern.dims[concat_dim] # type: ignore
else:
assert metadata_cache is not None # for mypy
global_metadata = metadata_cache[_GLOBAL_METADATA_KEY]
input_sequence_lens = global_metadata["input_sequence_lens"]
total_len = sum(input_sequence_lens)

input_chunk_grid = ChunkGrid({concat_dim: input_sequence_lens})
if subset_inputs and concat_dim in subset_inputs:
# scenario I: there is a single input per chunk, possibly with subsetting
assert (
inputs_per_chunk == 1
), "Doesn't make sense to have multiple inputs per chunk plus subsetting"
subset_factor = subset_inputs[concat_dim]
input_sequence_lens = tuple(
chain(*(calc_subsets(input_len, subset_factor) for input_len in input_sequence_lens))
)
subset_idx = [
dim_idx.index
for dim_idx in chunk_key
if dim_idx.operation == CombineOp.SUBSET and dim_idx.name == concat_dim
][0]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note how we no longer need this opaque logic in this function. It now lives in ChunkGrid.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(and the new chunk_position function, which translates a ChunkKey to a chunk index)

chunk_grid = input_chunk_grid.consolidate(subset_inputs)
elif inputs_per_chunk > 1:
chunk_grid = input_chunk_grid.consolidate({concat_dim: inputs_per_chunk})
else:
subset_factor = 1
subset_idx = 0 # unused
chunk_grid = input_chunk_grid
assert chunk_grid.shape[concat_dim] == total_len

# for now this will just have one key since we only allow one concat_dim
# but it should expand easily to accomodate multiple concat dims
chunk_index = {
dim_idx.name: dim_idx.index
for dim_idx in chunk_key
if dim_idx.operation == CombineOp.CONCAT
}
region = chunk_grid.chunk_index_to_array_slice(chunk_index)

assert len(input_sequence_lens) == ninputs * subset_factor
assert concat_dim_chunks is not None
target_grid = ChunkGrid.from_uniform_grid({concat_dim: (concat_dim_chunks, total_len)})
conflicts = chunk_grid.chunk_conflicts(chunk_index, target_grid)

chunk_bounds, all_chunk_conflicts = chunk_bounds_and_conflicts(
chunks=input_sequence_lens, zchunks=concat_dim_chunks
)
input_positions = [input_position(input_key) for input_key in input_keys]

if subset_factor > 1:
assert len(input_positions) == 1
start_idx = subset_factor * input_positions[0] + subset_idx
stop_idx = start_idx + 1
else:
start_idx = min(input_positions)
stop_idx = max(input_positions) + 1

start = chunk_bounds[start_idx]
stop = chunk_bounds[stop_idx]
region_slice = slice(start, stop)

this_chunk_conflicts = set()
for idx in range(start_idx, stop_idx):
conflict = all_chunk_conflicts[idx]
if conflict:
for conflict_index in conflict:
this_chunk_conflicts.add(conflict_index)

return {concat_dim: region_slice}, this_chunk_conflicts
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. All this code goes away.

return region, conflicts


@contextmanager
Expand Down Expand Up @@ -617,7 +596,11 @@ def store_chunk(
var.data
) # TODO: can we buffer large data rather than loading it all?
zarr_region = tuple(write_region.get(dim, slice(None)) for dim in var.dims)
lock_keys = [f"{vname}-{c}" for c in conflicts]
lock_keys = [
f"{vname}-{dim}-{c}"
for dim, dim_conflicts in conflicts.items()
for c in dim_conflicts
]
logger.debug(f"Acquiring locks {lock_keys}")
with lock_for_conflicts(lock_keys, timeout=lock_timeout):
logger.info(
Expand Down
Loading