Skip to content

Commit

Permalink
Refactor ReaderDriver API
Browse files Browse the repository at this point in the history
ReaderDriver API:
- `.new_load` receives GeoBox now
- `chunks=` has to be named

Reader API:
- making `dst=` a kwarg
- adding extra parameter for selecting subset of
bands to read.
- use `(slice, slice)` for spatial subset of the
  returned data
  • Loading branch information
Kirill888 committed May 23, 2024
1 parent 713e6f2 commit ca69b1f
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 35 deletions.
9 changes: 7 additions & 2 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def __init__(
rdr: ReaderDriver,
time_chunks: int = 1,
) -> None:
gbox = gbt.base
assert isinstance(gbox, GeoBox)

self.cfg = cfg
self.template = template
self.srcs = srcs
Expand All @@ -134,7 +137,9 @@ def __init__(
self.rdr = rdr
self._tk = tokenize(srcs, cfg, gbt, tyx_bins, env, time_chunks)
self.chunk_tyx = (time_chunks, *self.gbt.chunk_shape((0, 0)).yx)
self._load_state = rdr.new_load(dict(zip(["time", "y", "x"], self.chunk_tyx)))
self._load_state = rdr.new_load(
gbox, chunks=dict(zip(["time", "y", "x"], self.chunk_tyx))
)

def build(
self,
Expand Down Expand Up @@ -546,7 +551,7 @@ def direct_chunked_load(
)
ny, nx = gbt.shape.yx
total_tasks = nt * nb * ny * nx
load_state = rdr.new_load()
load_state = rdr.new_load(gbox)

def _do_one(task: LoadChunkTask) -> Tuple[str, int, int, int]:
dst_slice = ds[task.band].data[task.dst_roi]
Expand Down
44 changes: 28 additions & 16 deletions odc/loader/_rio.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
import threading
from contextlib import contextmanager
from typing import Any, Dict, Iterator, Optional, Tuple, Union
from typing import Any, Dict, Iterator, Optional, Union

import numpy as np
import rasterio
Expand All @@ -20,7 +20,7 @@
from odc.geo.converters import rio_geobox
from odc.geo.geobox import GeoBox
from odc.geo.overlap import ReprojectInfo, compute_reproject_roi
from odc.geo.roi import NormalizedROI, roi_is_empty, roi_shape, w_
from odc.geo.roi import roi_is_empty, roi_shape, w_
from odc.geo.warp import resampling_s2rio
from rasterio.session import AWSSession, Session

Expand All @@ -32,7 +32,7 @@
resolve_src_nodata,
same_nodata,
)
from .types import MDParser, RasterLoadParams, RasterSource
from .types import MDParser, RasterLoadParams, RasterSource, ReaderSubsetSelection

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,7 +86,8 @@ class LoaderState:
TODO: open file handle cache goes here
"""

def __init__(self, is_dask: bool) -> None:
def __init__(self, geobox: GeoBox, is_dask: bool) -> None:
self.geobox = geobox
self.is_dask = is_dask

def finalise(self) -> None:
Expand All @@ -100,18 +101,25 @@ def read(
self,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
*,
dst: Optional[np.ndarray] = None,
) -> Tuple[NormalizedROI, np.ndarray]:
return rio_read(self._src, cfg, dst_geobox, dst=dst)
selection: Optional[ReaderSubsetSelection] = None,
) -> tuple[tuple[slice, slice], np.ndarray]:
return rio_read(self._src, cfg, dst_geobox, dst=dst, selection=selection)


class RioDriver:
"""
Protocol for readers.
"""

def new_load(self, chunks: None | Dict[str, int] = None) -> RioReader.LoaderState:
return RioReader.LoaderState(chunks is not None)
def new_load(
self,
geobox: GeoBox,
*,
chunks: None | Dict[str, int] = None,
) -> RioReader.LoaderState:
return RioReader.LoaderState(geobox, is_dask=chunks is not None)

def finalise_load(self, load_state: RioReader.LoaderState) -> Any:
return load_state.finalise()
Expand Down Expand Up @@ -373,7 +381,7 @@ def _do_read(
dst_geobox: GeoBox,
rr: ReprojectInfo,
dst: Optional[np.ndarray] = None,
) -> Tuple[NormalizedROI, np.ndarray]:
) -> tuple[tuple[slice, slice], np.ndarray]:
resampling = resampling_s2rio(cfg.resampling)
rdr = src.ds

Expand All @@ -387,15 +395,16 @@ def _do_read(
src_nodata0 = rdr.nodatavals[src.bidx - 1]
src_nodata = resolve_src_nodata(src_nodata0, cfg)
dst_nodata = resolve_dst_nodata(_dst.dtype, cfg, src_nodata)
roi_dst: tuple[slice, slice] = rr.roi_dst # type: ignore

if roi_is_empty(rr.roi_dst):
return (rr.roi_dst, _dst)
if roi_is_empty(roi_dst):
return (roi_dst, _dst)

if roi_is_empty(rr.roi_src):
# no overlap case
if dst_nodata is not None:
np.copyto(_dst, dst_nodata)
return (rr.roi_dst, _dst)
return (roi_dst, _dst)

if rr.paste_ok and rr.read_shrink == 1:
rdr.read(src.bidx, out=_dst, window=w_[rr.roi_src])
Expand All @@ -418,15 +427,16 @@ def _do_read(
resampling=resampling,
)

return (rr.roi_dst, _dst)
return (roi_dst, _dst)


def rio_read(
src: RasterSource,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
dst: Optional[np.ndarray] = None,
) -> Tuple[NormalizedROI, np.ndarray]:
selection: Optional[ReaderSubsetSelection] = None,
) -> tuple[tuple[slice, slice], np.ndarray]:
"""
Internal read method.
Expand All @@ -450,7 +460,7 @@ def rio_read(
"""

try:
return _rio_read(src, cfg, dst_geobox, dst)
return _rio_read(src, cfg, dst_geobox, dst, selection=selection)
except (
rasterio.errors.RasterioIOError,
rasterio.errors.RasterBlockError,
Expand Down Expand Up @@ -491,10 +501,12 @@ def _rio_read(
cfg: RasterLoadParams,
dst_geobox: GeoBox,
dst: Optional[np.ndarray] = None,
) -> Tuple[NormalizedROI, np.ndarray]:
selection: Optional[ReaderSubsetSelection] = None,
) -> tuple[tuple[slice, slice], np.ndarray]:
# if resampling is `nearest` then ignore sub-pixel translation when deciding
# whether we can just paste source into destination
ttol = 0.9 if cfg.nearest else 0.05
assert selection is None, "Band selection not implemented in rio_read"

with rasterio.open(src.uri, "r", sharing=False) as rdr:
assert isinstance(rdr, rasterio.DatasetReader)
Expand Down
5 changes: 3 additions & 2 deletions odc/loader/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def test_pick_overiew():


def test_rio_reader_env():
gbox = GeoBox.from_bbox((-180, -90, 180, 90), shape=(160, 320), tight=True)
rdr = RioDriver()
load_state = rdr.new_load()
load_state = rdr.new_load(gbox)

configure_rio(cloud_defaults=True, verbose=True)
env = rdr.capture_env()
Expand Down Expand Up @@ -119,7 +120,7 @@ def test_rio_read():

# Going via RioReader should be the same
rdr_driver = RioDriver()
load_state = rdr_driver.new_load()
load_state = rdr_driver.new_load(gbox)
with rdr_driver.restore_env(rdr_driver.capture_env(), load_state) as ctx:
rdr = rdr_driver.open(src, ctx)
_roi, _pix = rdr.read(cfg, gbox)
Expand Down
32 changes: 23 additions & 9 deletions odc/loader/testing/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
import tempfile
from collections import abc
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, Optional, Tuple
from typing import Any, Dict, Generator, Iterator, Optional

import numpy as np
import rasterio
import xarray as xr
from odc.geo.geobox import GeoBox
from odc.geo.roi import NormalizedROI
from odc.geo.xr import ODCExtensionDa

from ..types import (
Expand All @@ -26,6 +25,7 @@
RasterGroupMetadata,
RasterLoadParams,
RasterSource,
ReaderSubsetSelection,
)

# pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -124,15 +124,20 @@ class LoadState:
"""

def __init__(
self, group_md: RasterGroupMetadata, env: dict[str, Any], is_dask: bool
self,
geobox: GeoBox,
group_md: RasterGroupMetadata,
env: dict[str, Any],
is_dask: bool,
) -> None:
self.geobox = geobox
self.group_md = group_md
self.env = env
self.is_dask = is_dask
self.finalised = False

def with_env(self, env: dict[str, Any]) -> "FakeReader.LoadState":
return FakeReader.LoadState(self.group_md, env, self.is_dask)
return FakeReader.LoadState(self.geobox, self.group_md, env, self.is_dask)

def __init__(self, src: RasterSource, load_state: "FakeReader.LoadState"):
self._src = src
Expand All @@ -146,14 +151,18 @@ def read(
self,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
*,
dst: Optional[np.ndarray] = None,
) -> Tuple[NormalizedROI, np.ndarray]:
selection: Optional[ReaderSubsetSelection] = None,
) -> tuple[tuple[slice, slice], np.ndarray]:
meta = self._src.meta
assert meta is not None
# TODO: handle selection
assert selection is None

extra_dims = self._extra_dims()
prefix_dims: Tuple[int, ...] = ()
postfix_dims: Tuple[int, ...] = ()
prefix_dims: tuple[int, ...] = ()
postfix_dims: tuple[int, ...] = ()
ydim = cfg.ydim

if len(cfg.dims) > 2:
Expand Down Expand Up @@ -198,8 +207,13 @@ def __init__(
self._group_md = group_md
self._parser = parser or FakeMDPlugin(group_md, None)

def new_load(self, chunks: None | Dict[str, int] = None) -> FakeReader.LoadState:
return FakeReader.LoadState(self._group_md, {}, chunks is not None)
def new_load(
self,
geobox: GeoBox,
*,
chunks: None | Dict[str, int] = None,
) -> FakeReader.LoadState:
return FakeReader.LoadState(geobox, self._group_md, {}, chunks is not None)

def finalise_load(self, load_state: FakeReader.LoadState) -> Any:
assert load_state.finalised is False
Expand Down
20 changes: 14 additions & 6 deletions odc/loader/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import numpy as np
from odc.geo.geobox import GeoBox
from odc.geo.roi import NormalizedROI

T = TypeVar("T")

Expand All @@ -31,6 +30,8 @@
BandQuery = Optional[Union[str, Sequence[str]]]
"""One|All|Some bands"""

ReaderSubsetSelection = Any


@dataclass(eq=True, frozen=True)
class RasterBandMetadata:
Expand Down Expand Up @@ -346,7 +347,7 @@ class MDParser(Protocol):
Protocol for metadata parsers.
- Parse group level metadata
- data bands andn their expected type
- data bands and their expected type
- extra dimensions and coordinates
- Extract driver specific data
"""
Expand All @@ -366,23 +367,30 @@ def read(
self,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
*,
dst: Optional[np.ndarray] = None,
) -> Tuple[NormalizedROI, np.ndarray]: ...
selection: Optional[ReaderSubsetSelection] = None,
) -> tuple[tuple[slice, slice], np.ndarray]: ...


class ReaderDriver(Protocol):
"""
Protocol for reader drivers.
"""

def new_load(self, chunks: None | Dict[str, int] = None) -> Any: ...
def new_load(
self,
geobox: GeoBox,
*,
chunks: None | Dict[str, int] = None,
) -> Any: ...

def finalise_load(self, load_state: Any) -> Any: ...

def capture_env(self) -> Dict[str, Any]: ...
def capture_env(self) -> dict[str, Any]: ...

def restore_env(
self, env: Dict[str, Any], load_state: Any
self, env: dict[str, Any], load_state: Any
) -> ContextManager[Any]: ...

def open(self, src: RasterSource, ctx: Any) -> RasterReader: ...
Expand Down

0 comments on commit ca69b1f

Please sign in to comment.