Skip to content

Commit

Permalink
Fix for new datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Nov 5, 2024
1 parent 9c74986 commit e34e120
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 15 deletions.
11 changes: 6 additions & 5 deletions torchgeo/datamodules/caffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch

from ..datasets import CaFFe
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule


Expand Down Expand Up @@ -40,16 +39,18 @@ def __init__(

self.size = size

self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.Resize(size),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)

self.aug = AugmentationSequential(
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.Resize(size),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)
10 changes: 5 additions & 5 deletions torchgeo/datamodules/ftw.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch

from ..datasets import FieldsOfTheWorld
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule


Expand Down Expand Up @@ -55,16 +54,17 @@ def __init__(
self.val_countries = val_countries
self.test_countries = test_countries

self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomRotation(p=0.5, degrees=90),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
K.RandomSharpness(p=0.5),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)
self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask']
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True
)

def setup(self, stage: str) -> None:
Expand Down
10 changes: 6 additions & 4 deletions torchgeo/datamodules/geonrw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch.utils.data import Subset

from ..datasets import GeoNRW
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule
from .utils import group_shuffle_split

Expand Down Expand Up @@ -38,14 +37,17 @@ def __init__(
"""
super().__init__(GeoNRW, batch_size, num_workers, **kwargs)

self.train_aug = AugmentationSequential(
self.train_aug = K.AugmentationSequential(
K.Resize(size),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=['image', 'mask'],
data_keys=None,
keepdim=True,
)

self.aug = AugmentationSequential(K.Resize(size), data_keys=['image', 'mask'])
self.aug = K.AugmentationSequential(
K.Resize(size), data_keys=None, keepdim=True
)

self.size = size

Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/geonrw.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:

# rename to torchgeo standard keys
sample['image'] = sample.pop('rgb').float()
sample['mask'] = sample.pop('seg').long()
sample['mask'] = sample.pop('seg').long().squeeze(0)

if self.transforms:
sample = self.transforms(sample)
Expand Down

0 comments on commit e34e120

Please sign in to comment.