Skip to content

Commit

Permalink
sqme: xrmemreader
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Jun 21, 2024
1 parent 81a7e32 commit e62615b
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 31 deletions.
27 changes: 25 additions & 2 deletions odc/loader/test_memreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from odc.geo.geobox import GeoBox
from odc.geo.xr import ODCExtensionDa, rasterize

from odc.loader.testing.mem_reader import XrMemReader, XrMemReaderDriver
from odc.loader.types import RasterGroupMetadata, RasterLoadParams, RasterSource
from .testing.mem_reader import XrMemReader, XrMemReaderDriver, raster_group_md
from .types import FixedCoord, RasterGroupMetadata, RasterLoadParams, RasterSource

# pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison
# pylint: disable=too-many-locals,too-many-statements
Expand Down Expand Up @@ -126,3 +126,26 @@ def test_mem_reader() -> None:
loader = loaders["zz"]
roi, pix = loader.read(cfgs["zz"], gbox, selection=np.s_[:2])
assert pix.shape == (2, *gbox.shape.yx)


def test_raster_group_md():
rgm = raster_group_md(xr.Dataset())
assert rgm.bands == {}
assert rgm.aliases == {}
assert rgm.extra_dims == {}

coord = FixedCoord("band", ["r", "g", "b"], dim="band")

rgm = raster_group_md(
xr.Dataset(), base=RasterGroupMetadata({}, {}, {"band": 3}, [])
)
assert rgm.extra_dims == {"band": 3}
assert len(rgm.extra_coords) == 0

rgm = raster_group_md(
xr.Dataset(), base=RasterGroupMetadata({}, extra_coords=[coord])
)
assert rgm.extra_dims == {}
assert rgm.extra_dims_full() == {"band": 3}
assert len(rgm.extra_coords) == 1
assert rgm.extra_coords[0] == coord
135 changes: 106 additions & 29 deletions odc/loader/testing/mem_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

from __future__ import annotations

import json
from contextlib import contextmanager
from typing import Any, Iterator
from typing import Any, Iterator, Sequence

import numpy as np
import xarray as xr
from odc.geo.geobox import GeoBox
from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_reproject
from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_coords, xr_reproject

from ..types import (
BandKey,
Expand All @@ -25,6 +26,18 @@
)


def extract_zarr_spec(src: dict[str, Any]) -> dict[str, Any] | None:
if "zarr:metadata" in src:
# TODO: handle zarr:chunks for reference filesystem
zmd = {"zarr_consolidated_format": 1, "metadata": src["zarr:metadata"]}
elif "zarr_consolidated_format" in src:
zmd = src
else:
zmd = {"zarr_consolidated_format": 1, "metadata": src}

return {".zmetadata": json.dumps(zmd)}


class XrMDPlugin:
"""
Convert xarray.Dataset to RasterGroupMetadata.
Expand All @@ -35,35 +48,64 @@ class XrMDPlugin:
- Driver data is xarray.DataArray for each band
"""

def __init__(self, src: xr.Dataset | None = None) -> None:
def __init__(
self,
template: RasterGroupMetadata,
src: xr.Dataset | None = None,
) -> None:
self._template = template
self._src = src
self._md = (
raster_group_md(src)
if src is not None
else RasterGroupMetadata({}, {}, {}, [])
)

def _from_zarr_spec(
self,
spec_doc: dict[str, Any],
regen_coords: bool = False,
) -> xr.Dataset:
xx = xr.open_dataset(spec_doc, engine="zarr")
gbox = xx.odc.geobox
if gbox is not None and regen_coords:
# re-gen x,y coords from geobox
xx = xx.assign_coords(xr_coords(gbox))

return xx

def _resolve_src(self, md: Any, regen_coords: bool = False) -> xr.Dataset | None:
src = self._src

if isinstance(md, dict) and (spec_doc := extract_zarr_spec(md)) is not None:
src = self._from_zarr_spec(spec_doc, regen_coords=regen_coords)

if isinstance(md, xr.Dataset):
src = md

# TODO: support stac items and datacube datasets

return src

