-
Notifications
You must be signed in to change notification settings - Fork 188
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
1,998 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
__pycache__/ | ||
*.pyc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import os | ||
import os.path | ||
import numpy as np | ||
import torch.utils.data as data | ||
import h5py | ||
import dataloaders.transforms as transforms | ||
|
||
def h5_loader(path): | ||
h5f = h5py.File(path, "r") | ||
rgb = np.array(h5f['rgb']) | ||
rgb = np.transpose(rgb, (1, 2, 0)) | ||
depth = np.array(h5f['depth']) | ||
return rgb, depth | ||
|
||
# def rgb2grayscale(rgb): | ||
# return rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114 | ||
|
||
class MyDataloader(data.Dataset): | ||
modality_names = ['rgb'] | ||
|
||
def is_image_file(self, filename): | ||
IMG_EXTENSIONS = ['.h5'] | ||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) | ||
|
||
def find_classes(self, dir): | ||
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] | ||
classes.sort() | ||
class_to_idx = {classes[i]: i for i in range(len(classes))} | ||
return classes, class_to_idx | ||
|
||
def make_dataset(self, dir, class_to_idx): | ||
images = [] | ||
dir = os.path.expanduser(dir) | ||
for target in sorted(os.listdir(dir)): | ||
d = os.path.join(dir, target) | ||
if not os.path.isdir(d): | ||
continue | ||
for root, _, fnames in sorted(os.walk(d)): | ||
for fname in sorted(fnames): | ||
if self.is_image_file(fname): | ||
path = os.path.join(root, fname) | ||
item = (path, class_to_idx[target]) | ||
images.append(item) | ||
return images | ||
|
||
color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4) | ||
|
||
def __init__(self, root, split, modality='rgb', loader=h5_loader): | ||
classes, class_to_idx = self.find_classes(root) | ||
imgs = self.make_dataset(root, class_to_idx) | ||
assert len(imgs)>0, "Found 0 images in subfolders of: " + root + "\n" | ||
# print("Found {} images in {} folder.".format(len(imgs), split)) | ||
self.root = root | ||
self.imgs = imgs | ||
self.classes = classes | ||
self.class_to_idx = class_to_idx | ||
if split == 'train': | ||
self.transform = self.train_transform | ||
elif split == 'holdout': | ||
self.transform = self.val_transform | ||
elif split == 'val': | ||
self.transform = self.val_transform | ||
else: | ||
raise (RuntimeError("Invalid dataset split: " + split + "\n" | ||
"Supported dataset splits are: train, val")) | ||
self.loader = loader | ||
|
||
assert (modality in self.modality_names), "Invalid modality split: " + modality + "\n" + \ | ||
"Supported dataset splits are: " + ''.join(self.modality_names) | ||
self.modality = modality | ||
|
||
def train_transform(self, rgb, depth): | ||
raise (RuntimeError("train_transform() is not implemented. ")) | ||
|
||
def val_transform(rgb, depth): | ||
raise (RuntimeError("val_transform() is not implemented.")) | ||
|
||
def __getraw__(self, index): | ||
""" | ||
Args: | ||
index (int): Index | ||
Returns: | ||
tuple: (rgb, depth) the raw data. | ||
""" | ||
path, target = self.imgs[index] | ||
rgb, depth = self.loader(path) | ||
return rgb, depth | ||
|
||
def __getitem__(self, index): | ||
rgb, depth = self.__getraw__(index) | ||
if self.transform is not None: | ||
rgb_np, depth_np = self.transform(rgb, depth) | ||
else: | ||
raise(RuntimeError("transform not defined")) | ||
|
||
# color normalization | ||
# rgb_tensor = normalize_rgb(rgb_tensor) | ||
# rgb_np = normalize_np(rgb_np) | ||
|
||
if self.modality == 'rgb': | ||
input_np = rgb_np | ||
|
||
to_tensor = transforms.ToTensor() | ||
input_tensor = to_tensor(input_np) | ||
while input_tensor.dim() < 3: | ||
input_tensor = input_tensor.unsqueeze(0) | ||
depth_tensor = to_tensor(depth_np) | ||
depth_tensor = depth_tensor.unsqueeze(0) | ||
|
||
return input_tensor, depth_tensor | ||
|
||
def __len__(self): | ||
return len(self.imgs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import numpy as np | ||
import dataloaders.transforms as transforms | ||
from dataloaders.dataloader import MyDataloader | ||
|
||
iheight, iwidth = 480, 640 # raw image size | ||
|
||
class NYUDataset(MyDataloader): | ||
def __init__(self, root, split, modality='rgb'): | ||
self.split = split | ||
super(NYUDataset, self).__init__(root, split, modality) | ||
self.output_size = (224, 224) | ||
|
||
def is_image_file(self, filename): | ||
# IMG_EXTENSIONS = ['.h5'] | ||
if self.split == 'train': | ||
return (filename.endswith('.h5') and \ | ||
'00001.h5' not in filename and '00201.h5' not in filename) | ||
elif self.split == 'holdout': | ||
return ('00001.h5' in filename or '00201.h5' in filename) | ||
elif self.split == 'val': | ||
return (filename.endswith('.h5')) | ||
else: | ||
raise (RuntimeError("Invalid dataset split: " + split + "\n" | ||
"Supported dataset splits are: train, val")) | ||
|
||
def train_transform(self, rgb, depth): | ||
s = np.random.uniform(1.0, 1.5) # random scaling | ||
depth_np = depth / s | ||
angle = np.random.uniform(-5.0, 5.0) # random rotation degrees | ||
do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip | ||
|
||
# perform 1st step of data augmentation | ||
transform = transforms.Compose([ | ||
transforms.Resize(250.0 / iheight), # this is for computational efficiency, since rotation can be slow | ||
transforms.Rotate(angle), | ||
transforms.Resize(s), | ||
transforms.CenterCrop((228, 304)), | ||
transforms.HorizontalFlip(do_flip), | ||
transforms.Resize(self.output_size), | ||
]) | ||
rgb_np = transform(rgb) | ||
rgb_np = self.color_jitter(rgb_np) # random color jittering | ||
rgb_np = np.asfarray(rgb_np, dtype='float') / 255 | ||
depth_np = transform(depth_np) | ||
|
||
return rgb_np, depth_np | ||
|
||
def val_transform(self, rgb, depth): | ||
depth_np = depth | ||
transform = transforms.Compose([ | ||
transforms.Resize(250.0 / iheight), | ||
transforms.CenterCrop((228, 304)), | ||
transforms.Resize(self.output_size), | ||
]) | ||
rgb_np = transform(rgb) | ||
rgb_np = np.asfarray(rgb_np, dtype='float') / 255 | ||
depth_np = transform(depth_np) | ||
|
||
return rgb_np, depth_np |
Oops, something went wrong.