Skip to content

Commit

Permalink
normans reveiw addressed
Browse files Browse the repository at this point in the history
  • Loading branch information
konstntokas committed May 15, 2024
1 parent cadfd62 commit 16e50f7
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 29 deletions.
20 changes: 10 additions & 10 deletions test/test_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_get_item_collection(self):
"spacenet-buildings-collection/AOI_4_Shanghai_img3344"
]
self.assertIsInstance(items, ItemCollection)
self.assertCountEqual(data_id_items, data_id_items_expected)
self.assertCountEqual(data_id_items_expected, data_id_items)
self.assertEqual(len(items), len(data_id_items))

@pytest.mark.vcr()
Expand All @@ -74,7 +74,7 @@ def test_get_item_collection_open_params(self):
"zanzibar-collection/znz001",
]
self.assertIsInstance(items, ItemCollection)
self.assertCountEqual(data_id_items, data_id_items_expected)
self.assertCountEqual(data_id_items_expected, data_id_items)
self.assertEqual(len(items), len(data_id_items))

items, data_id_items = stac_instance.get_item_collection(
Expand All @@ -83,7 +83,7 @@ def test_get_item_collection_open_params(self):
time_range=["2019-04-28", "2019-04-30"]
)
self.assertIsInstance(items, ItemCollection)
self.assertEqual(len(items), 0)
self.assertEqual(0, len(items))
self.assertEqual(len(items), len(data_id_items))

@pytest.mark.vcr()
Expand All @@ -110,7 +110,7 @@ def test_get_item_collection_searchable_catalog(self):
"sentinel-2-l2a/S2A_32UNU_20200302_0_L2A"
]
self.assertIsInstance(items, ItemCollection)
self.assertCountEqual(data_id_items, data_id_items_expected)
self.assertCountEqual(data_id_items_expected, data_id_items)
self.assertEqual(len(items), len(data_id_items))

@pytest.mark.vcr()
Expand All @@ -125,7 +125,7 @@ def test_get_data_ids(self):
"spacenet-buildings-collection/AOI_4_Shanghai_img3344/raster"
]
for (data_id, data_id_expected) in zip(data_ids, data_ids_expected):
self.assertEqual(data_id, data_id_expected)
self.assertEqual(data_id_expected, data_id)

@pytest.mark.vcr()
def test_get_data_ids_optional_args(self):
Expand All @@ -143,8 +143,8 @@ def test_get_data_ids_optional_args(self):
("zanzibar-collection:znz029:raster", {"title": "znz029_previewcog"})
]
for (data_id_test, data_id) in zip(data_ids_test, data_ids):
self.assertEqual(data_id[0], data_id_test[0])
self.assertDictEqual(data_id[1], data_id_test[1])
self.assertEqual(data_id_test[0], data_id[0])
self.assertDictEqual(data_id_test[1], data_id[1])

@pytest.mark.vcr()
def test_get_data_ids_optional_args_empty_args(self):
Expand All @@ -162,8 +162,8 @@ def test_get_data_ids_optional_args_empty_args(self):
("zanzibar-collection:znz029:raster", {})
]
for (data_id_test, data_id) in zip(data_ids_test, data_ids):
self.assertEqual(data_id[0], data_id_test[0])
self.assertDictEqual(data_id[1], data_id_test[1])
self.assertEqual(data_id_test[0], data_id[0])
self.assertDictEqual(data_id_test[1], data_id[1])

@pytest.mark.vcr()
def test_get_data_ids_from_items(self):
Expand All @@ -181,7 +181,7 @@ def test_get_data_ids_from_items(self):
"zanzibar-collection/znz029/raster"
]
for (data_id_expected, data_id) in zip(data_ids_expected, data_ids):
self.assertEqual(data_id, data_id_expected)
self.assertEqual(data_id_expected, data_id)

@pytest.mark.vcr()
def test_open_data(self):
Expand Down
6 changes: 3 additions & 3 deletions test/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_get_data_ids(self):
"spacenet-buildings-collection/AOI_4_Shanghai_img3344/raster"
]
for (data_id, data_id_expected) in zip(data_ids, data_ids_expected):
self.assertEqual(data_id, data_id_expected)
self.assertEqual(data_id_expected, data_id)

@pytest.mark.vcr()
def test_get_data_ids_optional_args(self):
Expand All @@ -90,7 +90,7 @@ def test_get_data_ids_optional_args(self):
"zanzibar-collection:znz029:raster"
]
for (data_id, data_id_expected) in zip(data_ids, data_ids_expected):
self.assertEqual(data_id, data_id_expected)
self.assertEqual(data_id_expected, data_id)

@pytest.mark.vcr()
def test_has_data(self):
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_get_item_collection(self):
"zanzibar-collection/znz029"
]
self.assertIsInstance(items, ItemCollection)
self.assertListEqual(data_id_items, data_id_items_expected)
self.assertListEqual(data_id_items_expected, data_id_items)
self.assertEqual(len(items), len(data_id_items))

