Skip to content

Commit

Permalink
Support extra dims/coords in stac load config
Browse files Browse the repository at this point in the history
- change from_dict method signature
- if "assets" is present at top level, assume that
  it's a single collection config applicable to
  all collections
  • Loading branch information
Kirill888 committed May 27, 2024
1 parent 71a73c8 commit b4f29ac
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 7 deletions.
6 changes: 4 additions & 2 deletions odc/stac/_mdtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def __init__(
if cfg is None:
cfg = {}

self._cfg = MDParseConfig.from_dict(collection_id, cfg)
self._cfg = MDParseConfig.from_dict(cfg, collection_id)
self.check_proj: bool = not self._cfg.ignore_proj
self.has_proj: Optional[bool] = None
self.collection_id = collection_id
Expand Down Expand Up @@ -547,7 +547,9 @@ def _bootstrap(self, item: pystac.item.Item):

for alias, bkey in self._cfg.aliases.items():
aliases.setdefault(alias, []).insert(0, bkey)
md = RasterGroupMetadata(bands, aliases)
md = RasterGroupMetadata(
bands, aliases, self._cfg.extra_dims, self._cfg.extra_coords
)

# We assume that grouping of data bands into grids is consistent across
# entire collection, so we compute it once and keep it
Expand Down
37 changes: 32 additions & 5 deletions odc/stac/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Metadata and data loading model classes."""

from __future__ import annotations

import datetime as dt
import math
from copy import copy
from dataclasses import astuple, dataclass, field, replace
from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple

from odc.geo import CRS, Geometry, MaybeCRS
from odc.geo.geobox import GeoBox
Expand All @@ -14,6 +16,7 @@
BandIdentifier,
BandKey,
BandQuery,
FixedCoord,
RasterBandMetadata,
RasterGroupMetadata,
RasterSource,
Expand Down Expand Up @@ -412,27 +415,51 @@ def __dask_tokenize__(self):
class MDParseConfig:
"""Item parsing config."""

band_defaults: RasterBandMetadata = field(default_factory=RasterBandMetadata)
band_defaults: RasterBandMetadata = field(
default_factory=lambda: norm_band_metadata({})
)
band_cfg: Dict[str, RasterBandMetadata] = field(default_factory=dict)
aliases: Dict[str, BandKey] = field(default_factory=dict)
ignore_proj: bool = False
extra_dims: Dict[str, int] = field(default_factory=dict)
extra_coords: Sequence[FixedCoord] = ()

@staticmethod
def from_dict(collection_id: str, cfg=Dict[str, Any]) -> "MDParseConfig":
_cfg = copy(cfg.get("*", {}))
_cfg.update(cfg.get(collection_id, {}))
def from_dict(
cfg: Dict[str, Any], collection_id: str | None = None
) -> "MDParseConfig":
if collection_id is not None:
if "assets" in cfg: # Assume it's a single collection config
_cfg = copy(cfg)
else:
_cfg = copy(cfg.get("*", {}))
_cfg.update(cfg.get(collection_id, {}))
else:
_cfg = copy(cfg)

band_defaults, band_cfg = _norm_band_cfg(_cfg.get("assets", {}))

aliases = {
alias: ((band, 1) if isinstance(band, str) else band)
for alias, band in _cfg.get("aliases", {}).items()
}
ignore_proj: bool = _cfg.get("ignore_proj", False)
extra_dims: Dict[str, int] = _cfg.get("dims", {})
extra_coords: list[FixedCoord] = []
cc: dict[str, list[Any]] = _cfg.get("coords", {})
assert isinstance(cc, dict)

for name, val in cc.items():
assert isinstance(val, list)
extra_coords.append(FixedCoord(name, val))

return MDParseConfig(
band_defaults=band_defaults,
band_cfg=band_cfg,
ignore_proj=ignore_proj,
aliases=aliases,
extra_dims=extra_dims,
extra_coords=tuple(extra_coords),
)


Expand Down
23 changes: 23 additions & 0 deletions tests/test_mdtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from odc.loader.testing.fixtures import FakeMDPlugin
from odc.loader.types import FixedCoord, RasterBandMetadata, RasterGroupMetadata
from odc.stac._mdtools import (
MDParseConfig,
_auto_load_params,
_most_common_gbox,
_normalize_geometry,
Expand All @@ -37,6 +38,28 @@
GBOX = GeoBox.from_bbox((-20, -10, 20, 10), "epsg:3857", shape=(200, 400))


def test_mdparse_config():
assert MDParseConfig() == MDParseConfig()
assert MDParseConfig.from_dict({}) == MDParseConfig()
assert MDParseConfig.from_dict({}, "cc") == MDParseConfig()
assert MDParseConfig().extra_coords == ()
assert MDParseConfig().extra_dims == {}

cfg = {
"assets": {"visual": {"data_type": "uint8", "dims": ["y", "x", "rgb"]}},
"dims": {"rgb": 3},
"coords": {"rgb": ["r", "g", "b"]},
}
assert MDParseConfig.from_dict(cfg) == MDParseConfig.from_dict({"*": cfg}, "cc")
assert MDParseConfig.from_dict(cfg) == MDParseConfig.from_dict({"cc": cfg}, "cc")
assert MDParseConfig() == MDParseConfig.from_dict({"cc": cfg}, "not-in-cfg")
assert MDParseConfig.from_dict(cfg) == MDParseConfig.from_dict(cfg, "irrelevant")

cfg = MDParseConfig.from_dict(cfg)
assert cfg.extra_dims == {"rgb": 3}
assert cfg.extra_coords == (FixedCoord("rgb", ["r", "g", "b"]),)


def test_is_raster_data(sentinel_stac_ms: pystac.item.Item):
item = sentinel_stac_ms
assert "B01" in item.assets
Expand Down

0 comments on commit b4f29ac

Please sign in to comment.