Skip to content

Commit

Permalink
refactor: fill-nd-slice
Browse files Browse the repository at this point in the history
split fusing and reading
  • Loading branch information
Kirill888 committed Jun 4, 2024
1 parent 1d64d46 commit dfc6c9c
Showing 1 changed file with 31 additions and 19 deletions.
50 changes: 31 additions & 19 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Any,
Dict,
Hashable,
Iterable,
Iterator,
List,
Literal,
Expand Down Expand Up @@ -319,6 +320,29 @@ def _dask_loader_tyx(
return chunk


def fuse_nd_slices(
srcs: Iterable[tuple[tuple[slice, slice], np.ndarray]],
fill_value: float | int,
dst: Any,
ydim: int = 0,
prefilled: bool = False,
) -> Any:
postfix_roi = (slice(None),) * len(dst.shape[ydim + 2 :])
prefix_roi = (slice(None),) * ydim

if not prefilled:
np.copyto(dst, fill_value)

for yx_roi, pix in srcs:
_roi: tuple[slice, ...] = prefix_roi + yx_roi + postfix_roi # type: ignore
assert dst[_roi].shape == pix.shape

missing = nodata_mask(dst[_roi], fill_value)
np.copyto(dst[_roi], pix, where=missing)

return dst


def _fill_nd_slice(
srcs: Sequence[RasterReader],
dst_gbox: GeoBox,
Expand All @@ -335,9 +359,6 @@ def _fill_nd_slice(
# pylint: disable=too-many-locals

assert dst.shape[ydim : ydim + 2] == dst_gbox.shape.yx
postfix_roi = (slice(None),) * len(dst.shape[ydim + 2 :])
prefix_roi = (slice(None),) * ydim

nodata = resolve_src_nodata(cfg.fill_value, cfg)
fill_value = resolve_dst_fill_value(dst.dtype, cfg, nodata)

Expand All @@ -350,22 +371,13 @@ def _fill_nd_slice(
assert len(yx_roi) == 2
assert pix.ndim == dst.ndim

for src in rest:
# first valid pixel takes precedence over others
yx_roi, pix = src.read(cfg, dst_gbox, selection=selection)
assert len(yx_roi) == 2
assert pix.ndim == dst.ndim

_roi: Tuple[slice,] = prefix_roi + yx_roi + postfix_roi # type: ignore
assert dst[_roi].shape == pix.shape

# nodata mask takes care of nan when working with floats
# so you can still get proper mask even when nodata is None
# when working with float32 data.
missing = nodata_mask(dst[_roi], nodata)
np.copyto(dst[_roi], pix, where=missing)

return dst
return fuse_nd_slices(
(src.read(cfg, dst_gbox, selection=selection) for src in rest),
fill_value,
dst,
ydim=ydim,
prefilled=True,
)


def mk_dataset(
Expand Down

0 comments on commit dfc6c9c

Please sign in to comment.