Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
sfalkena committed Sep 17, 2024
1 parent 3d74557 commit e2f19e4
Showing 1 changed file with 83 additions and 109 deletions.
192 changes: 83 additions & 109 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
import warnings
from collections.abc import Callable, Iterable, Iterator
from typing import Any
from typing import Any, List

Check failure on line 11 in torchgeo/samplers/single.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP035)

torchgeo/samplers/single.py:11:1: UP035 `typing.List` is deprecated, use `list` instead

import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -94,52 +94,54 @@ class GeoSampler(Sampler[BoundingBox], abc.ABC):
longitude, height, width, projection, coordinate system, and time.
"""

def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None:
def __init__(self,
dataset: GeoDataset,
roi: BoundingBox | List[BoundingBox] | None = None

Check failure on line 99 in torchgeo/samplers/single.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP006)

torchgeo/samplers/single.py:99:36: UP006 Use `list` instead of `List` for type annotation
) -> None:
"""Initialize a new Sampler instance.
Args:
dataset: dataset to index from
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
roi: single or multiple regions of interest to sample from (minx, maxx,
miny, maxy, mint, maxt). Defaults to the bounds of ``dataset.index``
"""
# If no roi is provided, use the entire dataset bounds
if roi is None:
self.index = dataset.index
roi = BoundingBox(*self.index.bounds)
else:
# Only keep hits unique in the spatial dimension if return_as_ts is enabled
# else keep hits that are unique in both spatial and temporal dimensions
filter_indices = slice(0, 4) if dataset.return_as_ts else slice(0, 6)

self.index = Index(interleaved=False, properties=Property(dimension=3))
if isinstance(roi, list):
for area in roi:
hits = dataset.index.intersection(tuple(area), objects=True)
for hit in hits:
bbox = BoundingBox(*hit.bounds) & area
# Filter hits
if tuple(bbox)[filter_indices] not in [
tuple(item.bounds[filter_indices])
for item in list(
self.index.intersection(tuple(area), objects=True)
)
]:
self.index.insert(hit.id, tuple(bbox), hit.object)
else:
hits = dataset.index.intersection(tuple(roi), objects=True)
for hit in hits:
bbox = BoundingBox(*hit.bounds) & roi
# Filter hits
if tuple(bbox)[filter_indices] not in [
tuple(item.bounds[filter_indices])
for item in list(self.index.intersection(tuple(roi), objects=True))
]:
self.index.insert(hit.id, tuple(bbox), hit.object)

self.rois = roi if isinstance(roi, List) else [roi]

Check failure on line 113 in torchgeo/samplers/single.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP006)

torchgeo/samplers/single.py:113:44: UP006 Use `list` instead of `List` for type annotation
# Only keep hits unique in the spatial dimension if return_as_ts is enabled
# else keep hits that are unique in both spatial and temporal dimensions
filter_indices = slice(0, 4) if dataset.return_as_ts else slice(0, 6)

self.index = Index(interleaved=False, properties=Property(dimension=3))

for roi in self.rois:
# First find all hits that intersect with the roi
print(roi)
hits = dataset.index.intersection(tuple(roi), objects=True)
for hit in hits:
bbox = BoundingBox(*hit.bounds) & roi
# Filter out hits that share the same extent and hits with zero area
if tuple(bbox)[filter_indices] not in [
tuple(item.bounds[filter_indices]) for item in list(
self.index.intersection(tuple(roi), objects=True)
)] and bbox.area > 0:
self.index.insert(hit.id, tuple(bbox), hit.object)

print(f"Index Size: {self.index.get_size()}")
self.res = dataset.res
self.roi = roi
self.dataset = dataset

@staticmethod
def __save_as_gpd_or_feather(self, path: str, gdf: GeoDataFrame, driver='ESRI Shapefile') -> None:

Check failure on line 138 in torchgeo/samplers/single.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

torchgeo/samplers/single.py:138:34: ANN001 Missing type annotation for function argument `self`

Check failure on line 138 in torchgeo/samplers/single.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

torchgeo/samplers/single.py:138:70: ANN001 Missing type annotation for function argument `driver`
"""Save a gdf as a file supported by any geopandas driver or feather file"""

Check failure on line 139 in torchgeo/samplers/single.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D415)

torchgeo/samplers/single.py:139:9: D415 First line should end with a period, question mark, or exclamation point
if path.endswith('.feather'):
chips.to_feather(path)

Check failure on line 141 in torchgeo/samplers/single.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

torchgeo/samplers/single.py:141:13: F821 Undefined name `chips`
else:
chips.to_file(path, driver=driver)

Check failure on line 143 in torchgeo/samplers/single.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

torchgeo/samplers/single.py:143:13: F821 Undefined name `chips`

@abc.abstractmethod
def get_chips(self, *args: Any, **kwargs: Any) -> GeoDataFrame:
"""Determines the way to get the extend of the chips (samples) of the dataset.
Expand Down Expand Up @@ -203,12 +205,9 @@ def set_worker_split(self, total_workers: int, worker_num: int) -> None:
"""
self.chips = np.array_split(self.chips, total_workers)[worker_num]

def save(self, path: str, driver: str = "Shapefile") -> None:
def save_chips(self, path: str, driver: str = "ESRI Shapefile") -> None:
"""Save the chips as a shapefile or feather file"""
if path.endswith('.feather'):
self.chips.to_feather(path)
else:
self.chips.to_file(path, driver=driver)
self.__save_as_gpd_or_feather(path, self.chips, driver)

def save_hits(self, path: str, driver: str = "Shapefile") -> None:
"""Save the hits as a shapefile or feather file"""
Expand All @@ -229,23 +228,21 @@ def save_hits(self, path: str, driver: str = "Shapefile") -> None:
bounds.append(bound)

bounds_gdf = GeoDataFrame(bounds, crs=self.dataset.crs)
if path.endswith('.feather'):
bounds_gdf.to_feather(path)
else:
bounds_gdf.to_file(path, driver=driver)
self.__save_as_gpd_or_feather(path, bounds_gdf, driver)

def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
Returns:
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
# sort chips based on hit_id
# sort chips based on hit_id.The idea behind this is to ensure that chips that are located
# in the same underlying hit are sampled sequentially. Together with keeping the (chunk of the)
# hit in memory by caching it, this should lead to a speed up in dataloading. TODO: support
# sorting of chunks as well.
self.chips = self.chips.sort_values(by=['hit_id'])

for _, chip in self.chips.iterrows():
print("------------------------------------")
print("Chip FID: {}".format(chip["fid"]))
yield BoundingBox(
chip.minx, chip.maxx, chip.miny, chip.maxy, chip.mint, chip.maxt
)
Expand Down Expand Up @@ -278,7 +275,7 @@ def __init__(
dataset: GeoDataset,
size: tuple[float, float] | float,
length: int | None = None,
roi: BoundingBox | None = None,
roi: BoundingBox | List[BoundingBox] | None = None,
units: Units = Units.PIXELS,
) -> None:
"""Initialize a new Sampler instance.
Expand Down Expand Up @@ -314,47 +311,23 @@ def __init__(
self.size = (self.size[0] * self.res, self.size[1] * self.res)

num_chips = 0
self.hits = []
areas = []
# use only geospatial bounds for filering in case return_as_ts is set to True
filter_indices = slice(0,4) if dataset.return_as_ts else slice(0,6)

if isinstance(self.roi, list):
for area in self.roi:
for hit in self.index.intersection(tuple(area), objects=True):
# Filter out hits in the index that share the same extent
if hit.bounds[filter_indices] not in [ht.bounds[filter_indices] for ht in self.hits]:
bounds = BoundingBox(*hit.bounds)
if (
bounds.maxx - bounds.minx >= self.size[1]
and bounds.maxy - bounds.miny >= self.size[0]
):
if bounds.area > 0:
rows, cols = tile_to_chips(bounds, self.size)
num_chips += rows * cols
else:
num_chips += 1
self.hits.append(hit)
areas.append(bounds.area)
else:
for hit in self.index.intersection(tuple(self.roi), objects=True):
# Filter out hits in the index that share the same extent
if hit.bounds[filter_indices] not in [ht.bounds[filter_indices] for ht in self.hits]:
bounds = BoundingBox(*hit.bounds)
if (
bounds.maxx - bounds.minx >= self.size[1]
and bounds.maxy - bounds.miny >= self.size[0]
):
if bounds.area > 0:
rows, cols = tile_to_chips(bounds, self.size)
num_chips += rows * cols
else:
num_chips += 1
self.hits.append(hit)
areas.append(bounds.area)
if length is not None:
num_chips = length
self.length = num_chips
self.areas_per_roi = []

for roi in self.rois:
areas = []
for hit in self.index.intersection(tuple(roi), objects=True):
# Filter out hits that are smaller than the chip size
hit_bounds = BoundingBox(*hit.bounds)
if (
hit_bounds.maxx - hit_bounds.minx >= self.size[1]
and hit_bounds.maxy - hit_bounds.miny >= self.size[0]
):
rows, cols = tile_to_chips(hit_bounds, self.size)
num_chips += rows * cols
areas.append(hit_bounds.area)
self.areas_per_roi.append(areas)

self.length = length or num_chips

print(f"Unique geospatial file hits: {len(self.hits)}")

Expand All @@ -363,34 +336,33 @@ def __init__(
if torch.sum(self.areas) == 0:
self.areas += 1

self.chips = self.get_chips(num_samples=num_chips)
self.chips = self.get_chips(num_samples=self.length)

def get_chips(self, num_samples) -> GeoDataFrame:
chips = []
print('generating samples... ')
while len(chips) < num_samples:
if isinstance(self.roi, list):
# Choose a random ROI, weighted by area
idx = torch.multinomial(torch.tensor([roi.area for roi in self.roi], dtype=torch.float), 1)
roi = self.roi[idx]
else:
roi = self.roi

# Choose a random bounding box within the ROI
bbox = get_random_bounding_box(roi, self.size, self.res)
minx, maxx, miny, maxy, mint, maxt = tuple(bbox)
# Choose a random ROI, weighted by area
roi_idx = torch.multinomial(torch.tensor([roi.area for roi in self.rois], dtype=torch.float), 1)
roi = self.rois[roi_idx]

# Find file hits in the index for chosen bounding box
hits = [hit for hit in list(self.index.intersection(tuple(roi), objects=True)) if BoundingBox(*hit.bounds).area > 0]
# Find the hits for the chosen ROI
hits = list(self.index.intersection(tuple(roi), objects=True))

# if the bounding box has no file hit, dont include it in the samples
# if the roi has no hits, dont try to sample from it
if len(hits) == 0:
continue
else:
# Choose a random hit
hit = random.choice(hits)
# Choose a random hit, weighted by area
hit_idx = torch.multinomial(torch.tensor(self.areas_per_roi[roi_idx], dtype=torch.float), 1)
assert len(self.areas_per_roi[roi_idx]) == len(hits)

# in case we are randomly sampling also the temporal dimension, get
# mint, maxt from the randomly chosen hit
# Choose a random bounding box within the hit
bbox = get_random_bounding_box(hits[hit_idx], self.size, self.res)
minx, maxx, miny, maxy, mint, maxt = tuple(bbox)

# If the dataset is a SITS dataset, the mint and maxt are the same for all hits
# If not, we also sample across time.
if self.dataset.return_as_ts:
mint = self.dataset.bounds.mint
maxt = self.dataset.bounds.maxt
Expand All @@ -410,7 +382,6 @@ def get_chips(self, num_samples) -> GeoDataFrame:
}
chips.append(chip)

print('creating geodataframe... ')
chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs)
chips_gdf['fid'] = chips_gdf.index

Expand Down Expand Up @@ -440,7 +411,7 @@ def __init__(
dataset: GeoDataset,
size: tuple[float, float] | float,
stride: tuple[float, float] | float,
roi: BoundingBox | None = None,
roi: BoundingBox | List[BoundingBox] | None = None,
units: Units = Units.PIXELS,
) -> None:
"""Initialize a new Sampler instance.
Expand Down Expand Up @@ -562,7 +533,10 @@ class PreChippedGeoSampler(GeoSampler):
"""

def __init__(
self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False
self,
dataset: GeoDataset,
roi: BoundingBox | List[BoundingBox] | None = None,
shuffle: bool = False
) -> None:
"""Initialize a new Sampler instance.
Expand Down

0 comments on commit e2f19e4

Please sign in to comment.