Skip to content

Commit

Permalink
amend me: dask graph mode impl
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Jun 4, 2024
1 parent dfc6c9c commit 203a320
Showing 1 changed file with 113 additions and 21 deletions.
134 changes: 113 additions & 21 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,15 @@
from odc.geo.geobox import GeoBox, GeoBoxBase, GeoboxTiles
from odc.geo.xr import xr_coords

from ._reader import nodata_mask, resolve_dst_fill_value, resolve_src_nodata
from ._reader import (
ReaderDaskAdaptor,
nodata_mask,
resolve_dst_fill_value,
resolve_src_nodata,
)
from ._utils import SizedIterable, pmap
from .types import (
DaskRasterReader,
MultiBandRasterSource,
RasterGroupMetadata,
RasterLoadParams,
Expand Down Expand Up @@ -154,13 +160,16 @@ def __init__(
env: Dict[str, Any],
rdr: ReaderDriver,
chunks: Mapping[str, int],
mode: Literal["auto"] | Literal["mem"] | Literal["concurrency"] = "auto",
) -> None:
gbox = gbt.base
assert isinstance(gbox, GeoBox)
# make sure chunks for tyx match our structure
chunk_tyx = (chunks.get("time", 1), *gbt.chunk_shape((0, 0)).yx)
chunks = {**chunks}
chunks.update(dict(zip(["time", "y", "x"], chunk_tyx)))
if mode == "auto":
mode = "mem" # current default

self.cfg = cfg
self.template = template
Expand All @@ -169,8 +178,9 @@ def __init__(
self.gbt = gbt
self.env = env
self.rdr = rdr
self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks)
self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks, mode)
self._chunks = chunks
self._mode = mode
self._load_state = rdr.new_load(gbox, chunks=chunks)

def _band_chunks(
Expand All @@ -191,7 +201,7 @@ def build(
gbox: GeoBox,
time: Sequence[datetime],
bands: Mapping[str, RasterLoadParams],
):
) -> xr.Dataset:
return mk_dataset(
gbox,
time,
Expand All @@ -200,13 +210,17 @@ def build(
template=self.template,
)

def _prep_sources(
self, name: str, dsk: dict[Key, Any], deps: list[Any]
) -> tuple[str, Any]:
def _norm_load_state(self, deps: list[Any]) -> Any:
load_state = self._load_state
if is_dask_collection(load_state):
deps.append(load_state)
load_state = load_state.key
return load_state

def _prep_sources(
self, name: str, dsk: dict[Key, Any], deps: list[Any]
) -> tuple[str, Any]:
load_state = self._norm_load_state(deps)

tk = self._tk
src_key = f"open-{name}-{tk}"
Expand All @@ -223,6 +237,35 @@ def _prep_sources(
)
return src_key, load_state

def _dask_rdr(self) -> DaskRasterReader:
if (dask_reader := self.rdr.dask_reader) is not None:
return dask_reader
return ReaderDaskAdaptor(self.rdr, self.env)

def _task_futures(
self,
task: LoadChunkTask,
dask_reader: DaskRasterReader,
deps: list[Any],
) -> list[list[Any]]:
srcs = task.resolve_sources(self.srcs)
out: list[list[Any]] = []
ctx = self._load_state
cfg = task.cfg
dst_gbox = task.dst_gbox

for layer in srcs:
read_futures: list[Any] = []
for src in layer:
rdr = dask_reader.open(src, ctx)
fut = rdr.read(cfg, dst_gbox, selection=task.selection)
read_futures.append(fut)
deps.append(fut)

out.append(read_futures)

return out

def __call__(
self,
shape: Tuple[int, ...],
Expand All @@ -238,6 +281,8 @@ def __call__(
assert dtype == cfg.dtype
assert ydim == cfg.ydim + 1 # +1 for time dimension

chunks = self._band_chunks(name, shape, ydim)

tk = self._tk
deps: list[Any] = []
cfg_dask_key = f"cfg-{tokenize(cfg)}"
Expand All @@ -247,25 +292,48 @@ def __call__(
cfg_dask_key: cfg,
gbt_dask_key: self.gbt,
}
src_key, load_state = self._prep_sources(name, dsk, deps)
dask_reader: DaskRasterReader | None = None
if self._mode == "mem":
src_key, load_state = self._prep_sources(name, dsk, deps)
else:
dask_reader = self._dask_rdr()
src_key, load_state = "", self._load_state

fill_value = resolve_dst_fill_value(
np.dtype(dtype),
cfg,
resolve_src_nodata(cfg.fill_value, cfg),
)

band_key = f"{name}-{tk}"
chunks = self._band_chunks(name, shape, ydim)

for task in self.load_tasks(name, shape[0]):
dsk[(band_key, *task.idx)] = (
_dask_loader_tyx,
task.resolve_sources_dask(src_key, dsk),
gbt_dask_key,
quote(task.idx_tyx[1:]),
quote(task.prefix_dims),
quote(task.postfix_dims),
cfg_dask_key,
self.rdr,
self.env,
load_state,
task.selection,
)
task_key: Key = (band_key, *task.idx)
if dask_reader is None:
dsk[task_key] = (
_dask_loader_tyx,
task.resolve_sources_dask(src_key, dsk),
gbt_dask_key,
quote(task.idx_tyx[1:]),
quote(task.prefix_dims),
quote(task.postfix_dims),
cfg_dask_key,
self.rdr,
self.env,
load_state,
task.selection,
)
else:
srcs_futures = self._task_futures(task, dask_reader, deps)

dsk[task_key] = (
_dask_fuser,
srcs_futures,
task.shape,
dtype,
fill_value,
ydim - 1,
)

dsk = HighLevelGraph.from_collections(band_key, dsk, dependencies=deps)

Expand Down Expand Up @@ -320,6 +388,30 @@ def _dask_loader_tyx(
return chunk


def _dask_fuser(
srcs: list[list[Any]],
shape: tuple[int, ...],
dtype: DTypeLike,
fill_value: float | int,
src_ydim: int = 0,
):
assert shape[0] == len(srcs)
assert len(shape) >= 3 # time, ..., y, x, ...

dst = np.full(shape, fill_value, dtype=dtype)

for ti, layer in enumerate(srcs):
fuse_nd_slices(
layer,
fill_value,
dst[ti],
ydim=src_ydim,
prefilled=True,
)

return dst


def fuse_nd_slices(
srcs: Iterable[tuple[tuple[slice, slice], np.ndarray]],
fill_value: float | int,
Expand Down

0 comments on commit 203a320

Please sign in to comment.