Skip to content

Commit

Permalink
add evaluation code
Browse files Browse the repository at this point in the history
  • Loading branch information
dwofk committed Mar 4, 2019
1 parent 3dea2a8 commit 83eeed5
Show file tree
Hide file tree
Showing 10 changed files with 1,998 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__/
*.pyc
114 changes: 114 additions & 0 deletions dataloaders/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os
import os.path
import numpy as np
import torch.utils.data as data
import h5py
import dataloaders.transforms as transforms

def h5_loader(path):
h5f = h5py.File(path, "r")
rgb = np.array(h5f['rgb'])
rgb = np.transpose(rgb, (1, 2, 0))
depth = np.array(h5f['depth'])
return rgb, depth

# def rgb2grayscale(rgb):
# return rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114

class MyDataloader(data.Dataset):
modality_names = ['rgb']

def is_image_file(self, filename):
IMG_EXTENSIONS = ['.h5']
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def find_classes(self, dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx

def make_dataset(self, dir, class_to_idx):
images = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if self.is_image_file(fname):
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item)
return images

color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4)

def __init__(self, root, split, modality='rgb', loader=h5_loader):
classes, class_to_idx = self.find_classes(root)
imgs = self.make_dataset(root, class_to_idx)
assert len(imgs)>0, "Found 0 images in subfolders of: " + root + "\n"
# print("Found {} images in {} folder.".format(len(imgs), split))
self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
if split == 'train':
self.transform = self.train_transform
elif split == 'holdout':
self.transform = self.val_transform
elif split == 'val':
self.transform = self.val_transform
else:
raise (RuntimeError("Invalid dataset split: " + split + "\n"
"Supported dataset splits are: train, val"))
self.loader = loader

assert (modality in self.modality_names), "Invalid modality split: " + modality + "\n" + \
"Supported dataset splits are: " + ''.join(self.modality_names)
self.modality = modality

def train_transform(self, rgb, depth):
raise (RuntimeError("train_transform() is not implemented. "))

def val_transform(rgb, depth):
raise (RuntimeError("val_transform() is not implemented."))

def __getraw__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (rgb, depth) the raw data.
"""
path, target = self.imgs[index]
rgb, depth = self.loader(path)
return rgb, depth

def __getitem__(self, index):
rgb, depth = self.__getraw__(index)
if self.transform is not None:
rgb_np, depth_np = self.transform(rgb, depth)
else:
raise(RuntimeError("transform not defined"))

# color normalization
# rgb_tensor = normalize_rgb(rgb_tensor)
# rgb_np = normalize_np(rgb_np)

if self.modality == 'rgb':
input_np = rgb_np

to_tensor = transforms.ToTensor()
input_tensor = to_tensor(input_np)
while input_tensor.dim() < 3:
input_tensor = input_tensor.unsqueeze(0)
depth_tensor = to_tensor(depth_np)
depth_tensor = depth_tensor.unsqueeze(0)

return input_tensor, depth_tensor

def __len__(self):
return len(self.imgs)
59 changes: 59 additions & 0 deletions dataloaders/nyu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np
import dataloaders.transforms as transforms
from dataloaders.dataloader import MyDataloader

iheight, iwidth = 480, 640 # raw image size

class NYUDataset(MyDataloader):
def __init__(self, root, split, modality='rgb'):
self.split = split
super(NYUDataset, self).__init__(root, split, modality)
self.output_size = (224, 224)

def is_image_file(self, filename):
# IMG_EXTENSIONS = ['.h5']
if self.split == 'train':
return (filename.endswith('.h5') and \
'00001.h5' not in filename and '00201.h5' not in filename)
elif self.split == 'holdout':
return ('00001.h5' in filename or '00201.h5' in filename)
elif self.split == 'val':
return (filename.endswith('.h5'))
else:
raise (RuntimeError("Invalid dataset split: " + split + "\n"
"Supported dataset splits are: train, val"))

def train_transform(self, rgb, depth):
s = np.random.uniform(1.0, 1.5) # random scaling
depth_np = depth / s
angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip

# perform 1st step of data augmentation
transform = transforms.Compose([
transforms.Resize(250.0 / iheight), # this is for computational efficiency, since rotation can be slow
transforms.Rotate(angle),
transforms.Resize(s),
transforms.CenterCrop((228, 304)),
transforms.HorizontalFlip(do_flip),
transforms.Resize(self.output_size),
])
rgb_np = transform(rgb)
rgb_np = self.color_jitter(rgb_np) # random color jittering
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
depth_np = transform(depth_np)

return rgb_np, depth_np

def val_transform(self, rgb, depth):
depth_np = depth
transform = transforms.Compose([
transforms.Resize(250.0 / iheight),
transforms.CenterCrop((228, 304)),
transforms.Resize(self.output_size),
])
rgb_np = transform(rgb)
rgb_np = np.asfarray(rgb_np, dtype='float') / 255
depth_np = transform(depth_np)

return rgb_np, depth_np
Loading

0 comments on commit 83eeed5

Please sign in to comment.