Skip to content

Commit

Permalink
update extrac param from torchgeo
Browse files Browse the repository at this point in the history
  • Loading branch information
mpelchat04 committed Sep 6, 2024
1 parent 6afd245 commit 6efd4f1
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 51 deletions.
6 changes: 3 additions & 3 deletions tests/CI/test_gh_actions_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ class Test_GH_Actions(object):
def test_ci(self) -> None:
data_dir = "data"
Path(data_dir).mkdir(exist_ok=True, parents=True)
extract_archive(src="tests/data/spacenet.zip")
extract_archive(src="tests/data/new_brunswick_aerial.zip")
extract_archive(src="tests/data/massachusetts_buildings_kaggle.zip")
extract_archive(from_path="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/new_brunswick_aerial.zip")
extract_archive(from_path="tests/data/massachusetts_buildings_kaggle.zip")

with initialize(config_path="../../config", job_name="test_ci"):
cfg = compose(config_name="gdl_config_template")
Expand Down
56 changes: 28 additions & 28 deletions tests/dataset/test_aoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class Test_AOI(object):
def test_multiband_input(self):
"""Tests reading a multiband raster as input"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = data[0]
aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'])
Expand All @@ -44,7 +44,7 @@ def bands_request(self, request: SubRequest) -> List:

def test_multiband_input_band_selection(self, bands_request: List) -> None:
"""Tests reading a multiband raster as input with band selection"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = data[0]
aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=bands_request)
Expand All @@ -58,7 +58,7 @@ def test_multiband_input_band_selection(self, bands_request: List) -> None:

def test_multiband_input_band_selection_from_letters(self):
"""Tests error when selecting bands from a multiband raster using letters, not integers"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = data[0]
bands_request = ["R", "G"]
Expand All @@ -68,7 +68,7 @@ def test_multiband_input_band_selection_from_letters(self):

def test_multiband_input_band_selection_too_many(self):
"""Tests error when selecting too many bands from a multiband raster"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = data[0]
bands_request = [1, 2, 3, 4, 5]
Expand All @@ -78,7 +78,7 @@ def test_multiband_input_band_selection_too_many(self):

def test_singleband_input(self):
"""Tests reading a singleband raster as input with ${dataset.bands} pattern"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-singleband_ci.csv")
bands = ['R', 'G', 'B']
row = next(iter(data))
Expand All @@ -89,7 +89,7 @@ def test_singleband_input(self):

def test_stac_input(self):
"""Tests singleband raster referenced by stac item as input"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-stac_ci.csv")
bands = ['red', 'green', 'blue']
row = next(iter(data))
Expand All @@ -106,7 +106,7 @@ def test_stac_input(self):

def test_stac_url_input(self):
"""Tests download of singleband raster as url path referenced by a stac item"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-singleband-url_ci.csv")
row = next(iter(data))
aoi = AOI(
Expand All @@ -123,7 +123,7 @@ def test_stac_url_input(self):

def test_missing_label(self):
"""Tests error when provided label file is missing"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = next(iter(data))
row['gpkg'] = "missing_file.gpkg"
Expand All @@ -132,15 +132,15 @@ def test_missing_label(self):

def test_no_label(self):
"""Test when no label are provided. Should pass for inference. """
extract_archive(src="tests/data/new_brunswick_aerial.zip")
extract_archive(from_path="tests/data/new_brunswick_aerial.zip")
csv_path = "tests/inference/inference_segmentation_multiclass_no_label.csv"
aois = aois_from_csv(csv_path=csv_path, bands_requested=[1, 2, 3])
assert aois[0].label is None

def test_parse_input_raster(self) -> None:
"""Tests parsing for three accepted patterns to reference input raster data with band selection"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(src="tests/data/massachusetts_buildings_kaggle.zip")
extract_archive(from_path="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/massachusetts_buildings_kaggle.zip")
raster_raw = {
"tests/data/spacenet/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03.json": [
"red", "green", "blue"],
Expand All @@ -153,7 +153,7 @@ def test_parse_input_raster(self) -> None:

def test_corrupt_raster(self) -> None:
"""Tests error when reading a corrupt file"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = next(iter(data))
row['tif'] = "tests/data/massachusetts_buildings_kaggle/corrupt_file.tif"
Expand All @@ -163,7 +163,7 @@ def test_corrupt_raster(self) -> None:

def test_image_only(self) -> None:
"""Tests AOI creation with image only, ie no label"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = next(iter(data))
aoi = AOI(raster=row['tif'], label=None)
Expand All @@ -172,7 +172,7 @@ def test_image_only(self) -> None:

def test_filter_gdf_by_attribute(self):
"""Tests filtering features from a vector file according to an attribute field and value"""
extract_archive(src="tests/data/new_brunswick_aerial.zip")
extract_archive(from_path="tests/data/new_brunswick_aerial.zip")
data = read_csv("tests/tiling/tiling_segmentation_multiclass_ci.csv")
iterator = iter(data)
row = next(iterator)
Expand All @@ -188,7 +188,7 @@ def test_filter_gdf_by_attribute(self):

def test_missing_raster(self) -> None:
"""Tests error when pointing to missing raster"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = next(iter(data))
row['tif'] = "missing_raster.tif"
Expand All @@ -198,7 +198,7 @@ def test_missing_raster(self) -> None:

