Skip to content

Commit c40d8bc

Browse files
committed
Fix for segmentation tests
1 parent f6edb6c commit c40d8bc

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

torchgeo/datamodules/spacenet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ def __init__(
5353
K.RandomVerticalFlip(p=0.5),
5454
K.RandomSharpness(p=0.5),
5555
K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
56-
data_keys=['image', 'mask'],
56+
data_keys=None,
5757
)
5858
self.aug = K.AugmentationSequential(
5959
K.Normalize(mean=self.mean, std=self.std),
6060
K.PadTo((448, 448)),
61-
data_keys=['image', 'mask'],
61+
data_keys=None,
6262
)
6363

6464
def setup(self, stage: str) -> None:

torchgeo/trainers/segmentation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import matplotlib.pyplot as plt
1010
import segmentation_models_pytorch as smp
1111
import torch.nn as nn
12+
from einops import rearrange
1213
from matplotlib.figure import Figure
1314
from torch import Tensor
1415
from torchmetrics import MetricCollection
@@ -225,6 +226,9 @@ def training_step(
225226
Returns:
226227
The loss tensor.
227228
"""
229+
if 'mask' in batch and batch['mask'].shape[1] == 1:
230+
batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w')
231+
228232
x = batch['image']
229233
y = batch['mask']
230234
batch_size = x.shape[0]
@@ -245,6 +249,8 @@ def validation_step(
245249
batch_idx: Integer displaying index of this batch.
246250
dataloader_idx: Index of the current dataloader.
247251
"""
252+
if 'mask' in batch and batch['mask'].shape[1] == 1:
253+
batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w')
248254
x = batch['image']
249255
y = batch['mask']
250256
batch_size = x.shape[0]
@@ -289,6 +295,8 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
289295
batch_idx: Integer displaying index of this batch.
290296
dataloader_idx: Index of the current dataloader.
291297
"""
298+
if 'mask' in batch and batch['mask'].shape[1] == 1:
299+
batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w')
292300
x = batch['image']
293301
y = batch['mask']
294302
batch_size = x.shape[0]

0 commit comments

Comments
 (0)