Skip to content

Commit

Permalink
amend me: dask graph mode impl
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Jun 4, 2024
1 parent 303155d commit 7c4db04
Showing 1 changed file with 36 additions and 20 deletions.
56 changes: 36 additions & 20 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,16 @@ def __init__(
env: Dict[str, Any],
rdr: ReaderDriver,
chunks: Mapping[str, int],
mode: Literal["auto"] | Literal["mem"] | Literal["concurrency"] = "auto",
) -> None:
gbox = gbt.base
assert isinstance(gbox, GeoBox)
# make sure chunks for tyx match our structure
chunk_tyx = (chunks.get("time", 1), *gbt.chunk_shape((0, 0)).yx)
chunks = {**chunks}
chunks.update(dict(zip(["time", "y", "x"], chunk_tyx)))
if mode == "auto":
mode = "mem" # current default

self.cfg = cfg
self.template = template
Expand All @@ -168,8 +171,9 @@ def __init__(
self.gbt = gbt
self.env = env
self.rdr = rdr
self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks)
self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks, mode)
self._chunks = chunks
self._mode = mode
self._load_state = rdr.new_load(gbox, chunks=chunks)

def _band_chunks(
Expand All @@ -190,7 +194,7 @@ def build(
gbox: GeoBox,
time: Sequence[datetime],
bands: Mapping[str, RasterLoadParams],
):
) -> xr.Dataset:
return mk_dataset(
gbox,
time,
Expand All @@ -199,13 +203,17 @@ def build(
template=self.template,
)

def _prep_sources(
self, name: str, dsk: dict[Key, Any], deps: list[Any]
) -> tuple[str, Any]:
def _norm_load_state(self, deps: list[Any]) -> Any:
load_state = self._load_state
if is_dask_collection(load_state):
deps.append(load_state)
load_state = load_state.key
return load_state

def _prep_sources(
self, name: str, dsk: dict[Key, Any], deps: list[Any]
) -> tuple[str, Any]:
load_state = self._norm_load_state(deps)

tk = self._tk
src_key = f"open-{name}-{tk}"
Expand Down Expand Up @@ -237,6 +245,8 @@ def __call__(
assert dtype == cfg.dtype
assert ydim == cfg.ydim + 1 # +1 for time dimension

chunks = self._band_chunks(name, shape, ydim)

tk = self._tk
deps: list[Any] = []
cfg_dask_key = f"cfg-{tokenize(cfg)}"
Expand All @@ -246,25 +256,31 @@ def __call__(
cfg_dask_key: cfg,
gbt_dask_key: self.gbt,
}
src_key, load_state = self._prep_sources(name, dsk, deps)
if self._mode == "mem":
src_key, load_state = self._prep_sources(name, dsk, deps)
else:
src_key, load_state = "", self._norm_load_state(deps)

band_key = f"{name}-{tk}"
chunks = self._band_chunks(name, shape, ydim)

for task in self.load_tasks(name, shape[0]):
dsk[(band_key, *task.idx)] = (
_dask_loader_tyx,
task.resolve_sources_dask(src_key, dsk),
gbt_dask_key,
quote(task.idx_tyx[1:]),
quote(task.prefix_dims),
quote(task.postfix_dims),
cfg_dask_key,
self.rdr,
self.env,
load_state,
task.selection,
)
task_key: Key = (band_key, *task.idx)
if self._mode == "mem":
dsk[task_key] = (
_dask_loader_tyx,
task.resolve_sources_dask(src_key, dsk),
gbt_dask_key,
quote(task.idx_tyx[1:]),
quote(task.prefix_dims),
quote(task.postfix_dims),
cfg_dask_key,
self.rdr,
self.env,
load_state,
task.selection,
)
else:
raise NotImplementedError("concurrency mode not implemented")

dsk = HighLevelGraph.from_collections(band_key, dsk, dependencies=deps)

Expand Down

0 comments on commit 7c4db04

Please sign in to comment.