Skip to content

Commit

Permalink
tests added for _utils and accessor
Browse files Browse the repository at this point in the history
  • Loading branch information
konstntokas committed Dec 6, 2024
1 parent b13d273 commit d1045b5
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 25 deletions.
111 changes: 111 additions & 0 deletions test/test_accessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# The MIT License (MIT)
# Copyright (c) 2024 by the xcube development team and contributors
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NON INFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import unittest
from unittest.mock import patch

import dask
import dask.array as da
import xarray as xr
import rasterio
from xcube.core.mldataset import MultiLevelDataset

from xcube_stac.accessor import S3Sentinel2DataAccessor


class TestS3Sentinel2DataAccessor(unittest.TestCase):
def setUp(self):
storage_options = dict(
anon=False,
key="xxx",
secret="xxx",
client_kwargs=dict(endpoint_url="https://eodata.dataspace.copernicus.eu"),
)
self.accessor = S3Sentinel2DataAccessor(
root="eodata", storage_options=storage_options
)

def test_init(self):
self.assertEqual("eodata", self.accessor._root)
self.assertIsInstance(self.accessor.session, rasterio.session.AWSSession)
self.assertIsInstance(self.accessor.env, rasterio.env.Env)
self.assertEqual("single-threaded", dask.config.get("scheduler"))

def test_del(self):
with self.assertLogs("xcube.stac", level="DEBUG") as cm:
del self.accessor
self.assertEqual(1, len(cm.output))
msg = "DEBUG:xcube.stac:Exit rasterio.env.Env for CDSE data access."
self.assertEqual(msg, str(cm.output[-1]))

def test_root(self):
self.assertEqual("eodata", self.accessor.root)

@patch("rioxarray.open_rasterio")
def test_open_data(self, mock_open_rasterio):
# set-up mock
mock_data = {
"band_1": (("y", "x"), da.ones((2048, 2048), chunks=(1024, 1024))),
}
mock_ds = xr.Dataset(mock_data)
mock_open_rasterio.return_value = mock_ds

access_params = dict(protocol="s3", root="eodata", fs_path="test.tif")
ds = self.accessor.open_data(access_params)
mock_open_rasterio.assert_called_once_with(
"s3://eodata/test.tif",
chunks=dict(x=1024, y=1024),
band_as_variable=True,
)
self.assertTrue("band_1" in ds)
self.assertEqual(ds["band_1"].shape, (2048, 2048))
self.assertCountEqual(
[1024, 1024], [ds.chunksizes["x"][0], ds.chunksizes["y"][0]]
)

with self.assertLogs("xcube.stac", level="INFO") as cm:
ds = self.accessor.open_data(access_params, tile_size=(512, 512))
self.assertEqual(1, len(cm.output))
msg = (
"INFO:xcube.stac:The parameter tile_size is set to (1024, 1024), which "
"is the native chunk size of the jp2 files in the Sentinel-2 archive."
)
self.assertEqual(msg, str(cm.output[-1]))
self.assertTrue("band_1" in ds)
self.assertEqual(ds["band_1"].shape, (2048, 2048))
self.assertCountEqual(
[1024, 1024], [ds.chunksizes["x"][0], ds.chunksizes["y"][0]]
)

mlds = self.accessor.open_data(access_params, data_type="mldataset")
self.assertIsInstance(mlds, MultiLevelDataset)
ds = mlds.base_dataset
mock_open_rasterio.assert_called_with(
"s3://eodata/test.tif",
overview_level=None,
chunks=dict(x=1024, y=1024),
band_as_variable=True,
)
self.assertTrue("band_1" in ds)
self.assertEqual(ds["band_1"].shape, (2048, 2048))
self.assertCountEqual(
[1024, 1024], [ds.chunksizes["x"][0], ds.chunksizes["y"][0]]
)
195 changes: 195 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,21 @@
import datetime
import unittest

import dask.array as da
import pystac
import pyproj
import numpy as np
import xarray as xr
from xcube.core.store import DataStoreError

