diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index f8cd72b2fa..b3003a0d5b 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -8,7 +8,7 @@ import re import warnings from collections.abc import Callable, Iterable, Iterator -from typing import Any +from typing import Any, List import geopandas as gpd import numpy as np @@ -94,52 +94,54 @@ class GeoSampler(Sampler[BoundingBox], abc.ABC): longitude, height, width, projection, coordinate system, and time. """ - def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: + def __init__(self, + dataset: GeoDataset, + roi: BoundingBox | List[BoundingBox] | None = None + ) -> None: """Initialize a new Sampler instance. Args: dataset: dataset to index from - roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) - (defaults to the bounds of ``dataset.index``) + roi: single or multiple regions of interest to sample from (minx, maxx, + miny, maxy, mint, maxt). Defaults to the bounds of ``dataset.index`` """ + # If no roi is provided, use the entire dataset bounds if roi is None: self.index = dataset.index roi = BoundingBox(*self.index.bounds) - else: - # Only keep hits unique in the spatial dimension if return_as_ts is enabled - # else keep hits that are unique in both spatial and temporal dimensions - filter_indices = slice(0, 4) if dataset.return_as_ts else slice(0, 6) - - self.index = Index(interleaved=False, properties=Property(dimension=3)) - if isinstance(roi, list): - for area in roi: - hits = dataset.index.intersection(tuple(area), objects=True) - for hit in hits: - bbox = BoundingBox(*hit.bounds) & area - # Filter hits - if tuple(bbox)[filter_indices] not in [ - tuple(item.bounds[filter_indices]) - for item in list( - self.index.intersection(tuple(area), objects=True) - ) - ]: - self.index.insert(hit.id, tuple(bbox), hit.object) - else: - hits = dataset.index.intersection(tuple(roi), objects=True) - for hit in hits: - bbox = BoundingBox(*hit.bounds) & roi - # Filter hits - if tuple(bbox)[filter_indices] not in [ - tuple(item.bounds[filter_indices]) - for item in list(self.index.intersection(tuple(roi), objects=True)) - ]: - self.index.insert(hit.id, tuple(bbox), hit.object) + + self.rois = roi if isinstance(roi, List) else [roi] + # Only keep hits unique in the spatial dimension if return_as_ts is enabled + # else keep hits that are unique in both spatial and temporal dimensions + filter_indices = slice(0, 4) if dataset.return_as_ts else slice(0, 6) + + self.index = Index(interleaved=False, properties=Property(dimension=3)) + + for roi in self.rois: + # First find all hits that intersect with the roi + print(roi) + hits = dataset.index.intersection(tuple(roi), objects=True) + for hit in hits: + bbox = BoundingBox(*hit.bounds) & roi + # Filter out hits that share the same extent and hits with zero area + if tuple(bbox)[filter_indices] not in [ + tuple(item.bounds[filter_indices]) for item in list( + self.index.intersection(tuple(roi), objects=True) + )] and bbox.area > 0: + self.index.insert(hit.id, tuple(bbox), hit.object) print(f"Index Size: {self.index.get_size()}") self.res = dataset.res - self.roi = roi self.dataset = dataset + @staticmethod + def __save_as_gpd_or_feather(self, path: str, gdf: GeoDataFrame, driver='ESRI Shapefile') -> None: + """Save a gdf as a file supported by any geopandas driver or feather file""" + if path.endswith('.feather'): + chips.to_feather(path) + else: + chips.to_file(path, driver=driver) + @abc.abstractmethod def get_chips(self, *args: Any, **kwargs: Any) -> GeoDataFrame: """Determines the way to get the extend of the chips (samples) of the dataset. @@ -203,12 +205,9 @@ def set_worker_split(self, total_workers: int, worker_num: int) -> None: """ self.chips = np.array_split(self.chips, total_workers)[worker_num] - def save(self, path: str, driver: str = "Shapefile") -> None: + def save_chips(self, path: str, driver: str = "ESRI Shapefile") -> None: """Save the chips as a shapefile or feather file""" - if path.endswith('.feather'): - self.chips.to_feather(path) - else: - self.chips.to_file(path, driver=driver) + self.__save_as_gpd_or_feather(path, self.chips, driver) def save_hits(self, path: str, driver: str = "Shapefile") -> None: """Save the hits as a shapefile or feather file""" @@ -229,10 +228,7 @@ def save_hits(self, path: str, driver: str = "Shapefile") -> None: bounds.append(bound) bounds_gdf = GeoDataFrame(bounds, crs=self.dataset.crs) - if path.endswith('.feather'): - bounds_gdf.to_feather(path) - else: - bounds_gdf.to_file(path, driver=driver) + self.__save_as_gpd_or_feather(path, bounds_gdf, driver) def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. @@ -240,12 +236,13 @@ def __iter__(self) -> Iterator[BoundingBox]: Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ - # sort chips based on hit_id + # sort chips based on hit_id.The idea behind this is to ensure that chips that are located + # in the same underlying hit are sampled sequentially. Together with keeping the (chunk of the) + # hit in memory by caching it, this should lead to a speed up in dataloading. TODO: support + # sorting of chunks as well. self.chips = self.chips.sort_values(by=['hit_id']) for _, chip in self.chips.iterrows(): - print("------------------------------------") - print("Chip FID: {}".format(chip["fid"])) yield BoundingBox( chip.minx, chip.maxx, chip.miny, chip.maxy, chip.mint, chip.maxt ) @@ -278,7 +275,7 @@ def __init__( dataset: GeoDataset, size: tuple[float, float] | float, length: int | None = None, - roi: BoundingBox | None = None, + roi: BoundingBox | List[BoundingBox] | None = None, units: Units = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -314,47 +311,23 @@ def __init__( self.size = (self.size[0] * self.res, self.size[1] * self.res) num_chips = 0 - self.hits = [] - areas = [] - # use only geospatial bounds for filering in case return_as_ts is set to True - filter_indices = slice(0,4) if dataset.return_as_ts else slice(0,6) - - if isinstance(self.roi, list): - for area in self.roi: - for hit in self.index.intersection(tuple(area), objects=True): - # Filter out hits in the index that share the same extent - if hit.bounds[filter_indices] not in [ht.bounds[filter_indices] for ht in self.hits]: - bounds = BoundingBox(*hit.bounds) - if ( - bounds.maxx - bounds.minx >= self.size[1] - and bounds.maxy - bounds.miny >= self.size[0] - ): - if bounds.area > 0: - rows, cols = tile_to_chips(bounds, self.size) - num_chips += rows * cols - else: - num_chips += 1 - self.hits.append(hit) - areas.append(bounds.area) - else: - for hit in self.index.intersection(tuple(self.roi), objects=True): - # Filter out hits in the index that share the same extent - if hit.bounds[filter_indices] not in [ht.bounds[filter_indices] for ht in self.hits]: - bounds = BoundingBox(*hit.bounds) - if ( - bounds.maxx - bounds.minx >= self.size[1] - and bounds.maxy - bounds.miny >= self.size[0] - ): - if bounds.area > 0: - rows, cols = tile_to_chips(bounds, self.size) - num_chips += rows * cols - else: - num_chips += 1 - self.hits.append(hit) - areas.append(bounds.area) - if length is not None: - num_chips = length - self.length = num_chips + self.areas_per_roi = [] + + for roi in self.rois: + areas = [] + for hit in self.index.intersection(tuple(roi), objects=True): + # Filter out hits that are smaller than the chip size + hit_bounds = BoundingBox(*hit.bounds) + if ( + hit_bounds.maxx - hit_bounds.minx >= self.size[1] + and hit_bounds.maxy - hit_bounds.miny >= self.size[0] + ): + rows, cols = tile_to_chips(hit_bounds, self.size) + num_chips += rows * cols + areas.append(hit_bounds.area) + self.areas_per_roi.append(areas) + + self.length = length or num_chips print(f"Unique geospatial file hits: {len(self.hits)}") @@ -363,34 +336,33 @@ def __init__( if torch.sum(self.areas) == 0: self.areas += 1 - self.chips = self.get_chips(num_samples=num_chips) + self.chips = self.get_chips(num_samples=self.length) def get_chips(self, num_samples) -> GeoDataFrame: chips = [] + print('generating samples... ') while len(chips) < num_samples: - if isinstance(self.roi, list): - # Choose a random ROI, weighted by area - idx = torch.multinomial(torch.tensor([roi.area for roi in self.roi], dtype=torch.float), 1) - roi = self.roi[idx] - else: - roi = self.roi - - # Choose a random bounding box within the ROI - bbox = get_random_bounding_box(roi, self.size, self.res) - minx, maxx, miny, maxy, mint, maxt = tuple(bbox) + # Choose a random ROI, weighted by area + roi_idx = torch.multinomial(torch.tensor([roi.area for roi in self.rois], dtype=torch.float), 1) + roi = self.rois[roi_idx] - # Find file hits in the index for chosen bounding box - hits = [hit for hit in list(self.index.intersection(tuple(roi), objects=True)) if BoundingBox(*hit.bounds).area > 0] + # Find the hits for the chosen ROI + hits = list(self.index.intersection(tuple(roi), objects=True)) - # if the bounding box has no file hit, dont include it in the samples + # if the roi has no hits, dont try to sample from it if len(hits) == 0: continue else: - # Choose a random hit - hit = random.choice(hits) + # Choose a random hit, weighted by area + hit_idx = torch.multinomial(torch.tensor(self.areas_per_roi[roi_idx], dtype=torch.float), 1) + assert len(self.areas_per_roi[roi_idx]) == len(hits) - # in case we are randomly sampling also the temporal dimension, get - # mint, maxt from the randomly chosen hit + # Choose a random bounding box within the hit + bbox = get_random_bounding_box(hits[hit_idx], self.size, self.res) + minx, maxx, miny, maxy, mint, maxt = tuple(bbox) + + # If the dataset is a SITS dataset, the mint and maxt are the same for all hits + # If not, we also sample across time. if self.dataset.return_as_ts: mint = self.dataset.bounds.mint maxt = self.dataset.bounds.maxt @@ -410,7 +382,6 @@ def get_chips(self, num_samples) -> GeoDataFrame: } chips.append(chip) - print('creating geodataframe... ') chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) chips_gdf['fid'] = chips_gdf.index @@ -440,7 +411,7 @@ def __init__( dataset: GeoDataset, size: tuple[float, float] | float, stride: tuple[float, float] | float, - roi: BoundingBox | None = None, + roi: BoundingBox | List[BoundingBox] | None = None, units: Units = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -562,7 +533,10 @@ class PreChippedGeoSampler(GeoSampler): """ def __init__( - self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False + self, + dataset: GeoDataset, + roi: BoundingBox | List[BoundingBox] | None = None, + shuffle: bool = False ) -> None: """Initialize a new Sampler instance.