Skip to content

Commit

Permalink
Refactor to enable prefix dims
Browse files Browse the repository at this point in the history
- generalize LoadChunkTask
  - single chunk can have multiple temporal slices
  - prefix/postfix dims
- extract load task generation logic from internals
  • Loading branch information
Kirill888 committed May 21, 2024
1 parent ec41200 commit e284eae
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 80 deletions.
217 changes: 150 additions & 67 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from ._reader import nodata_mask, resolve_dst_fill_value, resolve_src_nodata
from ._utils import SizedIterable, pmap
from .types import (
FixedCoord,
MultiBandRasterSource,
RasterGroupMetadata,
RasterLoadParams,
Expand Down Expand Up @@ -66,22 +65,47 @@ class LoadChunkTask:
"""

band: str
srcs: List[Tuple[int, str]]
srcs: List[List[Tuple[int, str]]]
cfg: RasterLoadParams
gbt: GeoboxTiles
idx_tyx: Tuple[int, int, int]
prefix_dims: Tuple[int, ...] = ()
postfix_dims: Tuple[int, ...] = ()

@property
def dst_roi(self):
def dst_roi(self) -> Tuple[slice, ...]:
t, y, x = self.idx_tyx
return (t, *self.gbt.roi[y, x]) + tuple([slice(None)] * len(self.postfix_dims))
iy, ix = self.gbt.roi[y, x]
return (
slice(t, t + len(self.srcs)),
*[slice(None) for _ in self.prefix_dims],
iy,
ix,
*[slice(None) for _ in self.postfix_dims],
)

@property
def dst_gbox(self) -> GeoBox:
_, y, x = self.idx_tyx
return cast(GeoBox, self.gbt[y, x])

def __bool__(self) -> bool:
return len(self.srcs) > 0 and any(len(src) > 0 for src in self.srcs)

def resolve_sources(
self, srcs: Sequence[MultiBandRasterSource]
) -> List[List[RasterSource]]:
out: List[List[RasterSource]] = []

for layer in self.srcs:
_srcs: List[RasterSource] = []
for idx, b in layer:
src = srcs[idx].get(b, None)
if src is not None:
_srcs.append(src)
out.append(_srcs)
return out


class DaskGraphBuilder:
"""
Expand Down Expand Up @@ -109,8 +133,8 @@ def __init__(
self.env = env
self.rdr = rdr
self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, time_chunks)
self.chunk_shape = (time_chunks, *self.gbt.chunk_shape((0, 0)).yx)
self._load_state = rdr.new_load(dict(zip(["time", "y", "x"], self.chunk_shape)))
self.chunk_tyx = (time_chunks, *self.gbt.chunk_shape((0, 0)).yx)
self._load_state = rdr.new_load(dict(zip(["time", "y", "x"], self.chunk_tyx)))

def build(
self,
Expand All @@ -123,8 +147,7 @@ def build(
time,
bands,
self,
extra_coords=self.template.extra_coords,
extra_dims=self.template.extra_dims,
template=self.template,
)

def __call__(
Expand All @@ -138,11 +161,16 @@ def __call__(
assert isinstance(name, str)
cfg = self.cfg[name]
assert dtype == cfg.dtype
# TODO: assumes postfix dims only for now
ydim = 1
post_fix_dims = shape[ydim + 2 :]

chunk_shape = (*self.chunk_shape, *post_fix_dims)
ydim = cfg.ydim + 1
postfix_dims = shape[ydim + 2 :]
prefix_dims = shape[1:ydim]

chunk_shape: Tuple[int, ...] = (
self.chunk_tyx[0],
*prefix_dims,
*self.chunk_tyx[1:],
*postfix_dims,
)
assert len(chunk_shape) == len(shape)
chunks = unpack_chunks(chunk_shape, shape)
tchunk_range = [
Expand Down Expand Up @@ -195,7 +223,8 @@ def __call__(
srcs,
gbt_dask_key,
quote((yi, xi)),
quote(post_fix_dims),
quote(prefix_dims),
quote(postfix_dims),
cfg_dask_key,
self.rdr,
self.env,
Expand All @@ -221,6 +250,7 @@ def _dask_loader_tyx(
srcs: Sequence[Sequence[RasterReader]],
gbt: GeoboxTiles,
iyx: Tuple[int, int],
prefix_dims: Tuple[int, ...],
postfix_dims: Tuple[int, ...],
cfg: RasterLoadParams,
rdr: ReaderDriver,
Expand All @@ -229,10 +259,14 @@ def _dask_loader_tyx(
):
assert cfg.dtype is not None
gbox = cast(GeoBox, gbt[iyx])
chunk = np.empty((len(srcs), *gbox.shape.yx, *postfix_dims), dtype=cfg.dtype)
chunk = np.empty(
(len(srcs), *prefix_dims, *gbox.shape.yx, *postfix_dims),
dtype=cfg.dtype,
)
ydim = len(prefix_dims)
with rdr.restore_env(env, load_state):
for ti, ti_srcs in enumerate(srcs):
_fill_nd_slice(ti_srcs, gbox, cfg, chunk[ti])
_fill_nd_slice(ti_srcs, gbox, cfg, chunk[ti], ydim=ydim)
return chunk


Expand All @@ -241,6 +275,7 @@ def _fill_nd_slice(
dst_gbox: GeoBox,
cfg: RasterLoadParams,
dst: Any,
ydim: int = 0,
) -> Any:
# TODO: support masks not just nodata based fusing
#
Expand All @@ -249,8 +284,9 @@ def _fill_nd_slice(
# otherwise defaults to .nan for floats and 0 for integers

# assume dst[y, x, ...] axis order
assert dst.shape[:2] == dst_gbox.shape.yx
postfix_roi = (slice(None),) * len(dst.shape[2:])
assert dst.shape[ydim : ydim + 2] == dst_gbox.shape.yx
postfix_roi = (slice(None),) * len(dst.shape[ydim + 2 :])
prefix_roi = (slice(None),) * ydim

nodata = resolve_src_nodata(cfg.fill_value, cfg)
fill_value = resolve_dst_fill_value(dst.dtype, cfg, nodata)
Expand All @@ -270,7 +306,7 @@ def _fill_nd_slice(
assert len(yx_roi) == 2
assert pix.ndim == dst.ndim

_roi: Tuple[slice,] = yx_roi + postfix_roi # type: ignore
_roi: Tuple[slice,] = prefix_roi + yx_roi + postfix_roi # type: ignore
assert dst[_roi].shape == pix.shape

# nodata mask takes care of nan when working with floats
Expand All @@ -288,28 +324,21 @@ def mk_dataset(
bands: Dict[str, RasterLoadParams],
alloc: Optional[MkArray] = None,
*,
extra_coords: Sequence[FixedCoord] | None = None,
extra_dims: Mapping[str, int] | None = None,
template: RasterGroupMetadata,
) -> xr.Dataset:
coords = xr_coords(gbox)
crs_coord_name: Hashable = list(coords)[-1]
coords["time"] = xr.DataArray(time, dims=("time",))
_coords: Mapping[str, xr.DataArray] = {}
_dims: Dict[str, int] = {}

if extra_coords is not None:
_coords = {
coord.name: xr.DataArray(
np.array(coord.values, dtype=coord.dtype),
dims=(coord.name,),
name=coord.name,
)
for coord in extra_coords
}
_dims.update({coord.name: len(coord.values) for coord in extra_coords})
_dims = template.extra_dims_full()

if extra_dims is not None:
_dims.update(extra_dims)
_coords = {
coord.name: xr.DataArray(
np.array(coord.values, dtype=coord.dtype),
dims=(coord.dim,),
name=coord.name,
)
for coord in template.extra_coords
}

def _alloc(shape: Tuple[int, ...], dtype: str, name: Hashable) -> Any:
if alloc is not None:
Expand All @@ -320,23 +349,29 @@ def _maker(name: Hashable, band: RasterLoadParams) -> xr.DataArray:
assert band.dtype is not None
band_coords = {**coords}

if band.dims is not None and len(band.dims) > 2:
# TODO: generalize to more dims
ydim = 0
postfix_dims = band.dims[ydim + 2 :]
if len(band.dims) > 2:
ydim = band.ydim
assert band.dims[ydim : ydim + 2] == ("y", "x")
prefix_dims = band.dims[:ydim]
postfix_dims = band.dims[ydim + 2 :]

dims: Tuple[str, ...] = ("time", *gbox.dimensions, *postfix_dims)
dims: Tuple[str, ...] = (
"time",
*prefix_dims,
*gbox.dimensions,
*postfix_dims,
)
shape: Tuple[int, ...] = (
len(time),
*[_dims[dim] for dim in prefix_dims],
*gbox.shape.yx,
*[_dims[dim] for dim in postfix_dims],
)

band_coords.update(
{
_coords[dim].name: _coords[dim]
for dim in postfix_dims
for dim in (prefix_dims + postfix_dims)
if dim in _coords
}
)
Expand Down Expand Up @@ -432,6 +467,57 @@ def dask_chunked_load(
return dask_loader.build(gbox, tss, load_cfg)


def load_tasks(
load_cfg: Dict[str, RasterLoadParams],
tyx_bins: Dict[Tuple[int, int, int], List[int]],
gbt: GeoboxTiles,
*,
nt: Optional[int] = None,
time_chunks: int = 1,
extra_dims: Mapping[str, int] | None = None,
) -> Iterator[LoadChunkTask]:
"""
Convert tyx_bins into a complete set of load tasks.
This is a generator that yields :py:class:`~odc.loader.LoadChunkTask`
instances for every possible time, y, x, bins, including empty ones.
"""
# pylint: disable=too-many-locals
if nt is None:
nt = max(t for t, _, _ in tyx_bins) + 1

if extra_dims is None:
extra_dims = {}

shape_in_chunks: Tuple[int, int, int] = (
(nt + time_chunks - 1) // time_chunks,
*gbt.shape.yx,
)

for band_name, cfg in load_cfg.items():
ydim = cfg.ydim
prefix_dims = tuple(extra_dims[dim] for dim in cfg.dims[:ydim])
postfix_dims = tuple(extra_dims[dim] for dim in cfg.dims[ydim + 2 :])

for idx in np.ndindex(shape_in_chunks):
tBi, yi, xi = idx # type: ignore
srcs: List[List[Tuple[int, str]]] = []
t0 = tBi * time_chunks
for ti in range(t0, min(t0 + time_chunks, nt)):
tyx_idx = (ti, yi, xi)
srcs.append([(idx, band_name) for idx in tyx_bins.get(tyx_idx, [])])

yield LoadChunkTask(
band_name,
srcs,
cfg,
gbt,
(tBi, yi, xi),
prefix_dims=prefix_dims,
postfix_dims=postfix_dims,
)


def direct_chunked_load(
load_cfg: Dict[str, RasterLoadParams],
template: RasterGroupMetadata,
Expand All @@ -451,48 +537,45 @@ def direct_chunked_load(
# pylint: disable=too-many-locals
nt = len(tss)
nb = len(load_cfg)
bands = list(load_cfg)
gbox = gbt.base
assert isinstance(gbox, GeoBox)
ds = mk_dataset(
gbox,
tss,
load_cfg,
extra_coords=template.extra_coords,
extra_dims=template.extra_dims,
template=template,
)
ny, nx = gbt.shape.yx
total_tasks = nt * nb * ny * nx
load_state = rdr.new_load()

def _task_stream(bands: List[str]) -> Iterator[LoadChunkTask]:
_shape: Tuple[int, int, int] = (nt, *gbt.shape.yx)
for band_name in bands:
cfg = load_cfg[band_name]
for ti, yi, xi in np.ndindex(_shape): # type: ignore
tyx_idx = (ti, yi, xi)
_srcs = [(idx, band_name) for idx in tyx_bins.get(tyx_idx, [])]
yield LoadChunkTask(band_name, _srcs, cfg, gbt, tyx_idx)

def _do_one(task: LoadChunkTask) -> Tuple[str, int, int, int]:
dst_slice = ds[task.band].data[task.dst_roi]
_srcs = [
src
for src in (srcs[idx].get(band, None) for idx, band in task.srcs)
if src is not None
]
layers = task.resolve_sources(srcs)
ydim = len(task.prefix_dims)

with rdr.restore_env(env, load_state) as ctx:
loaders = [rdr.open(src, ctx) for src in _srcs]
_ = _fill_nd_slice(
loaders,
task.dst_gbox,
task.cfg,
dst=dst_slice,
)
for t_idx, layer in enumerate(layers):
loaders = [rdr.open(src, ctx) for src in layer]
_ = _fill_nd_slice(
loaders,
task.dst_gbox,
task.cfg,
dst=dst_slice[t_idx],
ydim=ydim,
)
t, y, x = task.idx_tyx
return (task.band, t, y, x)

_work = pmap(_do_one, _task_stream(bands), pool)
tasks = load_tasks(
load_cfg,
tyx_bins,
gbt,
nt=nt,
extra_dims=template.extra_dims_full(),
)

_work = pmap(_do_one, tasks, pool)

if progress is not None:
_work = progress(SizedIterable(_work, total_tasks))
Expand Down
Loading

0 comments on commit e284eae

Please sign in to comment.