Skip to content

Commit

Permalink
Support B,Y,X output from rio driver
Browse files Browse the repository at this point in the history
no longer limited to reading a single plane of
pixels:

- band=0 indicates all bands
- currently output is always in B,Y,X order
- can read a subset if selection is supplied
  • Loading branch information
Kirill888 committed May 25, 2024
1 parent 2e3dbd3 commit 939eec7
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 18 deletions.
44 changes: 42 additions & 2 deletions odc/loader/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@
from __future__ import annotations

import math
from typing import Optional, Sequence
from typing import Any, Optional, Sequence

import numpy as np
from numpy.typing import DTypeLike

from .types import RasterBandMetadata, RasterLoadParams, with_default
from .types import (
RasterBandMetadata,
RasterLoadParams,
RasterSource,
ReaderSubsetSelection,
with_default,
)


def resolve_load_cfg(
Expand Down Expand Up @@ -112,6 +118,40 @@ def resolve_dst_fill_value(
return nodata


def _selection_to_bands(selection: Any, n: int) -> list[int]:
if isinstance(selection, list):
return selection

bidx = np.arange(1, n + 1)
if isinstance(selection, int):
return [int(bidx[selection])]
return bidx[selection].tolist()


def resolve_band_query(
src: RasterSource,
n: int,
selection: ReaderSubsetSelection | None = None,
) -> int | list[int]:
if src.band > n:
raise ValueError(
f"Requested band {src.band} from {src.uri} with only {n} bands"
)

if src.band == 0:
if selection:
return _selection_to_bands(selection, n)
return list(range(1, n + 1))

meta = src.meta
if meta is None:
return src.band
if meta.extra_dims:
return [src.band]

return src.band


def pick_overview(read_shrink: int, overviews: Sequence[int]) -> Optional[int]:
if len(overviews) == 0 or read_shrink < overviews[0]:
return None
Expand Down
37 changes: 22 additions & 15 deletions odc/loader/_rio.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ._reader import (
nodata_mask,
pick_overview,
resolve_band_query,
resolve_dst_dtype,
resolve_dst_nodata,
resolve_src_nodata,
Expand Down Expand Up @@ -384,18 +385,27 @@ def _do_read(
) -> tuple[tuple[slice, slice], np.ndarray]:
resampling = resampling_s2rio(cfg.resampling)
rdr = src.ds
roi_dst: tuple[slice, slice] = rr.roi_dst # type: ignore
ndim = 2
prefix: tuple[int, ...] = ()

if isinstance(src.bidx, int):
src_nodata0 = rdr.nodatavals[src.bidx - 1]
else:
ndim = 3
prefix = (len(src.bidx),)
(src_nodata0,) = (rdr.nodatavals[b - 1] for b in src.bidx[:1])

if dst is not None:
_dst = dst[rr.roi_dst] # type: ignore
# Assumes Y,X or B,Y,X order
assert dst.ndim == ndim
_dst = dst[(...,) + roi_dst] # type: ignore
else:
_dst = np.ndarray(
roi_shape(rr.roi_dst), dtype=resolve_dst_dtype(src.dtype, cfg)
)
shape = prefix + roi_shape(rr.roi_dst)
_dst = np.ndarray(shape, dtype=resolve_dst_dtype(src.dtype, cfg))

src_nodata0 = rdr.nodatavals[src.bidx - 1]
src_nodata = resolve_src_nodata(src_nodata0, cfg)
dst_nodata = resolve_dst_nodata(_dst.dtype, cfg, src_nodata)
roi_dst: tuple[slice, slice] = rr.roi_dst # type: ignore

if roi_is_empty(roi_dst):
return (roi_dst, _dst)
Expand Down Expand Up @@ -460,6 +470,7 @@ def rio_read(
"""

try:
# TODO: deal with Y,X,B order on output
return _rio_read(src, cfg, dst_geobox, dst, selection=selection)
except (
rasterio.errors.RasterioIOError,
Expand Down Expand Up @@ -506,25 +517,21 @@ def _rio_read(
# if resampling is `nearest` then ignore sub-pixel translation when deciding
# whether we can just paste source into destination
ttol = 0.9 if cfg.nearest else 0.05
assert selection is None, "Band selection not implemented in rio_read"

with rasterio.open(src.uri, "r", sharing=False) as rdr:
assert isinstance(rdr, rasterio.DatasetReader)
ovr_idx: Optional[int] = None

if src.band > rdr.count:
raise ValueError(f"No band {src.band} in '{src.uri}'")

bidx = resolve_band_query(src, rdr.count, selection=selection)
rr = _reproject_info_from_rio(rdr, dst_geobox, ttol=ttol)

if cfg.use_overviews and rr.read_shrink > 1:
ovr_idx = pick_overview(rr.read_shrink, rdr.overviews(src.band))
first_band = bidx if isinstance(bidx, int) else bidx[0]
ovr_idx = pick_overview(rr.read_shrink, rdr.overviews(first_band))

if ovr_idx is None:
with rio_env(VSI_CACHE=False):
return _do_read(
rasterio.band(rdr, src.band), cfg, dst_geobox, rr, dst=dst
)
return _do_read(rasterio.band(rdr, bidx), cfg, dst_geobox, rr, dst=dst)

# read from overview
with rasterio.open(
Expand All @@ -533,7 +540,7 @@ def _rio_read(
rr = _reproject_info_from_rio(rdr, dst_geobox, ttol=ttol)
with rio_env(VSI_CACHE=False):
return _do_read(
rasterio.band(rdr_ovr, src.band), cfg, dst_geobox, rr, dst=dst
rasterio.band(rdr_ovr, bidx), cfg, dst_geobox, rr, dst=dst
)


Expand Down
73 changes: 72 additions & 1 deletion odc/loader/test_reader.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
# pylint: disable=missing-function-docstring,missing-module-docstring,too-many-statements,too-many-locals
from __future__ import annotations

from math import isnan
from typing import Any

import numpy as np
import pytest
import rasterio
import xarray as xr
from numpy import ma
from numpy.testing import assert_array_equal
from odc.geo.geobox import GeoBox
from odc.geo.xr import xr_zeros

from ._reader import (
pick_overview,
resolve_band_query,
resolve_dst_dtype,
resolve_dst_nodata,
resolve_src_nodata,
same_nodata,
)
from ._rio import RioDriver, configure_rio, get_rio_env, rio_read
from .testing.fixtures import with_temp_tiff
from .types import RasterLoadParams, RasterSource
from .types import RasterBandMetadata, RasterLoadParams, RasterSource


def test_same_nodata():
Expand Down Expand Up @@ -68,6 +73,32 @@ def test_pick_overiew():
assert pick_overview(20, [2, 4, 8]) == 2


@pytest.mark.parametrize(
"n, dims, band, selection, expect",
[
(3, (), 1, None, 1),
(3, (), 3, None, 3),
(3, ("b", "y", "x"), 3, None, [3]),
(3, ("b", "y", "x"), 1, None, [1]),
(3, ("b", "y", "x"), 0, None, [1, 2, 3]),
(4, ("b", "y", "x"), 0, np.s_[:2], [1, 2]),
(4, ("b", "y", "x"), 0, (slice(1, 4),), [2, 3, 4]),
(4, ("b", "y", "x"), 0, [1, 3], [1, 3]),
(2, ("b", "y", "x"), 0, 1, [2]),
(5, ("b", "y", "x"), 0, -1, [5]),
],
)
def test_resolve_band_query(
n: int,
dims: tuple[str, ...],
band: int,
selection: Any,
expect: Any,
):
src = RasterSource("", band=band, meta=RasterBandMetadata(dims=dims))
assert resolve_band_query(src, n, selection) == expect


def test_rio_reader_env():
gbox = GeoBox.from_bbox((-180, -90, 180, 90), shape=(160, 320), tight=True)
rdr = RioDriver()
Expand Down Expand Up @@ -227,6 +258,46 @@ def test_reader_ovr():
assert _gbox[roi] == _gbox


@pytest.mark.parametrize("resamlpling", ["nearest", "bilinear", "cubic"])
def test_rio_read_rgb(resamlpling):
gbox = GeoBox.from_bbox((-180, -90, 180, 90), shape=(512, 512), tight=True)

non_zeros_roi = np.s_[30:47, 190:210]

xx = xr_zeros(gbox, dtype="uint8")
xx.values[non_zeros_roi] = 255
xx = xx.expand_dims("band", 2)
xx = xr.concat([xx, xx, xx], "band").assign_coords(band=["r", "g", "b"])

assert xx.odc.geobox == gbox

cfg = RasterLoadParams(
dtype="uint8",
dims=("band", "y", "x"),
resampling=resamlpling,
)
gbox2 = gbox.zoom_to(237)

# whole image from 1/2 overview
with with_temp_tiff(xx, compress=None, overview_levels=[2, 4]) as uri:
src = RasterSource(uri, band=0)
for gb in [gbox, gbox2]:
roi, pix = rio_read(src, cfg, gb)
assert len(roi) == 2
assert pix.ndim == 3
assert pix.shape == (3, *gb.shape)

# again but with dst=
_, pix2 = rio_read(src, cfg, gb, dst=pix)
assert pix2.shape == pix.shape

# again but with selection
roi, pix = rio_read(src, cfg, gb, selection=np.s_[:2])
assert len(roi) == 2
assert pix.ndim == 3
assert pix.shape == (2, *gb.shape)


def test_reader_unhappy_paths():
gbox = GeoBox.from_bbox((-180, -90, 180, 90), shape=(160, 320), tight=True)
xx = xr_zeros(gbox, dtype="int16")
Expand Down

0 comments on commit 939eec7

Please sign in to comment.