diff --git a/odc/loader/_builder.py b/odc/loader/_builder.py index bb1fee5..3d698a8 100644 --- a/odc/loader/_builder.py +++ b/odc/loader/_builder.py @@ -55,6 +55,7 @@ def __call__( dtype: DTypeLike, /, name: Hashable, + ydim: int, ) -> Any: ... # pragma: no cover @@ -64,13 +65,29 @@ class LoadChunkTask: Unit of work for dask graph builder. """ + # pylint: disable=too-many-instance-attributes + band: str srcs: List[List[Tuple[int, str]]] cfg: RasterLoadParams gbt: GeoboxTiles - idx_tyx: Tuple[int, int, int] - prefix_dims: Tuple[int, ...] = () - postfix_dims: Tuple[int, ...] = () + idx: Tuple[int, ...] + shape: Tuple[int, ...] + ydim: int = 1 + selection: Any = None # optional slice into extra dims + + @property + def idx_tyx(self) -> Tuple[int, int, int]: + ydim = self.ydim + return self.idx[0], self.idx[ydim], self.idx[ydim + 1] + + @property + def prefix_dims(self) -> tuple[int, ...]: + return self.shape[1 : self.ydim] + + @property + def postfix_dims(self) -> tuple[int, ...]: + return self.shape[self.ydim + 2 :] @property def dst_roi(self) -> Tuple[slice, ...]: @@ -116,14 +133,14 @@ class DaskGraphBuilder: def __init__( self, - cfg: Dict[str, RasterLoadParams], + cfg: Mapping[str, RasterLoadParams], template: RasterGroupMetadata, srcs: Sequence[MultiBandRasterSource], - tyx_bins: Dict[Tuple[int, int, int], List[int]], + tyx_bins: Mapping[Tuple[int, int, int], List[int]], gbt: GeoboxTiles, env: Dict[str, Any], rdr: ReaderDriver, - chunks: dict[str, int], + chunks: Mapping[str, int], ) -> None: gbox = gbt.base assert isinstance(gbox, GeoBox) @@ -145,7 +162,7 @@ def build( self, gbox: GeoBox, time: Sequence[datetime], - bands: Dict[str, RasterLoadParams], + bands: Mapping[str, RasterLoadParams], ): return mk_dataset( gbox, @@ -161,12 +178,13 @@ def __call__( dtype: DTypeLike, /, name: Hashable, + ydim: int, ) -> Any: # pylint: disable=too-many-locals assert isinstance(name, str) cfg = self.cfg[name] assert dtype == cfg.dtype - ydim = cfg.ydim + 1 # +1 for time dimension + assert ydim == cfg.ydim + 1 # +1 for time dimension postfix_dims = shape[ydim + 2 :] prefix_dims = shape[1:ydim] @@ -325,7 +343,7 @@ def _fill_nd_slice( def mk_dataset( gbox: GeoBox, time: Sequence[datetime], - bands: Dict[str, RasterLoadParams], + bands: Mapping[str, RasterLoadParams], alloc: Optional[MkArray] = None, *, template: RasterGroupMetadata, @@ -344,17 +362,17 @@ def mk_dataset( for coord in template.extra_coords } - def _alloc(shape: Tuple[int, ...], dtype: str, name: Hashable) -> Any: + def _alloc(shape: Tuple[int, ...], dtype: str, name: Hashable, ydim: int) -> Any: if alloc is not None: - return alloc(shape, dtype, name=name) + return alloc(shape, dtype, name=name, ydim=ydim) return np.empty(shape, dtype=dtype) def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray: assert band.dtype is not None band_coords = {**coords} + ydim = band.ydim if len(band.dims) > 2: - ydim = band.ydim assert band.dims[ydim : ydim + 2] == ("y", "x") prefix_dims = band.dims[:ydim] postfix_dims = band.dims[ydim + 2 :] @@ -383,7 +401,12 @@ def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray: dims = ("time", *gbox.dimensions) shape = (len(time), *gbox.shape.yx) - data = _alloc(shape, band.dtype, name=name) + data = _alloc( + shape, + band.dtype, + name=name, + ydim=ydim + 1, # +1 for time dimension + ) attrs = {} if band.fill_value is not None: attrs["nodata"] = band.fill_value @@ -396,16 +419,16 @@ def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray: def chunked_load( - load_cfg: Dict[str, RasterLoadParams], + load_cfg: Mapping[str, RasterLoadParams], template: RasterGroupMetadata, srcs: Sequence[MultiBandRasterSource], - tyx_bins: Dict[Tuple[int, int, int], List[int]], + tyx_bins: Mapping[Tuple[int, int, int], List[int]], gbt: GeoboxTiles, tss: Sequence[datetime], env: Dict[str, Any], rdr: ReaderDriver, *, - chunks: Dict[str, int | Literal["auto"]] | None = None, + chunks: Mapping[str, int | Literal["auto"]] | None = None, pool: ThreadPoolExecutor | int | None = None, progress: Optional[Any] = None, ) -> xr.Dataset: @@ -440,16 +463,16 @@ def chunked_load( def dask_chunked_load( - load_cfg: Dict[str, RasterLoadParams], + load_cfg: Mapping[str, RasterLoadParams], template: RasterGroupMetadata, srcs: Sequence[MultiBandRasterSource], - tyx_bins: Dict[Tuple[int, int, int], List[int]], + tyx_bins: Mapping[Tuple[int, int, int], List[int]], gbt: GeoboxTiles, tss: Sequence[datetime], env: Dict[str, Any], rdr: ReaderDriver, *, - chunks: Dict[str, int | Literal["auto"]] | None = None, + chunks: Mapping[str, int | Literal["auto"]] | None = None, ) -> xr.Dataset: """Builds Dask graph for data loading.""" if chunks is None: @@ -478,13 +501,21 @@ def dask_chunked_load( return dask_loader.build(gbox, tss, load_cfg) +def denorm_ydim(x: tuple[int, ...], ydim: int) -> tuple[int, ...]: + ydim = ydim - 1 + if ydim == 0: + return x + t, y, x, *rest = x + return (t, *rest[:ydim], y, x, *rest[ydim:]) + + def load_tasks( - load_cfg: Dict[str, RasterLoadParams], - tyx_bins: Dict[Tuple[int, int, int], List[int]], + load_cfg: Mapping[str, RasterLoadParams], + tyx_bins: Mapping[Tuple[int, int, int], List[int]], gbt: GeoboxTiles, *, nt: Optional[int] = None, - chunks: dict[str, int] | None = None, + chunks: Mapping[str, int] | None = None, extra_dims: Mapping[str, int] | None = None, ) -> Iterator[LoadChunkTask]: """ @@ -502,42 +533,78 @@ def load_tasks( if chunks is None: chunks = {} - time_chunks = chunks.get("time", 1) - - shape_in_chunks: Tuple[int, int, int] = ( - (nt + time_chunks - 1) // time_chunks, - *gbt.shape.yx, - ) + base_shape = (nt, *gbt.base.shape.yx) for band_name, cfg in load_cfg.items(): - ydim = cfg.ydim - prefix_dims = tuple(extra_dims[dim] for dim in cfg.dims[:ydim]) - postfix_dims = tuple(extra_dims[dim] for dim in cfg.dims[ydim + 2 :]) + _edims: Mapping[str, int] = {} - for idx in np.ndindex(shape_in_chunks): + if _dims := cfg.extra_dims: + _edims = dict((k, v) for k, v in extra_dims.items() if k in _dims) + + _chunks = resolve_chunks(base_shape, chunks, dtype=cfg.dtype, extra_dims=_edims) + _offsets: list[tuple[int, ...]] = [ + (0, *np.cumsum(ch, dtype="int64").tolist()) for ch in _chunks + ] + shape_in_chunks = tuple(len(ch) for ch in _chunks) # T,Y,X[,B] + ndim = len(shape_in_chunks) + ydim = cfg.ydim + 1 + + for idx in np.ndindex(shape_in_chunks[:3]): tBi, yi, xi = idx # type: ignore srcs: List[List[Tuple[int, str]]] = [] - t0 = tBi * time_chunks - for ti in range(t0, min(t0 + time_chunks, nt)): + t0, nt = _offsets[0][tBi], _chunks[0][tBi] + for ti in range(t0, t0 + nt): tyx_idx = (ti, yi, xi) srcs.append([(idx, band_name) for idx in tyx_bins.get(tyx_idx, [])]) - yield LoadChunkTask( - band_name, - srcs, - cfg, - gbt, - (tBi, yi, xi), - prefix_dims=prefix_dims, - postfix_dims=postfix_dims, + chunk_shape_tyx: tuple[int, ...] = tuple( + _chunks[dim][i_chunk] for dim, i_chunk in enumerate(idx) ) + if ndim == 3: + yield LoadChunkTask( + band_name, + srcs, + cfg, + gbt, + idx, + chunk_shape_tyx, + ) + continue + + for extra_idx in np.ndindex(shape_in_chunks[3:]): + extra_chunk_shape = tuple( + _chunks[dim][i_chunk] + for dim, i_chunk in enumerate(extra_idx, start=3) + ) + extra_chunk_offset = ( + _offsets[dim][i_chunk] + for dim, i_chunk in enumerate(extra_idx, start=3) + ) + selection: Any = tuple( + slice(o, o + n) + for o, n in zip(extra_chunk_offset, extra_chunk_shape) + ) + if len(selection) == 1: + selection = selection[0] + + yield LoadChunkTask( + band_name, + srcs, + cfg, + gbt, + denorm_ydim(idx + extra_idx, ydim), + denorm_ydim(chunk_shape_tyx + extra_chunk_shape, ydim), + ydim=ydim, + selection=selection, + ) + def direct_chunked_load( - load_cfg: Dict[str, RasterLoadParams], + load_cfg: Mapping[str, RasterLoadParams], template: RasterGroupMetadata, srcs: Sequence[MultiBandRasterSource], - tyx_bins: Dict[Tuple[int, int, int], List[int]], + tyx_bins: Mapping[Tuple[int, int, int], List[int]], gbt: GeoboxTiles, tss: Sequence[datetime], env: Dict[str, Any], @@ -602,46 +669,70 @@ def _do_one(task: LoadChunkTask) -> Tuple[str, int, int, int]: return ds +def _largest_dtype( + cfg: Mapping[str, RasterLoadParams] | None, + fallback: str | np.dtype = "float32", +) -> np.dtype: + if isinstance(fallback, str): + fallback = np.dtype(fallback) + + if cfg is None: + return fallback + + _dtypes = sorted( + set(np.dtype(cfg.dtype) for cfg in cfg.values() if cfg.dtype is not None), + key=lambda x: x.itemsize, + reverse=True, + ) + if _dtypes: + return _dtypes[0] + + return fallback + + +def resolve_chunks( + base_shape: tuple[int, int, int], + chunks: Mapping[str, int | Literal["auto"]], + dtype: Any | None = None, + extra_dims: Mapping[str, int] | None = None, + limit: Any | None = None, +) -> tuple[tuple[int, ...], ...]: + if extra_dims is None: + extra_dims = {} + tt = chunks.get("time", 1) + ty, tx = (chunks.get(dim, -1) for dim in ["y", "x"]) + chunks = (tt, ty, tx) + tuple((chunks.get(dim, -1) for dim in extra_dims)) + shape = base_shape + tuple(extra_dims.values()) + return normalize_chunks(chunks, shape, dtype=dtype, limit=limit) + + def resolve_chunk_shape( nt: int, gbox: GeoBoxBase, - chunks: Dict[str, int | Literal["auto"]], + chunks: Mapping[str, int | Literal["auto"]], dtype: Any | None = None, cfg: Mapping[str, RasterLoadParams] | None = None, extra_dims: Mapping[str, int] | None = None, ) -> Tuple[int, ...]: """ - Compute chunk size for time, y and x dimensions. - """ - if cfg is None: - cfg = {} + Compute chunk size for time, y and x dimensions and extra dims. - if dtype is None: - _dtypes = sorted( - set(cfg.dtype for cfg in cfg.values() if cfg.dtype is not None), - key=lambda x: np.dtype(x).itemsize, - reverse=True, - ) - dtype = "uint16" if len(_dtypes) == 0 else _dtypes[0] + Spatial dimension chunks need to be suppliead with ``y,x`` keys. - tt = chunks.get("time", 1) - ty, tx = ( - chunks.get(dim, chunks.get(fallback_dim, -1)) - for dim, fallback_dim in zip(gbox.dimensions, ["y", "x"]) - ) - postfix_shape: tuple[int, ...] = () - postfix_chunks: tuple[int, ...] = () + :returns: Chunk shape in (T,Y,X, *extra_dims) order + """ + if dtype is None and cfg: + dtype = _largest_dtype(cfg, "float32") - if extra_dims: - postfix_shape, postfix_chunks = zip( - *[(sz, chunks.get(dim, -1)) for dim, sz in extra_dims.items()] - ) + chunks = {**chunks} + for s, d in zip(gbox.dimensions, ["y", "x"]): + if s != d and s in chunks: + chunks[d] = chunks[s] - return tuple( - int(ch[0]) - for ch in normalize_chunks( - (tt, ty, tx, *postfix_chunks), - (nt, *gbox.shape.yx, *postfix_shape), - dtype=dtype, - ) + resolved_chunks = resolve_chunks( + (nt, *gbox.shape.yx), + chunks, + dtype=dtype, + extra_dims=extra_dims, ) + return tuple(int(ch[0]) for ch in resolved_chunks) diff --git a/odc/loader/test_builder.py b/odc/loader/test_builder.py index 9462200..cd114aa 100644 --- a/odc/loader/test_builder.py +++ b/odc/loader/test_builder.py @@ -1,9 +1,10 @@ # pylint: disable=missing-function-docstring,missing-module-docstring,too-many-statements,too-many-locals +# pylint: disable=redefined-outer-name,unused-argument from __future__ import annotations from datetime import datetime from types import SimpleNamespace as _sn -from typing import Dict, Mapping, Sequence +from typing import Any, Dict, Literal, Mapping, Sequence import dask import dask.array as da @@ -13,7 +14,14 @@ from odc.geo.geobox import GeoBox, GeoboxTiles from . import chunked_load -from ._builder import DaskGraphBuilder, mk_dataset, resolve_chunk_shape +from ._builder import ( + DaskGraphBuilder, + _largest_dtype, + load_tasks, + mk_dataset, + resolve_chunk_shape, + resolve_chunks, +) from .testing.fixtures import FakeMDPlugin, FakeReaderDriver from .types import ( FixedCoord, @@ -37,6 +45,10 @@ def _full_tyx_bins( return {idx: list(range(nsrcs)) for idx in np.ndindex((nt, *tiles.shape.yx))} # type: ignore +def _num_chunks(chunk: int, sz: int) -> int: + return (sz + chunk - 1) // chunk + + # bands,extra_coords,extra_dims,expect rlp_fixtures = [ [ @@ -194,6 +206,55 @@ def test_dask_builder( check_xx(xx_dasked, bands, extra_coords, extra_dims, expect) +@pytest.mark.parametrize( + "cfg,fallback,expect", + [ + ({}, "uint8", "uint8"), + ({}, "float64", "float64"), + (None, "float64", "float64"), + ({"a": _rlp("uint16")}, "float32", "uint16"), + ({"a": _rlp("uint16"), "b": _rlp("int32")}, "float32", "int32"), + ({"a": _rlp()}, "float32", "float32"), + ], +) +def test_largest_dtype(cfg, fallback, expect): + assert _largest_dtype(cfg, fallback) == expect + + +@pytest.mark.parametrize( + "base_shape,chunks,extra_dims,expect", + [ + ((1, 200, 300), {}, None, ((1,), (200,), (300,))), + ((1, 200, 300), {"y": 100}, None, ((1,), (100, 100), (300,))), + ( + (3, 200, 300), + {"y": 100, "x": 200, "time": 2}, + None, + ((2, 1), (100, 100), (200, 100)), + ), + ((1, 200, 300), {}, {"b": 3}, ((1,), (200,), (300,), (3,))), + ((1, 200, 300), {"b": 1}, {"b": 3}, ((1,), (200,), (300,), (1, 1, 1))), + ], +) +@pytest.mark.parametrize("dtype", [None, "uint8", "uint16"]) +def test_resolve_chunks( + base_shape: tuple[int, int, int], + chunks: Mapping[str, int | Literal["auto"]], + extra_dims: Mapping[str, int] | None, + expect: tuple[int, ...], + dtype: Any | None, +): + normed_chunks = resolve_chunks(base_shape, chunks, dtype, extra_dims) + assert isinstance(normed_chunks, tuple) + assert len(normed_chunks) == len(base_shape) + len(extra_dims or {}) + assert all( + (isinstance(ii, tuple) and isinstance(ii[0], int)) for ii in normed_chunks + ) + + if expect is not None: + assert normed_chunks == expect + + def test_resolve_chunk_shape(): # pylint: disable=redefined-outer-name nt = 7 @@ -202,6 +263,11 @@ def test_resolve_chunk_shape(): assert resolve_chunk_shape(nt, gbox, {}) == (1, *yx_shape) assert resolve_chunk_shape(nt, gbox, {"time": 3}) == (3, *yx_shape) assert resolve_chunk_shape(nt, gbox, {"y": 10, "x": 20}) == (1, 10, 20) + assert resolve_chunk_shape( + nt, + gbox, + dict(zip(gbox.dimensions, [10, 20])), + ) == (1, 10, 20) # extra chunks without extra_dims should be ignored assert resolve_chunk_shape(nt, gbox, {"y": 10, "x": 20, "b": 3}) == (1, 10, 20) @@ -218,3 +284,66 @@ def test_resolve_chunk_shape(): {"y": 10, "x": 20}, extra_dims={"b": 100}, ) == (1, 10, 20, 100) + + +@pytest.mark.parametrize( + "chunks,dims,nt,nsrcs,extra_dims", + [ + ({}, (), 3, 1, {}), + ({"time": 2}, (), 3, 1, {}), + ({"time": 2, "y": 80, "x": 80}, (), 3, 1, {}), + ({"b": 2, "y": 80, "x": 80}, ("b", "y", "x"), 2, 1, {"b": 5}), + ({"time": 2, "y": 100, "b": 1}, ("y", "x", "b"), 4, 3, {"b": 4}), + ({"time": 2, "y": 100, "b": 1}, ("y", "x", "b"), 1, 3, {"b": 4}), + ], +) +def test_load_tasks( + chunks: Mapping[str, int], + dims: tuple[str, ...], + extra_dims: Mapping[str, int], + nt: int, + nsrcs: int, +): + var_name = "xx" + cfg = {var_name: RasterLoadParams("uint8", dims=dims)} + + ydim = 1 + (dims.index("y") if dims else 0) + assert ydim in (1, 2) + + _nt, ny, nx, *extra_chunks = resolve_chunk_shape( + nt, gbox, chunks, extra_dims=extra_dims + ) + assert len(extra_chunks) in (0, 1) + assert _nt == min(chunks.get("time", 1), nt) + nt_chunks = _num_chunks(_nt, nt) + nb_chunks = 1 + if extra_chunks: + nb_chunks = _num_chunks(extra_chunks[0], list(extra_dims.values())[0]) + + gbt = GeoboxTiles(gbox, (ny, nx)) + + tyx_bins = _full_tyx_bins(gbt, nsrcs=nsrcs, nt=nt) + assert len(tyx_bins) == nt * gbt.shape.y * gbt.shape.x + + tasks = load_tasks(cfg, tyx_bins, gbt, nt=nt, chunks=chunks, extra_dims=extra_dims) + tasks = list(tasks) + assert len(tasks) == nt_chunks * gbt.shape.y * gbt.shape.x * nb_chunks + + for t in tasks: + assert t.band == var_name + assert t.gbt is gbt + assert t.idx_tyx in tyx_bins + assert t.ydim == ydim + assert len(t.srcs) > 0 + assert isinstance(t.srcs[0], list) + assert len(t.idx) == len(t.postfix_dims) + len(t.prefix_dims) + 2 + 1 + + if dims: + assert len(t.postfix_dims) + len(t.prefix_dims) == len(dims) - 2 + assert len(t.idx) == len(dims) + 1 + else: + assert t.idx == t.idx_tyx + assert len(t.postfix_dims) + len(t.prefix_dims) == 0 + + assert gbox.enclosing(t.dst_gbox.boundingbox) == t.dst_gbox + assert gbox[t.dst_roi[ydim : ydim + 2]] == t.dst_gbox diff --git a/odc/loader/types.py b/odc/loader/types.py index a4e9750..e75d6f4 100644 --- a/odc/loader/types.py +++ b/odc/loader/types.py @@ -207,12 +207,16 @@ def _repr_json_(self) -> Dict[str, Any]: "extra_coords": [c._repr_json_() for c in self.extra_coords], } - def extra_dims_full(self) -> Dict[str, int]: + def extra_dims_full(self, band: BandIdentifier | None = None) -> dict[str, int]: dims = {**self.extra_dims} for coord in self.extra_coords: if coord.dim not in dims: dims[coord.dim] = len(coord.values) + if band is not None: + band_dims = self.bands[norm_key(band)].dims + dims = {k: v for k, v in dims.items() if k in band_dims} + return dims