Skip to content

Commit

Permalink
refactor: mem reader
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Jun 21, 2024
1 parent 634e590 commit 62295b7
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 28 deletions.
27 changes: 25 additions & 2 deletions odc/loader/test_memreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from odc.geo.geobox import GeoBox
from odc.geo.xr import ODCExtensionDa, rasterize

from odc.loader.testing.mem_reader import XrMemReader, XrMemReaderDriver
from odc.loader.types import RasterGroupMetadata, RasterLoadParams, RasterSource
from .testing.mem_reader import XrMemReader, XrMemReaderDriver, raster_group_md
from .types import FixedCoord, RasterGroupMetadata, RasterLoadParams, RasterSource

# pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison
# pylint: disable=too-many-locals,too-many-statements
Expand Down Expand Up @@ -126,3 +126,26 @@ def test_mem_reader() -> None:
loader = loaders["zz"]
roi, pix = loader.read(cfgs["zz"], gbox, selection=np.s_[:2])
assert pix.shape == (2, *gbox.shape.yx)


def test_raster_group_md():
rgm = raster_group_md(xr.Dataset())
assert rgm.bands == {}
assert rgm.aliases == {}
assert rgm.extra_dims == {}

coord = FixedCoord("band", ["r", "g", "b"], dim="band")

rgm = raster_group_md(
xr.Dataset(), base=RasterGroupMetadata({}, {}, {"band": 3}, [])
)
assert rgm.extra_dims == {"band": 3}
assert len(rgm.extra_coords) == 0

rgm = raster_group_md(
xr.Dataset(), base=RasterGroupMetadata({}, extra_coords=[coord])
)
assert rgm.extra_dims == {}
assert rgm.extra_dims_full() == {"band": 3}
assert len(rgm.extra_coords) == 1
assert rgm.extra_coords[0] == coord
141 changes: 115 additions & 26 deletions odc/loader/testing/mem_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

from __future__ import annotations

import json
from contextlib import contextmanager
from typing import Any, Iterator
from typing import Any, Iterator, Sequence

import fsspec
import numpy as np
import xarray as xr
from odc.geo.geobox import GeoBox
from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_reproject
from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_coords, xr_reproject

from ..types import (
BandKey,
Expand All @@ -25,6 +27,18 @@
)


def extract_zarr_spec(src: dict[str, Any]) -> dict[str, Any] | None:
if "zarr:metadata" in src:
# TODO: handle zarr:chunks for reference filesystem
zmd = {"zarr_consolidated_format": 1, "metadata": src["zarr:metadata"]}
elif "zarr_consolidated_format" in src:
zmd = src
else:
zmd = {"zarr_consolidated_format": 1, "metadata": src}

return {".zmetadata": json.dumps(zmd)}


class XrMDPlugin:
"""
Convert xarray.Dataset to RasterGroupMetadata.
Expand All @@ -35,22 +49,65 @@ class XrMDPlugin:
- Driver data is xarray.DataArray for each band
"""

def __init__(self, src: xr.Dataset) -> None:
def __init__(
self,
template: RasterGroupMetadata,
src: xr.Dataset | None = None,
) -> None:
self._template = template
self._src = src
self._md = raster_group_md(src)

def _from_zarr_spec(
self,
spec_doc: dict[str, Any],
regen_coords: bool = False,
) -> xr.Dataset:
rfs = fsspec.filesystem("reference", fo=spec_doc)
xx = xr.open_dataset(rfs.get_mapper(""), engine="zarr")
gbox = xx.odc.geobox
if gbox is not None and regen_coords:
# re-gen x,y coords from geobox
xx = xx.assign_coords(xr_coords(gbox))

return xx

def _resolve_src(self, md: Any, regen_coords: bool = False) -> xr.Dataset | None:
src = self._src

if isinstance(md, dict) and (spec_doc := extract_zarr_spec(md)) is not None:
src = self._from_zarr_spec(spec_doc, regen_coords=regen_coords)

if isinstance(md, xr.Dataset):
src = md

# TODO: support stac items and datacube datasets

return src

def extract(self, md: Any) -> RasterGroupMetadata:
"""Fixed description of src dataset."""
assert md is not None
return self._md
if isinstance(md, RasterGroupMetadata):
return md

if (src := self._resolve_src(md, regen_coords=False)) is not None:
return raster_group_md(src, base=self._template)

return self._template

