Skip to content

Commit

Permalink
amend me: delayed load_state
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed May 17, 2024
1 parent b0a8226 commit 2adbe5f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
24 changes: 15 additions & 9 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@
import numpy as np
import xarray as xr
from dask import array as da
from dask import is_dask_collection
from dask.array.core import normalize_chunks
from dask.base import quote, tokenize
from dask.highlevelgraph import HighLevelGraph, Key
from numpy.typing import DTypeLike
from odc.geo.geobox import GeoBox, GeoBoxBase, GeoboxTiles
from odc.geo.xr import xr_coords

from ._dask import unpack_chunks
from ._reader import nodata_mask, resolve_src_nodata
from ._reader import nodata_mask, resolve_dst_fill_value, resolve_src_nodata
from ._utils import SizedIterable, pmap
from .types import (
FixedCoord,
Expand Down Expand Up @@ -146,10 +148,16 @@ def __call__(
range(last - n, last) for last, n in zip(np.cumsum(chunks[0]), chunks[0])
]

deps: list[Any] = []
load_state = self._load_state
if is_dask_collection(load_state):
deps.append(load_state)
load_state = load_state.key

cfg_dask_key = f"cfg-{tokenize(cfg)}"
gbt_dask_key = f"grid-{tokenize(self.gbt)}"

dsk: Dict[Hashable, Any] = {
dsk: Dict[Key, Any] = {
cfg_dask_key: cfg,
gbt_dask_key: self.gbt,
}
Expand All @@ -166,7 +174,7 @@ def __call__(
band,
self.rdr,
self.env,
self._load_state,
load_state,
)

for block_idx in np.ndindex(shape_in_blocks):
Expand All @@ -190,9 +198,11 @@ def __call__(
cfg_dask_key,
self.rdr,
self.env,
self._load_state,
load_state,
)

dsk = HighLevelGraph.from_collections(band_key, dsk, dependencies=deps)

return da.Array(dsk, band_key, chunks, dtype=dtype, shape=shape)


Expand Down Expand Up @@ -242,11 +252,7 @@ def _fill_nd_slice(
postfix_roi = (slice(None),) * len(dst.shape[2:])

nodata = resolve_src_nodata(cfg.fill_value, cfg)

if nodata is None:
fill_value = float("nan") if dst.dtype.kind == "f" else 0
else:
fill_value = nodata
fill_value = resolve_dst_fill_value(dst.dtype, cfg, nodata)

np.copyto(dst, fill_value)
if len(srcs) == 0:
Expand Down
11 changes: 11 additions & 0 deletions odc/loader/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ def resolve_dst_nodata(
return None


def resolve_dst_fill_value(
dst_dtype: np.dtype,
cfg: RasterLoadParams,
src_nodata: Optional[float] = None,
) -> float:
nodata = resolve_dst_nodata(dst_dtype, cfg, src_nodata)
if nodata is None:
return dst_dtype.type(0)
return nodata


def pick_overview(read_shrink: int, overviews: Sequence[int]) -> Optional[int]:
if len(overviews) == 0 or read_shrink < overviews[0]:
return None
Expand Down

0 comments on commit 2adbe5f

Please sign in to comment.