From 83eeed5f52912dc91221f25e690ce231abca2f4d Mon Sep 17 00:00:00 2001 From: Diana Wofk Date: Mon, 4 Mar 2019 01:14:51 -0500 Subject: [PATCH] add evaluation code --- .gitignore | 2 + dataloaders/dataloader.py | 114 ++++++ dataloaders/nyu.py | 59 +++ dataloaders/transforms.py | 622 +++++++++++++++++++++++++++++ imagenet/__init__.py | 0 imagenet/mobilenet.py | 79 ++++ main.py | 130 ++++++ metrics.py | 95 +++++ models.py | 814 ++++++++++++++++++++++++++++++++++++++ utils.py | 83 ++++ 10 files changed, 1998 insertions(+) create mode 100644 .gitignore create mode 100644 dataloaders/dataloader.py create mode 100644 dataloaders/nyu.py create mode 100644 dataloaders/transforms.py create mode 100644 imagenet/__init__.py create mode 100644 imagenet/mobilenet.py create mode 100644 main.py create mode 100644 metrics.py create mode 100644 models.py create mode 100644 utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7a60b85 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +*.pyc diff --git a/dataloaders/dataloader.py b/dataloaders/dataloader.py new file mode 100644 index 0000000..98f1b61 --- /dev/null +++ b/dataloaders/dataloader.py @@ -0,0 +1,114 @@ +import os +import os.path +import numpy as np +import torch.utils.data as data +import h5py +import dataloaders.transforms as transforms + +def h5_loader(path): + h5f = h5py.File(path, "r") + rgb = np.array(h5f['rgb']) + rgb = np.transpose(rgb, (1, 2, 0)) + depth = np.array(h5f['depth']) + return rgb, depth + +# def rgb2grayscale(rgb): +# return rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114 + +class MyDataloader(data.Dataset): + modality_names = ['rgb'] + + def is_image_file(self, filename): + IMG_EXTENSIONS = ['.h5'] + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + def find_classes(self, dir): + classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] + classes.sort() + class_to_idx = {classes[i]: i for i in range(len(classes))} + return classes, class_to_idx + + def make_dataset(self, dir, class_to_idx): + images = [] + dir = os.path.expanduser(dir) + for target in sorted(os.listdir(dir)): + d = os.path.join(dir, target) + if not os.path.isdir(d): + continue + for root, _, fnames in sorted(os.walk(d)): + for fname in sorted(fnames): + if self.is_image_file(fname): + path = os.path.join(root, fname) + item = (path, class_to_idx[target]) + images.append(item) + return images + + color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4) + + def __init__(self, root, split, modality='rgb', loader=h5_loader): + classes, class_to_idx = self.find_classes(root) + imgs = self.make_dataset(root, class_to_idx) + assert len(imgs)>0, "Found 0 images in subfolders of: " + root + "\n" + # print("Found {} images in {} folder.".format(len(imgs), split)) + self.root = root + self.imgs = imgs + self.classes = classes + self.class_to_idx = class_to_idx + if split == 'train': + self.transform = self.train_transform + elif split == 'holdout': + self.transform = self.val_transform + elif split == 'val': + self.transform = self.val_transform + else: + raise (RuntimeError("Invalid dataset split: " + split + "\n" + "Supported dataset splits are: train, val")) + self.loader = loader + + assert (modality in self.modality_names), "Invalid modality split: " + modality + "\n" + \ + "Supported dataset splits are: " + ''.join(self.modality_names) + self.modality = modality + + def train_transform(self, rgb, depth): + raise (RuntimeError("train_transform() is not implemented. ")) + + def val_transform(rgb, depth): + raise (RuntimeError("val_transform() is not implemented.")) + + def __getraw__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (rgb, depth) the raw data. + """ + path, target = self.imgs[index] + rgb, depth = self.loader(path) + return rgb, depth + + def __getitem__(self, index): + rgb, depth = self.__getraw__(index) + if self.transform is not None: + rgb_np, depth_np = self.transform(rgb, depth) + else: + raise(RuntimeError("transform not defined")) + + # color normalization + # rgb_tensor = normalize_rgb(rgb_tensor) + # rgb_np = normalize_np(rgb_np) + + if self.modality == 'rgb': + input_np = rgb_np + + to_tensor = transforms.ToTensor() + input_tensor = to_tensor(input_np) + while input_tensor.dim() < 3: + input_tensor = input_tensor.unsqueeze(0) + depth_tensor = to_tensor(depth_np) + depth_tensor = depth_tensor.unsqueeze(0) + + return input_tensor, depth_tensor + + def __len__(self): + return len(self.imgs) diff --git a/dataloaders/nyu.py b/dataloaders/nyu.py new file mode 100644 index 0000000..04650b4 --- /dev/null +++ b/dataloaders/nyu.py @@ -0,0 +1,59 @@ +import numpy as np +import dataloaders.transforms as transforms +from dataloaders.dataloader import MyDataloader + +iheight, iwidth = 480, 640 # raw image size + +class NYUDataset(MyDataloader): + def __init__(self, root, split, modality='rgb'): + self.split = split + super(NYUDataset, self).__init__(root, split, modality) + self.output_size = (224, 224) + + def is_image_file(self, filename): + # IMG_EXTENSIONS = ['.h5'] + if self.split == 'train': + return (filename.endswith('.h5') and \ + '00001.h5' not in filename and '00201.h5' not in filename) + elif self.split == 'holdout': + return ('00001.h5' in filename or '00201.h5' in filename) + elif self.split == 'val': + return (filename.endswith('.h5')) + else: + raise (RuntimeError("Invalid dataset split: " + split + "\n" + "Supported dataset splits are: train, val")) + + def train_transform(self, rgb, depth): + s = np.random.uniform(1.0, 1.5) # random scaling + depth_np = depth / s + angle = np.random.uniform(-5.0, 5.0) # random rotation degrees + do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip + + # perform 1st step of data augmentation + transform = transforms.Compose([ + transforms.Resize(250.0 / iheight), # this is for computational efficiency, since rotation can be slow + transforms.Rotate(angle), + transforms.Resize(s), + transforms.CenterCrop((228, 304)), + transforms.HorizontalFlip(do_flip), + transforms.Resize(self.output_size), + ]) + rgb_np = transform(rgb) + rgb_np = self.color_jitter(rgb_np) # random color jittering + rgb_np = np.asfarray(rgb_np, dtype='float') / 255 + depth_np = transform(depth_np) + + return rgb_np, depth_np + + def val_transform(self, rgb, depth): + depth_np = depth + transform = transforms.Compose([ + transforms.Resize(250.0 / iheight), + transforms.CenterCrop((228, 304)), + transforms.Resize(self.output_size), + ]) + rgb_np = transform(rgb) + rgb_np = np.asfarray(rgb_np, dtype='float') / 255 + depth_np = transform(depth_np) + + return rgb_np, depth_np diff --git a/dataloaders/transforms.py b/dataloaders/transforms.py new file mode 100644 index 0000000..fe80f38 --- /dev/null +++ b/dataloaders/transforms.py @@ -0,0 +1,622 @@ +from __future__ import division +import torch +import math +import random + +from PIL import Image, ImageOps, ImageEnhance +try: + import accimage +except ImportError: + accimage = None + +import numpy as np +import numbers +import types +import collections +import warnings + +import scipy.ndimage.interpolation as itpl +import scipy.misc as misc + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + +def _is_pil_image(img): + if accimage is not None: + return isinstance(img, (Image.Image, accimage.Image)) + else: + return isinstance(img, Image.Image) + +def _is_tensor_image(img): + return torch.is_tensor(img) and img.ndimension() == 3 + +def adjust_brightness(img, brightness_factor): + """Adjust brightness of an Image. + + Args: + img (PIL Image): PIL Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + PIL Image: Brightness adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +def adjust_contrast(img, contrast_factor): + """Adjust contrast of an Image. + + Args: + img (PIL Image): PIL Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + + Returns: + PIL Image: Contrast adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +def adjust_saturation(img, saturation_factor): + """Adjust color saturation of an image. + + Args: + img (PIL Image): PIL Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + PIL Image: Saturation adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +def adjust_hue(img, hue_factor): + """Adjust hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + See https://en.wikipedia.org/wiki/Hue for more details on Hue. + + Args: + img (PIL Image): PIL Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + PIL Image: Hue adjusted image. + """ + if not(-0.5 <= hue_factor <= 0.5): + raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) + + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + input_mode = img.mode + if input_mode in {'L', '1', 'I', 'F'}: + return img + + h, s, v = img.convert('HSV').split() + + np_h = np.array(h, dtype=np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over='ignore'): + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, 'L') + + img = Image.merge('HSV', (h, s, v)).convert(input_mode) + return img + + +def adjust_gamma(img, gamma, gain=1): + """Perform gamma correction on an image. + + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: + + I_out = 255 * gain * ((I_in / 255) ** gamma) + + See https://en.wikipedia.org/wiki/Gamma_correction for more details. + + Args: + img (PIL Image): PIL Image to be adjusted. + gamma (float): Non negative real number. gamma larger than 1 make the + shadows darker, while gamma smaller than 1 make dark regions + lighter. + gain (float): The constant multiplier. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if gamma < 0: + raise ValueError('Gamma should be a non-negative real number') + + input_mode = img.mode + img = img.convert('RGB') + + np_img = np.array(img, dtype=np.float32) + np_img = 255 * gain * ((np_img / 255) ** gamma) + np_img = np.uint8(np.clip(np_img, 0, 255)) + + img = Image.fromarray(np_img, 'RGB').convert(input_mode) + return img + + +class Compose(object): + """Composes several transforms together. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img): + for t in self.transforms: + img = t(img) + return img + + +class ToTensor(object): + """Convert a ``numpy.ndarray`` to tensor. + + Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). + """ + + def __call__(self, img): + """Convert a ``numpy.ndarray`` to tensor. + + Args: + img (numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + if not(_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + + if isinstance(img, np.ndarray): + # handle numpy array + if img.ndim == 3: + img = torch.from_numpy(img.transpose((2, 0, 1)).copy()) + elif img.ndim == 2: + img = torch.from_numpy(img.copy()) + else: + raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) + + # backward compatibility + # return img.float().div(255) + return img.float() + + +class NormalizeNumpyArray(object): + """Normalize a ``numpy.ndarray`` with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform + will normalize each channel of the input ``numpy.ndarray`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, img): + """ + Args: + img (numpy.ndarray): Image of size (H, W, C) to be normalized. + + Returns: + Tensor: Normalized image. + """ + if not(_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + # TODO: make efficient + print(img.shape) + for i in range(3): + img[:,:,i] = (img[:,:,i] - self.mean[i]) / self.std[i] + return img + +class NormalizeTensor(object): + """Normalize an tensor image with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + + Returns: + Tensor: Normalized Tensor image. + """ + if not _is_tensor_image(tensor): + raise TypeError('tensor is not a torch image.') + # TODO: make efficient + for t, m, s in zip(tensor, self.mean, self.std): + t.sub_(m).div_(s) + return tensor + +class Rotate(object): + """Rotates the given ``numpy.ndarray``. + + Args: + angle (float): The rotation angle in degrees. + """ + + def __init__(self, angle): + self.angle = angle + + def __call__(self, img): + """ + Args: + img (numpy.ndarray (C x H x W)): Image to be rotated. + + Returns: + img (numpy.ndarray (C x H x W)): Rotated image. + """ + + # order=0 means nearest-neighbor type interpolation + return itpl.rotate(img, self.angle, reshape=False, prefilter=False, order=0) + + +class Resize(object): + """Resize the the given ``numpy.ndarray`` to the given size. + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + """ + + def __init__(self, size, interpolation='nearest'): + assert isinstance(size, int) or isinstance(size, float) or \ + (isinstance(size, collections.Iterable) and len(size) == 2) + self.size = size + self.interpolation = interpolation + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be scaled. + Returns: + PIL Image: Rescaled image. + """ + if img.ndim == 3: + return misc.imresize(img, self.size, self.interpolation) + elif img.ndim == 2: + return misc.imresize(img, self.size, self.interpolation, 'F') + else: + RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) + + +class CenterCrop(object): + """Crops the given ``numpy.ndarray`` at the center. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + @staticmethod + def get_params(img, output_size): + """Get parameters for ``crop`` for center crop. + + Args: + img (numpy.ndarray (C x H x W)): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. + """ + h = img.shape[0] + w = img.shape[1] + th, tw = output_size + i = int(round((h - th) / 2.)) + j = int(round((w - tw) / 2.)) + + # # randomized cropping + # i = np.random.randint(i-3, i+4) + # j = np.random.randint(j-3, j+4) + + return i, j, th, tw + + def __call__(self, img): + """ + Args: + img (numpy.ndarray (C x H x W)): Image to be cropped. + + Returns: + img (numpy.ndarray (C x H x W)): Cropped image. + """ + i, j, h, w = self.get_params(img, self.size) + + """ + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + """ + if not(_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + if img.ndim == 3: + return img[i:i+h, j:j+w, :] + elif img.ndim == 2: + return img[i:i + h, j:j + w] + else: + raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) + +class BottomCrop(object): + """Crops the given ``numpy.ndarray`` at the bottom. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + @staticmethod + def get_params(img, output_size): + """Get parameters for ``crop`` for bottom crop. + + Args: + img (numpy.ndarray (C x H x W)): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for bottom crop. + """ + h = img.shape[0] + w = img.shape[1] + th, tw = output_size + i = h - th + j = int(round((w - tw) / 2.)) + + # randomized left and right cropping + # i = np.random.randint(i-3, i+4) + # j = np.random.randint(j-1, j+1) + + return i, j, th, tw + + def __call__(self, img): + """ + Args: + img (numpy.ndarray (C x H x W)): Image to be cropped. + + Returns: + img (numpy.ndarray (C x H x W)): Cropped image. + """ + i, j, h, w = self.get_params(img, self.size) + + """ + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + """ + if not(_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + if img.ndim == 3: + return img[i:i+h, j:j+w, :] + elif img.ndim == 2: + return img[i:i + h, j:j + w] + else: + raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) + +class Lambda(object): + """Apply a user-defined lambda as a transform. + + Args: + lambd (function): Lambda/function to be used for transform. + """ + + def __init__(self, lambd): + assert isinstance(lambd, types.LambdaType) + self.lambd = lambd + + def __call__(self, img): + return self.lambd(img) + + +class HorizontalFlip(object): + """Horizontally flip the given ``numpy.ndarray``. + + Args: + do_flip (boolean): whether or not do horizontal flip. + + """ + + def __init__(self, do_flip): + self.do_flip = do_flip + + def __call__(self, img): + """ + Args: + img (numpy.ndarray (C x H x W)): Image to be flipped. + + Returns: + img (numpy.ndarray (C x H x W)): flipped image. + """ + if not(_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + + if self.do_flip: + return np.fliplr(img) + else: + return img + + +class ColorJitter(object): + """Randomly change the brightness, contrast and saturation of an image. + + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + """Get a randomized transform to be applied on image. + + Arguments are same as that of __init__. + + Returns: + Transform which randomly adjusts brightness, contrast and + saturation in a random order. + """ + transforms = [] + if brightness > 0: + brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) + transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor))) + + if contrast > 0: + contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) + transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor))) + + if saturation > 0: + saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) + transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor))) + + if hue > 0: + hue_factor = np.random.uniform(-hue, hue) + transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor))) + + np.random.shuffle(transforms) + transform = Compose(transforms) + + return transform + + def __call__(self, img): + """ + Args: + img (numpy.ndarray (C x H x W)): Input image. + + Returns: + img (numpy.ndarray (C x H x W)): Color jittered image. + """ + if not(_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + + pil = Image.fromarray(img) + transform = self.get_params(self.brightness, self.contrast, + self.saturation, self.hue) + return np.array(transform(pil)) + +class Crop(object): + """Crops the given PIL Image to a rectangular region based on a given + 4-tuple defining the left, upper pixel coordinated, hight and width size. + + Args: + a tuple: (upper pixel coordinate, left pixel coordinate, hight, width)-tuple + """ + + def __init__(self, i, j, h, w): + """ + i: Upper pixel coordinate. + j: Left pixel coordinate. + h: Height of the cropped image. + w: Width of the cropped image. + """ + self.i = i + self.j = j + self.h = h + self.w = w + + def __call__(self, img): + """ + Args: + img (numpy.ndarray (C x H x W)): Image to be cropped. + Returns: + img (numpy.ndarray (C x H x W)): Cropped image. + """ + + i, j, h, w = self.i, self.j, self.h, self.w + + if not(_is_numpy_image(img)): + raise TypeError('img should be ndarray. Got {}'.format(type(img))) + if img.ndim == 3: + return img[i:i + h, j:j + w, :] + elif img.ndim == 2: + return img[i:i + h, j:j + w] + else: + raise RuntimeError( + 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) + + def __repr__(self): + return self.__class__.__name__ + '(i={0},j={1},h={2},w={3})'.format( + self.i, self.j, self.h, self.w) diff --git a/imagenet/__init__.py b/imagenet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/imagenet/mobilenet.py b/imagenet/mobilenet.py new file mode 100644 index 0000000..ad43333 --- /dev/null +++ b/imagenet/mobilenet.py @@ -0,0 +1,79 @@ +import os +import shutil +import time + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data + +class MobileNet(nn.Module): + def __init__(self, relu6=True): + super(MobileNet, self).__init__() + + def relu(relu6): + if relu6: + return nn.ReLU6(inplace=True) + else: + return nn.ReLU(inplace=True) + + def conv_bn(inp, oup, stride, relu6): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + relu(relu6), + ) + + def conv_dw(inp, oup, stride, relu6): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + relu(relu6), + + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + relu(relu6), + ) + + self.model = nn.Sequential( + conv_bn( 3, 32, 2, relu6), + conv_dw( 32, 64, 1, relu6), + conv_dw( 64, 128, 2, relu6), + conv_dw(128, 128, 1, relu6), + conv_dw(128, 256, 2, relu6), + conv_dw(256, 256, 1, relu6), + conv_dw(256, 512, 2, relu6), + conv_dw(512, 512, 1, relu6), + conv_dw(512, 512, 1, relu6), + conv_dw(512, 512, 1, relu6), + conv_dw(512, 512, 1, relu6), + conv_dw(512, 512, 1, relu6), + conv_dw(512, 1024, 2, relu6), + conv_dw(1024, 1024, 1, relu6), + nn.AvgPool2d(7), + ) + self.fc = nn.Linear(1024, 1000) + + def forward(self, x): + x = self.model(x) + x = x.view(-1, 1024) + x = self.fc(x) + return x + +def main(): + import torchvision.models + model = MobileNet(relu6=True) + model = torch.nn.DataParallel(model).cuda() + model_filename = os.path.join('results', 'imagenet.arch=mobilenet.lr=0.1.bs=256', 'model_best.pth.tar') + if os.path.isfile(model_filename): + print("=> loading Imagenet pretrained model '{}'".format(model_filename)) + checkpoint = torch.load(model_filename) + epoch = checkpoint['epoch'] + best_prec1 = checkpoint['best_prec1'] + model.load_state_dict(checkpoint['state_dict']) + print("=> loaded Imagenet pretrained model '{}' (epoch {}). best_prec1={}".format(model_filename, epoch, best_prec1)) + +if __name__ == '__main__': + main() diff --git a/main.py b/main.py new file mode 100644 index 0000000..79cc67e --- /dev/null +++ b/main.py @@ -0,0 +1,130 @@ +import os +import time +import csv +import numpy as np + +import torch +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +cudnn.benchmark = True + +import models +from metrics import AverageMeter, Result +import utils + +args = utils.parse_command() +print(args) +os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu # Set the GPU. + +fieldnames = ['rmse', 'mae', 'delta1', 'absrel', + 'lg10', 'mse', 'delta2', 'delta3', 'data_time', 'gpu_time'] +best_fieldnames = ['best_epoch'] + fieldnames +best_result = Result() +best_result.set_to_worst() + +def main(): + global args, best_result, output_directory, train_csv, test_csv + + # Data loading code + print("=> creating data loaders...") + valdir = os.path.join('..', 'data', args.data, 'val') + + if args.data == 'nyudepthv2': + from dataloaders.nyu import NYUDataset + val_dataset = NYUDataset(valdir, split='val', modality=args.modality) + else: + raise RuntimeError('Dataset not found.') + + # set batch size to be 1 for validation + val_loader = torch.utils.data.DataLoader(val_dataset, + batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) + print("=> data loaders created.") + + # evaluation mode + if args.evaluate: + assert os.path.isfile(args.evaluate), \ + "=> no model found at '{}'".format(args.evaluate) + print("=> loading model '{}'".format(args.evaluate)) + checkpoint = torch.load(args.evaluate) + if type(checkpoint) is dict: + args.start_epoch = checkpoint['epoch'] + best_result = checkpoint['best_result'] + model = checkpoint['model'] + print("=> loaded best model (epoch {})".format(checkpoint['epoch'])) + else: + model = checkpoint + args.start_epoch = 0 + output_directory = os.path.dirname(args.evaluate) + validate(val_loader, model, args.start_epoch, write_to_file=False) + return + + +def validate(val_loader, model, epoch, write_to_file=True): + average_meter = AverageMeter() + model.eval() # switch to evaluate mode + end = time.time() + for i, (input, target) in enumerate(val_loader): + input, target = input.cuda(), target.cuda() + # torch.cuda.synchronize() + data_time = time.time() - end + + # compute output + end = time.time() + with torch.no_grad(): + pred = model(input) + # torch.cuda.synchronize() + gpu_time = time.time() - end + + # measure accuracy and record loss + result = Result() + result.evaluate(pred.data, target.data) + average_meter.update(result, gpu_time, data_time, input.size(0)) + end = time.time() + + # save 8 images for visualization + skip = 50 + + if args.modality == 'rgb': + rgb = input + + if i == 0: + img_merge = utils.merge_into_row(rgb, target, pred) + elif (i < 8*skip) and (i % skip == 0): + row = utils.merge_into_row(rgb, target, pred) + img_merge = utils.add_row(img_merge, row) + elif i == 8*skip: + filename = output_directory + '/comparison_' + str(epoch) + '.png' + utils.save_image(img_merge, filename) + + if (i+1) % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t' + 'RMSE={result.rmse:.2f}({average.rmse:.2f}) ' + 'MAE={result.mae:.2f}({average.mae:.2f}) ' + 'Delta1={result.delta1:.3f}({average.delta1:.3f}) ' + 'REL={result.absrel:.3f}({average.absrel:.3f}) ' + 'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format( + i+1, len(val_loader), gpu_time=gpu_time, result=result, average=average_meter.average())) + + avg = average_meter.average() + + print('\n*\n' + 'RMSE={average.rmse:.3f}\n' + 'MAE={average.mae:.3f}\n' + 'Delta1={average.delta1:.3f}\n' + 'REL={average.absrel:.3f}\n' + 'Lg10={average.lg10:.3f}\n' + 't_GPU={time:.3f}\n'.format( + average=avg, time=avg.gpu_time)) + + if write_to_file: + with open(test_csv, 'a') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writerow({'mse': avg.mse, 'rmse': avg.rmse, 'absrel': avg.absrel, 'lg10': avg.lg10, + 'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3, + 'data_time': avg.data_time, 'gpu_time': avg.gpu_time}) + return avg, img_merge + +if __name__ == '__main__': + main() diff --git a/metrics.py b/metrics.py new file mode 100644 index 0000000..7884040 --- /dev/null +++ b/metrics.py @@ -0,0 +1,95 @@ +import torch +import math +import numpy as np + +def log10(x): + """Convert a new tensor with the base-10 logarithm of the elements of x. """ + return torch.log(x) / math.log(10) + +class Result(object): + def __init__(self): + self.irmse, self.imae = 0, 0 + self.mse, self.rmse, self.mae = 0, 0, 0 + self.absrel, self.lg10 = 0, 0 + self.delta1, self.delta2, self.delta3 = 0, 0, 0 + self.data_time, self.gpu_time = 0, 0 + + def set_to_worst(self): + self.irmse, self.imae = np.inf, np.inf + self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf + self.absrel, self.lg10 = np.inf, np.inf + self.delta1, self.delta2, self.delta3 = 0, 0, 0 + self.data_time, self.gpu_time = 0, 0 + + def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3, gpu_time, data_time): + self.irmse, self.imae = irmse, imae + self.mse, self.rmse, self.mae = mse, rmse, mae + self.absrel, self.lg10 = absrel, lg10 + self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3 + self.data_time, self.gpu_time = data_time, gpu_time + + def evaluate(self, output, target): + valid_mask = ((target>0) + (output>0)) > 0 + + output = 1e3 * output[valid_mask] + target = 1e3 * target[valid_mask] + abs_diff = (output - target).abs() + + self.mse = float((torch.pow(abs_diff, 2)).mean()) + self.rmse = math.sqrt(self.mse) + self.mae = float(abs_diff.mean()) + self.lg10 = float((log10(output) - log10(target)).abs().mean()) + self.absrel = float((abs_diff / target).mean()) + + maxRatio = torch.max(output / target, target / output) + self.delta1 = float((maxRatio < 1.25).float().mean()) + self.delta2 = float((maxRatio < 1.25 ** 2).float().mean()) + self.delta3 = float((maxRatio < 1.25 ** 3).float().mean()) + self.data_time = 0 + self.gpu_time = 0 + + inv_output = 1 / output + inv_target = 1 / target + abs_inv_diff = (inv_output - inv_target).abs() + self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean()) + self.imae = float(abs_inv_diff.mean()) + + +class AverageMeter(object): + def __init__(self): + self.reset() + + def reset(self): + self.count = 0.0 + + self.sum_irmse, self.sum_imae = 0, 0 + self.sum_mse, self.sum_rmse, self.sum_mae = 0, 0, 0 + self.sum_absrel, self.sum_lg10 = 0, 0 + self.sum_delta1, self.sum_delta2, self.sum_delta3 = 0, 0, 0 + self.sum_data_time, self.sum_gpu_time = 0, 0 + + def update(self, result, gpu_time, data_time, n=1): + self.count += n + + self.sum_irmse += n*result.irmse + self.sum_imae += n*result.imae + self.sum_mse += n*result.mse + self.sum_rmse += n*result.rmse + self.sum_mae += n*result.mae + self.sum_absrel += n*result.absrel + self.sum_lg10 += n*result.lg10 + self.sum_delta1 += n*result.delta1 + self.sum_delta2 += n*result.delta2 + self.sum_delta3 += n*result.delta3 + self.sum_data_time += n*data_time + self.sum_gpu_time += n*gpu_time + + def average(self): + avg = Result() + avg.update( + self.sum_irmse / self.count, self.sum_imae / self.count, + self.sum_mse / self.count, self.sum_rmse / self.count, self.sum_mae / self.count, + self.sum_absrel / self.count, self.sum_lg10 / self.count, + self.sum_delta1 / self.count, self.sum_delta2 / self.count, self.sum_delta3 / self.count, + self.sum_gpu_time / self.count, self.sum_data_time / self.count) + return avg \ No newline at end of file diff --git a/models.py b/models.py new file mode 100644 index 0000000..67673d3 --- /dev/null +++ b/models.py @@ -0,0 +1,814 @@ +import os +import torch +import torch.nn as nn +import torchvision.models +import collections +import math +import torch.nn.functional as F +import imagenet.mobilenet + +class Identity(nn.Module): + # a dummy identity module + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + +class Unpool(nn.Module): + # Unpool: 2*2 unpooling with zero padding + def __init__(self, stride=2): + super(Unpool, self).__init__() + + self.stride = stride + + # create kernel [1, 0; 0, 0] + self.mask = torch.zeros(1, 1, stride, stride) + self.mask[:,:,0,0] = 1 + + def forward(self, x): + assert x.dim() == 4 + num_channels = x.size(1) + return F.conv_transpose2d(x, + self.mask.detach().type_as(x).expand(num_channels, 1, -1, -1), + stride=self.stride, groups=num_channels) + +def weights_init(m): + # Initialize kernel weights with Gaussian distributions + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.ConvTranspose2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + +def conv(in_channels, out_channels, kernel_size): + padding = (kernel_size-1) // 2 + assert 2*padding == kernel_size-1, "parameters incorrect. kernel={}, padding={}".format(kernel_size, padding) + return nn.Sequential( + nn.Conv2d(in_channels,out_channels,kernel_size,stride=1,padding=padding,bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + +def depthwise(in_channels, kernel_size): + padding = (kernel_size-1) // 2 + assert 2*padding == kernel_size-1, "parameters incorrect. kernel={}, padding={}".format(kernel_size, padding) + return nn.Sequential( + nn.Conv2d(in_channels,in_channels,kernel_size,stride=1,padding=padding,bias=False,groups=in_channels), + nn.BatchNorm2d(in_channels), + nn.ReLU(inplace=True), + ) + +def pointwise(in_channels, out_channels): + return nn.Sequential( + nn.Conv2d(in_channels,out_channels,1,1,0,bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + +def convt(in_channels, out_channels, kernel_size): + stride = 2 + padding = (kernel_size - 1) // 2 + output_padding = kernel_size % 2 + assert -2 - 2*padding + kernel_size + output_padding == 0, "deconv parameters incorrect" + return nn.Sequential( + nn.ConvTranspose2d(in_channels,out_channels,kernel_size, + stride,padding,output_padding,bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + +def convt_dw(channels, kernel_size): + stride = 2 + padding = (kernel_size - 1) // 2 + output_padding = kernel_size % 2 + assert -2 - 2*padding + kernel_size + output_padding == 0, "deconv parameters incorrect" + return nn.Sequential( + nn.ConvTranspose2d(channels,channels,kernel_size, + stride,padding,output_padding,bias=False,groups=channels), + nn.BatchNorm2d(channels), + nn.ReLU(inplace=True), + ) + +def upconv(in_channels, out_channels): + return nn.Sequential( + Unpool(2), + nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1,padding=2,bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + +class upproj(nn.Module): + # UpProj module has two branches, with a Unpool at the start and a ReLu at the end + # upper branch: 5*5 conv -> batchnorm -> ReLU -> 3*3 conv -> batchnorm + # bottom branch: 5*5 conv -> batchnorm + + def __init__(self, in_channels, out_channels): + super(upproj, self).__init__() + self.unpool = Unpool(2) + self.branch1 = nn.Sequential( + nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1,padding=2,bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False), + nn.BatchNorm2d(out_channels), + ) + self.branch2 = nn.Sequential( + nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1,padding=2,bias=False), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + x = self.unpool(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + return F.relu(x1 + x2) + +class Decoder(nn.Module): + names = ['deconv{}{}'.format(i,dw) for i in range(3,10,2) for dw in ['', 'dw']] + names.append("upconv") + names.append("upproj") + for i in range(3,10,2): + for dw in ['', 'dw']: + names.append("nnconv{}{}".format(i, dw)) + names.append("blconv{}{}".format(i, dw)) + names.append("shuffle{}{}".format(i, dw)) + +class DeConv(nn.Module): + + def __init__(self, kernel_size, dw): + super(DeConv, self).__init__() + if dw: + self.convt1 = nn.Sequential( + convt_dw(1024, kernel_size), + pointwise(1024, 512)) + self.convt2 = nn.Sequential( + convt_dw(512, kernel_size), + pointwise(512, 256)) + self.convt3 = nn.Sequential( + convt_dw(256, kernel_size), + pointwise(256, 128)) + self.convt4 = nn.Sequential( + convt_dw(128, kernel_size), + pointwise(128, 64)) + self.convt5 = nn.Sequential( + convt_dw(64, kernel_size), + pointwise(64, 32)) + else: + self.convt1 = convt(1024, 512, kernel_size) + self.convt2 = convt(512, 256, kernel_size) + self.convt3 = convt(256, 128, kernel_size) + self.convt4 = convt(128, 64, kernel_size) + self.convt5 = convt(64, 32, kernel_size) + self.convf = pointwise(32, 1) + + def forward(self, x): + x = self.convt1(x) + x = self.convt2(x) + x = self.convt3(x) + x = self.convt4(x) + x = self.convt5(x) + x = self.convf(x) + return x + + +class UpConv(nn.Module): + + def __init__(self): + super(UpConv, self).__init__() + self.upconv1 = upconv(1024, 512) + self.upconv2 = upconv(512, 256) + self.upconv3 = upconv(256, 128) + self.upconv4 = upconv(128, 64) + self.upconv5 = upconv(64, 32) + self.convf = pointwise(32, 1) + + def forward(self, x): + x = self.upconv1(x) + x = self.upconv2(x) + x = self.upconv3(x) + x = self.upconv4(x) + x = self.upconv5(x) + x = self.convf(x) + return x + +class UpProj(nn.Module): + # UpProj decoder consists of 4 upproj modules with decreasing number of channels and increasing feature map size + + def __init__(self): + super(UpProj, self).__init__() + self.upproj1 = upproj(1024, 512) + self.upproj2 = upproj(512, 256) + self.upproj3 = upproj(256, 128) + self.upproj4 = upproj(128, 64) + self.upproj5 = upproj(64, 32) + self.convf = pointwise(32, 1) + + def forward(self, x): + x = self.upproj1(x) + x = self.upproj2(x) + x = self.upproj3(x) + x = self.upproj4(x) + x = self.upproj5(x) + x = self.convf(x) + return x + +class NNConv(nn.Module): + + def __init__(self, kernel_size, dw): + super(NNConv, self).__init__() + if dw: + self.conv1 = nn.Sequential( + depthwise(1024, kernel_size), + pointwise(1024, 512)) + self.conv2 = nn.Sequential( + depthwise(512, kernel_size), + pointwise(512, 256)) + self.conv3 = nn.Sequential( + depthwise(256, kernel_size), + pointwise(256, 128)) + self.conv4 = nn.Sequential( + depthwise(128, kernel_size), + pointwise(128, 64)) + self.conv5 = nn.Sequential( + depthwise(64, kernel_size), + pointwise(64, 32)) + self.conv6 = pointwise(32, 1) + else: + self.conv1 = conv(1024, 512, kernel_size) + self.conv2 = conv(512, 256, kernel_size) + self.conv3 = conv(256, 128, kernel_size) + self.conv4 = conv(128, 64, kernel_size) + self.conv5 = conv(64, 32, kernel_size) + self.conv6 = pointwise(32, 1) + + def forward(self, x): + x = self.conv1(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + + x = self.conv2(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + + x = self.conv3(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + + x = self.conv4(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + + x = self.conv5(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + + x = self.conv6(x) + return x + +class BLConv(NNConv): + + def __init__(self, kernel_size, dw): + super(BLConv, self).__init__(kernel_size, dw) + + def forward(self, x): + x = self.conv1(x) + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + + x = self.conv2(x) + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + + x = self.conv3(x) + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + + x = self.conv4(x) + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + + x = self.conv5(x) + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + + x = self.conv6(x) + return x + +class ShuffleConv(nn.Module): + + def __init__(self, kernel_size, dw): + super(ShuffleConv, self).__init__() + if dw: + self.conv1 = nn.Sequential( + depthwise(256, kernel_size), + pointwise(256, 256)) + self.conv2 = nn.Sequential( + depthwise(64, kernel_size), + pointwise(64, 64)) + self.conv3 = nn.Sequential( + depthwise(16, kernel_size), + pointwise(16, 16)) + self.conv4 = nn.Sequential( + depthwise(4, kernel_size), + pointwise(4, 4)) + else: + self.conv1 = conv(256, 256, kernel_size) + self.conv2 = conv(64, 64, kernel_size) + self.conv3 = conv(16, 16, kernel_size) + self.conv4 = conv(4, 4, kernel_size) + + def forward(self, x): + x = F.pixel_shuffle(x, 2) + x = self.conv1(x) + + x = F.pixel_shuffle(x, 2) + x = self.conv2(x) + + x = F.pixel_shuffle(x, 2) + x = self.conv3(x) + + x = F.pixel_shuffle(x, 2) + x = self.conv4(x) + + x = F.pixel_shuffle(x, 2) + return x + +def choose_decoder(decoder): + depthwise = ('dw' in decoder) + if decoder[:6] == 'deconv': + assert len(decoder)==7 or (len(decoder)==9 and 'dw' in decoder) + kernel_size = int(decoder[6]) + model = DeConv(kernel_size, depthwise) + elif decoder == "upproj": + model = UpProj() + elif decoder == "upconv": + model = UpConv() + elif decoder[:7] == 'shuffle': + assert len(decoder)==8 or (len(decoder)==10 and 'dw' in decoder) + kernel_size = int(decoder[7]) + model = ShuffleConv(kernel_size, depthwise) + elif decoder[:6] == 'nnconv': + assert len(decoder)==7 or (len(decoder)==9 and 'dw' in decoder) + kernel_size = int(decoder[6]) + model = NNConv(kernel_size, depthwise) + elif decoder[:6] == 'blconv': + assert len(decoder)==7 or (len(decoder)==9 and 'dw' in decoder) + kernel_size = int(decoder[6]) + model = BLConv(kernel_size, depthwise) + else: + assert False, "invalid option for decoder: {}".format(decoder) + model.apply(weights_init) + return model + + +class ResNet(nn.Module): + def __init__(self, layers, decoder, output_size, in_channels=3, pretrained=True): + + if layers not in [18, 34, 50, 101, 152]: + raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers)) + + super(ResNet, self).__init__() + self.output_size = output_size + pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained) + if not pretrained: + pretrained_model.apply(weights_init) + + if in_channels == 3: + self.conv1 = pretrained_model._modules['conv1'] + self.bn1 = pretrained_model._modules['bn1'] + else: + self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + weights_init(self.conv1) + weights_init(self.bn1) + + self.relu = pretrained_model._modules['relu'] + self.maxpool = pretrained_model._modules['maxpool'] + self.layer1 = pretrained_model._modules['layer1'] + self.layer2 = pretrained_model._modules['layer2'] + self.layer3 = pretrained_model._modules['layer3'] + self.layer4 = pretrained_model._modules['layer4'] + + # clear memory + del pretrained_model + + # define number of intermediate channels + if layers <= 34: + num_channels = 512 + elif layers >= 50: + num_channels = 2048 + self.conv2 = nn.Conv2d(num_channels, 1024, 1) + weights_init(self.conv2) + self.decoder = choose_decoder(decoder) + + def forward(self, x): + # resnet + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.conv2(x) + + # decoder + x = self.decoder(x) + + return x + +class MobileNet(nn.Module): + def __init__(self, decoder, output_size, in_channels=3, pretrained=True): + + super(MobileNet, self).__init__() + self.output_size = output_size + mobilenet = imagenet.mobilenet.MobileNet() + if pretrained: + pretrained_path = os.path.join('imagenet', 'results', 'imagenet.arch=mobilenet.lr=0.1.bs=256', 'model_best.pth.tar') + checkpoint = torch.load(pretrained_path) + state_dict = checkpoint['state_dict'] + + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + mobilenet.load_state_dict(new_state_dict) + else: + mobilenet.apply(weights_init) + + if in_channels == 3: + self.mobilenet = nn.Sequential(*(mobilenet.model[i] for i in range(14))) + else: + def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + self.mobilenet = nn.Sequential( + conv_bn(in_channels, 32, 2), + *(mobilenet.model[i] for i in range(1,14)) + ) + + self.decoder = choose_decoder(decoder) + + def forward(self, x): + x = self.mobilenet(x) + x = self.decoder(x) + return x + +class ResNetSkipAdd(nn.Module): + def __init__(self, layers, output_size, in_channels=3, pretrained=True): + + if layers not in [18, 34, 50, 101, 152]: + raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers)) + + super(ResNetSkipAdd, self).__init__() + self.output_size = output_size + pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained) + if not pretrained: + pretrained_model.apply(weights_init) + + if in_channels == 3: + self.conv1 = pretrained_model._modules['conv1'] + self.bn1 = pretrained_model._modules['bn1'] + else: + self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + weights_init(self.conv1) + weights_init(self.bn1) + + self.relu = pretrained_model._modules['relu'] + self.maxpool = pretrained_model._modules['maxpool'] + self.layer1 = pretrained_model._modules['layer1'] + self.layer2 = pretrained_model._modules['layer2'] + self.layer3 = pretrained_model._modules['layer3'] + self.layer4 = pretrained_model._modules['layer4'] + + # clear memory + del pretrained_model + + # define number of intermediate channels + if layers <= 34: + num_channels = 512 + elif layers >= 50: + num_channels = 2048 + self.conv2 = nn.Conv2d(num_channels, 1024, 1) + weights_init(self.conv2) + + kernel_size = 5 + self.decode_conv1 = conv(1024, 512, kernel_size) + self.decode_conv2 = conv(512, 256, kernel_size) + self.decode_conv3 = conv(256, 128, kernel_size) + self.decode_conv4 = conv(128, 64, kernel_size) + self.decode_conv5 = conv(64, 32, kernel_size) + self.decode_conv6 = pointwise(32, 1) + weights_init(self.decode_conv1) + weights_init(self.decode_conv2) + weights_init(self.decode_conv3) + weights_init(self.decode_conv4) + weights_init(self.decode_conv5) + weights_init(self.decode_conv6) + + def forward(self, x): + # resnet + x = self.conv1(x) + x = self.bn1(x) + x1 = self.relu(x) + # print("x1", x1.size()) + x2 = self.maxpool(x1) + # print("x2", x2.size()) + x3 = self.layer1(x2) + # print("x3", x3.size()) + x4 = self.layer2(x3) + # print("x4", x4.size()) + x5 = self.layer3(x4) + # print("x5", x5.size()) + x6 = self.layer4(x5) + # print("x6", x6.size()) + x7 = self.conv2(x6) + + # decoder + y10 = self.decode_conv1(x7) + # print("y10", y10.size()) + y9 = F.interpolate(y10 + x6, scale_factor=2, mode='nearest') + # print("y9", y9.size()) + y8 = self.decode_conv2(y9) + # print("y8", y8.size()) + y7 = F.interpolate(y8 + x5, scale_factor=2, mode='nearest') + # print("y7", y7.size()) + y6 = self.decode_conv3(y7) + # print("y6", y6.size()) + y5 = F.interpolate(y6 + x4, scale_factor=2, mode='nearest') + # print("y5", y5.size()) + y4 = self.decode_conv4(y5) + # print("y4", y4.size()) + y3 = F.interpolate(y4 + x3, scale_factor=2, mode='nearest') + # print("y3", y3.size()) + y2 = self.decode_conv5(y3 + x1) + # print("y2", y2.size()) + y1 = F.interpolate(y2, scale_factor=2, mode='nearest') + # print("y1", y1.size()) + y = self.decode_conv6(y1) + + return y + +class ResNetSkipConcat(nn.Module): + def __init__(self, layers, output_size, in_channels=3, pretrained=True): + + if layers not in [18, 34, 50, 101, 152]: + raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers)) + + super(ResNetSkipConcat, self).__init__() + self.output_size = output_size + pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained) + if not pretrained: + pretrained_model.apply(weights_init) + + if in_channels == 3: + self.conv1 = pretrained_model._modules['conv1'] + self.bn1 = pretrained_model._modules['bn1'] + else: + self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + weights_init(self.conv1) + weights_init(self.bn1) + + self.relu = pretrained_model._modules['relu'] + self.maxpool = pretrained_model._modules['maxpool'] + self.layer1 = pretrained_model._modules['layer1'] + self.layer2 = pretrained_model._modules['layer2'] + self.layer3 = pretrained_model._modules['layer3'] + self.layer4 = pretrained_model._modules['layer4'] + + # clear memory + del pretrained_model + + # define number of intermediate channels + if layers <= 34: + num_channels = 512 + elif layers >= 50: + num_channels = 2048 + self.conv2 = nn.Conv2d(num_channels, 1024, 1) + weights_init(self.conv2) + + kernel_size = 5 + self.decode_conv1 = conv(1024, 512, kernel_size) + self.decode_conv2 = conv(768, 256, kernel_size) + self.decode_conv3 = conv(384, 128, kernel_size) + self.decode_conv4 = conv(192, 64, kernel_size) + self.decode_conv5 = conv(128, 32, kernel_size) + self.decode_conv6 = pointwise(32, 1) + weights_init(self.decode_conv1) + weights_init(self.decode_conv2) + weights_init(self.decode_conv3) + weights_init(self.decode_conv4) + weights_init(self.decode_conv5) + weights_init(self.decode_conv6) + + def forward(self, x): + # resnet + x = self.conv1(x) + x = self.bn1(x) + x1 = self.relu(x) + # print("x1", x1.size()) + x2 = self.maxpool(x1) + # print("x2", x2.size()) + x3 = self.layer1(x2) + # print("x3", x3.size()) + x4 = self.layer2(x3) + # print("x4", x4.size()) + x5 = self.layer3(x4) + # print("x5", x5.size()) + x6 = self.layer4(x5) + # print("x6", x6.size()) + x7 = self.conv2(x6) + + # decoder + y10 = self.decode_conv1(x7) + # print("y10", y10.size()) + y9 = F.interpolate(y10, scale_factor=2, mode='nearest') + # print("y9", y9.size()) + y8 = self.decode_conv2(torch.cat((y9, x5), 1)) + # print("y8", y8.size()) + y7 = F.interpolate(y8, scale_factor=2, mode='nearest') + # print("y7", y7.size()) + y6 = self.decode_conv3(torch.cat((y7, x4), 1)) + # print("y6", y6.size()) + y5 = F.interpolate(y6, scale_factor=2, mode='nearest') + # print("y5", y5.size()) + y4 = self.decode_conv4(torch.cat((y5, x3), 1)) + # print("y4", y4.size()) + y3 = F.interpolate(y4, scale_factor=2, mode='nearest') + # print("y3", y3.size()) + y2 = self.decode_conv5(torch.cat((y3, x1), 1)) + # print("y2", y2.size()) + y1 = F.interpolate(y2, scale_factor=2, mode='nearest') + # print("y1", y1.size()) + y = self.decode_conv6(y1) + + return y + +class MobileNetSkipAdd(nn.Module): + def __init__(self, output_size, pretrained=True): + + super(MobileNetSkipAdd, self).__init__() + self.output_size = output_size + mobilenet = imagenet.mobilenet.MobileNet() + if pretrained: + pretrained_path = os.path.join('imagenet', 'results', 'imagenet.arch=mobilenet.lr=0.1.bs=256', 'model_best.pth.tar') + checkpoint = torch.load(pretrained_path) + state_dict = checkpoint['state_dict'] + + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + mobilenet.load_state_dict(new_state_dict) + else: + mobilenet.apply(weights_init) + + for i in range(14): + setattr( self, 'conv{}'.format(i), mobilenet.model[i]) + + kernel_size = 5 + # self.decode_conv1 = conv(1024, 512, kernel_size) + # self.decode_conv2 = conv(512, 256, kernel_size) + # self.decode_conv3 = conv(256, 128, kernel_size) + # self.decode_conv4 = conv(128, 64, kernel_size) + # self.decode_conv5 = conv(64, 32, kernel_size) + self.decode_conv1 = nn.Sequential( + depthwise(1024, kernel_size), + pointwise(1024, 512)) + self.decode_conv2 = nn.Sequential( + depthwise(512, kernel_size), + pointwise(512, 256)) + self.decode_conv3 = nn.Sequential( + depthwise(256, kernel_size), + pointwise(256, 128)) + self.decode_conv4 = nn.Sequential( + depthwise(128, kernel_size), + pointwise(128, 64)) + self.decode_conv5 = nn.Sequential( + depthwise(64, kernel_size), + pointwise(64, 32)) + self.decode_conv6 = pointwise(32, 1) + weights_init(self.decode_conv1) + weights_init(self.decode_conv2) + weights_init(self.decode_conv3) + weights_init(self.decode_conv4) + weights_init(self.decode_conv5) + weights_init(self.decode_conv6) + + def forward(self, x): + # skip connections: dec4: enc1 + # dec 3: enc2 or enc3 + # dec 2: enc4 or enc5 + for i in range(14): + layer = getattr(self, 'conv{}'.format(i)) + x = layer(x) + # print("{}: {}".format(i, x.size())) + if i==1: + x1 = x + elif i==3: + x2 = x + elif i==5: + x3 = x + for i in range(1,6): + layer = getattr(self, 'decode_conv{}'.format(i)) + x = layer(x) + x = F.interpolate(x, scale_factor=2, mode='nearest') + if i==4: + x = x + x1 + elif i==3: + x = x + x2 + elif i==2: + x = x + x3 + # print("{}: {}".format(i, x.size())) + x = self.decode_conv6(x) + return x + +class MobileNetSkipConcat(nn.Module): + def __init__(self, output_size, pretrained=True): + + super(MobileNetSkipConcat, self).__init__() + self.output_size = output_size + mobilenet = imagenet.mobilenet.MobileNet() + if pretrained: + pretrained_path = os.path.join('imagenet', 'results', 'imagenet.arch=mobilenet.lr=0.1.bs=256', 'model_best.pth.tar') + checkpoint = torch.load(pretrained_path) + state_dict = checkpoint['state_dict'] + + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + mobilenet.load_state_dict(new_state_dict) + else: + mobilenet.apply(weights_init) + + for i in range(14): + setattr( self, 'conv{}'.format(i), mobilenet.model[i]) + + kernel_size = 5 + # self.decode_conv1 = conv(1024, 512, kernel_size) + # self.decode_conv2 = conv(512, 256, kernel_size) + # self.decode_conv3 = conv(256, 128, kernel_size) + # self.decode_conv4 = conv(128, 64, kernel_size) + # self.decode_conv5 = conv(64, 32, kernel_size) + self.decode_conv1 = nn.Sequential( + depthwise(1024, kernel_size), + pointwise(1024, 512)) + self.decode_conv2 = nn.Sequential( + depthwise(512, kernel_size), + pointwise(512, 256)) + self.decode_conv3 = nn.Sequential( + depthwise(512, kernel_size), + pointwise(512, 128)) + self.decode_conv4 = nn.Sequential( + depthwise(256, kernel_size), + pointwise(256, 64)) + self.decode_conv5 = nn.Sequential( + depthwise(128, kernel_size), + pointwise(128, 32)) + self.decode_conv6 = pointwise(32, 1) + weights_init(self.decode_conv1) + weights_init(self.decode_conv2) + weights_init(self.decode_conv3) + weights_init(self.decode_conv4) + weights_init(self.decode_conv5) + weights_init(self.decode_conv6) + + def forward(self, x): + # skip connections: dec4: enc1 + # dec 3: enc2 or enc3 + # dec 2: enc4 or enc5 + for i in range(14): + layer = getattr(self, 'conv{}'.format(i)) + x = layer(x) + # print("{}: {}".format(i, x.size())) + if i==1: + x1 = x + elif i==3: + x2 = x + elif i==5: + x3 = x + for i in range(1,6): + layer = getattr(self, 'decode_conv{}'.format(i)) + # print("{}a: {}".format(i, x.size())) + x = layer(x) + # print("{}b: {}".format(i, x.size())) + x = F.interpolate(x, scale_factor=2, mode='nearest') + if i==4: + x = torch.cat((x, x1), 1) + elif i==3: + x = torch.cat((x, x2), 1) + elif i==2: + x = torch.cat((x, x3), 1) + # print("{}c: {}".format(i, x.size())) + x = self.decode_conv6(x) + return x diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..beadb9e --- /dev/null +++ b/utils.py @@ -0,0 +1,83 @@ +import os +import torch +import shutil +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image +import math + +cmap = plt.cm.viridis + + +def parse_command(): + data_names = ['nyudepthv2'] + + from dataloaders.dataloader import MyDataloader + modality_names = MyDataloader.modality_names + + import argparse + parser = argparse.ArgumentParser(description='FastDepth') + parser.add_argument('--data', metavar='DATA', default='nyudepthv2', + choices=data_names, + help='dataset: ' + ' | '.join(data_names) + ' (default: nyudepthv2)') + parser.add_argument('--modality', '-m', metavar='MODALITY', default='rgb', choices=modality_names, + help='modality: ' + ' | '.join(modality_names) + ' (default: rgb)') + parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', + help='number of data loading workers (default: 16)') + parser.add_argument('--print-freq', '-p', default=50, type=int, + metavar='N', help='print frequency (default: 50)') + parser.add_argument('-e', '--evaluate', default='', type=str, metavar='PATH',) + parser.add_argument('--gpu', default='0', type=str, metavar='N', help="gpu id") + parser.set_defaults(cuda=True) + + args = parser.parse_args() + return args + + +def colored_depthmap(depth, d_min=None, d_max=None): + if d_min is None: + d_min = np.min(depth) + if d_max is None: + d_max = np.max(depth) + depth_relative = (depth - d_min) / (d_max - d_min) + return 255 * cmap(depth_relative)[:,:,:3] # H, W, C + + +def merge_into_row(input, depth_target, depth_pred): + rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C + depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) + depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy()) + + d_min = min(np.min(depth_target_cpu), np.min(depth_pred_cpu)) + d_max = max(np.max(depth_target_cpu), np.max(depth_pred_cpu)) + depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max) + depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max) + img_merge = np.hstack([rgb, depth_target_col, depth_pred_col]) + + return img_merge + + +def merge_into_row_with_gt(input, depth_input, depth_target, depth_pred): + rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C + depth_input_cpu = np.squeeze(depth_input.cpu().numpy()) + depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) + depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy()) + + d_min = min(np.min(depth_input_cpu), np.min(depth_target_cpu), np.min(depth_pred_cpu)) + d_max = max(np.max(depth_input_cpu), np.max(depth_target_cpu), np.max(depth_pred_cpu)) + depth_input_col = colored_depthmap(depth_input_cpu, d_min, d_max) + depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max) + depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max) + + img_merge = np.hstack([rgb, depth_input_col, depth_target_col, depth_pred_col]) + + return img_merge + + +def add_row(img_merge, row): + return np.vstack([img_merge, row]) + + +def save_image(img_merge, filename): + img_merge = Image.fromarray(img_merge.astype('uint8')) + img_merge.save(filename)