def test_wrong_split(self) -> None:
"""Tests error when setting a wrong split, ie not 'trn', 'tst' or 'inference'"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = next(iter(data))
row['split'] = "missing_split"
Expand All @@ -208,7 +208,7 @@ def test_wrong_split(self) -> None:

def test_download_data(self) -> None:
"""Tests download data"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = next(iter(data))
row['tif'] = "http://datacube-stage-data-public.s3.ca-central-1.amazonaws.com/store/imagery/optical/" \
Expand All @@ -221,7 +221,7 @@ def test_download_data(self) -> None:

def test_no_intersection(self) -> None:
"""Tests error testing no intersection between raster and label"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = next(iter(data))
row['gpkg'] = "tests/data/new_brunswick_aerial/BakerLake_2017_clipped.gpkg"
Expand All @@ -231,7 +231,7 @@ def test_no_intersection(self) -> None:

def test_write_multiband_from_single_band(self) -> None:
"""Tests the 'write_multiband' method"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-singleband_ci.csv")
row = data[0]
aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=['R', 'G', 'B'],
Expand All @@ -243,7 +243,7 @@ def test_write_multiband_from_single_band(self) -> None:

def test_write_multiband_from_single_band_url(self) -> None:
"""Tests the 'write_multiband' method with singleband raster as URL"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-singleband-url_ci.csv")
row = next(iter(data))
aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=['R', 'G', 'B'],
Expand All @@ -256,7 +256,7 @@ def test_write_multiband_from_single_band_url(self) -> None:

def test_write_multiband_not_applicable(self) -> None:
"""Tests the skipping of 'write_multiband' method"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = next(iter(data))
aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=[1, 2, 3],
Expand All @@ -266,7 +266,7 @@ def test_write_multiband_not_applicable(self) -> None:

