diff --git a/odc/loader/_builder.py b/odc/loader/_builder.py index 27cf74a..56c5abc 100644 --- a/odc/loader/_builder.py +++ b/odc/loader/_builder.py @@ -23,14 +23,16 @@ import numpy as np import xarray as xr from dask import array as da +from dask import is_dask_collection from dask.array.core import normalize_chunks from dask.base import quote, tokenize +from dask.highlevelgraph import HighLevelGraph, Key from numpy.typing import DTypeLike from odc.geo.geobox import GeoBox, GeoBoxBase, GeoboxTiles from odc.geo.xr import xr_coords from ._dask import unpack_chunks -from ._reader import nodata_mask, resolve_src_nodata +from ._reader import nodata_mask, resolve_dst_fill_value, resolve_src_nodata from ._utils import SizedIterable, pmap from .types import ( FixedCoord, @@ -146,10 +148,16 @@ def __call__( range(last - n, last) for last, n in zip(np.cumsum(chunks[0]), chunks[0]) ] + deps: list[Any] = [] + load_state = self._load_state + if is_dask_collection(load_state): + deps.append(load_state) + load_state = load_state.key + cfg_dask_key = f"cfg-{tokenize(cfg)}" gbt_dask_key = f"grid-{tokenize(self.gbt)}" - dsk: Dict[Hashable, Any] = { + dsk: Dict[Key, Any] = { cfg_dask_key: cfg, gbt_dask_key: self.gbt, } @@ -166,7 +174,7 @@ def __call__( band, self.rdr, self.env, - self._load_state, + load_state, ) for block_idx in np.ndindex(shape_in_blocks): @@ -190,9 +198,11 @@ def __call__( cfg_dask_key, self.rdr, self.env, - self._load_state, + load_state, ) + dsk = HighLevelGraph.from_collections(band_key, dsk, dependencies=deps) + return da.Array(dsk, band_key, chunks, dtype=dtype, shape=shape) @@ -242,11 +252,7 @@ def _fill_nd_slice( postfix_roi = (slice(None),) * len(dst.shape[2:]) nodata = resolve_src_nodata(cfg.fill_value, cfg) - - if nodata is None: - fill_value = float("nan") if dst.dtype.kind == "f" else 0 - else: - fill_value = nodata + fill_value = resolve_dst_fill_value(dst.dtype, cfg, nodata) np.copyto(dst, fill_value) if len(srcs) == 0: diff --git a/odc/loader/_reader.py b/odc/loader/_reader.py index fecc867..b2d462b 100644 --- a/odc/loader/_reader.py +++ b/odc/loader/_reader.py @@ -101,6 +101,17 @@ def resolve_dst_nodata( return None +def resolve_dst_fill_value( + dst_dtype: np.dtype, + cfg: RasterLoadParams, + src_nodata: Optional[float] = None, +) -> float: + nodata = resolve_dst_nodata(dst_dtype, cfg, src_nodata) + if nodata is None: + return dst_dtype.type(0) + return nodata + + def pick_overview(read_shrink: int, overviews: Sequence[int]) -> Optional[int]: if len(overviews) == 0 or read_shrink < overviews[0]: return None