From 5241fa6ca3eac27fca72f831d1ed51b1508f38d0 Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Thu, 15 Feb 2024 18:36:46 +1100 Subject: [PATCH] Supporting postfix dimensions on load --- odc/loader/_builder.py | 158 ++++++++++++++++++++++++++------ odc/loader/_reader.py | 1 + odc/loader/test_builder.py | 163 +++++++++++++++++++++++++++++++++ odc/loader/testing/fixtures.py | 96 ++++++++++++++++++- odc/loader/types.py | 18 ++-- odc/stac/_mdtools.py | 3 +- odc/stac/_stac_load.py | 7 +- odc/stac/testing/stac.py | 3 +- tests/test_load.py | 8 +- 9 files changed, 411 insertions(+), 46 deletions(-) create mode 100644 odc/loader/test_builder.py diff --git a/odc/loader/_builder.py b/odc/loader/_builder.py index cd33f4f..10704f8 100644 --- a/odc/loader/_builder.py +++ b/odc/loader/_builder.py @@ -12,6 +12,7 @@ Iterator, List, Literal, + Mapping, Optional, Protocol, Sequence, @@ -31,7 +32,14 @@ from ._dask import unpack_chunks from ._reader import nodata_mask, resolve_src_nodata from ._utils import SizedIterable, pmap -from .types import MultiBandRasterSource, RasterLoadParams, RasterSource, SomeReader +from .types import ( + FixedCoord, + MultiBandRasterSource, + RasterGroupMetadata, + RasterLoadParams, + RasterSource, + SomeReader, +) class MkArray(Protocol): @@ -58,11 +66,12 @@ class LoadChunkTask: cfg: RasterLoadParams gbt: GeoboxTiles idx_tyx: Tuple[int, int, int] + postfix_dims: Tuple[int, ...] = () @property def dst_roi(self): t, y, x = self.idx_tyx - return (t, *self.gbt.roi[y, x]) + return (t, *self.gbt.roi[y, x]) + tuple([slice(None)] * len(self.postfix_dims)) @property def dst_gbox(self) -> GeoBox: @@ -80,6 +89,7 @@ class DaskGraphBuilder: def __init__( self, cfg: Dict[str, RasterLoadParams], + template: RasterGroupMetadata, srcs: Sequence[MultiBandRasterSource], tyx_bins: Dict[Tuple[int, int, int], List[int]], gbt: GeoboxTiles, @@ -88,6 +98,7 @@ def __init__( time_chunks: int = 1, ) -> None: self.cfg = cfg + self.template = template self.srcs = srcs self.tyx_bins = tyx_bins self.gbt = gbt @@ -96,6 +107,21 @@ def __init__( self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, time_chunks) self.chunk_shape = (time_chunks, *self.gbt.chunk_shape((0, 0)).yx) + def build( + self, + gbox: GeoBox, + time: Sequence[datetime], + bands: Dict[str, RasterLoadParams], + ): + return mk_dataset( + gbox, + time, + bands, + self, + extra_coords=self.template.extra_coords, + extra_dims=self.template.extra_dims, + ) + def __call__( self, shape: Tuple[int, ...], @@ -104,51 +130,57 @@ def __call__( name: Hashable, ) -> Any: # pylint: disable=too-many-locals - assert len(shape) == 3 assert isinstance(name, str) cfg = self.cfg[name] assert dtype == cfg.dtype + # TODO: assumes postfix dims only for now + ydim = 1 + post_fix_dims = shape[ydim + 2 :] - chunks = unpack_chunks(self.chunk_shape, shape) + chunk_shape = (*self.chunk_shape, *post_fix_dims) + assert len(chunk_shape) == len(shape) + chunks = unpack_chunks(chunk_shape, shape) tchunk_range = [ range(last - n, last) for last, n in zip(np.cumsum(chunks[0]), chunks[0]) ] - cfg_key = f"cfg-{tokenize(cfg)}" - gbt_key = f"grid-{tokenize(self.gbt)}" + cfg_dask_key = f"cfg-{tokenize(cfg)}" + gbt_dask_key = f"grid-{tokenize(self.gbt)}" dsk: Dict[Hashable, Any] = { - cfg_key: cfg, - gbt_key: self.gbt, + cfg_dask_key: cfg, + gbt_dask_key: self.gbt, } tk = self._tk band_key = f"{name}-{tk}" md_key = f"md-{name}-{tk}" shape_in_blocks = tuple(len(ch) for ch in chunks) - for idx, src in enumerate(self.srcs): + for src_idx, src in enumerate(self.srcs): band = src.get(name, None) if band is not None: - dsk[md_key, idx] = band + dsk[md_key, src_idx] = band - for ti, yi, xi in np.ndindex(shape_in_blocks): # type: ignore + for block_idx in np.ndindex(shape_in_blocks): + ti, yi, xi = block_idx[0], block_idx[ydim], block_idx[ydim + 1] srcs = [] for _ti in tchunk_range[ti]: srcs.append( [ - (md_key, idx) - for idx in self.tyx_bins.get((_ti, yi, xi), []) - if (md_key, idx) in dsk + (md_key, src_idx) + for src_idx in self.tyx_bins.get((_ti, yi, xi), []) + if (md_key, src_idx) in dsk ] ) - dsk[band_key, ti, yi, xi] = ( + dsk[(band_key, *block_idx)] = ( _dask_loader_tyx, srcs, - gbt_key, + gbt_dask_key, quote((yi, xi)), + quote(post_fix_dims), self.rdr, - cfg_key, + cfg_dask_key, self.env, ) @@ -159,16 +191,17 @@ def _dask_loader_tyx( srcs: Sequence[Sequence[RasterSource]], gbt: GeoboxTiles, iyx: Tuple[int, int], + postfix_dims: Tuple[int, ...], rdr: SomeReader, cfg: RasterLoadParams, env: Dict[str, Any], ): assert cfg.dtype is not None gbox = cast(GeoBox, gbt[iyx]) - chunk = np.empty((len(srcs), *gbox.shape.yx), dtype=cfg.dtype) + chunk = np.empty((len(srcs), *gbox.shape.yx, *postfix_dims), dtype=cfg.dtype) with rdr.restore_env(env): - for i, plane in enumerate(srcs): - fill_2d_slice(plane, gbox, cfg, rdr, chunk[i, :, :]) + for ti, ti_srcs in enumerate(srcs): + fill_2d_slice(ti_srcs, gbox, cfg, rdr, chunk[ti]) return chunk @@ -184,7 +217,11 @@ def fill_2d_slice( # ``nodata`` marks missing pixels, but it might be None (everything is valid) # ``fill_value`` is the initial value to use, it's equal to ``nodata`` when set, # otherwise defaults to .nan for floats and 0 for integers - assert dst.shape == dst_gbox.shape.yx + + # assume dst[y, x, ...] axis order + assert dst.shape[:2] == dst_gbox.shape.yx + postfix_roi = (slice(None),) * len(dst.shape[2:]) + nodata = resolve_src_nodata(cfg.fill_value, cfg) if nodata is None: @@ -197,11 +234,18 @@ def fill_2d_slice( return dst src, *rest = srcs - _roi, pix = rdr.read(src, cfg, dst_gbox, dst=dst) + yx_roi, pix = rdr.read(src, cfg, dst_gbox, dst=dst) + assert len(yx_roi) == 2 + assert pix.ndim == dst.ndim for src in rest: # first valid pixel takes precedence over others - _roi, pix = rdr.read(src, cfg, dst_gbox) + yx_roi, pix = rdr.read(src, cfg, dst_gbox) + assert len(yx_roi) == 2 + assert pix.ndim == dst.ndim + + _roi: Tuple[slice,] = yx_roi + postfix_roi # type: ignore + assert dst[_roi].shape == pix.shape # nodata mask takes care of nan when working with floats # so you can still get proper mask even when nodata is None @@ -217,12 +261,29 @@ def mk_dataset( time: Sequence[datetime], bands: Dict[str, RasterLoadParams], alloc: Optional[MkArray] = None, + *, + extra_coords: Sequence[FixedCoord] | None = None, + extra_dims: Mapping[str, int] | None = None, ) -> xr.Dataset: - _shape = (len(time), *gbox.shape.yx) coords = xr_coords(gbox) crs_coord_name: Hashable = list(coords)[-1] coords["time"] = xr.DataArray(time, dims=("time",)) - dims = ("time", *gbox.dimensions) + _coords: Mapping[str, xr.DataArray] = {} + _dims: Dict[str, int] = {} + + if extra_coords is not None: + _coords = { + coord.name: xr.DataArray( + np.array(coord.values, dtype=coord.dtype), + dims=(coord.name,), + name=coord.name, + ) + for coord in extra_coords + } + _dims.update({coord.name: len(coord.values) for coord in extra_coords}) + + if extra_dims is not None: + _dims.update(extra_dims) def _alloc(shape: Tuple[int, ...], dtype: str, name: Hashable) -> Any: if alloc is not None: @@ -231,12 +292,38 @@ def _alloc(shape: Tuple[int, ...], dtype: str, name: Hashable) -> Any: def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray: assert band.dtype is not None - data = _alloc(_shape, band.dtype, name=name) + band_coords = {**coords} + + if band.dims is not None and len(band.dims) > 2: + # TODO: generalize to more dims + ydim = 0 + postfix_dims = band.dims[ydim + 2 :] + assert band.dims[ydim : ydim + 2] == ("y", "x") + + dims: Tuple[str, ...] = ("time", *gbox.dimensions, *postfix_dims) + shape: Tuple[int, ...] = ( + len(time), + *gbox.shape.yx, + *[_dims[dim] for dim in postfix_dims], + ) + + band_coords.update( + { + _coords[dim].name: _coords[dim] + for dim in postfix_dims + if dim in _coords + } + ) + else: + dims = ("time", *gbox.dimensions) + shape = (len(time), *gbox.shape.yx) + + data = _alloc(shape, band.dtype, name=name) attrs = {} if band.fill_value is not None: attrs["nodata"] = band.fill_value - xx = xr.DataArray(data=data, coords=coords, dims=dims, attrs=attrs) + xx = xr.DataArray(data=data, coords=band_coords, dims=dims, attrs=attrs) xx.encoding.update(grid_mapping=crs_coord_name) return xx @@ -245,6 +332,7 @@ def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray: def direct_chunked_load( load_cfg: Dict[str, RasterLoadParams], + template: RasterGroupMetadata, srcs: Sequence[MultiBandRasterSource], tyx_bins: Dict[Tuple[int, int, int], List[int]], gbt: GeoboxTiles, @@ -265,7 +353,13 @@ def direct_chunked_load( bands = list(load_cfg) gbox = gbt.base assert isinstance(gbox, GeoBox) - ds = mk_dataset(gbox, tss, load_cfg) + ds = mk_dataset( + gbox, + tss, + load_cfg, + extra_coords=template.extra_coords, + extra_dims=template.extra_dims, + ) ny, nx = gbt.shape.yx total_tasks = nt * nb * ny * nx @@ -286,7 +380,13 @@ def _do_one(task: LoadChunkTask) -> Tuple[str, int, int, int]: if src is not None ] with rdr.restore_env(env): - _ = fill_2d_slice(_srcs, task.dst_gbox, task.cfg, rdr, dst_slice) + _ = fill_2d_slice( + _srcs, + task.dst_gbox, + task.cfg, + rdr=rdr, + dst=dst_slice, + ) t, y, x = task.idx_tyx return (task.band, t, y, x) diff --git a/odc/loader/_reader.py b/odc/loader/_reader.py index 868facb..fecc867 100644 --- a/odc/loader/_reader.py +++ b/odc/loader/_reader.py @@ -59,6 +59,7 @@ def _resolve(name: str, band: RasterBandMetadata) -> RasterLoadParams: use_overviews=use_overviews, resampling=_resampling(name, "nearest"), fail_on_error=fail_on_error, + dims=band.dims, ) return {name: _resolve(name, band) for name, band in bands.items()} diff --git a/odc/loader/test_builder.py b/odc/loader/test_builder.py new file mode 100644 index 0000000..09ebeea --- /dev/null +++ b/odc/loader/test_builder.py @@ -0,0 +1,163 @@ +# pylint: disable=missing-function-docstring,missing-module-docstring,too-many-statements +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace as _sn +from typing import Dict, Mapping, Sequence + +import dask +import dask.array as da +import numpy as np +import pytest +import xarray as xr +from odc.geo.geobox import GeoBox, GeoboxTiles + +from ._builder import DaskGraphBuilder, mk_dataset +from .testing.fixtures import FakeMDPlugin, FakeReader +from .types import ( + FixedCoord, + RasterBandMetadata, + RasterGroupMetadata, + RasterLoadParams, + RasterSource, +) + +time = [datetime(2020, 1, 1)] +gbox = GeoBox.from_bbox((-180, -90, 180, 90), shape=(160, 320), tight=True) +gbt = GeoboxTiles(gbox, (80, 80)) +shape = (len(time), *gbox.shape.yx) +dims = ("time", *gbox.dimensions) +tyx_bins = {(0, *idx): [0] for idx in np.ndindex(gbt.shape.yx)} +_rlp = RasterLoadParams + +rlp_fixtures = [ + [ + # Y,X only + {"a": _rlp("uint8")}, + None, + None, + {"a": _sn(dims=dims, shape=shape)}, + ], + [ + # Y,X,B coords only, no dims + {"a": _rlp("uint8", dims=("y", "x", "B"))}, + [FixedCoord("B", ["r", "g", "b"])], + None, + {"a": _sn(dims=(*dims, "B"), shape=(*shape, 3))}, + ], + [ + # Y,X,B dims only + {"a": _rlp("uint8", dims=("y", "x", "W"))}, + None, + {"W": 4}, + {"a": _sn(dims=(*dims, "W"), shape=(*shape, 4))}, + ], + [ + # Y,X,B coords and dims + {"a": _rlp("uint16", dims=("y", "x", "W"))}, + [FixedCoord("W", ["r", "g", "b", "a"])], + {"W": 4}, + {"a": _sn(dims=(*dims, "W"), shape=(*shape, 4))}, + ], +] + + +def check_xx( + xx, + bands: Dict[str, RasterLoadParams], + extra_coords: Sequence[FixedCoord] | None, + extra_dims: Mapping[str, int] | None, + expect: Mapping[str, _sn], +): + assert isinstance(xx, xr.Dataset) + for name, dv in xx.data_vars.items(): + assert isinstance(dv.data, (np.ndarray, da.Array)) + assert name in bands + assert dv.dtype == bands[name].dtype + + assert set(xx.data_vars) == set(bands) + + for n, e in expect.items(): + assert n in xx.data_vars + v = xx[n] + assert v.dims == e.dims + assert v.shape == e.shape + + if extra_coords is not None: + for c in extra_coords: + assert c.name in xx.coords + assert xx.coords[c.name].shape == (len(c.values),) + + if extra_dims is not None: + for n, s in extra_dims.items(): + assert n in xx.dims + assert n in xx.sizes + assert s == xx.sizes[n] + + +@pytest.mark.parametrize("bands,extra_coords,extra_dims,expect", rlp_fixtures) +def test_mk_dataset( + bands: Dict[str, RasterLoadParams], + extra_coords: Sequence[FixedCoord] | None, + extra_dims: Mapping[str, int] | None, + expect: Mapping[str, _sn], +): + assert gbox.crs == "EPSG:4326" + xx = mk_dataset( + gbox, + time, + bands=bands, + extra_coords=extra_coords, + extra_dims=extra_dims, + ) + check_xx(xx, bands, extra_coords, extra_dims, expect) + + +@pytest.mark.parametrize("bands,extra_coords,extra_dims,expect", rlp_fixtures) +def test_dask_builder( + bands: Dict[str, RasterLoadParams], + extra_coords: Sequence[FixedCoord] | None, + extra_dims: Mapping[str, int] | None, + expect: Mapping[str, _sn], +): + _bands = { + k: RasterBandMetadata(b.dtype, b.fill_value, dims=b.dims) + for k, b in bands.items() + } + extra_dims = {**extra_dims} if extra_dims is not None else {} + rgm = RasterGroupMetadata( + {(k, 1): b for k, b in _bands.items()}, + extra_dims=extra_dims, + extra_coords=extra_coords or [], + ) + + rdr = FakeReader(rgm, parser=FakeMDPlugin(rgm, None)) + rdr_env = rdr.capture_env() + + template = RasterGroupMetadata( + {(k, 1): b for k, b in _bands.items()}, + aliases={}, + extra_dims=extra_dims, + extra_coords=extra_coords or (), + ) + src_mapper = { + k: RasterSource("file:///tmp/a.tif", meta=b) for k, b in _bands.items() + } + srcs = [src_mapper, src_mapper, src_mapper] + + builder = DaskGraphBuilder( + bands, + template=template, + srcs=srcs, + tyx_bins=tyx_bins, + gbt=gbt, + env=rdr_env, + rdr=rdr, + time_chunks=1, + ) + + xx = builder.build(gbox, time, bands) + check_xx(xx, bands, extra_coords, extra_dims, expect) + + (yy,) = dask.compute(xx, scheduler="synchronous") + check_xx(yy, bands, extra_coords, extra_dims, expect) diff --git a/odc/loader/testing/fixtures.py b/odc/loader/testing/fixtures.py index be407d9..0f55e3a 100644 --- a/odc/loader/testing/fixtures.py +++ b/odc/loader/testing/fixtures.py @@ -11,13 +11,22 @@ import tempfile from collections import abc from contextlib import contextmanager -from typing import Any, Generator +from typing import Any, ContextManager, Dict, Generator, Optional, Tuple +import numpy as np import rasterio import xarray as xr +from odc.geo.geobox import GeoBox +from odc.geo.roi import NormalizedROI from odc.geo.xr import ODCExtensionDa -from ..types import BandKey, RasterGroupMetadata +from ..types import ( + BandKey, + MDParser, + RasterGroupMetadata, + RasterLoadParams, + RasterSource, +) @contextmanager @@ -100,3 +109,86 @@ def driver_data(self, md, band_key: BandKey) -> Any: if band_key in self._driver_data: return self._driver_data[band_key] return self._driver_data + + +class FakeReader: + """ + Fake reader for testing. + """ + + class Context: + """ + EMIT Context manager. + """ + + def __init__(self, parent: "FakeReader", env: dict[str, Any]) -> None: + self._parent = parent + self.env = env + + def __enter__(self): + assert self._parent._ctx is None + self._parent._ctx = self + + def __exit__(self, type, value, traceback): + # pylint: disable=unused-argument,redefined-builtin + self._parent._ctx = None + + def __init__( + self, + group_md: RasterGroupMetadata, + *, + parser: MDParser | None = None, + ): + self._group_md = group_md + self._parser = parser or FakeMDPlugin(group_md, None) + self._ctx: FakeReader.Context | None = None + + def capture_env(self) -> Dict[str, Any]: + return {} + + def restore_env(self, env: Dict[str, Any]) -> ContextManager[Any]: + return self.Context(self, env) + + def read( + self, + src: RasterSource, + cfg: RasterLoadParams, + dst_geobox: GeoBox, + dst: Optional[np.ndarray] = None, + ) -> Tuple[NormalizedROI, np.ndarray]: + assert self._ctx is not None + assert src.meta is not None + meta = src.meta + extra_dims = self._group_md.extra_dims or { + coord.dim: len(coord.values) for coord in self._group_md.extra_coords + } + postfix_dims: Tuple[int, ...] = () + if meta.dims is not None: + assert set(meta.dims[2:]).issubset(extra_dims) + postfix_dims = tuple(extra_dims[d] for d in meta.dims[2:]) + + ny, nx = dst_geobox.shape.yx + yx_roi = (slice(0, ny), slice(0, nx)) + shape = (ny, nx, *postfix_dims) + + src_pix: np.ndarray | None = src.driver_data + if src_pix is None: + src_pix = np.ones(shape, dtype=cfg.dtype) + else: + assert src_pix.shape == shape + + assert postfix_dims == src_pix.shape[2:] + + if dst is None: + dst = np.zeros((ny, nx, *postfix_dims), dtype=cfg.dtype) + dst[:] = src_pix.astype(dst.dtype) + return yx_roi, dst + + assert dst.shape == src_pix.shape + dst[:] = src_pix.astype(dst.dtype) + + return yx_roi, dst[yx_roi] + + @property + def md_parser(self) -> MDParser | None: + return self._parser diff --git a/odc/loader/types.py b/odc/loader/types.py index 1ac9fc0..0f302cc 100644 --- a/odc/loader/types.py +++ b/odc/loader/types.py @@ -93,14 +93,14 @@ class FixedCoord: name: str values: Sequence[Any] - dtype: Optional[str] = None - dim: Optional[str] = None + dtype: str = "" + dim: str = "" units: str = "1" def __post_init__(self): - if self.dtype is None: - self.dtype = np.array(self.values).dtype.name - if self.dim is None: + if not self.dtype: + self.dtype = np.array(self.values).dtype.str + if not self.dim: self.dim = self.name def _repr_json_(self) -> Dict[str, Any]: @@ -245,6 +245,8 @@ class RasterLoadParams: Captures data loading configuration. """ + # pylint: disable=too-many-instance-attributes + dtype: Optional[str] = None """Output dtype, default same as source.""" @@ -281,6 +283,9 @@ class RasterLoadParams: fail_on_error: bool = True """Quit on the first error or continue.""" + dims: Optional[Tuple[str, ...]] = None + """Dimension names for this band.""" + @staticmethod def same_as(src: Union[RasterBandMetadata, RasterSource]) -> "RasterLoadParams": """Construct from source object.""" @@ -293,7 +298,7 @@ def same_as(src: Union[RasterBandMetadata, RasterSource]) -> "RasterLoadParams": if dtype is None: dtype = "float32" - return RasterLoadParams(dtype=dtype, fill_value=meta.nodata) + return RasterLoadParams(dtype=dtype, fill_value=meta.nodata, dims=meta.dims) @property def nearest(self) -> bool: @@ -315,6 +320,7 @@ def _repr_json_(self) -> Dict[str, Any]: "use_overviews": self.use_overviews, "resampling": self.resampling, "fail_on_error": self.fail_on_error, + "dims": self.dims, } diff --git a/odc/stac/_mdtools.py b/odc/stac/_mdtools.py index cfe454b..1473ff4 100644 --- a/odc/stac/_mdtools.py +++ b/odc/stac/_mdtools.py @@ -871,7 +871,8 @@ def output_geobox( Exposed at top-level for debugging. """ - # pylint: disable=too-many-locals,too-many-branches,too-many-statements,too-many-return-statements + # pylint: disable=too-many-locals,too-many-branches,too-many-statements + # pylint: disable=too-many-return-statements,too-many-arguments # geobox, like --> GeoBox # lon,lat --> geopolygon[epsg:4326] diff --git a/odc/stac/_stac_load.py b/odc/stac/_stac_load.py index 24e0059..105b170 100644 --- a/odc/stac/_stac_load.py +++ b/odc/stac/_stac_load.py @@ -35,7 +35,6 @@ from odc.loader import ( DaskGraphBuilder, direct_chunked_load, - mk_dataset, reader_driver, resolve_chunk_shape, resolve_load_cfg, @@ -463,8 +462,9 @@ def _with_debug_info(ds: xr.Dataset, **kw) -> xr.Dataset: rdr_env = rdr.capture_env() if chunks is not None: - _loader = DaskGraphBuilder( + dask_loader = DaskGraphBuilder( load_cfg, + collection.meta, _parsed, tyx_bins, gbt, @@ -472,11 +472,12 @@ def _with_debug_info(ds: xr.Dataset, **kw) -> xr.Dataset: rdr, time_chunks=chunk_shape[0], ) - return _with_debug_info(mk_dataset(gbox, tss, load_cfg, _loader)) + return _with_debug_info(dask_loader.build(gbox, tss, load_cfg)) return _with_debug_info( direct_chunked_load( load_cfg, + collection.meta, _parsed, tyx_bins, gbt, diff --git a/odc/stac/testing/stac.py b/odc/stac/testing/stac.py index cc7e870..2c84cfe 100644 --- a/odc/stac/testing/stac.py +++ b/odc/stac/testing/stac.py @@ -47,6 +47,7 @@ def b_( dtype="int16", nodata=None, unit="1", + dims=None, uri=None, bidx=1, prefix="http://example.com/items/", @@ -55,7 +56,7 @@ def b_( name, _ = band_key if uri is None: uri = f"{prefix}{name}.tif" - meta = RasterBandMetadata(dtype, nodata, unit) + meta = RasterBandMetadata(dtype, nodata, unit, dims=dims) return (band_key, RasterSource(uri, bidx, geobox=geobox, meta=meta)) diff --git a/tests/test_load.py b/tests/test_load.py index fb15900..2d56aed 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -217,7 +217,7 @@ def test_resolve_load_cfg(): item = mk_parsed_item( [ b_("a", dtype="int8", nodata=-1), - b_("b", dtype="float64"), + b_("b", dtype="float64", dims=("y", "x", "b")), ] ) @@ -229,7 +229,7 @@ def test_resolve_load_cfg(): cfg = resolve_load_cfg(_bands, resampling="average") assert cfg["a"] == rlp("int8", -1, resampling="average") - assert cfg["b"] == rlp("float64", None, resampling="average") + assert cfg["b"] == rlp("float64", None, resampling="average", dims=("y", "x", "b")) cfg = resolve_load_cfg( _bands, @@ -238,11 +238,11 @@ def test_resolve_load_cfg(): dtype="int64", ) assert cfg["a"] == rlp("int64", -999, resampling="mode") - assert cfg["b"] == rlp("int64", -999, resampling="sum") + assert cfg["b"] == rlp("int64", -999, resampling="sum", dims=("y", "x", "b")) cfg = resolve_load_cfg( _bands, dtype={"a": "float32"}, ) assert cfg["a"] == rlp("float32", -1) - assert cfg["b"] == rlp("float64", None) + assert cfg["b"] == rlp("float64", None, dims=_bands["b"].dims)