From 042539b0bebd36413183a0676e86120369e54f0b Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Tue, 27 Aug 2024 11:37:52 +0200 Subject: [PATCH 01/10] add return_as_ts, support ts in randomsampler --- torchgeo/datasets/geo.py | 3 +++ torchgeo/samplers/single.py | 13 ++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 64f7ed1ceb0..dc9178ad223 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -96,6 +96,9 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): #: other datasets. It should not include a file extension, as the dataset may be in #: a different file format than what it was originally downloaded as. filename_glob = '*' + + # Whether to return the dataset as a Timeseries, this will add another dimension to the dataset + return_as_ts = False # NOTE: according to the Python docs: # diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 094142cb647..2617a6d41b3 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -53,9 +53,11 @@ def __iter__(self) -> Iterator[BoundingBox]: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ - class RandomGeoSampler(GeoSampler): - """Samples elements from a region of interest randomly. + """Differs from TrochGeo's RandomGeoSampler in that it can sample SITS data. + + Documentation from TorchGeo: + Samples elements from a region of interest randomly. This is particularly useful during training when you want to maximize the size of the dataset and return as many random :term:`chips ` as possible. Note that @@ -101,6 +103,7 @@ def __init__( """ super().__init__(dataset, roi) self.size = _to_tuple(size) + self.dataset = dataset if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) @@ -139,8 +142,12 @@ def __iter__(self) -> Iterator[BoundingBox]: # Choose a random tile, weighted by area idx = torch.multinomial(self.areas, 1) hit = self.hits[idx] - bounds = BoundingBox(*hit.bounds) + hit_bounds = hit.bounds + if self.dataset.return_as_ts: + hit_bounds[-2] = self.dataset.bounds.mint + hit_bounds[-1] = self.dataset.bounds.maxt + bounds = BoundingBox(*hit_bounds) # Choose a random index within that tile bounding_box = get_random_bounding_box(bounds, self.size, self.res) From d233b3a9e8404873268f83dcd9799b8e5bf3f8ab Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Tue, 27 Aug 2024 16:17:53 +0200 Subject: [PATCH 02/10] Move VERS samplers into torchgeo samplers, implement pre-chipping everywhere --- torchgeo/datasets/geo.py | 2 + torchgeo/samplers/single.py | 270 ++++++++++++++++++++++++++++-------- 2 files changed, 217 insertions(+), 55 deletions(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index dc9178ad223..f0d217ea640 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -983,6 +983,7 @@ def __init__( if not isinstance(ds, GeoDataset): raise ValueError('IntersectionDataset only supports GeoDatasets') + self.return_as_ts = dataset1.return_as_ts or dataset2.return_as_ts self.crs = dataset1.crs self.res = dataset1.res @@ -1143,6 +1144,7 @@ def __init__( if not isinstance(ds, GeoDataset): raise ValueError('UnionDataset only supports GeoDatasets') + self.return_as_ts = dataset1.return_as_ts and dataset2.return_as_ts self.crs = dataset1.crs self.res = dataset1.res diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 2617a6d41b3..b4dc4befac7 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -13,7 +13,46 @@ from ..datasets import BoundingBox, GeoDataset from .constants import Units from .utils import _to_tuple, get_random_bounding_box, tile_to_chips +from geopandas import GeoDataFrame +from tqdm import tqdm +from shapely.geometry import box +import re +import pandas as pd +def _get_regex_groups_as_df(dataset, hits): + """ + Extracts the regex metadata from a list of hits. + + Args: + dataset (GeoDataset): The dataset to sample from. + hits (list): A list of hits. + + Returns: + pandas.DataFrame: A DataFrame containing the extracted file metadata. + """ + has_filename_regex = bool(getattr(dataset, "filename_regex", None)) + if has_filename_regex: + filename_regex = re.compile(dataset.filename_regex, re.VERBOSE) + file_metadata = [] + for hit in hits: + if has_filename_regex: + match = re.match(filename_regex, str(hit.object)) + if match: + meta = match.groupdict() + else: + meta = {} + meta.update( + { + "minx": hit.bounds[0], + "maxx": hit.bounds[1], + "miny": hit.bounds[2], + "maxy": hit.bounds[3], + "mint": hit.bounds[4], + "maxt": hit.bounds[5], + } + ) + file_metadata.append(meta) + return pd.DataFrame(file_metadata) class GeoSampler(Sampler[BoundingBox], abc.ABC): """Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`. @@ -44,14 +83,80 @@ def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: self.res = dataset.res self.roi = roi + self.dataset = dataset + + @abc.abstractmethod + def get_chips(self) -> GeoDataFrame: + """Determines the way to get the extend of the chips (samples) of the dataset. + Should return a GeoDataFrame with the extend of the chips with the columns + geometry, minx, miny, maxx, maxy, mint, maxt, fid. Each row is a chip.""" + raise NotImplementedError + + + def filter_chips( + self, + filter_by: str | GeoDataFrame, + predicate: str = "intersects", + action: str = "keep", + ) -> None: + """Filter the default set of chips in the sampler down to a specific subset by + specifying files supported by geopandas such as shapefiles, geodatabases or + feather files. + + Args: + filter_by: The file or geodataframe for which the geometries will be used during filtering + predicate: Predicate as used in Geopandas sindex.query_bulk + action: What to do with the chips that satisfy the condition by the predicacte. + Can either be "drop" or "keep". + """ + prefilter_leng = len(self.chips) + filtering_gdf = load_file(filter_by).to_crs(self.dataset.crs) + self.chips = filter_tiles( + self.chips, filtering_gdf, predicate, action + ).reset_index(drop=True) + self.chips.fid = self.chips.index + print(f"Filter step reduced chips from {prefilter_leng} to {len(self.chips)}") + assert not self.chips.empty, "No chips left after filtering!" + + def set_worker_split(self, total_workers: int, worker_num: int) -> None: + """Splits the chips in n equal parts for the number of workers and keeps the set of + chips for the specific worker id, convenient if you want to split the chips across + multiple dataloaders for multi-gpu inference. + + Args: + total_workers: The total number of parts to split the chips + worker_num: The id of the worker (which part to keep), starts from 0 + + """ + self.chips = np.array_split(self.chips, total_workers)[worker_num] + + def save(self, + path: str, + driver: str = None) -> None: + """Save the chips as a shapefile or feather file""" + if path.endswith(".feather"): + self.chips.to_feather(path) + else: + self.chips.to_file(path, driver=driver) - @abc.abstractmethod def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ + for _, chip in self.chips.iterrows(): + yield BoundingBox( + chip.minx, chip.maxx, chip.miny, chip.maxy, chip.mint, chip.maxt + ) + + def __len__(self) -> int: + """Return the number of samples over the ROI. + + Returns: + number of patches that will be sampled + """ + return len(self.chips) class RandomGeoSampler(GeoSampler): """Differs from TrochGeo's RandomGeoSampler in that it can sample SITS data. @@ -103,7 +208,6 @@ def __init__( """ super().__init__(dataset, roi) self.size = _to_tuple(size) - self.dataset = dataset if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) @@ -132,13 +236,12 @@ def __init__( if torch.sum(self.areas) == 0: self.areas += 1 - def __iter__(self) -> Iterator[BoundingBox]: - """Return the index of a dataset. + self.chips = self.get_chips() + - Returns: - (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset - """ - for _ in range(len(self)): + def get_chips(self) -> GeoDataFrame: + chips = [] + for _ in tqdm(range(len(self))): # Choose a random tile, weighted by area idx = torch.multinomial(self.areas, 1) hit = self.hits[idx] @@ -149,9 +252,28 @@ def __iter__(self) -> Iterator[BoundingBox]: bounds = BoundingBox(*hit_bounds) # Choose a random index within that tile - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bbox = get_random_bounding_box(bounds, self.size, self.res) + minx, maxx, miny, maxy, mint, maxt = tuple(bbox) + chip = { + "geometry": box(minx, miny, maxx, maxy), + "minx": minx, + "miny": miny, + "maxx": maxx, + "maxy": maxy, + "mint": mint, + "maxt": maxt, + } + chips.append(chip) + + if chips: + print("creating geodataframe... ") + chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) + chips_gdf["fid"] = chips_gdf.index - yield bounding_box + else: + warnings.warn("Sampler has no chips, check your inputs") + chips_gdf = GeoDataFrame() + return chips_gdf def __len__(self) -> int: """Return the number of samples in a single epoch. @@ -213,33 +335,38 @@ def __init__( self.size = (self.size[0] * self.res, self.size[1] * self.res) self.stride = (self.stride[0] * self.res, self.stride[1] * self.res) - self.hits = [] - for hit in self.index.intersection(tuple(self.roi), objects=True): - bounds = BoundingBox(*hit.bounds) - if ( - bounds.maxx - bounds.minx >= self.size[1] - and bounds.maxy - bounds.miny >= self.size[0] - ): - self.hits.append(hit) + hits = self.index.intersection(tuple(self.roi), objects=True) + df_path = _get_regex_groups_as_df(self.dataset, hits) - self.length = 0 - for hit in self.hits: - bounds = BoundingBox(*hit.bounds) - rows, cols = tile_to_chips(bounds, self.size, self.stride) - self.length += rows * cols + # Filter out tiles smaller than the chip size + self.df_path = df_path[ + (df_path.maxx - df_path.minx >= self.size[1]) + & (df_path.maxy - df_path.miny >= self.size[0]) + ] - def __iter__(self) -> Iterator[BoundingBox]: - """Return the index of a dataset. + # Filter out hits in the index that share the same extent + if self.dataset.return_as_ts: + self.df_path.drop_duplicates( + subset=["minx", "maxx", "miny", "maxy"], inplace=True + ) + else: + self.df_path.drop_duplicates( + subset=["minx", "maxx", "miny", "maxy", "mint", "maxt"], inplace=True + ) + + self.chips = self.get_chips() - Returns: - (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset - """ - # For each tile... - for hit in self.hits: - bounds = BoundingBox(*hit.bounds) + + def get_chips(self) -> GeoDataFrame: + print("generating samples... ") + optional_keys = ["tile", "date"] + self.length = 0 + chips = [] + for _, row in tqdm(self.df_path.iterrows(), total=len(self.df_path)): + bounds = BoundingBox( + row.minx, row.maxx, row.miny, row.maxy, row.mint, row.maxt + ) rows, cols = tile_to_chips(bounds, self.size, self.stride) - mint = bounds.mint - maxt = bounds.maxt # For each row... for i in range(rows): @@ -251,15 +378,37 @@ def __iter__(self) -> Iterator[BoundingBox]: minx = bounds.minx + j * self.stride[1] maxx = minx + self.size[1] - yield BoundingBox(minx, maxx, miny, maxy, mint, maxt) - - def __len__(self) -> int: - """Return the number of samples over the ROI. + if self.dataset.return_as_ts: + mint = self.dataset.bounds.mint + maxt = self.dataset.bounds.maxt + else: + mint = bounds.mint + maxt = bounds.maxt + + chip = { + "geometry": box(minx, miny, maxx, maxy), + "minx": minx, + "miny": miny, + "maxx": maxx, + "maxy": maxy, + "mint": mint, + "maxt": maxt, + } + for key in optional_keys: + if key in row.keys(): + chip[key] = row[key] + + chips.append(chip) + + if chips: + print("creating geodataframe... ") + chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) + chips_gdf["fid"] = chips_gdf.index - Returns: - number of patches that will be sampled - """ - return self.length + else: + warnings.warn("Sampler has no chips, check your inputs") + chips_gdf = GeoDataFrame() + return chips_gdf class PreChippedGeoSampler(GeoSampler): @@ -294,25 +443,36 @@ def __init__( self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): - self.hits.append(hit) + self.hits.append(hit)\ + + self.chips = get_chips(self) - def __iter__(self) -> Iterator[BoundingBox]: - """Return the index of a dataset. + def get_chips(self) -> GeoDataFrame: - Returns: - (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset - """ generator: Callable[[int], Iterable[int]] = range if self.shuffle: generator = torch.randperm + chips = [] for idx in generator(len(self)): - yield BoundingBox(*self.hits[idx].bounds) - - def __len__(self) -> int: - """Return the number of samples over the ROI. + minx, maxx, miny, maxy, mint, maxt = self.hits[idx].bounds + chip = { + "geometry": box(minx, miny, maxx, maxy), + "minx": minx, + "miny": miny, + "maxx": maxx, + "maxy": maxy, + "mint": mint, + "maxt": maxt, + } + chips.append(chip) + + if chips: + print("creating geodataframe... ") + chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) + chips_gdf["fid"] = chips_gdf.index + else: + warnings.warn("Sampler has no chips, check your inputs") + chips_gdf = GeoDataFrame() + return chips_gdf - Returns: - number of patches that will be sampled - """ - return len(self.hits) From daab71c357b4ddc49fde6096f0a0b905b0e31c0a Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Wed, 28 Aug 2024 12:00:10 +0200 Subject: [PATCH 03/10] Pass sampler tests --- tests/samplers/test_single.py | 32 +++++++++++++++++++++++++------- torchgeo/samplers/single.py | 34 +++++++++++++--------------------- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 1416368098a..34f0705a358 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -20,14 +20,29 @@ tile_to_chips, ) +from geopandas import GeoDataFrame +from shapely.geometry import box + class CustomGeoSampler(GeoSampler): def __init__(self) -> None: - pass + self.chips = self.get_chips() - def __iter__(self) -> Iterator[BoundingBox]: + def get_chips(self) -> GeoDataFrame: + chips = [] for i in range(len(self)): - yield BoundingBox(i, i, i, i, i, i) + chips.append( + { + "geometry": box(i, i, i, i), + "minx": i, + "miny": i, + "maxx": i, + "maxy": i, + "mint": i, + "maxt": i, + } + ) + return GeoDataFrame(chips, crs="3005") def __len__(self) -> int: return 2 @@ -156,7 +171,7 @@ class TestGridGeoSampler: def dataset(self) -> CustomGeoDataset: ds = CustomGeoDataset() ds.index.insert(0, (0, 100, 200, 300, 400, 500)) - ds.index.insert(1, (0, 100, 200, 300, 400, 500)) + ds.index.insert(1, (0, 100, 200, 300, 500, 600)) return ds @pytest.fixture( @@ -197,13 +212,16 @@ def test_iter(self, sampler: GridGeoSampler) -> None: assert math.isclose(query.maxx - query.minx, sampler.size[1]) assert math.isclose(query.maxy - query.miny, sampler.size[0]) - assert math.isclose( - query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint + assert ( + sampler.roi.mint + <= query.mint + <= query.maxt + <= sampler.roi.maxt ) def test_len(self, sampler: GridGeoSampler) -> None: rows, cols = tile_to_chips(sampler.roi, sampler.size, sampler.stride) - length = rows * cols * 2 # two items in dataset + length = rows * cols * 2 # two spatially but not temporally overlapping items in dataset assert len(sampler) == length def test_roi(self, dataset: CustomGeoDataset) -> None: diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index b4dc4befac7..2d85886db37 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -18,6 +18,7 @@ from shapely.geometry import box import re import pandas as pd +import warnings def _get_regex_groups_as_df(dataset, hits): """ @@ -86,7 +87,7 @@ def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: self.dataset = dataset @abc.abstractmethod - def get_chips(self) -> GeoDataFrame: + def get_chips(self, **kwargs) -> GeoDataFrame: """Determines the way to get the extend of the chips (samples) of the dataset. Should return a GeoDataFrame with the extend of the chips with the columns geometry, minx, miny, maxx, maxy, mint, maxt, fid. Each row is a chip.""" @@ -212,7 +213,7 @@ def __init__( if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) - self.length = 0 + num_chips = 0 self.hits = [] areas = [] for hit in self.index.intersection(tuple(self.roi), objects=True): @@ -223,25 +224,25 @@ def __init__( ): if bounds.area > 0: rows, cols = tile_to_chips(bounds, self.size) - self.length += rows * cols + num_chips += rows * cols else: - self.length += 1 + num_chips += 1 self.hits.append(hit) areas.append(bounds.area) if length is not None: - self.length = length + num_chips = length + self.length = num_chips # torch.multinomial requires float probabilities > 0 self.areas = torch.tensor(areas, dtype=torch.float) if torch.sum(self.areas) == 0: self.areas += 1 - self.chips = self.get_chips() - + self.chips = self.get_chips(num_samples=num_chips) - def get_chips(self) -> GeoDataFrame: + def get_chips(self, num_samples) -> GeoDataFrame: chips = [] - for _ in tqdm(range(len(self))): + for _ in tqdm(range(num_samples)): # Choose a random tile, weighted by area idx = torch.multinomial(self.areas, 1) hit = self.hits[idx] @@ -275,14 +276,6 @@ def get_chips(self) -> GeoDataFrame: chips_gdf = GeoDataFrame() return chips_gdf - def __len__(self) -> int: - """Return the number of samples in a single epoch. - - Returns: - length of the epoch - """ - return self.length - class GridGeoSampler(GeoSampler): """Samples elements in a grid-like fashion. @@ -360,7 +353,6 @@ def __init__( def get_chips(self) -> GeoDataFrame: print("generating samples... ") optional_keys = ["tile", "date"] - self.length = 0 chips = [] for _, row in tqdm(self.df_path.iterrows(), total=len(self.df_path)): bounds = BoundingBox( @@ -443,9 +435,9 @@ def __init__( self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): - self.hits.append(hit)\ + self.hits.append(hit) - self.chips = get_chips(self) + self.chips = self.get_chips() def get_chips(self) -> GeoDataFrame: @@ -454,7 +446,7 @@ def get_chips(self) -> GeoDataFrame: generator = torch.randperm chips = [] - for idx in generator(len(self)): + for idx in generator(len(self.hits)): minx, maxx, miny, maxy, mint, maxt = self.hits[idx].bounds chip = { "geometry": box(minx, miny, maxx, maxy), From fb85941d8088ba5889f44a73fda8b48ec6fec8fe Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Wed, 28 Aug 2024 16:16:34 +0200 Subject: [PATCH 04/10] 100% test coverage for samplers --- .vscode/settings.json | 7 ++ tests/data/samplers/filtering_4x4.feather | Bin 0 -> 5490 bytes .../samplers/filtering_4x4/filtering_4x4.cpg | 1 + .../samplers/filtering_4x4/filtering_4x4.dbf | Bin 0 -> 78 bytes .../samplers/filtering_4x4/filtering_4x4.prj | 1 + .../samplers/filtering_4x4/filtering_4x4.shp | Bin 0 -> 236 bytes .../samplers/filtering_4x4/filtering_4x4.shx | Bin 0 -> 108 bytes tests/samplers/test_single.py | 116 ++++++++++++++++++ torchgeo/samplers/single.py | 80 ++++++++---- 9 files changed, 184 insertions(+), 21 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 tests/data/samplers/filtering_4x4.feather create mode 100644 tests/data/samplers/filtering_4x4/filtering_4x4.cpg create mode 100644 tests/data/samplers/filtering_4x4/filtering_4x4.dbf create mode 100644 tests/data/samplers/filtering_4x4/filtering_4x4.prj create mode 100644 tests/data/samplers/filtering_4x4/filtering_4x4.shp create mode 100644 tests/data/samplers/filtering_4x4/filtering_4x4.shx diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000000..ad7af29f625 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/tests/data/samplers/filtering_4x4.feather b/tests/data/samplers/filtering_4x4.feather new file mode 100644 index 0000000000000000000000000000000000000000..305d37e4fa6244002c3a17f577da8866f1a340c8 GIT binary patch literal 5490 zcmeHL&2HO95MC#;;{9<$d<^lp+(9+v`FLXq{YtoVXQ-lq zP@_&Sjq^d`Y;lKh1F6->`I0*pni2T?5{`vTdTAIXeH?iE)pN5~*?mv4-mx4AV2-sO zW?L&OLzSMmsp_v-RJL;3bT}eTp?;il3h=GB>Zoddx ze1$*X=IyvEi7g?!B9on7A;Q#7U_+tvfzhHQzAMvE>ZrFbkE9}PmP$;d5HVO5e7--Q zn5xv06QfAyp;1tDyGou$1R;BqonHx!(_y5@+663u@_^<^q_HWKeI=zsVqHf)P5L*J zw_cqhFGXf}UZs8rwF{$)Tpy<@(j&0n**n;UI23h=fmMHfvoMh{rIq?pvOIxHLY$yj zK2E3IWTS@fyvs$*+e!-Tn}O?Q;QNN_P#{<|LT`&#%uSXwX5i@&KN4{!IVsSOOjMtW zI5(oN5wk0L&&ECod=nR!YQ8G%{vn1>J%9ZcERDt$mK7DfO!G{OBor#-ut^+C6Nsg* zYt36QZo$;oGxn44Ba8$0HfCO(X1znvN_UhCeWO>srWqA=kweXOn<22sv+L6=U*rCL zglE;b|B>6qczifhhS?TM%PbNmgx}9Eg=KJmu4Wa$K6vM?$WuNpwVsLIdnu|QXcU{l z=&fdR4Mv|7whq#&xzWyE8jt#vw+;#2?WSkA=1rFjZeFXH(d*D0M=CF}rsonK!uKet z=v2|Mb0?0viYhG4!D)!=Vk%bUT+{Vd-Bo8irqWDRyl!ea>NOtJ{v~?sNyh$MK;4d# zP@WZ4yU0R65)53FNcz=EmgI3vGSN*{NaYBdOs7MddS)ukbPQc>z3&*}))bQqSWr2N zWE^&if>*$-WX62pPsz=sX`+|pgW+g`NV}-#un)OPP)$aYFx^Xh`{w1-fDCad%#{>d z?M|j-nYa=3uSnD#Q4j+@gFMj5vcKMk!;w%b{Rz}69`x2-az4sq$-9E|UHUk1F1e09 zq^CjeW%BAh(~&ibK5n4~UvqlUzhS`trN&s-%*XbHztbADFzW14(HYhaMRvEwJ}3Bh z25Vn+f;3EI**(Bv4`@KHa+EyKC*Cw37YJT_2bGPWRiy_M^uVY&UqAZc*)O#G8OuYw zFv>}R|AIh&FF*hgTBP?UPUD%wK3iss_wTSh0@{K87&F+dPMRzcr#h|B=fYZ$>xr3J zQrtAojQs&LjjDanU(fsKc^^IRqbE6^Qu(}(E}PkTA6@c9b{MY&~mj?Nx literal 0 HcmV?d00001 diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.cpg b/tests/data/samplers/filtering_4x4/filtering_4x4.cpg new file mode 100644 index 00000000000..57decb48120 --- /dev/null +++ b/tests/data/samplers/filtering_4x4/filtering_4x4.cpg @@ -0,0 +1 @@ +ISO-8859-1 diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.dbf b/tests/data/samplers/filtering_4x4/filtering_4x4.dbf new file mode 100644 index 0000000000000000000000000000000000000000..499d67bcec48f8473adebc8ab148bcb05ea6894d GIT binary patch literal 78 mcmZRsVV7WJU|?`$-~p1Dz|GSICg=xZaKm^|npXh<45R>p9Rtt+ literal 0 HcmV?d00001 diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.prj b/tests/data/samplers/filtering_4x4/filtering_4x4.prj new file mode 100644 index 00000000000..42fd4b91b78 --- /dev/null +++ b/tests/data/samplers/filtering_4x4/filtering_4x4.prj @@ -0,0 +1 @@ +PROJCS["NAD_1983_BC_Environment_Albers",GEOGCS["GCS_North_American_1983",DATUM["D_North_American_1983",SPHEROID["GRS_1980",6378137.0,298.257222101]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Albers"],PARAMETER["False_Easting",1000000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",-126.0],PARAMETER["Standard_Parallel_1",50.0],PARAMETER["Standard_Parallel_2",58.5],PARAMETER["Latitude_Of_Origin",45.0],UNIT["Meter",1.0]] diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.shp b/tests/data/samplers/filtering_4x4/filtering_4x4.shp new file mode 100644 index 0000000000000000000000000000000000000000..65606c26dd6675aa22232af31613c0d39433b9db GIT binary patch literal 236 zcmZQzQ0HR64$59IGcd4Xmjj9lI6$OeG){#e2}U4xAjT|^Lfq;=M!^8gUR*Rx9fAe` DP2vP( literal 0 HcmV?d00001 diff --git a/tests/data/samplers/filtering_4x4/filtering_4x4.shx b/tests/data/samplers/filtering_4x4/filtering_4x4.shx new file mode 100644 index 0000000000000000000000000000000000000000..b2028e759e5a7509214f94701bfbbb3e3bb83d69 GIT binary patch literal 108 lcmZQzQ0HR64$NLKGcd4Xmjj9lI6$OeG){#e2_qnO005390%`yN literal 0 HcmV?d00001 diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 34f0705a358..10634624227 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import os import math from collections.abc import Iterator from itertools import product @@ -20,6 +21,7 @@ tile_to_chips, ) +import geopandas as gpd from geopandas import GeoDataFrame from shapely.geometry import box @@ -58,6 +60,16 @@ def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: return {'index': query} +class CustomGeoDatasetSITS(GeoDataset): + def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: + super().__init__() + self._crs = crs + self.res = res + self.return_as_ts=True + + def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: + return {'index': query} + class TestGeoSampler: @pytest.fixture(scope='class') def dataset(self) -> CustomGeoDataset: @@ -69,6 +81,14 @@ def dataset(self) -> CustomGeoDataset: def sampler(self) -> CustomGeoSampler: return CustomGeoSampler() + @pytest.fixture(scope='class') + def datadir(self) -> str: + return os.path.join('tests', 'data', 'samplers') + + def test_no_get_chips_implemented(self) -> None: + with pytest.raises(TypeError): + GeoSampler() + def test_iter(self, sampler: CustomGeoSampler) -> None: assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0) @@ -79,6 +99,69 @@ def test_abstract(self, dataset: CustomGeoDataset) -> None: with pytest.raises(TypeError, match="Can't instantiate abstract class"): GeoSampler(dataset) # type: ignore[abstract] + @pytest.mark.parametrize( + "filtering_file", + [ + "filtering_4x4", + "filtering_4x4.feather", + ], + ) + def test_filtering_from_path(self, datadir, filtering_file) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler( + ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10) + ) + iterator = iter(sampler) + + assert len(sampler) == 4 + filtering_path = os.path.join(datadir, filtering_file) + sampler.filter_chips(filtering_path, "intersects", "drop") + assert len(sampler) == 3 + assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) + + def test_filtering_from_gdf(self, datadir) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler( + ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10) + ) + iterator = iter(sampler) + + # Dropping first chip + assert len(sampler) == 4 + filtering_gdf = gpd.read_file( + os.path.join(datadir, "filtering_4x4") + ) + sampler.filter_chips(filtering_gdf, "intersects", "drop") + assert len(sampler) == 3 + assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) + + # Keeping only first chip + sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) + iterator = iter(sampler) + sampler.filter_chips(filtering_gdf, "intersects", "keep") + assert len(sampler) == 1 + assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) + + def test_set_worker_split(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler( + ds, 5, 5, units=Units.CRS, roi=BoundingBox(0, 10, 0, 10, 0, 10) + ) + assert len(sampler) == 4 + sampler.set_worker_split(total_workers=4, worker_num=1) + assert len(sampler) == 1 + + def test_save_chips(self, tmpdir_factory) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) + sampler.save(str(tmpdir_factory.mktemp("out").join("chips"))) + sampler.save(str(tmpdir_factory.mktemp("out").join("chips.feather"))) + + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( @@ -154,6 +237,15 @@ def test_weighted_sampling(self) -> None: for bbox in sampler: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + def test_return_as_ts(self) -> None: + ds = CustomGeoDatasetSITS() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + ds.index.insert(1, (0, 10, 0, 10, 15, 20)) + sampler = RandomGeoSampler(ds, 1, 5) + for query in sampler: + assert query.mint == ds.bounds.mint == 0 + assert query.maxt == ds.bounds.maxt == 20 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( @@ -260,6 +352,30 @@ def test_float_multiple(self) -> None: assert len(sampler) == 2 assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) + + def test_dataset_has_regex(self) -> None: + ds = CustomGeoDataset() + ds.filename_regex = r'.*(?Ptest)' + ds.index.insert(0, (0, 10, 0, 10, 0, 10), "filepath_containing_key_test") + sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS) + assert "my_key" in sampler.chips.columns + + def test_dataset_has_regex_no_match(self) -> None: + ds = CustomGeoDataset() + ds.filename_regex = r'(?Ptest)' + ds.index.insert(0, (0, 10, 0, 10, 0, 10), "no_matching_key") + sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS) + assert "my_key" not in sampler.chips.columns + + def test_return_as_ts(self) -> None: + ds = CustomGeoDatasetSITS() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + ds.index.insert(1, (0, 10, 0, 10, 15, 20)) + sampler = GridGeoSampler(ds, 1, 1) + for query in sampler: + assert query.mint == ds.bounds.mint == 0 + assert query.maxt == ds.bounds.maxt == 20 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 2d85886db37..8bc6f628a6d 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -19,6 +19,31 @@ import re import pandas as pd import warnings +import geopandas as gpd +import numpy as np + +def load_file(path: str|GeoDataFrame) -> GeoDataFrame: + """ + Load a file from the given path. + + Parameters: + path (str or GeoDataFrame): The path to the file or a GeoDataFrame object. + + Returns: + GeoDataFrame: The loaded file as a GeoDataFrame. + + Raises: + None + + """ + if isinstance(path, GeoDataFrame): + return path + if path.endswith(".feather"): + print(f"Reading feather file: {path}") + return gpd.read_feather(path) + else: + print(f"Reading shapefile: {path}") + return gpd.read_file(path) def _get_regex_groups_as_df(dataset, hits): """ @@ -40,6 +65,8 @@ def _get_regex_groups_as_df(dataset, hits): match = re.match(filename_regex, str(hit.object)) if match: meta = match.groupdict() + else: + meta = {} else: meta = {} meta.update( @@ -91,14 +118,13 @@ def get_chips(self, **kwargs) -> GeoDataFrame: """Determines the way to get the extend of the chips (samples) of the dataset. Should return a GeoDataFrame with the extend of the chips with the columns geometry, minx, miny, maxx, maxy, mint, maxt, fid. Each row is a chip.""" - raise NotImplementedError def filter_chips( self, filter_by: str | GeoDataFrame, predicate: str = "intersects", - action: str = "keep", + action: str = "keep" , ) -> None: """Filter the default set of chips in the sampler down to a specific subset by specifying files supported by geopandas such as shapefiles, geodatabases or @@ -112,9 +138,28 @@ def filter_chips( """ prefilter_leng = len(self.chips) filtering_gdf = load_file(filter_by).to_crs(self.dataset.crs) - self.chips = filter_tiles( - self.chips, filtering_gdf, predicate, action - ).reset_index(drop=True) + + if action == "keep": + self.chips = self.chips.iloc[ + list( + set( + self.chips.sindex.query_bulk( + filtering_gdf.geometry, predicate=predicate + )[1] + ) + ) + ].reset_index(drop=True) + elif action == "drop": + self.chips = self.chips.drop( + index=list( + set( + self.chips.sindex.query_bulk( + filtering_gdf.geometry, predicate=predicate + )[1] + ) + ) + ).reset_index(drop=True) + self.chips.fid = self.chips.index print(f"Filter step reduced chips from {prefilter_leng} to {len(self.chips)}") assert not self.chips.empty, "No chips left after filtering!" @@ -160,7 +205,7 @@ def __len__(self) -> int: return len(self.chips) class RandomGeoSampler(GeoSampler): - """Differs from TrochGeo's RandomGeoSampler in that it can sample SITS data. + """Differs from TorchGeo's official RandomGeoSampler in that it can sample SITS data. Documentation from TorchGeo: Samples elements from a region of interest randomly. @@ -266,14 +311,10 @@ def get_chips(self, num_samples) -> GeoDataFrame: } chips.append(chip) - if chips: - print("creating geodataframe... ") - chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) - chips_gdf["fid"] = chips_gdf.index + print("creating geodataframe... ") + chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) + chips_gdf["fid"] = chips_gdf.index - else: - warnings.warn("Sampler has no chips, check your inputs") - chips_gdf = GeoDataFrame() return chips_gdf @@ -352,7 +393,7 @@ def __init__( def get_chips(self) -> GeoDataFrame: print("generating samples... ") - optional_keys = ["tile", "date"] + optional_keys = set(self.df_path.keys()) - set(['geometry', 'minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt']) chips = [] for _, row in tqdm(self.df_path.iterrows(), total=len(self.df_path)): bounds = BoundingBox( @@ -459,12 +500,9 @@ def get_chips(self) -> GeoDataFrame: } chips.append(chip) - if chips: - print("creating geodataframe... ") - chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) - chips_gdf["fid"] = chips_gdf.index - else: - warnings.warn("Sampler has no chips, check your inputs") - chips_gdf = GeoDataFrame() + print("creating geodataframe... ") + chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) + chips_gdf["fid"] = chips_gdf.index + return chips_gdf From 889735c8a4d260793075a6537f4159c19304e2fa Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Wed, 28 Aug 2024 17:00:38 +0200 Subject: [PATCH 05/10] Run pre-commit (tg ruleset) WIP --- .vscode/settings.json | 12 ++- tests/samplers/test_single.py | 80 ++++++++--------- torchgeo/datasets/geo.py | 2 +- torchgeo/samplers/single.py | 160 +++++++++++++++++----------------- 4 files changed, 120 insertions(+), 134 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index ad7af29f625..d969f962b02 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,7 +1,5 @@ -{ - "python.testing.pytestArgs": [ - "tests" - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true -} \ No newline at end of file +{ + "python.testing.pytestArgs": ["tests"], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 10634624227..7cf54f69000 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -1,14 +1,16 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import os import math -from collections.abc import Iterator +import os from itertools import product +import geopandas as gpd import pytest from _pytest.fixtures import SubRequest +from geopandas import GeoDataFrame from rasterio.crs import CRS +from shapely.geometry import box from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples @@ -21,10 +23,6 @@ tile_to_chips, ) -import geopandas as gpd -from geopandas import GeoDataFrame -from shapely.geometry import box - class CustomGeoSampler(GeoSampler): def __init__(self) -> None: @@ -35,16 +33,16 @@ def get_chips(self) -> GeoDataFrame: for i in range(len(self)): chips.append( { - "geometry": box(i, i, i, i), - "minx": i, - "miny": i, - "maxx": i, - "maxy": i, - "mint": i, - "maxt": i, + 'geometry': box(i, i, i, i), + 'minx': i, + 'miny': i, + 'maxx': i, + 'maxy': i, + 'mint': i, + 'maxt': i, } ) - return GeoDataFrame(chips, crs="3005") + return GeoDataFrame(chips, crs='3005') def __len__(self) -> int: return 2 @@ -65,11 +63,12 @@ def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None: super().__init__() self._crs = crs self.res = res - self.return_as_ts=True + self.return_as_ts = True def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]: return {'index': query} + class TestGeoSampler: @pytest.fixture(scope='class') def dataset(self) -> CustomGeoDataset: @@ -85,9 +84,9 @@ def sampler(self) -> CustomGeoSampler: def datadir(self) -> str: return os.path.join('tests', 'data', 'samplers') - def test_no_get_chips_implemented(self) -> None: + def test_no_get_chips_implemented(self, dataset: CustomGeoDataset) -> None: with pytest.raises(TypeError): - GeoSampler() + GeoSampler(dataset) def test_iter(self, sampler: CustomGeoSampler) -> None: assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0) @@ -100,13 +99,9 @@ def test_abstract(self, dataset: CustomGeoDataset) -> None: GeoSampler(dataset) # type: ignore[abstract] @pytest.mark.parametrize( - "filtering_file", - [ - "filtering_4x4", - "filtering_4x4.feather", - ], + 'filtering_file', ['filtering_4x4', 'filtering_4x4.feather'] ) - def test_filtering_from_path(self, datadir, filtering_file) -> None: + def test_filtering_from_path(self, datadir: str, filtering_file: str) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) sampler = GridGeoSampler( @@ -116,11 +111,11 @@ def test_filtering_from_path(self, datadir, filtering_file) -> None: assert len(sampler) == 4 filtering_path = os.path.join(datadir, filtering_file) - sampler.filter_chips(filtering_path, "intersects", "drop") + sampler.filter_chips(filtering_path, 'intersects', 'drop') assert len(sampler) == 3 assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) - def test_filtering_from_gdf(self, datadir) -> None: + def test_filtering_from_gdf(self, datadir: str) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) sampler = GridGeoSampler( @@ -130,17 +125,15 @@ def test_filtering_from_gdf(self, datadir) -> None: # Dropping first chip assert len(sampler) == 4 - filtering_gdf = gpd.read_file( - os.path.join(datadir, "filtering_4x4") - ) - sampler.filter_chips(filtering_gdf, "intersects", "drop") + filtering_gdf = gpd.read_file(os.path.join(datadir, 'filtering_4x4')) + sampler.filter_chips(filtering_gdf, 'intersects', 'drop') assert len(sampler) == 3 assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) # Keeping only first chip sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) iterator = iter(sampler) - sampler.filter_chips(filtering_gdf, "intersects", "keep") + sampler.filter_chips(filtering_gdf, 'intersects', 'keep') assert len(sampler) == 1 assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) @@ -158,9 +151,8 @@ def test_save_chips(self, tmpdir_factory) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS) - sampler.save(str(tmpdir_factory.mktemp("out").join("chips"))) - sampler.save(str(tmpdir_factory.mktemp("out").join("chips.feather"))) - + sampler.save(str(tmpdir_factory.mktemp('out').join('chips'))) + sampler.save(str(tmpdir_factory.mktemp('out').join('chips.feather'))) @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) @@ -304,16 +296,13 @@ def test_iter(self, sampler: GridGeoSampler) -> None: assert math.isclose(query.maxx - query.minx, sampler.size[1]) assert math.isclose(query.maxy - query.miny, sampler.size[0]) - assert ( - sampler.roi.mint - <= query.mint - <= query.maxt - <= sampler.roi.maxt - ) + assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt def test_len(self, sampler: GridGeoSampler) -> None: rows, cols = tile_to_chips(sampler.roi, sampler.size, sampler.stride) - length = rows * cols * 2 # two spatially but not temporally overlapping items in dataset + length = ( + rows * cols * 2 + ) # two spatially but not temporally overlapping items in dataset assert len(sampler) == length def test_roi(self, dataset: CustomGeoDataset) -> None: @@ -352,20 +341,20 @@ def test_float_multiple(self) -> None: assert len(sampler) == 2 assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10) assert next(iterator) == BoundingBox(5, 10, 0, 5, 0, 10) - + def test_dataset_has_regex(self) -> None: ds = CustomGeoDataset() ds.filename_regex = r'.*(?Ptest)' - ds.index.insert(0, (0, 10, 0, 10, 0, 10), "filepath_containing_key_test") + ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'filepath_containing_key_test') sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS) - assert "my_key" in sampler.chips.columns + assert 'my_key' in sampler.chips.columns def test_dataset_has_regex_no_match(self) -> None: ds = CustomGeoDataset() ds.filename_regex = r'(?Ptest)' - ds.index.insert(0, (0, 10, 0, 10, 0, 10), "no_matching_key") + ds.index.insert(0, (0, 10, 0, 10, 0, 10), 'no_matching_key') sampler = GridGeoSampler(ds, 1, 2, units=Units.CRS) - assert "my_key" not in sampler.chips.columns + assert 'my_key' not in sampler.chips.columns def test_return_as_ts(self) -> None: ds = CustomGeoDatasetSITS() @@ -376,7 +365,6 @@ def test_return_as_ts(self) -> None: assert query.mint == ds.bounds.mint == 0 assert query.maxt == ds.bounds.maxt == 20 - @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index f0d217ea640..5d5175c63b4 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -96,7 +96,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): #: other datasets. It should not include a file extension, as the dataset may be in #: a different file format than what it was originally downloaded as. filename_glob = '*' - + # Whether to return the dataset as a Timeseries, this will add another dimension to the dataset return_as_ts = False diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 8bc6f628a6d..2924edad7f4 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -4,27 +4,28 @@ """TorchGeo samplers.""" import abc +import re +import warnings from collections.abc import Callable, Iterable, Iterator +from typing import Any +import geopandas as gpd +import numpy as np +import pandas as pd import torch +from geopandas import GeoDataFrame from rtree.index import Index, Property +from shapely.geometry import box from torch.utils.data import Sampler +from tqdm import tqdm from ..datasets import BoundingBox, GeoDataset from .constants import Units from .utils import _to_tuple, get_random_bounding_box, tile_to_chips -from geopandas import GeoDataFrame -from tqdm import tqdm -from shapely.geometry import box -import re -import pandas as pd -import warnings -import geopandas as gpd -import numpy as np -def load_file(path: str|GeoDataFrame) -> GeoDataFrame: - """ - Load a file from the given path. + +def load_file(path: str | GeoDataFrame) -> GeoDataFrame: + """Load a file from the given path. Parameters: path (str or GeoDataFrame): The path to the file or a GeoDataFrame object. @@ -38,16 +39,16 @@ def load_file(path: str|GeoDataFrame) -> GeoDataFrame: """ if isinstance(path, GeoDataFrame): return path - if path.endswith(".feather"): - print(f"Reading feather file: {path}") + if path.endswith('.feather'): + print(f'Reading feather file: {path}') return gpd.read_feather(path) else: - print(f"Reading shapefile: {path}") + print(f'Reading shapefile: {path}') return gpd.read_file(path) -def _get_regex_groups_as_df(dataset, hits): - """ - Extracts the regex metadata from a list of hits. + +def _get_regex_groups_as_df(dataset: GeoDataset, hits: list) -> pd.DataFrame: + """Extracts the regex metadata from a list of hits. Args: dataset (GeoDataset): The dataset to sample from. @@ -56,7 +57,7 @@ def _get_regex_groups_as_df(dataset, hits): Returns: pandas.DataFrame: A DataFrame containing the extracted file metadata. """ - has_filename_regex = bool(getattr(dataset, "filename_regex", None)) + has_filename_regex = hasattr(dataset, 'filename_regex') if has_filename_regex: filename_regex = re.compile(dataset.filename_regex, re.VERBOSE) file_metadata = [] @@ -71,17 +72,18 @@ def _get_regex_groups_as_df(dataset, hits): meta = {} meta.update( { - "minx": hit.bounds[0], - "maxx": hit.bounds[1], - "miny": hit.bounds[2], - "maxy": hit.bounds[3], - "mint": hit.bounds[4], - "maxt": hit.bounds[5], + 'minx': hit.bounds[0], + 'maxx': hit.bounds[1], + 'miny': hit.bounds[2], + 'maxy': hit.bounds[3], + 'mint': hit.bounds[4], + 'maxt': hit.bounds[5], } ) file_metadata.append(meta) return pd.DataFrame(file_metadata) + class GeoSampler(Sampler[BoundingBox], abc.ABC): """Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`. @@ -112,23 +114,23 @@ def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: self.res = dataset.res self.roi = roi self.dataset = dataset - - @abc.abstractmethod - def get_chips(self, **kwargs) -> GeoDataFrame: + + @abc.abstractmethod + def get_chips(self, *args: Any, **kwargs: Any) -> GeoDataFrame: """Determines the way to get the extend of the chips (samples) of the dataset. - Should return a GeoDataFrame with the extend of the chips with the columns - geometry, minx, miny, maxx, maxy, mint, maxt, fid. Each row is a chip.""" + Should return a GeoDataFrame with the extend of the chips with the columns + geometry, minx, miny, maxx, maxy, mint, maxt, fid. Each row is a chip. + """ def filter_chips( self, filter_by: str | GeoDataFrame, - predicate: str = "intersects", - action: str = "keep" , + predicate: str = 'intersects', + action: str = 'keep', ) -> None: - """Filter the default set of chips in the sampler down to a specific subset by - specifying files supported by geopandas such as shapefiles, geodatabases or - feather files. + """Filter the default set of chips in the sampler down to a specific subset by specifying files + supported by geopandas such as shapefiles, geodatabases or feather files. Args: filter_by: The file or geodataframe for which the geometries will be used during filtering @@ -139,7 +141,7 @@ def filter_chips( prefilter_leng = len(self.chips) filtering_gdf = load_file(filter_by).to_crs(self.dataset.crs) - if action == "keep": + if action == 'keep': self.chips = self.chips.iloc[ list( set( @@ -149,7 +151,7 @@ def filter_chips( ) ) ].reset_index(drop=True) - elif action == "drop": + elif action == 'drop': self.chips = self.chips.drop( index=list( set( @@ -161,8 +163,8 @@ def filter_chips( ).reset_index(drop=True) self.chips.fid = self.chips.index - print(f"Filter step reduced chips from {prefilter_leng} to {len(self.chips)}") - assert not self.chips.empty, "No chips left after filtering!" + print(f'Filter step reduced chips from {prefilter_leng} to {len(self.chips)}') + assert not self.chips.empty, 'No chips left after filtering!' def set_worker_split(self, total_workers: int, worker_num: int) -> None: """Splits the chips in n equal parts for the number of workers and keeps the set of @@ -176,11 +178,9 @@ def set_worker_split(self, total_workers: int, worker_num: int) -> None: """ self.chips = np.array_split(self.chips, total_workers)[worker_num] - def save(self, - path: str, - driver: str = None) -> None: + def save(self, path: str, driver: str) -> None: """Save the chips as a shapefile or feather file""" - if path.endswith(".feather"): + if path.endswith('.feather'): self.chips.to_feather(path) else: self.chips.to_file(path, driver=driver) @@ -204,6 +204,7 @@ def __len__(self) -> int: """ return len(self.chips) + class RandomGeoSampler(GeoSampler): """Differs from TorchGeo's official RandomGeoSampler in that it can sample SITS data. @@ -301,19 +302,19 @@ def get_chips(self, num_samples) -> GeoDataFrame: bbox = get_random_bounding_box(bounds, self.size, self.res) minx, maxx, miny, maxy, mint, maxt = tuple(bbox) chip = { - "geometry": box(minx, miny, maxx, maxy), - "minx": minx, - "miny": miny, - "maxx": maxx, - "maxy": maxy, - "mint": mint, - "maxt": maxt, + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': mint, + 'maxt': maxt, } chips.append(chip) - - print("creating geodataframe... ") + + print('creating geodataframe... ') chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) - chips_gdf["fid"] = chips_gdf.index + chips_gdf['fid'] = chips_gdf.index return chips_gdf @@ -381,19 +382,20 @@ def __init__( # Filter out hits in the index that share the same extent if self.dataset.return_as_ts: self.df_path.drop_duplicates( - subset=["minx", "maxx", "miny", "maxy"], inplace=True + subset=['minx', 'maxx', 'miny', 'maxy'], inplace=True ) else: self.df_path.drop_duplicates( - subset=["minx", "maxx", "miny", "maxy", "mint", "maxt"], inplace=True + subset=['minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt'], inplace=True ) - - self.chips = self.get_chips() + self.chips = self.get_chips() def get_chips(self) -> GeoDataFrame: - print("generating samples... ") - optional_keys = set(self.df_path.keys()) - set(['geometry', 'minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt']) + print('generating samples... ') + optional_keys = set(self.df_path.keys()) - set( + ['geometry', 'minx', 'maxx', 'miny', 'maxy', 'mint', 'maxt'] + ) chips = [] for _, row in tqdm(self.df_path.iterrows(), total=len(self.df_path)): bounds = BoundingBox( @@ -419,13 +421,13 @@ def get_chips(self) -> GeoDataFrame: maxt = bounds.maxt chip = { - "geometry": box(minx, miny, maxx, maxy), - "minx": minx, - "miny": miny, - "maxx": maxx, - "maxy": maxy, - "mint": mint, - "maxt": maxt, + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': mint, + 'maxt': maxt, } for key in optional_keys: if key in row.keys(): @@ -434,12 +436,12 @@ def get_chips(self) -> GeoDataFrame: chips.append(chip) if chips: - print("creating geodataframe... ") + print('creating geodataframe... ') chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) - chips_gdf["fid"] = chips_gdf.index + chips_gdf['fid'] = chips_gdf.index else: - warnings.warn("Sampler has no chips, check your inputs") + warnings.warn('Sampler has no chips, check your inputs') chips_gdf = GeoDataFrame() return chips_gdf @@ -477,11 +479,10 @@ def __init__( self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): self.hits.append(hit) - + self.chips = self.get_chips() def get_chips(self) -> GeoDataFrame: - generator: Callable[[int], Iterable[int]] = range if self.shuffle: generator = torch.randperm @@ -490,19 +491,18 @@ def get_chips(self) -> GeoDataFrame: for idx in generator(len(self.hits)): minx, maxx, miny, maxy, mint, maxt = self.hits[idx].bounds chip = { - "geometry": box(minx, miny, maxx, maxy), - "minx": minx, - "miny": miny, - "maxx": maxx, - "maxy": maxy, - "mint": mint, - "maxt": maxt, + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': mint, + 'maxt': maxt, } chips.append(chip) - print("creating geodataframe... ") + print('creating geodataframe... ') chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) - chips_gdf["fid"] = chips_gdf.index + chips_gdf['fid'] = chips_gdf.index return chips_gdf - From 3e312bec55504db70e1a3b0dddafd7c7fd725874 Mon Sep 17 00:00:00 2001 From: stypoumic Date: Thu, 5 Sep 2024 14:57:41 +0000 Subject: [PATCH 06/10] Allow for multiple ROIs in RandomGeoSampler and filtering of hits --- torchgeo/samplers/single.py | 67 +++++++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 2924edad7f4..374478da610 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -106,10 +106,17 @@ def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: roi = BoundingBox(*self.index.bounds) else: self.index = Index(interleaved=False, properties=Property(dimension=3)) - hits = dataset.index.intersection(tuple(roi), objects=True) - for hit in hits: - bbox = BoundingBox(*hit.bounds) & roi - self.index.insert(hit.id, tuple(bbox), hit.object) + if isinstance(roi, list): + for area in roi: + hits = dataset.index.intersection(tuple(area), objects=True) + for hit in hits: + bbox = BoundingBox(*hit.bounds) & area + self.index.insert(hit.id, tuple(bbox), hit.object) + else: + hits = dataset.index.intersection(tuple(roi), objects=True) + for hit in hits: + bbox = BoundingBox(*hit.bounds) & roi + self.index.insert(hit.id, tuple(bbox), hit.object) self.res = dataset.res self.roi = roi @@ -262,23 +269,48 @@ def __init__( num_chips = 0 self.hits = [] areas = [] - for hit in self.index.intersection(tuple(self.roi), objects=True): - bounds = BoundingBox(*hit.bounds) - if ( - bounds.maxx - bounds.minx >= self.size[1] - and bounds.maxy - bounds.miny >= self.size[0] - ): - if bounds.area > 0: - rows, cols = tile_to_chips(bounds, self.size) - num_chips += rows * cols - else: - num_chips += 1 - self.hits.append(hit) - areas.append(bounds.area) + # use only geospatial bounds for filering in case return_as_ts is set to True + filter_indices = slice(0,4) if dataset.return_as_ts else slice(0,6) + + if isinstance(self.roi, list): + for area in self.roi: + for hit in self.index.intersection(tuple(area), objects=True): + # Filter out hits in the index that share the same extent + if hit.bounds[filter_indices] not in [ht.bounds[filter_indices] for ht in self.hits]: + bounds = BoundingBox(*hit.bounds) + if ( + bounds.maxx - bounds.minx >= self.size[1] + and bounds.maxy - bounds.miny >= self.size[0] + ): + if bounds.area > 0: + rows, cols = tile_to_chips(bounds, self.size) + num_chips += rows * cols + else: + num_chips += 1 + self.hits.append(hit) + areas.append(bounds.area) + else: + for hit in self.index.intersection(tuple(self.roi), objects=True): + # Filter out hits in the index that share the same extent + if hit.bounds[filter_indices] not in [ht.bounds[filter_indices] for ht in self.hits]: + bounds = BoundingBox(*hit.bounds) + if ( + bounds.maxx - bounds.minx >= self.size[1] + and bounds.maxy - bounds.miny >= self.size[0] + ): + if bounds.area > 0: + rows, cols = tile_to_chips(bounds, self.size) + num_chips += rows * cols + else: + num_chips += 1 + self.hits.append(hit) + areas.append(bounds.area) if length is not None: num_chips = length self.length = num_chips + print(f"Unique geospatial file hits: {len(self.hits)}") + # torch.multinomial requires float probabilities > 0 self.areas = torch.tensor(areas, dtype=torch.float) if torch.sum(self.areas) == 0: @@ -309,6 +341,7 @@ def get_chips(self, num_samples) -> GeoDataFrame: 'maxy': maxy, 'mint': mint, 'maxt': maxt, + 'hit_id': hit.id } chips.append(chip) From ad748492d3bbd0eeb8aeebbd5983ddf5032c248d Mon Sep 17 00:00:00 2001 From: stypoumic Date: Thu, 5 Sep 2024 14:59:04 +0000 Subject: [PATCH 07/10] Add function for saving hits of GeoSampler --- torchgeo/samplers/single.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 374478da610..bf03caecdf1 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -191,6 +191,30 @@ def save(self, path: str, driver: str) -> None: self.chips.to_feather(path) else: self.chips.to_file(path, driver=driver) + + def save_hits(self, path: str, driver: str) -> None: + """Save the hits as a shapefile or feather file""" + bounds = [] + for hit in self.hits: + bbox = BoundingBox(*hit.bounds) + minx, maxx, miny, maxy, mint, maxt = tuple(bbox) + bound = { + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': mint, + 'maxt': maxt, + 'hit_id': hit.id + } + bounds.append(bound) + + bounds_gdf = GeoDataFrame(bounds, crs=self.dataset.crs) + if path.endswith('.feather'): + bounds_gdf.to_feather(path) + else: + bounds_gdf.to_file(path, driver=driver) def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. From 40bb16e2179b29cb50c6a7818a06b4f3989d34f5 Mon Sep 17 00:00:00 2001 From: stypoumic Date: Tue, 10 Sep 2024 09:46:31 +0000 Subject: [PATCH 08/10] Add ROI-based random sampling instead of hit-based --- torchgeo/samplers/single.py | 101 +++++++++++++++++++++++++----------- 1 file changed, 70 insertions(+), 31 deletions(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index bf03caecdf1..5b7a5fad401 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -4,6 +4,7 @@ """TorchGeo samplers.""" import abc +import random import re import warnings from collections.abc import Callable, Iterable, Iterator @@ -105,19 +106,36 @@ def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: self.index = dataset.index roi = BoundingBox(*self.index.bounds) else: + # Only keep hits unique in the spatial dimension if return_as_ts is enabled + # else keep hits that are unique in both spatial and temporal dimensions + filter_indices = slice(0, 4) if dataset.return_as_ts else slice(0, 6) + self.index = Index(interleaved=False, properties=Property(dimension=3)) if isinstance(roi, list): for area in roi: hits = dataset.index.intersection(tuple(area), objects=True) for hit in hits: bbox = BoundingBox(*hit.bounds) & area - self.index.insert(hit.id, tuple(bbox), hit.object) + # Filter hits + if tuple(bbox)[filter_indices] not in [ + tuple(item.bounds[filter_indices]) + for item in list( + self.index.intersection(tuple(area), objects=True) + ) + ]: + self.index.insert(hit.id, tuple(bbox), hit.object) else: hits = dataset.index.intersection(tuple(roi), objects=True) for hit in hits: bbox = BoundingBox(*hit.bounds) & roi - self.index.insert(hit.id, tuple(bbox), hit.object) + # Filter hits + if tuple(bbox)[filter_indices] not in [ + tuple(item.bounds[filter_indices]) + for item in list(self.index.intersection(tuple(roi), objects=True)) + ]: + self.index.insert(hit.id, tuple(bbox), hit.object) + print(f"Index Size: {self.index.get_size()}") self.res = dataset.res self.roi = roi self.dataset = dataset @@ -185,14 +203,14 @@ def set_worker_split(self, total_workers: int, worker_num: int) -> None: """ self.chips = np.array_split(self.chips, total_workers)[worker_num] - def save(self, path: str, driver: str) -> None: + def save(self, path: str, driver: str = "Shapefile") -> None: """Save the chips as a shapefile or feather file""" if path.endswith('.feather'): self.chips.to_feather(path) else: self.chips.to_file(path, driver=driver) - - def save_hits(self, path: str, driver: str) -> None: + + def save_hits(self, path: str, driver: str = "Shapefile") -> None: """Save the hits as a shapefile or feather file""" bounds = [] for hit in self.hits: @@ -209,7 +227,7 @@ def save_hits(self, path: str, driver: str) -> None: 'hit_id': hit.id } bounds.append(bound) - + bounds_gdf = GeoDataFrame(bounds, crs=self.dataset.crs) if path.endswith('.feather'): bounds_gdf.to_feather(path) @@ -341,38 +359,59 @@ def __init__( self.areas += 1 self.chips = self.get_chips(num_samples=num_chips) - + def get_chips(self, num_samples) -> GeoDataFrame: chips = [] - for _ in tqdm(range(num_samples)): - # Choose a random tile, weighted by area - idx = torch.multinomial(self.areas, 1) - hit = self.hits[idx] - hit_bounds = hit.bounds - if self.dataset.return_as_ts: - hit_bounds[-2] = self.dataset.bounds.mint - hit_bounds[-1] = self.dataset.bounds.maxt - - bounds = BoundingBox(*hit_bounds) - # Choose a random index within that tile - bbox = get_random_bounding_box(bounds, self.size, self.res) + while len(chips) < num_samples: + if isinstance(self.roi, list): + # Choose a random ROI, weighted by area + idx = torch.multinomial(torch.tensor([roi.area for roi in self.roi], dtype=torch.float), 1) + roi = self.roi[idx] + else: + roi = self.roi + + # Choose a random bounding box within the ROI + bbox = get_random_bounding_box(roi, self.size, self.res) minx, maxx, miny, maxy, mint, maxt = tuple(bbox) - chip = { - 'geometry': box(minx, miny, maxx, maxy), - 'minx': minx, - 'miny': miny, - 'maxx': maxx, - 'maxy': maxy, - 'mint': mint, - 'maxt': maxt, - 'hit_id': hit.id - } - chips.append(chip) + # Find file hits in the index for chosen bounding box + hits = [hit for hit in list(self.index.intersection(tuple(roi), objects=True)) if BoundingBox(*hit.bounds).area > 0] + + # if the bounding box has no file hit, dont include it in the samples + if len(hits) == 0: + continue + else: + # Choose a random hit + hit = random.choice(hits) + + # in case we are randomly sampling also the temporal dimension, get + # mint, maxt from the randomly chosen hit + if self.dataset.return_as_ts: + mint = self.dataset.bounds.mint + maxt = self.dataset.bounds.maxt + else: + mint = hit.bounds[-2] + maxt = hit.bounds[-1] + + chip = { + 'geometry': box(minx, miny, maxx, maxy), + 'minx': minx, + 'miny': miny, + 'maxx': maxx, + 'maxy': maxy, + 'mint': mint, + 'maxt': maxt, + 'hit_id': hit.id + } + chips.append(chip) + print('creating geodataframe... ') chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) chips_gdf['fid'] = chips_gdf.index + print("Number of Unique Hit IDs: {}".format(chips_gdf["hit_id"].unique().shape[0])) + print("Number of Unique Timestamps: {}".format(chips_gdf["mint"].unique().shape[0])) + return chips_gdf @@ -490,7 +529,7 @@ def get_chips(self) -> GeoDataFrame: if key in row.keys(): chip[key] = row[key] - chips.append(chip) + chips.append(chip) if chips: print('creating geodataframe... ') From 3d74557e66cff2658c1b0b34cffa83e1ca3a5b62 Mon Sep 17 00:00:00 2001 From: stypoumic Date: Fri, 13 Sep 2024 10:58:11 +0000 Subject: [PATCH 09/10] Sort chips based on hit_id --- torchgeo/samplers/single.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 5b7a5fad401..f8cd72b2fa6 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -240,7 +240,12 @@ def __iter__(self) -> Iterator[BoundingBox]: Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ + # sort chips based on hit_id + self.chips = self.chips.sort_values(by=['hit_id']) + for _, chip in self.chips.iterrows(): + print("------------------------------------") + print("Chip FID: {}".format(chip["fid"])) yield BoundingBox( chip.minx, chip.maxx, chip.miny, chip.maxy, chip.mint, chip.maxt ) From e2f19e4548a7c83e653fb8bd160a943f7b05b6c3 Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Tue, 17 Sep 2024 10:53:26 +0200 Subject: [PATCH 10/10] WIP --- torchgeo/samplers/single.py | 192 ++++++++++++++++-------------------- 1 file changed, 83 insertions(+), 109 deletions(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index f8cd72b2fa6..b3003a0d5bc 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -8,7 +8,7 @@ import re import warnings from collections.abc import Callable, Iterable, Iterator -from typing import Any +from typing import Any, List import geopandas as gpd import numpy as np @@ -94,52 +94,54 @@ class GeoSampler(Sampler[BoundingBox], abc.ABC): longitude, height, width, projection, coordinate system, and time. """ - def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: + def __init__(self, + dataset: GeoDataset, + roi: BoundingBox | List[BoundingBox] | None = None + ) -> None: """Initialize a new Sampler instance. Args: dataset: dataset to index from - roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) - (defaults to the bounds of ``dataset.index``) + roi: single or multiple regions of interest to sample from (minx, maxx, + miny, maxy, mint, maxt). Defaults to the bounds of ``dataset.index`` """ + # If no roi is provided, use the entire dataset bounds if roi is None: self.index = dataset.index roi = BoundingBox(*self.index.bounds) - else: - # Only keep hits unique in the spatial dimension if return_as_ts is enabled - # else keep hits that are unique in both spatial and temporal dimensions - filter_indices = slice(0, 4) if dataset.return_as_ts else slice(0, 6) - - self.index = Index(interleaved=False, properties=Property(dimension=3)) - if isinstance(roi, list): - for area in roi: - hits = dataset.index.intersection(tuple(area), objects=True) - for hit in hits: - bbox = BoundingBox(*hit.bounds) & area - # Filter hits - if tuple(bbox)[filter_indices] not in [ - tuple(item.bounds[filter_indices]) - for item in list( - self.index.intersection(tuple(area), objects=True) - ) - ]: - self.index.insert(hit.id, tuple(bbox), hit.object) - else: - hits = dataset.index.intersection(tuple(roi), objects=True) - for hit in hits: - bbox = BoundingBox(*hit.bounds) & roi - # Filter hits - if tuple(bbox)[filter_indices] not in [ - tuple(item.bounds[filter_indices]) - for item in list(self.index.intersection(tuple(roi), objects=True)) - ]: - self.index.insert(hit.id, tuple(bbox), hit.object) + + self.rois = roi if isinstance(roi, List) else [roi] + # Only keep hits unique in the spatial dimension if return_as_ts is enabled + # else keep hits that are unique in both spatial and temporal dimensions + filter_indices = slice(0, 4) if dataset.return_as_ts else slice(0, 6) + + self.index = Index(interleaved=False, properties=Property(dimension=3)) + + for roi in self.rois: + # First find all hits that intersect with the roi + print(roi) + hits = dataset.index.intersection(tuple(roi), objects=True) + for hit in hits: + bbox = BoundingBox(*hit.bounds) & roi + # Filter out hits that share the same extent and hits with zero area + if tuple(bbox)[filter_indices] not in [ + tuple(item.bounds[filter_indices]) for item in list( + self.index.intersection(tuple(roi), objects=True) + )] and bbox.area > 0: + self.index.insert(hit.id, tuple(bbox), hit.object) print(f"Index Size: {self.index.get_size()}") self.res = dataset.res - self.roi = roi self.dataset = dataset + @staticmethod + def __save_as_gpd_or_feather(self, path: str, gdf: GeoDataFrame, driver='ESRI Shapefile') -> None: + """Save a gdf as a file supported by any geopandas driver or feather file""" + if path.endswith('.feather'): + chips.to_feather(path) + else: + chips.to_file(path, driver=driver) + @abc.abstractmethod def get_chips(self, *args: Any, **kwargs: Any) -> GeoDataFrame: """Determines the way to get the extend of the chips (samples) of the dataset. @@ -203,12 +205,9 @@ def set_worker_split(self, total_workers: int, worker_num: int) -> None: """ self.chips = np.array_split(self.chips, total_workers)[worker_num] - def save(self, path: str, driver: str = "Shapefile") -> None: + def save_chips(self, path: str, driver: str = "ESRI Shapefile") -> None: """Save the chips as a shapefile or feather file""" - if path.endswith('.feather'): - self.chips.to_feather(path) - else: - self.chips.to_file(path, driver=driver) + self.__save_as_gpd_or_feather(path, self.chips, driver) def save_hits(self, path: str, driver: str = "Shapefile") -> None: """Save the hits as a shapefile or feather file""" @@ -229,10 +228,7 @@ def save_hits(self, path: str, driver: str = "Shapefile") -> None: bounds.append(bound) bounds_gdf = GeoDataFrame(bounds, crs=self.dataset.crs) - if path.endswith('.feather'): - bounds_gdf.to_feather(path) - else: - bounds_gdf.to_file(path, driver=driver) + self.__save_as_gpd_or_feather(path, bounds_gdf, driver) def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. @@ -240,12 +236,13 @@ def __iter__(self) -> Iterator[BoundingBox]: Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ - # sort chips based on hit_id + # sort chips based on hit_id.The idea behind this is to ensure that chips that are located + # in the same underlying hit are sampled sequentially. Together with keeping the (chunk of the) + # hit in memory by caching it, this should lead to a speed up in dataloading. TODO: support + # sorting of chunks as well. self.chips = self.chips.sort_values(by=['hit_id']) for _, chip in self.chips.iterrows(): - print("------------------------------------") - print("Chip FID: {}".format(chip["fid"])) yield BoundingBox( chip.minx, chip.maxx, chip.miny, chip.maxy, chip.mint, chip.maxt ) @@ -278,7 +275,7 @@ def __init__( dataset: GeoDataset, size: tuple[float, float] | float, length: int | None = None, - roi: BoundingBox | None = None, + roi: BoundingBox | List[BoundingBox] | None = None, units: Units = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -314,47 +311,23 @@ def __init__( self.size = (self.size[0] * self.res, self.size[1] * self.res) num_chips = 0 - self.hits = [] - areas = [] - # use only geospatial bounds for filering in case return_as_ts is set to True - filter_indices = slice(0,4) if dataset.return_as_ts else slice(0,6) - - if isinstance(self.roi, list): - for area in self.roi: - for hit in self.index.intersection(tuple(area), objects=True): - # Filter out hits in the index that share the same extent - if hit.bounds[filter_indices] not in [ht.bounds[filter_indices] for ht in self.hits]: - bounds = BoundingBox(*hit.bounds) - if ( - bounds.maxx - bounds.minx >= self.size[1] - and bounds.maxy - bounds.miny >= self.size[0] - ): - if bounds.area > 0: - rows, cols = tile_to_chips(bounds, self.size) - num_chips += rows * cols - else: - num_chips += 1 - self.hits.append(hit) - areas.append(bounds.area) - else: - for hit in self.index.intersection(tuple(self.roi), objects=True): - # Filter out hits in the index that share the same extent - if hit.bounds[filter_indices] not in [ht.bounds[filter_indices] for ht in self.hits]: - bounds = BoundingBox(*hit.bounds) - if ( - bounds.maxx - bounds.minx >= self.size[1] - and bounds.maxy - bounds.miny >= self.size[0] - ): - if bounds.area > 0: - rows, cols = tile_to_chips(bounds, self.size) - num_chips += rows * cols - else: - num_chips += 1 - self.hits.append(hit) - areas.append(bounds.area) - if length is not None: - num_chips = length - self.length = num_chips + self.areas_per_roi = [] + + for roi in self.rois: + areas = [] + for hit in self.index.intersection(tuple(roi), objects=True): + # Filter out hits that are smaller than the chip size + hit_bounds = BoundingBox(*hit.bounds) + if ( + hit_bounds.maxx - hit_bounds.minx >= self.size[1] + and hit_bounds.maxy - hit_bounds.miny >= self.size[0] + ): + rows, cols = tile_to_chips(hit_bounds, self.size) + num_chips += rows * cols + areas.append(hit_bounds.area) + self.areas_per_roi.append(areas) + + self.length = length or num_chips print(f"Unique geospatial file hits: {len(self.hits)}") @@ -363,34 +336,33 @@ def __init__( if torch.sum(self.areas) == 0: self.areas += 1 - self.chips = self.get_chips(num_samples=num_chips) + self.chips = self.get_chips(num_samples=self.length) def get_chips(self, num_samples) -> GeoDataFrame: chips = [] + print('generating samples... ') while len(chips) < num_samples: - if isinstance(self.roi, list): - # Choose a random ROI, weighted by area - idx = torch.multinomial(torch.tensor([roi.area for roi in self.roi], dtype=torch.float), 1) - roi = self.roi[idx] - else: - roi = self.roi - - # Choose a random bounding box within the ROI - bbox = get_random_bounding_box(roi, self.size, self.res) - minx, maxx, miny, maxy, mint, maxt = tuple(bbox) + # Choose a random ROI, weighted by area + roi_idx = torch.multinomial(torch.tensor([roi.area for roi in self.rois], dtype=torch.float), 1) + roi = self.rois[roi_idx] - # Find file hits in the index for chosen bounding box - hits = [hit for hit in list(self.index.intersection(tuple(roi), objects=True)) if BoundingBox(*hit.bounds).area > 0] + # Find the hits for the chosen ROI + hits = list(self.index.intersection(tuple(roi), objects=True)) - # if the bounding box has no file hit, dont include it in the samples + # if the roi has no hits, dont try to sample from it if len(hits) == 0: continue else: - # Choose a random hit - hit = random.choice(hits) + # Choose a random hit, weighted by area + hit_idx = torch.multinomial(torch.tensor(self.areas_per_roi[roi_idx], dtype=torch.float), 1) + assert len(self.areas_per_roi[roi_idx]) == len(hits) - # in case we are randomly sampling also the temporal dimension, get - # mint, maxt from the randomly chosen hit + # Choose a random bounding box within the hit + bbox = get_random_bounding_box(hits[hit_idx], self.size, self.res) + minx, maxx, miny, maxy, mint, maxt = tuple(bbox) + + # If the dataset is a SITS dataset, the mint and maxt are the same for all hits + # If not, we also sample across time. if self.dataset.return_as_ts: mint = self.dataset.bounds.mint maxt = self.dataset.bounds.maxt @@ -410,7 +382,6 @@ def get_chips(self, num_samples) -> GeoDataFrame: } chips.append(chip) - print('creating geodataframe... ') chips_gdf = GeoDataFrame(chips, crs=self.dataset.crs) chips_gdf['fid'] = chips_gdf.index @@ -440,7 +411,7 @@ def __init__( dataset: GeoDataset, size: tuple[float, float] | float, stride: tuple[float, float] | float, - roi: BoundingBox | None = None, + roi: BoundingBox | List[BoundingBox] | None = None, units: Units = Units.PIXELS, ) -> None: """Initialize a new Sampler instance. @@ -562,7 +533,10 @@ class PreChippedGeoSampler(GeoSampler): """ def __init__( - self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False + self, + dataset: GeoDataset, + roi: BoundingBox | List[BoundingBox] | None = None, + shuffle: bool = False ) -> None: """Initialize a new Sampler instance.