Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GeoSampler updates #2

Open
wants to merge 12 commits into
base: vers_working_branch
Choose a base branch
from
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"python.testing.pytestArgs": ["tests"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
Binary file added tests/data/samplers/filtering_4x4.feather
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/samplers/filtering_4x4/filtering_4x4.cpg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ISO-8859-1
Binary file added tests/data/samplers/filtering_4x4/filtering_4x4.dbf
Binary file not shown.
1 change: 1 addition & 0 deletions tests/data/samplers/filtering_4x4/filtering_4x4.prj
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
PROJCS["NAD_1983_BC_Environment_Albers",GEOGCS["GCS_North_American_1983",DATUM["D_North_American_1983",SPHEROID["GRS_1980",6378137.0,298.257222101]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Albers"],PARAMETER["False_Easting",1000000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",-126.0],PARAMETER["Standard_Parallel_1",50.0],PARAMETER["Standard_Parallel_2",58.5],PARAMETER["Latitude_Of_Origin",45.0],UNIT["Meter",1.0]]
Binary file added tests/data/samplers/filtering_4x4/filtering_4x4.shp
Binary file not shown.
Binary file added tests/data/samplers/filtering_4x4/filtering_4x4.shx
Binary file not shown.
140 changes: 131 additions & 9 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
# Licensed under the MIT License.

import math
from collections.abc import Iterator
import os
from itertools import product

import geopandas as gpd
import pytest
from _pytest.fixtures import SubRequest
from geopandas import GeoDataFrame
from rasterio.crs import CRS
from shapely.geometry import box
from torch.utils.data import DataLoader

from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
Expand All @@ -23,11 +26,23 @@

class CustomGeoSampler(GeoSampler):
def __init__(self) -> None:
pass
self.chips = self.get_chips()

def __iter__(self) -> Iterator[BoundingBox]:
def get_chips(self) -> GeoDataFrame:
chips = []
for i in range(len(self)):
yield BoundingBox(i, i, i, i, i, i)
chips.append(
{
'geometry': box(i, i, i, i),
'minx': i,
'miny': i,
'maxx': i,
'maxy': i,
'mint': i,
'maxt': i,
}
)
return GeoDataFrame(chips, crs='3005')

def __len__(self) -> int:
return 2
Expand All @@ -43,6 +58,17 @@
return {'index': query}


class CustomGeoDatasetSITS(GeoDataset):
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
super().__init__()
self._crs = crs
self.res = res
self.return_as_ts = True

def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
return {'index': query}


class TestGeoSampler:
@pytest.fixture(scope='class')
def dataset(self) -> CustomGeoDataset:
Expand All @@ -54,6 +80,14 @@
def sampler(self) -> CustomGeoSampler:
return CustomGeoSampler()

@pytest.fixture(scope='class')
def datadir(self) -> str:
return os.path.join('tests', 'data', 'samplers')

def test_no_get_chips_implemented(self, dataset: CustomGeoDataset) -> None:
with pytest.raises(TypeError):
GeoSampler(dataset)

def test_iter(self, sampler: CustomGeoSampler) -> None:
assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0)

Expand All @@ -64,6 +98,62 @@
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
GeoSampler(dataset) # type: ignore[abstract]

@pytest.mark.parametrize(
'filtering_file', ['filtering_4x4', 'filtering_4x4.feather']
)
def test_filtering_from_path(self, datadir: str, filtering_file: str) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10)
)
iterator = iter(sampler)

assert len(sampler) == 4
filtering_path = os.path.join(datadir, filtering_file)
sampler.filter_chips(filtering_path, 'intersects', 'drop')
assert len(sampler) == 3
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

def test_filtering_from_gdf(self, datadir: str) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10)
)
iterator = iter(sampler)

# Dropping first chip
assert len(sampler) == 4
filtering_gdf = gpd.read_file(os.path.join(datadir, 'filtering_4x4'))
sampler.filter_chips(filtering_gdf, 'intersects', 'drop')
assert len(sampler) == 3
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

