diff --git a/odc/loader/_builder.py b/odc/loader/_builder.py index 7f431f3..8b4d999 100644 --- a/odc/loader/_builder.py +++ b/odc/loader/_builder.py @@ -33,9 +33,15 @@ from odc.geo.geobox import GeoBox, GeoBoxBase, GeoboxTiles from odc.geo.xr import xr_coords -from ._reader import nodata_mask, resolve_dst_fill_value, resolve_src_nodata +from ._reader import ( + ReaderDaskAdaptor, + nodata_mask, + resolve_dst_fill_value, + resolve_src_nodata, +) from ._utils import SizedIterable, pmap from .types import ( + DaskRasterReader, MultiBandRasterSource, RasterGroupMetadata, RasterLoadParams, @@ -154,6 +160,7 @@ def __init__( env: Dict[str, Any], rdr: ReaderDriver, chunks: Mapping[str, int], + mode: Literal["auto"] | Literal["mem"] | Literal["concurrency"] = "auto", ) -> None: gbox = gbt.base assert isinstance(gbox, GeoBox) @@ -161,6 +168,8 @@ def __init__( chunk_tyx = (chunks.get("time", 1), *gbt.chunk_shape((0, 0)).yx) chunks = {**chunks} chunks.update(dict(zip(["time", "y", "x"], chunk_tyx))) + if mode == "auto": + mode = "mem" # current default self.cfg = cfg self.template = template @@ -169,8 +178,9 @@ def __init__( self.gbt = gbt self.env = env self.rdr = rdr - self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks) + self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks, mode) self._chunks = chunks + self._mode = mode self._load_state = rdr.new_load(gbox, chunks=chunks) def _band_chunks( @@ -191,7 +201,7 @@ def build( gbox: GeoBox, time: Sequence[datetime], bands: Mapping[str, RasterLoadParams], - ): + ) -> xr.Dataset: return mk_dataset( gbox, time, @@ -200,13 +210,17 @@ def build( template=self.template, ) - def _prep_sources( - self, name: str, dsk: dict[Key, Any], deps: list[Any] - ) -> tuple[str, Any]: + def _norm_load_state(self, deps: list[Any]) -> Any: load_state = self._load_state if is_dask_collection(load_state): deps.append(load_state) load_state = load_state.key + return load_state + + def _prep_sources( + self, name: str, dsk: dict[Key, Any], deps: list[Any] + ) -> tuple[str, Any]: + load_state = self._norm_load_state(deps) tk = self._tk src_key = f"open-{name}-{tk}" @@ -223,6 +237,35 @@ def _prep_sources( ) return src_key, load_state + def _dask_rdr(self) -> DaskRasterReader: + if (dask_reader := self.rdr.dask_reader) is not None: + return dask_reader + return ReaderDaskAdaptor(self.rdr, self.env) + + def _task_futures( + self, + task: LoadChunkTask, + dask_reader: DaskRasterReader, + deps: list[Any], + ) -> list[list[Any]]: + srcs = task.resolve_sources(self.srcs) + out: list[list[Any]] = [] + ctx = self._load_state + cfg = task.cfg + dst_gbox = task.dst_gbox + + for layer in srcs: + read_futures: list[Any] = [] + for src in layer: + rdr = dask_reader.open(src, ctx) + fut = rdr.read(cfg, dst_gbox, selection=task.selection) + read_futures.append(fut) + deps.append(fut) + + out.append(read_futures) + + return out + def __call__( self, shape: Tuple[int, ...], @@ -238,6 +281,8 @@ def __call__( assert dtype == cfg.dtype assert ydim == cfg.ydim + 1 # +1 for time dimension + chunks = self._band_chunks(name, shape, ydim) + tk = self._tk deps: list[Any] = [] cfg_dask_key = f"cfg-{tokenize(cfg)}" @@ -247,25 +292,48 @@ def __call__( cfg_dask_key: cfg, gbt_dask_key: self.gbt, } - src_key, load_state = self._prep_sources(name, dsk, deps) + dask_reader: DaskRasterReader | None = None + if self._mode == "mem": + src_key, load_state = self._prep_sources(name, dsk, deps) + else: + dask_reader = self._dask_rdr() + src_key, load_state = "", self._load_state + + fill_value = resolve_dst_fill_value( + np.dtype(dtype), + cfg, + resolve_src_nodata(cfg.fill_value, cfg), + ) band_key = f"{name}-{tk}" - chunks = self._band_chunks(name, shape, ydim) for task in self.load_tasks(name, shape[0]): - dsk[(band_key, *task.idx)] = ( - _dask_loader_tyx, - task.resolve_sources_dask(src_key, dsk), - gbt_dask_key, - quote(task.idx_tyx[1:]), - quote(task.prefix_dims), - quote(task.postfix_dims), - cfg_dask_key, - self.rdr, - self.env, - load_state, - task.selection, - ) + task_key: Key = (band_key, *task.idx) + if dask_reader is None: + dsk[task_key] = ( + _dask_loader_tyx, + task.resolve_sources_dask(src_key, dsk), + gbt_dask_key, + quote(task.idx_tyx[1:]), + quote(task.prefix_dims), + quote(task.postfix_dims), + cfg_dask_key, + self.rdr, + self.env, + load_state, + task.selection, + ) + else: + srcs_futures = self._task_futures(task, dask_reader, deps) + + dsk[task_key] = ( + _dask_fuser, + srcs_futures, + task.shape, + dtype, + fill_value, + ydim - 1, + ) dsk = HighLevelGraph.from_collections(band_key, dsk, dependencies=deps) @@ -320,6 +388,30 @@ def _dask_loader_tyx( return chunk +def _dask_fuser( + srcs: list[list[Any]], + shape: tuple[int, ...], + dtype: DTypeLike, + fill_value: float | int, + src_ydim: int = 0, +): + assert shape[0] == len(srcs) + assert len(shape) >= 3 # time, ..., y, x, ... + + dst = np.full(shape, fill_value, dtype=dtype) + + for ti, layer in enumerate(srcs): + fuse_nd_slices( + layer, + fill_value, + dst[ti], + ydim=src_ydim, + prefilled=True, + ) + + return dst + + def fuse_nd_slices( srcs: Iterable[tuple[tuple[slice, slice], np.ndarray]], fill_value: float | int,