Skip to content

Commit

Permalink
Support Y,X,B output from rio driver
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed May 27, 2024
1 parent b4f29ac commit a16cb7b
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 7 deletions.
33 changes: 31 additions & 2 deletions odc/loader/_rio.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,39 @@ def rio_read(
mosaic[roi] = pix # if sources are true tiles (no overlaps)
"""
ydim = src.ydim

def prep_dst(dst: Optional[np.ndarray]) -> Optional[np.ndarray]:
if dst is None:
return None
if dst.ndim == 2 or ydim == 1:
# Y,X or B,Y,X
return dst

# Supplied as Y,X,B, but we need B,Y,X
assert ydim == 0 and dst.ndim == 3
return dst.transpose([2, 0, 1])

def fixup_out(
x: tuple[tuple[slice, slice], np.ndarray]
) -> tuple[tuple[slice, slice], np.ndarray]:
roi, out = x
if out.ndim == 2 or ydim == 1:
# Y,X or B,Y,X
return roi, out

# must be Y,X,B on output
if dst is not None:
return roi, dst[roi]

assert ydim == 0 and out.ndim == 3
# B,Y,X -> Y,X,B
return roi, out.transpose([1, 2, 0])

try:
# TODO: deal with Y,X,B order on output
return _rio_read(src, cfg, dst_geobox, dst, selection=selection)
return fixup_out(
_rio_read(src, cfg, dst_geobox, prep_dst(dst), selection=selection)
)
except (
rasterio.errors.RasterioIOError,
rasterio.errors.RasterBlockError,
Expand Down
19 changes: 14 additions & 5 deletions odc/loader/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def test_reader_ovr():


@pytest.mark.parametrize("resamlpling", ["nearest", "bilinear", "cubic"])
def test_rio_read_rgb(resamlpling):
@pytest.mark.parametrize("dims", [("y", "x", "band"), ("band", "y", "x")])
def test_rio_read_rgb(resamlpling, dims):
gbox = GeoBox.from_bbox((-180, -90, 180, 90), shape=(512, 512), tight=True)

non_zeros_roi = np.s_[30:47, 190:210]
Expand All @@ -273,19 +274,27 @@ def test_rio_read_rgb(resamlpling):

cfg = RasterLoadParams(
dtype="uint8",
dims=("band", "y", "x"),
dims=dims,
resampling=resamlpling,
)
gbox2 = gbox.zoom_to(237)

# whole image from 1/2 overview
with with_temp_tiff(xx, compress=None, overview_levels=[2, 4]) as uri:
src = RasterSource(uri, band=0)
src = RasterSource(
uri,
band=0,
meta=RasterBandMetadata(cfg.dtype, dims=cfg.dims),
)
for gb in [gbox, gbox2]:
expect_shape = (3, *gb.shape) if src.ydim == 1 else (*gb.shape, 3)
expect_shape_2 = (2, *gb.shape) if src.ydim == 1 else (*gb.shape, 2)

roi, pix = rio_read(src, cfg, gb)

assert len(roi) == 2
assert pix.ndim == 3
assert pix.shape == (3, *gb.shape)
assert pix.shape == expect_shape

# again but with dst=
_, pix2 = rio_read(src, cfg, gb, dst=pix)
Expand All @@ -295,7 +304,7 @@ def test_rio_read_rgb(resamlpling):
roi, pix = rio_read(src, cfg, gb, selection=np.s_[:2])
assert len(roi) == 2
assert pix.ndim == 3
assert pix.shape == (2, *gb.shape)
assert pix.shape == expect_shape_2


def test_reader_unhappy_paths():
Expand Down
7 changes: 7 additions & 0 deletions odc/loader/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,13 @@ def strip(self) -> "RasterSource":
driver_data=self.driver_data,
)

@property
def ydim(self) -> int:
"""Index of y dimension, typically 0."""
if self.meta is None:
return 0
return self.meta.ydim

def __dask_tokenize__(self):
return (self.uri, self.band, self.subdataset)

Expand Down

0 comments on commit a16cb7b

Please sign in to comment.