Skip to content

Commit

Permalink
Fix for segmentation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Jul 2, 2024
1 parent f6edb6c commit c40d8bc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchgeo/datamodules/spacenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def __init__(
K.RandomVerticalFlip(p=0.5),
K.RandomSharpness(p=0.5),
K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
data_keys=['image', 'mask'],
data_keys=None,
)
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.PadTo((448, 448)),
data_keys=['image', 'mask'],
data_keys=None,
)

def setup(self, stage: str) -> None:
Expand Down
8 changes: 8 additions & 0 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import torch.nn as nn
from einops import rearrange
from matplotlib.figure import Figure
from torch import Tensor
from torchmetrics import MetricCollection
Expand Down Expand Up @@ -225,6 +226,9 @@ def training_step(
Returns:
The loss tensor.
"""
if 'mask' in batch and batch['mask'].shape[1] == 1:
batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w')

x = batch['image']
y = batch['mask']
batch_size = x.shape[0]
Expand All @@ -245,6 +249,8 @@ def validation_step(
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
if 'mask' in batch and batch['mask'].shape[1] == 1:
batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w')
x = batch['image']
y = batch['mask']
batch_size = x.shape[0]
Expand Down Expand Up @@ -289,6 +295,8 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
if 'mask' in batch and batch['mask'].shape[1] == 1:
batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w')
x = batch['image']
y = batch['mask']
batch_size = x.shape[0]
Expand Down

0 comments on commit c40d8bc

Please sign in to comment.