from xcube_stac._utils import (
get_format_id,
list_assets_from_item,
get_format_from_path,
reproject_bbox,
get_spatial_dims,
merge_datasets,
normalize_crs,
convert_datetime2str,
convert_str2datetime,
do_bboxes_intersect,
Expand All @@ -37,6 +48,71 @@

class UtilsTest(unittest.TestCase):

def test_get_format_id(self):
asset = pystac.Asset(
href=f"https://example.com/data/test.tif",
media_type="image/tiff",
roles=["data"],
extra_fields=dict(id="test"),
)
self.assertEqual("geotiff", get_format_id(asset))
asset = pystac.Asset(
href=f"https://example.com/data/test.tif",
roles=["data"],
extra_fields=dict(id="test"),
)
self.assertEqual("geotiff", get_format_id(asset))
asset = pystac.Asset(
href=f"https://example.com/data/test.xml",
roles=["meta"],
extra_fields=dict(id="test"),
)
with self.assertRaises(DataStoreError) as cm:
get_format_id(asset)
self.assertEqual(
"No format_id found for asset test",
f"{cm.exception}",
)

def test_list_assets_from_item(self):
geometry = {
"type": "Polygon",
"coordinates": [
[[100.0, 0.0], [101.0, 0.0], [101.0, 1.0], [100.0, 1.0], [100.0, 0.0]]
],
}
bbox = [100.0, 0.0, 101.0, 1.0]
dt = datetime.datetime(2023, 1, 1, 0, 0, 0)
item = pystac.Item(
id="test_item", geometry=geometry, bbox=bbox, datetime=dt, properties={}
)
supported_format_ids = ["geotiff", "netcdf"]

asset_names = ["asset1", "asset2", "asset3"]
media_types = ["image/tiff", "application/zarr", "meta/xml"]
for asset_name, media_type in zip(asset_names, media_types):
asset_href = f"https://example.com/data/{asset_name}.tif"
asset = pystac.Asset(href=asset_href, media_type=media_type, roles=["data"])
item.add_asset(asset_name, asset)
list_assets = list_assets_from_item(item)
self.assertCountEqual(
["asset1", "asset2"], [asset.extra_fields["id"] for asset in list_assets]
)
list_assets = list_assets_from_item(item, asset_names=["asset2"])
self.assertCountEqual(
["asset2"], [asset.extra_fields["id"] for asset in list_assets]
)
list_assets = list_assets_from_item(
item, supported_format_ids=supported_format_ids
)
self.assertCountEqual(
["asset1"], [asset.extra_fields["id"] for asset in list_assets]
)
list_assets = list_assets_from_item(
item, supported_format_ids=supported_format_ids, asset_names=["asset2"]
)
self.assertCountEqual([], [asset.extra_fields["id"] for asset in list_assets])

def test_convert_datetime2str(self):
dt = datetime.datetime(2024, 1, 1, 12, 00, 00)
self.assertEqual("2024-01-01T12:00:00", convert_datetime2str(dt))
Expand Down Expand Up @@ -229,8 +305,127 @@ def test_do_bboxes_intersect(self):
for west, south, east, north, fun in item_test_paramss:
fun(do_bboxes_intersect(item.bbox, bbox=[west, south, east, north]))

def test_get_format_from_path(self):
path = "https://example/data/file.tif"
self.assertEqual("geotiff", get_format_from_path(path))
path = "https://example/data/file.zarr"
self.assertEqual("zarr", get_format_from_path(path))

def test_update_nested_dict(self):
dic = dict(a=1, b=dict(c=3))
dic_update = dict(d=1, b=dict(c=5, e=8))
dic_expected = dict(a=1, d=1, b=dict(c=5, e=8))
self.assertDictEqual(dic_expected, update_dict(dic, dic_update))

def test_reproject_bbox(self):
bbox_wgs84 = [2, 50, 3, 51]
crs_wgs84 = "EPSG:4326"
crs_3035 = "EPSG:3035"
bbox_3035 = [
3748675.952977144,
3018751.225593612,
3830472.135997862,
3122243.2680214494,
]
self.assertEqual(bbox_wgs84, reproject_bbox(bbox_wgs84, crs_wgs84, crs_wgs84))
self.assertEqual(bbox_3035, reproject_bbox(bbox_3035, crs_3035, crs_3035))
np.testing.assert_almost_equal(
bbox_3035, reproject_bbox(bbox_wgs84, crs_wgs84, crs_3035)
)
np.testing.assert_almost_equal(
bbox_wgs84, reproject_bbox(bbox_3035, crs_3035, crs_wgs84)
)

def test_normalize_crs(self):
crs_str = "EPSG:4326"
crs_pyproj = pyproj.CRS.from_string(crs_str)
self.assertEqual(crs_pyproj, normalize_crs(crs_str))
self.assertEqual(crs_pyproj, normalize_crs(crs_pyproj))

def test_merge_datasets(self):
ds1 = xr.Dataset()
ds1["B01"] = xr.DataArray(
data=da.ones((3, 3)),
dims=("y", "x"),
coords=dict(x=[1000, 1020, 1040], y=[1000, 1020, 1040]),
)
ds2 = xr.Dataset()
ds2["B02"] = xr.DataArray(
data=da.ones((5, 5)),
dims=("y", "x"),
coords=dict(
x=[995, 1005, 1015, 1025, 1035],
y=[995, 1005, 1015, 1025, 1035],
),
)
ds3 = xr.Dataset()
ds3["B03"] = xr.DataArray(
data=da.ones((5, 5)),
dims=("y", "x"),
coords=dict(
x=[995, 1005, 1015, 1025, 1035],
y=[995, 1005, 1015, 1025, 1035],
),
)
ds_list = [ds1, ds2, ds3]
wkt = (
'PROJCRS["ETRS89 / LAEA Europe",'
'BASEGEOGCRS["ETRS89",'
'DATUM["European Terrestrial Reference System 1989",'
'ELLIPSOID["GRS 1980",6378137,298.257222101,LENGTHUNIT["metre",1]]]],'
'CONVERSION["Europe Equal Area",'
'METHOD["Lambert Azimuthal Equal Area"],'
'PARAMETER["Latitude of natural origin",52,'
'ANGLEUNIT["degree",0.0174532925199433]],'
'PARAMETER["Longitude of natural origin",10,'
'ANGLEUNIT["degree",0.0174532925199433]],'
'PARAMETER["False easting",4321000,LENGTHUNIT["metre",1]],'
'PARAMETER["False northing",3210000,LENGTHUNIT["metre",1]]],'
"CS[Cartesian,2],"
'AXIS["easting (X)",east,ORDER[1]],'
'AXIS["northing (Y)",north,ORDER[2]],'
'LENGTHUNIT["metre",1]]'
)
for ds in ds_list:
ds["crs"] = xr.DataArray(
data=0,
attrs={
"long_name": "Coordinate Reference System",
"description": "WKT representation of EPSG:3035",
"grid_mapping_name": "lambert_azimuthal_equal_area",
"crs_wkt": wkt,
},
)
ds_merged = merge_datasets(ds_list)
ds_merged = ds_merged.drop_vars("crs")
ds_merged_expected = xr.Dataset()
ds_merged_expected["B01"] = ds3["B03"]
ds_merged_expected["B02"] = ds3["B03"]
ds_merged_expected["B03"] = ds3["B03"]
xr.testing.assert_allclose(ds_merged_expected.B01, ds_merged.B01)

def test_get_spatial_dims(self):
ds = xr.Dataset()
ds["var"] = xr.DataArray(
data=np.ones((2, 2)), dims=("y", "x"), coords=dict(y=[0, 10], x=[0, 10])
)
self.assertEqual(("y", "x"), get_spatial_dims(ds))
ds = xr.Dataset()
ds["var"] = xr.DataArray(
data=np.ones((2, 2)),
dims=("lat", "lon"),
coords=dict(lat=[0, 10], lon=[0, 10]),
)
self.assertEqual(("lat", "lon"), get_spatial_dims(ds))
ds = xr.Dataset()
ds["var"] = xr.DataArray(
data=np.ones((2, 2)),
dims=("dim_false0", "dim_false1"),
coords=dict(dim_false0=[0, 10], dim_false1=[0, 10]),
)
with self.assertRaises(DataStoreError) as cm:
get_spatial_dims(ds)
self.assertEqual(
"No spatial dimensions found in dataset.",
f"{cm.exception}",
)
29 changes: 5 additions & 24 deletions xcube_stac/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@


def get_format_id(asset: pystac.Asset) -> str:
if hasattr(asset, "media_type"):
format_id = MAP_MIME_TYP_FORMAT.get(asset.media_type.split("; ")[0])
if asset.media_type is None:
format_id = get_format_from_path(asset.href)
else:
_, file_extension = os.path.splitext(asset.href)
format_id = MAP_FILE_EXTENSION_FORMAT.get(file_extension)
format_id = MAP_MIME_TYP_FORMAT.get(asset.media_type.split("; ")[0])
if format_id is None:
raise DataStoreError(f"No format_id found for asset {asset.extra_fields['id']}")
return format_id
Expand Down Expand Up @@ -412,25 +411,7 @@ def get_url_from_pystac_object(

def get_format_from_path(path: str) -> str:
_, file_extension = os.path.splitext(path)
return MAP_FILE_EXTENSION_FORMAT[file_extension]


def xarray_rename_vars(
ds: Union[xr.Dataset, xr.DataArray], name_dict: dict
) -> Union[xr.Dataset, xr.DataArray]:
"""Auxiliary functions which turns the method xarray.Dataset.rename_vars and
xarray.DataArray.rename_vars into a function which takes the Dataset or DataArray
as argument.
Args:
ds: Dataset or DataArray
name_dict: Dictionary whose keys are current variable names and whose values
are the desired names.
Returns:
Dataset with renamed variables
"""
return ds.rename_vars(name_dict)
return MAP_FILE_EXTENSION_FORMAT.get(file_extension)


def is_valid_data_type(data_type: DataTypeLike) -> bool:
Expand Down Expand Up @@ -622,7 +603,7 @@ def get_spatial_dims(ds: xr.Dataset) -> (str, str):
elif "y" in ds and "x" in ds:
y_coord, x_coord = "y", "x"
else:
raise DataStoreError(f"No spatial dimensions found in dataset {ds[0]}")
raise DataStoreError(f"No spatial dimensions found in dataset.")
return y_coord, x_coord


Expand Down
1 change: 0 additions & 1 deletion xcube_stac/store_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
get_data_id_from_pystac_object,
get_url_from_pystac_object,
is_valid_ml_data_type,
list_assets_from_item,
reproject_bbox,
search_collections,
update_dict,
Expand Down

0 comments on commit d1045b5

Please sign in to comment.