From 552ce70336a678184b21aee0814023a6a1aa96a1 Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Mon, 3 Jun 2024 13:10:16 +1000 Subject: [PATCH] Adding DaskRasterReader protocol Optional feature of the ReaderDriver. --- odc/loader/_rio.py | 14 ++++++++++++-- odc/loader/testing/fixtures.py | 5 +++++ odc/loader/testing/mem_reader.py | 5 +++++ odc/loader/types.py | 27 +++++++++++++++++++++++++++ 4 files changed, 49 insertions(+), 2 deletions(-) diff --git a/odc/loader/_rio.py b/odc/loader/_rio.py index 436cd04..29ce31e 100644 --- a/odc/loader/_rio.py +++ b/odc/loader/_rio.py @@ -33,7 +33,13 @@ resolve_src_nodata, same_nodata, ) -from .types import MDParser, RasterLoadParams, RasterSource, ReaderSubsetSelection +from .types import ( + DaskRasterReader, + MDParser, + RasterLoadParams, + RasterSource, + ReaderSubsetSelection, +) log = logging.getLogger(__name__) @@ -143,7 +149,11 @@ def open( return RioReader(src, ctx) @property - def md_parser(self) -> Optional[MDParser]: + def md_parser(self) -> MDParser | None: + return None + + @property + def dask_reader(self) -> DaskRasterReader | None: return None diff --git a/odc/loader/testing/fixtures.py b/odc/loader/testing/fixtures.py index e904bfb..69c495a 100644 --- a/odc/loader/testing/fixtures.py +++ b/odc/loader/testing/fixtures.py @@ -22,6 +22,7 @@ from .._reader import expand_selection from ..types import ( BandKey, + DaskRasterReader, MDParser, RasterGroupMetadata, RasterLoadParams, @@ -239,3 +240,7 @@ def open(self, src: RasterSource, ctx: FakeReader.LoadState) -> FakeReader: @property def md_parser(self) -> MDParser | None: return self._parser + + @property + def dask_reader(self) -> DaskRasterReader | None: + return None diff --git a/odc/loader/testing/mem_reader.py b/odc/loader/testing/mem_reader.py index 0ad0641..8e8e5c2 100644 --- a/odc/loader/testing/mem_reader.py +++ b/odc/loader/testing/mem_reader.py @@ -14,6 +14,7 @@ from ..types import ( BandKey, + DaskRasterReader, FixedCoord, MDParser, RasterBandMetadata, @@ -150,6 +151,10 @@ def open(self, src: RasterSource, ctx: Context) -> XrMemReader: def md_parser(self) -> MDParser: return XrMDPlugin(self.src) + @property + def dask_reader(self) -> DaskRasterReader | None: + return None + def band_info(xx: xr.DataArray) -> RasterBandMetadata: """ diff --git a/odc/loader/types.py b/odc/loader/types.py index 551e6fe..5212852 100644 --- a/odc/loader/types.py +++ b/odc/loader/types.py @@ -445,6 +445,30 @@ def read( ) -> tuple[tuple[slice, slice], np.ndarray]: ... +class DaskRasterReader(Protocol): + """ + Protocol for raster readers that produce Dask sub-graphs. + + ``.read`` method should return a Dask future evaluating to a numpy array of + pixels for a given geobox, alternatively dask future may evaluate to a + subset of the geobox overlapping with the source. In this case Dask future + should evaluate to a tuple: ``(yx_slice, pixels)``, such that + ``dst_geobox[yx_slice].shape == pixels.shape[ydim:ydim+2]``. + """ + + # pylint: disable=too-few-public-methods + + def read( + self, + cfg: RasterLoadParams, + dst_geobox: GeoBox, + *, + selection: Optional[ReaderSubsetSelection] = None, + ) -> Any: ... + + def open(self, src: RasterSource, ctx: Any) -> "DaskRasterReader": ... + + class ReaderDriver(Protocol): """ Protocol for reader drivers. @@ -470,6 +494,9 @@ def open(self, src: RasterSource, ctx: Any) -> RasterReader: ... @property def md_parser(self) -> MDParser | None: ... + @property + def dask_reader(self) -> DaskRasterReader | None: ... + ReaderDriverSpec = Union[str, ReaderDriver]