def test_download_true_not_url(self) -> None:
"""Tests AOI creation if download_data set to True, but not necessary (local image)"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-singleband_ci.csv")
row = next(iter(data))
aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], download_data=True,
Expand All @@ -275,7 +275,7 @@ def test_download_true_not_url(self) -> None:

def test_raster_stats_from_stac(self) -> None:
"""Tests the calculation of statistics of raster data as stac item from an AOI instance"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-stac_ci.csv")
bands_request = ['red', 'green', 'blue']
expected_stats = {
Expand All @@ -297,7 +297,7 @@ def test_raster_stats_from_stac(self) -> None:

def test_raster_stats_not_stac(self) -> None:
"""Tests the calculation of statistics of local multiband raster data from an AOI instance"""
extract_archive(src="tests/data/new_brunswick_aerial.zip")
extract_archive(from_path="tests/data/new_brunswick_aerial.zip")
data = read_csv("tests/tiling/tiling_segmentation_multiclass_ci.csv")
expected_stats = {
'band_0': {'statistics': {'minimum': 11, 'maximum': 254, 'mean': 159.36075617930456, 'median': 165.0,
Expand All @@ -318,7 +318,7 @@ def test_raster_stats_not_stac(self) -> None:

def test_to_dict(self):
"""Test the 'to_dict()' method on an AOI instance"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-stac_ci.csv")
bands = ['red', 'green', 'blue']
row = next(iter(data))
Expand Down Expand Up @@ -368,7 +368,7 @@ def test_name_raster(self) -> None:

def test_is_low_contrast(self):
"""Test raster contrast (high | low)"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = data[0]
aoi = AOI(
Expand All @@ -392,7 +392,7 @@ def test_is_low_contrast(self):

def test_equalize_hist_raster(self):
"""Test equalize input raster with CLAHE transform"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = data[0]
aoi = AOI(
Expand All @@ -411,7 +411,7 @@ def test_equalize_hist_raster(self):

def test_equalize_hist_raster_per_band(self):
"""Test equalize input raster per band with CLAHE transform"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
row = data[0]
aoi = AOI(
Expand Down
6 changes: 3 additions & 3 deletions tests/dataset/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class TestDRDataset:
"tests/data/massachusetts_buildings_kaggle/23429155_15_uint8_clipped.tif"]
)
def raster_dataset(self, request: SubRequest) -> DatasetReader:
extract_archive(src="tests/data/massachusetts_buildings_kaggle.zip")
extract_archive(from_path="tests/data/massachusetts_buildings_kaggle.zip")
image = request.param
dr_ds = rasterio.open(image)
return dr_ds
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_init(self):
"tests/data/massachusetts_buildings_kaggle/23429155_15.gpkg"]
)
def vector_dataset(self, request: SubRequest) -> GDLVectorDataset:
extract_archive(src="tests/data/massachusetts_buildings_kaggle.zip")
extract_archive(from_path="tests/data/massachusetts_buildings_kaggle.zip")
fp = request.param
vec_ds = GDLVectorDataset(fp)
return vec_ds
Expand Down Expand Up @@ -158,7 +158,7 @@ class TestIntersectionCustomDatasets:
)
)
def dataset(self, request: SubRequest) -> List:
extract_archive(src="tests/data/massachusetts_buildings_kaggle.zip")
extract_archive(from_path="tests/data/massachusetts_buildings_kaggle.zip")
image = request.param[0]
dr_ds = rasterio.open(image)
raster_ds = DRDataset(dr_ds)
Expand Down
6 changes: 3 additions & 3 deletions tests/dataset/test_stacitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class Test_SingleBandItemEO(object):
def test_stac_input_missing_band(self):
"""Tests error when requesting non-existing singleband input rasters from stac item"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-stac_ci.csv")
row = next(iter(data))
with pytest.raises(ValueError):
Expand All @@ -18,8 +18,8 @@ def test_stac_input_missing_band(self):

def test_stac_input_empty_band_request(self):
"""Tests error when band selection is required (stac item) but missing"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(src="tests/data/massachusetts_buildings_kaggle.zip")
extract_archive(from_path="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/massachusetts_buildings_kaggle.zip")
stac_item_path = "tests/data/spacenet/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03.json"
with pytest.raises(ValueError):
item = SingleBandItemEO(item=pystac.Item.from_file(stac_item_path),
Expand Down
12 changes: 6 additions & 6 deletions tests/tiling/test_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class TestTiler(object):
def test_tiling_per_aoi(self):
extract_archive(src="tests/data/massachusetts_buildings_kaggle.zip")
extract_archive(from_path="tests/data/massachusetts_buildings_kaggle.zip")
img = "tests/data/massachusetts_buildings_kaggle/22978945_15_uint8_clipped.tif"
gt = "tests/data/massachusetts_buildings_kaggle/22978945_15.gpkg"
my_aoi = AOI(raster=img, raster_bands_request=[1, 2, 3], label=gt, split='trn')
Expand All @@ -40,7 +40,7 @@ def test_tiling_per_aoi(self):

def test_passes_min_annot(self):
"""Tests annotated percent calculation"""
extract_archive(src="tests/data/spacenet.zip")
extract_archive(from_path="tests/data/spacenet.zip")
img = "tests/data/spacenet/SN7_global_monthly_2020_01_mosaic_L15-0331E-1257N_1327_3160_13_uint8_clipped.tif"
gt = "tests/data/spacenet/SN7_global_monthly_2020_01_mosaic_L15-0331E-1257N_1327_3160_13_uint8_clipped.gpkg"
my_aoi = AOI(raster=img, raster_bands_request=[1, 2, 3], label=gt, split='trn')
Expand All @@ -62,7 +62,7 @@ def test_passes_min_annot(self):

def test_burn_gt_patch(self):
"""Tests burning a label while using the filter for attribute field and values"""
extract_archive(src="tests/data/new_brunswick_aerial.zip")
extract_archive(from_path="tests/data/new_brunswick_aerial.zip")
img = "tests/data/new_brunswick_aerial/23322E759967N_clipped_1m_1of2.tif"
gt = "tests/data/new_brunswick_aerial/BakerLake_2017_clipped.gpkg"
my_aoi = AOI(
Expand Down Expand Up @@ -178,7 +178,7 @@ def test__parse_torchgeo_batch(self):

def test__define_output_name(self):
""" Test _define_output_name method of the Tiler class """
extract_archive(src="tests/data/massachusetts_buildings_kaggle.zip")
extract_archive(from_path="tests/data/massachusetts_buildings_kaggle.zip")
img = "tests/data/massachusetts_buildings_kaggle/22978945_15_uint8_clipped.tif"
gt = "tests/data/massachusetts_buildings_kaggle/22978945_15.gpkg"
my_aoi = AOI(raster=img, raster_bands_request=[1, 2, 3], label=gt, split='trn')
Expand All @@ -199,7 +199,7 @@ def test__save_vec_mem_tile(self):
""" Test _save_vec_mem_tile method of the Tiler class """
""" Test _define_output_name method of the Tiler class """
try:
extract_archive(src="tests/data/massachusetts_buildings_kaggle_patch.zip")
extract_archive(from_path="tests/data/massachusetts_buildings_kaggle_patch.zip")
except FileNotFoundError:
pass
gt = "tests/data/massachusetts_buildings_kaggle_patch/massachusetts_buildings_kaggle_patch.gpkg"
Expand Down Expand Up @@ -261,7 +261,7 @@ def test__save_vec_mem_tile(self):

def test_tiling_per_aoi_append_mode(self):
"""Tests tiling's append mode"""
extract_archive(src="tests/data/massachusetts_buildings_kaggle.zip")
extract_archive(from_path="tests/data/massachusetts_buildings_kaggle.zip")
img = "tests/data/massachusetts_buildings_kaggle/22978945_15_uint8_clipped.tif"
gt = "tests/data/massachusetts_buildings_kaggle/22978945_15.gpkg"
my_aoi = AOI(raster=img, raster_bands_request=[1, 2, 3], label=gt, split='trn')
Expand Down
Loading

0 comments on commit 6efd4f1

Please sign in to comment.