Skip to content

Commit

Permalink
Supporting postfix dimensions on load
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Feb 16, 2024
1 parent 04fa2c0 commit 5241fa6
Show file tree
Hide file tree
Showing 9 changed files with 411 additions and 46 deletions.
158 changes: 129 additions & 29 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Iterator,
List,
Literal,
Mapping,
Optional,
Protocol,
Sequence,
Expand All @@ -31,7 +32,14 @@
from ._dask import unpack_chunks
from ._reader import nodata_mask, resolve_src_nodata
from ._utils import SizedIterable, pmap
from .types import MultiBandRasterSource, RasterLoadParams, RasterSource, SomeReader
from .types import (
FixedCoord,
MultiBandRasterSource,
RasterGroupMetadata,
RasterLoadParams,
RasterSource,
SomeReader,
)


class MkArray(Protocol):
Expand All @@ -58,11 +66,12 @@ class LoadChunkTask:
cfg: RasterLoadParams
gbt: GeoboxTiles
idx_tyx: Tuple[int, int, int]
postfix_dims: Tuple[int, ...] = ()

@property
def dst_roi(self):
t, y, x = self.idx_tyx
return (t, *self.gbt.roi[y, x])
return (t, *self.gbt.roi[y, x]) + tuple([slice(None)] * len(self.postfix_dims))

