Skip to content

Commit

Permalink
amend me: extra dims support
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Feb 15, 2024
1 parent 04fa2c0 commit 22b48d1
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 16 deletions.
84 changes: 78 additions & 6 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Iterator,
List,
Literal,
Mapping,
Optional,
Protocol,
Sequence,
Expand All @@ -31,7 +32,14 @@
from ._dask import unpack_chunks
from ._reader import nodata_mask, resolve_src_nodata
from ._utils import SizedIterable, pmap
from .types import MultiBandRasterSource, RasterLoadParams, RasterSource, SomeReader
from .types import (
FixedCoord,
MultiBandRasterSource,
RasterGroupMetadata,
RasterLoadParams,
RasterSource,
SomeReader,
)


class MkArray(Protocol):
Expand Down Expand Up @@ -80,6 +88,7 @@ class DaskGraphBuilder:
def __init__(
self,
cfg: Dict[str, RasterLoadParams],
template: RasterGroupMetadata,
srcs: Sequence[MultiBandRasterSource],
tyx_bins: Dict[Tuple[int, int, int], List[int]],
gbt: GeoboxTiles,
Expand All @@ -88,6 +97,7 @@ def __init__(
time_chunks: int = 1,
) -> None:
self.cfg = cfg
self.template = template
self.srcs = srcs
self.tyx_bins = tyx_bins
self.gbt = gbt
Expand All @@ -96,6 +106,21 @@ def __init__(
self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, time_chunks)
self.chunk_shape = (time_chunks, *self.gbt.chunk_shape((0, 0)).yx)

def build(
self,
gbox: GeoBox,
time: Sequence[datetime],
bands: Dict[str, RasterLoadParams],
):
return mk_dataset(
gbox,
time,
bands,
self,
extra_coords=self.template.extra_coords,
extra_dims=self.template.extra_dims,
)

def __call__(
self,
shape: Tuple[int, ...],
Expand Down Expand Up @@ -217,12 +242,29 @@ def mk_dataset(
time: Sequence[datetime],
bands: Dict[str, RasterLoadParams],
alloc: Optional[MkArray] = None,
*,
extra_coords: Sequence[FixedCoord] | None = None,
extra_dims: Mapping[str, int] | None = None,
) -> xr.Dataset:
_shape = (len(time), *gbox.shape.yx)
coords = xr_coords(gbox)
crs_coord_name: Hashable = list(coords)[-1]
coords["time"] = xr.DataArray(time, dims=("time",))
dims = ("time", *gbox.dimensions)
_coords: Mapping[str, xr.DataArray] = {}
_dims: Dict[str, int] = {}

if extra_coords is not None:
_coords = {
coord.name: xr.DataArray(
np.array(coord.values, dtype=coord.dtype),
dims=(coord.name,),
name=coord.name,
)
for coord in extra_coords
}
_dims.update({coord.name: len(coord.values) for coord in extra_coords})

if extra_dims is not None:
_dims.update(extra_dims)

def _alloc(shape: Tuple[int, ...], dtype: str, name: Hashable) -> Any:
if alloc is not None:
Expand All @@ -231,12 +273,35 @@ def _alloc(shape: Tuple[int, ...], dtype: str, name: Hashable) -> Any:

def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray:
assert band.dtype is not None
data = _alloc(_shape, band.dtype, name=name)
band_coords = {**coords}

if band.dims is not None and len(band.dims) > 2:
# TODO: generalize to more dims
assert band.dims[:2] == ("y", "x")
assert len(band.dims) == 3

_, _, extra_dim = band.dims

dims: Tuple[str, ...] = ("time", *gbox.dimensions, extra_dim)
shape: Tuple[int, ...] = (len(time), *gbox.shape.yx, _dims[extra_dim])

band_coords.update(
{
_coords[dim].name: _coords[dim]
for dim in band.dims
if dim not in ("y", "x")
}
)
else:
dims = ("time", *gbox.dimensions)
shape = (len(time), *gbox.shape.yx)

data = _alloc(shape, band.dtype, name=name)
attrs = {}
if band.fill_value is not None:
attrs["nodata"] = band.fill_value

xx = xr.DataArray(data=data, coords=coords, dims=dims, attrs=attrs)
xx = xr.DataArray(data=data, coords=band_coords, dims=dims, attrs=attrs)
xx.encoding.update(grid_mapping=crs_coord_name)
return xx

Expand All @@ -245,6 +310,7 @@ def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray:

def direct_chunked_load(
load_cfg: Dict[str, RasterLoadParams],
template: RasterGroupMetadata,
srcs: Sequence[MultiBandRasterSource],
tyx_bins: Dict[Tuple[int, int, int], List[int]],
gbt: GeoboxTiles,
Expand All @@ -265,7 +331,13 @@ def direct_chunked_load(
bands = list(load_cfg)
gbox = gbt.base
assert isinstance(gbox, GeoBox)
ds = mk_dataset(gbox, tss, load_cfg)
ds = mk_dataset(
gbox,
tss,
load_cfg,
extra_coords=template.extra_coords,
extra_dims=template.extra_dims,
)
ny, nx = gbt.shape.yx
total_tasks = nt * nb * ny * nx

Expand Down
1 change: 1 addition & 0 deletions odc/loader/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _resolve(name: str, band: RasterBandMetadata) -> RasterLoadParams:
use_overviews=use_overviews,
resampling=_resampling(name, "nearest"),
fail_on_error=fail_on_error,
dims=band.dims,
)

return {name: _resolve(name, band) for name, band in bands.items()}
Expand Down
70 changes: 70 additions & 0 deletions odc/loader/test_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# pylint: disable=missing-function-docstring,missing-module-docstring,too-many-statements
from __future__ import annotations

from datetime import datetime
from types import SimpleNamespace as _sn
from typing import Dict, Mapping, Sequence

import pytest
from odc.geo.geobox import GeoBox

from ._builder import mk_dataset
from .types import FixedCoord, RasterLoadParams

time = [datetime(2020, 1, 1)]
gbox = GeoBox.from_bbox((-180, -90, 180, 90), shape=(160, 320), tight=True)
shape = (len(time), *gbox.shape.yx)
dims = ("time", *gbox.dimensions)

_rlp = RasterLoadParams


@pytest.mark.parametrize(
"bands,extra_coords,extra_dims,expect",
[
[
{"a": _rlp("uint8")},
None,
None,
{"a": _sn(dims=dims, shape=shape)},
],
[
{"a": _rlp("uint8", dims=("y", "x", "B"))},
[FixedCoord("B", ["r", "g", "b"])],
None,
{"a": _sn(dims=(*dims, "B"), shape=(*shape, 3))},
],
[
{"a": _rlp("uint8", dims=("y", "x", "W"))},
[FixedCoord("W", ["r", "g", "b", "a"])],
{"b": 4},
{"a": _sn(dims=(*dims, "W"), shape=(*shape, 4))},
],
],
)
def test_mk_dataset(
bands: Dict[str, RasterLoadParams],
extra_coords: Sequence[FixedCoord] | None,
extra_dims: Mapping[str, int] | None,
expect: Mapping[str, _sn],
):
assert gbox.crs == "EPSG:4326"
xx = mk_dataset(
gbox,
time,
bands=bands,
extra_coords=extra_coords,
extra_dims=extra_dims,
)
assert xx is not None

for n, e in expect.items():
assert n in xx.data_vars
v = xx[n]
assert v.dims == e.dims
assert v.shape == e.shape

if extra_coords is not None:
for c in extra_coords:
assert c.name in xx.coords
assert xx.coords[c.name].shape == (len(c.values),)
10 changes: 8 additions & 2 deletions odc/loader/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class FixedCoord:

def __post_init__(self):
if self.dtype is None:
self.dtype = np.array(self.values).dtype.name
self.dtype = np.array(self.values).dtype.str
if self.dim is None:
self.dim = self.name

Expand Down Expand Up @@ -245,6 +245,8 @@ class RasterLoadParams:
Captures data loading configuration.
"""

# pylint: disable=too-many-instance-attributes

dtype: Optional[str] = None
"""Output dtype, default same as source."""

Expand Down Expand Up @@ -281,6 +283,9 @@ class RasterLoadParams:
fail_on_error: bool = True
"""Quit on the first error or continue."""

dims: Optional[Tuple[str, ...]] = None
"""Dimension names for this band."""

@staticmethod
def same_as(src: Union[RasterBandMetadata, RasterSource]) -> "RasterLoadParams":
"""Construct from source object."""
Expand All @@ -293,7 +298,7 @@ def same_as(src: Union[RasterBandMetadata, RasterSource]) -> "RasterLoadParams":
if dtype is None:
dtype = "float32"

return RasterLoadParams(dtype=dtype, fill_value=meta.nodata)
return RasterLoadParams(dtype=dtype, fill_value=meta.nodata, dims=meta.dims)

@property
def nearest(self) -> bool:
Expand All @@ -315,6 +320,7 @@ def _repr_json_(self) -> Dict[str, Any]:
"use_overviews": self.use_overviews,
"resampling": self.resampling,
"fail_on_error": self.fail_on_error,
"dims": self.dims,
}


Expand Down
7 changes: 4 additions & 3 deletions odc/stac/_stac_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from odc.loader import (
DaskGraphBuilder,
direct_chunked_load,
mk_dataset,
reader_driver,
resolve_chunk_shape,
resolve_load_cfg,
Expand Down Expand Up @@ -463,20 +462,22 @@ def _with_debug_info(ds: xr.Dataset, **kw) -> xr.Dataset:

rdr_env = rdr.capture_env()
if chunks is not None:
_loader = DaskGraphBuilder(
dask_loader = DaskGraphBuilder(
load_cfg,
collection.meta,
_parsed,
tyx_bins,
gbt,
rdr_env,
rdr,
time_chunks=chunk_shape[0],
)
return _with_debug_info(mk_dataset(gbox, tss, load_cfg, _loader))
return _with_debug_info(dask_loader.build(gbox, tss, load_cfg))

return _with_debug_info(
direct_chunked_load(
load_cfg,
collection.meta,
_parsed,
tyx_bins,
gbt,
Expand Down
3 changes: 2 additions & 1 deletion odc/stac/testing/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def b_(
dtype="int16",
nodata=None,
unit="1",
dims=None,
uri=None,
bidx=1,
prefix="http://example.com/items/",
Expand All @@ -55,7 +56,7 @@ def b_(
name, _ = band_key
if uri is None:
uri = f"{prefix}{name}.tif"
meta = RasterBandMetadata(dtype, nodata, unit)
meta = RasterBandMetadata(dtype, nodata, unit, dims=dims)
return (band_key, RasterSource(uri, bidx, geobox=geobox, meta=meta))


Expand Down
8 changes: 4 additions & 4 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_resolve_load_cfg():
item = mk_parsed_item(
[
b_("a", dtype="int8", nodata=-1),
b_("b", dtype="float64"),
b_("b", dtype="float64", dims=("y", "x", "b")),
]
)

Expand All @@ -229,7 +229,7 @@ def test_resolve_load_cfg():

cfg = resolve_load_cfg(_bands, resampling="average")
assert cfg["a"] == rlp("int8", -1, resampling="average")
assert cfg["b"] == rlp("float64", None, resampling="average")
assert cfg["b"] == rlp("float64", None, resampling="average", dims=("y", "x", "b"))

cfg = resolve_load_cfg(
_bands,
Expand All @@ -238,11 +238,11 @@ def test_resolve_load_cfg():
dtype="int64",
)
assert cfg["a"] == rlp("int64", -999, resampling="mode")
assert cfg["b"] == rlp("int64", -999, resampling="sum")
assert cfg["b"] == rlp("int64", -999, resampling="sum", dims=("y", "x", "b"))

cfg = resolve_load_cfg(
_bands,
dtype={"a": "float32"},
)
assert cfg["a"] == rlp("float32", -1)
assert cfg["b"] == rlp("float64", None)
assert cfg["b"] == rlp("float64", None, dims=_bands["b"].dims)

0 comments on commit 22b48d1

Please sign in to comment.