Skip to content

Commit

Permalink
amend me: delayed load_state
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed May 11, 2024
1 parent 9b2a90a commit 22232ed
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import numpy as np
import xarray as xr
from dask import array as da
from dask import is_dask_collection
from dask.array.core import normalize_chunks
from dask.base import quote, tokenize
from dask.highlevelgraph import HighLevelGraph, Key
from numpy.typing import DTypeLike
from odc.geo.geobox import GeoBox, GeoBoxBase, GeoboxTiles
from odc.geo.xr import xr_coords
Expand Down Expand Up @@ -146,10 +148,16 @@ def __call__(
range(last - n, last) for last, n in zip(np.cumsum(chunks[0]), chunks[0])
]

deps: list[Any] = []
load_state = self._load_state
if is_dask_collection(load_state):
deps.append(load_state)
load_state = load_state.key

cfg_dask_key = f"cfg-{tokenize(cfg)}"
gbt_dask_key = f"grid-{tokenize(self.gbt)}"

dsk: Dict[Hashable, Any] = {
dsk: Dict[Key, Any] = {
cfg_dask_key: cfg,
gbt_dask_key: self.gbt,
}
Expand All @@ -166,7 +174,7 @@ def __call__(
band,
self.rdr,
self.env,
self._load_state,
load_state,
)

for block_idx in np.ndindex(shape_in_blocks):
Expand All @@ -190,9 +198,11 @@ def __call__(
cfg_dask_key,
self.rdr,
self.env,
self._load_state,
load_state,
)

dsk = HighLevelGraph.from_collections(band_key, dsk, dependencies=deps)

return da.Array(dsk, band_key, chunks, dtype=dtype, shape=shape)


Expand Down

0 comments on commit 22232ed

Please sign in to comment.