From ca69b1fbc2147bc1ae5e821cfe78de6dcd6a1c07 Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Mon, 20 May 2024 19:21:10 +1000 Subject: [PATCH] Refactor ReaderDriver API ReaderDriver API: - `.new_load` receives GeoBox now - `chunks=` has to be named Reader API: - making `dst=` a kwarg - adding extra parameter for selecting subset of bands to read. - use `(slice, slice)` for spatial subset of the returned data --- odc/loader/_builder.py | 9 +++++-- odc/loader/_rio.py | 44 +++++++++++++++++++++------------- odc/loader/test_reader.py | 5 ++-- odc/loader/testing/fixtures.py | 32 ++++++++++++++++++------- odc/loader/types.py | 20 +++++++++++----- 5 files changed, 75 insertions(+), 35 deletions(-) diff --git a/odc/loader/_builder.py b/odc/loader/_builder.py index 2acd8c6..cd12e90 100644 --- a/odc/loader/_builder.py +++ b/odc/loader/_builder.py @@ -125,6 +125,9 @@ def __init__( rdr: ReaderDriver, time_chunks: int = 1, ) -> None: + gbox = gbt.base + assert isinstance(gbox, GeoBox) + self.cfg = cfg self.template = template self.srcs = srcs @@ -134,7 +137,9 @@ def __init__( self.rdr = rdr self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, time_chunks) self.chunk_tyx = (time_chunks, *self.gbt.chunk_shape((0, 0)).yx) - self._load_state = rdr.new_load(dict(zip(["time", "y", "x"], self.chunk_tyx))) + self._load_state = rdr.new_load( + gbox, chunks=dict(zip(["time", "y", "x"], self.chunk_tyx)) + ) def build( self, @@ -546,7 +551,7 @@ def direct_chunked_load( ) ny, nx = gbt.shape.yx total_tasks = nt * nb * ny * nx - load_state = rdr.new_load() + load_state = rdr.new_load(gbox) def _do_one(task: LoadChunkTask) -> Tuple[str, int, int, int]: dst_slice = ds[task.band].data[task.dst_roi] diff --git a/odc/loader/_rio.py b/odc/loader/_rio.py index 1b7a032..6a604e4 100644 --- a/odc/loader/_rio.py +++ b/odc/loader/_rio.py @@ -10,7 +10,7 @@ import logging import threading from contextlib import contextmanager -from typing import Any, Dict, Iterator, Optional, Tuple, Union +from typing import Any, Dict, Iterator, Optional, Union import numpy as np import rasterio @@ -20,7 +20,7 @@ from odc.geo.converters import rio_geobox from odc.geo.geobox import GeoBox from odc.geo.overlap import ReprojectInfo, compute_reproject_roi -from odc.geo.roi import NormalizedROI, roi_is_empty, roi_shape, w_ +from odc.geo.roi import roi_is_empty, roi_shape, w_ from odc.geo.warp import resampling_s2rio from rasterio.session import AWSSession, Session @@ -32,7 +32,7 @@ resolve_src_nodata, same_nodata, ) -from .types import MDParser, RasterLoadParams, RasterSource +from .types import MDParser, RasterLoadParams, RasterSource, ReaderSubsetSelection log = logging.getLogger(__name__) @@ -86,7 +86,8 @@ class LoaderState: TODO: open file handle cache goes here """ - def __init__(self, is_dask: bool) -> None: + def __init__(self, geobox: GeoBox, is_dask: bool) -> None: + self.geobox = geobox self.is_dask = is_dask def finalise(self) -> None: @@ -100,9 +101,11 @@ def read( self, cfg: RasterLoadParams, dst_geobox: GeoBox, + *, dst: Optional[np.ndarray] = None, - ) -> Tuple[NormalizedROI, np.ndarray]: - return rio_read(self._src, cfg, dst_geobox, dst=dst) + selection: Optional[ReaderSubsetSelection] = None, + ) -> tuple[tuple[slice, slice], np.ndarray]: + return rio_read(self._src, cfg, dst_geobox, dst=dst, selection=selection) class RioDriver: @@ -110,8 +113,13 @@ class RioDriver: Protocol for readers. """ - def new_load(self, chunks: None | Dict[str, int] = None) -> RioReader.LoaderState: - return RioReader.LoaderState(chunks is not None) + def new_load( + self, + geobox: GeoBox, + *, + chunks: None | Dict[str, int] = None, + ) -> RioReader.LoaderState: + return RioReader.LoaderState(geobox, is_dask=chunks is not None) def finalise_load(self, load_state: RioReader.LoaderState) -> Any: return load_state.finalise() @@ -373,7 +381,7 @@ def _do_read( dst_geobox: GeoBox, rr: ReprojectInfo, dst: Optional[np.ndarray] = None, -) -> Tuple[NormalizedROI, np.ndarray]: +) -> tuple[tuple[slice, slice], np.ndarray]: resampling = resampling_s2rio(cfg.resampling) rdr = src.ds @@ -387,15 +395,16 @@ def _do_read( src_nodata0 = rdr.nodatavals[src.bidx - 1] src_nodata = resolve_src_nodata(src_nodata0, cfg) dst_nodata = resolve_dst_nodata(_dst.dtype, cfg, src_nodata) + roi_dst: tuple[slice, slice] = rr.roi_dst # type: ignore - if roi_is_empty(rr.roi_dst): - return (rr.roi_dst, _dst) + if roi_is_empty(roi_dst): + return (roi_dst, _dst) if roi_is_empty(rr.roi_src): # no overlap case if dst_nodata is not None: np.copyto(_dst, dst_nodata) - return (rr.roi_dst, _dst) + return (roi_dst, _dst) if rr.paste_ok and rr.read_shrink == 1: rdr.read(src.bidx, out=_dst, window=w_[rr.roi_src]) @@ -418,7 +427,7 @@ def _do_read( resampling=resampling, ) - return (rr.roi_dst, _dst) + return (roi_dst, _dst) def rio_read( @@ -426,7 +435,8 @@ def rio_read( cfg: RasterLoadParams, dst_geobox: GeoBox, dst: Optional[np.ndarray] = None, -) -> Tuple[NormalizedROI, np.ndarray]: + selection: Optional[ReaderSubsetSelection] = None, +) -> tuple[tuple[slice, slice], np.ndarray]: """ Internal read method. @@ -450,7 +460,7 @@ def rio_read( """ try: - return _rio_read(src, cfg, dst_geobox, dst) + return _rio_read(src, cfg, dst_geobox, dst, selection=selection) except ( rasterio.errors.RasterioIOError, rasterio.errors.RasterBlockError, @@ -491,10 +501,12 @@ def _rio_read( cfg: RasterLoadParams, dst_geobox: GeoBox, dst: Optional[np.ndarray] = None, -) -> Tuple[NormalizedROI, np.ndarray]: + selection: Optional[ReaderSubsetSelection] = None, +) -> tuple[tuple[slice, slice], np.ndarray]: # if resampling is `nearest` then ignore sub-pixel translation when deciding # whether we can just paste source into destination ttol = 0.9 if cfg.nearest else 0.05 + assert selection is None, "Band selection not implemented in rio_read" with rasterio.open(src.uri, "r", sharing=False) as rdr: assert isinstance(rdr, rasterio.DatasetReader) diff --git a/odc/loader/test_reader.py b/odc/loader/test_reader.py index f7a057b..ef1fad5 100644 --- a/odc/loader/test_reader.py +++ b/odc/loader/test_reader.py @@ -69,8 +69,9 @@ def test_pick_overiew(): def test_rio_reader_env(): + gbox = GeoBox.from_bbox((-180, -90, 180, 90), shape=(160, 320), tight=True) rdr = RioDriver() - load_state = rdr.new_load() + load_state = rdr.new_load(gbox) configure_rio(cloud_defaults=True, verbose=True) env = rdr.capture_env() @@ -119,7 +120,7 @@ def test_rio_read(): # Going via RioReader should be the same rdr_driver = RioDriver() - load_state = rdr_driver.new_load() + load_state = rdr_driver.new_load(gbox) with rdr_driver.restore_env(rdr_driver.capture_env(), load_state) as ctx: rdr = rdr_driver.open(src, ctx) _roi, _pix = rdr.read(cfg, gbox) diff --git a/odc/loader/testing/fixtures.py b/odc/loader/testing/fixtures.py index b671572..aaf5c03 100644 --- a/odc/loader/testing/fixtures.py +++ b/odc/loader/testing/fixtures.py @@ -11,13 +11,12 @@ import tempfile from collections import abc from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterator, Optional, Tuple +from typing import Any, Dict, Generator, Iterator, Optional import numpy as np import rasterio import xarray as xr from odc.geo.geobox import GeoBox -from odc.geo.roi import NormalizedROI from odc.geo.xr import ODCExtensionDa from ..types import ( @@ -26,6 +25,7 @@ RasterGroupMetadata, RasterLoadParams, RasterSource, + ReaderSubsetSelection, ) # pylint: disable=too-few-public-methods @@ -124,15 +124,20 @@ class LoadState: """ def __init__( - self, group_md: RasterGroupMetadata, env: dict[str, Any], is_dask: bool + self, + geobox: GeoBox, + group_md: RasterGroupMetadata, + env: dict[str, Any], + is_dask: bool, ) -> None: + self.geobox = geobox self.group_md = group_md self.env = env self.is_dask = is_dask self.finalised = False def with_env(self, env: dict[str, Any]) -> "FakeReader.LoadState": - return FakeReader.LoadState(self.group_md, env, self.is_dask) + return FakeReader.LoadState(self.geobox, self.group_md, env, self.is_dask) def __init__(self, src: RasterSource, load_state: "FakeReader.LoadState"): self._src = src @@ -146,14 +151,18 @@ def read( self, cfg: RasterLoadParams, dst_geobox: GeoBox, + *, dst: Optional[np.ndarray] = None, - ) -> Tuple[NormalizedROI, np.ndarray]: + selection: Optional[ReaderSubsetSelection] = None, + ) -> tuple[tuple[slice, slice], np.ndarray]: meta = self._src.meta assert meta is not None + # TODO: handle selection + assert selection is None extra_dims = self._extra_dims() - prefix_dims: Tuple[int, ...] = () - postfix_dims: Tuple[int, ...] = () + prefix_dims: tuple[int, ...] = () + postfix_dims: tuple[int, ...] = () ydim = cfg.ydim if len(cfg.dims) > 2: @@ -198,8 +207,13 @@ def __init__( self._group_md = group_md self._parser = parser or FakeMDPlugin(group_md, None) - def new_load(self, chunks: None | Dict[str, int] = None) -> FakeReader.LoadState: - return FakeReader.LoadState(self._group_md, {}, chunks is not None) + def new_load( + self, + geobox: GeoBox, + *, + chunks: None | Dict[str, int] = None, + ) -> FakeReader.LoadState: + return FakeReader.LoadState(geobox, self._group_md, {}, chunks is not None) def finalise_load(self, load_state: FakeReader.LoadState) -> Any: assert load_state.finalised is False diff --git a/odc/loader/types.py b/odc/loader/types.py index 8689853..1e418d9 100644 --- a/odc/loader/types.py +++ b/odc/loader/types.py @@ -18,7 +18,6 @@ import numpy as np from odc.geo.geobox import GeoBox -from odc.geo.roi import NormalizedROI T = TypeVar("T") @@ -31,6 +30,8 @@ BandQuery = Optional[Union[str, Sequence[str]]] """One|All|Some bands""" +ReaderSubsetSelection = Any + @dataclass(eq=True, frozen=True) class RasterBandMetadata: @@ -346,7 +347,7 @@ class MDParser(Protocol): Protocol for metadata parsers. - Parse group level metadata - - data bands andn their expected type + - data bands and their expected type - extra dimensions and coordinates - Extract driver specific data """ @@ -366,8 +367,10 @@ def read( self, cfg: RasterLoadParams, dst_geobox: GeoBox, + *, dst: Optional[np.ndarray] = None, - ) -> Tuple[NormalizedROI, np.ndarray]: ... + selection: Optional[ReaderSubsetSelection] = None, + ) -> tuple[tuple[slice, slice], np.ndarray]: ... class ReaderDriver(Protocol): @@ -375,14 +378,19 @@ class ReaderDriver(Protocol): Protocol for reader drivers. """ - def new_load(self, chunks: None | Dict[str, int] = None) -> Any: ... + def new_load( + self, + geobox: GeoBox, + *, + chunks: None | Dict[str, int] = None, + ) -> Any: ... def finalise_load(self, load_state: Any) -> Any: ... - def capture_env(self) -> Dict[str, Any]: ... + def capture_env(self) -> dict[str, Any]: ... def restore_env( - self, env: Dict[str, Any], load_state: Any + self, env: dict[str, Any], load_state: Any ) -> ContextManager[Any]: ... def open(self, src: RasterSource, ctx: Any) -> RasterReader: ...