@property
def dst_gbox(self) -> GeoBox:
Expand All @@ -80,6 +89,7 @@ class DaskGraphBuilder:
def __init__(
self,
cfg: Dict[str, RasterLoadParams],
template: RasterGroupMetadata,
srcs: Sequence[MultiBandRasterSource],
tyx_bins: Dict[Tuple[int, int, int], List[int]],
gbt: GeoboxTiles,
Expand All @@ -88,6 +98,7 @@ def __init__(
time_chunks: int = 1,
) -> None:
self.cfg = cfg
self.template = template
self.srcs = srcs
self.tyx_bins = tyx_bins
self.gbt = gbt
Expand All @@ -96,6 +107,21 @@ def __init__(
self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, time_chunks)
self.chunk_shape = (time_chunks, *self.gbt.chunk_shape((0, 0)).yx)

def build(
self,
gbox: GeoBox,
time: Sequence[datetime],
bands: Dict[str, RasterLoadParams],
):
return mk_dataset(
gbox,
time,
bands,
self,
extra_coords=self.template.extra_coords,
extra_dims=self.template.extra_dims,
)

def __call__(
self,
shape: Tuple[int, ...],
Expand All @@ -104,51 +130,57 @@ def __call__(
name: Hashable,
) -> Any:
# pylint: disable=too-many-locals
assert len(shape) == 3
assert isinstance(name, str)
cfg = self.cfg[name]
assert dtype == cfg.dtype
# TODO: assumes postfix dims only for now
ydim = 1
post_fix_dims = shape[ydim + 2 :]

chunks = unpack_chunks(self.chunk_shape, shape)
chunk_shape = (*self.chunk_shape, *post_fix_dims)
assert len(chunk_shape) == len(shape)
chunks = unpack_chunks(chunk_shape, shape)
tchunk_range = [
range(last - n, last) for last, n in zip(np.cumsum(chunks[0]), chunks[0])
]

cfg_key = f"cfg-{tokenize(cfg)}"
gbt_key = f"grid-{tokenize(self.gbt)}"
cfg_dask_key = f"cfg-{tokenize(cfg)}"
gbt_dask_key = f"grid-{tokenize(self.gbt)}"

dsk: Dict[Hashable, Any] = {
cfg_key: cfg,
gbt_key: self.gbt,
cfg_dask_key: cfg,
gbt_dask_key: self.gbt,
}
tk = self._tk
band_key = f"{name}-{tk}"
md_key = f"md-{name}-{tk}"
shape_in_blocks = tuple(len(ch) for ch in chunks)

for idx, src in enumerate(self.srcs):
for src_idx, src in enumerate(self.srcs):
band = src.get(name, None)
if band is not None:
dsk[md_key, idx] = band
dsk[md_key, src_idx] = band

for ti, yi, xi in np.ndindex(shape_in_blocks): # type: ignore
for block_idx in np.ndindex(shape_in_blocks):
ti, yi, xi = block_idx[0], block_idx[ydim], block_idx[ydim + 1]
srcs = []
for _ti in tchunk_range[ti]:
srcs.append(
[
(md_key, idx)
for idx in self.tyx_bins.get((_ti, yi, xi), [])
if (md_key, idx) in dsk
(md_key, src_idx)
for src_idx in self.tyx_bins.get((_ti, yi, xi), [])
if (md_key, src_idx) in dsk
]
)

dsk[band_key, ti, yi, xi] = (
dsk[(band_key, *block_idx)] = (
_dask_loader_tyx,
srcs,
gbt_key,
gbt_dask_key,
quote((yi, xi)),
quote(post_fix_dims),
self.rdr,
cfg_key,
cfg_dask_key,
self.env,
)

Expand All @@ -159,16 +191,17 @@ def _dask_loader_tyx(
srcs: Sequence[Sequence[RasterSource]],
gbt: GeoboxTiles,
iyx: Tuple[int, int],
postfix_dims: Tuple[int, ...],
rdr: SomeReader,
cfg: RasterLoadParams,
env: Dict[str, Any],
):
assert cfg.dtype is not None
gbox = cast(GeoBox, gbt[iyx])
chunk = np.empty((len(srcs), *gbox.shape.yx), dtype=cfg.dtype)
chunk = np.empty((len(srcs), *gbox.shape.yx, *postfix_dims), dtype=cfg.dtype)
with rdr.restore_env(env):
for i, plane in enumerate(srcs):
fill_2d_slice(plane, gbox, cfg, rdr, chunk[i, :, :])
for ti, ti_srcs in enumerate(srcs):
fill_2d_slice(ti_srcs, gbox, cfg, rdr, chunk[ti])
return chunk


Expand All @@ -184,7 +217,11 @@ def fill_2d_slice(
# ``nodata`` marks missing pixels, but it might be None (everything is valid)
# ``fill_value`` is the initial value to use, it's equal to ``nodata`` when set,
# otherwise defaults to .nan for floats and 0 for integers
assert dst.shape == dst_gbox.shape.yx

# assume dst[y, x, ...] axis order
assert dst.shape[:2] == dst_gbox.shape.yx
postfix_roi = (slice(None),) * len(dst.shape[2:])

nodata = resolve_src_nodata(cfg.fill_value, cfg)

if nodata is None:
Expand All @@ -197,11 +234,18 @@ def fill_2d_slice(
return dst

src, *rest = srcs
_roi, pix = rdr.read(src, cfg, dst_gbox, dst=dst)
yx_roi, pix = rdr.read(src, cfg, dst_gbox, dst=dst)
assert len(yx_roi) == 2
assert pix.ndim == dst.ndim

for src in rest:
# first valid pixel takes precedence over others
_roi, pix = rdr.read(src, cfg, dst_gbox)
yx_roi, pix = rdr.read(src, cfg, dst_gbox)
assert len(yx_roi) == 2
assert pix.ndim == dst.ndim

_roi: Tuple[slice,] = yx_roi + postfix_roi # type: ignore
assert dst[_roi].shape == pix.shape

# nodata mask takes care of nan when working with floats
# so you can still get proper mask even when nodata is None
Expand All @@ -217,12 +261,29 @@ def mk_dataset(
time: Sequence[datetime],
bands: Dict[str, RasterLoadParams],
alloc: Optional[MkArray] = None,
*,
extra_coords: Sequence[FixedCoord] | None = None,
extra_dims: Mapping[str, int] | None = None,
) -> xr.Dataset:
_shape = (len(time), *gbox.shape.yx)
coords = xr_coords(gbox)
crs_coord_name: Hashable = list(coords)[-1]
coords["time"] = xr.DataArray(time, dims=("time",))
dims = ("time", *gbox.dimensions)
_coords: Mapping[str, xr.DataArray] = {}
_dims: Dict[str, int] = {}

if extra_coords is not None:
_coords = {
coord.name: xr.DataArray(
np.array(coord.values, dtype=coord.dtype),
dims=(coord.name,),
name=coord.name,
)
for coord in extra_coords
}
_dims.update({coord.name: len(coord.values) for coord in extra_coords})

if extra_dims is not None:
_dims.update(extra_dims)

def _alloc(shape: Tuple[int, ...], dtype: str, name: Hashable) -> Any:
if alloc is not None:
Expand All @@ -231,12 +292,38 @@ def _alloc(shape: Tuple[int, ...], dtype: str, name: Hashable) -> Any:

def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray:
assert band.dtype is not None
data = _alloc(_shape, band.dtype, name=name)
band_coords = {**coords}

if band.dims is not None and len(band.dims) > 2:
# TODO: generalize to more dims
ydim = 0
postfix_dims = band.dims[ydim + 2 :]
assert band.dims[ydim : ydim + 2] == ("y", "x")

dims: Tuple[str, ...] = ("time", *gbox.dimensions, *postfix_dims)
shape: Tuple[int, ...] = (
len(time),
*gbox.shape.yx,
*[_dims[dim] for dim in postfix_dims],
)

band_coords.update(
{
_coords[dim].name: _coords[dim]
for dim in postfix_dims
if dim in _coords
}
)
else:
dims = ("time", *gbox.dimensions)
shape = (len(time), *gbox.shape.yx)

data = _alloc(shape, band.dtype, name=name)
attrs = {}
if band.fill_value is not None:
attrs["nodata"] = band.fill_value

xx = xr.DataArray(data=data, coords=coords, dims=dims, attrs=attrs)
xx = xr.DataArray(data=data, coords=band_coords, dims=dims, attrs=attrs)
xx.encoding.update(grid_mapping=crs_coord_name)
return xx

Expand All @@ -245,6 +332,7 @@ def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray:

def direct_chunked_load(
load_cfg: Dict[str, RasterLoadParams],
template: RasterGroupMetadata,
srcs: Sequence[MultiBandRasterSource],
tyx_bins: Dict[Tuple[int, int, int], List[int]],
gbt: GeoboxTiles,
Expand All @@ -265,7 +353,13 @@ def direct_chunked_load(
bands = list(load_cfg)
gbox = gbt.base
assert isinstance(gbox, GeoBox)
ds = mk_dataset(gbox, tss, load_cfg)
ds = mk_dataset(
gbox,
tss,
load_cfg,
extra_coords=template.extra_coords,
extra_dims=template.extra_dims,
)
ny, nx = gbt.shape.yx
total_tasks = nt * nb * ny * nx

Expand All @@ -286,7 +380,13 @@ def _do_one(task: LoadChunkTask) -> Tuple[str, int, int, int]:
if src is not None
]
with rdr.restore_env(env):
_ = fill_2d_slice(_srcs, task.dst_gbox, task.cfg, rdr, dst_slice)
_ = fill_2d_slice(
_srcs,
task.dst_gbox,
task.cfg,
rdr=rdr,
dst=dst_slice,
)
t, y, x = task.idx_tyx
return (task.band, t, y, x)

Expand Down
1 change: 1 addition & 0 deletions odc/loader/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _resolve(name: str, band: RasterBandMetadata) -> RasterLoadParams:
use_overviews=use_overviews,
resampling=_resampling(name, "nearest"),
fail_on_error=fail_on_error,
dims=band.dims,
)

return {name: _resolve(name, band) for name, band in bands.items()}
Expand Down
Loading

0 comments on commit 5241fa6

Please sign in to comment.