Skip to content

Commit

Permalink
add creodias vm support
Browse files Browse the repository at this point in the history
  • Loading branch information
konstntokas committed Dec 16, 2024
1 parent 5ba6b6f commit ac5abc0
Show file tree
Hide file tree
Showing 10 changed files with 411 additions and 51 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- dask
- s3fs
- numpy
- odc-stac
- rasterio
- rioxarray
- requests
Expand Down
14 changes: 7 additions & 7 deletions test/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,11 @@ def test_open_data_abfs(self):
_ = store.open_data(data_id, asset_names=["surface_air_pressure"])
self.assertEqual(
(
"Only 's3' and 'https' protocols are supported, not 'abfs'. The asset "
"'surface_air_pressure' has a href 'abfs://era5/ERA5/2020/12/"
"surface_air_pressure.zarr'. The item's url is given by "
"'https://planetarycomputer.microsoft.com/api/stac/v1/collections/"
"era5-pds/items/era5-pds-2020-12-an'."
"Only 'file', 's3' and 'https' protocols are supported, not 'abfs'. "
"The asset 'surface_air_pressure' has a href 'abfs://era5/ERA5/2020/"
"12/surface_air_pressure.zarr'. The item's url is given by "
"'https://planetarycomputer.microsoft.com/api/stac/v1/collectio"
"ns/era5-pds/items/era5-pds-2020-12-an'."
),
f"{cm.exception}",
)
Expand Down Expand Up @@ -655,11 +655,11 @@ def test_open_data_stack_mode(self):
self.assertIsInstance(ds, xr.Dataset)
self.assertCountEqual(["red", "green", "blue"], list(ds.data_vars))
self.assertCountEqual(
[4, 16, 16],
[4, 15, 15],
[ds.sizes["time"], ds.sizes["y"], ds.sizes["x"]],
)
self.assertCountEqual(
[1, 16, 16],
[1, 15, 15],
[ds.chunksizes["time"][0], ds.chunksizes["y"][0], ds.chunksizes["x"][0]],
)

