Skip to content

Commit

Permalink
Big refactor in load_tasks to support extra dims
Browse files Browse the repository at this point in the history
- Extend LoadChunkTask to support extra dims
- Chunking extra dims
- more tests
  • Loading branch information
Kirill888 committed May 28, 2024
1 parent 69982f2 commit abcf9c0
Show file tree
Hide file tree
Showing 3 changed files with 301 additions and 77 deletions.
239 changes: 165 additions & 74 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __call__(
dtype: DTypeLike,
/,
name: Hashable,
ydim: int,
) -> Any: ... # pragma: no cover


Expand All @@ -64,13 +65,29 @@ class LoadChunkTask:
Unit of work for dask graph builder.
"""

# pylint: disable=too-many-instance-attributes

band: str
srcs: List[List[Tuple[int, str]]]
cfg: RasterLoadParams
gbt: GeoboxTiles
idx_tyx: Tuple[int, int, int]
prefix_dims: Tuple[int, ...] = ()
postfix_dims: Tuple[int, ...] = ()
idx: Tuple[int, ...]
shape: Tuple[int, ...]
ydim: int = 1
selection: Any = None # optional slice into extra dims

@property
def idx_tyx(self) -> Tuple[int, int, int]:
ydim = self.ydim
return self.idx[0], self.idx[ydim], self.idx[ydim + 1]

@property
def prefix_dims(self) -> tuple[int, ...]:
return self.shape[1 : self.ydim]

@property
def postfix_dims(self) -> tuple[int, ...]:
return self.shape[self.ydim + 2 :]

@property
def dst_roi(self) -> Tuple[slice, ...]:
Expand Down Expand Up @@ -116,14 +133,14 @@ class DaskGraphBuilder:

def __init__(
self,
cfg: Dict[str, RasterLoadParams],
cfg: Mapping[str, RasterLoadParams],
template: RasterGroupMetadata,
srcs: Sequence[MultiBandRasterSource],
tyx_bins: Dict[Tuple[int, int, int], List[int]],
tyx_bins: Mapping[Tuple[int, int, int], List[int]],
gbt: GeoboxTiles,
env: Dict[str, Any],
rdr: ReaderDriver,
chunks: dict[str, int],
chunks: Mapping[str, int],
) -> None:
gbox = gbt.base
assert isinstance(gbox, GeoBox)
Expand All @@ -145,7 +162,7 @@ def build(
self,
gbox: GeoBox,
time: Sequence[datetime],
bands: Dict[str, RasterLoadParams],
bands: Mapping[str, RasterLoadParams],
):
return mk_dataset(
gbox,
Expand All @@ -161,12 +178,13 @@ def __call__(
dtype: DTypeLike,
/,
name: Hashable,
ydim: int,
) -> Any:
# pylint: disable=too-many-locals
assert isinstance(name, str)
cfg = self.cfg[name]
assert dtype == cfg.dtype
ydim = cfg.ydim + 1 # +1 for time dimension
assert ydim == cfg.ydim + 1 # +1 for time dimension
postfix_dims = shape[ydim + 2 :]
prefix_dims = shape[1:ydim]

Expand Down Expand Up @@ -325,7 +343,7 @@ def _fill_nd_slice(
def mk_dataset(
gbox: GeoBox,
time: Sequence[datetime],
bands: Dict[str, RasterLoadParams],
bands: Mapping[str, RasterLoadParams],
alloc: Optional[MkArray] = None,
*,
template: RasterGroupMetadata,
Expand All @@ -344,17 +362,17 @@ def mk_dataset(
for coord in template.extra_coords
}

def _alloc(shape: Tuple[int, ...], dtype: str, name: Hashable) -> Any:
def _alloc(shape: Tuple[int, ...], dtype: str, name: Hashable, ydim: int) -> Any:
if alloc is not None:
return alloc(shape, dtype, name=name)
return alloc(shape, dtype, name=name, ydim=ydim)
return np.empty(shape, dtype=dtype)

def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray:
assert band.dtype is not None
band_coords = {**coords}
ydim = band.ydim

if len(band.dims) > 2:
ydim = band.ydim
assert band.dims[ydim : ydim + 2] == ("y", "x")
prefix_dims = band.dims[:ydim]
postfix_dims = band.dims[ydim + 2 :]
Expand Down Expand Up @@ -383,7 +401,12 @@ def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray:
dims = ("time", *gbox.dimensions)
shape = (len(time), *gbox.shape.yx)

data = _alloc(shape, band.dtype, name=name)
data = _alloc(
shape,
band.dtype,
name=name,
ydim=ydim + 1, # +1 for time dimension
)
attrs = {}
if band.fill_value is not None:
attrs["nodata"] = band.fill_value
Expand All @@ -396,16 +419,16 @@ def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray:


def chunked_load(
load_cfg: Dict[str, RasterLoadParams],
load_cfg: Mapping[str, RasterLoadParams],
template: RasterGroupMetadata,
srcs: Sequence[MultiBandRasterSource],
tyx_bins: Dict[Tuple[int, int, int], List[int]],
tyx_bins: Mapping[Tuple[int, int, int], List[int]],
gbt: GeoboxTiles,
tss: Sequence[datetime],
env: Dict[str, Any],
rdr: ReaderDriver,
*,
chunks: Dict[str, int | Literal["auto"]] | None = None,
chunks: Mapping[str, int | Literal["auto"]] | None = None,
pool: ThreadPoolExecutor | int | None = None,
progress: Optional[Any] = None,
) -> xr.Dataset:
Expand Down Expand Up @@ -440,16 +463,16 @@ def chunked_load(


def dask_chunked_load(
load_cfg: Dict[str, RasterLoadParams],
load_cfg: Mapping[str, RasterLoadParams],
template: RasterGroupMetadata,
srcs: Sequence[MultiBandRasterSource],
tyx_bins: Dict[Tuple[int, int, int], List[int]],
tyx_bins: Mapping[Tuple[int, int, int], List[int]],
gbt: GeoboxTiles,
tss: Sequence[datetime],
env: Dict[str, Any],
rdr: ReaderDriver,
*,
chunks: Dict[str, int | Literal["auto"]] | None = None,
chunks: Mapping[str, int | Literal["auto"]] | None = None,
) -> xr.Dataset:
"""Builds Dask graph for data loading."""
if chunks is None:
Expand Down Expand Up @@ -478,13 +501,21 @@ def dask_chunked_load(
return dask_loader.build(gbox, tss, load_cfg)


def denorm_ydim(x: tuple[int, ...], ydim: int) -> tuple[int, ...]:
ydim = ydim - 1
if ydim == 0:
return x
t, y, x, *rest = x
return (t, *rest[:ydim], y, x, *rest[ydim:])


def load_tasks(
load_cfg: Dict[str, RasterLoadParams],
tyx_bins: Dict[Tuple[int, int, int], List[int]],
load_cfg: Mapping[str, RasterLoadParams],
tyx_bins: Mapping[Tuple[int, int, int], List[int]],
gbt: GeoboxTiles,
*,
nt: Optional[int] = None,
chunks: dict[str, int] | None = None,
chunks: Mapping[str, int] | None = None,
extra_dims: Mapping[str, int] | None = None,
) -> Iterator[LoadChunkTask]:
"""
Expand All @@ -502,42 +533,78 @@ def load_tasks(
if chunks is None:
chunks = {}

time_chunks = chunks.get("time", 1)

shape_in_chunks: Tuple[int, int, int] = (
(nt + time_chunks - 1) // time_chunks,
*gbt.shape.yx,
)
base_shape = (nt, *gbt.base.shape.yx)

for band_name, cfg in load_cfg.items():
ydim = cfg.ydim
prefix_dims = tuple(extra_dims[dim] for dim in cfg.dims[:ydim])
postfix_dims = tuple(extra_dims[dim] for dim in cfg.dims[ydim + 2 :])
_edims: Mapping[str, int] = {}

for idx in np.ndindex(shape_in_chunks):
if _dims := cfg.extra_dims:
_edims = dict((k, v) for k, v in extra_dims.items() if k in _dims)

_chunks = resolve_chunks(base_shape, chunks, dtype=cfg.dtype, extra_dims=_edims)
_offsets: list[tuple[int, ...]] = [
(0, *np.cumsum(ch, dtype="int64").tolist()) for ch in _chunks
]
shape_in_chunks = tuple(len(ch) for ch in _chunks) # T,Y,X[,B]
ndim = len(shape_in_chunks)
ydim = cfg.ydim + 1

for idx in np.ndindex(shape_in_chunks[:3]):
tBi, yi, xi = idx # type: ignore
srcs: List[List[Tuple[int, str]]] = []
t0 = tBi * time_chunks
for ti in range(t0, min(t0 + time_chunks, nt)):
t0, nt = _offsets[0][tBi], _chunks[0][tBi]
for ti in range(t0, t0 + nt):
tyx_idx = (ti, yi, xi)
srcs.append([(idx, band_name) for idx in tyx_bins.get(tyx_idx, [])])

yield LoadChunkTask(
band_name,
srcs,
cfg,
gbt,
(tBi, yi, xi),
prefix_dims=prefix_dims,
postfix_dims=postfix_dims,
chunk_shape_tyx: tuple[int, ...] = tuple(
_chunks[dim][i_chunk] for dim, i_chunk in enumerate(idx)
)

if ndim == 3:
yield LoadChunkTask(
band_name,
srcs,
cfg,
gbt,
idx,
chunk_shape_tyx,
)
continue

for extra_idx in np.ndindex(shape_in_chunks[3:]):
extra_chunk_shape = tuple(
_chunks[dim][i_chunk]
for dim, i_chunk in enumerate(extra_idx, start=3)
)
extra_chunk_offset = (
_offsets[dim][i_chunk]
for dim, i_chunk in enumerate(extra_idx, start=3)
)
selection: Any = tuple(
slice(o, o + n)
for o, n in zip(extra_chunk_offset, extra_chunk_shape)
)
if len(selection) == 1:
selection = selection[0]

yield LoadChunkTask(
band_name,
srcs,
cfg,
gbt,
denorm_ydim(idx + extra_idx, ydim),
denorm_ydim(chunk_shape_tyx + extra_chunk_shape, ydim),
ydim=ydim,
selection=selection,
)


def direct_chunked_load(
load_cfg: Dict[str, RasterLoadParams],
load_cfg: Mapping[str, RasterLoadParams],
template: RasterGroupMetadata,
srcs: Sequence[MultiBandRasterSource],
tyx_bins: Dict[Tuple[int, int, int], List[int]],
tyx_bins: Mapping[Tuple[int, int, int], List[int]],
gbt: GeoboxTiles,
tss: Sequence[datetime],
env: Dict[str, Any],
Expand Down Expand Up @@ -602,46 +669,70 @@ def _do_one(task: LoadChunkTask) -> Tuple[str, int, int, int]:
return ds


def _largest_dtype(
cfg: Mapping[str, RasterLoadParams] | None,
fallback: str | np.dtype = "float32",
) -> np.dtype:
if isinstance(fallback, str):
fallback = np.dtype(fallback)

if cfg is None:
return fallback

_dtypes = sorted(
set(np.dtype(cfg.dtype) for cfg in cfg.values() if cfg.dtype is not None),
key=lambda x: x.itemsize,
reverse=True,
)
if _dtypes:
return _dtypes[0]

return fallback


def resolve_chunks(
base_shape: tuple[int, int, int],
chunks: Mapping[str, int | Literal["auto"]],
dtype: Any | None = None,
extra_dims: Mapping[str, int] | None = None,
limit: Any | None = None,
) -> tuple[tuple[int, ...], ...]:
if extra_dims is None:
extra_dims = {}
tt = chunks.get("time", 1)
ty, tx = (chunks.get(dim, -1) for dim in ["y", "x"])
chunks = (tt, ty, tx) + tuple((chunks.get(dim, -1) for dim in extra_dims))
shape = base_shape + tuple(extra_dims.values())
return normalize_chunks(chunks, shape, dtype=dtype, limit=limit)


def resolve_chunk_shape(
nt: int,
gbox: GeoBoxBase,
chunks: Dict[str, int | Literal["auto"]],
chunks: Mapping[str, int | Literal["auto"]],
dtype: Any | None = None,
cfg: Mapping[str, RasterLoadParams] | None = None,
extra_dims: Mapping[str, int] | None = None,
) -> Tuple[int, ...]:
"""
Compute chunk size for time, y and x dimensions.
"""
if cfg is None:
cfg = {}
Compute chunk size for time, y and x dimensions and extra dims.
if dtype is None:
_dtypes = sorted(
set(cfg.dtype for cfg in cfg.values() if cfg.dtype is not None),
key=lambda x: np.dtype(x).itemsize,
reverse=True,
)
dtype = "uint16" if len(_dtypes) == 0 else _dtypes[0]
Spatial dimension chunks need to be suppliead with ``y,x`` keys.
tt = chunks.get("time", 1)
ty, tx = (
chunks.get(dim, chunks.get(fallback_dim, -1))
for dim, fallback_dim in zip(gbox.dimensions, ["y", "x"])
)
postfix_shape: tuple[int, ...] = ()
postfix_chunks: tuple[int, ...] = ()
:returns: Chunk shape in (T,Y,X, *extra_dims) order
"""
if dtype is None and cfg:
dtype = _largest_dtype(cfg, "float32")

if extra_dims:
postfix_shape, postfix_chunks = zip(
*[(sz, chunks.get(dim, -1)) for dim, sz in extra_dims.items()]
)
chunks = {**chunks}
for s, d in zip(gbox.dimensions, ["y", "x"]):
if s != d and s in chunks:
chunks[d] = chunks[s]

return tuple(
int(ch[0])
for ch in normalize_chunks(
(tt, ty, tx, *postfix_chunks),
(nt, *gbox.shape.yx, *postfix_shape),
dtype=dtype,
)
resolved_chunks = resolve_chunks(
(nt, *gbox.shape.yx),
chunks,
dtype=dtype,
extra_dims=extra_dims,
)
return tuple(int(ch[0]) for ch in resolved_chunks)
Loading

0 comments on commit abcf9c0

Please sign in to comment.