Skip to content

Commit

Permalink
Implement chunking along time dimension #81
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Oct 9, 2023
1 parent eb751e4 commit cf31f64
Showing 1 changed file with 53 additions and 37 deletions.
90 changes: 53 additions & 37 deletions odc/stac/_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,15 @@ def __init__(
tyx_bins: Dict[Tuple[int, int, int], List[int]],
gbt: GeoboxTiles,
env: Dict[str, Any],
time_chunks: int = 1,
) -> None:
self.cfg = cfg
self.items = items
self.tyx_bins = tyx_bins
self.gbt = gbt
self.env = env
self._tk = tokenize(items, cfg, gbt, tyx_bins, env)
self._tk = tokenize(items, cfg, gbt, tyx_bins, env, time_chunks)
self.chunk_shape = (time_chunks, *self.gbt.chunk_shape((0, 0)).yx)

def __call__(
self,
Expand All @@ -123,6 +125,11 @@ def __call__(
cfg = self.cfg[name]
assert dtype == cfg.dtype

chunks = unpack_chunks(self.chunk_shape, shape)
tchunk_range = [
range(last - n, last) for last, n in zip(np.cumsum(chunks[0]), chunks[0])
]

cfg_key = f"cfg-{tokenize(cfg)}"
gbt_key = f"grid-{tokenize(self.gbt)}"

Expand All @@ -133,19 +140,23 @@ def __call__(
tk = self._tk
band_key = f"{name}-{tk}"
md_key = f"md-{name}-{tk}"
shape_in_blocks = (shape[0], *self.gbt.shape.yx)
shape_in_blocks = tuple(len(ch) for ch in chunks)
for idx, item in enumerate(self.items):
band = item.get(name, None)
if band is not None:
dsk[md_key, idx] = band

for ti, yi, xi in np.ndindex(shape_in_blocks):
tyx_idx = (ti, yi, xi)
srcs = [
(md_key, idx)
for idx in self.tyx_bins.get(tyx_idx, [])
if (md_key, idx) in dsk
]
srcs = []
for _ti in tchunk_range[ti]:
srcs.append(
[
(md_key, idx)
for idx in self.tyx_bins.get((_ti, yi, xi), [])
if (md_key, idx) in dsk
]
)

dsk[band_key, ti, yi, xi] = (
_dask_loader_tyx,
srcs,
Expand All @@ -155,9 +166,6 @@ def __call__(
self.env,
)

chunk_shape = (1, *self.gbt.chunk_shape((0, 0)).yx)
chunks = unpack_chunks(chunk_shape, shape)

return da.Array(dsk, band_key, chunks, dtype=dtype, shape=shape)


Expand Down Expand Up @@ -522,15 +530,6 @@ def load(
)
dtype = "uint16" if len(_dtypes) == 0 else _dtypes[0]

if chunks is not None:
chunk_shape = _resolve_chunk_shape(gbox, chunks, dtype)
else:
chunk_shape = _resolve_chunk_shape(
gbox,
{dim: DEFAULT_CHUNK_FOR_LOAD for dim in gbox.dimensions},
dtype,
)

if patch_url is not None:
_parsed = [patch_urls(item, edit=patch_url, bands=bands) for item in _parsed]

Expand All @@ -546,9 +545,19 @@ def load(

tss = _extract_timestamps(ndeepmap(2, lambda idx: _parsed[idx], _grouped_idx))

if chunks is not None:
chunk_shape = _resolve_chunk_shape(len(tss), gbox, chunks, dtype)
else:
chunk_shape = _resolve_chunk_shape(
len(tss),
gbox,
{dim: DEFAULT_CHUNK_FOR_LOAD for dim in gbox.dimensions},
dtype,
)

# Spatio-temporal binning
assert isinstance(gbox.crs, CRS)
gbt = GeoboxTiles(gbox, chunk_shape)
gbt = GeoboxTiles(gbox, chunk_shape[1:])
tyx_bins = dict(_tyx_bins(_grouped_idx, _parsed, gbt))
_parsed = [item.strip() for item in _parsed]

Expand All @@ -573,15 +582,6 @@ def _with_debug_info(ds: xr.Dataset, **kw) -> xr.Dataset:
)
return ds

def _task_stream(bands: List[str]) -> Iterator[_LoadChunkTask]:
_shape = (len(_grouped_idx), *gbt.shape)
for band_name in bands:
cfg = load_cfg[band_name]
for ti, yi, xi in np.ndindex(_shape):
tyx_idx = (ti, yi, xi)
srcs = [(idx, band_name) for idx in tyx_bins.get(tyx_idx, [])]
yield _LoadChunkTask(band_name, srcs, cfg, gbt, tyx_idx)

_rio_env = _capture_rio_env()
if chunks is not None:
# Dask case: dummy for now
Expand All @@ -591,9 +591,19 @@ def _task_stream(bands: List[str]) -> Iterator[_LoadChunkTask]:
tyx_bins,
gbt,
_rio_env,
time_chunks=chunk_shape[0],
)
return _with_debug_info(_mk_dataset(gbox, tss, load_cfg, _loader))

def _task_stream(bands: List[str]) -> Iterator[_LoadChunkTask]:
_shape = (len(_grouped_idx), *gbt.shape)
for band_name in bands:
cfg = load_cfg[band_name]
for ti, yi, xi in np.ndindex(_shape):
tyx_idx = (ti, yi, xi)
srcs = [(idx, band_name) for idx in tyx_bins.get(tyx_idx, [])]
yield _LoadChunkTask(band_name, srcs, cfg, gbt, tyx_idx)

ds = _mk_dataset(gbox, tss, load_cfg)
ny, nx = gbt.shape.yx
total_tasks = len(bands) * len(tss) * ny * nx
Expand Down Expand Up @@ -671,17 +681,19 @@ def _resolve(name: str, band: RasterBandMetadata) -> RasterLoadParams:


def _dask_loader_tyx(
srcs: List[RasterSource],
srcs: List[List[RasterSource]],
gbt: GeoboxTiles,
iyx: Tuple[int, int],
cfg: RasterLoadParams,
env: Dict[str, Any],
):
assert cfg.dtype is not None
gbox = cast(GeoBox, gbt[iyx])
chunk = np.empty(gbox.shape.yx, dtype=cfg.dtype)
chunk = np.empty((len(srcs), *gbox.shape.yx), dtype=cfg.dtype)
with rio_env(**env):
return _fill_2d_slice(srcs, gbox, cfg, chunk)[np.newaxis]
for i, plane in enumerate(srcs):
_fill_2d_slice(plane, gbox, cfg, chunk[i, :, :])
return chunk


def _fill_2d_slice(
Expand Down Expand Up @@ -866,12 +878,16 @@ def _tyx_bins(


def _resolve_chunk_shape(
gbox: GeoBox, chunks: Dict[str, int | Literal["auto"]], dtype: Any
) -> Tuple[int, int]:
nt: int, gbox: GeoBox, chunks: Dict[str, int | Literal["auto"]], dtype: Any
) -> Tuple[int, int, int]:
tt = chunks.get("time", 1)
ty, tx = (
chunks.get(dim, chunks.get(fallback_dim, -1))
for dim, fallback_dim in zip(gbox.dimensions, ["y", "x"])
)
ny, nx = (ch[0] for ch in normalize_chunks((ty, tx), gbox.shape.yx, dtype=dtype))
nt, ny, nx = (
ch[0]
for ch in normalize_chunks((tt, ty, tx), (nt, *gbox.shape.yx), dtype=dtype)
)

return ny, nx
return nt, ny, nx

0 comments on commit cf31f64

Please sign in to comment.