Skip to content

Commit

Permalink
refactor: DaskBuilder extra dims
Browse files Browse the repository at this point in the history
- use load_tasks in DaskBuilder
  - implements extra dims slicing
- fix in load_tasks():
   - Set `selection=None` when extra dims are not
     sliced into
- fix in stac.load
   - Raster metadata needs to be supplied in query
     band names and not canonical names
  • Loading branch information
Kirill888 committed Jun 3, 2024
1 parent 447b882 commit ace5d34
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 57 deletions.
97 changes: 50 additions & 47 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
RasterReader,
RasterSource,
ReaderDriver,
T,
)


Expand Down Expand Up @@ -122,6 +123,9 @@ def resolve_sources(
out.append(_srcs)
return out

def resolve_sources_dask(self, dask_key: str) -> list[list[tuple[str, int]]]:
return [[(dask_key, idx) for idx, _ in layer] for layer in self.srcs]


class DaskGraphBuilder:
"""
Expand All @@ -143,6 +147,10 @@ def __init__(
) -> None:
gbox = gbt.base
assert isinstance(gbox, GeoBox)
# make sure chunks for tyx match our structure
chunk_tyx = (chunks.get("time", 1), *gbt.chunk_shape((0, 0)).yx)
chunks = {**chunks}
chunks.update(dict(zip(["time", "y", "x"], chunk_tyx)))

self.cfg = cfg
self.template = template
Expand All @@ -153,10 +161,20 @@ def __init__(
self.rdr = rdr
self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, chunks)
self._chunks = chunks
self.chunk_tyx = (chunks.get("time", 1), *self.gbt.chunk_shape((0, 0)).yx)
self._load_state = rdr.new_load(
gbox, chunks=dict(zip(["time", "y", "x"], self.chunk_tyx))
self._load_state = rdr.new_load(gbox, chunks=chunks)

def _band_chunks(
self,
band: str,
shape: tuple[int, ...],
ydim: int,
) -> tuple[tuple[int, ...], ...]:
chunks = resolve_chunks(
(shape[0], shape[ydim], shape[ydim + 1]),
self._chunks,
extra_dims=self.template.extra_dims_full(band),
)
return denorm_ydim(chunks, ydim)

def build(
self,
Expand Down Expand Up @@ -222,58 +240,34 @@ def __call__(
src_key, load_state = self._prep_sources(name, dsk, deps)

band_key = f"{name}-{tk}"
chunks = self._band_chunks(name, shape, ydim)

postfix_dims = shape[ydim + 2 :]
prefix_dims = shape[1:ydim]

chunk_shape: Tuple[int, ...] = (
self.chunk_tyx[0],
*prefix_dims,
*self.chunk_tyx[1:],
*postfix_dims,
)
assert len(chunk_shape) == len(shape)
chunks: tuple[tuple[int, ...], ...] = normalize_chunks(chunk_shape, shape)
tchunk_range = [
range(last - n, last) for last, n in zip(np.cumsum(chunks[0]), chunks[0])
]

shape_in_blocks = tuple(len(ch) for ch in chunks)

for block_idx in np.ndindex(shape_in_blocks):
ti, yi, xi = block_idx[0], block_idx[ydim], block_idx[ydim + 1]
srcs_keys: list[list[tuple[str, int]]] = []
for _ti in tchunk_range[ti]:
srcs_keys.append(
[
(src_key, src_idx)
for src_idx in self.tyx_bins.get((_ti, yi, xi), [])
if (src_key, src_idx) in dsk
]
)

dsk[(band_key, *block_idx)] = (
for task in self.load_tasks(name, shape[0]):
dsk[(band_key, *task.idx)] = (
_dask_loader_tyx,
srcs_keys,
task.resolve_sources_dask(src_key),
gbt_dask_key,
quote((yi, xi)),
quote(prefix_dims),
quote(postfix_dims),
quote(task.idx_tyx[1:]),
quote(task.prefix_dims),
quote(task.postfix_dims),
cfg_dask_key,
self.rdr,
self.env,
load_state,
task.selection,
)

dsk = HighLevelGraph.from_collections(band_key, dsk, dependencies=deps)

return da.Array(dsk, band_key, chunks, dtype=dtype, shape=shape)

def load_tasks(self, name: str) -> Iterator[LoadChunkTask]:
def load_tasks(self, name: str, nt: int) -> Iterator[LoadChunkTask]:
return load_tasks(
self.cfg,
self.tyx_bins,
self.gbt,
nt=nt,
chunks=self._chunks,
extra_dims=self.template.extra_dims_full(name),
bands=[name],
)
Expand All @@ -299,6 +293,7 @@ def _dask_loader_tyx(
rdr: ReaderDriver,
env: Dict[str, Any],
load_state: Any,
selection: Any | None = None,
):
assert cfg.dtype is not None
gbox = cast(GeoBox, gbt[iyx])
Expand All @@ -309,7 +304,9 @@ def _dask_loader_tyx(
ydim = len(prefix_dims)
with rdr.restore_env(env, load_state):
for ti, ti_srcs in enumerate(srcs):
_fill_nd_slice(ti_srcs, gbox, cfg, chunk[ti], ydim=ydim)
_fill_nd_slice(
ti_srcs, gbox, cfg, chunk[ti], ydim=ydim, selection=selection
)
return chunk


Expand All @@ -319,12 +316,14 @@ def _fill_nd_slice(
cfg: RasterLoadParams,
dst: Any,
ydim: int = 0,
selection: Any | None = None,
) -> Any:
# TODO: support masks not just nodata based fusing
#
# ``nodata`` marks missing pixels, but it might be None (everything is valid)
# ``fill_value`` is the initial value to use, it's equal to ``nodata`` when set,
# otherwise defaults to .nan for floats and 0 for integers
# pylint: disable=too-many-locals

assert dst.shape[ydim : ydim + 2] == dst_gbox.shape.yx
postfix_roi = (slice(None),) * len(dst.shape[ydim + 2 :])
Expand All @@ -338,13 +337,13 @@ def _fill_nd_slice(
return dst

src, *rest = srcs
yx_roi, pix = src.read(cfg, dst_gbox, dst=dst)
yx_roi, pix = src.read(cfg, dst_gbox, dst=dst, selection=selection)
assert len(yx_roi) == 2
assert pix.ndim == dst.ndim

for src in rest:
# first valid pixel takes precedence over others
yx_roi, pix = src.read(cfg, dst_gbox)
yx_roi, pix = src.read(cfg, dst_gbox, selection=selection)
assert len(yx_roi) == 2
assert pix.ndim == dst.ndim

Expand Down Expand Up @@ -521,7 +520,7 @@ def dask_chunked_load(
return dask_loader.build(gbox, tss, load_cfg)


def denorm_ydim(x: tuple[int, ...], ydim: int) -> tuple[int, ...]:
def denorm_ydim(x: tuple[T, ...], ydim: int) -> tuple[T, ...]:
ydim = ydim - 1
if ydim == 0:
return x
Expand All @@ -546,14 +545,14 @@ def load_tasks(
instances for every possible time, y, x, bins, including empty ones.
"""
# pylint: disable=too-many-locals
extra_dims = extra_dims or {}
chunks = chunks or {}

if nt is None:
nt = max(t for t, _, _ in tyx_bins) + 1

if extra_dims is None:
extra_dims = {}
if chunks is None:
chunks = {}

chunks = {**chunks}
chunks.update(zip(["y", "x"], gbt.chunk_shape((0, 0)).yx))
base_shape = (nt, *gbt.base.shape.yx)

if bands is None:
Expand Down Expand Up @@ -612,6 +611,8 @@ def load_tasks(
)
if len(selection) == 1:
selection = selection[0]
if shape_in_chunks[3] == 1:
selection = None

yield LoadChunkTask(
band_name,
Expand Down Expand Up @@ -681,6 +682,8 @@ def _do_one(task: LoadChunkTask) -> Tuple[str, int, int, int]:
nt=nt,
extra_dims=template.extra_dims_full(),
)
tasks = list(tasks)
assert len(tasks) == total_tasks

_work = pmap(_do_one, tasks, pool)

Expand Down
16 changes: 16 additions & 0 deletions odc/loader/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,22 @@ def resolve_band_query(
return src.band


def expand_selection(selection: Any, ydim: int) -> tuple[slice, ...]:
"""
Add Y/X slices to selection tuple
:param selection: Selection object
:return: Tuple of slices
"""
if selection is None:
selection = ()
if not isinstance(selection, tuple):
selection = (selection,)

prefix, postfix = selection[:ydim], selection[ydim:]
return prefix + (slice(None), slice(None)) + postfix


def pick_overview(read_shrink: int, overviews: Sequence[int]) -> Optional[int]:
if len(overviews) == 0 or read_shrink < overviews[0]:
return None
Expand Down
8 changes: 7 additions & 1 deletion odc/loader/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,13 @@ def test_mk_dataset(


@pytest.mark.parametrize("bands,extra_coords,extra_dims,expect", rlp_fixtures)
@pytest.mark.parametrize("chunk_extra_dims", [False, True])
def test_dask_builder(
bands: Dict[str, RasterLoadParams],
extra_coords: Sequence[FixedCoord] | None,
extra_dims: Mapping[str, int] | None,
expect: Mapping[str, _sn],
chunk_extra_dims: bool,
):
_bands = {
k: RasterBandMetadata(b.dtype, b.fill_value, dims=b.dims)
Expand Down Expand Up @@ -180,6 +182,10 @@ def test_dask_builder(
srcs = [src_mapper, src_mapper, src_mapper]
tyx_bins = _full_tyx_bins(gbt, nsrcs=len(srcs), nt=len(tss))

chunks = {"time": 1}
if chunk_extra_dims:
chunks = {k: 1 for k in extra_dims}

builder = DaskGraphBuilder(
bands,
template=template,
Expand All @@ -188,7 +194,7 @@ def test_dask_builder(
gbt=gbt,
env=rdr_env,
rdr=rdr,
chunks={"time": 1},
chunks=chunks,
)

xx = builder.build(gbox, tss, bands)
Expand Down
17 changes: 17 additions & 0 deletions odc/loader/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from odc.geo.xr import xr_zeros

from ._reader import (
expand_selection,
pick_overview,
resolve_band_query,
resolve_dst_dtype,
Expand Down Expand Up @@ -99,6 +100,22 @@ def test_resolve_band_query(
assert resolve_band_query(src, n, selection) == expect


@pytest.mark.parametrize(
"ydim, selection, expect",
[
(0, None, np.s_[:, :]),
(0, np.s_[:4], np.s_[:, :, :4]),
(1, np.s_[:3], np.s_[:3, :, :]),
(1, np.s_[8], np.s_[8, :, :]),
(0, np.s_[1:2, 3:4], np.s_[:, :, 1:2, 3:4]),
(1, np.s_[1:2, 3:4], np.s_[1:2, :, :, 3:4]),
(2, np.s_[1:2, 3:4], np.s_[1:2, 3:4, :, :]),
],
)
def test_expand_selection(ydim, selection, expect):
assert expand_selection(selection, ydim) == expect


def test_rio_reader_env():
gbox = GeoBox.from_bbox((-180, -90, 180, 90), shape=(160, 320), tight=True)
rdr = RioDriver()
Expand Down
2 changes: 2 additions & 0 deletions odc/loader/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,13 @@ def test_raster_band():
assert RasterBandMetadata("float32", -9999).nodata == -9999
assert RasterBandMetadata().units == "1"
assert RasterBandMetadata().unit == "1"
assert RasterBandMetadata().ndim == 2
assert RasterBandMetadata("float32").data_type == "float32"
assert RasterBandMetadata("float32").dtype == "float32"
assert RasterBandMetadata(dims=("y", "x", "B")).ydim == 0
assert RasterBandMetadata(dims=("B", "y", "x")).ydim == 1
assert RasterBandMetadata(dims=("B", "y", "x")).extra_dims == ("B",)
assert RasterBandMetadata(dims=("B", "y", "x")).ndim == 3

assert RasterBandMetadata().patch(nodata=-1).nodata == -1
assert RasterBandMetadata(nodata=10).patch(nodata=-1).nodata == -1
Expand Down
18 changes: 11 additions & 7 deletions odc/loader/testing/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from odc.geo.geobox import GeoBox
from odc.geo.xr import ODCExtensionDa

from .._reader import expand_selection
from ..types import (
BandKey,
MDParser,
Expand Down Expand Up @@ -126,26 +127,25 @@ class LoadState:
def __init__(
self,
geobox: GeoBox,
group_md: RasterGroupMetadata,
meta: RasterGroupMetadata,
env: dict[str, Any],
is_dask: bool,
) -> None:
self.geobox = geobox
self.group_md = group_md
self.meta = meta
self.env = env
self.is_dask = is_dask
self.finalised = False

def with_env(self, env: dict[str, Any]) -> "FakeReader.LoadState":
return FakeReader.LoadState(self.geobox, self.group_md, env, self.is_dask)
return FakeReader.LoadState(self.geobox, self.meta, env, self.is_dask)

def __init__(self, src: RasterSource, load_state: "FakeReader.LoadState"):
self._src = src
self._load_state = load_state

def _extra_dims(self) -> Dict[str, int]:
md = self._load_state.group_md
return md.extra_dims_full()
return self._load_state.meta.extra_dims_full()

def read(
self,
Expand All @@ -157,8 +157,6 @@ def read(
) -> tuple[tuple[slice, slice], np.ndarray]:
meta = self._src.meta
assert meta is not None
# TODO: handle selection
assert selection is None

extra_dims = self._extra_dims()
prefix_dims: tuple[int, ...] = ()
Expand All @@ -180,7 +178,13 @@ def read(
else:
assert src_pix.shape == shape

if selection is not None:
src_pix = src_pix[expand_selection(selection, ydim)]
prefix_dims = src_pix.shape[:ydim]
postfix_dims = src_pix.shape[ydim + 2 :]

assert postfix_dims == src_pix.shape[ydim + 2 :]
assert prefix_dims == src_pix.shape[:ydim]

if dst is None:
dst = np.zeros((*prefix_dims, ny, nx, *postfix_dims), dtype=cfg.dtype)
Expand Down
5 changes: 5 additions & 0 deletions odc/loader/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def ydim(self) -> int:
"""Index of y dimension, typically 0."""
return _ydim(self.dims)

@property
def ndim(self) -> int:
"""Number of dimensions."""
return len(self.extra_dims) + 2

@property
def unit(self) -> str:
"""
Expand Down
Loading

0 comments on commit ace5d34

Please sign in to comment.