From 22b48d15318120aa2baebd1f71b1a36f0f487fa5 Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Thu, 15 Feb 2024 18:36:46 +1100 Subject: [PATCH] amend me: extra dims support --- odc/loader/_builder.py | 84 +++++++++++++++++++++++++++++++++++--- odc/loader/_reader.py | 1 + odc/loader/test_builder.py | 70 +++++++++++++++++++++++++++++++ odc/loader/types.py | 10 ++++- odc/stac/_stac_load.py | 7 ++-- odc/stac/testing/stac.py | 3 +- tests/test_load.py | 8 ++-- 7 files changed, 167 insertions(+), 16 deletions(-) create mode 100644 odc/loader/test_builder.py diff --git a/odc/loader/_builder.py b/odc/loader/_builder.py index cd33f4f..1eeb8f8 100644 --- a/odc/loader/_builder.py +++ b/odc/loader/_builder.py @@ -12,6 +12,7 @@ Iterator, List, Literal, + Mapping, Optional, Protocol, Sequence, @@ -31,7 +32,14 @@ from ._dask import unpack_chunks from ._reader import nodata_mask, resolve_src_nodata from ._utils import SizedIterable, pmap -from .types import MultiBandRasterSource, RasterLoadParams, RasterSource, SomeReader +from .types import ( + FixedCoord, + MultiBandRasterSource, + RasterGroupMetadata, + RasterLoadParams, + RasterSource, + SomeReader, +) class MkArray(Protocol): @@ -80,6 +88,7 @@ class DaskGraphBuilder: def __init__( self, cfg: Dict[str, RasterLoadParams], + template: RasterGroupMetadata, srcs: Sequence[MultiBandRasterSource], tyx_bins: Dict[Tuple[int, int, int], List[int]], gbt: GeoboxTiles, @@ -88,6 +97,7 @@ def __init__( time_chunks: int = 1, ) -> None: self.cfg = cfg + self.template = template self.srcs = srcs self.tyx_bins = tyx_bins self.gbt = gbt @@ -96,6 +106,21 @@ def __init__( self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, time_chunks) self.chunk_shape = (time_chunks, *self.gbt.chunk_shape((0, 0)).yx) + def build( + self, + gbox: GeoBox, + time: Sequence[datetime], + bands: Dict[str, RasterLoadParams], + ): + return mk_dataset( + gbox, + time, + bands, + self, + extra_coords=self.template.extra_coords, + extra_dims=self.template.extra_dims, + ) + def __call__( self, shape: Tuple[int, ...], @@ -217,12 +242,29 @@ def mk_dataset( time: Sequence[datetime], bands: Dict[str, RasterLoadParams], alloc: Optional[MkArray] = None, + *, + extra_coords: Sequence[FixedCoord] | None = None, + extra_dims: Mapping[str, int] | None = None, ) -> xr.Dataset: - _shape = (len(time), *gbox.shape.yx) coords = xr_coords(gbox) crs_coord_name: Hashable = list(coords)[-1] coords["time"] = xr.DataArray(time, dims=("time",)) - dims = ("time", *gbox.dimensions) + _coords: Mapping[str, xr.DataArray] = {} + _dims: Dict[str, int] = {} + + if extra_coords is not None: + _coords = { + coord.name: xr.DataArray( + np.array(coord.values, dtype=coord.dtype), + dims=(coord.name,), + name=coord.name, + ) + for coord in extra_coords + } + _dims.update({coord.name: len(coord.values) for coord in extra_coords}) + + if extra_dims is not None: + _dims.update(extra_dims) def _alloc(shape: Tuple[int, ...], dtype: str, name: Hashable) -> Any: if alloc is not None: @@ -231,12 +273,35 @@ def _alloc(shape: Tuple[int, ...], dtype: str, name: Hashable) -> Any: def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray: assert band.dtype is not None - data = _alloc(_shape, band.dtype, name=name) + band_coords = {**coords} + + if band.dims is not None and len(band.dims) > 2: + # TODO: generalize to more dims + assert band.dims[:2] == ("y", "x") + assert len(band.dims) == 3 + + _, _, extra_dim = band.dims + + dims: Tuple[str, ...] = ("time", *gbox.dimensions, extra_dim) + shape: Tuple[int, ...] = (len(time), *gbox.shape.yx, _dims[extra_dim]) + + band_coords.update( + { + _coords[dim].name: _coords[dim] + for dim in band.dims + if dim not in ("y", "x") + } + ) + else: + dims = ("time", *gbox.dimensions) + shape = (len(time), *gbox.shape.yx) + + data = _alloc(shape, band.dtype, name=name) attrs = {} if band.fill_value is not None: attrs["nodata"] = band.fill_value - xx = xr.DataArray(data=data, coords=coords, dims=dims, attrs=attrs) + xx = xr.DataArray(data=data, coords=band_coords, dims=dims, attrs=attrs) xx.encoding.update(grid_mapping=crs_coord_name) return xx @@ -245,6 +310,7 @@ def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray: def direct_chunked_load( load_cfg: Dict[str, RasterLoadParams], + template: RasterGroupMetadata, srcs: Sequence[MultiBandRasterSource], tyx_bins: Dict[Tuple[int, int, int], List[int]], gbt: GeoboxTiles, @@ -265,7 +331,13 @@ def direct_chunked_load( bands = list(load_cfg) gbox = gbt.base assert isinstance(gbox, GeoBox) - ds = mk_dataset(gbox, tss, load_cfg) + ds = mk_dataset( + gbox, + tss, + load_cfg, + extra_coords=template.extra_coords, + extra_dims=template.extra_dims, + ) ny, nx = gbt.shape.yx total_tasks = nt * nb * ny * nx diff --git a/odc/loader/_reader.py b/odc/loader/_reader.py index 868facb..fecc867 100644 --- a/odc/loader/_reader.py +++ b/odc/loader/_reader.py @@ -59,6 +59,7 @@ def _resolve(name: str, band: RasterBandMetadata) -> RasterLoadParams: use_overviews=use_overviews, resampling=_resampling(name, "nearest"), fail_on_error=fail_on_error, + dims=band.dims, ) return {name: _resolve(name, band) for name, band in bands.items()} diff --git a/odc/loader/test_builder.py b/odc/loader/test_builder.py new file mode 100644 index 0000000..300932b --- /dev/null +++ b/odc/loader/test_builder.py @@ -0,0 +1,70 @@ +# pylint: disable=missing-function-docstring,missing-module-docstring,too-many-statements +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace as _sn +from typing import Dict, Mapping, Sequence + +import pytest +from odc.geo.geobox import GeoBox + +from ._builder import mk_dataset +from .types import FixedCoord, RasterLoadParams + +time = [datetime(2020, 1, 1)] +gbox = GeoBox.from_bbox((-180, -90, 180, 90), shape=(160, 320), tight=True) +shape = (len(time), *gbox.shape.yx) +dims = ("time", *gbox.dimensions) + +_rlp = RasterLoadParams + + +@pytest.mark.parametrize( + "bands,extra_coords,extra_dims,expect", + [ + [ + {"a": _rlp("uint8")}, + None, + None, + {"a": _sn(dims=dims, shape=shape)}, + ], + [ + {"a": _rlp("uint8", dims=("y", "x", "B"))}, + [FixedCoord("B", ["r", "g", "b"])], + None, + {"a": _sn(dims=(*dims, "B"), shape=(*shape, 3))}, + ], + [ + {"a": _rlp("uint8", dims=("y", "x", "W"))}, + [FixedCoord("W", ["r", "g", "b", "a"])], + {"b": 4}, + {"a": _sn(dims=(*dims, "W"), shape=(*shape, 4))}, + ], + ], +) +def test_mk_dataset( + bands: Dict[str, RasterLoadParams], + extra_coords: Sequence[FixedCoord] | None, + extra_dims: Mapping[str, int] | None, + expect: Mapping[str, _sn], +): + assert gbox.crs == "EPSG:4326" + xx = mk_dataset( + gbox, + time, + bands=bands, + extra_coords=extra_coords, + extra_dims=extra_dims, + ) + assert xx is not None + + for n, e in expect.items(): + assert n in xx.data_vars + v = xx[n] + assert v.dims == e.dims + assert v.shape == e.shape + + if extra_coords is not None: + for c in extra_coords: + assert c.name in xx.coords + assert xx.coords[c.name].shape == (len(c.values),) diff --git a/odc/loader/types.py b/odc/loader/types.py index 1ac9fc0..4f7558a 100644 --- a/odc/loader/types.py +++ b/odc/loader/types.py @@ -99,7 +99,7 @@ class FixedCoord: def __post_init__(self): if self.dtype is None: - self.dtype = np.array(self.values).dtype.name + self.dtype = np.array(self.values).dtype.str if self.dim is None: self.dim = self.name @@ -245,6 +245,8 @@ class RasterLoadParams: Captures data loading configuration. """ + # pylint: disable=too-many-instance-attributes + dtype: Optional[str] = None """Output dtype, default same as source.""" @@ -281,6 +283,9 @@ class RasterLoadParams: fail_on_error: bool = True """Quit on the first error or continue.""" + dims: Optional[Tuple[str, ...]] = None + """Dimension names for this band.""" + @staticmethod def same_as(src: Union[RasterBandMetadata, RasterSource]) -> "RasterLoadParams": """Construct from source object.""" @@ -293,7 +298,7 @@ def same_as(src: Union[RasterBandMetadata, RasterSource]) -> "RasterLoadParams": if dtype is None: dtype = "float32" - return RasterLoadParams(dtype=dtype, fill_value=meta.nodata) + return RasterLoadParams(dtype=dtype, fill_value=meta.nodata, dims=meta.dims) @property def nearest(self) -> bool: @@ -315,6 +320,7 @@ def _repr_json_(self) -> Dict[str, Any]: "use_overviews": self.use_overviews, "resampling": self.resampling, "fail_on_error": self.fail_on_error, + "dims": self.dims, } diff --git a/odc/stac/_stac_load.py b/odc/stac/_stac_load.py index 24e0059..105b170 100644 --- a/odc/stac/_stac_load.py +++ b/odc/stac/_stac_load.py @@ -35,7 +35,6 @@ from odc.loader import ( DaskGraphBuilder, direct_chunked_load, - mk_dataset, reader_driver, resolve_chunk_shape, resolve_load_cfg, @@ -463,8 +462,9 @@ def _with_debug_info(ds: xr.Dataset, **kw) -> xr.Dataset: rdr_env = rdr.capture_env() if chunks is not None: - _loader = DaskGraphBuilder( + dask_loader = DaskGraphBuilder( load_cfg, + collection.meta, _parsed, tyx_bins, gbt, @@ -472,11 +472,12 @@ def _with_debug_info(ds: xr.Dataset, **kw) -> xr.Dataset: rdr, time_chunks=chunk_shape[0], ) - return _with_debug_info(mk_dataset(gbox, tss, load_cfg, _loader)) + return _with_debug_info(dask_loader.build(gbox, tss, load_cfg)) return _with_debug_info( direct_chunked_load( load_cfg, + collection.meta, _parsed, tyx_bins, gbt, diff --git a/odc/stac/testing/stac.py b/odc/stac/testing/stac.py index cc7e870..2c84cfe 100644 --- a/odc/stac/testing/stac.py +++ b/odc/stac/testing/stac.py @@ -47,6 +47,7 @@ def b_( dtype="int16", nodata=None, unit="1", + dims=None, uri=None, bidx=1, prefix="http://example.com/items/", @@ -55,7 +56,7 @@ def b_( name, _ = band_key if uri is None: uri = f"{prefix}{name}.tif" - meta = RasterBandMetadata(dtype, nodata, unit) + meta = RasterBandMetadata(dtype, nodata, unit, dims=dims) return (band_key, RasterSource(uri, bidx, geobox=geobox, meta=meta)) diff --git a/tests/test_load.py b/tests/test_load.py index fb15900..2d56aed 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -217,7 +217,7 @@ def test_resolve_load_cfg(): item = mk_parsed_item( [ b_("a", dtype="int8", nodata=-1), - b_("b", dtype="float64"), + b_("b", dtype="float64", dims=("y", "x", "b")), ] ) @@ -229,7 +229,7 @@ def test_resolve_load_cfg(): cfg = resolve_load_cfg(_bands, resampling="average") assert cfg["a"] == rlp("int8", -1, resampling="average") - assert cfg["b"] == rlp("float64", None, resampling="average") + assert cfg["b"] == rlp("float64", None, resampling="average", dims=("y", "x", "b")) cfg = resolve_load_cfg( _bands, @@ -238,11 +238,11 @@ def test_resolve_load_cfg(): dtype="int64", ) assert cfg["a"] == rlp("int64", -999, resampling="mode") - assert cfg["b"] == rlp("int64", -999, resampling="sum") + assert cfg["b"] == rlp("int64", -999, resampling="sum", dims=("y", "x", "b")) cfg = resolve_load_cfg( _bands, dtype={"a": "float32"}, ) assert cfg["a"] == rlp("float32", -1) - assert cfg["b"] == rlp("float64", None) + assert cfg["b"] == rlp("float64", None, dims=_bands["b"].dims)