diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index e09d665c3ff..d4d221e3c0d 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.8.0
+ rev: v0.9.1
hooks:
- id: ruff
types_or:
diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 458699a5325..656909190ab 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -358,6 +358,11 @@ MapInWild
.. autoclass:: MapInWild
+MDAS
+^^^^
+
+.. autoclass:: MDAS
+
Million-AID
^^^^^^^^^^^
diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv
index cc68d0d05b0..5b70fb768a6 100644
--- a/docs/api/datasets/non_geo_datasets.csv
+++ b/docs/api/datasets/non_geo_datasets.csv
@@ -30,6 +30,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`LEVIR-CD+`_,CD,Google Earth,-,985,2,"1,024x1,024",0.5,RGB
`LoveDA`_,S,Google Earth,"CC-BY-NC-SA-4.0","5,987",7,"1,024x1,024",0.3,RGB
`MapInWild`_,S,"Sentinel-1/2, ESA WorldCover, NOAA VIIRS DNB","CC-BY-4.0",1018,1,1920x1920,10--463.83,"SAR, MSI, 2020_Map, avg_rad"
+`MDAS`_,S,"Sentinel-1/2,EnMAP,HySpex","CC-BY-SA-4.0",3,20,"100x120, 300x360, 1364x1636, 10000x12000, 15000x18000",0.3--30,HSI
`Million-AID`_,C,Google Earth,-,1M,51--73,,0.5--153,RGB
`MMEarth`_,"C, S","Aster, Sentinel, ERA5","CC-BY-4.0","100K--1M",,"128x128 or 64x64",10,MSI
`NASA Marine Debris`_,OD,PlanetScope,"Apache-2.0",707,1,256x256,3,RGB
diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb
index 689b2eebd33..e148945afbb 100644
--- a/docs/tutorials/transforms.ipynb
+++ b/docs/tutorials/transforms.ipynb
@@ -707,7 +707,7 @@
"sample = dataset[idx]\n",
"rgb = sample['image'][0, 1:4]\n",
"image = T.ToPILImage()(rgb)\n",
- "print(f\"Class Label: {dataset.classes[sample['label']]}\")\n",
+ "print(f'Class Label: {dataset.classes[sample[\"label\"]]}')\n",
"image.resize((256, 256), resample=Image.BILINEAR)"
]
},
diff --git a/experiments/torchgeo/run_resisc45_experiments.py b/experiments/torchgeo/run_resisc45_experiments.py
index 6897ea12772..9ed69b03968 100755
--- a/experiments/torchgeo/run_resisc45_experiments.py
+++ b/experiments/torchgeo/run_resisc45_experiments.py
@@ -38,7 +38,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool:
for model, lr, loss, weights in itertools.product(
model_options, lr_options, loss_options, weight_options
):
- experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}"
+ experiment_name = f'{model}_{lr}_{loss}_{weights.replace("_", "-")}'
output_dir = os.path.join('output', 'resisc45_experiments')
log_dir = os.path.join(output_dir, 'logs')
diff --git a/experiments/torchgeo/run_so2sat_byol_experiments.py b/experiments/torchgeo/run_so2sat_byol_experiments.py
index 169a010cef8..4ae78601fbd 100755
--- a/experiments/torchgeo/run_so2sat_byol_experiments.py
+++ b/experiments/torchgeo/run_so2sat_byol_experiments.py
@@ -39,7 +39,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool:
for model, lr, loss, weights, bands in itertools.product(
model_options, lr_options, loss_options, weight_options, bands_options
):
- experiment_name = f"{model}_{lr}_{loss}_byol_{bands}-{weights.split('/')[-2]}"
+ experiment_name = f'{model}_{lr}_{loss}_byol_{bands}-{weights.split("/")[-2]}'
output_dir = os.path.join('output', 'so2sat_experiments')
log_dir = os.path.join(output_dir, 'logs')
diff --git a/experiments/torchgeo/run_so2sat_experiments.py b/experiments/torchgeo/run_so2sat_experiments.py
index 41e2fc04b5f..44ba5c7aaf3 100755
--- a/experiments/torchgeo/run_so2sat_experiments.py
+++ b/experiments/torchgeo/run_so2sat_experiments.py
@@ -38,7 +38,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool:
for model, lr, loss, weights in itertools.product(
model_options, lr_options, loss_options, weight_options
):
- experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}"
+ experiment_name = f'{model}_{lr}_{loss}_{weights.replace("_", "-")}'
output_dir = os.path.join('output', 'so2sat_experiments')
log_dir = os.path.join(output_dir, 'logs')
diff --git a/experiments/torchgeo/run_so2sat_seed_experiments.py b/experiments/torchgeo/run_so2sat_seed_experiments.py
index 2d2efe1e248..4f71770917a 100755
--- a/experiments/torchgeo/run_so2sat_seed_experiments.py
+++ b/experiments/torchgeo/run_so2sat_seed_experiments.py
@@ -39,7 +39,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool:
for model, lr, loss, weights, seed in itertools.product(
model_options, lr_options, loss_options, weight_options, seeds
):
- experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}_{seed}"
+ experiment_name = f'{model}_{lr}_{loss}_{weights.replace("_", "-")}_{seed}'
output_dir = os.path.join('output', 'so2sat_seed_experiments')
log_dir = os.path.join(output_dir, 'logs')
diff --git a/pyproject.toml b/pyproject.toml
index a4e125f5a31..8d7fcc89183 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -115,8 +115,8 @@ docs = [
style = [
# mypy 0.900+ required for pyproject.toml support
"mypy>=0.900",
- # ruff 0.8+ required for removal of ANN101, ANN102
- "ruff>=0.8",
+ # ruff 0.9+ required for 2025 style guide
+ "ruff>=0.9",
]
tests = [
# nbmake 1.3.3+ required for variable mocking
diff --git a/requirements/datasets.txt b/requirements/datasets.txt
index 9cfdcbe5d3c..d2be8e65a52 100644
--- a/requirements/datasets.txt
+++ b/requirements/datasets.txt
@@ -6,6 +6,6 @@ pandas[parquet]==2.2.3
pycocotools==2.0.8
pyvista==0.44.2
scikit-image==0.25.0
-scipy==1.14.1
+scipy==1.15.0
xarray==2024.11.0
netcdf4==1.7.2
diff --git a/requirements/required.txt b/requirements/required.txt
index 055d27670f6..62295bb4c71 100644
--- a/requirements/required.txt
+++ b/requirements/required.txt
@@ -1,5 +1,5 @@
# setup
-setuptools==75.6.0
+setuptools==75.8.0
# install
einops==0.8.0
@@ -10,7 +10,7 @@ lightning[pytorch-extra]==2.5.0.post0
matplotlib==3.10.0
numpy==2.2.1
pandas==2.2.3
-pillow==11.0.0
+pillow==11.1.0
pyproj==3.7.0
rasterio==1.4.3
rtree==1.3.0
diff --git a/requirements/style.txt b/requirements/style.txt
index dc9b06b247a..4b0987677c7 100644
--- a/requirements/style.txt
+++ b/requirements/style.txt
@@ -1,3 +1,3 @@
# style
mypy==1.14.1
-ruff==0.8.4
+ruff==0.9.1
diff --git a/tests/data/inria/data.py b/tests/data/inria/data.py
index 96626dea5fa..76ea3aa9c60 100755
--- a/tests/data/inria/data.py
+++ b/tests/data/inria/data.py
@@ -68,9 +68,9 @@ def generate_test_data(root: str, n_samples: int = 2) -> str:
lbl = np.random.randint(2, size=size, dtype=dtype)
timg = np.random.randint(dtype_max, size=size, dtype=dtype)
- img_path = os.path.join(img_dir, f'austin{i+1}.tif')
- lbl_path = os.path.join(lbl_dir, f'austin{i+1}.tif')
- timg_path = os.path.join(timg_dir, f'austin{i+10}.tif')
+ img_path = os.path.join(img_dir, f'austin{i + 1}.tif')
+ lbl_path = os.path.join(lbl_dir, f'austin{i + 1}.tif')
+ timg_path = os.path.join(timg_dir, f'austin{i + 10}.tif')
write_data(img_path, img, driver, crs, transform)
write_data(lbl_path, lbl, driver, crs, transform)
diff --git a/tests/data/mdas/Augsburg_data_4_publication.zip b/tests/data/mdas/Augsburg_data_4_publication.zip
new file mode 100644
index 00000000000..a4e00554127
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication.zip differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/3K_DSM_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/3K_DSM_sub_area1.tif
new file mode 100644
index 00000000000..13e7483a0be
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/3K_DSM_sub_area1.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/3K_RGB_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/3K_RGB_sub_area1.tif
new file mode 100644
index 00000000000..400085756b3
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/3K_RGB_sub_area1.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_EnMAP_10m_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_EnMAP_10m_sub_area1.tif
new file mode 100644
index 00000000000..96f0de07058
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_EnMAP_10m_sub_area1.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_EnMAP_30m_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_EnMAP_30m_sub_area1.tif
new file mode 100644
index 00000000000..30d07cd51ea
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_EnMAP_30m_sub_area1.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_Sentinel_2_10m_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_Sentinel_2_10m_sub_area1.tif
new file mode 100644
index 00000000000..5a50122b2c9
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/EeteS_Sentinel_2_10m_sub_area1.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/HySpex_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/HySpex_sub_area1.tif
new file mode 100644
index 00000000000..e0984b4f7c1
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/HySpex_sub_area1.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/Sentinel_1_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/Sentinel_1_sub_area1.tif
new file mode 100644
index 00000000000..834990c3e20
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/Sentinel_1_sub_area1.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/Sentinel_2_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/Sentinel_2_sub_area1.tif
new file mode 100644
index 00000000000..489e96791a2
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/Sentinel_2_sub_area1.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_buildings_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_buildings_sub_area1.tif
new file mode 100644
index 00000000000..0faaa47d73f
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_buildings_sub_area1.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_landuse_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_landuse_sub_area1.tif
new file mode 100644
index 00000000000..c80350418fc
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_landuse_sub_area1.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_water_sub_area1.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_water_sub_area1.tif
new file mode 100644
index 00000000000..8840f67a9aa
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_1/osm_water_sub_area1.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/3K_DSM_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/3K_DSM_sub_area2.tif
new file mode 100644
index 00000000000..313e5f9f6f4
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/3K_DSM_sub_area2.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/3K_RGB_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/3K_RGB_sub_area2.tif
new file mode 100644
index 00000000000..37cdaa7f26b
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/3K_RGB_sub_area2.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_EnMAP_10m_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_EnMAP_10m_sub_area2.tif
new file mode 100644
index 00000000000..7ac37f43f06
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_EnMAP_10m_sub_area2.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_EnMAP_30m_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_EnMAP_30m_sub_area2.tif
new file mode 100644
index 00000000000..349b28401ac
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_EnMAP_30m_sub_area2.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_Sentinel_2_10m_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_Sentinel_2_10m_sub_area2.tif
new file mode 100644
index 00000000000..fd0bcceda81
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/EeteS_Sentinel_2_10m_sub_area2.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/HySpex_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/HySpex_sub_area2.tif
new file mode 100644
index 00000000000..21e22212147
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/HySpex_sub_area2.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/Sentinel_1_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/Sentinel_1_sub_area2.tif
new file mode 100644
index 00000000000..a7ebb2b5b61
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/Sentinel_1_sub_area2.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/Sentinel_2_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/Sentinel_2_sub_area2.tif
new file mode 100644
index 00000000000..c4a390f88a8
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/Sentinel_2_sub_area2.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_buildings_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_buildings_sub_area2.tif
new file mode 100644
index 00000000000..711d872d705
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_buildings_sub_area2.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_landuse_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_landuse_sub_area2.tif
new file mode 100644
index 00000000000..8a3f1ca1692
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_landuse_sub_area2.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_water_sub_area2.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_water_sub_area2.tif
new file mode 100644
index 00000000000..9c88128b574
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_2/osm_water_sub_area2.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/3K_DSM_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/3K_DSM_sub_area3.tif
new file mode 100644
index 00000000000..e664307c2c4
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/3K_DSM_sub_area3.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/3K_RGB_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/3K_RGB_sub_area3.tif
new file mode 100644
index 00000000000..aa020dde3f1
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/3K_RGB_sub_area3.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_EnMAP_10m_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_EnMAP_10m_sub_area3.tif
new file mode 100644
index 00000000000..8f64503512d
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_EnMAP_10m_sub_area3.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_EnMAP_30m_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_EnMAP_30m_sub_area3.tif
new file mode 100644
index 00000000000..c30d66e3dd4
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_EnMAP_30m_sub_area3.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_Sentinel_2_10m_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_Sentinel_2_10m_sub_area3.tif
new file mode 100644
index 00000000000..2fd45d96c7f
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/EeteS_Sentinel_2_10m_sub_area3.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/HySpex_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/HySpex_sub_area3.tif
new file mode 100644
index 00000000000..36b6ce50d76
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/HySpex_sub_area3.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/Sentinel_1_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/Sentinel_1_sub_area3.tif
new file mode 100644
index 00000000000..5cfad510e84
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/Sentinel_1_sub_area3.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/Sentinel_2_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/Sentinel_2_sub_area3.tif
new file mode 100644
index 00000000000..3e80c170bf8
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/Sentinel_2_sub_area3.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_buildings_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_buildings_sub_area3.tif
new file mode 100644
index 00000000000..9861e5e07ee
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_buildings_sub_area3.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_landuse_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_landuse_sub_area3.tif
new file mode 100644
index 00000000000..3705837b37a
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_landuse_sub_area3.tif differ
diff --git a/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_water_sub_area3.tif b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_water_sub_area3.tif
new file mode 100644
index 00000000000..4bcb146bef6
Binary files /dev/null and b/tests/data/mdas/Augsburg_data_4_publication/sub_area_3/osm_water_sub_area3.tif differ
diff --git a/tests/data/mdas/data.py b/tests/data/mdas/data.py
new file mode 100644
index 00000000000..c82b54cb89a
--- /dev/null
+++ b/tests/data/mdas/data.py
@@ -0,0 +1,161 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import hashlib
+import os
+import shutil
+
+import numpy as np
+import rasterio
+from rasterio.crs import CRS
+from rasterio.transform import from_origin
+
+# Set the random seed for reproducibility
+np.random.seed(0)
+
+# Define the root directory, dataset name, subareas, and modalities based on mdas.py
+root_dir = '.'
+ds_root_name = 'Augsburg_data_4_publication'
+subareas = ['sub_area_1', 'sub_area_2', 'sub_area_3']
+modalities = [
+ '3K_DSM',
+ '3K_RGB',
+ 'HySpex',
+ 'EeteS_EnMAP_10m',
+ 'EeteS_EnMAP_30m',
+ 'EeteS_Sentinel_2_10m',
+ 'Sentinel_1',
+ 'Sentinel_2',
+ 'osm_buildings',
+ 'osm_landuse',
+ 'osm_water',
+]
+
+landuse_class_codes = [
+ -2147483647, # no label
+ 7201, # forest
+ 7202, # park
+ 7203, # residential
+ 7204, # industrial
+ 7205, # farm
+ 7206, # cemetery
+ 7207, # allotments
+ 7208, # meadow
+ 7209, # commercial
+ 7210, # nature reserve
+ 7211, # recreation ground
+ 7212, # retail
+ 7213, # military
+ 7214, # quarry
+ 7215, # orchard
+ 7217, # scrub
+ 7218, # grass
+ 7219, # heath
+]
+
+# Remove existing dummy data if it exists
+dataset_path = os.path.join(root_dir, ds_root_name)
+if os.path.exists(dataset_path):
+ shutil.rmtree(dataset_path)
+
+
+def create_dummy_geotiff(
+ path: str,
+ num_bands: int = 3,
+ width: int = 32,
+ height: int = 32,
+ dtype: np.dtype = np.uint16,
+ binary: bool = False,
+ landuse: bool = False,
+) -> None:
+ """Create a dummy GeoTIFF file."""
+ crs = CRS.from_epsg(32632)
+ transform = from_origin(0, 0, 1, 1)
+
+ if binary:
+ data = np.random.randint(0, 2, size=(num_bands, height, width)).astype(dtype)
+ elif landuse:
+ num_pixels = num_bands * height * width
+ no_label_ratio = 0.1
+ num_no_label = int(no_label_ratio * num_pixels)
+ num_labels = num_pixels - num_no_label
+ landuse_values = np.random.choice(landuse_class_codes[1:], size=num_labels)
+ no_label_values = np.full(num_no_label, landuse_class_codes[0], dtype=dtype)
+ combined = np.concatenate([landuse_values, no_label_values])
+ np.random.shuffle(combined)
+ data = combined.reshape((num_bands, height, width)).astype(dtype)
+ else:
+ # Generate random data for other modalities
+ data = np.random.randint(0, 255, size=(num_bands, height, width)).astype(dtype)
+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+
+ with rasterio.open(
+ path,
+ 'w',
+ driver='GTiff',
+ height=height,
+ width=width,
+ count=num_bands,
+ dtype=dtype,
+ crs=crs,
+ transform=transform,
+ ) as dst:
+ dst.write(data)
+
+
+# Create directory structure and dummy data
+for subarea in subareas:
+ # Format the subarea name for filenames, as in mdas.py _format_subarea method
+ parts = subarea.split('_')
+ subarea_formatted = parts[0] + '_' + parts[1] + parts[2] # e.g., 'sub_area1'
+
+ subarea_dir = os.path.join(root_dir, ds_root_name, subarea)
+
+ for modality in modalities:
+ filename = f'{modality}_{subarea_formatted}.tif'
+ file_path = os.path.join(subarea_dir, filename)
+
+ if modality in ['osm_buildings', 'osm_water']:
+ create_dummy_geotiff(file_path, num_bands=1, dtype=np.uint8, binary=True)
+ elif modality == 'osm_landuse':
+ create_dummy_geotiff(file_path, num_bands=1, dtype=np.float64, landuse=True)
+ elif modality == 'HySpex':
+ create_dummy_geotiff(file_path, num_bands=368, dtype=np.int16)
+ elif modality in ['EeteS_EnMAP_10m', 'EeteS_EnMAP_30m']:
+ create_dummy_geotiff(file_path, num_bands=242, dtype=np.uint16)
+ elif modality == 'Sentinel_1':
+ create_dummy_geotiff(file_path, num_bands=2, dtype=np.float32)
+ elif modality in ['Sentinel_2', 'EeteS_Sentinel_2_10m']:
+ create_dummy_geotiff(file_path, num_bands=13, dtype=np.uint16)
+ elif modality == '3K_DSM':
+ create_dummy_geotiff(file_path, num_bands=1, dtype=np.float32)
+ elif modality == '3K_RGB':
+ create_dummy_geotiff(file_path, num_bands=3, dtype=np.uint8)
+
+print(f'Dummy MDAS dataset created at {os.path.join(root_dir, ds_root_name)}')
+
+# Create a zip archive of the dataset directory
+zip_filename = f'{ds_root_name}.zip'
+zip_path = os.path.join(root_dir, zip_filename)
+
+shutil.make_archive(
+ base_name=os.path.splitext(zip_path)[0],
+ format='zip',
+ root_dir='.',
+ base_dir=ds_root_name,
+)
+
+
+def calculate_md5(filename: str) -> str:
+ hash_md5 = hashlib.md5()
+ with open(filename, 'rb') as f:
+ for chunk in iter(lambda: f.read(4096), b''):
+ hash_md5.update(chunk)
+ return hash_md5.hexdigest()
+
+
+checksum = calculate_md5(zip_path)
+print(f'MD5 checksum: {checksum}')
diff --git a/tests/data/seasonet/data.py b/tests/data/seasonet/data.py
index e3197ddde12..86f6210636b 100644
--- a/tests/data/seasonet/data.py
+++ b/tests/data/seasonet/data.py
@@ -63,7 +63,7 @@
os.remove(archive)
for grid, comp in zip(grids, name_comps):
- file_name = f"{comp[0]}_{''.join(comp[1:8])}_{'_'.join(comp[8:])}"
+ file_name = f'{comp[0]}_{"".join(comp[1:8])}_{"_".join(comp[8:])}'
dir = os.path.join(season, f'grid{grid}', file_name)
os.makedirs(dir)
diff --git a/tests/data/ssl4eo_benchmark_landsat/data.py b/tests/data/ssl4eo_benchmark_landsat/data.py
index 177ed7d7954..5470aacef05 100755
--- a/tests/data/ssl4eo_benchmark_landsat/data.py
+++ b/tests/data/ssl4eo_benchmark_landsat/data.py
@@ -193,7 +193,7 @@ def create_tarballs(directories: str) -> None:
# mask directory cdl
mask_keep = ['tm_toa', 'etm_sr', 'oli_sr']
mask_filenames = {
- f"ssl4eo_l_{key.split('_')[0]}_cdl": val
+ f'ssl4eo_l_{key.split("_")[0]}_cdl': val
for key, val in filenames.items()
if key in mask_keep
}
@@ -203,7 +203,7 @@ def create_tarballs(directories: str) -> None:
# mask directory nlcd
mask_filenames = {
- f"ssl4eo_l_{key.split('_')[0]}_nlcd": val
+ f'ssl4eo_l_{key.split("_")[0]}_nlcd': val
for key, val in filenames.items()
if key in mask_keep
}
diff --git a/tests/datamodules/test_digital_typhoon.py b/tests/datamodules/test_digital_typhoon.py
index 0ecd85f5ec7..dd61eb26933 100644
--- a/tests/datamodules/test_digital_typhoon.py
+++ b/tests/datamodules/test_digital_typhoon.py
@@ -57,14 +57,14 @@ def find_max_time_per_id(
# Assert that each max value in train_max_values is lower
# than in val_max_values for each key id
for id, max_value in train_max_values.items():
- assert (
- id not in val_max_values or max_value < val_max_values[id]
- ), f'Max value for id {id} in train is not lower than in validation.'
+ assert id not in val_max_values or max_value < val_max_values[id], (
+ f'Max value for id {id} in train is not lower than in validation.'
+ )
else:
train_ids = {seq['id'] for seq in train_sequences}
val_ids = {seq['id'] for seq in val_sequences}
# Assert that the intersection between train_ids and val_ids is empty
- assert (
- len(train_ids & val_ids) == 0
- ), 'Train and validation datasets have overlapping ids.'
+ assert len(train_ids & val_ids) == 0, (
+ 'Train and validation datasets have overlapping ids.'
+ )
diff --git a/tests/datasets/test_mdas.py b/tests/datasets/test_mdas.py
new file mode 100644
index 00000000000..83138c84207
--- /dev/null
+++ b/tests/datasets/test_mdas.py
@@ -0,0 +1,113 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import os
+import shutil
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import pytest
+import torch
+import torch.nn as nn
+from _pytest.fixtures import SubRequest
+from pytest import MonkeyPatch
+
+from torchgeo.datasets import MDAS, DatasetNotFoundError
+
+
+class TestMDAS:
+ @pytest.fixture(
+ params=[
+ {'subareas': ['sub_area_1'], 'modalities': ['HySpex']},
+ {
+ 'subareas': ['sub_area_1', 'sub_area_2'],
+ 'modalities': ['3K_DSM', 'HySpex', 'osm_water'],
+ },
+ {
+ 'subareas': ['sub_area_2', 'sub_area_3'],
+ 'modalities': [
+ '3K_DSM',
+ '3K_RGB',
+ 'HySpex',
+ 'EeteS_EnMAP_10m',
+ 'EeteS_EnMAP_30m',
+ 'EeteS_Sentinel_2_10m',
+ 'Sentinel_2',
+ 'Sentinel_1',
+ 'osm_buildings',
+ 'osm_landuse',
+ 'osm_water',
+ ],
+ },
+ ]
+ )
+ def dataset(
+ self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
+ ) -> MDAS:
+ md5 = '99e1744ca6f19aa19a3aa23a2bbf7bef'
+ monkeypatch.setattr(MDAS, 'md5', md5)
+ url = os.path.join('tests', 'data', 'mdas', 'Augsburg_data_4_publication.zip')
+ monkeypatch.setattr(MDAS, 'url', url)
+
+ params = request.param
+ subareas = params['subareas']
+ modalities = params['modalities']
+
+ root = tmp_path
+ transforms = nn.Identity()
+
+ return MDAS(
+ root=root,
+ subareas=subareas,
+ modalities=modalities,
+ transforms=transforms,
+ download=True,
+ checksum=True,
+ )
+
+ def test_getitem(self, dataset: MDAS) -> None:
+ x = dataset[0]
+ assert isinstance(x, dict)
+ for key in dataset.modalities:
+ if key.startswith('osm'):
+ key = f'{key}_mask'
+ else:
+ key = f'{key}_image'
+ assert key in x
+
+ for key, value in x.items():
+ assert isinstance(value, torch.Tensor)
+
+ def test_len(self, dataset: MDAS) -> None:
+ assert len(dataset) == len(dataset.subareas)
+
+ def test_already_downloaded(self, dataset: MDAS) -> None:
+ MDAS(root=dataset.root)
+
+ def test_not_yet_extracted(self, tmp_path: Path) -> None:
+ filename = 'Augsburg_data_4_publication.zip'
+ dir = os.path.join('tests', 'data', 'mdas')
+ shutil.copyfile(
+ os.path.join(dir, filename), os.path.join(str(tmp_path), filename)
+ )
+ MDAS(root=str(tmp_path))
+
+ def test_invalid_subarea(self) -> None:
+ with pytest.raises(AssertionError):
+ MDAS(subareas=['foo'])
+
+ def test_invalid_modality(self) -> None:
+ with pytest.raises(AssertionError):
+ MDAS(modalities=['foo'])
+
+ def test_not_downloaded(self, tmp_path: Path) -> None:
+ with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
+ MDAS(tmp_path)
+
+ def test_plot(self, dataset: MDAS) -> None:
+ dataset.plot(dataset[0], suptitle='Test')
+ plt.close()
+
+ def test_plot_single_sample(self, dataset: MDAS) -> None:
+ dataset.plot(dataset[0], show_titles=False)
+ plt.close()
diff --git a/torchgeo/datamodules/digital_typhoon.py b/torchgeo/datamodules/digital_typhoon.py
index ce799bf3d52..9ebc1643255 100644
--- a/torchgeo/datamodules/digital_typhoon.py
+++ b/torchgeo/datamodules/digital_typhoon.py
@@ -43,9 +43,9 @@ def __init__(
"""
super().__init__(DigitalTyphoon, batch_size, num_workers, **kwargs)
- assert (
- split_by in self.valid_split_types
- ), f'Please choose from {self.valid_split_types}'
+ assert split_by in self.valid_split_types, (
+ f'Please choose from {self.valid_split_types}'
+ )
self.split_by = split_by
def _split_dataset(
diff --git a/torchgeo/datamodules/ftw.py b/torchgeo/datamodules/ftw.py
index a197a789c48..19128cdbb3d 100644
--- a/torchgeo/datamodules/ftw.py
+++ b/torchgeo/datamodules/ftw.py
@@ -44,9 +44,9 @@ def __init__(
Raises:
AssertionError: If 'countries' are specified in kwargs
"""
- assert (
- 'countries' not in kwargs
- ), "Please specify 'train_countries', 'val_countries', and 'test_countries' instead of 'countries' inside kwargs"
+ assert 'countries' not in kwargs, (
+ "Please specify 'train_countries', 'val_countries', and 'test_countries' instead of 'countries' inside kwargs"
+ )
super().__init__(FieldsOfTheWorld, batch_size, num_workers, **kwargs)
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index c9cb2989d0a..8c2e362b6c1 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -86,6 +86,7 @@
from .levircd import LEVIRCD, LEVIRCDBase, LEVIRCDPlus
from .loveda import LoveDA
from .mapinwild import MapInWild
+from .mdas import MDAS
from .millionaid import MillionAID
from .mmearth import MMEarth
from .naip import NAIP
@@ -160,6 +161,7 @@
'GBIF',
'GID15',
'LEVIRCD',
+ 'MDAS',
'NAIP',
'NCCM',
'NLCD',
diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py
index 3624c1e193e..ba116377878 100644
--- a/torchgeo/datasets/agrifieldnet.py
+++ b/torchgeo/datasets/agrifieldnet.py
@@ -149,9 +149,9 @@ def __init__(
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
- assert (
- set(classes) <= self.cmap.keys()
- ), f'Only the following classes are valid: {list(self.cmap.keys())}.'
+ assert set(classes) <= self.cmap.keys(), (
+ f'Only the following classes are valid: {list(self.cmap.keys())}.'
+ )
assert 0 in classes, 'Classes must include the background class: 0'
self.paths = paths
diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py
index 38669cd6ff1..ef62ac1a280 100644
--- a/torchgeo/datasets/bigearthnet.py
+++ b/torchgeo/datasets/bigearthnet.py
@@ -565,9 +565,9 @@ def plot(
ax.imshow(image)
ax.axis('off')
if show_titles:
- title = f"Labels: {', '.join(labels)}"
+ title = f'Labels: {", ".join(labels)}'
if showing_predictions:
- title += f"\nPredictions: {', '.join(predictions)}"
+ title += f'\nPredictions: {", ".join(predictions)}'
ax.set_title(title)
if suptitle is not None:
diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py
index 70a53a4220a..2531c96dd23 100644
--- a/torchgeo/datasets/biomassters.py
+++ b/torchgeo/datasets/biomassters.py
@@ -81,14 +81,14 @@ def __init__(
"""
self.root = root
- assert (
- split in self.valid_splits
- ), f'Please choose one of the valid splits: {self.valid_splits}.'
+ assert split in self.valid_splits, (
+ f'Please choose one of the valid splits: {self.valid_splits}.'
+ )
self.split = split
- assert set(sensors).issubset(
- set(self.valid_sensors)
- ), f'Please choose a subset of valid sensors: {self.valid_sensors}.'
+ assert set(sensors).issubset(set(self.valid_sensors)), (
+ f'Please choose a subset of valid sensors: {self.valid_sensors}.'
+ )
self.sensors = sensors
self.as_time_series = as_time_series
diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py
index 0b0f6ac5b3d..2de5719beb0 100644
--- a/torchgeo/datasets/cdl.py
+++ b/torchgeo/datasets/cdl.py
@@ -248,9 +248,9 @@ def __init__(
'CDL data product only exists for the following years: '
f'{list(self.md5s.keys())}.'
)
- assert (
- set(classes) <= self.cmap.keys()
- ), f'Only the following classes are valid: {list(self.cmap.keys())}.'
+ assert set(classes) <= self.cmap.keys(), (
+ f'Only the following classes are valid: {list(self.cmap.keys())}.'
+ )
assert 0 in classes, 'Classes must include the background class: 0'
self.paths = paths
diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py
index f9db256238d..681d5026f25 100644
--- a/torchgeo/datasets/cms_mangrove_canopy.py
+++ b/torchgeo/datasets/cms_mangrove_canopy.py
@@ -204,15 +204,15 @@ def __init__(
self.checksum = checksum
assert isinstance(country, str), 'Country argument must be a str.'
- assert (
- country in self.all_countries
- ), f'You have selected an invalid country, please choose one of {self.all_countries}'
+ assert country in self.all_countries, (
+ f'You have selected an invalid country, please choose one of {self.all_countries}'
+ )
self.country = country
assert isinstance(measurement, str), 'Measurement must be a string.'
- assert (
- measurement in self.measurements
- ), f'You have entered an invalid measurement, please choose one of {self.measurements}.'
+ assert measurement in self.measurements, (
+ f'You have entered an invalid measurement, please choose one of {self.measurements}.'
+ )
self.measurement = measurement
self.filename_glob = f'**/Mangrove_{self.measurement}_{self.country}*'
diff --git a/torchgeo/datasets/digital_typhoon.py b/torchgeo/datasets/digital_typhoon.py
index 42bb4caa1bd..dfa47966440 100644
--- a/torchgeo/datasets/digital_typhoon.py
+++ b/torchgeo/datasets/digital_typhoon.py
@@ -139,9 +139,9 @@ def __init__(
self.min_feature_value = min_feature_value
self.max_feature_value = max_feature_value
- assert (
- task in self.valid_tasks
- ), f'Please choose one of {self.valid_tasks}, you provided {task}.'
+ assert task in self.valid_tasks, (
+ f'Please choose one of {self.valid_tasks}, you provided {task}.'
+ )
self.task = task
assert set(features).issubset(set(self.valid_features))
diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py
index 93b6b18e455..e3066058371 100644
--- a/torchgeo/datasets/loveda.py
+++ b/torchgeo/datasets/loveda.py
@@ -115,9 +115,9 @@ def __init__(
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits
- assert set(scene).intersection(
- set(self.scenes)
- ), "The possible scenes are 'rural' and/or 'urban'"
+ assert set(scene).intersection(set(self.scenes)), (
+ "The possible scenes are 'rural' and/or 'urban'"
+ )
assert len(scene) <= 2, "There are no other scenes than 'rural' or 'urban'"
self.root = root
diff --git a/torchgeo/datasets/mdas.py b/torchgeo/datasets/mdas.py
new file mode 100644
index 00000000000..25a61a72396
--- /dev/null
+++ b/torchgeo/datasets/mdas.py
@@ -0,0 +1,379 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""MDAS dataset."""
+
+import os
+from collections.abc import Callable
+from typing import Any, ClassVar
+
+import matplotlib.cm as cm
+import matplotlib.pyplot as plt
+import numpy as np
+import rasterio as rio
+import torch
+from matplotlib.colors import ListedColormap
+from matplotlib.figure import Figure
+from torch import Tensor
+
+from .errors import DatasetNotFoundError
+from .geo import NonGeoDataset
+from .utils import Path, download_and_extract_archive, extract_archive
+
+
+class MDAS(NonGeoDataset):
+ """MDAS dataset.
+
+ The `MDAS `__ multimodal dataset
+ is a comprehensive dataset for the city of Augsburg, Germany, collected on 7th May 2018.
+ It includes SAR, multispectral, hyperspectral, DSM, and GIS data,
+ providing comprehensive options for data fusion research.
+ MDAS supports applications like resolution enhancement, spectral unmixing, and land cover classification.
+
+ Dataset features:
+
+ * 3K DSM data
+ * 3K high resolution RGB images
+ * Original very high resolution HySpex airborne imagery
+ * EeteS simulated imagery with 10m GSD and EnMAP spectral bands
+ * EeteS simulated imagery with 30m GSD and EnMAP spectral bands
+ * EeteS simulated imagery with 10m GSD and Sentinel-2 spectral bands
+ * Sentinel-2 L2A product
+ * Sentinel-1 GRD product
+ * Open Street Map (OSM) labels, see `this table `__ for
+ a table of the label distribution
+
+ Dataset format:
+
+ * 3K_RGB.tif (Shape: (4, 15000, 18000)px, Data Type: uint8)
+ * 3K_dsm.tif (Shape: (1, 10000, 12000)px, Data Type: float32)
+ * HySpex.tif (Shape: (368, 1364, 1636)px, Data Type: int16)
+ * EeteS_EnMAP_2dot2m.tif (Shape: (242, 1364, 1636)px, Data Type: float32)
+ * EeteS_EnMAP_10m.tif (Shape: (242, 300, 360)px, Data Type: uint16)
+ * EeteS_EnMAP_30m.tif (Shape: (242, 100, 120)px, Data Type: uint16)
+ * EeteS_Sentinel_2_10m.tif (Shape: (4, 300, 360)px, Data Type: uint16)
+ * Sentinel_2.tif (Shape: (12, 300, 360)px, Data Type: uint16)
+ * Sentinel_1.tif (Shape: (2, 300, 360)px, Data Type: float32)
+ * osm_buildings.tif (Shape: (1, 1364, 1636)px, Data Type: uint8)
+ * osm_landuse.tif (Shape: (1, 1364, 1636)px, Data Type: float64)
+ * osm_water.tif (Shape: (1, 1364, 1636)px, Data Type: float64)
+
+ If you use this dataset in your research, please cite the following paper:
+
+ * https://essd.copernicus.org/articles/15/113/2023/
+
+ .. versionadded:: 0.7
+ """
+
+ valid_modalities = (
+ '3K_DSM',
+ '3K_RGB',
+ 'HySpex',
+ 'EeteS_EnMAP_10m',
+ 'EeteS_EnMAP_30m',
+ 'EeteS_Sentinel_2_10m',
+ 'Sentinel_2',
+ 'Sentinel_1',
+ 'osm_buildings',
+ 'osm_landuse',
+ 'osm_water',
+ )
+ landuse_class_names: ClassVar[dict[int, str]] = {
+ 0: 'no label',
+ 1: 'forest',
+ 2: 'park',
+ 3: 'residential',
+ 4: 'industrial',
+ 5: 'farm',
+ 6: 'cemetery',
+ 7: 'allotments',
+ 8: 'meadow',
+ 9: 'commercial',
+ 10: 'nature reserve',
+ 11: 'recreation ground',
+ 12: 'retail',
+ 13: 'military',
+ 14: 'quarry',
+ 15: 'orchard',
+ 16: 'scrub',
+ 17: 'grass',
+ 18: 'heath',
+ }
+
+ # https://github.com/zhu-xlab/augsburg_Multimodal_Data_Set_MDaS/blob/75c015022b5f688dfc44744f19bcf34bdce786c7/Augsburg_data_4_publication/entire_city/OSM_label/README#L14
+ landuse_mapping: ClassVar[dict[int, int]] = {
+ -2147483647: 0,
+ 7201: 1,
+ 7202: 2,
+ 7203: 3,
+ 7204: 4,
+ 7205: 5,
+ 7206: 6,
+ 7207: 7,
+ 7208: 8,
+ 7209: 9,
+ 7210: 10,
+ 7211: 11,
+ 7212: 12,
+ 7213: 13,
+ 7214: 14,
+ 7215: 15,
+ 7217: 16,
+ 7218: 17,
+ 7219: 18,
+ }
+
+ ds_root_name = 'Augsburg_data_4_publication'
+
+ zipfilename = f'{ds_root_name}.zip'
+
+ valid_subareas = ('sub_area_1', 'sub_area_2', 'sub_area_3')
+
+ url = 'https://huggingface.co/datasets/torchgeo/mdas/resolve/860226b74269f1cf1bed8ea3c03f571ae701144c/Augsburg_data_4_publication.zip'
+
+ md5 = '7b63c26e3717cb52c6ba47d215f18d5b'
+
+ enmap_rgb_band_idx: ClassVar[list[int]] = [43, 28, 10]
+ sentinel_2_rgb_band_idx: ClassVar[list[int]] = [3, 2, 1]
+ hyspex_rgb_band_idx: ClassVar[list[int]] = [100, 50, 10]
+
+ def __init__(
+ self,
+ root: Path = 'data',
+ subareas: list[str] = ['sub_area_1'],
+ modalities: list[str] = ['3K_RGB', 'HySpex', 'Sentinel_2'],
+ transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
+ download: bool = False,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new MDAS dataset instance.
+
+ Args:
+ root: Root directory where the dataset should be stored.
+ subareas: The subareas to load. Options are 'sub_area_1', 'sub_area_2', 'sub_area_3'.
+ modalities: The modalities to load. Options are '3K_DSM', '3K_RGB', 'HySpex', 'EeteS_EnMAP_10m', 'EeteS_EnMAP_30m', 'EeteS_Sentinel_2_10m', 'Sentinel-2', 'Sentinel-1', 'OSM_label'.
+ transforms: A function/transform that takes in a dictionary and returns a transformed version.
+ download: if True, download dataset and store it in the root directory
+ checksum: If True, check the integrity of the dataset after download.
+
+ Raises:
+ AssertionError: If the subareas or modalities are not valid.
+ DatasetNotFoundError: If dataset is not found and *download* is False.
+ """
+ self.root = root
+ self.download = download
+ assert all(sub in self.valid_subareas for sub in subareas), (
+ f'Subareas must be one of {self.valid_subareas}'
+ )
+ self.subareas = subareas
+ assert all(mod in self.valid_modalities for mod in modalities), (
+ f'Modalities must be one of {self.valid_modalities}'
+ )
+ self.modalities = modalities
+ self.transforms = transforms
+ self.checksum = checksum
+
+ self._verify()
+ self.files = self._load_files()
+
+ def __len__(self) -> int:
+ """Return the number of samples in the dataset.
+
+ Returns:
+ the length of the dataset
+ """
+ return len(self.files)
+
+ def _load_files(self) -> list[dict[str, str]]:
+ """Return the paths of the files in the dataset.
+
+ Returns:
+ a list of dictionaries containing the paths of the files in the dataset
+ """
+ files = []
+ for subarea in self.subareas:
+ subarea_files = {}
+ for modality in self.modalities:
+ subarea_files[modality] = os.path.join(
+ self.root,
+ self.ds_root_name,
+ subarea,
+ f'{modality}_{self._format_subarea(subarea)}.tif',
+ )
+ files.append(subarea_files)
+ return files
+
+ def _format_subarea(self, subarea: str) -> str:
+ """Format the subarea name.
+
+ Args:
+ subarea: The subarea string to format.
+
+ Returns:
+ formatted subarea string for files
+ """
+ parts = subarea.split('_')
+ return parts[0] + '_' + parts[1] + parts[2]
+
+ def _load_image(self, path: Path) -> Tensor:
+ """Load an image from a given path.
+
+ Args:
+ path: The path to the image file
+
+ Returns:
+ the loaded image as a tensor
+ """
+ with rio.open(path) as src:
+ img = src.read()
+ if img.dtype == np.uint16:
+ img = img.astype(np.int32)
+ if 'osm_landuse' in str(path):
+ img = np.vectorize(self.landuse_mapping.get)(img)
+
+ return torch.from_numpy(img)
+
+ def __getitem__(self, idx: int) -> dict[str, Tensor]:
+ """Return the dataset sample at the given index.
+
+ Args:
+ idx: The index of the sample to return
+
+ Returns:
+ a dictionary containing the data of chosen modalities
+ """
+ sample_files = self.files[idx]
+ sample: dict[str, Any] = {}
+ for modality, path in sample_files.items():
+ if 'osm' in modality:
+ sample[f'{modality}_mask'] = self._load_image(path).long()
+ else:
+ sample[f'{modality}_image'] = self._load_image(path)
+
+ if self.transforms:
+ sample = self.transforms(sample)
+
+ return sample
+
+ def _verify(self) -> None:
+ """Verify the integrity of the dataset."""
+ # check if each desired modality file exists in specified subarea
+ exists = []
+ for subarea in self.subareas:
+ for modality in self.modalities:
+ path = os.path.join(
+ self.root,
+ self.ds_root_name,
+ subarea,
+ f'{modality}_{self._format_subarea(subarea)}.tif',
+ )
+ if not os.path.exists(path):
+ exists.append(False)
+ else:
+ exists.append(True)
+ if all(exists):
+ return
+
+ # check if zip file downloaded
+ if os.path.exists(os.path.join(self.root, self.zipfilename)):
+ self._extract()
+ return
+
+ if not self.download:
+ raise DatasetNotFoundError(self)
+
+ self._download()
+
+ def _extract(self) -> None:
+ """Extract the dataset."""
+ extract_archive(os.path.join(self.root, self.zipfilename), self.root)
+
+ def _download(self) -> None:
+ """Download the dataset."""
+ download_and_extract_archive(
+ self.url,
+ self.root,
+ filename=self.zipfilename,
+ md5=self.md5 if self.checksum else None,
+ )
+
+ def plot(
+ self,
+ sample: dict[str, Tensor],
+ show_titles: bool = True,
+ suptitle: str | None = None,
+ ) -> Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: A sample returned by `__getitem__`.
+ show_titles: Whether to display titles on the subplots.
+ suptitle: An optional super title for the plot.
+
+ Returns:
+ a matplotlib Figure with the rendered sample
+ """
+ ncols = len(sample)
+ fig, axs = plt.subplots(1, ncols, figsize=(5 * ncols, 5))
+
+ if ncols == 1:
+ axs = [axs]
+
+ for idx, (key, data) in enumerate(sample.items()):
+ match key:
+ case '3K_RGB_image':
+ img = data[:3].numpy().transpose(1, 2, 0) / 255.0
+ axs[idx].imshow(img)
+ case '3K_DSM_image':
+ img = data.numpy().squeeze(0)
+ axs[idx].imshow(img, cmap='gray')
+ case 'EeteS_EnMAP_10m_image' | 'EeteS_EnMAP_30m_image':
+ img = (
+ data[self.enmap_rgb_band_idx].numpy().transpose(1, 2, 0)
+ / 10000.0
+ )
+ axs[idx].imshow(img)
+ case 'EeteS_Sentinel_2_10m_image':
+ img = (
+ data[self.sentinel_2_rgb_band_idx].numpy().transpose(1, 2, 0)
+ / 10000.0
+ )
+ axs[idx].imshow(img)
+ case 'Sentinel_1_image':
+ img = data[0].numpy().clip(0, 1)
+ axs[idx].imshow(img)
+ case 'Sentinel_2_image':
+ img = (
+ data[self.sentinel_2_rgb_band_idx].numpy().transpose(1, 2, 0)
+ / 10000.0
+ )
+ axs[idx].imshow(img)
+ case 'HySpex_image':
+ img = (
+ data[self.hyspex_rgb_band_idx].numpy().transpose(1, 2, 0)
+ / 15000.0
+ )
+ axs[idx].imshow(img)
+ case 'osm_landuse_mask':
+ img = data.numpy().squeeze(0)
+ cmap = ListedColormap([cm.get_cmap('tab20')(i) for i in range(20)])
+ im = axs[idx].imshow(img, cmap=cmap)
+ cbar = plt.colorbar(im, ax=axs[idx], ticks=range(19))
+ cbar.ax.set_yticklabels(
+ [self.landuse_class_names[i] for i in range(19)]
+ )
+ case 'osm_buildings_mask':
+ img = data.numpy().squeeze(0)
+ axs[idx].imshow(img, cmap='gray')
+ case 'osm_water_mask':
+ img = data.numpy().squeeze(0)
+ axs[idx].imshow(img, cmap='Blues')
+
+ axs[idx].axis('off')
+ if show_titles:
+ axs[idx].set_title(key)
+
+ if suptitle:
+ plt.suptitle(suptitle)
+
+ return fig
diff --git a/torchgeo/datasets/mmearth.py b/torchgeo/datasets/mmearth.py
index f363276c40a..b940537d8b4 100644
--- a/torchgeo/datasets/mmearth.py
+++ b/torchgeo/datasets/mmearth.py
@@ -206,12 +206,12 @@ def __init__(
"""
lazy_import('h5py')
- assert (
- normalization_mode in self.norm_modes
- ), f'Invalid normalization mode: {normalization_mode}, please choose from {self.norm_modes}'
- assert (
- subset in self.subsets
- ), f'Invalid dataset version: {subset}, please choose from {self.subsets}'
+ assert normalization_mode in self.norm_modes, (
+ f'Invalid normalization mode: {normalization_mode}, please choose from {self.norm_modes}'
+ )
+ assert subset in self.subsets, (
+ f'Invalid dataset version: {subset}, please choose from {self.subsets}'
+ )
self._validate_modalities(modalities)
self.modalities = modalities
diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py
index 501fd6db8f9..681f0e242bc 100644
--- a/torchgeo/datasets/nlcd.py
+++ b/torchgeo/datasets/nlcd.py
@@ -167,9 +167,9 @@ def __init__(
'NLCD data product only exists for the following years: '
f'{list(self.md5s.keys())}.'
)
- assert (
- set(classes) <= self.cmap.keys()
- ), f'Only the following classes are valid: {list(self.cmap.keys())}.'
+ assert set(classes) <= self.cmap.keys(), (
+ f'Only the following classes are valid: {list(self.cmap.keys())}.'
+ )
assert 0 in classes, 'Classes must include the background class: 0'
self.paths = paths
diff --git a/torchgeo/datasets/seasonet.py b/torchgeo/datasets/seasonet.py
index 3e47a8ec491..ebbb9036374 100644
--- a/torchgeo/datasets/seasonet.py
+++ b/torchgeo/datasets/seasonet.py
@@ -450,7 +450,7 @@ def plot(
axs[ax].imshow(image)
axs[ax].axis('off')
if show_titles:
- axs[ax].set_title(f'Image {ax+1}')
+ axs[ax].set_title(f'Image {ax + 1}')
axs[ax + 1].imshow(mask, vmin=0, vmax=32, cmap=plt_cmap, interpolation='none')
axs[ax + 1].axis('off')
diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py
index 3accf32d2af..d5a68f48488 100644
--- a/torchgeo/datasets/skippd.py
+++ b/torchgeo/datasets/skippd.py
@@ -104,14 +104,14 @@ def __init__(
"""
lazy_import('h5py')
- assert (
- split in self.valid_splits
- ), f'Please choose one of these valid data splits {self.valid_splits}.'
+ assert split in self.valid_splits, (
+ f'Please choose one of these valid data splits {self.valid_splits}.'
+ )
self.split = split
- assert (
- task in self.valid_tasks
- ), f'Please choose one of these valid tasks {self.valid_tasks}.'
+ assert task in self.valid_tasks, (
+ f'Please choose one of these valid tasks {self.valid_tasks}.'
+ )
self.task = task
self.root = root
diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py
index a8643873c5b..841cd7173de 100644
--- a/torchgeo/datasets/south_africa_crop_type.py
+++ b/torchgeo/datasets/south_africa_crop_type.py
@@ -131,9 +131,9 @@ def __init__(
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
- assert (
- set(classes) <= self.cmap.keys()
- ), f'Only the following classes are valid: {list(self.cmap.keys())}.'
+ assert set(classes) <= self.cmap.keys(), (
+ f'Only the following classes are valid: {list(self.cmap.keys())}.'
+ )
assert 0 in classes, 'Classes must include the background class: 0'
self.paths = paths
diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py
index 13c5a8474c4..111fe487e09 100644
--- a/torchgeo/datasets/ssl4eo_benchmark.py
+++ b/torchgeo/datasets/ssl4eo_benchmark.py
@@ -138,26 +138,26 @@ def __init__(
AssertionError: if any arguments are invalid
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
- assert (
- sensor in self.valid_sensors
- ), f'Only supports one of {self.valid_sensors}, but found {sensor}.'
+ assert sensor in self.valid_sensors, (
+ f'Only supports one of {self.valid_sensors}, but found {sensor}.'
+ )
self.sensor = sensor
- assert (
- product in self.valid_products
- ), f'Only supports one of {self.valid_products}, but found {product}.'
+ assert product in self.valid_products, (
+ f'Only supports one of {self.valid_products}, but found {product}.'
+ )
self.product = product
- assert (
- split in self.valid_splits
- ), f'Only supports one of {self.valid_splits}, but found {split}.'
+ assert split in self.valid_splits, (
+ f'Only supports one of {self.valid_splits}, but found {split}.'
+ )
self.split = split
self.cmap = self.cmaps[product]
if classes is None:
classes = list(self.cmap.keys())
- assert (
- set(classes) <= self.cmap.keys()
- ), f'Only the following classes are valid: {list(self.cmap.keys())}.'
+ assert set(classes) <= self.cmap.keys(), (
+ f'Only the following classes are valid: {list(self.cmap.keys())}.'
+ )
assert 0 in classes, 'Classes must include the background class: 0'
self.root = root
diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py
index eec9be57ab3..4d3a0b4de9f 100644
--- a/torchgeo/datasets/sustainbench_crop_yield.py
+++ b/torchgeo/datasets/sustainbench_crop_yield.py
@@ -82,14 +82,14 @@ def __init__(
is invalid
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
- assert set(countries).issubset(
- self.valid_countries
- ), f'Please choose a subset of these valid countried: {self.valid_countries}.'
+ assert set(countries).issubset(self.valid_countries), (
+ f'Please choose a subset of these valid countried: {self.valid_countries}.'
+ )
self.countries = countries
- assert (
- split in self.valid_splits
- ), f'Pleas choose one of these valid data splits {self.valid_splits}.'
+ assert split in self.valid_splits, (
+ f'Pleas choose one of these valid data splits {self.valid_splits}.'
+ )
self.split = split
self.root = root
diff --git a/torchgeo/models/croma.py b/torchgeo/models/croma.py
index 475c32fd3a9..de57a835936 100644
--- a/torchgeo/models/croma.py
+++ b/torchgeo/models/croma.py
@@ -56,9 +56,9 @@ def __init__(
"""
super().__init__()
for modality in modalities:
- assert (
- modality in self.valid_modalities
- ), f'{modality} is not a valid modality'
+ assert modality in self.valid_modalities, (
+ f'{modality} is not a valid modality'
+ )
assert image_size % 8 == 0, 'image_size must be a multiple of 8'
assert num_heads % 2 == 0, 'num_heads must be a power of 2'