From 402b712eac25f5583c83b6d87485365965ddb297 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 27 Aug 2024 12:37:37 +0200 Subject: [PATCH 1/2] DOFA: fix bug in patch embedding --- torchgeo/models/dofa.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchgeo/models/dofa.py b/torchgeo/models/dofa.py index e82fbd574bd..e362a85ada5 100644 --- a/torchgeo/models/dofa.py +++ b/torchgeo/models/dofa.py @@ -203,8 +203,10 @@ def forward(self, x: Tensor, wavelengths: Tensor) -> tuple[Tensor, Tensor]: weight, bias = self.weight_generator(waves) # 3x3x3 dynamic_weight = weight.view( - self.embed_dim, inplanes, self.kernel_size, self.kernel_size - ) # 3xoutdx16x16 + inplanes, self.kernel_size, self.kernel_size, self.embed_dim + ) + dynamic_weight = dynamic_weight.permute([3, 0, 1, 2]) + if bias is not None: bias = bias.view([self.embed_dim]) * self.scaler From 763870d24ada90c5135e2d6832c3ec65977b678b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 1 Sep 2024 11:25:49 +0200 Subject: [PATCH 2/2] Update checkpoints --- docs/api/weights/agnostic.csv | 4 ++-- torchgeo/models/dofa.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/api/weights/agnostic.csv b/docs/api/weights/agnostic.csv index 5937da65018..e06f47e2dc8 100644 --- a/docs/api/weights/agnostic.csv +++ b/docs/api/weights/agnostic.csv @@ -1,5 +1,5 @@ Weight,Source,Citation,License,Spatial,Temporal,Spectral,m-bigearthnet,m-forestnet,m-brick-kiln,m-pv4ger,m-so2sat,m-eurosat,m-pv4ger-seg,m-nz-cattle,m-NeonTree,m-cashew-plant,m-SA-crop,m-chesapeake -DOFABase16_Weights.DOFA_MAE,`link `__,`link `__,CC-BY-4.0,implicit,-,explicit,63.8,45.3,94.7,96.9,52.1,92.2,94.7,81.6,58.6,48.3,31.3,65.4 -DOFALarge16_Weights.DOFA_MAE,`link `__,`link `__,CC-BY-4.0,implicit,-,explicit,64.4,47.4,95.1,97.3,59.3,93.8,95.0,81.7,59.1,53.8,32.1,66.3 +DOFABase16_Weights.DOFA_MAE,`link `__,`link `__,CC-BY-4.0,implicit,-,explicit,65.7,50.9,95.8,96.9,55.1,93.9,94.5,81.4,58.8,51.5,33.0,65.3 +DOFALarge16_Weights.DOFA_MAE,`link `__,`link `__,CC-BY-4.0,implicit,-,explicit,67.5,54.6,96.9,97.3,60.1,97.1,95.0,81.8,59.4,56.9,32.1,66.3 ResNet50_Weights.FMOW_RGB_GASSL,`link `__,`link `__,-,implicit,-,-,,,,,,,,,,,, ScaleMAE_ViTLarge16_Weights.FMOW_RGB_SCALEMAE,`link `__,`link `__,CC-BY-NC-4.0,explicit,-,-,,,,,,,,,,, diff --git a/torchgeo/models/dofa.py b/torchgeo/models/dofa.py index e362a85ada5..cdbd8242873 100644 --- a/torchgeo/models/dofa.py +++ b/torchgeo/models/dofa.py @@ -386,7 +386,7 @@ class DOFABase16_Weights(WeightsEnum): # type: ignore[misc] """ DOFA_MAE = Weights( - url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth', + url='https://hf.co/torchgeo/dofa/resolve/b8db318b64a90b9e085ec04ba8851233c5893666/dofa_base_patch16_224-a0275954.pth', transforms=_dofa_transforms, meta={ 'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k', @@ -405,7 +405,7 @@ class DOFALarge16_Weights(WeightsEnum): # type: ignore[misc] """ DOFA_MAE = Weights( - url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth', + url='https://hf.co/torchgeo/dofa/resolve/b8db318b64a90b9e085ec04ba8851233c5893666/dofa_large_patch16_224-0ff904d3.pth', transforms=_dofa_transforms, meta={ 'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k',