From 7c4db04d76e8eb32c1b9d254740d3cab180ec392 Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Mon, 3 Jun 2024 19:17:21 +1000 Subject: [PATCH] amend me: dask graph mode impl --- odc/loader/_builder.py | 56 +++++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/odc/loader/_builder.py b/odc/loader/_builder.py index ed9b4c9..c318ad7 100644 --- a/odc/loader/_builder.py +++ b/odc/loader/_builder.py @@ -153,6 +153,7 @@ def __init__( env: Dict[str, Any], rdr: ReaderDriver, chunks: Mapping[str, int], + mode: Literal["auto"] | Literal["mem"] | Literal["concurrency"] = "auto", ) -> None: gbox = gbt.base assert isinstance(gbox, GeoBox) @@ -160,6 +161,8 @@ def __init__( chunk_tyx = (chunks.get("time", 1), *gbt.chunk_shape((0, 0)).yx) chunks = {**chunks} chunks.update(dict(zip(["time", "y", "x"], chunk_tyx))) + if mode == "auto": + mode = "mem" # current default self.cfg = cfg self.template = template @@ -168,8 +171,9 @@ def __init__( self.gbt = gbt self.env = env self.rdr = rdr - self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks) + self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks, mode) self._chunks = chunks + self._mode = mode self._load_state = rdr.new_load(gbox, chunks=chunks) def _band_chunks( @@ -190,7 +194,7 @@ def build( gbox: GeoBox, time: Sequence[datetime], bands: Mapping[str, RasterLoadParams], - ): + ) -> xr.Dataset: return mk_dataset( gbox, time, @@ -199,13 +203,17 @@ def build( template=self.template, ) - def _prep_sources( - self, name: str, dsk: dict[Key, Any], deps: list[Any] - ) -> tuple[str, Any]: + def _norm_load_state(self, deps: list[Any]) -> Any: load_state = self._load_state if is_dask_collection(load_state): deps.append(load_state) load_state = load_state.key + return load_state + + def _prep_sources( + self, name: str, dsk: dict[Key, Any], deps: list[Any] + ) -> tuple[str, Any]: + load_state = self._norm_load_state(deps) tk = self._tk src_key = f"open-{name}-{tk}" @@ -237,6 +245,8 @@ def __call__( assert dtype == cfg.dtype assert ydim == cfg.ydim + 1 # +1 for time dimension + chunks = self._band_chunks(name, shape, ydim) + tk = self._tk deps: list[Any] = [] cfg_dask_key = f"cfg-{tokenize(cfg)}" @@ -246,25 +256,31 @@ def __call__( cfg_dask_key: cfg, gbt_dask_key: self.gbt, } - src_key, load_state = self._prep_sources(name, dsk, deps) + if self._mode == "mem": + src_key, load_state = self._prep_sources(name, dsk, deps) + else: + src_key, load_state = "", self._norm_load_state(deps) band_key = f"{name}-{tk}" - chunks = self._band_chunks(name, shape, ydim) for task in self.load_tasks(name, shape[0]): - dsk[(band_key, *task.idx)] = ( - _dask_loader_tyx, - task.resolve_sources_dask(src_key, dsk), - gbt_dask_key, - quote(task.idx_tyx[1:]), - quote(task.prefix_dims), - quote(task.postfix_dims), - cfg_dask_key, - self.rdr, - self.env, - load_state, - task.selection, - ) + task_key: Key = (band_key, *task.idx) + if self._mode == "mem": + dsk[task_key] = ( + _dask_loader_tyx, + task.resolve_sources_dask(src_key, dsk), + gbt_dask_key, + quote(task.idx_tyx[1:]), + quote(task.prefix_dims), + quote(task.postfix_dims), + cfg_dask_key, + self.rdr, + self.env, + load_state, + task.selection, + ) + else: + raise NotImplementedError("concurrency mode not implemented") dsk = HighLevelGraph.from_collections(band_key, dsk, dependencies=deps)