Expand Down
7 changes: 4 additions & 3 deletions xcube_stac/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def search_items(
if searchable:
# rewrite to "datetime"
search_params["datetime"] = search_params.pop("time_range", None)
print(f"{datetime.datetime.now()}: {search_params}")
items = catalog.search(**search_params).items()
else:
items = search_nonsearchable_catalog(catalog, **search_params)
Expand Down Expand Up @@ -551,11 +552,11 @@ def get_gridmapping(
crs: Union[str, pyproj.crs.CRS],
tile_size: Union[int, tuple[int, int]] = TILE_SIZE,
) -> GridMapping:
x_size = int((bbox[2] - bbox[0]) / spatial_res) + 1
y_size = int(abs(bbox[3] - bbox[1]) / spatial_res) + 1
x_size = int((bbox[2] - bbox[0]) / spatial_res)
y_size = int(abs(bbox[3] - bbox[1]) / spatial_res)
return GridMapping.regular(
size=(x_size, y_size),
xy_min=(bbox[0] - spatial_res / 2, bbox[1] - spatial_res / 2),
xy_min=(bbox[0], bbox[1]),
xy_res=spatial_res,
crs=crs,
tile_size=tile_size,
Expand Down
47 changes: 42 additions & 5 deletions xcube_stac/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,42 @@ def open_data(
)


class Sentinel2DataAccessor:
"""Implementation of the data accessor supporting
the jp2 format of Sentinel-2 data.
"""

def __init__(self, root):
self._root = root

@property
def root(self) -> str:
return self._root

def open_data(
self,
access_params: dict,
opener_id: str = None,
data_type: DataTypeLike = None,
**open_params,
) -> Union[xr.Dataset, MultiLevelDataset]:
if opener_id is None:
opener_id = ""
if "tile_size" in open_params:
LOG.info(
"The parameter tile_size is set to (1024, 1024), which is the "
"native chunk size of the jp2 files in the Sentinel-2 archive."
)
if is_valid_ml_data_type(data_type) or opener_id.split(":")[0] == "mldataset":
return Jp2MultiLevelDataset(access_params["fs_path"], **open_params)
else:
return rioxarray.open_rasterio(
access_params["fs_path"],
chunks=dict(x=1024, y=1024),
band_as_variable=True,
)


class S3DataAccessor:
"""Implementation of the data accessor supporting
the zarr, geotiff and netcdf format via the AWS S3 protocol.
Expand Down Expand Up @@ -160,14 +196,15 @@ def open_data(
"The parameter tile_size is set to (1024, 1024), which is the "
"native chunk size of the jp2 files in the Sentinel-2 archive."
)
file_path = (
f"{access_params["protocol"]}://{access_params["root"]}/"
f"{access_params["fs_path"]}"
)
if is_valid_ml_data_type(data_type) or opener_id.split(":")[0] == "mldataset":
return Jp2MultiLevelDataset(access_params, **open_params)
return Jp2MultiLevelDataset(file_path, **open_params)
else:
return rioxarray.open_rasterio(
(
f"{access_params["protocol"]}://{access_params["root"]}/"
f"{access_params["fs_path"]}"
),
file_path,
chunks=dict(x=1024, y=1024),
band_as_variable=True,
)
18 changes: 14 additions & 4 deletions xcube_stac/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
JsonNumberSchema,
JsonObjectSchema,
JsonStringSchema,
JsonComplexSchema,
)

# general stac constants
Expand Down Expand Up @@ -92,10 +93,19 @@
# parameter schemas
STAC_STORE_PARAMETERS = dict(
url=JsonStringSchema(title="URL to STAC catalog"),
stack_mode=JsonBooleanSchema(
title="Decide if stacking of STAC items is applied",
description="If True, 'odc-stac' is used as a default backend.",
default=False,
stack_mode=JsonComplexSchema(
one_of=[
JsonStringSchema(
title="Backend for stacking STAC items",
description="So far, only 'odc-stac' is supported as a backend.",
const="odc-stac",
),
JsonBooleanSchema(
title="Decide if stacking of STAC items is applied",
description="If True, 'odc-stac' is used as a default backend.",
default=False,
),
],
),
)
STAC_STORE_PARAMETERS.update(S3FsAccessor.get_storage_options_schema().properties)
Expand Down
166 changes: 165 additions & 1 deletion xcube_stac/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import pystac_client.client
from xcube.core.store import DataStoreError
import s3fs
import fsspec

from .accessor import S3DataAccessor
from .accessor import S3Sentinel2DataAccessor
from .accessor import Sentinel2DataAccessor
from .constants import MAP_CDSE_COLLECTION_FORMAT
from .constants import MLDATASET_FORMATS
from .constants import STAC_SEARCH_PARAMETERS
Expand All @@ -17,6 +19,7 @@
from .constants import SCHEMA_PROCESSING_LEVEL
from .constants import SCHEMA_COLLECTIONS
from .constants import SCHEMA_SPATIAL_RES
from .constants import LOG
from .sen2.constants import CDSE_SENITNEL_2_BANDS
from .sen2.constants import CDSE_SENTINEL_2_LEVEL_BAND_RESOLUTIONS
from .sen2.constants import CDSE_SENTINEL_2_MIN_RESOLUTIONS
Expand All @@ -42,6 +45,8 @@ def __init__(self):
self.schema_search_params = STAC_SEARCH_PARAMETERS
self.schema_search_params_stack = STAC_SEARCH_PARAMETERS_STACK_MODE
self.s3_accessor = S3DataAccessor
self.file_accessor = None
self.limit_split_timerange = np.inf

def parse_item(self, item: pystac.Item, **open_params) -> pystac.Item:
return item
Expand Down Expand Up @@ -173,6 +178,7 @@ def __init__(self, **storage_options_s3):
secret=storage_options_s3["secret"],
)
self.s3_accessor = S3Sentinel2DataAccessor
self.limit_split_timerange = 190

def parse_item(self, item: pystac.Item, **open_params) -> pystac.Item:
processing_level = open_params.pop("processing_level", "L2A")
Expand All @@ -195,7 +201,18 @@ def parse_item(self, item: pystac.Item, **open_params) -> pystac.Item:
hrefs = self._fs.glob(
f"{href_base}/**/*_{asset_name}_{res_select}m.jp2"
)
assert len(hrefs) == 1, "No unique jp2 file found"
if len(hrefs) == 0:
LOG.warning(
"No jp2 file found for "
f"{href_base}/**/*_{asset_name}_{res_select}m.jp2"
)
hrefs = self._fs.glob(f"{href_base}/**/*_{asset_name}_*.jp2")
if len(hrefs) > 1:
LOG.warning(
f"Multiple hrefs {hrefs} are found "
f"for {href_base}/**/*_{asset_name}_{res_select}m.jp2. "
f"Href {hrefs[0]} is taken."
)
href_mod = hrefs[0]
time_end = hrefs[0].split("/IMG_DATA/")[0][-15:]
else:
Expand Down Expand Up @@ -284,3 +301,150 @@ def search_items(
if not processing_level[1:] in item.properties["processingLevel"]:
continue
yield item


class HelperCdseCreodiasVM(Helper):

def __init__(self):
super().__init__()
self.supported_protocols = ["file"]
self.supported_format_ids = ["netcdf", "zarr", "geotiff", "jp2"]
self.schema_open_params = dict(
**STAC_OPEN_PARAMETERS, spatial_res=SCHEMA_SPATIAL_RES
)
open_params_stack = dict(
**STAC_OPEN_PARAMETERS_STACK_MODE, processing_level=SCHEMA_PROCESSING_LEVEL
)
del open_params_stack["query"]
self.schema_open_params_stack = open_params_stack
self.schema_search_params = dict(
**STAC_SEARCH_PARAMETERS_STACK_MODE,
collections=SCHEMA_COLLECTIONS,
processing_level=SCHEMA_PROCESSING_LEVEL,
)
self._fs = fsspec.filesystem("file")
self.file_accessor = Sentinel2DataAccessor
self.limit_split_timerange = 190

def parse_item(self, item: pystac.Item, **open_params) -> pystac.Item:
processing_level = open_params.pop("processing_level", "L2A")
open_params["asset_names"] = open_params.get(
"asset_names", CDSE_SENITNEL_2_BANDS[processing_level]
)
href_base = item.assets["PRODUCT"].extra_fields["alternate"]["s3"]["href"][1:]
href_base = f"/{href_base}"
res_want = open_params.get("spatial_res", CDSE_SENTINEL_2_MIN_RESOLUTIONS)
if "crs" in open_params:
target_crs = normalize_crs(open_params["crs"])
if target_crs.is_geographic:
res_want = open_params["spatial_res"] * 111320
time_end = None
for asset_name in open_params["asset_names"]:
res_avail = CDSE_SENTINEL_2_LEVEL_BAND_RESOLUTIONS[processing_level][
asset_name
]
res_select = res_avail[np.argmin(abs(np.array(res_avail) - res_want))]
if time_end is None:
hrefs = self._fs.glob(
f"{href_base}/**/*_{asset_name}_{res_select}m.jp2"
)
if len(hrefs) == 0:
LOG.warning(
"No jp2 file found for "
f"{href_base}/**/*_{asset_name}_{res_select}m.jp2"
)
hrefs = self._fs.glob(f"{href_base}/**/*_{asset_name}_*.jp2")
if len(hrefs) > 1:
LOG.warning(
f"Multiple hrefs {hrefs} are found "
f"for {href_base}/**/*_{asset_name}_{res_select}m.jp2. "
f"Href {hrefs[0]} is taken."
)
href_mod = hrefs[0]
time_end = hrefs[0].split("/IMG_DATA/")[0][-15:]
else:
id_parts = item.id.split("_")
href_mod = (
f"{href_base}/GRANULE/L2A_T{item.properties["tileId"]}_"
f"A{item.properties["orbitNumber"]:06}_{time_end}/IMG_DATA/"
f"R{res_select}m/T{item.properties["tileId"]}_"
f"{id_parts[2]}_{asset_name}_{res_select}m.jp2"
)
if float(item.properties["processorVersion"]) >= 4.00:
offset = CDSE_SENITNEL_2_OFFSET_400[asset_name]
else:
offset = 0
item.assets[asset_name] = pystac.Asset(
href_mod,
asset_name,
media_type="image/jp2",
roles=["data"],
extra_fields={
"cdse": True,
"raster:bands": [
dict(
nodata=CDSE_SENITNEL_2_NO_DATA,
scale=1 / CDSE_SENITNEL_2_SCALE[asset_name],
offset=offset / CDSE_SENITNEL_2_SCALE[asset_name],
)
],
},
)
# add asset for meta data for angles
item.assets["granule_metadata"] = pystac.Asset(
f"{href_base}/GRANULE/MTD_TL.xml",
"granule_metadata",
media_type="application/xml",
roles=["metadata"],
extra_fields={"cdse": True},
)
return item

def get_data_access_params(self, item: pystac.Item, **open_params) -> dict:
processing_level = open_params.pop("processing_level", "L2A")
asset_names = open_params.get(
"asset_names", CDSE_SENITNEL_2_BANDS[processing_level]
)
data_access_params = {}
for asset_name in asset_names:
protocol = "file"
href_components = item.assets[asset_name].href.split("/")
root = ""
instrument = href_components[2]
format_id = MAP_CDSE_COLLECTION_FORMAT[instrument]
fs_path = "/".join(href_components)
storage_options = {}
data_access_params[asset_name] = dict(
name=asset_name,
protocol=protocol,
root=root,
fs_path=fs_path,
storage_options=storage_options,
format_id=format_id,
href=item.assets[asset_name].href,
)
return data_access_params

def get_protocols(self, item: pystac.Item, **open_params) -> list[str]:
return ["file"]

def get_format_ids(self, item: pystac.Item, **open_params) -> list[str]:
return ["jp2"]

def is_mldataset_available(self, item: pystac.Item, **open_params) -> bool:
return True

def search_items(
self,
catalog: Union[pystac.Catalog, pystac_client.client.Client],
searchable: bool,
**search_params,
) -> Iterator[pystac.Item]:
processing_level = search_params.pop("processing_level", "L2A")
if "sortby" not in search_params:
search_params["sortby"] = "+datetime"
items = search_items(catalog, searchable, **search_params)
for item in items:
if not processing_level[1:] in item.properties["processingLevel"]:
continue
yield item
7 changes: 1 addition & 6 deletions xcube_stac/mldataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,10 @@ class Jp2MultiLevelDataset(LazyMultiLevelDataset):

def __init__(
self,
access_params: dict,
file_path: str,
**open_params: dict[str, Any],
):
file_path = (
f"{access_params["protocol"]}://{access_params["root"]}"
f"/{access_params["fs_path"]}"
)
self._file_path = file_path
self._access_params = access_params
self._open_params = open_params
super().__init__(ds_id=file_path)

Expand Down
Loading

0 comments on commit ac5abc0

Please sign in to comment.