Skip to content

Commit

Permalink
sqme: driver
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Feb 19, 2024
1 parent 1a725a4 commit 5358900
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 101 deletions.
70 changes: 44 additions & 26 deletions odc/loader/_rio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import logging
import threading
from typing import Any, ContextManager, Dict, Optional, Tuple, Union
from contextlib import contextmanager
from typing import Any, Dict, Iterator, Optional, Tuple, Union

import numpy as np
import rasterio
Expand Down Expand Up @@ -70,50 +71,67 @@
"GDAL_HTTP_RETRY_DELAY": "0.5",
}

# pylint: disable=too-few-public-methods

class RioDriver:

class RioReader:
"""
Protocol for readers.
Reader part of RIO driver.
"""

class Reader:
class LoaderState:
"""
Reader part of RIO driver.
Shared across all Readers for single ``.load``.
TODO: open file handle cache goes here
"""

# pylint: disable=too-few-public-methods
def __init__(self, is_dask: bool) -> None:
self.is_dask = is_dask

def finalise(self) -> None:
pass

def __init__(self, src: RasterSource, ctx: "RioReader.LoaderState") -> None:
self._src = src
self._ctx = ctx

def read(
self,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
dst: Optional[np.ndarray] = None,
) -> Tuple[NormalizedROI, np.ndarray]:
return rio_read(self._src, cfg, dst_geobox, dst=dst)

def __init__(self, src: RasterSource, ctx: Any) -> None:
self._src = src
self._ctx = ctx

def read(
self,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
dst: Optional[np.ndarray] = None,
) -> Tuple[NormalizedROI, np.ndarray]:
return rio_read(self._src, cfg, dst_geobox, dst=dst)
class RioDriver:
"""
Protocol for readers.
"""

def new_load(self, chunks: None | Dict[str, int] = None) -> Any:
return {"is_dask": chunks is not None}
def new_load(self, chunks: None | Dict[str, int] = None) -> RioReader.LoaderState:
return RioReader.LoaderState(chunks is not None)

def finalise_load(self, load_state: Any) -> Any:
return load_state
def finalise_load(self, load_state: RioReader.LoaderState) -> Any:
return load_state.finalise()

def capture_env(self) -> Dict[str, Any]:
return capture_rio_env()

def restore_env(self, env: Dict[str, Any], load_state: Any) -> ContextManager[Any]:
assert load_state is not None
return rio_env(**env)
@contextmanager
def restore_env(
self, env: Dict[str, Any], load_state: RioReader.LoaderState
) -> Iterator[RioReader.LoaderState]:
with rio_env(**env):
yield load_state

def open(
self,
src: RasterSource,
ctx: Any,
) -> "RioDriver.Reader":
return RioDriver.Reader(src, ctx)
ctx: RioReader.LoaderState,
) -> RioReader:
return RioReader(src, ctx)

@property
def md_parser(self) -> Optional[MDParser]:
Expand Down
143 changes: 68 additions & 75 deletions odc/loader/testing/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import tempfile
from collections import abc
from contextlib import contextmanager
from typing import Any, ContextManager, Dict, Generator, Optional, Tuple
from typing import Any, Dict, Generator, Iterator, Optional, Tuple

import numpy as np
import rasterio
Expand All @@ -28,6 +28,8 @@
RasterSource,
)

# pylint: disable=too-few-public-methods


@contextmanager
def with_temp_tiff(data: xr.DataArray, **cog_opts) -> Generator[str, None, None]:
Expand Down Expand Up @@ -111,87 +113,78 @@ def driver_data(self, md, band_key: BandKey) -> Any:
return self._driver_data


class FakeReaderDriver:
class FakeReader:
"""
Fake reader for testing.
"""

class Context:
class LoadState:
"""
EMIT Context manager.
Shared state for all readers for a given load.
"""

def __init__(
self,
parent: "FakeReaderDriver",
env: dict[str, Any],
load_state: Any,
self, group_md: RasterGroupMetadata, env: dict[str, Any], is_dask: bool
) -> None:
self._parent = parent
self.group_md = group_md
self.env = env
self.load_state = load_state
self.is_dask = is_dask

def __enter__(self):
return self._parent, self.env, self.load_state
def with_env(self, env: dict[str, Any]) -> "FakeReader.LoadState":
return FakeReader.LoadState(self.group_md, env, self.is_dask)

def __exit__(self, type, value, traceback):
# pylint: disable=unused-argument,redefined-builtin
pass
def __init__(self, src: RasterSource, load_state: "FakeReader.LoadState"):
self._src = src
self._load_state = load_state

