From 22232ed30020f10b54fc2b81e48134678f1a40a8 Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Sat, 11 May 2024 14:12:26 +1000 Subject: [PATCH] amend me: delayed load_state --- odc/loader/_builder.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/odc/loader/_builder.py b/odc/loader/_builder.py index 27cf74a..f29094c 100644 --- a/odc/loader/_builder.py +++ b/odc/loader/_builder.py @@ -23,8 +23,10 @@ import numpy as np import xarray as xr from dask import array as da +from dask import is_dask_collection from dask.array.core import normalize_chunks from dask.base import quote, tokenize +from dask.highlevelgraph import HighLevelGraph, Key from numpy.typing import DTypeLike from odc.geo.geobox import GeoBox, GeoBoxBase, GeoboxTiles from odc.geo.xr import xr_coords @@ -146,10 +148,16 @@ def __call__( range(last - n, last) for last, n in zip(np.cumsum(chunks[0]), chunks[0]) ] + deps: list[Any] = [] + load_state = self._load_state + if is_dask_collection(load_state): + deps.append(load_state) + load_state = load_state.key + cfg_dask_key = f"cfg-{tokenize(cfg)}" gbt_dask_key = f"grid-{tokenize(self.gbt)}" - dsk: Dict[Hashable, Any] = { + dsk: Dict[Key, Any] = { cfg_dask_key: cfg, gbt_dask_key: self.gbt, } @@ -166,7 +174,7 @@ def __call__( band, self.rdr, self.env, - self._load_state, + load_state, ) for block_idx in np.ndindex(shape_in_blocks): @@ -190,9 +198,11 @@ def __call__( cfg_dask_key, self.rdr, self.env, - self._load_state, + load_state, ) + dsk = HighLevelGraph.from_collections(band_key, dsk, dependencies=deps) + return da.Array(dsk, band_key, chunks, dtype=dtype, shape=shape)