def driver_data(self, md: Any, band_key: BandKey) -> xr.DataArray:
def driver_data(self, md: Any, band_key: BandKey) -> xr.DataArray | None:
"""
Extract driver specific data for a given band.
"""
assert md is not None
name, _ = band_key
return self._src[name]

if isinstance(md, xr.DataArray):
return md

if (src := self._resolve_src(md, regen_coords=True)) is not None:
if (aa := src.data_vars.get(name)) is not None:
return aa

return None


class Context:
Expand All @@ -60,17 +117,15 @@ class Context:

def __init__(
self,
src: xr.Dataset,
geobox: GeoBox,
chunks: None | dict[str, int],
) -> None:
self.src = src
self.geobox = geobox
self.chunks = chunks

def with_env(self, env: dict[str, Any]) -> "Context":
assert isinstance(env, dict)
return Context(self.src, self.geobox, self.chunks)
return Context(self.geobox, self.chunks)


class XrMemReader:
Expand All @@ -81,7 +136,6 @@ class XrMemReader:
# pylint: disable=too-few-public-methods

def __init__(self, src: RasterSource, ctx: Context) -> None:
self._src = src
self._xx: xr.DataArray = src.driver_data
self._ctx = ctx

Expand Down Expand Up @@ -121,16 +175,25 @@ class XrMemReaderDriver:

Reader = XrMemReader

def __init__(self, src: xr.Dataset) -> None:
def __init__(
self,
src: xr.Dataset | None = None,
template: RasterGroupMetadata | None = None,
) -> None:
if src is not None and template is None:
template = raster_group_md(src)
if template is None:
template = RasterGroupMetadata({}, {}, {}, [])
self.src = src
self.template = template

def new_load(
self,
geobox: GeoBox,
*,
chunks: None | dict[str, int] = None,
) -> Context:
return Context(self.src, geobox, chunks)
return Context(geobox, chunks)

def finalise_load(self, load_state: Context) -> Context:
return load_state
Expand All @@ -149,7 +212,7 @@ def open(self, src: RasterSource, ctx: Context) -> XrMemReader:

@property
def md_parser(self) -> MDParser:
return XrMDPlugin(self.src)
return XrMDPlugin(self.template, src=self.src)

@property
def dask_reader(self) -> DaskRasterReader | None:
Expand Down Expand Up @@ -177,25 +240,46 @@ def band_info(xx: xr.DataArray) -> RasterBandMetadata:
)


def raster_group_md(src: xr.Dataset) -> RasterGroupMetadata:
def raster_group_md(
src: xr.Dataset,
*,
base: RasterGroupMetadata | None = None,
aliases: dict[str, list[BandKey]] | None = None,
extra_coords: Sequence[FixedCoord] = (),
extra_dims: dict[str, int] | None = None,
) -> RasterGroupMetadata:
oo: ODCExtensionDs = src.odc
sdims = oo.spatial_dims or ("y", "x")

bands: dict[BandKey, RasterBandMetadata] = {
(str(k), 1): band_info(v) for k, v in src.data_vars.items() if v.ndim >= 2
}
if base is None:
base = RasterGroupMetadata(
bands={},
aliases=aliases or {},
extra_coords=extra_coords,
extra_dims=extra_dims or {},
)

bands = base.bands.copy()
bands.update(
{(str(k), 1): band_info(v) for k, v in src.data_vars.items() if v.ndim >= 2}
)

edims = base.extra_dims.copy()
edims.update({str(name): sz for name, sz in src.sizes.items() if name not in sdims})

extra_dims: dict[str, int] = {
str(name): sz for name, sz in src.sizes.items() if name not in sdims
}
aliases: dict[str, list[BandKey]] = base.aliases.copy()

aliases: dict[str, list[BandKey]] = {}
extra_coords: list[FixedCoord] = list(base.extra_coords)
supplied_coords = set(coord.name for coord in extra_coords)

extra_coords: list[FixedCoord] = []
for coord in src.coords.values():
if len(coord.dims) != 1 or coord.dims[0] in sdims:
# Only 1-d non-spatial coords
continue

if coord.name in supplied_coords:
continue

extra_coords.append(
FixedCoord(
coord.name,
Expand All @@ -205,4 +289,9 @@ def raster_group_md(src: xr.Dataset) -> RasterGroupMetadata:
)
)

return RasterGroupMetadata(bands, aliases, extra_dims, extra_coords)
return RasterGroupMetadata(
bands=bands,
aliases=aliases,
extra_dims=edims,
extra_coords=extra_coords,
)

0 comments on commit 62295b7

Please sign in to comment.