@pytest.mark.vcr()
Expand Down
25 changes: 12 additions & 13 deletions xcube_stac/stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import pandas as pd
import pystac
from pystac import Catalog, Collection, ItemCollection, Item
import pystac_client
from shapely.geometry import box
import xarray as xr
Expand Down Expand Up @@ -70,7 +69,7 @@ def __init__(
# open_data(), which will be used to open the hrefs

@property
def catalog(self) -> Catalog:
def catalog(self) -> pystac.Catalog:
return self._catalog

def get_open_data_params_schema(self, data_id: str = None) -> JsonObjectSchema:
Expand All @@ -87,7 +86,7 @@ def get_open_data_params_schema(self, data_id: str = None) -> JsonObjectSchema:

def get_item_collection(
self, **open_params
) -> Tuple[ItemCollection, List[str]]:
) -> Tuple[pystac.ItemCollection, List[str]]:
"""Collects all items within the given STAC catalog
using the supplied *open_params*.
Expand All @@ -108,11 +107,11 @@ def get_item_collection(
self.catalog,
**open_params
)
items = ItemCollection(items)
items = pystac.ItemCollection(items)
item_data_ids = self.list_item_data_ids(items)
return items, item_data_ids

def get_item_data_id(self, item: Item) -> str:
def get_item_data_id(self, item: pystac.Item) -> str:
"""Generates the data ID of an item, which follows the structure:
`collection_id_0/../collection_id_n/item_id`
Expand All @@ -131,7 +130,7 @@ def get_item_data_id(self, item: Item) -> str:
id_parts.reverse()
return self._data_id_delimiter.join(id_parts)

def get_item_data_ids(self, items: Iterable[Item]) -> Iterator[str]:
def get_item_data_ids(self, items: Iterable[pystac.Item]) -> Iterator[str]:
"""Generates the data IDs of an item collection,
which follows the structure:
Expand All @@ -146,7 +145,7 @@ def get_item_data_ids(self, items: Iterable[Item]) -> Iterator[str]:
for item in items:
yield self.get_item_data_id(item)

def list_item_data_ids(self, items: Iterable[Item]) -> List[str]:
def list_item_data_ids(self, items: Iterable[pystac.Item]) -> List[str]:
"""Generates a list of data IDs for a given item collection,
which follows the structure:
Expand All @@ -162,7 +161,7 @@ def list_item_data_ids(self, items: Iterable[Item]) -> List[str]:

def get_data_ids(
self,
items: Iterable[Item] = None,
items: Iterable[pystac.Item] = None,
item_data_ids: Iterable[str] = None,
include_attrs: Container[str] = None,
**open_params
Expand Down Expand Up @@ -213,7 +212,7 @@ def get_data_ids(

def get_assets_from_item(
self,
item: Item,
item: pystac.Item,
include_attrs: Container[str] = None,
**open_params
) -> Iterator[str]:
Expand Down Expand Up @@ -268,10 +267,10 @@ def open_data(self, data_id: str, **open_params) -> xr.Dataset:

def _get_items_nonsearchable_catalog(
self,
pystac_object: Union[Catalog, Collection],
pystac_object: Union[pystac.Catalog, pystac.Collection],
recursive: bool = True,
**open_params
) -> Iterator[Tuple[Item, str]]:
) -> Iterator[Tuple[pystac.Item, str]]:
"""Get the items of a catalog which does not implement the
"STAC API - Item Search" conformance class.
Expand Down Expand Up @@ -318,7 +317,7 @@ def _get_items_nonsearchable_catalog(
# iterate through assets of item
yield item

def _is_datetime_in_range(self, item: Item, **open_params) -> bool:
def _is_datetime_in_range(self, item: pystac.Item, **open_params) -> bool:
"""Determine whether the datetime or datetime range of an item
intersects to the 'time_range' given by *open_params*.
Expand Down Expand Up @@ -352,7 +351,7 @@ def _is_datetime_in_range(self, item: Item, **open_params) -> bool:
dt_data = pd.Timestamp(item.properties["datetime"]).to_pydatetime()
return dt_start <= dt_data <= dt_end

def _do_bboxes_intersect(self, item: Item, **open_params) -> bool:
def _do_bboxes_intersect(self, item: pystac.Item, **open_params) -> bool:
"""Determine whether two bounding boxes intersect.
Args:
Expand Down
6 changes: 3 additions & 3 deletions xcube_stac/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from typing import Any, Container, Dict, Iterable, Iterator, List, Tuple, Union

from pystac import ItemCollection, Item
import pystac
import xarray as xr
from xcube.core.store import (
DATASET_TYPE,
Expand Down Expand Up @@ -92,7 +92,7 @@ def get_data_types_for_data(self, data_id: str) -> Tuple[str, ...]:

def get_item_collection(
self, **open_params
) -> Tuple[ItemCollection, List[str]]:
) -> Tuple[pystac.ItemCollection, List[str]]:
"""Collects all items within the given STAC catalog
using the supplied *open_params*.
Expand All @@ -105,7 +105,7 @@ def get_item_collection(
def get_data_ids(
self,
data_type: DataTypeLike = None,
items: Iterable[Item] = None,
items: Iterable[pystac.Item] = None,
item_data_ids: Iterable[str] = None,
include_attrs: Container[str] = None,
**open_params
Expand Down

0 comments on commit 16e50f7

Please sign in to comment.