# Keeping only first chip
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
iterator = iter(sampler)
sampler.filter_chips(filtering_gdf, 'intersects', 'keep')
assert len(sampler) == 1
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)

def test_set_worker_split(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(
ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10)
)
assert len(sampler) == 4
sampler.set_worker_split(total_workers=4, worker_num=1)
assert len(sampler) == 1

def test_save_chips(self, tmpdir_factory) -> None:

Check failure on line 150 in tests/samplers/test_single.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

tests/samplers/test_single.py:150:31: ANN001 Missing type annotation for function argument `tmpdir_factory`
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
sampler.save(str(tmpdir_factory.mktemp('out').join('chips')))
sampler.save(str(tmpdir_factory.mktemp('out').join('chips.feather')))

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down Expand Up @@ -139,6 +229,15 @@
for bbox in sampler:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_return_as_ts(self) -> None:
ds = CustomGeoDatasetSITS()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 10, 0, 10, 15, 20))
sampler = RandomGeoSampler(ds, 1, 5)
for query in sampler:
assert query.mint == ds.bounds.mint == 0
assert query.maxt == ds.bounds.maxt == 20

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand All @@ -156,7 +255,7 @@
def dataset(self) -> CustomGeoDataset:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 100, 200, 300, 400, 500))
ds.index.insert(1, (0, 100, 200, 300, 400, 500))
ds.index.insert(1, (0, 100, 200, 300, 500, 600))
return ds

@pytest.fixture(
Expand Down Expand Up @@ -197,13 +296,13 @@

assert math.isclose(query.maxx - query.minx, sampler.size[1])
assert math.isclose(query.maxy - query.miny, sampler.size[0])
assert math.isclose(
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
)
assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt

def test_len(self, sampler: GridGeoSampler) -> None:
rows, cols = tile_to_chips(sampler.roi, sampler.size, sampler.stride)
length = rows * cols * 2 # two items in dataset
length = (
rows * cols * 2
) # two spatially but not temporally overlapping items in dataset
assert len(sampler) == length

def test_roi(self, dataset: CustomGeoDataset) -> None:
Expand Down Expand Up @@ -243,6 +342,29 @@
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)
assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10)

def test_dataset_has_regex(self) -> None:
ds = CustomGeoDataset()
ds.filename_regex = r'.*(?P<my_key>test)'
ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'filepath_containing_key_test')
sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS)
assert 'my_key' in sampler.chips.columns

def test_dataset_has_regex_no_match(self) -> None:
ds = CustomGeoDataset()
ds.filename_regex = r'(?P<my_key>test)'
ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'no_matching_key')
sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS)
assert 'my_key' not in sampler.chips.columns

def test_return_as_ts(self) -> None:
ds = CustomGeoDatasetSITS()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 10, 0, 10, 15, 20))
sampler = GridGeoSampler(ds, 1, 1)
for query in sampler:
assert query.mint == ds.bounds.mint == 0
assert query.maxt == ds.bounds.maxt == 20

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
5 changes: 5 additions & 0 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
#: a different file format than what it was originally downloaded as.
filename_glob = '*'

# Whether to return the dataset as a Timeseries, this will add another dimension to the dataset
return_as_ts = False

# NOTE: according to the Python docs:
#
# * https://docs.python.org/3/library/exceptions.html#NotImplementedError
Expand Down Expand Up @@ -983,6 +986,7 @@ def __init__(
if not isinstance(ds, GeoDataset):
raise ValueError('IntersectionDataset only supports GeoDatasets')

self.return_as_ts = dataset1.return_as_ts or dataset2.return_as_ts
self.crs = dataset1.crs
self.res = dataset1.res

Expand Down Expand Up @@ -1143,6 +1147,7 @@ def __init__(
if not isinstance(ds, GeoDataset):
raise ValueError('UnionDataset only supports GeoDatasets')

self.return_as_ts = dataset1.return_as_ts and dataset2.return_as_ts
self.crs = dataset1.crs
self.res = dataset1.res

Expand Down
Loading
Loading