Skip to content

Commit

Permalink
refactor: DaskBuilder extra dims
Browse files Browse the repository at this point in the history
- use load_tasks in DaskBuilder
  - implements extra dims slicing
- fix in load_tasks():
   - Set `selection=None` when extra dims are not
     sliced into
- fix in stac.load
   - Raster metadata needs to be supplied in query
     band names and not canonical names
  • Loading branch information
Kirill888 committed May 31, 2024
1 parent 447b882 commit 4b01457
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 43 deletions.
83 changes: 42 additions & 41 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
RasterReader,
RasterSource,
ReaderDriver,
T,
)


Expand Down Expand Up @@ -122,6 +123,9 @@ def resolve_sources(
out.append(_srcs)
return out

def resolve_sources_dask(self, dask_key: str) -> list[list[tuple[str, int]]]:
return [[(dask_key, idx) for idx, _ in layer] for layer in self.srcs]


class DaskGraphBuilder:
"""
Expand All @@ -143,6 +147,10 @@ def __init__(
) -> None:
gbox = gbt.base
assert isinstance(gbox, GeoBox)
# make sure chunks for tyx match our structure
chunk_tyx = (chunks.get("time", 1), *gbt.chunk_shape((0, 0)).yx)
chunks = {**chunks}
chunks.update(dict(zip(["time", "y", "x"], chunk_tyx)))

self.cfg = cfg
self.template = template
Expand All @@ -153,10 +161,20 @@ def __init__(
self.rdr = rdr
self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks)
self._chunks = chunks
self.chunk_tyx = (chunks.get("time", 1), *self.gbt.chunk_shape((0, 0)).yx)
self._load_state = rdr.new_load(
gbox, chunks=dict(zip(["time", "y", "x"], self.chunk_tyx))
self._load_state = rdr.new_load(gbox, chunks=chunks)

def _band_chunks(
self,
band: str,
shape: tuple[int, ...],
ydim: int,
) -> tuple[tuple[int, ...], ...]:
chunks = resolve_chunks(
(shape[0], shape[ydim], shape[ydim + 1]),
self._chunks,
extra_dims=self.template.extra_dims_full(band),
)
return denorm_ydim(chunks, ydim)

def build(
self,
Expand Down Expand Up @@ -222,58 +240,34 @@ def __call__(
src_key, load_state = self._prep_sources(name, dsk, deps)

band_key = f"{name}-{tk}"
chunks = self._band_chunks(name, shape, ydim)

postfix_dims = shape[ydim + 2 :]
prefix_dims = shape[1:ydim]

chunk_shape: Tuple[int, ...] = (
self.chunk_tyx[0],
*prefix_dims,
*self.chunk_tyx[1:],
*postfix_dims,
)
assert len(chunk_shape) == len(shape)
chunks: tuple[tuple[int, ...], ...] = normalize_chunks(chunk_shape, shape)
tchunk_range = [
range(last - n, last) for last, n in zip(np.cumsum(chunks[0]), chunks[0])
]

shape_in_blocks = tuple(len(ch) for ch in chunks)

for block_idx in np.ndindex(shape_in_blocks):
ti, yi, xi = block_idx[0], block_idx[ydim], block_idx[ydim + 1]
srcs_keys: list[list[tuple[str, int]]] = []
for _ti in tchunk_range[ti]:
srcs_keys.append(
[
(src_key, src_idx)
for src_idx in self.tyx_bins.get((_ti, yi, xi), [])
if (src_key, src_idx) in dsk
]
)

dsk[(band_key, *block_idx)] = (
for task in self.load_tasks(name, shape[0]):
dsk[(band_key, *task.idx)] = (
_dask_loader_tyx,
srcs_keys,
task.resolve_sources_dask(src_key),
gbt_dask_key,
quote((yi, xi)),
quote(prefix_dims),
quote(postfix_dims),
quote(task.idx_tyx),
quote(task.prefix_dims),
quote(task.postfix_dims),
cfg_dask_key,
self.rdr,
self.env,
load_state,
task.selection,
)

dsk = HighLevelGraph.from_collections(band_key, dsk, dependencies=deps)

return da.Array(dsk, band_key, chunks, dtype=dtype, shape=shape)

def load_tasks(self, name: str) -> Iterator[LoadChunkTask]:
def load_tasks(self, name: str, nt: int) -> Iterator[LoadChunkTask]:
return load_tasks(
self.cfg,
self.tyx_bins,
self.gbt,
nt=nt,
chunks=self._chunks,
extra_dims=self.template.extra_dims_full(name),
bands=[name],
)
Expand All @@ -299,6 +293,7 @@ def _dask_loader_tyx(
rdr: ReaderDriver,
env: Dict[str, Any],
load_state: Any,
selection: Any | None = None,
):
assert cfg.dtype is not None
gbox = cast(GeoBox, gbt[iyx])
Expand All @@ -309,7 +304,9 @@ def _dask_loader_tyx(
ydim = len(prefix_dims)
with rdr.restore_env(env, load_state):
for ti, ti_srcs in enumerate(srcs):
_fill_nd_slice(ti_srcs, gbox, cfg, chunk[ti], ydim=ydim)
_fill_nd_slice(
ti_srcs, gbox, cfg, chunk[ti], ydim=ydim, selection=selection
)
return chunk


Expand All @@ -319,12 +316,14 @@ def _fill_nd_slice(
cfg: RasterLoadParams,
dst: Any,
ydim: int = 0,
selection: Any | None = None,
) -> Any:
# TODO: support masks not just nodata based fusing
#
# ``nodata`` marks missing pixels, but it might be None (everything is valid)
# ``fill_value`` is the initial value to use, it's equal to ``nodata`` when set,
# otherwise defaults to .nan for floats and 0 for integers
# pylint: disable=too-many-locals

assert dst.shape[ydim : ydim + 2] == dst_gbox.shape.yx
postfix_roi = (slice(None),) * len(dst.shape[ydim + 2 :])
Expand All @@ -338,7 +337,7 @@ def _fill_nd_slice(
return dst

src, *rest = srcs
yx_roi, pix = src.read(cfg, dst_gbox, dst=dst)
yx_roi, pix = src.read(cfg, dst_gbox, dst=dst, selection=selection)
assert len(yx_roi) == 2
assert pix.ndim == dst.ndim

Expand Down Expand Up @@ -521,7 +520,7 @@ def dask_chunked_load(
return dask_loader.build(gbox, tss, load_cfg)


def denorm_ydim(x: tuple[int, ...], ydim: int) -> tuple[int, ...]:
def denorm_ydim(x: tuple[T, ...], ydim: int) -> tuple[T, ...]:
ydim = ydim - 1
if ydim == 0:
return x
Expand Down Expand Up @@ -612,6 +611,8 @@ def load_tasks(
)
if len(selection) == 1:
selection = selection[0]
if shape_in_chunks[3] == 1:
selection = None

yield LoadChunkTask(
band_name,
Expand Down
12 changes: 10 additions & 2 deletions odc/stac/_stac_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,17 @@ def load(
)

tss = _extract_timestamps(ndeepmap(2, lambda idx: _parsed[idx], _grouped_idx))
meta = collection.meta_for(bands)

if chunks is not None:
chunk_shape = resolve_chunk_shape(len(tss), gbox, chunks, dtype, cfg=load_cfg)
chunk_shape = resolve_chunk_shape(
len(tss),
gbox,
chunks,
dtype,
cfg=load_cfg,
extra_dims=meta.extra_dims_full(),
)
else:
chunk_shape = (1, DEFAULT_CHUNK_FOR_LOAD, DEFAULT_CHUNK_FOR_LOAD)

Expand Down Expand Up @@ -442,7 +450,7 @@ def _with_debug_info(ds: xr.Dataset, **kw) -> xr.Dataset:
return _with_debug_info(
chunked_load(
load_cfg,
collection.meta,
meta,
_parsed,
tyx_bins,
gbt,
Expand Down
11 changes: 11 additions & 0 deletions odc/stac/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ def __getitem__(self, band: BandIdentifier) -> RasterBandMetadata:
def bands(self) -> Dict[BandKey, RasterBandMetadata]:
return self.meta.bands

def meta_for(self, bands: BandQuery = None) -> RasterGroupMetadata:
"""
Extract raster group metadata for a subset of bands.
Output uses supplied band names as keys, effectively replacing canonical
names with aliases supplied by the user.
"""
return self.meta.patch(
bands={norm_key(b): self[b] for b in self.normalize_band_query(bands)}
)

@property
def aliases(self) -> Dict[str, List[BandKey]]:
return self.meta.aliases
Expand Down

0 comments on commit 4b01457

Please sign in to comment.