class Reader:
"""
Fake reader for testing.
"""
def _extra_dims(self) -> Dict[str, int]:
md = self._load_state.group_md
return md.extra_dims or {
coord.dim: len(coord.values) for coord in md.extra_coords
}

def read(
self,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
dst: Optional[np.ndarray] = None,
) -> Tuple[NormalizedROI, np.ndarray]:
meta = self._src.meta
assert meta is not None

extra_dims = self._extra_dims()
postfix_dims: Tuple[int, ...] = ()
if meta.dims is not None:
assert set(meta.dims[2:]).issubset(extra_dims)
postfix_dims = tuple(extra_dims[d] for d in meta.dims[2:])

ny, nx = dst_geobox.shape.yx
yx_roi = (slice(0, ny), slice(0, nx))
shape = (ny, nx, *postfix_dims)

src_pix: np.ndarray | None = self._src.driver_data
if src_pix is None:
src_pix = np.ones(shape, dtype=cfg.dtype)
else:
assert src_pix.shape == shape

assert postfix_dims == src_pix.shape[2:]

# pylint: disable=too-few-public-methods

def __init__(self, src: RasterSource, load_state: Dict[str, Any]):
self._src = src
self._group_md: RasterGroupMetadata = load_state["group_md"]
self._is_dask = load_state["is_dask"]

def _extra_dims(self) -> Dict[str, int]:
md = self._group_md
return md.extra_dims or {
coord.dim: len(coord.values) for coord in md.extra_coords
}

def read(
self,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
dst: Optional[np.ndarray] = None,
) -> Tuple[NormalizedROI, np.ndarray]:
meta = self._src.meta
assert meta is not None

extra_dims = self._extra_dims()
postfix_dims: Tuple[int, ...] = ()
if meta.dims is not None:
assert set(meta.dims[2:]).issubset(extra_dims)
postfix_dims = tuple(extra_dims[d] for d in meta.dims[2:])

ny, nx = dst_geobox.shape.yx
yx_roi = (slice(0, ny), slice(0, nx))
shape = (ny, nx, *postfix_dims)

src_pix: np.ndarray | None = self._src.driver_data
if src_pix is None:
src_pix = np.ones(shape, dtype=cfg.dtype)
else:
assert src_pix.shape == shape

assert postfix_dims == src_pix.shape[2:]

if dst is None:
dst = np.zeros((ny, nx, *postfix_dims), dtype=cfg.dtype)
dst[:] = src_pix.astype(dst.dtype)
return yx_roi, dst

assert dst.shape == src_pix.shape
if dst is None:
dst = np.zeros((ny, nx, *postfix_dims), dtype=cfg.dtype)
dst[:] = src_pix.astype(dst.dtype)
return yx_roi, dst

return yx_roi, dst[yx_roi]
assert dst.shape == src_pix.shape
dst[:] = src_pix.astype(dst.dtype)

return yx_roi, dst[yx_roi]


class FakeReaderDriver:
"""
Fake reader for testing.
"""

def __init__(
self,
Expand All @@ -202,8 +195,8 @@ def __init__(
self._group_md = group_md
self._parser = parser or FakeMDPlugin(group_md, None)

def new_load(self, chunks: None | Dict[str, int] = None) -> Any:
return {"is_dask": chunks is not None, "group_md": self._group_md}
def new_load(self, chunks: None | Dict[str, int] = None) -> FakeReader.LoadState:
return FakeReader.LoadState(self._group_md, {}, chunks is not None)

def finalise_load(self, load_state: Any) -> Any:
assert "findalised" not in load_state
Expand All @@ -213,14 +206,14 @@ def finalise_load(self, load_state: Any) -> Any:
def capture_env(self) -> Dict[str, Any]:
return {}

def restore_env(self, env: Dict[str, Any], load_state: Any) -> ContextManager[Any]:
return self.Context(self, env, load_state)
@contextmanager
def restore_env(
self, env: Dict[str, Any], load_state: FakeReader.LoadState
) -> Iterator[FakeReader.LoadState]:
yield load_state.with_env(env)

def open(self, src: RasterSource, ctx: Any) -> "FakeReaderDriver.Reader":
_self, env, load_state = ctx
assert _self is self
assert env == {}
return FakeReaderDriver.Reader(src, load_state)
def open(self, src: RasterSource, ctx: FakeReader.LoadState) -> FakeReader:
return FakeReader(src, ctx)

@property
def md_parser(self) -> MDParser | None:
Expand Down

0 comments on commit 5358900

Please sign in to comment.