Skip to content

Commit

Permalink
refactor: DaskBuilder extra dims
Browse files Browse the repository at this point in the history
- use load_tasks in DaskBuilder
  - implements extra dims slicing
- fix in load_tasks():
   - Set `selection=None` when extra dims are not
     sliced into
- fix in stac.load
   - Raster metadata needs to be supplied in query
     band names and not canonical names
  • Loading branch information
Kirill888 committed May 29, 2024
1 parent 447b882 commit 0d2e58e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 25 deletions.
50 changes: 26 additions & 24 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def resolve_sources(
out.append(_srcs)
return out

def resolve_sources_dask(self, dask_key: str) -> list[list[tuple[str, int]]]:
return [[(dask_key, idx) for idx, _ in layer] for layer in self.srcs]


class DaskGraphBuilder:
"""
Expand Down Expand Up @@ -234,46 +237,38 @@ def __call__(
)
assert len(chunk_shape) == len(shape)
chunks: tuple[tuple[int, ...], ...] = normalize_chunks(chunk_shape, shape)
tchunk_range = [
range(last - n, last) for last, n in zip(np.cumsum(chunks[0]), chunks[0])
]

shape_in_blocks = tuple(len(ch) for ch in chunks)

for block_idx in np.ndindex(shape_in_blocks):
ti, yi, xi = block_idx[0], block_idx[ydim], block_idx[ydim + 1]
srcs_keys: list[list[tuple[str, int]]] = []
for _ti in tchunk_range[ti]:
srcs_keys.append(
[
(src_key, src_idx)
for src_idx in self.tyx_bins.get((_ti, yi, xi), [])
if (src_key, src_idx) in dsk
]
)

dsk[(band_key, *block_idx)] = (
for task in self.load_tasks(name):
dsk[(band_key, *task.idx)] = (
_dask_loader_tyx,
srcs_keys,
task.resolve_sources_dask(src_key),
gbt_dask_key,
quote((yi, xi)),
quote(prefix_dims),
quote(postfix_dims),
quote(task.idx_tyx),
quote(task.prefix_dims),
quote(task.postfix_dims),
cfg_dask_key,
self.rdr,
self.env,
load_state,
task.selection,
)

dsk = HighLevelGraph.from_collections(band_key, dsk, dependencies=deps)

return da.Array(dsk, band_key, chunks, dtype=dtype, shape=shape)

def load_tasks(self, name: str) -> Iterator[LoadChunkTask]:
chunks = {**self._chunks}

# make sure chunks for tyx match our structure
chunks.update(dict(zip(["time", "y", "x"], self.chunk_tyx)))

return load_tasks(
self.cfg,
self.tyx_bins,
self.gbt,
nt=len(self.tyx_bins),
chunks=chunks,
extra_dims=self.template.extra_dims_full(name),
bands=[name],
)
Expand All @@ -299,6 +294,7 @@ def _dask_loader_tyx(
rdr: ReaderDriver,
env: Dict[str, Any],
load_state: Any,
selection: Any | None = None,
):
assert cfg.dtype is not None
gbox = cast(GeoBox, gbt[iyx])
Expand All @@ -309,7 +305,9 @@ def _dask_loader_tyx(
ydim = len(prefix_dims)
with rdr.restore_env(env, load_state):
for ti, ti_srcs in enumerate(srcs):
_fill_nd_slice(ti_srcs, gbox, cfg, chunk[ti], ydim=ydim)
_fill_nd_slice(
ti_srcs, gbox, cfg, chunk[ti], ydim=ydim, selection=selection
)
return chunk


Expand All @@ -319,12 +317,14 @@ def _fill_nd_slice(
cfg: RasterLoadParams,
dst: Any,
ydim: int = 0,
selection: Any | None = None,
) -> Any:
# TODO: support masks not just nodata based fusing
#
# ``nodata`` marks missing pixels, but it might be None (everything is valid)
# ``fill_value`` is the initial value to use, it's equal to ``nodata`` when set,
# otherwise defaults to .nan for floats and 0 for integers
# pylint: disable=too-many-locals

assert dst.shape[ydim : ydim + 2] == dst_gbox.shape.yx
postfix_roi = (slice(None),) * len(dst.shape[ydim + 2 :])
Expand All @@ -338,7 +338,7 @@ def _fill_nd_slice(
return dst

src, *rest = srcs
yx_roi, pix = src.read(cfg, dst_gbox, dst=dst)
yx_roi, pix = src.read(cfg, dst_gbox, dst=dst, selection=selection)
assert len(yx_roi) == 2
assert pix.ndim == dst.ndim

Expand Down Expand Up @@ -612,6 +612,8 @@ def load_tasks(
)
if len(selection) == 1:
selection = selection[0]
if shape_in_chunks[3] == 1:
selection = None

yield LoadChunkTask(
band_name,
Expand Down
2 changes: 1 addition & 1 deletion odc/stac/_stac_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def _with_debug_info(ds: xr.Dataset, **kw) -> xr.Dataset:
return _with_debug_info(
chunked_load(
load_cfg,
collection.meta,
collection.meta_for(bands),
_parsed,
tyx_bins,
gbt,
Expand Down
11 changes: 11 additions & 0 deletions odc/stac/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ def __getitem__(self, band: BandIdentifier) -> RasterBandMetadata:
def bands(self) -> Dict[BandKey, RasterBandMetadata]:
return self.meta.bands

def meta_for(self, bands: BandQuery = None) -> RasterGroupMetadata:
"""
Extract raster group metadata for a subset of bands.
Output uses supplied band names as keys, effectively replacing canonical
names with aliases supplied by the user.
"""
return self.meta.patch(
bands={norm_key(b): self[b] for b in self.normalize_band_query(bands)}
)

@property
def aliases(self) -> Dict[str, List[BandKey]]:
return self.meta.aliases
Expand Down

0 comments on commit 0d2e58e

Please sign in to comment.