def extract(self, md: Any) -> RasterGroupMetadata:
"""Fixed description of src dataset."""
if isinstance(md, RasterGroupMetadata):
return md
if isinstance(md, xr.Dataset):
return raster_group_md(md)
return self._md

def driver_data(self, md: Any, band_key: BandKey) -> xr.DataArray:
if (src := self._resolve_src(md, regen_coords=False)) is not None:
return raster_group_md(src, base=self._template)

return self._template

def driver_data(self, md: Any, band_key: BandKey) -> xr.DataArray | None:
"""
Extract driver specific data for a given band.
"""
assert md is not None
name, _ = band_key
if isinstance(md, xr.Dataset):
return md[name]

if isinstance(md, xr.DataArray):
return md

assert self._src is not None
return self._src[name]
if (src := self._resolve_src(md, regen_coords=True)) is not None:
if (aa := src.data_vars.get(name)) is not None:
return aa

return None


class Context:
Expand Down Expand Up @@ -131,8 +173,17 @@ class XrMemReaderDriver:

Reader = XrMemReader

def __init__(self, src: xr.Dataset | None = None) -> None:
def __init__(
self,
src: xr.Dataset | None = None,
template: RasterGroupMetadata | None = None,
) -> None:
if src is not None and template is None:
template = raster_group_md(src)
if template is None:
template = RasterGroupMetadata({}, {}, {}, [])
self.src = src
self.template = template

def new_load(
self,
Expand All @@ -159,7 +210,7 @@ def open(self, src: RasterSource, ctx: Context) -> XrMemReader:

@property
def md_parser(self) -> MDParser:
return XrMDPlugin(self.src)
return XrMDPlugin(self.template, src=self.src)

@property
def dask_reader(self) -> DaskRasterReader | None:
Expand Down Expand Up @@ -187,25 +238,46 @@ def band_info(xx: xr.DataArray) -> RasterBandMetadata:
)


def raster_group_md(src: xr.Dataset) -> RasterGroupMetadata:
def raster_group_md(
src: xr.Dataset,
*,
base: RasterGroupMetadata | None = None,
aliases: dict[str, list[BandKey]] | None = None,
extra_coords: Sequence[FixedCoord] = (),
extra_dims: dict[str, int] | None = None,
) -> RasterGroupMetadata:
oo: ODCExtensionDs = src.odc
sdims = oo.spatial_dims or ("y", "x")

bands: dict[BandKey, RasterBandMetadata] = {
(str(k), 1): band_info(v) for k, v in src.data_vars.items() if v.ndim >= 2
}
if base is None:
base = RasterGroupMetadata(
bands={},
aliases=aliases or {},
extra_coords=extra_coords,
extra_dims=extra_dims or {},
)

bands = base.bands.copy()
bands.update(
{(str(k), 1): band_info(v) for k, v in src.data_vars.items() if v.ndim >= 2}
)

edims = base.extra_dims.copy()
edims.update({str(name): sz for name, sz in src.sizes.items() if name not in sdims})

extra_dims: dict[str, int] = {
str(name): sz for name, sz in src.sizes.items() if name not in sdims
}
aliases: dict[str, list[BandKey]] = base.aliases.copy()

aliases: dict[str, list[BandKey]] = {}
extra_coords: list[FixedCoord] = list(base.extra_coords)
supplied_coords = set(coord.name for coord in extra_coords)

extra_coords: list[FixedCoord] = []
for coord in src.coords.values():
if len(coord.dims) != 1 or coord.dims[0] in sdims:
# Only 1-d non-spatial coords
continue

if coord.name in supplied_coords:
continue

extra_coords.append(
FixedCoord(
coord.name,
Expand All @@ -215,4 +287,9 @@ def raster_group_md(src: xr.Dataset) -> RasterGroupMetadata:
)
)

return RasterGroupMetadata(bands, aliases, extra_dims, extra_coords)
return RasterGroupMetadata(
bands=bands,
aliases=aliases,
extra_dims=edims,
extra_coords=extra_coords,
)

0 comments on commit e62615b

Please sign in to comment.