diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e09d665c3ff..d4d221e3c0d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.0 + rev: v0.9.1 hooks: - id: ruff types_or: diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 458699a5325..656909190ab 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -358,6 +358,11 @@ MapInWild .. autoclass:: MapInWild +MDAS +^^^^ + +.. autoclass:: MDAS + Million-AID ^^^^^^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index cc68d0d05b0..5b70fb768a6 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -30,6 +30,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `LEVIR-CD+`_,CD,Google Earth,-,985,2,"1,024x1,024",0.5,RGB `LoveDA`_,S,Google Earth,"CC-BY-NC-SA-4.0","5,987",7,"1,024x1,024",0.3,RGB `MapInWild`_,S,"Sentinel-1/2, ESA WorldCover, NOAA VIIRS DNB","CC-BY-4.0",1018,1,1920x1920,10--463.83,"SAR, MSI, 2020_Map, avg_rad" +`MDAS`_,S,"Sentinel-1/2,EnMAP,HySpex","CC-BY-SA-4.0",3,20,"100x120, 300x360, 1364x1636, 10000x12000, 15000x18000",0.3--30,HSI `Million-AID`_,C,Google Earth,-,1M,51--73,,0.5--153,RGB `MMEarth`_,"C, S","Aster, Sentinel, ERA5","CC-BY-4.0","100K--1M",,"128x128 or 64x64",10,MSI `NASA Marine Debris`_,OD,PlanetScope,"Apache-2.0",707,1,256x256,3,RGB diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb index 689b2eebd33..e148945afbb 100644 --- a/docs/tutorials/transforms.ipynb +++ b/docs/tutorials/transforms.ipynb @@ -707,7 +707,7 @@ "sample = dataset[idx]\n", "rgb = sample['image'][0, 1:4]\n", "image = T.ToPILImage()(rgb)\n", - "print(f\"Class Label: {dataset.classes[sample['label']]}\")\n", + "print(f'Class Label: {dataset.classes[sample[\"label\"]]}')\n", "image.resize((256, 256), resample=Image.BILINEAR)" ] }, diff --git a/experiments/torchgeo/run_resisc45_experiments.py b/experiments/torchgeo/run_resisc45_experiments.py index 6897ea12772..9ed69b03968 100755 --- a/experiments/torchgeo/run_resisc45_experiments.py +++ b/experiments/torchgeo/run_resisc45_experiments.py @@ -38,7 +38,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: for model, lr, loss, weights in itertools.product( model_options, lr_options, loss_options, weight_options ): - experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}" + experiment_name = f'{model}_{lr}_{loss}_{weights.replace("_", "-")}' output_dir = os.path.join('output', 'resisc45_experiments') log_dir = os.path.join(output_dir, 'logs') diff --git a/experiments/torchgeo/run_so2sat_byol_experiments.py b/experiments/torchgeo/run_so2sat_byol_experiments.py index 169a010cef8..4ae78601fbd 100755 --- a/experiments/torchgeo/run_so2sat_byol_experiments.py +++ b/experiments/torchgeo/run_so2sat_byol_experiments.py @@ -39,7 +39,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: for model, lr, loss, weights, bands in itertools.product( model_options, lr_options, loss_options, weight_options, bands_options ): - experiment_name = f"{model}_{lr}_{loss}_byol_{bands}-{weights.split('/')[-2]}" + experiment_name = f'{model}_{lr}_{loss}_byol_{bands}-{weights.split("/")[-2]}' output_dir = os.path.join('output', 'so2sat_experiments') log_dir = os.path.join(output_dir, 'logs') diff --git a/experiments/torchgeo/run_so2sat_experiments.py b/experiments/torchgeo/run_so2sat_experiments.py index 41e2fc04b5f..44ba5c7aaf3 100755 --- a/experiments/torchgeo/run_so2sat_experiments.py +++ b/experiments/torchgeo/run_so2sat_experiments.py @@ -38,7 +38,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: for model, lr, loss, weights in itertools.product( model_options, lr_options, loss_options, weight_options ): - experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}" + experiment_name = f'{model}_{lr}_{loss}_{weights.replace("_", "-")}' output_dir = os.path.join('output', 'so2sat_experiments') log_dir = os.path.join(output_dir, 'logs') diff --git a/experiments/torchgeo/run_so2sat_seed_experiments.py b/experiments/torchgeo/run_so2sat_seed_experiments.py index 2d2efe1e248..4f71770917a 100755 --- a/experiments/torchgeo/run_so2sat_seed_experiments.py +++ b/experiments/torchgeo/run_so2sat_seed_experiments.py @@ -39,7 +39,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool: for model, lr, loss, weights, seed in itertools.product( model_options, lr_options, loss_options, weight_options, seeds ): - experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}_{seed}" + experiment_name = f'{model}_{lr}_{loss}_{weights.replace("_", "-")}_{seed}' output_dir = os.path.join('output', 'so2sat_seed_experiments') log_dir = os.path.join(output_dir, 'logs') diff --git a/pyproject.toml b/pyproject.toml index a4e125f5a31..8d7fcc89183 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,8 +115,8 @@ docs = [ style = [ # mypy 0.900+ required for pyproject.toml support "mypy>=0.900", - # ruff 0.8+ required for removal of ANN101, ANN102 - "ruff>=0.8", + # ruff 0.9+ required for 2025 style guide + "ruff>=0.9", ] tests = [ # nbmake 1.3.3+ required for variable mocking diff --git a/requirements/datasets.txt b/requirements/datasets.txt index 9cfdcbe5d3c..d2be8e65a52 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -6,6 +6,6 @@ pandas[parquet]==2.2.3 pycocotools==2.0.8 pyvista==0.44.2 scikit-image==0.25.0 -scipy==1.14.1 +scipy==1.15.0 xarray==2024.11.0 netcdf4==1.7.2 diff --git a/requirements/required.txt b/requirements/required.txt index 055d27670f6..62295bb4c71 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -1,5 +1,5 @@ # setup -setuptools==75.6.0 +setuptools==75.8.0 # install einops==0.8.0 @@ -10,7 +10,7 @@ lightning[pytorch-extra]==2.5.0.post0 matplotlib==3.10.0 numpy==2.2.1 pandas==2.2.3 -pillow==11.0.0 +pillow==11.1.0 pyproj==3.7.0 rasterio==1.4.3 rtree==1.3.0 diff --git a/requirements/style.txt b/requirements/style.txt index dc9b06b247a..4b0987677c7 100644 --- a/requirements/style.txt +++ b/requirements/style.txt @@ -1,3 +1,3 @@ # style mypy==1.14.1 -ruff==0.8.4 +ruff==0.9.1 diff --git a/tests/data/inria/data.py b/tests/data/inria/data.py index 96626dea5fa..76ea3aa9c60 100755 --- a/tests/data/inria/data.py +++ b/tests/data/inria/data.py @@ -68,9 +68,9 @@ def generate_test_data(root: str, n_samples: int = 2) -> str: lbl = np.random.randint(2, size=size, dtype=dtype) timg = np.random.randint(dtype_max, size=size, dtype=dtype) - img_path = os.path.join(img_dir, f'austin{i+1}.tif') - lbl_path = os.path.join(lbl_dir, f'austin{i+1}.tif') - timg_path = os.path.join(timg_dir, f'austin{i+10}.tif') + img_path = os.path.join(img_dir, f'austin{i + 1}.tif') + lbl_path = os.path.join(lbl_dir, f'austin{i + 1}.tif') + timg_path = os.path.join(timg_dir, f'austin{i + 10}.tif') write_data(img_path, img, driver, crs, transform) write_data(lbl_path, lbl, driver, crs, transform) diff --git a/tests/data/mdas/Augsburg_data_4_publication.zip b/tests/data/mdas/Augsburg_data_4_publication.zip new file mode 100644 index 00000000000..a4e00554127 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication.zip differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/3K_DSM_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/3K_DSM_sub_area1.tif new file mode 100644 index 00000000000..13e7483a0be Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/3K_DSM_sub_area1.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/3K_RGB_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/3K_RGB_sub_area1.tif new file mode 100644 index 00000000000..400085756b3 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/3K_RGB_sub_area1.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_EnMAP_10m_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_EnMAP_10m_sub_area1.tif new file mode 100644 index 00000000000..96f0de07058 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_EnMAP_10m_sub_area1.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_EnMAP_30m_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_EnMAP_30m_sub_area1.tif new file mode 100644 index 00000000000..30d07cd51ea Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_EnMAP_30m_sub_area1.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_Sentinel_2_10m_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_Sentinel_2_10m_sub_area1.tif new file mode 100644 index 00000000000..5a50122b2c9 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_Sentinel_2_10m_sub_area1.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/HySpex_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/HySpex_sub_area1.tif new file mode 100644 index 00000000000..e0984b4f7c1 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/HySpex_sub_area1.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/Sentinel_1_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/Sentinel_1_sub_area1.tif new file mode 100644 index 00000000000..834990c3e20 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/Sentinel_1_sub_area1.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/Sentinel_2_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/Sentinel_2_sub_area1.tif new file mode 100644 index 00000000000..489e96791a2 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/Sentinel_2_sub_area1.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_buildings_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_buildings_sub_area1.tif new file mode 100644 index 00000000000..0faaa47d73f Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_buildings_sub_area1.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_landuse_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_landuse_sub_area1.tif new file mode 100644 index 00000000000..c80350418fc Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_landuse_sub_area1.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_water_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_water_sub_area1.tif new file mode 100644 index 00000000000..8840f67a9aa Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_water_sub_area1.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/3K_DSM_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/3K_DSM_sub_area2.tif new file mode 100644 index 00000000000..313e5f9f6f4 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/3K_DSM_sub_area2.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/3K_RGB_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/3K_RGB_sub_area2.tif new file mode 100644 index 00000000000..37cdaa7f26b Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/3K_RGB_sub_area2.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_EnMAP_10m_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_EnMAP_10m_sub_area2.tif new file mode 100644 index 00000000000..7ac37f43f06 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_EnMAP_10m_sub_area2.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_EnMAP_30m_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_EnMAP_30m_sub_area2.tif new file mode 100644 index 00000000000..349b28401ac Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_EnMAP_30m_sub_area2.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_Sentinel_2_10m_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_Sentinel_2_10m_sub_area2.tif new file mode 100644 index 00000000000..fd0bcceda81 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_Sentinel_2_10m_sub_area2.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/HySpex_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/HySpex_sub_area2.tif new file mode 100644 index 00000000000..21e22212147 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/HySpex_sub_area2.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/Sentinel_1_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/Sentinel_1_sub_area2.tif new file mode 100644 index 00000000000..a7ebb2b5b61 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/Sentinel_1_sub_area2.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/Sentinel_2_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/Sentinel_2_sub_area2.tif new file mode 100644 index 00000000000..c4a390f88a8 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/Sentinel_2_sub_area2.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_buildings_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_buildings_sub_area2.tif new file mode 100644 index 00000000000..711d872d705 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_buildings_sub_area2.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_landuse_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_landuse_sub_area2.tif new file mode 100644 index 00000000000..8a3f1ca1692 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_landuse_sub_area2.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_water_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_water_sub_area2.tif new file mode 100644 index 00000000000..9c88128b574 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_water_sub_area2.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/3K_DSM_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/3K_DSM_sub_area3.tif new file mode 100644 index 00000000000..e664307c2c4 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/3K_DSM_sub_area3.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/3K_RGB_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/3K_RGB_sub_area3.tif new file mode 100644 index 00000000000..aa020dde3f1 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/3K_RGB_sub_area3.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_EnMAP_10m_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_EnMAP_10m_sub_area3.tif new file mode 100644 index 00000000000..8f64503512d Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_EnMAP_10m_sub_area3.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_EnMAP_30m_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_EnMAP_30m_sub_area3.tif new file mode 100644 index 00000000000..c30d66e3dd4 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_EnMAP_30m_sub_area3.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_Sentinel_2_10m_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_Sentinel_2_10m_sub_area3.tif new file mode 100644 index 00000000000..2fd45d96c7f Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_Sentinel_2_10m_sub_area3.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/HySpex_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/HySpex_sub_area3.tif new file mode 100644 index 00000000000..36b6ce50d76 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/HySpex_sub_area3.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/Sentinel_1_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/Sentinel_1_sub_area3.tif new file mode 100644 index 00000000000..5cfad510e84 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/Sentinel_1_sub_area3.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/Sentinel_2_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/Sentinel_2_sub_area3.tif new file mode 100644 index 00000000000..3e80c170bf8 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/Sentinel_2_sub_area3.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_buildings_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_buildings_sub_area3.tif new file mode 100644 index 00000000000..9861e5e07ee Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_buildings_sub_area3.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_landuse_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_landuse_sub_area3.tif new file mode 100644 index 00000000000..3705837b37a Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_landuse_sub_area3.tif differ diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_water_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_water_sub_area3.tif new file mode 100644 index 00000000000..4bcb146bef6 Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_water_sub_area3.tif differ diff --git a/tests/data/mdas/data.py b/tests/data/mdas/data.py new file mode 100644 index 00000000000..c82b54cb89a --- /dev/null +++ b/tests/data/mdas/data.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil + +import numpy as np +import rasterio +from rasterio.crs import CRS +from rasterio.transform import from_origin + +# Set the random seed for reproducibility +np.random.seed(0) + +# Define the root directory, dataset name, subareas, and modalities based on mdas.py +root_dir = '.' +ds_root_name = 'Augsburg_data_4_publication' +subareas = ['sub_area_1', 'sub_area_2', 'sub_area_3'] +modalities = [ + '3K_DSM', + '3K_RGB', + 'HySpex', + 'EeteS_EnMAP_10m', + 'EeteS_EnMAP_30m', + 'EeteS_Sentinel_2_10m', + 'Sentinel_1', + 'Sentinel_2', + 'osm_buildings', + 'osm_landuse', + 'osm_water', +] + +landuse_class_codes = [ + -2147483647, # no label + 7201, # forest + 7202, # park + 7203, # residential + 7204, # industrial + 7205, # farm + 7206, # cemetery + 7207, # allotments + 7208, # meadow + 7209, # commercial + 7210, # nature reserve + 7211, # recreation ground + 7212, # retail + 7213, # military + 7214, # quarry + 7215, # orchard + 7217, # scrub + 7218, # grass + 7219, # heath +] + +# Remove existing dummy data if it exists +dataset_path = os.path.join(root_dir, ds_root_name) +if os.path.exists(dataset_path): + shutil.rmtree(dataset_path) + + +def create_dummy_geotiff( + path: str, + num_bands: int = 3, + width: int = 32, + height: int = 32, + dtype: np.dtype = np.uint16, + binary: bool = False, + landuse: bool = False, +) -> None: + """Create a dummy GeoTIFF file.""" + crs = CRS.from_epsg(32632) + transform = from_origin(0, 0, 1, 1) + + if binary: + data = np.random.randint(0, 2, size=(num_bands, height, width)).astype(dtype) + elif landuse: + num_pixels = num_bands * height * width + no_label_ratio = 0.1 + num_no_label = int(no_label_ratio * num_pixels) + num_labels = num_pixels - num_no_label + landuse_values = np.random.choice(landuse_class_codes[1:], size=num_labels) + no_label_values = np.full(num_no_label, landuse_class_codes[0], dtype=dtype) + combined = np.concatenate([landuse_values, no_label_values]) + np.random.shuffle(combined) + data = combined.reshape((num_bands, height, width)).astype(dtype) + else: + # Generate random data for other modalities + data = np.random.randint(0, 255, size=(num_bands, height, width)).astype(dtype) + + os.makedirs(os.path.dirname(path), exist_ok=True) + + with rasterio.open( + path, + 'w', + driver='GTiff', + height=height, + width=width, + count=num_bands, + dtype=dtype, + crs=crs, + transform=transform, + ) as dst: + dst.write(data) + + +# Create directory structure and dummy data +for subarea in subareas: + # Format the subarea name for filenames, as in mdas.py _format_subarea method + parts = subarea.split('_') + subarea_formatted = parts[0] + '_' + parts[1] + parts[2] # e.g., 'sub_area1' + + subarea_dir = os.path.join(root_dir, ds_root_name, subarea) + + for modality in modalities: + filename = f'{modality}_{subarea_formatted}.tif' + file_path = os.path.join(subarea_dir, filename) + + if modality in ['osm_buildings', 'osm_water']: + create_dummy_geotiff(file_path, num_bands=1, dtype=np.uint8, binary=True) + elif modality == 'osm_landuse': + create_dummy_geotiff(file_path, num_bands=1, dtype=np.float64, landuse=True) + elif modality == 'HySpex': + create_dummy_geotiff(file_path, num_bands=368, dtype=np.int16) + elif modality in ['EeteS_EnMAP_10m', 'EeteS_EnMAP_30m']: + create_dummy_geotiff(file_path, num_bands=242, dtype=np.uint16) + elif modality == 'Sentinel_1': + create_dummy_geotiff(file_path, num_bands=2, dtype=np.float32) + elif modality in ['Sentinel_2', 'EeteS_Sentinel_2_10m']: + create_dummy_geotiff(file_path, num_bands=13, dtype=np.uint16) + elif modality == '3K_DSM': + create_dummy_geotiff(file_path, num_bands=1, dtype=np.float32) + elif modality == '3K_RGB': + create_dummy_geotiff(file_path, num_bands=3, dtype=np.uint8) + +print(f'Dummy MDAS dataset created at {os.path.join(root_dir, ds_root_name)}') + +# Create a zip archive of the dataset directory +zip_filename = f'{ds_root_name}.zip' +zip_path = os.path.join(root_dir, zip_filename) + +shutil.make_archive( + base_name=os.path.splitext(zip_path)[0], + format='zip', + root_dir='.', + base_dir=ds_root_name, +) + + +def calculate_md5(filename: str) -> str: + hash_md5 = hashlib.md5() + with open(filename, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b''): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +checksum = calculate_md5(zip_path) +print(f'MD5 checksum: {checksum}') diff --git a/tests/data/seasonet/data.py b/tests/data/seasonet/data.py index e3197ddde12..86f6210636b 100644 --- a/tests/data/seasonet/data.py +++ b/tests/data/seasonet/data.py @@ -63,7 +63,7 @@ os.remove(archive) for grid, comp in zip(grids, name_comps): - file_name = f"{comp[0]}_{''.join(comp[1:8])}_{'_'.join(comp[8:])}" + file_name = f'{comp[0]}_{"".join(comp[1:8])}_{"_".join(comp[8:])}' dir = os.path.join(season, f'grid{grid}', file_name) os.makedirs(dir) diff --git a/tests/data/ssl4eo_benchmark_landsat/data.py b/tests/data/ssl4eo_benchmark_landsat/data.py index 177ed7d7954..5470aacef05 100755 --- a/tests/data/ssl4eo_benchmark_landsat/data.py +++ b/tests/data/ssl4eo_benchmark_landsat/data.py @@ -193,7 +193,7 @@ def create_tarballs(directories: str) -> None: # mask directory cdl mask_keep = ['tm_toa', 'etm_sr', 'oli_sr'] mask_filenames = { - f"ssl4eo_l_{key.split('_')[0]}_cdl": val + f'ssl4eo_l_{key.split("_")[0]}_cdl': val for key, val in filenames.items() if key in mask_keep } @@ -203,7 +203,7 @@ def create_tarballs(directories: str) -> None: # mask directory nlcd mask_filenames = { - f"ssl4eo_l_{key.split('_')[0]}_nlcd": val + f'ssl4eo_l_{key.split("_")[0]}_nlcd': val for key, val in filenames.items() if key in mask_keep } diff --git a/tests/datamodules/test_digital_typhoon.py b/tests/datamodules/test_digital_typhoon.py index 0ecd85f5ec7..dd61eb26933 100644 --- a/tests/datamodules/test_digital_typhoon.py +++ b/tests/datamodules/test_digital_typhoon.py @@ -57,14 +57,14 @@ def find_max_time_per_id( # Assert that each max value in train_max_values is lower # than in val_max_values for each key id for id, max_value in train_max_values.items(): - assert ( - id not in val_max_values or max_value < val_max_values[id] - ), f'Max value for id {id} in train is not lower than in validation.' + assert id not in val_max_values or max_value < val_max_values[id], ( + f'Max value for id {id} in train is not lower than in validation.' + ) else: train_ids = {seq['id'] for seq in train_sequences} val_ids = {seq['id'] for seq in val_sequences} # Assert that the intersection between train_ids and val_ids is empty - assert ( - len(train_ids & val_ids) == 0 - ), 'Train and validation datasets have overlapping ids.' + assert len(train_ids & val_ids) == 0, ( + 'Train and validation datasets have overlapping ids.' + ) diff --git a/tests/datasets/test_mdas.py b/tests/datasets/test_mdas.py new file mode 100644 index 00000000000..83138c84207 --- /dev/null +++ b/tests/datasets/test_mdas.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import MDAS, DatasetNotFoundError + + +class TestMDAS: + @pytest.fixture( + params=[ + {'subareas': ['sub_area_1'], 'modalities': ['HySpex']}, + { + 'subareas': ['sub_area_1', 'sub_area_2'], + 'modalities': ['3K_DSM', 'HySpex', 'osm_water'], + }, + { + 'subareas': ['sub_area_2', 'sub_area_3'], + 'modalities': [ + '3K_DSM', + '3K_RGB', + 'HySpex', + 'EeteS_EnMAP_10m', + 'EeteS_EnMAP_30m', + 'EeteS_Sentinel_2_10m', + 'Sentinel_2', + 'Sentinel_1', + 'osm_buildings', + 'osm_landuse', + 'osm_water', + ], + }, + ] + ) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> MDAS: + md5 = '99e1744ca6f19aa19a3aa23a2bbf7bef' + monkeypatch.setattr(MDAS, 'md5', md5) + url = os.path.join('tests', 'data', 'mdas', 'Augsburg_data_4_publication.zip') + monkeypatch.setattr(MDAS, 'url', url) + + params = request.param + subareas = params['subareas'] + modalities = params['modalities'] + + root = tmp_path + transforms = nn.Identity() + + return MDAS( + root=root, + subareas=subareas, + modalities=modalities, + transforms=transforms, + download=True, + checksum=True, + ) + + def test_getitem(self, dataset: MDAS) -> None: + x = dataset[0] + assert isinstance(x, dict) + for key in dataset.modalities: + if key.startswith('osm'): + key = f'{key}_mask' + else: + key = f'{key}_image' + assert key in x + + for key, value in x.items(): + assert isinstance(value, torch.Tensor) + + def test_len(self, dataset: MDAS) -> None: + assert len(dataset) == len(dataset.subareas) + + def test_already_downloaded(self, dataset: MDAS) -> None: + MDAS(root=dataset.root) + + def test_not_yet_extracted(self, tmp_path: Path) -> None: + filename = 'Augsburg_data_4_publication.zip' + dir = os.path.join('tests', 'data', 'mdas') + shutil.copyfile( + os.path.join(dir, filename), os.path.join(str(tmp_path), filename) + ) + MDAS(root=str(tmp_path)) + + def test_invalid_subarea(self) -> None: + with pytest.raises(AssertionError): + MDAS(subareas=['foo']) + + def test_invalid_modality(self) -> None: + with pytest.raises(AssertionError): + MDAS(modalities=['foo']) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + MDAS(tmp_path) + + def test_plot(self, dataset: MDAS) -> None: + dataset.plot(dataset[0], suptitle='Test') + plt.close() + + def test_plot_single_sample(self, dataset: MDAS) -> None: + dataset.plot(dataset[0], show_titles=False) + plt.close() diff --git a/torchgeo/datamodules/digital_typhoon.py b/torchgeo/datamodules/digital_typhoon.py index ce799bf3d52..9ebc1643255 100644 --- a/torchgeo/datamodules/digital_typhoon.py +++ b/torchgeo/datamodules/digital_typhoon.py @@ -43,9 +43,9 @@ def __init__( """ super().__init__(DigitalTyphoon, batch_size, num_workers, **kwargs) - assert ( - split_by in self.valid_split_types - ), f'Please choose from {self.valid_split_types}' + assert split_by in self.valid_split_types, ( + f'Please choose from {self.valid_split_types}' + ) self.split_by = split_by def _split_dataset( diff --git a/torchgeo/datamodules/ftw.py b/torchgeo/datamodules/ftw.py index a197a789c48..19128cdbb3d 100644 --- a/torchgeo/datamodules/ftw.py +++ b/torchgeo/datamodules/ftw.py @@ -44,9 +44,9 @@ def __init__( Raises: AssertionError: If 'countries' are specified in kwargs """ - assert ( - 'countries' not in kwargs - ), "Please specify 'train_countries', 'val_countries', and 'test_countries' instead of 'countries' inside kwargs" + assert 'countries' not in kwargs, ( + "Please specify 'train_countries', 'val_countries', and 'test_countries' instead of 'countries' inside kwargs" + ) super().__init__(FieldsOfTheWorld, batch_size, num_workers, **kwargs) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index c9cb2989d0a..8c2e362b6c1 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -86,6 +86,7 @@ from .levircd import LEVIRCD, LEVIRCDBase, LEVIRCDPlus from .loveda import LoveDA from .mapinwild import MapInWild +from .mdas import MDAS from .millionaid import MillionAID from .mmearth import MMEarth from .naip import NAIP @@ -160,6 +161,7 @@ 'GBIF', 'GID15', 'LEVIRCD', + 'MDAS', 'NAIP', 'NCCM', 'NLCD', diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py index 3624c1e193e..ba116377878 100644 --- a/torchgeo/datasets/agrifieldnet.py +++ b/torchgeo/datasets/agrifieldnet.py @@ -149,9 +149,9 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert ( - set(classes) <= self.cmap.keys() - ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert set(classes) <= self.cmap.keys(), ( + f'Only the following classes are valid: {list(self.cmap.keys())}.' + ) assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 38669cd6ff1..ef62ac1a280 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -565,9 +565,9 @@ def plot( ax.imshow(image) ax.axis('off') if show_titles: - title = f"Labels: {', '.join(labels)}" + title = f'Labels: {", ".join(labels)}' if showing_predictions: - title += f"\nPredictions: {', '.join(predictions)}" + title += f'\nPredictions: {", ".join(predictions)}' ax.set_title(title) if suptitle is not None: diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py index 70a53a4220a..2531c96dd23 100644 --- a/torchgeo/datasets/biomassters.py +++ b/torchgeo/datasets/biomassters.py @@ -81,14 +81,14 @@ def __init__( """ self.root = root - assert ( - split in self.valid_splits - ), f'Please choose one of the valid splits: {self.valid_splits}.' + assert split in self.valid_splits, ( + f'Please choose one of the valid splits: {self.valid_splits}.' + ) self.split = split - assert set(sensors).issubset( - set(self.valid_sensors) - ), f'Please choose a subset of valid sensors: {self.valid_sensors}.' + assert set(sensors).issubset(set(self.valid_sensors)), ( + f'Please choose a subset of valid sensors: {self.valid_sensors}.' + ) self.sensors = sensors self.as_time_series = as_time_series diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index 0b0f6ac5b3d..2de5719beb0 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -248,9 +248,9 @@ def __init__( 'CDL data product only exists for the following years: ' f'{list(self.md5s.keys())}.' ) - assert ( - set(classes) <= self.cmap.keys() - ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert set(classes) <= self.cmap.keys(), ( + f'Only the following classes are valid: {list(self.cmap.keys())}.' + ) assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py index f9db256238d..681d5026f25 100644 --- a/torchgeo/datasets/cms_mangrove_canopy.py +++ b/torchgeo/datasets/cms_mangrove_canopy.py @@ -204,15 +204,15 @@ def __init__( self.checksum = checksum assert isinstance(country, str), 'Country argument must be a str.' - assert ( - country in self.all_countries - ), f'You have selected an invalid country, please choose one of {self.all_countries}' + assert country in self.all_countries, ( + f'You have selected an invalid country, please choose one of {self.all_countries}' + ) self.country = country assert isinstance(measurement, str), 'Measurement must be a string.' - assert ( - measurement in self.measurements - ), f'You have entered an invalid measurement, please choose one of {self.measurements}.' + assert measurement in self.measurements, ( + f'You have entered an invalid measurement, please choose one of {self.measurements}.' + ) self.measurement = measurement self.filename_glob = f'**/Mangrove_{self.measurement}_{self.country}*' diff --git a/torchgeo/datasets/digital_typhoon.py b/torchgeo/datasets/digital_typhoon.py index 42bb4caa1bd..dfa47966440 100644 --- a/torchgeo/datasets/digital_typhoon.py +++ b/torchgeo/datasets/digital_typhoon.py @@ -139,9 +139,9 @@ def __init__( self.min_feature_value = min_feature_value self.max_feature_value = max_feature_value - assert ( - task in self.valid_tasks - ), f'Please choose one of {self.valid_tasks}, you provided {task}.' + assert task in self.valid_tasks, ( + f'Please choose one of {self.valid_tasks}, you provided {task}.' + ) self.task = task assert set(features).issubset(set(self.valid_features)) diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index 93b6b18e455..e3066058371 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -115,9 +115,9 @@ def __init__( DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits - assert set(scene).intersection( - set(self.scenes) - ), "The possible scenes are 'rural' and/or 'urban'" + assert set(scene).intersection(set(self.scenes)), ( + "The possible scenes are 'rural' and/or 'urban'" + ) assert len(scene) <= 2, "There are no other scenes than 'rural' or 'urban'" self.root = root diff --git a/torchgeo/datasets/mdas.py b/torchgeo/datasets/mdas.py new file mode 100644 index 00000000000..25a61a72396 --- /dev/null +++ b/torchgeo/datasets/mdas.py @@ -0,0 +1,379 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""MDAS dataset.""" + +import os +from collections.abc import Callable +from typing import Any, ClassVar + +import matplotlib.cm as cm +import matplotlib.pyplot as plt +import numpy as np +import rasterio as rio +import torch +from matplotlib.colors import ListedColormap +from matplotlib.figure import Figure +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, download_and_extract_archive, extract_archive + + +class MDAS(NonGeoDataset): + """MDAS dataset. + + The `MDAS `__ multimodal dataset + is a comprehensive dataset for the city of Augsburg, Germany, collected on 7th May 2018. + It includes SAR, multispectral, hyperspectral, DSM, and GIS data, + providing comprehensive options for data fusion research. + MDAS supports applications like resolution enhancement, spectral unmixing, and land cover classification. + + Dataset features: + + * 3K DSM data + * 3K high resolution RGB images + * Original very high resolution HySpex airborne imagery + * EeteS simulated imagery with 10m GSD and EnMAP spectral bands + * EeteS simulated imagery with 30m GSD and EnMAP spectral bands + * EeteS simulated imagery with 10m GSD and Sentinel-2 spectral bands + * Sentinel-2 L2A product + * Sentinel-1 GRD product + * Open Street Map (OSM) labels, see `this table `__ for + a table of the label distribution + + Dataset format: + + * 3K_RGB.tif (Shape: (4, 15000, 18000)px, Data Type: uint8) + * 3K_dsm.tif (Shape: (1, 10000, 12000)px, Data Type: float32) + * HySpex.tif (Shape: (368, 1364, 1636)px, Data Type: int16) + * EeteS_EnMAP_2dot2m.tif (Shape: (242, 1364, 1636)px, Data Type: float32) + * EeteS_EnMAP_10m.tif (Shape: (242, 300, 360)px, Data Type: uint16) + * EeteS_EnMAP_30m.tif (Shape: (242, 100, 120)px, Data Type: uint16) + * EeteS_Sentinel_2_10m.tif (Shape: (4, 300, 360)px, Data Type: uint16) + * Sentinel_2.tif (Shape: (12, 300, 360)px, Data Type: uint16) + * Sentinel_1.tif (Shape: (2, 300, 360)px, Data Type: float32) + * osm_buildings.tif (Shape: (1, 1364, 1636)px, Data Type: uint8) + * osm_landuse.tif (Shape: (1, 1364, 1636)px, Data Type: float64) + * osm_water.tif (Shape: (1, 1364, 1636)px, Data Type: float64) + + If you use this dataset in your research, please cite the following paper: + + * https://essd.copernicus.org/articles/15/113/2023/ + + .. versionadded:: 0.7 + """ + + valid_modalities = ( + '3K_DSM', + '3K_RGB', + 'HySpex', + 'EeteS_EnMAP_10m', + 'EeteS_EnMAP_30m', + 'EeteS_Sentinel_2_10m', + 'Sentinel_2', + 'Sentinel_1', + 'osm_buildings', + 'osm_landuse', + 'osm_water', + ) + landuse_class_names: ClassVar[dict[int, str]] = { + 0: 'no label', + 1: 'forest', + 2: 'park', + 3: 'residential', + 4: 'industrial', + 5: 'farm', + 6: 'cemetery', + 7: 'allotments', + 8: 'meadow', + 9: 'commercial', + 10: 'nature reserve', + 11: 'recreation ground', + 12: 'retail', + 13: 'military', + 14: 'quarry', + 15: 'orchard', + 16: 'scrub', + 17: 'grass', + 18: 'heath', + } + + # https://github.com/zhu-xlab/augsburg_Multimodal_Data_Set_MDaS/blob/75c015022b5f688dfc44744f19bcf34bdce786c7/Augsburg_data_4_publication/entire_city/OSM_label/README#L14 + landuse_mapping: ClassVar[dict[int, int]] = { + -2147483647: 0, + 7201: 1, + 7202: 2, + 7203: 3, + 7204: 4, + 7205: 5, + 7206: 6, + 7207: 7, + 7208: 8, + 7209: 9, + 7210: 10, + 7211: 11, + 7212: 12, + 7213: 13, + 7214: 14, + 7215: 15, + 7217: 16, + 7218: 17, + 7219: 18, + } + + ds_root_name = 'Augsburg_data_4_publication' + + zipfilename = f'{ds_root_name}.zip' + + valid_subareas = ('sub_area_1', 'sub_area_2', 'sub_area_3') + + url = 'https://huggingface.co/datasets/torchgeo/mdas/resolve/860226b74269f1cf1bed8ea3c03f571ae701144c/Augsburg_data_4_publication.zip' + + md5 = '7b63c26e3717cb52c6ba47d215f18d5b' + + enmap_rgb_band_idx: ClassVar[list[int]] = [43, 28, 10] + sentinel_2_rgb_band_idx: ClassVar[list[int]] = [3, 2, 1] + hyspex_rgb_band_idx: ClassVar[list[int]] = [100, 50, 10] + + def __init__( + self, + root: Path = 'data', + subareas: list[str] = ['sub_area_1'], + modalities: list[str] = ['3K_RGB', 'HySpex', 'Sentinel_2'], + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new MDAS dataset instance. + + Args: + root: Root directory where the dataset should be stored. + subareas: The subareas to load. Options are 'sub_area_1', 'sub_area_2', 'sub_area_3'. + modalities: The modalities to load. Options are '3K_DSM', '3K_RGB', 'HySpex', 'EeteS_EnMAP_10m', 'EeteS_EnMAP_30m', 'EeteS_Sentinel_2_10m', 'Sentinel-2', 'Sentinel-1', 'OSM_label'. + transforms: A function/transform that takes in a dictionary and returns a transformed version. + download: if True, download dataset and store it in the root directory + checksum: If True, check the integrity of the dataset after download. + + Raises: + AssertionError: If the subareas or modalities are not valid. + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + self.root = root + self.download = download + assert all(sub in self.valid_subareas for sub in subareas), ( + f'Subareas must be one of {self.valid_subareas}' + ) + self.subareas = subareas + assert all(mod in self.valid_modalities for mod in modalities), ( + f'Modalities must be one of {self.valid_modalities}' + ) + self.modalities = modalities + self.transforms = transforms + self.checksum = checksum + + self._verify() + self.files = self._load_files() + + def __len__(self) -> int: + """Return the number of samples in the dataset. + + Returns: + the length of the dataset + """ + return len(self.files) + + def _load_files(self) -> list[dict[str, str]]: + """Return the paths of the files in the dataset. + + Returns: + a list of dictionaries containing the paths of the files in the dataset + """ + files = [] + for subarea in self.subareas: + subarea_files = {} + for modality in self.modalities: + subarea_files[modality] = os.path.join( + self.root, + self.ds_root_name, + subarea, + f'{modality}_{self._format_subarea(subarea)}.tif', + ) + files.append(subarea_files) + return files + + def _format_subarea(self, subarea: str) -> str: + """Format the subarea name. + + Args: + subarea: The subarea string to format. + + Returns: + formatted subarea string for files + """ + parts = subarea.split('_') + return parts[0] + '_' + parts[1] + parts[2] + + def _load_image(self, path: Path) -> Tensor: + """Load an image from a given path. + + Args: + path: The path to the image file + + Returns: + the loaded image as a tensor + """ + with rio.open(path) as src: + img = src.read() + if img.dtype == np.uint16: + img = img.astype(np.int32) + if 'osm_landuse' in str(path): + img = np.vectorize(self.landuse_mapping.get)(img) + + return torch.from_numpy(img) + + def __getitem__(self, idx: int) -> dict[str, Tensor]: + """Return the dataset sample at the given index. + + Args: + idx: The index of the sample to return + + Returns: + a dictionary containing the data of chosen modalities + """ + sample_files = self.files[idx] + sample: dict[str, Any] = {} + for modality, path in sample_files.items(): + if 'osm' in modality: + sample[f'{modality}_mask'] = self._load_image(path).long() + else: + sample[f'{modality}_image'] = self._load_image(path) + + if self.transforms: + sample = self.transforms(sample) + + return sample + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # check if each desired modality file exists in specified subarea + exists = [] + for subarea in self.subareas: + for modality in self.modalities: + path = os.path.join( + self.root, + self.ds_root_name, + subarea, + f'{modality}_{self._format_subarea(subarea)}.tif', + ) + if not os.path.exists(path): + exists.append(False) + else: + exists.append(True) + if all(exists): + return + + # check if zip file downloaded + if os.path.exists(os.path.join(self.root, self.zipfilename)): + self._extract() + return + + if not self.download: + raise DatasetNotFoundError(self) + + self._download() + + def _extract(self) -> None: + """Extract the dataset.""" + extract_archive(os.path.join(self.root, self.zipfilename), self.root) + + def _download(self) -> None: + """Download the dataset.""" + download_and_extract_archive( + self.url, + self.root, + filename=self.zipfilename, + md5=self.md5 if self.checksum else None, + ) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: A sample returned by `__getitem__`. + show_titles: Whether to display titles on the subplots. + suptitle: An optional super title for the plot. + + Returns: + a matplotlib Figure with the rendered sample + """ + ncols = len(sample) + fig, axs = plt.subplots(1, ncols, figsize=(5 * ncols, 5)) + + if ncols == 1: + axs = [axs] + + for idx, (key, data) in enumerate(sample.items()): + match key: + case '3K_RGB_image': + img = data[:3].numpy().transpose(1, 2, 0) / 255.0 + axs[idx].imshow(img) + case '3K_DSM_image': + img = data.numpy().squeeze(0) + axs[idx].imshow(img, cmap='gray') + case 'EeteS_EnMAP_10m_image' | 'EeteS_EnMAP_30m_image': + img = ( + data[self.enmap_rgb_band_idx].numpy().transpose(1, 2, 0) + / 10000.0 + ) + axs[idx].imshow(img) + case 'EeteS_Sentinel_2_10m_image': + img = ( + data[self.sentinel_2_rgb_band_idx].numpy().transpose(1, 2, 0) + / 10000.0 + ) + axs[idx].imshow(img) + case 'Sentinel_1_image': + img = data[0].numpy().clip(0, 1) + axs[idx].imshow(img) + case 'Sentinel_2_image': + img = ( + data[self.sentinel_2_rgb_band_idx].numpy().transpose(1, 2, 0) + / 10000.0 + ) + axs[idx].imshow(img) + case 'HySpex_image': + img = ( + data[self.hyspex_rgb_band_idx].numpy().transpose(1, 2, 0) + / 15000.0 + ) + axs[idx].imshow(img) + case 'osm_landuse_mask': + img = data.numpy().squeeze(0) + cmap = ListedColormap([cm.get_cmap('tab20')(i) for i in range(20)]) + im = axs[idx].imshow(img, cmap=cmap) + cbar = plt.colorbar(im, ax=axs[idx], ticks=range(19)) + cbar.ax.set_yticklabels( + [self.landuse_class_names[i] for i in range(19)] + ) + case 'osm_buildings_mask': + img = data.numpy().squeeze(0) + axs[idx].imshow(img, cmap='gray') + case 'osm_water_mask': + img = data.numpy().squeeze(0) + axs[idx].imshow(img, cmap='Blues') + + axs[idx].axis('off') + if show_titles: + axs[idx].set_title(key) + + if suptitle: + plt.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/mmearth.py b/torchgeo/datasets/mmearth.py index f363276c40a..b940537d8b4 100644 --- a/torchgeo/datasets/mmearth.py +++ b/torchgeo/datasets/mmearth.py @@ -206,12 +206,12 @@ def __init__( """ lazy_import('h5py') - assert ( - normalization_mode in self.norm_modes - ), f'Invalid normalization mode: {normalization_mode}, please choose from {self.norm_modes}' - assert ( - subset in self.subsets - ), f'Invalid dataset version: {subset}, please choose from {self.subsets}' + assert normalization_mode in self.norm_modes, ( + f'Invalid normalization mode: {normalization_mode}, please choose from {self.norm_modes}' + ) + assert subset in self.subsets, ( + f'Invalid dataset version: {subset}, please choose from {self.subsets}' + ) self._validate_modalities(modalities) self.modalities = modalities diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py index 501fd6db8f9..681f0e242bc 100644 --- a/torchgeo/datasets/nlcd.py +++ b/torchgeo/datasets/nlcd.py @@ -167,9 +167,9 @@ def __init__( 'NLCD data product only exists for the following years: ' f'{list(self.md5s.keys())}.' ) - assert ( - set(classes) <= self.cmap.keys() - ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert set(classes) <= self.cmap.keys(), ( + f'Only the following classes are valid: {list(self.cmap.keys())}.' + ) assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths diff --git a/torchgeo/datasets/seasonet.py b/torchgeo/datasets/seasonet.py index 3e47a8ec491..ebbb9036374 100644 --- a/torchgeo/datasets/seasonet.py +++ b/torchgeo/datasets/seasonet.py @@ -450,7 +450,7 @@ def plot( axs[ax].imshow(image) axs[ax].axis('off') if show_titles: - axs[ax].set_title(f'Image {ax+1}') + axs[ax].set_title(f'Image {ax + 1}') axs[ax + 1].imshow(mask, vmin=0, vmax=32, cmap=plt_cmap, interpolation='none') axs[ax + 1].axis('off') diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 3accf32d2af..d5a68f48488 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -104,14 +104,14 @@ def __init__( """ lazy_import('h5py') - assert ( - split in self.valid_splits - ), f'Please choose one of these valid data splits {self.valid_splits}.' + assert split in self.valid_splits, ( + f'Please choose one of these valid data splits {self.valid_splits}.' + ) self.split = split - assert ( - task in self.valid_tasks - ), f'Please choose one of these valid tasks {self.valid_tasks}.' + assert task in self.valid_tasks, ( + f'Please choose one of these valid tasks {self.valid_tasks}.' + ) self.task = task self.root = root diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py index a8643873c5b..841cd7173de 100644 --- a/torchgeo/datasets/south_africa_crop_type.py +++ b/torchgeo/datasets/south_africa_crop_type.py @@ -131,9 +131,9 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert ( - set(classes) <= self.cmap.keys() - ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert set(classes) <= self.cmap.keys(), ( + f'Only the following classes are valid: {list(self.cmap.keys())}.' + ) assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py index 13c5a8474c4..111fe487e09 100644 --- a/torchgeo/datasets/ssl4eo_benchmark.py +++ b/torchgeo/datasets/ssl4eo_benchmark.py @@ -138,26 +138,26 @@ def __init__( AssertionError: if any arguments are invalid DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert ( - sensor in self.valid_sensors - ), f'Only supports one of {self.valid_sensors}, but found {sensor}.' + assert sensor in self.valid_sensors, ( + f'Only supports one of {self.valid_sensors}, but found {sensor}.' + ) self.sensor = sensor - assert ( - product in self.valid_products - ), f'Only supports one of {self.valid_products}, but found {product}.' + assert product in self.valid_products, ( + f'Only supports one of {self.valid_products}, but found {product}.' + ) self.product = product - assert ( - split in self.valid_splits - ), f'Only supports one of {self.valid_splits}, but found {split}.' + assert split in self.valid_splits, ( + f'Only supports one of {self.valid_splits}, but found {split}.' + ) self.split = split self.cmap = self.cmaps[product] if classes is None: classes = list(self.cmap.keys()) - assert ( - set(classes) <= self.cmap.keys() - ), f'Only the following classes are valid: {list(self.cmap.keys())}.' + assert set(classes) <= self.cmap.keys(), ( + f'Only the following classes are valid: {list(self.cmap.keys())}.' + ) assert 0 in classes, 'Classes must include the background class: 0' self.root = root diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index eec9be57ab3..4d3a0b4de9f 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -82,14 +82,14 @@ def __init__( is invalid DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert set(countries).issubset( - self.valid_countries - ), f'Please choose a subset of these valid countried: {self.valid_countries}.' + assert set(countries).issubset(self.valid_countries), ( + f'Please choose a subset of these valid countried: {self.valid_countries}.' + ) self.countries = countries - assert ( - split in self.valid_splits - ), f'Pleas choose one of these valid data splits {self.valid_splits}.' + assert split in self.valid_splits, ( + f'Pleas choose one of these valid data splits {self.valid_splits}.' + ) self.split = split self.root = root diff --git a/torchgeo/models/croma.py b/torchgeo/models/croma.py index 475c32fd3a9..de57a835936 100644 --- a/torchgeo/models/croma.py +++ b/torchgeo/models/croma.py @@ -56,9 +56,9 @@ def __init__( """ super().__init__() for modality in modalities: - assert ( - modality in self.valid_modalities - ), f'{modality} is not a valid modality' + assert modality in self.valid_modalities, ( + f'{modality} is not a valid modality' + ) assert image_size % 8 == 0, 'image_size must be a multiple of 8' assert num_heads % 2 == 0, 'num_heads must be a power of 2'