-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
110 lines (96 loc) · 4.24 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import fnmatch
import os
import random
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
class RandomScaleCrop(object):
"""
Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34.
"""
def __init__(self, scale=[1.0, 1.2, 1.5]):
self.scale = scale
def __call__(self, img, label, depth, normal):
height, width = img.shape[-2:]
sc = self.scale[random.randint(0, len(self.scale) - 1)]
h, w = int(height / sc), int(width / sc)
i = random.randint(0, height - h)
j = random.randint(0, width - w)
img_ = F.interpolate(
img[None, :, i : i + h, j : j + w],
size=(height, width),
mode="bilinear",
align_corners=True,
).squeeze(0)
label_ = (
F.interpolate(
label[None, None, i : i + h, j : j + w],
size=(height, width),
mode="nearest",
)
.squeeze(0)
.squeeze(0)
)
depth_ = F.interpolate(
depth[None, :, i : i + h, j : j + w], size=(height, width), mode="nearest"
).squeeze(0)
normal_ = F.interpolate(
normal[None, :, i : i + h, j : j + w],
size=(height, width),
mode="bilinear",
align_corners=True,
).squeeze(0)
return img_, label_, depth_ / sc, normal_
class RandomScaleCropCityScapes(object):
"""
Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34.
"""
def __init__(self, scale=[1.0, 1.2, 1.5]):
self.scale = scale
def __call__(self, img, label, depth):
height, width = img.shape[-2:]
sc = self.scale[random.randint(0, len(self.scale) - 1)]
h, w = int(height / sc), int(width / sc)
i = random.randint(0, height - h)
j = random.randint(0, width - w)
img_ = F.interpolate(img[None, :, i:i + h, j:j + w], size=(height, width), mode='bilinear', align_corners=True).squeeze(0)
label_ = F.interpolate(label[None, None, i:i + h, j:j + w], size=(height, width), mode='nearest').squeeze(0).squeeze(0)
depth_ = F.interpolate(depth[None, :, i:i + h, j:j + w], size=(height, width), mode='nearest').squeeze(0)
return img_, label_, depth_ / sc
class CityScapes(Dataset):
"""
We could further improve the performance with the data augmentation of NYUv2 defined in:
[1] PAD-Net: Multi-Tasks Guided Prediction-and-Distillation Network for Simultaneous Depth Estimation and Scene Parsing
[2] Pattern affinitive propagation across depth, surface normal and semantic segmentation
[3] Mti-net: Multiscale task interaction networks for multi-task learning
1. Random scale in a selected raio 1.0, 1.2, and 1.5.
2. Random horizontal flip.
Please note that: all baselines and MTAN did NOT apply data augmentation in the original paper.
"""
def __init__(self, root, train=True, augmentation=False):
self.train = train
self.root = os.path.expanduser(root)
self.augmentation = augmentation
# read the data file
if train:
self.data_path = root + '/train'
else:
self.data_path = root + '/val'
# calculate data length
self.data_len = len(fnmatch.filter(os.listdir(self.data_path + '/image'), '*.npy'))
def __getitem__(self, index):
# load data from the pre-processed npy files
image = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/image/{:d}.npy'.format(index)), -1, 0))
semantic = torch.from_numpy(np.load(self.data_path + '/label_7/{:d}.npy'.format(index)))
depth = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/depth/{:d}.npy'.format(index)), -1, 0))
# apply data augmentation if required
if self.augmentation:
image, semantic, depth = RandomScaleCropCityScapes()(image, semantic, depth)
if torch.rand(1) < 0.5:
image = torch.flip(image, dims=[2])
semantic = torch.flip(semantic, dims=[1])
depth = torch.flip(depth, dims=[2])
return image.float(), semantic.float(), depth.float()
def __len__(self):
return self.data_len