diff --git a/aim22-reverseisp/teams/HIT-IIL/README.md b/aim22-reverseisp/teams/HIT-IIL/README.md new file mode 100644 index 0000000..de719e5 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/README.md @@ -0,0 +1,3 @@ +- [The pre-trained models for track 1](https://drive.google.com/drive/folders/1GmahiWwpMsPb9Y37TEiHs-ittD7ZAF0A?usp=sharing) can be downloaded. You need to put them in the `sRGB-to-RAW-s7/ckpt` folder. +- [The pre-trained models for track 2](https://drive.google.com/drive/folders/1VMFi8ombywlD60sroWJOUwv_0l_CF71N?usp=sharing) can be downloaded. You need to put them in the `sRGB-to-RAW-p20/ckpt` folder. + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/__init__.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/__init__.py new file mode 100644 index 0000000..4559e6f --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/__init__.py @@ -0,0 +1,57 @@ +import importlib +import torch.utils.data +from data.base_dataset import BaseDataset + + +def find_dataset_using_name(dataset_name, split='train'): + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + raise NotImplementedError("In %s.py, there should be a subclass of " + "BaseDataset with class name that matches %s in " + "lowercase." % (dataset_filename, target_dataset_name)) + return dataset + + +def create_dataset(dataset_name, split, opt): + data_loader = CustomDatasetDataLoader(dataset_name, split, opt) + dataset = data_loader.load_data() + return dataset + + +class CustomDatasetDataLoader(): + def __init__(self, dataset_name, split, opt): + self.opt = opt + dataset_class = find_dataset_using_name(dataset_name, split) + self.dataset = dataset_class(opt, split, dataset_name) +# self.imio = self.dataset.imio + print("dataset [%s(%s)] created" % (dataset_name, split)) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batch_size if split=='train' else 1, + shuffle=opt.shuffle and split=='train', + num_workers=int(opt.num_dataloader), + drop_last=opt.drop_last) + + def load_data(self): + return self + + def __len__(self): + """Return the number of data in the dataset""" + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + """Return a batch of data""" + for i, data in enumerate(self.dataloader): + if i * self.opt.batch_size >= self.opt.max_dataset_size: + break + yield data + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/base_dataset.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/base_dataset.py new file mode 100644 index 0000000..eadad20 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/base_dataset.py @@ -0,0 +1,19 @@ +import torch.utils.data as data +from abc import ABC, abstractmethod + + +class BaseDataset(data.Dataset, ABC): + def __init__(self, opt, split, dataset_name): + self.opt = opt + self.split = split + self.root = opt.dataroot + self.dataset_name = dataset_name.lower() + + @abstractmethod + def __len__(self): + return 0 + + @abstractmethod + def __getitem__(self, index): + pass + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/imlib.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/imlib.py new file mode 100644 index 0000000..b06539c --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/imlib.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- + +import numpy as np +import os +import cv2 +from PIL import Image +from functools import wraps +import time + + +class imlib(): + """ + Note that YCxCx in OpenCV and PIL are different. + Therefore, be careful if a model is trained with OpenCV and tested with + PIL in Y mode, and vise versa + + force_color = True: return a 3 channel YCxCx image + For mode 'Y', if a gray image is given, repeat the channel for 3 times, + and then converted to YCxCx mode. + force_color = False: return a 3 channel YCxCx image or a 1 channel gray one. + For mode 'Y', if a gray image is given, the gray image is directly used. + """ + def __init__(self, mode='RGB', fmt='CHW', lib='cv2', force_color=True): + assert mode.upper() in ('RGB', 'L', 'Y', 'RAW') + self.mode = mode.upper() + + assert fmt.upper() in ('HWC', 'CHW', 'NHWC', 'NCHW') + self.fmt = 'CHW' if fmt.upper() in ('CHW', 'NCHW') else 'HWC' + + assert lib.lower() in ('cv2', 'pillow') + self.lib = lib.lower() + + self.force_color = force_color + + self.dtype = np.uint8 + + self._imread = getattr(self, '_imread_%s_%s'%(self.lib, self.mode)) + self._imwrite = getattr(self, '_imwrite_%s_%s'%(self.lib, self.mode)) + self._trans_batch = getattr(self, '_trans_batch_%s_%s' + % (self.mode, self.fmt)) + self._trans_image = getattr(self, '_trans_image_%s_%s' + % (self.mode, self.fmt)) + self._trans_back = getattr(self, '_trans_back_%s_%s' + % (self.mode, self.fmt)) + + def _imread_cv2_RGB(self, path): + return cv2_imread(path, cv2.IMREAD_COLOR)[..., ::-1] + def _imread_cv2_RAW(self, path): + return cv2_imread(path, -1) + def _imread_cv2_Y(self, path): + if self.force_color: + img = cv2_imread(path, cv2.IMREAD_COLOR) + else: + img = cv2_imread(path, cv2.IMREAD_ANYCOLOR) + if len(img.shape) == 2: + return np.expand_dims(img, 3) + elif len(img.shape) == 3: + return cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) + else: + raise ValueError('The dimension should be either 2 or 3.') + def _imread_cv2_L(self, path): + return cv2_imread(path, cv2.IMREAD_GRAYSCALE) + + def _imread_pillow_RGB(self, path): + img = Image.open(path) + im = np.array(img.convert(self.mode)) + img.close() + return im + _imread_pillow_L = _imread_pillow_RGB + # WARNING: the RGB->YCbCr procedure of PIL may be different with OpenCV + def _imread_pillow_Y(self, path): + img = Image.open(path) + if img.mode == 'RGB': + im = np.array(img.convert('YCbCr')) + elif img.mode == 'L': + if self.force_color: + im = np.array(img.convert('RGB').convert('YCbCr')) + else: + im = np.expand_dims(np.array(img), 3) + else: + img.close() + raise NotImplementedError('Only support RGB and gray images now.') + img.close() + return im + + def _imwrite_cv2_RGB(self, image, path): + cv2.imwrite(path, image[..., ::-1]) + def _imwrite_cv2_RAW(self, image, path): + pass + def _imwrite_cv2_Y(self, image, path): + if image.shape[2] == 1: + cv2.imwrite(path, image[..., 0]) + elif image.shape[2] == 3: + cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_YCrCb2BGR)) + else: + raise ValueError('There should be 1 or 3 channels.') + def _imwrite_cv2_L(self, image, path): + cv2.imwrite(path, image) + + def _imwrite_pillow_RGB(self, image, path): + Image.fromarray(image).save(path) + _imwrite_pillow_L = _imwrite_pillow_RGB + def _imwrite_pillow_Y(self, image, path): + if image.shape[2] == 1: + self._imwrite_pillow_L(np.squeeze(image, 2), path) + elif image.shape[2] == 3: + Image.fromarray(image, mode='YCbCr').convert('RGB').save(path) + else: + raise ValueError('There should be 1 or 3 channels.') + + def _trans_batch_RGB_HWC(self, images): + return np.ascontiguousarray(images) + def _trans_batch_RGB_CHW(self, images): + return np.ascontiguousarray(np.transpose(images, (0, 3, 1, 2))) + _trans_batch_RAW_HWC = _trans_batch_RGB_HWC + _trans_batch_RAW_CHW = _trans_batch_RGB_CHW + _trans_batch_Y_HWC = _trans_batch_RGB_HWC + _trans_batch_Y_CHW = _trans_batch_RGB_CHW + def _trans_batch_L_HWC(self, images): + return np.ascontiguousarray(np.expand_dims(images, 3)) + def _trans_batch_L_CHW(slef, images): + return np.ascontiguousarray(np.expand_dims(images, 1)) + + def _trans_image_RGB_HWC(self, image): + return np.ascontiguousarray(image) + def _trans_image_RGB_CHW(self, image): + return np.ascontiguousarray(np.transpose(image, (2, 0, 1))) + _trans_image_RAW_HWC = _trans_image_RGB_HWC + _trans_image_RAW_CHW = _trans_image_RGB_CHW + _trans_image_Y_HWC = _trans_image_RGB_HWC + _trans_image_Y_CHW = _trans_image_RGB_CHW + def _trans_image_L_HWC(self, image): + return np.ascontiguousarray(np.expand_dims(image, 2)) + def _trans_image_L_CHW(self, image): + return np.ascontiguousarray(np.expand_dims(image, 0)) + + def _trans_back_RGB_HWC(self, image): + return image + def _trans_back_RGB_CHW(self, image): + return np.transpose(image, (1, 2, 0)) + _trans_back_RAW_HWC = _trans_back_RGB_HWC + _trans_back_RAW_CHW = _trans_back_RGB_CHW + _trans_back_Y_HWC = _trans_back_RGB_HWC + _trans_back_Y_CHW = _trans_back_RGB_CHW + def _trans_back_L_HWC(self, image): + return np.squeeze(image, 2) + def _trans_back_L_CHW(self, image): + return np.squeeze(image, 0) + + img_ext = ('png', 'PNG', 'jpg', 'JPG', 'bmp', 'BMP', 'jpeg', 'JPEG') + + def is_image(self, fname): + return any(fname.endswith(i) for i in self.img_ext) + + def read(self, paths): + if isinstance(paths, (list, tuple)): + images = [self._imread(path) for path in paths] + return self._trans_batch(np.array(images)) + return self._trans_image(self._imread(paths)) + + def back(self, image): + return self._trans_back(image) + + def write(self, image, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + self._imwrite(self.back(image), path) + +def read_until_success(func): + @wraps(func) + def wrapper(*args, **kwargs): + for i in range(30): + try: + ret = func(*args, **kwargs) + if ret is None: + raise OSError() + else: + break + except OSError: + print('%s OSError' % str(args)) + time.sleep(1) + return ret + return wrapper + +@read_until_success +def cv2_imread(*args, **kwargs): + return cv2.imread(*args, **kwargs) + +# if __name__ == '__main__': +# import matplotlib.pyplot as plt +# im_rgb_chw_cv2 = imlib('rgb', fmt='chw', lib='cv2') +# im_rgb_hwc_cv2 = imlib('rgb', fmt='hwc', lib='cv2') +# im_rgb_chw_pil = imlib('rgb', fmt='chw', lib='pillow') +# im_rgb_hwc_pil = imlib('rgb', fmt='hwc', lib='pillow') +# im_y_chw_cv2 = imlib('y', fmt='chw', lib='cv2') +# im_y_hwc_cv2 = imlib('y', fmt='hwc', lib='cv2') +# im_y_chw_pil = imlib('y', fmt='chw', lib='pillow') +# im_y_hwc_pil = imlib('y', fmt='hwc', lib='pillow') +# im_l_chw_cv2 = imlib('l', fmt='chw', lib='cv2') +# im_l_hwc_cv2 = imlib('l', fmt='hwc', lib='cv2') +# im_l_chw_pil = imlib('l', fmt='chw', lib='pillow') +# im_l_hwc_pil = imlib('l', fmt='hwc', lib='pillow') +# path = 'D:/Datasets/test/000001.jpg' + +# img_rgb_chw_cv2 = im_rgb_chw_cv2.read(path) +# print(img_rgb_chw_cv2.shape) +# plt.imshow(im_rgb_chw_cv2.back(img_rgb_chw_cv2)) +# plt.show() +# im_rgb_chw_cv2.write(img_rgb_chw_cv2, +# (path.replace('000001.jpg', 'img_rgb_chw_cv2.jpg'))) +# img_rgb_hwc_cv2 = im_rgb_hwc_cv2.read(path) +# print(img_rgb_hwc_cv2.shape) +# plt.imshow(im_rgb_hwc_cv2.back(img_rgb_hwc_cv2)) +# plt.show() +# im_rgb_hwc_cv2.write(img_rgb_hwc_cv2, +# (path.replace('000001.jpg', 'img_rgb_hwc_cv2.jpg'))) +# img_rgb_chw_pil = im_rgb_chw_pil.read(path) +# print(img_rgb_chw_pil.shape) +# plt.imshow(im_rgb_chw_pil.back(img_rgb_chw_pil)) +# plt.show() +# im_rgb_chw_pil.write(img_rgb_chw_pil, +# (path.replace('000001.jpg', 'img_rgb_chw_pil.jpg'))) +# img_rgb_hwc_pil = im_rgb_hwc_pil.read(path) +# print(img_rgb_hwc_pil.shape) +# plt.imshow(im_rgb_hwc_pil.back(img_rgb_hwc_pil)) +# plt.show() +# im_rgb_hwc_pil.write(img_rgb_hwc_pil, +# (path.replace('000001.jpg', 'img_rgb_hwc_pil.jpg'))) + + +# img_y_chw_cv2 = im_y_chw_cv2.read(path) +# print(img_y_chw_cv2.shape) +# plt.imshow(np.squeeze(im_y_chw_cv2.back(img_y_chw_cv2))) +# plt.show() +# im_y_chw_cv2.write(img_y_chw_cv2, +# (path.replace('000001.jpg', 'img_y_chw_cv2.jpg'))) +# img_y_hwc_cv2 = im_y_hwc_cv2.read(path) +# print(img_y_hwc_cv2.shape) +# plt.imshow(np.squeeze(im_y_hwc_cv2.back(img_y_hwc_cv2))) +# plt.show() +# im_y_hwc_cv2.write(img_y_hwc_cv2, +# (path.replace('000001.jpg', 'img_y_hwc_cv2.jpg'))) +# img_y_chw_pil = im_y_chw_pil.read(path) +# print(img_y_chw_pil.shape) +# plt.imshow(np.squeeze(im_y_chw_pil.back(img_y_chw_pil))) +# plt.show() +# im_y_chw_pil.write(img_y_chw_pil, +# (path.replace('000001.jpg', 'img_y_chw_pil.jpg'))) +# img_y_hwc_pil = im_y_hwc_pil.read(path) +# print(img_y_hwc_pil.shape) +# plt.imshow(np.squeeze(im_y_hwc_pil.back(img_y_hwc_pil))) +# plt.show() +# im_y_hwc_pil.write(img_y_hwc_pil, +# (path.replace('000001.jpg', 'img_y_hwc_pil.jpg'))) + + +# img_l_chw_cv2 = im_l_chw_cv2.read(path) +# print(img_l_chw_cv2.shape) +# plt.imshow(im_l_chw_cv2.back(img_l_chw_cv2)) +# plt.show() +# im_l_chw_cv2.write(img_l_chw_cv2, +# (path.replace('000001.jpg', 'img_l_chw_cv2.jpg'))) +# img_l_hwc_cv2 = im_l_hwc_cv2.read(path) +# print(img_l_hwc_cv2.shape) +# plt.imshow(im_l_hwc_cv2.back(img_l_hwc_cv2)) +# plt.show() +# im_l_hwc_cv2.write(img_l_hwc_cv2, +# (path.replace('000001.jpg', 'img_l_hwc_cv2.jpg'))) +# img_l_chw_pil = im_l_chw_pil.read(path) +# print(img_l_chw_pil.shape) +# plt.imshow(im_l_chw_pil.back(img_l_chw_pil)) +# plt.show() +# im_l_chw_pil.write(img_l_chw_pil, +# (path.replace('000001.jpg', 'img_l_chw_pil.jpg'))) +# img_l_hwc_pil = im_l_hwc_pil.read(path) +# print(img_l_hwc_pil.shape) +# plt.imshow(im_l_hwc_pil.back(img_l_hwc_pil)) +# plt.show() +# im_l_hwc_pil.write(img_l_hwc_pil, +# (path.replace('000001.jpg', 'img_l_hwc_pil.jpg'))) diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/p20_dataset.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/p20_dataset.py new file mode 100644 index 0000000..def4478 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/p20_dataset.py @@ -0,0 +1,113 @@ +import numpy as np +import os +from data.base_dataset import BaseDataset +from .imlib import imlib +from multiprocessing.dummy import Pool +from tqdm import tqdm +from util.util import augment, remove_black_level, get_coord +from util.util import extract_bayer_channels, get_raw_demosaic, load_img +import glob +# from skimage.exposure import match_histograms + + +# Zurich RAW to RGB (ZRR) dataset +class P20Dataset(BaseDataset): + def __init__(self, opt, split='train', dataset_name='ZRR'): + super(P20Dataset, self).__init__(opt, split, dataset_name) + + + self.batch_size = opt.batch_size + + if split == 'train': + self.root_dir = os.path.join(self.root,'train'); + self.train_raws = sorted(glob.glob(os.path.join( self.root_dir, '*.npy'))) + self.train_rgbs = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))) + self.names = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))); + self._getitem = self._getitem_train + + elif split == 'val': + self.root_dir = os.path.join(self.root, 'test') + self.test_raws = sorted(glob.glob(os.path.join( self.root_dir, '*.npy'))) + self.test_rgbs = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))) + self.names = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))); + self._getitem = self._getitem_val + + elif split == 'test': + self.root_dir = os.path.join(self.root, 'test') + self.test_rgbs = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))) + self.names = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))); + self._getitem = self._getitem_test + + else: + raise ValueError + + self.len_data = len(self.names) + + + def __getitem__(self, index): + return self._getitem(index) + + def __len__(self): + return self.len_data + + def _getitem_train(self, idx): + + raw = np.load(self.train_raws[idx], encoding='bytes', allow_pickle=True); + raw_combined, raw_demosaic = self._process_raw(raw) + dslr_image = load_img(self.train_rgbs[idx]) + dslr_image = np.ascontiguousarray(dslr_image.transpose((2, 0, 1))) + + raw_combined, raw_demosaic, dslr_image = augment( + raw_combined, raw_demosaic, dslr_image) + + return {'raw': raw_combined, + 'raw_demosaic': raw_demosaic, + 'dslr': dslr_image, + 'fname': self.names[idx]} + + def _getitem_val(self, idx): + raw = np.load(self.test_raws[idx], encoding='bytes', allow_pickle=True); + raw_combined, raw_demosaic = self._process_raw(raw) + dslr_image = load_img(self.test_rgbs[idx]) + dslr_image = dslr_image.transpose((2, 0, 1)) + + return {'raw': raw_combined, + 'raw_demosaic': raw_demosaic, + 'dslr': dslr_image, + 'fname': self.names[idx]} + + def _getitem_test(self, idx): + dslr_image = load_img(self.test_rgbs[idx]) + dslr_image = dslr_image.transpose((2, 0, 1)) + + return { + 'dslr': dslr_image, + 'fname': self.names[idx]} + + + def _process_raw(self, raw): + raw = remove_black_level(raw,white_lv=4*255) + raw_combined = extract_bayer_channels(raw) + raw_demosaic = get_raw_demosaic(raw) + return raw_combined, raw_demosaic + +def iter_obj(num, objs): + for i in range(num): + yield (i, objs) + +def imreader(arg): + # Due to the memory (32 GB) limitation, here we only preload the raw images. + # If you have enough memory, you can also modify the code to preload the sRGB images to speed up the training process. + i, obj = arg + for _ in range(3): + try: + obj.raw_images[i] = obj.raw_imio.read(os.path.join(obj.raw_dir, obj.names[i] + '.png')) + failed = False + break + except: + failed = True + if failed: print('%s fails!' % obj.names[i]) + + +if __name__ == '__main__': + pass diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/p20patch_dataset.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/p20patch_dataset.py new file mode 100644 index 0000000..d2641a0 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/p20patch_dataset.py @@ -0,0 +1,115 @@ +import numpy as np +import os +from data.base_dataset import BaseDataset +from util.util import augment, remove_black_level +from util.util import extract_bayer_channels, get_raw_demosaic, load_img +import glob +import random + + +# Zurich RAW to RGB (ZRR) dataset +class P20patchDataset(BaseDataset): + def __init__(self, opt, split='train', dataset_name='ZRR'): + super(P20patchDataset, self).__init__(opt, split, dataset_name) + + + self.batch_size = opt.batch_size + + if split == 'train': + self.root_dir = os.path.join(self.root,'train_full'); + self.train_raws = sorted(glob.glob(os.path.join( self.root_dir, '*.npy'))) + self.train_rgbs = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))) + self.train_raws_file=[] + self.train_rgbs_file=[] + self.patch = opt.patch_size + for seq_path in self.train_raws: + seq = np.load(seq_path, encoding='bytes', allow_pickle=True); + self.train_raws_file.append(seq) + + for seq_path in self.train_rgbs: + seq = load_img(seq_path) + self.train_rgbs_file.append(seq) + + self.names = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))); + self._getitem = self._getitem_train + self.len_data = len(self.names)*48 + + elif split == 'val': + self.root_dir = os.path.join(self.root, 'test_full') + self.patch = opt.patch_size + self.test_raws = sorted(glob.glob(os.path.join( self.root_dir, '*.npy'))) + self.test_rgbs = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))) + self.names = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))); + self._getitem = self._getitem_val + self.len_data = len(self.names) + + elif split == 'test': + self.root_dir = os.path.join(self.root) + self.test_rgbs = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))) + self.names = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))); + self._getitem = self._getitem_test + self.len_data = len(self.names) + + else: + raise ValueError + + + + + def __getitem__(self, index): + return self._getitem(index) + + def __len__(self): + return self.len_data + + def _getitem_train(self, idx): + idx = idx % (self.len_data//48) + H,W,C = self.train_raws_file[idx].shape + crop_h = random.randrange(0,H - self.patch ,2) + crop_w = random.randrange(0,W - self.patch ,2) + raw = self.train_raws_file[idx][crop_h:crop_h+self.patch , crop_w:crop_w+self.patch,:] + dslr_image = self.train_rgbs_file[idx][2*crop_h:2*crop_h+2*self.patch, 2*crop_w:2*crop_w+2*self.patch,:] + raw_combined, raw_demosaic = self._process_raw(raw) + dslr_image = dslr_image.transpose((2, 0, 1)) + raw_combined, raw_demosaic, dslr_image = augment( + raw_combined, raw_demosaic, dslr_image) + + return {'raw': raw_combined, + 'raw_demosaic': raw_demosaic, + 'dslr': dslr_image, + 'fname': self.names[idx]} + + def _getitem_val(self, idx): + raw_init = np.load(self.test_raws[idx], encoding='bytes', allow_pickle=True); + raw = raw_init[0:0+self.patch , 0:0+self.patch,:] + raw_combined, raw_demosaic = self._process_raw(raw) + dslr_image_init = load_img(self.test_rgbs[idx]) + dslr_image = dslr_image_init[0:0+2*self.patch , 0:0 + 2*self.patch,:] + dslr_image = dslr_image.transpose((2, 0, 1)) + + return {'raw': raw_combined, + 'raw_demosaic': raw_demosaic, + 'dslr': dslr_image, + 'fname': self.names[idx]} + + def _getitem_test(self, idx): + dslr_image = load_img(self.test_rgbs[idx]) + dslr_image = dslr_image.transpose((2, 0, 1)) + + return { + 'dslr': dslr_image, + 'fname': self.names[idx]} + + + def _process_raw(self, raw): + raw = remove_black_level(raw,white_lv=4*255) + raw_combined = extract_bayer_channels(raw) + raw_demosaic = get_raw_demosaic(raw) + return raw_combined, raw_demosaic + + + +if __name__ == '__main__': + pass + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/__init__.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/__init__.py new file mode 100644 index 0000000..77aa38a --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/__init__.py @@ -0,0 +1,47 @@ +import importlib +from models.base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + raise NotImplementedError("In %s.py, there should be a subclass of " + "BaseModel with class name that matches %s in " + "lowercase." % (model_filename, target_model_name)) + + return model + + +def get_option_setter(model_name): + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/arch_util.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/arch_util.py new file mode 100644 index 0000000..8e005cf --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/arch_util.py @@ -0,0 +1,350 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import math +import torch +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + + + +# try: +# from basicsr.models.ops.dcn import (ModulatedDeformConvPack, +# modulated_deform_conv) +# except ImportError: +# # print('Cannot import dcn. Ignore this warning if dcn is not used. ' +# # 'Otherwise install BasicSR with compiling dcn.') +# + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' + 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +def flow_warp(x, + flow, + interp_mode='bilinear', + padding_mode='zeros', + align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid( + torch.arange(0, h).type_as(x), + torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample( + x, + vgrid_scaled, + mode=interp_mode, + padding_mode=padding_mode, + align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, + size_type, + sizes, + interp_mode='bilinear', + align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError( + f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, + size=(output_h, output_w), + mode=interp_mode, + align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +# class DCNv2Pack(ModulatedDeformConvPack): +# """Modulated deformable conv for deformable alignment. +# +# Different from the official DCNv2Pack, which generates offsets and masks +# from the preceding features, this DCNv2Pack takes another different +# features to generate offsets and masks. +# +# Ref: +# Delving Deep into Deformable Alignment in Video Super-Resolution. +# """ +# +# def forward(self, x, feat): +# out = self.conv_offset(feat) +# o1, o2, mask = torch.chunk(out, 3, dim=1) +# offset = torch.cat((o1, o2), dim=1) +# mask = torch.sigmoid(mask) +# +# offset_absmean = torch.mean(torch.abs(offset)) +# if offset_absmean > 50: +# logger = get_root_logger() +# logger.warning( +# f'Offset abs mean is {offset_absmean}, larger than 50.') +# +# return modulated_deform_conv(x, offset, mask, self.weight, self.bias, +# self.stride, self.padding, self.dilation, +# self.groups, self.deformable_groups) + + +class LayerNormFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( + dim=0), None + +class LayerNorm2d(nn.Module): + + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter('weight', nn.Parameter(torch.ones(channels))) + self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + +# handle multiple input +class MySequential(nn.Sequential): + def forward(self, *inputs): + for module in self._modules.values(): + if type(inputs) == tuple: + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + +import time +def measure_inference_speed(model, data, max_iter=200, log_interval=50): + model.eval() + + # the first several iterations may be very slow so skip them + num_warmup = 5 + pure_inf_time = 0 + fps = 0 + + # benchmark with 2000 image and take the average + for i in range(max_iter): + + torch.cuda.synchronize() + start_time = time.perf_counter() + + with torch.no_grad(): + model(*data) + + torch.cuda.synchronize() + elapsed = time.perf_counter() - start_time + + if i >= num_warmup: + pure_inf_time += elapsed + if (i + 1) % log_interval == 0: + fps = (i + 1 - num_warmup) / pure_inf_time + print( + f'Done image [{i + 1:<3}/ {max_iter}], ' + f'fps: {fps:.1f} img / s, ' + f'times per image: {1000 / fps:.1f} ms / img', + flush=True) + + if (i + 1) == max_iter: + fps = (i + 1 - num_warmup) / pure_inf_time + print( + f'Overall fps: {fps:.1f} img / s, ' + f'times per image: {1000 / fps:.1f} ms / img', + flush=True) + break + return fps \ No newline at end of file diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/base_model.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/base_model.py new file mode 100644 index 0000000..aceb65a --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/base_model.py @@ -0,0 +1,293 @@ +import os +import torch +from collections import OrderedDict +from abc import ABC, abstractmethod +from . import networks +import torch +from util.util import torch_save +import math +import torch.nn.functional as F + +def calc_psnr(sr, hr, range=1.): + # shave = 2 + with torch.no_grad(): + diff = (sr - hr) / range + mse = torch.pow(diff, 2) + mse= torch.mean(mse,dim=1,keepdim=True) + return (-10 * torch.log10(mse)) + +class BaseModel(ABC): + def __init__(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.scale = opt.scale + + if len(self.gpu_ids) > 0: + self.device = torch.device('cuda', self.gpu_ids[0]) + else: + self.device = torch.device('cpu') + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.optimizers = [] + self.optimizer_names = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + self.start_epoch = 0 + + self.backwarp_tenGrid = {} + self.backwarp_tenPartial = {} + + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + @abstractmethod + def set_input(self, input): + pass + + @abstractmethod + def forward(self): + pass + + @abstractmethod + def optimize_parameters(self): + pass + + def setup(self, opt=None): + opt = opt if opt is not None else self.opt + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) \ + for optimizer in self.optimizers] + for scheduler in self.schedulers: + scheduler.last_epoch = opt.load_iter + if opt.load_iter > 0 or opt.load_path != '': + load_suffix = opt.load_iter + self.load_networks(load_suffix) + if opt.load_optimizers: + self.load_optimizers(opt.load_iter) + + self.print_networks(opt.verbose) + + def eval(self): + for name in self.model_names: + net = getattr(self, 'net' + name) + net.eval() + + def train(self): + for name in self.model_names: + net = getattr(self, 'net' + name) + net.train() + + def test(self): + with torch.no_grad(): + self.forward() + + def get_image_paths(self): + return self.image_paths + + def update_learning_rate(self): + for i, scheduler in enumerate(self.schedulers): + if scheduler.__class__.__name__ == 'ReduceLROnPlateau': + scheduler.step(self.metric) + else: + scheduler.step() + print('lr of %s = %.7f' % ( + self.optimizer_names[i], scheduler.get_last_lr()[0])) + + def get_current_visuals(self): + visual_ret = OrderedDict() + for name in self.visual_names: + if 'xy' in name or 'coord' in name: + visual_ret[name] = getattr(self, name).detach() + else: + visual_ret[name] = torch.clamp( + getattr(self, name).detach(), 0., 1.) + return visual_ret + + def get_current_losses(self): + errors_ret = OrderedDict() + for name in self.loss_names: + errors_ret[name] = float(getattr(self, 'loss_' + name)) + return errors_ret + + def save_networks(self, epoch): + for name in self.model_names: + save_filename = '%s_model_%d.pth' % (name, epoch) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net' + name) + if self.device.type == 'cuda': + state = {'state_dict': net.module.cpu().state_dict()} + torch_save(state, save_path) + net.to(self.device) + else: + state = {'state_dict': net.state_dict()} + torch_save(state, save_path) + self.save_optimizers(epoch) + + def load_networks(self, epoch): +# self.model_names.append('GCMModel') + for name in self.model_names: #[0:1]: + # if name is 'Discriminator': + # continue + load_filename = '%s_model_%d.pth' % (name, epoch) +# if name=='GCMModel': +# load_filename = '%s_model_%d.pth' % (name, 1) + if self.opt.load_path != '': + load_path = self.opt.load_path + else: + load_path = os.path.join(self.save_dir, load_filename) + print(name,load_path) + net = getattr(self, 'net' + name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + state_dict = torch.load(load_path, map_location=self.device) + print('loading the model from %s' % (load_path)) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + + net_state = net.state_dict() + is_loaded = {n:False for n in net_state.keys()} + for name, param in state_dict['state_dict'].items(): + if name in net_state: + try: + net_state[name].copy_(param) + is_loaded[name] = True + except Exception: + print('While copying the parameter named [%s], ' + 'whose dimensions in the model are %s and ' + 'whose dimensions in the checkpoint are %s.' + % (name, list(net_state[name].shape), + list(param.shape))) + raise RuntimeError + else: + print('Saved parameter named [%s] is skipped' % name) + mark = True + for name in is_loaded: + if not is_loaded[name]: + print('Parameter named [%s] is randomly initialized' % name) + mark = False + if mark: + print('All parameters are initialized using [%s]' % load_path) + + self.start_epoch = epoch + + def save_optimizers(self, epoch): + assert len(self.optimizers) == len(self.optimizer_names) + for id, optimizer in enumerate(self.optimizers): + save_filename = self.optimizer_names[id] + state = {'name': save_filename, + 'epoch': epoch, + 'state_dict': optimizer.state_dict()} + save_path = os.path.join(self.save_dir, save_filename+'.pth') + torch_save(state, save_path) + + def load_optimizers(self, epoch): + assert len(self.optimizers) == len(self.optimizer_names) + for id, optimizer in enumerate(self.optimizer_names): + load_filename = self.optimizer_names[id] + load_path = os.path.join(self.save_dir, load_filename+'.pth') + print('loading the optimizer from %s' % load_path) + state_dict = torch.load(load_path) + print(state_dict['epoch']) + assert optimizer == state_dict['name'] + assert epoch == state_dict['epoch'] + self.optimizers[id].load_state_dict(state_dict['state_dict']) + + def print_networks(self, verbose): + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' + % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + def estimate(self, tenFirst, tenSecond, net): + assert(tenFirst.shape[3] == tenSecond.shape[3]) + assert(tenFirst.shape[2] == tenSecond.shape[2]) + intWidth = tenFirst.shape[3] + intHeight = tenFirst.shape[2] + # tenPreprocessedFirst = tenFirst.view(1, 3, intHeight, intWidth) + # tenPreprocessedSecond = tenSecond.view(1, 3, intHeight, intWidth) + + intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) + intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) + + tenPreprocessedFirst = F.interpolate(input=tenFirst, + size=(intPreprocessedHeight, intPreprocessedWidth), + mode='bilinear', align_corners=False) + tenPreprocessedSecond = F.interpolate(input=tenSecond, + size=(intPreprocessedHeight, intPreprocessedWidth), + mode='bilinear', align_corners=False) + + tenFlow = 20.0 * F.interpolate( + input=net(tenPreprocessedFirst, tenPreprocessedSecond), + size=(intHeight, intWidth), mode='bilinear', align_corners=False) + + tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) + tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) + + return tenFlow[:, :, :, :] + + def backwarp(self, tenInput, tenFlow): + index = str(tenFlow.shape) + str(tenInput.device) + if index not in self.backwarp_tenGrid: + tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), + tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1) + tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), + tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3]) + self.backwarp_tenGrid[index] = torch.cat([tenHor, tenVer], 1).to(tenInput.device) + + if index not in self.backwarp_tenPartial: + self.backwarp_tenPartial[index] = tenFlow.new_ones([ + tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3]]) + + tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) + tenInput = torch.cat([tenInput, self.backwarp_tenPartial[index]], 1) + + tenOutput = F.grid_sample(input=tenInput, + grid=(self.backwarp_tenGrid[index] + tenFlow).permute(0, 2, 3, 1), + mode='bilinear', padding_mode='zeros', align_corners=False) + + return tenOutput + + def get_backwarp(self, tenFirst, tenSecond,raw, net, flow=None): + if flow is None: + flow = self.get_flow(tenFirst, tenSecond, net) + + flow_raw = F.interpolate(flow, scale_factor=0.5)/2. + tenoutput = self.backwarp(tenSecond, flow) + rgb_tenMask = tenoutput[:, -1:, :, :] + rgb_tenMask[rgb_tenMask > 0.999] = 1.0 + rgb_tenMask[rgb_tenMask < 1.0] = 0.0 + + rawoutput = self.backwarp(raw, flow_raw) + raw_tenMask = rawoutput[:, -1:, :, :] + raw_tenMask[raw_tenMask > 0.999] = 1.0 + raw_tenMask[raw_tenMask < 1.0] = 0.0 + d=tenoutput[:, :-1, :, :] * rgb_tenMask + return tenoutput[:, :-1, :, :] * rgb_tenMask, rgb_tenMask,rawoutput[:, :-1, :, :] * raw_tenMask, raw_tenMask + + + def get_flow(self, tenFirst, tenSecond, net): + with torch.no_grad(): + net.eval() + flow = self.estimate(tenFirst, tenSecond, net) + return flow \ No newline at end of file diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/blocks.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/blocks.py new file mode 100644 index 0000000..fb924f7 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/blocks.py @@ -0,0 +1,281 @@ +""" +Code copy from container source code: +https://github.com/allenai/container/blob/main/models.py +""" +import os +import torch +import torch.nn as nn +from functools import partial +import math +from timm.models.vision_transformer import VisionTransformer, _cfg +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_, DropPath, to_2tuple + +# ResMLP's normalization +class Aff(nn.Module): + def __init__(self, dim): + super().__init__() + # learnable + self.alpha = nn.Parameter(torch.ones([1, 1, dim])) + self.beta = nn.Parameter(torch.zeros([1, 1, dim])) + + def forward(self, x): + x = x * self.alpha + self.beta + return x + +# Color Normalization +class Aff_channel(nn.Module): + def __init__(self, dim, channel_first = True): + super().__init__() + # learnable + self.alpha = nn.Parameter(torch.ones([1, 1, dim])) + self.beta = nn.Parameter(torch.zeros([1, 1, dim])) + self.color = nn.Parameter(torch.eye(dim)) + self.channel_first = channel_first + + def forward(self, x): + if self.channel_first: + x1 = torch.tensordot(x, self.color, dims=[[-1], [-1]]) + x2 = x1 * self.alpha + self.beta + else: + x1 = x * self.alpha + self.beta + x2 = torch.tensordot(x1, self.color, dims=[[-1], [-1]]) + return x2 + +class Mlp(nn.Module): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class CMlp(nn.Module): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class CBlock_ln(nn.Module): + def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=Aff_channel, init_values=1e-4): + super().__init__() + self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) + #self.norm1 = Aff_channel(dim) + self.norm1 = norm_layer(dim) + self.conv1 = nn.Conv2d(dim, dim, 1) + self.conv2 = nn.Conv2d(dim, dim, 1) + self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + #self.norm2 = Aff_channel(dim) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.gamma_1 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True) + self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.pos_embed(x) + B, C, H, W = x.shape + #print(x.shape) + norm_x = x.flatten(2).transpose(1, 2) + #print(norm_x.shape) + norm_x = self.norm1(norm_x) + norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2) + + + x = x + self.drop_path(self.gamma_1*self.conv2(self.attn(self.conv1(norm_x)))) + norm_x = x.flatten(2).transpose(1, 2) + norm_x = self.norm2(norm_x) + norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2) + x = x + self.drop_path(self.gamma_2*self.mlp(norm_x)) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + #print(x.shape) + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x): + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +## Layer_norm, Aff_norm, Aff_channel_norm +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads=2, window_size=8, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=Aff_channel): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + + self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) + #self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + #self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.pos_embed(x) + B, C, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x.transpose(1, 2).reshape(B, C, H, W) + + return x + + +if __name__ == "__main__": + os.environ['CUDA_VISIBLE_DEVICES']='1' + cb_blovk = CBlock_ln(dim = 16) + x = torch.Tensor(1, 16, 400, 600) + swin = SwinTransformerBlock(dim=16, num_heads=4) + x = cb_blovk(x) + print(x.shape) diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/global_net.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/global_net.py new file mode 100644 index 0000000..d512732 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/global_net.py @@ -0,0 +1,129 @@ +import imp +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_, DropPath, to_2tuple +import os +from models.blocks import Mlp + + +class query_Attention(nn.Module): + def __init__(self, dim, num_heads=2, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.q = nn.Parameter(torch.ones((1, 17, dim)), requires_grad=True) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, 17, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class query_SABlock(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) + self.norm1 = norm_layer(dim) + self.attn = query_Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.pos_embed(x);print(x.shape) + x = x.flatten(2).transpose(1, 2);print(x.shape) + x = self.drop_path(self.attn(self.norm1(x)));print(x.shape) + x = x + self.drop_path(self.mlp(self.norm2(x)));print(x.shape) + return x + + +class conv_embedding(nn.Module): + def __init__(self, in_channels, out_channels): + super(conv_embedding, self).__init__() + self.proj = nn.Sequential( + nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), + nn.BatchNorm2d(out_channels // 2), + nn.GELU(), + # nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), + # nn.BatchNorm2d(out_channels // 2), + # nn.GELU(), + nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + x = self.proj(x) + return x + + +class Global_pred(nn.Module): + def __init__(self, in_channels=3, out_channels=64, num_heads=4, type='exp'): + super(Global_pred, self).__init__() + if type == 'exp': + self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=False) # False in exposure correction + else: + self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=True) + self.color_base = nn.Parameter(torch.eye((4)), requires_grad=True) # basic color matrix + # main blocks + self.conv_large = conv_embedding(in_channels, out_channels) + self.generator = query_SABlock(dim=out_channels, num_heads=num_heads) + self.gamma_linear = nn.Linear(out_channels, 1) + self.color_linear = nn.Linear(out_channels, 1) + + self.apply(self._init_weights) + + for name, p in self.named_parameters(): + if name == 'generator.attn.v.weight': + nn.init.constant_(p, 0) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + + def forward(self, x): + #print(x.shape) + x = self.conv_large(x) + x = self.generator(x) + gamma, color = x[:, 0].unsqueeze(1), x[:, 1:] + gamma = self.gamma_linear(gamma).squeeze(-1) + self.gamma_base + #print(x.shape) + color = self.color_linear(color).squeeze(-1).view(-1, 4, 4) + self.color_base + return gamma, color + +if __name__ == "__main__": + os.environ['CUDA_VISIBLE_DEVICES']='3' + #net = Local_pred_new().cuda() + img = torch.Tensor(8, 3, 400, 600) + global_net = Global_pred() + gamma, color = global_net(img) + print(gamma.shape, color.shape) diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/local_arch.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/local_arch.py new file mode 100644 index 0000000..bc459c6 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/local_arch.py @@ -0,0 +1,104 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +class AvgPool2d(nn.Module): + def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None): + super().__init__() + self.kernel_size = kernel_size + self.base_size = base_size + self.auto_pad = auto_pad + + # only used for fast implementation + self.fast_imp = fast_imp + self.rs = [5, 4, 3, 2, 1] + self.max_r1 = self.rs[0] + self.max_r2 = self.rs[0] + self.train_size = train_size + + def extra_repr(self) -> str: + return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format( + self.kernel_size, self.base_size, self.kernel_size, self.fast_imp + ) + + def forward(self, x): + if self.kernel_size is None and self.base_size: + train_size = self.train_size + if isinstance(self.base_size, int): + self.base_size = (self.base_size, self.base_size) + self.kernel_size = list(self.base_size) + self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2] + self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1] + + # only used for fast implementation + self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2]) + self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1]) + + if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1): + return F.adaptive_avg_pool2d(x, 1) + + if self.fast_imp: # Non-equivalent implementation but faster + h, w = x.shape[2:] + if self.kernel_size[0] >= h and self.kernel_size[1] >= w: + out = F.adaptive_avg_pool2d(x, 1) + else: + r1 = [r for r in self.rs if h % r == 0][0] + r2 = [r for r in self.rs if w % r == 0][0] + # reduction_constraint + r1 = min(self.max_r1, r1) + r2 = min(self.max_r2, r2) + s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2) + n, c, h, w = s.shape + k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2) + out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2) + out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2)) + else: + n, c, h, w = x.shape + s = x.cumsum(dim=-1).cumsum_(dim=-2) + s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience + k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1]) + s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:] + out = s4 + s1 - s2 - s3 + out = out / (k1 * k2) + + if self.auto_pad: + n, c, h, w = x.shape + _h, _w = out.shape[2:] + # print(x.shape, self.kernel_size) + pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2) + out = torch.nn.functional.pad(out, pad2d, mode='replicate') + + return out + +def replace_layers(model, base_size, train_size, fast_imp, **kwargs): + for n, m in model.named_children(): + if len(list(m.children())) > 0: + ## compound module, go inside it + replace_layers(m, base_size, train_size, fast_imp, **kwargs) + + if isinstance(m, nn.AdaptiveAvgPool2d): + pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size) + assert m.output_size == 1 + setattr(model, n, pool) + + +''' +ref. +@article{chu2021tlsc, + title={Revisiting Global Statistics Aggregation for Improving Image Restoration}, + author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin}, + journal={arXiv preprint arXiv:2112.04491}, + year={2021} +} +''' +class Local_Base(): + def convert(self, *args, train_size, **kwargs): + replace_layers(self, *args, train_size=train_size, **kwargs) + imgs = torch.rand(train_size) + with torch.no_grad(): + self.forward(imgs) diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/losses.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/losses.py new file mode 100644 index 0000000..6e33447 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/losses.py @@ -0,0 +1,345 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp +from torch.nn import L1Loss, MSELoss +import numpy as np + +class CannyNet(nn.Module): + def __init__(self): + super(CannyNet, self).__init__() + self.pad = nn.ReflectionPad2d(1) + self.conv1 = nn.Conv2d(4, 4, 3, padding=(0, 0), bias=False) + def forward(self, x): + b,c,h,w = x.size() + x = self.conv1(self.pad(x)) + return x + +class Canny(nn.Module): + def __init__(self): + super(Canny, self).__init__() + self.net = CannyNet().cuda() + self.conv_rgb_core_original = [ + [[0, 0, 0], [0, 1, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 1, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 1, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 1, 0], [0, 0, 0] + ]] + self.conv_rgb_core_sobel = [ + [[-1, -1, -1], [-1, 8, -1], [-1, -1, -1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [-1, -1, -1], [-1, 8, -1], [-1, -1, -1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [-1, -1, -1], [-1, 8, -1], [-1, -1, -1], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [-1, -1, -1], [-1, 8, -1], [-1, -1, -1], + ]] + self.conv_rgb_core_sobel_vertical = [ + [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [-1, 0, 1], [-2, 0, 2], [-1, 0, 1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [-1, 0, 1], [-2, 0, 2], [-1, 0, 1], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [-1, 0, 1], [-2, 0, 2], [-1, 0, 1], + ]] + self.conv_rgb_core_sobel_horizontal = [ + [[1, 2, 1], [0, 0, 0], [-1, -2, -1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [1, 2, 1], [0, 0, 0], [-1, -2, -1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [1, 2, 1], [0, 0, 0], [-1, -2, -1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [1, 2, 1], [0, 0, 0], [-1, -2, -1], + ]] + + def sobel(self, net, kernel): + sobel_kernel = np.array(kernel, dtype='float32') + sobel_kernel = sobel_kernel.reshape((4, 4, 3, 3)) + net.conv1.weight.data = torch.from_numpy(sobel_kernel).cuda() + + def forward(self, x): + # x = x*2-1 #to [-1,1] + # self.sobel(self.net, self.conv_rgb_core_sobel) + # out = self.net(x).detach() + self.sobel(self.net, self.conv_rgb_core_sobel_vertical) + out_v = self.net(x).detach() + self.sobel(self.net, self.conv_rgb_core_sobel_horizontal) + out_h = self.net(x).detach() + out = torch.sqrt(torch.square(out_h)+torch.square(out_v)) + # out = torch.abs((out+1)/2.) + return out + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp( + -(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) \ + for x in range(window_size)]) + return gauss / gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand( + channel, 1, window_size, window_size).contiguous()) + return window + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d( + img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d( + img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d( + img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ + ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + +def ssim(img1, img2, window_size=11, size_average=True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + +class SSIMLoss(torch.nn.Module): + def __init__(self, window_size=11, size_average=True): + super(SSIMLoss, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and \ + self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim(img1, img2, window, self.window_size, + channel, self.size_average) + +CONTENT_LAYER = 'relu_16' +cfgs = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + +def make_layers(cfg, batch_norm=False): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + +class VGG(nn.Module): + def __init__(self, num_classes=1000): + super(VGG, self).__init__() + self.features = make_layers(cfgs['E'], batch_norm=False) + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + self.load_state_dict(torch.load('./ckpt/vgg19.pth')) + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + +def vgg_19(): + vgg_19 = VGG().features + model = nn.Sequential() + + i = 0 + for layer in vgg_19.children(): + if isinstance(layer, nn.Conv2d): + i += 1 + name = 'conv_{}'.format(i) + elif isinstance(layer, nn.ReLU): + name = 'relu_{}'.format(i) + layer = nn.ReLU(inplace=False) + elif isinstance(layer, nn.MaxPool2d): + name = 'pool_{}'.format(i) + elif isinstance(layer, nn.BatchNorm2d): + name = 'bn_{}'.format(i) + else: + raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) + + model.add_module(name, layer) + if name == CONTENT_LAYER: + break + + for param in model.parameters(): + param.requires_grad = False + + for param in vgg_19.parameters(): + param.requires_grad = False + + return model + +def normalize_batch(batch): + mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) + std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) + return (batch - mean) / std + +class VGGLoss(torch.nn.Module): + def __init__(self): + super(VGGLoss, self).__init__() + self.VGG_19 = vgg_19() + self.L1_loss = torch.nn.L1Loss() + + def forward(self, img1, img2): + img1 = F.interpolate(img1, scale_factor=0.5, mode="bilinear") + img2 = F.interpolate(img2, scale_factor=0.5, mode="bilinear") + img1_vgg = self.VGG_19(normalize_batch(img1)) + img2_vgg = self.VGG_19(normalize_batch(img2)) + loss_vgg = self.L1_loss(img1_vgg, img2_vgg) + return loss_vgg + + +class FFTLoss(nn.Module): + def __init__(self): + super().__init__() + self.canny = Canny() + self.criterion = torch.nn.L1Loss() +# self.loss_weight = loss_weight + + def forward(self, pred, target): + Edge = self.canny(target) + # pred_fft = torch.fft.fft2(pred, dim=(-2, -1)) + # target_fft = torch.fft.fft2(target, dim=(-2, -1)) + return self.criterion(pred*Edge, target*Edge) + +class GANLoss(nn.Module): + def __init__(self, gan_mode='lsgan', target_real_label=1.0, target_fake_label=0.0): + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.gan_mode = gan_mode + if gan_mode == 'lsgan': + self.loss = nn.MSELoss() + elif gan_mode == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif gan_mode in ['wgangp']: + self.loss = None + else: + raise NotImplementedError('gan mode %s not implemented' % gan_mode) + + def get_target_tensor(self, prediction, target_is_real): + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(prediction) + + def __call__(self, prediction, target_is_real): + if self.gan_mode in ['lsgan', 'vanilla']: + target_tensor = self.get_target_tensor(prediction, target_is_real) + loss = self.loss(prediction, target_tensor) + elif self.gan_mode == 'wgangp': + if target_is_real: + loss = -prediction.mean() + else: + loss = prediction.mean() + return loss diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/modules.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/modules.py new file mode 100644 index 0000000..0f9ce53 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/modules.py @@ -0,0 +1,388 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import split_feature, compute_same_pad + + +def gaussian_p(mean, logs, x): + """ + lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) } + k = 1 (Independent) + Var = logs ** 2 + """ + c = math.log(2 * math.pi) + return -0.5 * (logs * 2.0 + ((x - mean) ** 2) / torch.exp(logs * 2.0) + c) + + +def gaussian_likelihood(mean, logs, x): + p = gaussian_p(mean, logs, x) + return torch.sum(p, dim=[1, 2, 3]) + + +def gaussian_sample(mean, logs, temperature=1): + # Sample from Gaussian with temperature + z = torch.normal(mean, torch.exp(logs) * temperature) + + return z + + +def squeeze2d(input, factor): + if factor == 1: + return input + + B, C, H, W = input.size() + + assert H % factor == 0 and W % factor == 0, "H or W modulo factor is not 0" + + x = input.view(B, C, H // factor, factor, W // factor, factor) + x = x.permute(0, 1, 3, 5, 2, 4).contiguous() + x = x.view(B, C * factor * factor, H // factor, W // factor) + + return x + + +def unsqueeze2d(input, factor): + if factor == 1: + return input + + factor2 = factor ** 2 + + B, C, H, W = input.size() + + assert C % (factor2) == 0, "C module factor squared is not 0" + + x = input.view(B, C // factor2, factor, factor, H, W) + x = x.permute(0, 1, 4, 2, 5, 3).contiguous() + x = x.view(B, C // (factor2), H * factor, W * factor) + + return x + + +class _ActNorm(nn.Module): + """ + Activation Normalization + Initialize the bias and scale with a given minibatch, + so that the output per-channel have zero mean and unit variance for that. + + After initialization, `bias` and `logs` will be trained as parameters. + """ + + def __init__(self, num_features, scale=1.0): + super().__init__() + # register mean and scale + size = [1, num_features, 1, 1] + self.bias = nn.Parameter(torch.zeros(*size)) + self.logs = nn.Parameter(torch.zeros(*size)) + self.num_features = num_features + self.scale = scale + self.inited = False + + def initialize_parameters(self, input): + if not self.training: + raise ValueError("In Eval mode, but ActNorm not inited") + + with torch.no_grad(): + bias = -torch.mean(input.clone(), dim=[0, 2, 3], keepdim=True) + vars = torch.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) + logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) + + self.bias.data.copy_(bias.data) + self.logs.data.copy_(logs.data) + + self.inited = True + + def _center(self, input, reverse=False): + if reverse: + return input - self.bias + else: + return input + self.bias + + def _scale(self, input, logdet=None, reverse=False): + + if reverse: + input = input * torch.exp(-self.logs) + else: + input = input * torch.exp(self.logs) + + if logdet is not None: + """ + logs is log_std of `mean of channels` + so we need to multiply by number of pixels + """ + b, c, h, w = input.shape + + dlogdet = torch.sum(self.logs) * h * w + + if reverse: + dlogdet *= -1 + + logdet = logdet + dlogdet + + return input, logdet + + def forward(self, input, logdet=None, reverse=False): + self._check_input_dim(input) + + if not self.inited: + self.initialize_parameters(input) + + if reverse: + input, logdet = self._scale(input, logdet, reverse) + input = self._center(input, reverse) + else: + input = self._center(input, reverse) + input, logdet = self._scale(input, logdet, reverse) + + return input, logdet + + +class ActNorm2d(_ActNorm): + def __init__(self, num_features, scale=1.0): + super().__init__(num_features, scale) + + def _check_input_dim(self, input): + assert len(input.size()) == 4 + assert input.size(1) == self.num_features, ( + "[ActNorm]: input should be in shape as `BCHW`," + " channels should be {} rather than {}".format( + self.num_features, input.size() + ) + ) + + +class LinearZeros(nn.Module): + def __init__(self, in_channels, out_channels, logscale_factor=3): + super().__init__() + + self.linear = nn.Linear(in_channels, out_channels) + self.linear.weight.data.zero_() + self.linear.bias.data.zero_() + + self.logscale_factor = logscale_factor + + self.logs = nn.Parameter(torch.zeros(out_channels)) + + def forward(self, input): + output = self.linear(input) + return output * torch.exp(self.logs * self.logscale_factor) + + +class Conv2d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding="same", + do_actnorm=True, + weight_std=0.05, + ): + super().__init__() + + if padding == "same": + padding = compute_same_pad(kernel_size, stride) + elif padding == "valid": + padding = 0 + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias=(not do_actnorm), + ) + + # init weight with std + self.conv.weight.data.normal_(mean=0.0, std=weight_std) + + if not do_actnorm: + self.conv.bias.data.zero_() + else: + self.actnorm = ActNorm2d(out_channels) + + self.do_actnorm = do_actnorm + + def forward(self, input): + x = self.conv(input) + if self.do_actnorm: + x, _ = self.actnorm(x) + return x + + +class Conv2dZeros(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding="same", + logscale_factor=3, + ): + super().__init__() + + if padding == "same": + padding = compute_same_pad(kernel_size, stride) + elif padding == "valid": + padding = 0 + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) + + self.conv.weight.data.zero_() + self.conv.bias.data.zero_() + + self.logscale_factor = logscale_factor + self.logs = nn.Parameter(torch.zeros(out_channels, 1, 1)) + + def forward(self, input): + output = self.conv(input) + return output * torch.exp(self.logs * self.logscale_factor) + + +class Permute2d(nn.Module): + def __init__(self, num_channels, shuffle): + super().__init__() + self.num_channels = num_channels + self.indices = torch.arange(self.num_channels - 1, -1, -1, dtype=torch.long) + self.indices_inverse = torch.zeros((self.num_channels), dtype=torch.long) + + for i in range(self.num_channels): + self.indices_inverse[self.indices[i]] = i + + if shuffle: + self.reset_indices() + + def reset_indices(self): + shuffle_idx = torch.randperm(self.indices.shape[0]) + self.indices = self.indices[shuffle_idx] + + for i in range(self.num_channels): + self.indices_inverse[self.indices[i]] = i + + def forward(self, input, reverse=False): + assert len(input.size()) == 4 + + if not reverse: + input = input[:, self.indices, :, :] + return input + else: + return input[:, self.indices_inverse, :, :] + + +class Split2d(nn.Module): + def __init__(self, num_channels): + super().__init__() + self.conv = Conv2dZeros(num_channels // 2, num_channels) + + def split2d_prior(self, z): + h = self.conv(z) + return split_feature(h, "cross") + + def forward(self, input, logdet=0.0, reverse=False, temperature=None): + if reverse: + z1 = input + mean, logs = self.split2d_prior(z1) + z2 = gaussian_sample(mean, logs, temperature) + z = torch.cat((z1, z2), dim=1) + return z, logdet + else: + z1, z2 = split_feature(input, "split") + mean, logs = self.split2d_prior(z1) + logdet = gaussian_likelihood(mean, logs, z2) + logdet + return z1, logdet + + +class SqueezeLayer(nn.Module): + def __init__(self, factor): + super().__init__() + self.factor = factor + + def forward(self, input, logdet=None, reverse=False): + if reverse: + output = unsqueeze2d(input, self.factor) + else: + output = squeeze2d(input, self.factor) + + return output, logdet + + +class InvertibleConv1x1(nn.Module): + def __init__(self, num_channels, LU_decomposed): + super().__init__() + w_shape = [num_channels, num_channels] + w_init = torch.qr(torch.randn(*w_shape))[0] + + if not LU_decomposed: + self.weight = nn.Parameter(torch.Tensor(w_init)) + else: + p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) + s = torch.diag(upper) + sign_s = torch.sign(s) + log_s = torch.log(torch.abs(s)) + upper = torch.triu(upper, 1) + l_mask = torch.tril(torch.ones(w_shape), -1) + eye = torch.eye(*w_shape) + + self.register_buffer("p", p) + self.register_buffer("sign_s", sign_s) + self.lower = nn.Parameter(lower) + self.log_s = nn.Parameter(log_s) + self.upper = nn.Parameter(upper) + self.l_mask = l_mask + self.eye = eye + + self.w_shape = w_shape + self.LU_decomposed = LU_decomposed + + def get_weight(self, input, reverse): + b, c, h, w = input.shape + + if not self.LU_decomposed: + dlogdet = torch.slogdet(self.weight)[1] * h * w + if reverse: + weight = torch.inverse(self.weight) + else: + weight = self.weight + else: + self.l_mask = self.l_mask.to(input.device) + self.eye = self.eye.to(input.device) + + lower = self.lower * self.l_mask + self.eye + + u = self.upper * self.l_mask.transpose(0, 1).contiguous() + u += torch.diag(self.sign_s * torch.exp(self.log_s)) + + dlogdet = torch.sum(self.log_s) * h * w + + if reverse: + print(u) + u_inv = torch.inverse(u) + l_inv = torch.inverse(lower) + p_inv = torch.inverse(self.p) + + weight = torch.matmul(u_inv, torch.matmul(l_inv, p_inv)) + else: + weight = torch.matmul(self.p, torch.matmul(lower, u)) + + return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet + + def forward(self, input, logdet=None, reverse=False): + """ + log-det = log|abs(|W|)| * pixels + """ + weight, dlogdet = self.get_weight(input, reverse) + + if not reverse: + z = F.conv2d(input, weight) + if logdet is not None: + logdet = logdet + dlogdet + return z, logdet + else: + z = F.conv2d(input, weight) + if logdet is not None: + logdet = logdet - dlogdet + return z, logdet diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/mwcnn_model.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/mwcnn_model.py new file mode 100644 index 0000000..fe0e8d5 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/mwcnn_model.py @@ -0,0 +1,118 @@ +import torch +import networks as N +import torch.nn as nn +import math +import torch.optim as optim + +class MWRCAN(nn.Module): + def __init__(self): + super(MWRCAN, self).__init__() + c1 = 64 + c2 = 128 + c3 = 128 + n_b = 20 + self.head = N.DWTForward() + + self.down1 = N.seq( + nn.Conv2d(4 * 4, c1, 3, 1, 1), + nn.PReLU(), + N.RCAGroup(in_channels=c1, out_channels=c1, nb=n_b) + ) + + self.down2 = N.seq( + N.DWTForward(), + nn.Conv2d(c1 * 4, c2, 3, 1, 1), + nn.PReLU(), + N.RCAGroup(in_channels=c2, out_channels=c2, nb=n_b) + ) + + self.down3 = N.seq( + N.DWTForward(), + nn.Conv2d(c2 * 4, c3, 3, 1, 1), + nn.PReLU() + ) + + self.middle = N.seq( + N.RCAGroup(in_channels=c3, out_channels=c3, nb=n_b), + N.RCAGroup(in_channels=c3, out_channels=c3, nb=n_b) + ) + + self.up1 = N.seq( + nn.Conv2d(c3, c2 * 4, 3, 1, 1), + nn.PReLU(), + N.DWTInverse() + ) + + self.up2 = N.seq( + N.RCAGroup(in_channels=c2, out_channels=c2, nb=n_b), + nn.Conv2d(c2, c1 * 4, 3, 1, 1), + nn.PReLU(), + N.DWTInverse() + ) + + self.up3 = N.seq( + N.RCAGroup(in_channels=c1, out_channels=c1, nb=n_b), + nn.Conv2d(c1, 16, 3, 1, 1) + ) + + self.tail = N.seq( + N.DWTInverse(), + nn.Conv2d(4, 12, 3, 1, 1), + nn.PixelShuffle(upscale_factor=2) + ) + + def forward(self, x, c=None): + c0 = x + c1 = self.head(c0) + c2 = self.down1(c1) + c3 = self.down2(c2) + c4 = self.down3(c3) + m = self.middle(c4) + c5 = self.up1(m) + c3 + c6 = self.up2(c5) + c2 + c7 = self.up3(c6) + c1 + out = self.tail(c7) + + return out + +class Discriminator(nn.Module): + """Defines a PatchGAN discriminator""" + def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(Discriminator, self).__init__() + use_bias = False + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.model = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.model(input) \ No newline at end of file diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/naf_model.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/naf_model.py new file mode 100644 index 0000000..c00bc9e --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/naf_model.py @@ -0,0 +1,317 @@ +import torch +from .base_model import BaseModel +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from . import losses as L +from util.util import get_coord +import numpy as np +from models.arch_util import LayerNorm2d +from models.local_arch import Local_Base +from . import net as N + + +def demosaic (raw): + """Simple demosaicing to visualize RAW images + Inputs: + - raw: (h,w,4) RAW RGGB image normalized [0..1] as float32 + Returns: + - Simple Avg. Green Demosaiced RAW image with shape (h*2, w*2, 3) + """ + + assert raw.shape[1] == 4 + shape = raw.shape + + blue = raw[:,0:1,:,:] + green_red = raw[:,1:2,:,:] + red = raw[:,2:3,:,:] + green_blue = raw[:,3:,:,:] + avg_green = (green_red + green_blue) / 2 + image = torch.cat((red, avg_green, blue), dim=1) + image = F.interpolate(input=image, size=(shape[2]*2, shape[3]*2), + mode='bilinear', align_corners=True) + return image + +def gamma_compression(image): + """Converts from linear to gamma space.""" + return torch.clamp(image, 1e-8, 1.0) ** (1.0 / 2.2) + +def tonemap(image): + """Simple S-curved global tonemap""" + return (3*(image**2)) - (2*(image**3)) + +def ISP(raw): + raw = demosaic(raw) + raw = gamma_compression(raw) + raw = tonemap(raw) + raw = torch.clamp(raw, 0.0, 1.0) + return raw + + +def pixel_unshuffle(input, downscale_factor): + ''' + input: batchSize * c * k*w * k*h + kdownscale_factor: k + batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h + ''' + c = input.shape[1] + + kernel = torch.zeros(size=[downscale_factor * downscale_factor * c, + 1, downscale_factor, downscale_factor], + device=input.device) + for y in range(downscale_factor): + for x in range(downscale_factor): + kernel[x + y * downscale_factor::downscale_factor*downscale_factor, 0, y, x] = 1 + return F.conv2d(input, kernel, stride=downscale_factor, groups=c) + +class PixelUnshuffle(nn.Module): + def __init__(self, downscale_factor): + super(PixelUnshuffle, self).__init__() + self.downscale_factor = downscale_factor + def forward(self, input): + ''' + input: batchSize * c * k*w * k*h + kdownscale_factor: k + batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h + ''' + + return pixel_unshuffle(input, self.downscale_factor) + +def inverse_gamma(image): + """Converts from linear to gamma space.""" + return torch.clamp(image, 1e-8, 1.0) ** (2.2) + +def inverse_tonemap(image): + image =torch.clamp(image,0.,1.) + image = 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0) + return image + +class nafModel(BaseModel): + staticmethod + def modify_commandline_options(parser, is_train=True): + return parser + + def __init__(self, opt): + super(nafModel, self).__init__(opt) + + self.opt = opt + + self.loss_names = ['GCMModel_L1', 'NAFISPNet_L1', 'NAFISPNet_SSIM','NAFISPNet_VGG', 'Total'] + + if self.isTrain: + self.visual_names = [ 'data_out','data_raw_demosaic','data_dslr','GCMModel_out','GCMModel_out_warp','rgb_mask','raw_warp', 'raw_mask','data_raw','data_dslr_mask','data_out_mask'] + else: + self.visual_names = [ 'data_out','data_dslr'] + + self.model_names = ['NAFISPNet', 'GCMModel'] + self.optimizer_names = ['NAFISPNet_optimizer_%s' % opt.optimizer, + 'GCMModel_optimizer_%s' % opt.optimizer] + + isp = NAFISPNet(opt) + self.netNAFISPNet= N.init_net(isp, opt.init_type, opt.init_gain, opt.gpu_ids) + + self.pool = nn.AvgPool2d(kernel_size=2,stride=2,padding=0) + + gcm = GCMModel(opt) + self.netGCMModel = N.init_net(gcm, opt.init_type, opt.init_gain, opt.gpu_ids) + + if self.isTrain: + + from pwc import pwc_net + pwcnet = pwc_net.PWCNET() + self.netPWCNET = N.init_net(pwcnet, opt.init_type, opt.init_gain, opt.gpu_ids) + self.set_requires_grad(self.netPWCNET, requires_grad=False) + + + if self.isTrain: + self.optimizer_NAFISPNet = optim.AdamW(self.netNAFISPNet.parameters(), + lr=opt.lr, + betas=(opt.beta1, opt.beta2), + weight_decay=opt.weight_decay) + self.optimizer_GCMModel = optim.AdamW(self.netGCMModel.parameters(), + lr=opt.lr, + betas=(opt.beta1, opt.beta2), + weight_decay=opt.weight_decay) + self.optimizers = [self.optimizer_NAFISPNet, self.optimizer_GCMModel] + + self.criterionL1 = N.init_net(L.L1Loss(), gpu_ids=opt.gpu_ids) + self.criterionSSIM = N.init_net(L.SSIMLoss(), gpu_ids=opt.gpu_ids) + self.criterionVGG = N.init_net(L.VGGLoss(), gpu_ids=opt.gpu_ids) + + + def set_input(self, input): + if self.isTrain: + self.data_raw = input['raw'].to(self.device) + self.data_raw_demosaic = input['raw_demosaic'].to(self.device) + self.data_dslr = input['dslr'].to(self.device) + self.image_paths = input['fname'] + + def forward(self): + if self.isTrain: + self.GCMModel_out = self.netGCMModel(self.data_raw_demosaic, self.data_dslr) + self.GCMModel_out_warp, self.rgb_mask,self.raw_warp, self.raw_mask = \ + self.get_backwarp( self.data_dslr,self.GCMModel_out,self.data_raw, self.netPWCNET) + + self.data_dslr_pool=self.pool(self.data_dslr) + self.data_out = self.netNAFISPNet(self.data_dslr_pool) + + + if self.isTrain: + self.data_dslr_mask = self.data_dslr * self.rgb_mask; + self.data_out_mask = self.data_out * self.raw_mask + self.raw_warp_rgb = ISP(self.raw_warp).to(self.device) + self.data_out_mask_rgb = ISP(self.data_out_mask).to(self.device) + + def backward(self): + self.loss_GCMModel_L1 = self.criterionL1(self.GCMModel_out_warp, self.data_dslr_mask).mean() + self.loss_NAFISPNet_L1 = self.criterionL1(self.data_out_mask, self.raw_warp).mean() + self.loss_NAFISPNet_SSIM = 1 - self.criterionSSIM(self.data_out_mask, self.raw_warp).mean() + self.loss_NAFISPNet_VGG = self.criterionVGG(self.data_out_mask_rgb, self.raw_warp_rgb).mean() + self.loss_Total = self.loss_GCMModel_L1 + self.loss_NAFISPNet_L1+ \ + self.loss_NAFISPNet_VGG * 0.4 + self.loss_NAFISPNet_SSIM * 0.3 + self.loss_Total.backward() + + def optimize_parameters(self): + self.forward() + self.optimizer_NAFISPNet.zero_grad() + self.optimizer_GCMModel.zero_grad() + self.backward() + self.optimizer_NAFISPNet.step() + self.optimizer_GCMModel.step() + +class GCMModel(nn.Module): + def __init__(self, opt): + super(GCMModel, self).__init__() + self.opt = opt + self.ch_1 = 32 + self.ch_2 = 64 + + guide_input_channels = 6 + align_input_channels = 3 + + self.guide_net = N.seq( + N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'), + N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'), + nn.AdaptiveAvgPool2d(1), + N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C') + ) + + self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR') + + self.align_base = N.seq( + N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCR') + ) + self.align_tail = N.seq( + N.conv(self.ch_2, 3, 1, padding=0, mode='C') + ) + + def forward(self, demosaic_raw, dslr): + demosaic_raw = torch.pow(demosaic_raw, 1/2.2) + + guide_input = torch.cat((demosaic_raw, dslr), 1) + base_input = demosaic_raw + + guide = self.guide_net(guide_input) + + out = self.align_head(base_input) + out = guide * out + out + out = self.align_base(out) + out = self.align_tail(out) + demosaic_raw + + return out + + + +class NAFISPNet(nn.Module): + + def __init__(self, img_channel=3, width=64, middle_blk_num=12, enc_blk_nums=[2, 2, 4, 8], dec_blk_nums=[2, 2, 2, 2]): + super().__init__() + + self.intro = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + self.intro_1 = nn.Conv2d(in_channels=width, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, + bias=True) + self.ending = nn.Sequential( + #nn.Conv2d(in_channels=width, out_channels=width//4, kernel_size=3, padding=1, stride=1, groups=1, + # bias=True), + #PixelUnshuffle(downscale_factor=2), + nn.Conv2d(in_channels=width, out_channels=4, kernel_size=3, padding=1, stride=1, groups=1, + bias=True), + ) + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + self.middle_blks = nn.ModuleList() + self.ups = nn.ModuleList() + self.downs = nn.ModuleList() + + chan = width + for num in enc_blk_nums: + self.encoders.append( + nn.Sequential( + *[N.NAFBlock(chan) for _ in range(num)] + ) + ) + self.downs.append( + nn.Conv2d(chan, 2*chan, 2, 2) + ) + chan = chan * 2 + + self.middle_blks = \ + nn.Sequential( + *[N.NAFBlock(chan) for _ in range(middle_blk_num)] + ) + + for num in dec_blk_nums: + self.ups.append( + nn.Sequential( + nn.Conv2d(chan, chan * 2, 1, bias=False), + nn.PixelShuffle(2) + ) + ) + chan = chan // 2 + self.decoders.append( + nn.Sequential( + *[N.NAFBlock(chan) for _ in range(num)] + ) + ) + + self.padder_size = 2 ** len(self.encoders) + self.skip_conv = nn.Conv2d(width, width, kernel_size=1, bias=True) + + def forward(self, inp): + B, C, H, W = inp.shape + inp = self.check_image_size(inp) + inp = inverse_gamma(inverse_tonemap(inp)) + x_input = self.intro(inp) + x = self.intro_1(x_input) + + encs = [] + + for encoder, down in zip(self.encoders, self.downs): + x = encoder(x) + encs.append(x) + x = down(x) + + x = self.middle_blks(x) + + for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): + x = up(x) + x = x + enc_skip + x = decoder(x) + + # print(x.shape,self.skip_conv(x_input).shape) + x = x + self.skip_conv(x_input) + x = self.ending(x) + # x = x + inp + + return x[:, :, :H, :W] + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size + mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) + return x + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/net.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/net.py new file mode 100644 index 0000000..acb4495 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/net.py @@ -0,0 +1,534 @@ +from re import A +import torch +import torch.nn as nn +from torch.nn import init +from torch.optim import lr_scheduler +from collections import OrderedDict +from models.arch_util import LayerNorm2d +import numbers +from einops import rearrange +import torch.nn.functional as F + +train_size=(1,3,504, 504) +class AvgPool2d(nn.Module): + def __init__(self, kernel_size=None, base_size=[288,288], auto_pad=True, fast_imp=False): + super().__init__() + self.kernel_size = kernel_size + self.base_size = base_size + self.auto_pad = auto_pad + + # only used for fast implementation + self.fast_imp = fast_imp + self.rs = [5,4,3,2,1] + self.max_r1 = self.rs[0] + self.max_r2 = self.rs[0] + + def extra_repr(self) -> str: + return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format( + self.kernel_size, self.base_size, self.kernel_size, self.fast_imp + ) + + def forward(self, x): + if self.kernel_size is None and self.base_size: + if isinstance(self.base_size, int): + self.base_size = (self.base_size, self.base_size) + self.kernel_size = list(self.base_size) + self.kernel_size[0] = x.shape[2]*self.base_size[0]//train_size[-2] + self.kernel_size[1] = x.shape[3]*self.base_size[1]//train_size[-1] + + # only used for fast implementation + self.max_r1 = max(1, self.rs[0]*x.shape[2]//train_size[-2]) + self.max_r2 = max(1, self.rs[0]*x.shape[3]//train_size[-1]) + + if self.fast_imp: # Non-equivalent implementation but faster + h, w = x.shape[2:] + if self.kernel_size[0]>=h and self.kernel_size[1]>=w: + out = F.adaptive_avg_pool2d(x,1) + else: + r1 = [r for r in self.rs if h%r==0][0] + r2 = [r for r in self.rs if w%r==0][0] + # reduction_constraint + r1 = min(self.max_r1, r1) + r2 = min(self.max_r2, r2) + s = x[:,:,::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2) + n, c, h, w = s.shape + k1, k2 = min(h-1, self.kernel_size[0]//r1), min(w-1, self.kernel_size[1]//r2) + out = (s[:,:,:-k1,:-k2]-s[:,:,:-k1,k2:]-s[:,:,k1:,:-k2]+s[:,:,k1:,k2:])/(k1*k2) + out = torch.nn.functional.interpolate(out, scale_factor=(r1,r2)) + else: + n, c, h, w = x.shape + s = x.cumsum(dim=-1).cumsum(dim=-2) + s = torch.nn.functional.pad(s, (1,0,1,0)) # pad 0 for convenience + k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1]) + s1, s2, s3, s4 = s[:,:,:-k1,:-k2],s[:,:,:-k1,k2:], s[:,:,k1:,:-k2], s[:,:,k1:,k2:] + out = s4+s1-s2-s3 + out = out / (k1*k2) + + if self.auto_pad: + n, c, h, w = x.shape + _h, _w = out.shape[2:] + # print(x.shape, self.kernel_size) + pad2d = ((w - _w)//2, (w - _w + 1)//2, (h - _h) // 2, (h - _h + 1) // 2) + out = torch.nn.functional.pad(out, pad2d, mode='replicate') + + return out + +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + return 1 - max(0, epoch-opt.niter) / max(1, float(opt.niter_decay)) + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, + step_size=opt.lr_decay_iters, + gamma=0.5) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, + mode='min', + factor=0.2, + threshold=0.01, + patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, + T_max=opt.niter, + eta_min=1e-5) + else: + return NotImplementedError('lr [%s] is not implemented', opt.lr_policy) + return scheduler + +def init_weights(net, init_type='normal', init_gain=0.02): + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 \ + or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + elif init_type == 'uniform': + init.uniform_(m.weight.data, b=init_gain) + else: + raise NotImplementedError('[%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + +def init_net(net, init_type='default', init_gain=0.02, gpu_ids=[]): + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + if init_type != 'default' and init_type is not None: + init_weights(net, init_type, init_gain=init_gain) + return net + + +''' +# =================================== +# Advanced nn.Sequential +# reform nn.Sequentials and nn.Modules +# to a single nn.Sequential +# =================================== +''' + +def seq(*args): + if len(args) == 1: + args = args[0] + if isinstance(args, nn.Module): + return args + modules = OrderedDict() + if isinstance(args, OrderedDict): + for k, v in args.items(): + modules[k] = seq(v) + return nn.Sequential(modules) + assert isinstance(args, (list, tuple)) + return nn.Sequential(*[seq(i) for i in args]) + +''' +# =================================== +# Useful blocks +# -------------------------------- +# conv (+ normaliation + relu) +# concat +# sum +# resblock (ResBlock) +# resdenseblock (ResidualDenseBlock_5C) +# resinresdenseblock (RRDB) +# =================================== +''' + +# ------------------------------------------------------- +# return nn.Sequantial of (Conv + BN + ReLU) +# ------------------------------------------------------- +def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, + output_padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', mode='C'): + L = [] + for t in mode: + if t == 'C': + L.append(nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode)) + elif t == 'X': + assert in_channels == out_channels + L.append(nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=bias, + padding_mode=padding_mode)) + elif t == 'T': + L.append(nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + padding_mode=padding_mode)) + elif t == 'B': + L.append(nn.BatchNorm2d(out_channels)) + elif t == 'I': + L.append(nn.InstanceNorm2d(out_channels, affine=True)) + elif t == 'i': + L.append(nn.InstanceNorm2d(out_channels)) + elif t == 'R': + L.append(nn.ReLU(inplace=True)) + elif t == 'r': + L.append(nn.ReLU(inplace=False)) + elif t == 'S': + L.append(nn.Sigmoid()) + elif t == 'P': + L.append(nn.PReLU()) + elif t == 'L': + L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=True)) + elif t == 'l': + L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=False)) + elif t == '2': + L.append(nn.PixelShuffle(upscale_factor=2)) + elif t == '3': + L.append(nn.PixelShuffle(upscale_factor=3)) + elif t == '4': + L.append(nn.PixelShuffle(upscale_factor=4)) + elif t == 'U': + L.append(nn.Upsample(scale_factor=2, mode='nearest')) + elif t == 'u': + L.append(nn.Upsample(scale_factor=3, mode='nearest')) + elif t == 'M': + L.append(nn.MaxPool2d(kernel_size=kernel_size, + stride=stride, + padding=0)) + elif t == 'A': + L.append(nn.AvgPool2d(kernel_size=kernel_size, + stride=stride, + padding=0)) + else: + raise NotImplementedError('Undefined type: '.format(t)) + return seq(*L) + + +class DWTForward(nn.Conv2d): + def __init__(self, in_channels=64): + super(DWTForward, self).__init__(in_channels, in_channels*4, 2, 2, + groups=in_channels, bias=False) + weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]], + [[[0.5, 0.5], [-0.5, -0.5]]], + [[[0.5, -0.5], [ 0.5, -0.5]]], + [[[0.5, -0.5], [-0.5, 0.5]]]], + dtype=torch.get_default_dtype() + ).repeat(in_channels, 1, 1, 1)# / 2 + self.weight.data.copy_(weight) + self.requires_grad_(False) + + +class DWTInverse(nn.ConvTranspose2d): + def __init__(self, in_channels=64): + super(DWTInverse, self).__init__(in_channels, in_channels//4, 2, 2, + groups=in_channels//4, bias=False) + weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]], + [[[0.5, 0.5], [-0.5, -0.5]]], + [[[0.5, -0.5], [ 0.5, -0.5]]], + [[[0.5, -0.5], [-0.5, 0.5]]]], + dtype=torch.get_default_dtype() + ).repeat(in_channels//4, 1, 1, 1)# * 2 + self.weight.data.copy_(weight) + self.requires_grad_(False) + + + + +# ------------------------------------------------------- +# Res Block: x + conv(relu(conv(x))) +# ------------------------------------------------------- +class ResBlock(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1, bias=True, mode='CRC'): + super(ResBlock, self).__init__() + + assert in_channels == out_channels + if mode[0] in ['R','L']: + mode = mode[0].lower() + mode[1:] + + self.res = conv(in_channels, out_channels, kernel_size, + stride, padding=padding, bias=bias, mode=mode) + + def forward(self, x): + res = self.res(x) + return x + res + + + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 +########################################################################## +## NAF Block +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + self.sca = nn.Sequential( + AvgPool2d(), + nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + + +def to_4d(x, h, w): + return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma + 1e-5) * self.weight + + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type == 'BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim * ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, + groups=hidden_features * 2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + + +########################################################################## +#TransformerBlock +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/networks.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/networks.py new file mode 100644 index 0000000..e85f228 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/networks.py @@ -0,0 +1,471 @@ +from re import A +import torch +import torch.nn as nn +from torch.nn import init +from torch.optim import lr_scheduler +from collections import OrderedDict +from models.arch_util import LayerNorm2d +import numbers +from einops import rearrange +import torch.nn.functional as F + + +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + return 1 - max(0, epoch-opt.niter) / max(1, float(opt.niter_decay)) + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, + step_size=opt.lr_decay_iters, + gamma=0.5) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, + mode='min', + factor=0.2, + threshold=0.01, + patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, + T_max=opt.niter, + eta_min=1e-5) + else: + return NotImplementedError('lr [%s] is not implemented', opt.lr_policy) + return scheduler + +def init_weights(net, init_type='normal', init_gain=0.02): + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 \ + or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + elif init_type == 'uniform': + init.uniform_(m.weight.data, b=init_gain) + else: + raise NotImplementedError('[%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + +def init_net(net, init_type='default', init_gain=0.02, gpu_ids=[]): + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + if init_type != 'default' and init_type is not None: + init_weights(net, init_type, init_gain=init_gain) + return net + + +''' +# =================================== +# Advanced nn.Sequential +# reform nn.Sequentials and nn.Modules +# to a single nn.Sequential +# =================================== +''' + +def seq(*args): + if len(args) == 1: + args = args[0] + if isinstance(args, nn.Module): + return args + modules = OrderedDict() + if isinstance(args, OrderedDict): + for k, v in args.items(): + modules[k] = seq(v) + return nn.Sequential(modules) + assert isinstance(args, (list, tuple)) + return nn.Sequential(*[seq(i) for i in args]) + +''' +# =================================== +# Useful blocks +# -------------------------------- +# conv (+ normaliation + relu) +# concat +# sum +# resblock (ResBlock) +# resdenseblock (ResidualDenseBlock_5C) +# resinresdenseblock (RRDB) +# =================================== +''' + +# ------------------------------------------------------- +# return nn.Sequantial of (Conv + BN + ReLU) +# ------------------------------------------------------- +def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, + output_padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', mode='C'): + L = [] + for t in mode: + if t == 'C': + L.append(nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode)) + elif t == 'X': + assert in_channels == out_channels + L.append(nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=bias, + padding_mode=padding_mode)) + elif t == 'T': + L.append(nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + padding_mode=padding_mode)) + elif t == 'B': + L.append(nn.BatchNorm2d(out_channels)) + elif t == 'I': + L.append(nn.InstanceNorm2d(out_channels, affine=True)) + elif t == 'i': + L.append(nn.InstanceNorm2d(out_channels)) + elif t == 'R': + L.append(nn.ReLU(inplace=True)) + elif t == 'r': + L.append(nn.ReLU(inplace=False)) + elif t == 'S': + L.append(nn.Sigmoid()) + elif t == 'P': + L.append(nn.PReLU()) + elif t == 'L': + L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=True)) + elif t == 'l': + L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=False)) + elif t == '2': + L.append(nn.PixelShuffle(upscale_factor=2)) + elif t == '3': + L.append(nn.PixelShuffle(upscale_factor=3)) + elif t == '4': + L.append(nn.PixelShuffle(upscale_factor=4)) + elif t == 'U': + L.append(nn.Upsample(scale_factor=2, mode='nearest')) + elif t == 'u': + L.append(nn.Upsample(scale_factor=3, mode='nearest')) + elif t == 'M': + L.append(nn.MaxPool2d(kernel_size=kernel_size, + stride=stride, + padding=0)) + elif t == 'A': + L.append(nn.AvgPool2d(kernel_size=kernel_size, + stride=stride, + padding=0)) + else: + raise NotImplementedError('Undefined type: '.format(t)) + return seq(*L) + + +class DWTForward(nn.Conv2d): + def __init__(self, in_channels=64): + super(DWTForward, self).__init__(in_channels, in_channels*4, 2, 2, + groups=in_channels, bias=False) + weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]], + [[[0.5, 0.5], [-0.5, -0.5]]], + [[[0.5, -0.5], [ 0.5, -0.5]]], + [[[0.5, -0.5], [-0.5, 0.5]]]], + dtype=torch.get_default_dtype() + ).repeat(in_channels, 1, 1, 1)# / 2 + self.weight.data.copy_(weight) + self.requires_grad_(False) + + +class DWTInverse(nn.ConvTranspose2d): + def __init__(self, in_channels=64): + super(DWTInverse, self).__init__(in_channels, in_channels//4, 2, 2, + groups=in_channels//4, bias=False) + weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]], + [[[0.5, 0.5], [-0.5, -0.5]]], + [[[0.5, -0.5], [ 0.5, -0.5]]], + [[[0.5, -0.5], [-0.5, 0.5]]]], + dtype=torch.get_default_dtype() + ).repeat(in_channels//4, 1, 1, 1)# * 2 + self.weight.data.copy_(weight) + self.requires_grad_(False) + + + + +# ------------------------------------------------------- +# Res Block: x + conv(relu(conv(x))) +# ------------------------------------------------------- +class ResBlock(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1, bias=True, mode='CRC'): + super(ResBlock, self).__init__() + + assert in_channels == out_channels + if mode[0] in ['R','L']: + mode = mode[0].lower() + mode[1:] + + self.res = conv(in_channels, out_channels, kernel_size, + stride, padding=padding, bias=bias, mode=mode) + + def forward(self, x): + res = self.res(x) + return x + res + + + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 +########################################################################## +## NAF Block +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, + groups=1, bias=True), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + + +def to_4d(x, h, w): + return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma + 1e-5) * self.weight + + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type == 'BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim * ffn_expansion_factor) + + self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, + groups=hidden_features * 2, bias=bias) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + + +########################################################################## +#TransformerBlock +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) + + def forward(self, x): + x = self.proj(x) + + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/utils.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/utils.py new file mode 100644 index 0000000..d1bef31 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/models/utils.py @@ -0,0 +1,52 @@ +import math +import torch + + +def compute_same_pad(kernel_size, stride): + if isinstance(kernel_size, int): + kernel_size = [kernel_size] + + if isinstance(stride, int): + stride = [stride] + + assert len(stride) == len( + kernel_size + ), "Pass kernel size and stride both as int, or both as equal length iterable" + + return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)] + + +def uniform_binning_correction(x, n_bits=8): + """Replaces x^i with q^i(x) = U(x, x + 1.0 / 256.0). + + Args: + x: 4-D Tensor of shape (NCHW) + n_bits: optional. + Returns: + x: x ~ U(x, x + 1.0 / 256) + objective: Equivalent to -q(x)*log(q(x)). + """ + b, c, h, w = x.size() + n_bins = 2 ** n_bits + chw = c * h * w + x += torch.zeros_like(x).uniform_(0, 1.0 / n_bins) + + objective = -math.log(n_bins) * chw * torch.ones(b, device=x.device) + return x, objective + + +def split_feature(tensor, type="split"): + """ + type = ["split", "cross"] + """ + C = tensor.size(1) + if type == "split": + # return tensor[:, : C // 2, ...], tensor[:, C // 2 :, ...] + return tensor[:, :1, ...], tensor[:,1:, ...] + elif type == "cross": + # return tensor[:, 0::2, ...], tensor[:, 1::2, ...] + return tensor[:, 0::2, ...], tensor[:, 1::2, ...] + + + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/options/__init__.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/options/__init__.py new file mode 100644 index 0000000..17ab6b7 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/options/__init__.py @@ -0,0 +1 @@ +"""This package options includes option modules.""" diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/options/base_options.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/options/base_options.py new file mode 100644 index 0000000..21ff814 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/options/base_options.py @@ -0,0 +1,214 @@ +import argparse +import os +import re +from util import util +import torch +import models +import time + +def str2bool(v): + return v.lower() in ('yes', 'y', 'true', 't', '1') + +inf = float('inf') + +class BaseOptions(): + def __init__(self): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + # data parameters + parser.add_argument('--dataroot', type=str, default='') + parser.add_argument('--dataset_name', type=str, default=['eth'], nargs='+') + parser.add_argument('--max_dataset_size', type=int, default=inf) + parser.add_argument('--scale', type=int, default=4, help='Super-resolution scale.') + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--patch_size', type=int, default=224) + parser.add_argument('--shuffle', type=str2bool, default=True) + parser.add_argument('-j', '--num_dataloader', default=4, type=int) + parser.add_argument('--drop_last', type=str2bool, default=True) + + # device parameters + parser.add_argument('--gpu_ids', type=str, default='all', + help='Separate the GPU ids by `,`, using all GPUs by default. ' + 'eg, `--gpu_ids 0`, `--gpu_ids 2,3`, `--gpu_ids -1`(CPU)') + parser.add_argument('--checkpoints_dir', type=str, default='./ckpt') + parser.add_argument('-v', '--verbose', type=str2bool, default=True) + parser.add_argument('--suffix', default='', type=str) + + # model parameters + parser.add_argument('--name', type=str, required=True, + help='Name of the folder to save models and logs.') + parser.add_argument('--model', type=str, required=True) + parser.add_argument('--load_path', type=str, default='', + help='Will load pre-trained model if load_path is set') + parser.add_argument('--load_iter', type=int, default=[0], nargs='+', + help='Load parameters if > 0 and load_path is not set. ' + 'Set the value of `last_epoch`') + parser.add_argument('--gcm_coord', type=str2bool, default=True) + parser.add_argument('--pre_ispnet_coord', type=str2bool, default=True) + parser.add_argument('--chop', type=str2bool, default=False) + + # training parameters + parser.add_argument('--init_type', type=str, default='default', + choices=['default', 'normal', 'xavier', + 'kaiming', 'orthogonal', 'uniform'], + help='`default` means using PyTorch default init functions.') + parser.add_argument('--init_gain', type=float, default=0.02) + # parser.add_argument('--loss', type=str, default='L1', + # help='choose from [L1, MSE, SSIM, VGG, PSNR]') + parser.add_argument('--optimizer', type=str, default='Adam', + choices=['Adam', 'SGD', 'RMSprop']) + parser.add_argument('--niter', type=int, default=1000) + parser.add_argument('--niter_decay', type=int, default=0) + parser.add_argument('--lr_policy', type=str, default='step') + parser.add_argument('--lr_decay_iters', type=int, default=200) + parser.add_argument('--lr', type=float, default=0.0001) + + # Optimizer + parser.add_argument('--load_optimizers', type=str2bool, default=False, + help='Loading optimizer parameters for continuing training.') + parser.add_argument('--weight_decay', type=float, default=0) + # Adam + parser.add_argument('--beta1', type=float, default=0.9) + parser.add_argument('--beta2', type=float, default=0.999) + # SGD & RMSprop + parser.add_argument('--momentum', type=float, default=0) + # RMSprop + parser.add_argument('--alpha', type=float, default=0.99) + + # visualization parameters + parser.add_argument('--print_freq', type=int, default=100) + parser.add_argument('--test_every', type=int, default=1000) + parser.add_argument('--save_epoch_freq', type=int, default=1) + parser.add_argument('--calc_metrics', type=str2bool, default=False) + parser.add_argument('--save_imgs', type=str2bool, default=False) + parser.add_argument('--visual_full_imgs', type=str2bool, default=False) + + # p20 test parameters + parser.add_argument('--TLC', type=str2bool, default=True) + parser.add_argument('--save_path', type=str, default='./submission') + parser.add_argument('--test_patch', type=int, default=248) + + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are difined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class= + argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, _ = parser.parse_known_args() + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + opt, _ = parser.parse_known_args() # parse again with new defaults + + # save and return the parser + self.parser = parser + return parser.parse_args() + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, 'opt_%s.txt' + % ('train' if self.isTrain else 'test')) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + def parse(self): + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + opt.serial_batches = not opt.shuffle + + if self.isTrain and (opt.load_iter != [0] or opt.load_path != '') \ + and not opt.load_optimizers: + util.prompt('You are loading a checkpoint and continuing training, ' + 'and no optimizer parameters are loaded. Please make ' + 'sure that the hyper parameters are correctly set.', 80) + time.sleep(3) + + opt.model = opt.model.lower() + opt.name = opt.name.lower() + + scale_patch = {2: 96, 3: 144, 4: 192} + if opt.patch_size is None: + opt.patch_size = scale_patch[opt.scale] + + if opt.name.startswith(opt.checkpoints_dir): + opt.name = opt.name.replace(opt.checkpoints_dir+'/', '') + if opt.name.endswith('/'): + opt.name = opt.name[:-1] + + if len(opt.dataset_name) == 1: + opt.dataset_name = opt.dataset_name[0] + + if len(opt.load_iter) == 1: + opt.load_iter = opt.load_iter[0] + + # process opt.suffix + if opt.suffix != '': + suffix = ('_' + opt.suffix.format(**vars(opt))) + opt.name = opt.name + suffix + + self.print_options(opt) + + # set gpu ids + cuda_device_count = torch.cuda.device_count() + if opt.gpu_ids == 'all': + # GT 710 (3.5), GT 610 (2.1) + gpu_ids = [i for i in range(cuda_device_count)] + else: + p = re.compile('[^-0-9]+') + gpu_ids = [int(i) for i in re.split(p, opt.gpu_ids) if int(i) >= 0] + opt.gpu_ids = [i for i in gpu_ids \ + if torch.cuda.get_device_capability(i) >= (4,0)] + + if len(opt.gpu_ids) == 0 and len(gpu_ids) > 0: + opt.gpu_ids = gpu_ids + util.prompt('You\'re using GPUs with computing capability < 4') + elif len(opt.gpu_ids) != len(gpu_ids): + util.prompt('GPUs(computing capability < 4) have been disabled') + + if len(opt.gpu_ids) > 0: + assert torch.cuda.is_available(), 'No cuda available !!!' + torch.cuda.set_device(opt.gpu_ids[0]) + print('The GPUs you are using:') + for gpu_id in opt.gpu_ids: + print(' %2d *%s* with capability %d.%d' % ( + gpu_id, + torch.cuda.get_device_name(gpu_id), + *torch.cuda.get_device_capability(gpu_id))) + else: + util.prompt('You are using CPU mode') + + self.opt = opt + return self.opt diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/options/test_options.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/options/test_options.py new file mode 100644 index 0000000..26e404d --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/options/test_options.py @@ -0,0 +1,8 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + self.isTrain = False + return parser diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/options/train_options.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/options/train_options.py new file mode 100644 index 0000000..c44ea20 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/options/train_options.py @@ -0,0 +1,8 @@ +from .base_options import BaseOptions + + +class TrainOptions(BaseOptions): + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + self.isTrain = True + return parser diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/pwc/correlation/correlation.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/pwc/correlation/correlation.py new file mode 100644 index 0000000..c9c97e3 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/pwc/correlation/correlation.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Correlation_rearrange = ''' + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + + if (intIndex >= n) { + return; + } + + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + + float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + + __syncthreads(); + + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; + } +''' + +kernel_Correlation_updateOutput = ''' + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[32]; + + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +''' + +kernel_Correlation_updateGradFirst = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradFirst( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradFirst); // channels + int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos + int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradFirst); + const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); + gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; + } } +''' + +kernel_Correlation_updateGradSecond = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradSecond( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradSecond); // channels + int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos + int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradSecond); + const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); + gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; + } } +''' + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + # end + + return strKernel +# end + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) +# end + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) + rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) + + self.save_for_backward(first, second, rbot0, rbot1) + + assert(first.is_contiguous() == True) + assert(second.is_contiguous() == True) + + output = first.new_zeros([ first.shape[0], 81, first.shape[2], first.shape[3] ]) + + if first.is_cuda == True: + n = first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': first, + 'output': rbot0 + }))( + grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]), + block=tuple([ 16, 1, 1 ]), + args=[ n, first.data_ptr(), rbot0.data_ptr() ] + ) + + n = second.shape[2] * second.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': second, + 'output': rbot1 + }))( + grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]), + block=tuple([ 16, 1, 1 ]), + args=[ n, second.data_ptr(), rbot1.data_ptr() ] + ) + + n = output.shape[1] * output.shape[2] * output.shape[3] + cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'top': output + }))( + grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]), + block=tuple([ 32, 1, 1 ]), + shared_mem=first.shape[1] * 4, + args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ] + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return output + # end + + @staticmethod + def backward(self, gradOutput): + first, second, rbot0, rbot1 = self.saved_tensors + + assert(gradOutput.is_contiguous() == True) + + gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None + gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None + + if first.is_cuda == True: + if gradFirst is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': gradFirst, + 'gradSecond': None + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ] + ) + # end + # end + + if gradSecond is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': None, + 'gradSecond': gradSecond + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ] + ) + # end + # end + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return gradFirst, gradSecond + # end +# end + +def FunctionCorrelation(tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) +# end + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super(ModuleCorrelation, self).__init__() + # end + + def forward(self, tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) + # end +# end \ No newline at end of file diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/pwc/pwc_net.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/pwc/pwc_net.py new file mode 100644 index 0000000..add26c4 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/pwc/pwc_net.py @@ -0,0 +1,251 @@ +#-*- encoding: UTF-8 -*- + +import torch +import sys +from functools import partial +import pickle + +try: + from pwc.correlation import correlation # the custom cost volume layer +except: + sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python + + +# Borrow the code of the optical flow network (PWC-Net) from https://github.com/sniklaus/pytorch-pwc/ +class PWCNET(torch.nn.Module): + def __init__(self): + super(PWCNET, self).__init__() + + class Extractor(torch.nn.Module): + def __init__(self): + super(Extractor, self).__init__() + + self.netOne = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netTwo = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netThr = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFou = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFiv = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netSix = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + def forward(self, tenInput): + tenOne = self.netOne(tenInput) + tenTwo = self.netTwo(tenOne) + tenThr = self.netThr(tenTwo) + tenFou = self.netFou(tenThr) + tenFiv = self.netFiv(tenFou) + tenSix = self.netSix(tenFiv) + + return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ] + + class Decoder(torch.nn.Module): + def __init__(self, intLevel): + super(Decoder, self).__init__() + + intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, + 81 + 128 + 2 + 2, 81, None ][intLevel + 1] + intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, + 81 + 128 + 2 + 2, 81, None ][intLevel + 0] + + self.backwarp_tenGrid = {} + self.backwarp_tenPartial = {} + + if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, + out_channels=2, kernel_size=4, stride=2, padding=1) + if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d( + in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, + kernel_size=4, stride=2, padding=1) + if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1] + + self.netOne = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, + kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netTwo = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, + kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netThr = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, + kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFou = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, + kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFiv = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, + kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netSix = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, + kernel_size=3, stride=1, padding=1) + ) + + def forward(self, tenFirst, tenSecond, objPrevious): + tenFlow = None + tenFeat = None + + if objPrevious is None: + tenFlow = None + tenFeat = None + tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation( + tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False) + tenFeat = torch.cat([ tenVolume ], 1) + + elif objPrevious is not None: + tenFlow = self.netUpflow(objPrevious['tenFlow']) + tenFeat = self.netUpfeat(objPrevious['tenFeat']) + + tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation( + tenFirst=tenFirst, tenSecond=self.backwarp(tenInput=tenSecond, + tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False) + + tenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1) + + tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1) + tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1) + tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1) + tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1) + tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1) + + tenFlow = self.netSix(tenFeat) + + return { + 'tenFlow': tenFlow, + 'tenFeat': tenFeat + } + + def backwarp(self, tenInput, tenFlow): + index = str(tenFlow.shape) + str(tenInput.device) + if index not in self.backwarp_tenGrid: + tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), + tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1) + tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), + tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3]) + self.backwarp_tenGrid[index] = torch.cat([ tenHor, tenVer ], 1).to(tenInput.device) + + if index not in self.backwarp_tenPartial: + self.backwarp_tenPartial[index] = tenFlow.new_ones([ tenFlow.shape[0], + 1, tenFlow.shape[2], tenFlow.shape[3] ]) + + tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) + tenInput = torch.cat([ tenInput, self.backwarp_tenPartial[index] ], 1) + + tenOutput = torch.nn.functional.grid_sample(input=tenInput, + grid=(self.backwarp_tenGrid[index] + tenFlow).permute(0, 2, 3, 1), + mode='bilinear', padding_mode='zeros', align_corners=False) + + tenMask = tenOutput[:, -1:, :, :]; tenMask[tenMask > 0.999] = 1.0; tenMask[tenMask < 1.0] = 0.0 + + return tenOutput[:, :-1, :, :] * tenMask + + class Refiner(torch.nn.Module): + def __init__(self): + super(Refiner, self).__init__() + self.netMain = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, + out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1) + ) + + def forward(self, tenInput): + return self.netMain(tenInput) + + self.netExtractor = Extractor() + + self.netTwo = Decoder(2) + self.netThr = Decoder(3) + self.netFou = Decoder(4) + self.netFiv = Decoder(5) + self.netSix = Decoder(6) + + self.netRefiner = Refiner() + + pickle.load = partial(pickle.load, encoding="latin1") + pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") + + self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight + in torch.load('./ckpt/pwc-net.pth', map_location=lambda storage, + loc: storage, pickle_module=pickle).items() }) + + def forward(self, tenFirst, tenSecond): + tenFirst = self.netExtractor(tenFirst) + tenSecond = self.netExtractor(tenSecond) + objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) + objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) + objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) + objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) + objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) + return objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat']) + + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/test_full.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/test_full.py new file mode 100644 index 0000000..237242a --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/test_full.py @@ -0,0 +1,109 @@ +import os +import torch +from options.test_options import TestOptions +from data import create_dataset +from models import create_model +from util.visualizer import Visualizer +from tqdm import tqdm +from util.util import calc_psnr as calc_psnr +import time +import numpy as np +from collections import OrderedDict as odict +from copy import deepcopy +from util.util import pack_rggb_channels +from os.path import join +from tensorboardX import SummaryWriter +import cv2 +from util.util import pack_rggb_channels +import glob + +def save_rgb (img, filename): + if np.max(img) <= 1: + img = img * 255 + + img = img.astype(np.float32) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + cv2.imwrite(filename, img) + +def log(log_file, str, also_print=True): + with open(log_file, 'a+') as F: + F.write(str) + if also_print: + print(str, end='') + +def _open_img(img_p,ratio): + return np.load(img_p, allow_pickle=True).astype(float) / ratio + + +if __name__ == '__main__': + opt = TestOptions().parse() + + if not isinstance(opt.load_iter, list): + load_iters = [opt.load_iter] + else: + load_iters = deepcopy(opt.load_iter) + + if not isinstance(opt.dataset_name, list): + dataset_names = [opt.dataset_name] + else: + dataset_names = deepcopy(opt.dataset_name) + datasets = odict() + for dataset_name in dataset_names: + if opt.visual_full_imgs: + dataset = create_dataset(dataset_name, 'visual', opt) + else: + dataset = create_dataset(dataset_name, 'test', opt) + datasets[dataset_name] = tqdm(dataset) + + + + for load_iter in load_iters: + opt.load_iter = load_iter + model = create_model(opt) + model.setup(opt) + model.eval() + + for dataset_name in dataset_names: + opt.dataset_name = dataset_name + tqdm_val = datasets[dataset_name] + dataset_test = tqdm_val.iterable + dataset_size_test = len(dataset_test) + + print('='*80) + print(dataset_name + ' dataset') + tqdm_val.reset() + + + for i, data in enumerate(tqdm_val): + + model.set_input(data) + model.test() + res = model.get_current_visuals() + recon_raw = res['data_out'][0].detach().permute(1, 2, 0).numpy() + + ratio = 1020 + folder_dir = opt.save_path + PS = opt.test_patch + os.makedirs(folder_dir, exist_ok=True) + + H,W,C= recon_raw.shape + pic_i=0 + avg_ps=0 + for rr in np.arange(0, H - PS + 1, PS): + for cc in np.arange(0, W - PS + 1, PS): + + raw_patch = recon_raw[rr:rr + PS, cc:cc + PS,:] + pic_index=data['fname'][0].split('/')[-1].split('_')[-1].split('.')[0]+'_'+str(pic_i)+'.npy' + raw_patch=pack_rggb_channels(raw_patch) + raw_patch = (raw_patch * ratio).astype(np.uint16) + save_dir = '%s/%s' % (folder_dir, pic_index) + os.makedirs(folder_dir, exist_ok=True) + np.save(save_dir, raw_patch) + pic_i+=1 + + + for dataset in datasets: + datasets[dataset].close() + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/test_p20.sh b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/test_p20.sh new file mode 100644 index 0000000..b376ea4 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/test_p20.sh @@ -0,0 +1,14 @@ +#!/bin/bash +echo "Start to test the model...." + +name="p20" +dataroot="/home/work/ssd1/hagongda/lxy/data_p20/test_full" +save_path='./submission' + + +python test_full.py \ +--model naf --name $name --dataset_name p20patch --pre_ispnet_coord False --gcm_coord False \ +--load_iter 82 --batch_size 1 --gpu_ids -1 --save_imgs True --calc_metrics True --visual_full_imgs False -j 3 \ +--dataroot $dataroot --save_path $save_path + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/train.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/train.py new file mode 100644 index 0000000..9c8684c --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/train.py @@ -0,0 +1,126 @@ +import time +import torch +from options.train_options import TrainOptions +from data import create_dataset +from models import create_model +from util.visualizer import Visualizer +from tqdm import tqdm +import numpy as np +import math +import sys +import torch.multiprocessing as mp + +from util.util import calc_psnr as calc_psnr +from util.AISP_utils import demosaic, postprocess_raw, plot_pair +from util.util import pack_rggb_channels +import os +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" +import cv2 + + +# from skimage.exposure import match_histograms + +def save_rgb(img, filename): + if np.max(img) <= 1: + img = img * 255 + + img = img.astype(np.float32) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + cv2.imwrite(filename, img) + + +if __name__ == '__main__': + opt = TrainOptions().parse() + dataset_train = create_dataset(opt.dataset_name, 'train', opt) + dataset_size_train = len(dataset_train) + print('The number of training images = %d' % dataset_size_train) + dataset_val = create_dataset(opt.dataset_name, 'val', opt) + dataset_size_val = len(dataset_val) + print('The number of val images = %d' % dataset_size_val) + + model = create_model(opt) + model.setup(opt) + visualizer = Visualizer(opt) + total_iters = ((model.start_epoch * (dataset_size_train // opt.batch_size)) \ + // opt.print_freq) * opt.print_freq + + for epoch in range(model.start_epoch + 1, opt.niter + opt.niter_decay + 1): + # training + epoch_start_time = time.time() + epoch_iter = 0 + model.train() + + iter_data_time = iter_start_time = time.time() + for i, data in enumerate(dataset_train): + if total_iters % opt.print_freq == 0: + t_data = time.time() - iter_data_time + total_iters += 1 + epoch_iter += 1 + model.set_input(data) + model.optimize_parameters() + res = model.get_current_visuals() + + if opt.save_imgs: + + psnr_train = calc_psnr(data['raw'], res['data_out'].detach().cpu()) + print(data['fname'][0], psnr_train) + res = model.get_current_visuals() + folder_dir = './ckpt/%s/output_train' % (opt.name); + os.makedirs(folder_dir, exist_ok=True) + + save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_dslr') + dslr = res['data_dslr'][0].cpu().permute(1, 2, 0).numpy(); + save_rgb(dslr, save_dir) + + save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_GCMModel_out_warp_all') + dslr = res['GCMModel_out_warp'][0].cpu().permute(1, 2, 0).numpy(); + save_rgb(dslr, save_dir) + + save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_GCMModel_out_warp_test') + dslr = res['mask_test'][0].cpu().permute(1, 2, 0).numpy(); + save_rgb(dslr, save_dir) + + + if total_iters % opt.print_freq == 0: + losses = model.get_current_losses() + t_comp = (time.time() - iter_start_time) + visualizer.print_current_losses( + epoch, epoch_iter, losses, t_comp, t_data, total_iters) + iter_start_time = time.time() + + iter_data_time = time.time() + if epoch % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d, iters %d' + % (epoch, total_iters)) + model.save_networks(epoch) + + print('End of epoch %d / %d \t Time Taken: %.3f sec' + % (epoch, opt.niter + opt.niter_decay, + time.time() - epoch_start_time)) + model.update_learning_rate() + + # val + if opt.calc_metrics: + model.eval() + val_iter_time = time.time() + tqdm_val = tqdm(dataset_val) + psnr = [0.0] * dataset_size_val + time_val = 0 + for i, data in enumerate(tqdm_val): + model.set_input(data) + time_val_start = time.time() + with torch.no_grad(): + model.test() + time_val += time.time() - time_val_start + res = model.get_current_visuals() + psnr[i] = calc_psnr(res['data_raw'].detach().cpu(), res['data_out'].detach().cpu()) + visualizer.print_psnr(epoch, opt.niter + opt.niter_decay, time_val, np.mean(psnr)) + + sys.stdout.flush() + + + + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/train_p20.sh b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/train_p20.sh new file mode 100644 index 0000000..11104f2 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/train_p20.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +echo "Start to train the model...." + +name="p20" +dataroot="/home/work/ssd1/hagongda/lxy/data_p20" + +build_dir="./ckpt/"$name + +if [ ! -d "$build_dir" ]; then + mkdir $build_dir +fi + +LOG=./ckpt/$name/`date +%Y-%m-%d-%H-%M-%S`.txt + + +python train.py \ + --dataset_name p20patch --model s7naf2 --name $name --gcm_coord False \ + --pre_ispnet_coord False --niter 200 --lr_policy cosine --save_imgs False \ + --batch_size 4 --print_freq 300 --calc_metrics True --lr 3e-4 -j 8 \ + --weight_decay 0.001 --patch_size 704 --load_iter 99 --load_optimizers True \ + --dataroot $dataroot | tee $LOG + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/AISP_utils.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/AISP_utils.py new file mode 100644 index 0000000..783d2dd --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/AISP_utils.py @@ -0,0 +1,191 @@ +import cv2 +import numpy as np +import rawpy +import matplotlib.pyplot as plt +import imageio + + +def extract_bayer_channels(raw): + + ch_B = raw[1::2, 1::2] + ch_Gb = raw[0::2, 1::2] + ch_R = raw[0::2, 0::2] + ch_Gr = raw[1::2, 0::2] + + return ch_R, ch_Gr, ch_B, ch_Gb + +def load_rawpy (raw_file): + raw = rawpy.imread(raw_file) + raw_image = raw.raw_image + return raw_image + +def load_img (filename, debug=False, norm=True, resize=None): + img = cv2.imread(filename) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if norm: + img = img / 255. + img = img.astype(np.float32) + if debug: + print (img.shape, img.dtype, img.min(), img.max()) + + if resize: + img = cv2.resize(img, (resize[0], resize[1]), interpolation = cv2.INTER_AREA) + + return img + +def save_rgb (img, filename): + if np.max(img) <= 1: + img = img * 255 + + img = img.astype(np.float32) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + cv2.imwrite(filename, img) + +def load_raw_png(raw, debug=False): + ''' + Load RAW images from the ZurichRAW2RGB Dataset + Reference: https://github.com/aiff22/PyNET-PyTorch/blob/master/dng_to_png.py + by Andrey Ignatov. + + inputs: + - raw: filename to the raw image saved as '.png' + returns: + - RAW_norm: normalized float32 4-channel raw image with bayer pattern RGGB. + ''' + + assert '.png' in raw + raw = np.asarray(imageio.imread((raw))) + ch_R, ch_Gr, ch_B, ch_Gb = extract_bayer_channels (raw) + + RAW_combined = np.dstack((ch_R, ch_Gr, ch_Gb, ch_B)) + RAW_norm = RAW_combined.astype(np.float32) / (4 * 255) + RAW_norm = np.clip(RAW_norm, 0, 1) + + if debug: + print (RAW_norm.shape, RAW_norm.dtype, RAW_norm.min(), RAW_norm.max()) + + # raw as (h,w,1) in RGBG domain! do not use + raw_unpack = raw.astype(np.float32) / (4 * 255) + raw_unpack = np.expand_dims(raw_unpack, axis=-1) + + return RAW_norm + +# def load_raw(raw, max_val=2**10): +# raw = np.load (raw)/ max_val +# return raw.astype(np.float32) + + +########## RAW image manipulation + +def unpack_raw(im): + """ + Unpack RAW image from (h,w,4) to (h*2 , w*2, 1) + """ + h,w,chan = im.shape + H, W = h*2, w*2 + img2 = np.zeros((H,W)) + img2[0:H:2,0:W:2]=im[:,:,0] + img2[0:H:2,1:W:2]=im[:,:,1] + img2[1:H:2,0:W:2]=im[:,:,2] + img2[1:H:2,1:W:2]=im[:,:,3] + img2 = np.squeeze(img2) + img2 = np.expand_dims(img2, axis=-1) + return img2 + +def pack_raw(im): + """ + Pack RAW image from (h,w,1) to (h/2 , w/2, 4) + """ + img_shape = im.shape + H = img_shape[0] + W = img_shape[1] + ## R G G B + out = np.concatenate((im[0:H:2,0:W:2,:], + im[0:H:2,1:W:2,:], + im[1:H:2,0:W:2,:], + im[1:H:2,1:W:2,:]), axis=2) + return out + + + +########## VISUALIZATION + +def demosaic (raw): + """Simple demosaicing to visualize RAW images + Inputs: + - raw: (h,w,4) RAW RGGB image normalized [0..1] as float32 + Returns: + - Simple Avg. Green Demosaiced RAW image with shape (h*2, w*2, 3) + """ + + assert raw.shape[-1] == 4 + shape = raw.shape + + red = raw[:,:,0] + green_red = raw[:,:,1] + green_blue = raw[:,:,2] + blue = raw[:,:,3] + avg_green = (green_red + green_blue) / 2 + image = np.stack((red, avg_green, blue), axis=-1) + image = cv2.resize(image, (shape[1]*2, shape[0]*2)) + return image + + +def mosaic(rgb): + """Extracts RGGB Bayer planes from an RGB image.""" + + assert rgb.shape[-1] == 3 + shape = rgb.shape + + red = rgb[0::2, 0::2, 0] + green_red = rgb[0::2, 1::2, 1] + green_blue = rgb[1::2, 0::2, 1] + blue = rgb[1::2, 1::2, 2] + + image = np.stack((red, green_red, green_blue, blue), axis=-1) + return image + + +def gamma_compression(image): + """Converts from linear to gamma space.""" + return np.maximum(image, 1e-8) ** (1.0 / 2.2) + +def tonemap(image): + """Simple S-curved global tonemap""" + return (3*(image**2)) - (2*(image**3)) + +def postprocess_raw(raw): + """Simple post-processing to visualize demosaic RAW imgaes + Input: (h,w,3) RAW image normalized + Output: (h,w,3) post-processed RAW image + """ + raw = gamma_compression(raw) + raw = tonemap(raw) + raw = np.clip(raw, 0, 1) + return raw + +def plot_pair (rgb, raw, t1='RGB', t2='RAW', axis='off'): + + fig = plt.figure(figsize=(12, 6), dpi=80) + plt.subplot(1,2,1) + plt.title(t1) + plt.axis(axis) + plt.imshow(rgb) + + plt.subplot(1,2,2) + plt.title(t2) + plt.axis(axis) + plt.imshow(raw) + plt.show() + +########## METRICS + +def PSNR(y_true, y_pred): + mse = np.mean((y_true - y_pred) ** 2) + if(mse == 0): + return np.inf + + max_pixel = np.max(y_true) + psnr = 20 * np.log10(max_pixel / np.sqrt(mse)) + return psnr \ No newline at end of file diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/__init__.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/__init__.py new file mode 100644 index 0000000..e2f595e --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/__init__.py @@ -0,0 +1 @@ +"""This package includes a miscellaneous collection of helper functions.""" diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/myssim.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/myssim.py new file mode 100644 index 0000000..eb5d9db --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/myssim.py @@ -0,0 +1,371 @@ +from __future__ import division, absolute_import, print_function + + + +import numpy as np +# from numpy.lib.arraypad import _validate_lengths +from scipy.ndimage import uniform_filter, gaussian_filter + +dtype_range = {np.bool_: (False, True), + np.bool8: (False, True), + np.uint8: (0, 255), + np.uint16: (0, 65535), + np.uint32: (0, 2**32 - 1), + np.uint64: (0, 2**64 - 1), + np.int8: (-128, 127), + np.int16: (-32768, 32767), + np.int32: (-2**31, 2**31 - 1), + np.int64: (-2**63, 2**63 - 1), + np.float16: (-1, 1), + np.float32: (-1, 1), + np.float64: (-1, 1)} + +def _normalize_shape(ndarray, shape, cast_to_int=True): + """ + Private function which does some checks and normalizes the possibly + much simpler representations of 'pad_width', 'stat_length', + 'constant_values', 'end_values'. + + Parameters + ---------- + narray : ndarray + Input ndarray + shape : {sequence, array_like, float, int}, optional + The width of padding (pad_width), the number of elements on the + edge of the narray used for statistics (stat_length), the constant + value(s) to use when filling padded regions (constant_values), or the + endpoint target(s) for linear ramps (end_values). + ((before_1, after_1), ... (before_N, after_N)) unique number of + elements for each axis where `N` is rank of `narray`. + ((before, after),) yields same before and after constants for each + axis. + (constant,) or val is a shortcut for before = after = constant for + all axes. + cast_to_int : bool, optional + Controls if values in ``shape`` will be rounded and cast to int + before being returned. + + Returns + ------- + normalized_shape : tuple of tuples + val => ((val, val), (val, val), ...) + [[val1, val2], [val3, val4], ...] => ((val1, val2), (val3, val4), ...) + ((val1, val2), (val3, val4), ...) => no change + [[val1, val2], ] => ((val1, val2), (val1, val2), ...) + ((val1, val2), ) => ((val1, val2), (val1, val2), ...) + [[val , ], ] => ((val, val), (val, val), ...) + ((val , ), ) => ((val, val), (val, val), ...) + + """ + ndims = ndarray.ndim + + # Shortcut shape=None + if shape is None: + return ((None, None), ) * ndims + + # Convert any input `info` to a NumPy array + shape_arr = np.asarray(shape) + + try: + shape_arr = np.broadcast_to(shape_arr, (ndims, 2)) + except ValueError: + fmt = "Unable to create correctly shaped tuple from %s" + raise ValueError(fmt % (shape,)) + + # Cast if necessary + if cast_to_int is True: + shape_arr = np.round(shape_arr).astype(int) + + # Convert list of lists to tuple of tuples + return tuple(tuple(axis) for axis in shape_arr.tolist()) + + +def _validate_lengths(narray, number_elements): + """ + Private function which does some checks and reformats pad_width and + stat_length using _normalize_shape. + + Parameters + ---------- + narray : ndarray + Input ndarray + number_elements : {sequence, int}, optional + The width of padding (pad_width) or the number of elements on the edge + of the narray used for statistics (stat_length). + ((before_1, after_1), ... (before_N, after_N)) unique number of + elements for each axis. + ((before, after),) yields same before and after constants for each + axis. + (constant,) or int is a shortcut for before = after = constant for all + axes. + + Returns + ------- + _validate_lengths : tuple of tuples + int => ((int, int), (int, int), ...) + [[int1, int2], [int3, int4], ...] => ((int1, int2), (int3, int4), ...) + ((int1, int2), (int3, int4), ...) => no change + [[int1, int2], ] => ((int1, int2), (int1, int2), ...) + ((int1, int2), ) => ((int1, int2), (int1, int2), ...) + [[int , ], ] => ((int, int), (int, int), ...) + ((int , ), ) => ((int, int), (int, int), ...) + + """ + normshp = _normalize_shape(narray, number_elements) + for i in normshp: + chk = [1 if x is None else x for x in i] + chk = [1 if x >= 0 else -1 for x in chk] + if (chk[0] < 0) or (chk[1] < 0): + fmt = "%s cannot contain negative values." + raise ValueError(fmt % (number_elements,)) + return normshp + + + +def crop(ar, crop_width, copy=False, order='K'): + """Crop array `ar` by `crop_width` along each dimension. + Parameters + ---------- + ar : array-like of rank N + Input array. + crop_width : {sequence, int} + Number of values to remove from the edges of each axis. + ``((before_1, after_1),`` ... ``(before_N, after_N))`` specifies + unique crop widths at the start and end of each axis. + ``((before, after),)`` specifies a fixed start and end crop + for every axis. + ``(n,)`` or ``n`` for integer ``n`` is a shortcut for + before = after = ``n`` for all axes. + copy : bool, optional + If `True`, ensure the returned array is a contiguous copy. Normally, + a crop operation will return a discontiguous view of the underlying + input array. + order : {'C', 'F', 'A', 'K'}, optional + If ``copy==True``, control the memory layout of the copy. See + ``np.copy``. + Returns + ------- + cropped : array + The cropped array. If ``copy=False`` (default), this is a sliced + view of the input array. + """ + ar = np.array(ar, copy=False) + crops = _validate_lengths(ar, crop_width) + slices = [slice(a, ar.shape[i] - b) for i, (a, b) in enumerate(crops)] + if copy: + cropped = np.array(ar[slices], order=order, copy=True) + else: + cropped = ar[slices] + return cropped + +def compare_ssim(X, Y, win_size=None, gradient=False, + data_range=1, multichannel=False, gaussian_weights=False, + full=False, dynamic_range=None, **kwargs): + """Compute the mean structural similarity index between two images. + Parameters + ---------- + X, Y : ndarray + Image. Any dimensionality. + win_size : int or None + The side-length of the sliding window used in comparison. Must be an + odd value. If `gaussian_weights` is True, this is ignored and the + window size will depend on `sigma`. + gradient : bool, optional + If True, also return the gradient. + data_range : int, optional + The data range of the input image (distance between minimum and + maximum possible values). By default, this is estimated from the image + data-type. + multichannel : bool, optional + If True, treat the last dimension of the array as channels. Similarity + calculations are done independently for each channel then averaged. + gaussian_weights : bool, optional + If True, each patch has its mean and variance spatially weighted by a + normalized Gaussian kernel of width sigma=1.5. + full : bool, optional + If True, return the full structural similarity image instead of the + mean value. + Other Parameters + ---------------- + use_sample_covariance : bool + if True, normalize covariances by N-1 rather than, N where N is the + number of pixels within the sliding window. + K1 : float + algorithm parameter, K1 (small constant, see [1]_) + K2 : float + algorithm parameter, K2 (small constant, see [1]_) + sigma : float + sigma for the Gaussian when `gaussian_weights` is True. + Returns + ------- + mssim : float + The mean structural similarity over the image. + grad : ndarray + The gradient of the structural similarity index between X and Y [2]_. + This is only returned if `gradient` is set to True. + S : ndarray + The full SSIM image. This is only returned if `full` is set to True. + Notes + ----- + To match the implementation of Wang et. al. [1]_, set `gaussian_weights` + to True, `sigma` to 1.5, and `use_sample_covariance` to False. + References + ---------- + .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. + (2004). Image quality assessment: From error visibility to + structural similarity. IEEE Transactions on Image Processing, + 13, 600-612. + https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, + DOI:10.1.1.11.2477 + .. [2] Avanaki, A. N. (2009). Exact global histogram specification + optimized for structural similarity. Optical Review, 16, 613-621. + http://arxiv.org/abs/0901.0065, + DOI:10.1007/s10043-009-0119-z + """ + if not X.dtype == Y.dtype: + raise ValueError('Input images must have the same dtype.') + + if not X.shape == Y.shape: + raise ValueError('Input images must have the same dimensions.') + + if dynamic_range is not None: + #warn('`dynamic_range` has been deprecated in favor of ' + # '`data_range`. The `dynamic_range` keyword argument ' + # 'will be removed in v0.14', skimage_deprecation) + data_range = dynamic_range + + if multichannel: + # loop over channels + args = dict(win_size=win_size, + gradient=gradient, + data_range=data_range, + multichannel=False, + gaussian_weights=gaussian_weights, + full=full) + args.update(kwargs) + nch = X.shape[-1] + mssim = np.empty(nch) + if gradient: + G = np.empty(X.shape) + if full: + S = np.empty(X.shape) + for ch in range(nch): + ch_result = compare_ssim(X[..., ch], Y[..., ch], **args) + if gradient and full: + mssim[..., ch], G[..., ch], S[..., ch] = ch_result + elif gradient: + mssim[..., ch], G[..., ch] = ch_result + elif full: + mssim[..., ch], S[..., ch] = ch_result + else: + mssim[..., ch] = ch_result + mssim = mssim.mean() + if gradient and full: + return mssim, G, S + elif gradient: + return mssim, G + elif full: + return mssim, S + else: + return mssim + + K1 = kwargs.pop('K1', 0.01) + K2 = kwargs.pop('K2', 0.03) + sigma = kwargs.pop('sigma', 1.5) + if K1 < 0: + raise ValueError("K1 must be positive") + if K2 < 0: + raise ValueError("K2 must be positive") + if sigma < 0: + raise ValueError("sigma must be positive") + use_sample_covariance = kwargs.pop('use_sample_covariance', True) + + if win_size is None: + if gaussian_weights: + win_size = 11 # 11 to match Wang et. al. 2004 + else: + win_size = 7 # backwards compatibility + + if np.any((np.asarray(X.shape) - win_size) < 0): + raise ValueError( + "win_size exceeds image extent. If the input is a multichannel " + "(color) image, set multichannel=True.") + + if not (win_size % 2 == 1): + raise ValueError('Window size must be odd.') + + if data_range is None: + dmin, dmax = dtype_range[X.dtype.type] + data_range = dmax - dmin + + ndim = X.ndim + + if gaussian_weights: + # sigma = 1.5 to approximately match filter in Wang et. al. 2004 + # this ends up giving a 13-tap rather than 11-tap Gaussian + filter_func = gaussian_filter + filter_args = {'sigma': sigma} + + else: + filter_func = uniform_filter + filter_args = {'size': win_size} + + # ndimage filters need floating point data + X = X.astype(np.float64) + Y = Y.astype(np.float64) + + NP = win_size ** ndim + + # filter has already normalized by NP + if use_sample_covariance: + cov_norm = NP / (NP - 1) # sample covariance + else: + cov_norm = 1.0 # population covariance to match Wang et. al. 2004 + + # compute (weighted) means + ux = filter_func(X, **filter_args) + uy = filter_func(Y, **filter_args) + + # compute (weighted) variances and covariances + uxx = filter_func(X * X, **filter_args) + uyy = filter_func(Y * Y, **filter_args) + uxy = filter_func(X * Y, **filter_args) + vx = cov_norm * (uxx - ux * ux) + vy = cov_norm * (uyy - uy * uy) + vxy = cov_norm * (uxy - ux * uy) + + R = data_range + C1 = (K1 * R) ** 2 + C2 = (K2 * R) ** 2 + + A1, A2, B1, B2 = ((2 * ux * uy + C1, + 2 * vxy + C2, + ux ** 2 + uy ** 2 + C1, + vx + vy + C2)) + D = B1 * B2 + S = (A1 * A2) / D + + # to avoid edge effects will ignore filter radius strip around edges + pad = (win_size - 1) // 2 + + # compute (weighted) mean of ssim + mssim = crop(S, pad).mean() + + if gradient: + # The following is Eqs. 7-8 of Avanaki 2009. + grad = filter_func(A1 / D, **filter_args) * X + grad += filter_func(-S / B2, **filter_args) * Y + grad += filter_func((ux * (A2 - A1) - uy * (B2 - B1) * S) / D, + **filter_args) + grad *= (2 / X.size) + + if full: + return mssim, grad, S + else: + return mssim, grad + else: + if full: + return mssim, S + else: + return mssim + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/util.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/util.py new file mode 100644 index 0000000..fbab32e --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/util.py @@ -0,0 +1,237 @@ +"""This module contains simple helper functions """ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os +import time +from functools import wraps +import torch +import random +import numpy as np +import cv2 +import torch +import colour_demosaicing +import glob + +# 修饰函数,重新尝试600次,每次间隔1秒钟 +# 能对func本身处理,缺点在于无法查看func本身的提示 +def loop_until_success(func): + @wraps(func) + def wrapper(*args, **kwargs): + for i in range(600): + try: + ret = func(*args, **kwargs) + break + except OSError: + time.sleep(1) + return ret + return wrapper + +# 修改后的print函数及torch.save函数示例 +@loop_until_success +def loop_print(*args, **kwargs): + print(*args, **kwargs) + +@loop_until_success +def torch_save(*args, **kwargs): + torch.save(*args, **kwargs) + +def calc_psnr(sr, hr, range=1.): + # shave = 2 + with torch.no_grad(): + diff = (sr - hr) / range + # diff = diff[:, :, shave:-shave, shave:-shave] + mse = torch.pow(diff, 2).mean() + return (-10 * torch.log10(mse)).item() + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + +def print_numpy(x, val=True, shp=True): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, mid = %3.3f, std=%3.3f' + % (np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + + + +def get_coord(H, W, x=448/3968, y=448/2976): + x_coord = np.linspace(-x + (x / W), x - (x / W), W) + x_coord = np.expand_dims(x_coord, axis=0) + x_coord = np.tile(x_coord, (H, 1)) + x_coord = np.expand_dims(x_coord, axis=0) + + y_coord = np.linspace(-y + (y / H), y - (y / H), H) + y_coord = np.expand_dims(y_coord, axis=1) + y_coord = np.tile(y_coord, (1, W)) + y_coord = np.expand_dims(y_coord, axis=0) + + coord = np.ascontiguousarray(np.concatenate([x_coord, y_coord])) + coord = np.float32(coord) + + return coord + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) + +def prompt(s, width=66): + print('='*(width+4)) + ss = s.split('\n') + if len(ss) == 1 and len(s) <= width: + print('= ' + s.center(width) + ' =') + else: + for s in ss: + for i in split_str(s, width): + print('= ' + i.ljust(width) + ' =') + print('='*(width+4)) + +def split_str(s, width): + ss = [] + while len(s) > width: + idx = s.rfind(' ', 0, width+1) + if idx > width >> 1: + ss.append(s[:idx]) + s = s[idx+1:] + else: + ss.append(s[:width]) + s = s[width:] + if s.strip() != '': + ss.append(s) + return ss + + +def load_img(filename, debug=False, norm=True, resize=False): + img = cv2.imread(filename) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if norm: + img = (img) / 255. + img = img.astype(np.float32) + if debug: + print(img.shape, img.dtype, img.min(), img.max()) + + if resize: + img = cv2.resize(img, (resize[0], resize[1]), interpolation=cv2.INTER_AREA) + + return img + +def augment_func(img, hflip, vflip, rot90): # CxHxW + if hflip: img = img[:, :, ::-1] + if vflip: img = img[:, ::-1, :] + if rot90: img = img.transpose(0, 2, 1) + return np.ascontiguousarray(img) + +def augment(*imgs): # CxHxW + hflip = random.random() < 0.5 + vflip = random.random() < 0.5 + rot90 = random.random() < 0.5 + return (augment_func(img, hflip, vflip, rot90) for img in imgs) + +def remove_black_level(img, black_lv=0, white_lv=2**10): + img = np.maximum(img.astype(np.float32)-black_lv, 0) / (white_lv-black_lv) + return img + + + +def extract_bayer_channels(raw): # HxWx4 + ch_R = raw[:,:,0] + ch_Gb = raw[:,:,1] + ch_Gr = raw[:,:,2] + ch_B = raw[:,:,3] + raw_combined = np.dstack((ch_B, ch_Gb, ch_R, ch_Gr)); + # raw_combined = raw + raw_combined = np.ascontiguousarray(raw_combined.transpose((2, 0, 1))); + return raw_combined # 4xHxW + +def extract_bayer_channels_rggb(raw): # HxWx4 + raw_combined = np.ascontiguousarray(raw.transpose((2, 0, 1))); + return raw_combined # 4xHxW + +def pack_rggb_channels(raw): # HxWx4 + ch_B = raw[:,:,0] + ch_Gb = raw[:,:,1] + ch_R = raw[:,:,2] + ch_Gr = raw[:,:,3] + raw_combined = np.dstack((ch_R, ch_Gb, ch_Gr, ch_B)); + raw_combined = np.ascontiguousarray(raw_combined); + return raw_combined # HxWx4 + +def RGGB2Bayer(im):# H//2xW//2x4 + # convert RGGB stacked image to one channel Bayer + bayer = np.zeros((im.shape[0] * 2, im.shape[1] * 2)) + bayer[0::2, 0::2] = im[:, :, 0] + bayer[0::2, 1::2] = im[:, :, 1] + bayer[1::2, 0::2] = im[:, :, 2] + bayer[1::2, 1::2] = im[:, :, 3] + return bayer# HxWx1 + +def get_raw_demosaic(raw, pattern='RGGB'): # HxW + raw=RGGB2Bayer(raw) + raw_demosaic = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, pattern=pattern) + raw_demosaic = np.ascontiguousarray(raw_demosaic.astype(np.float32).transpose((2, 0, 1))) +# raw_demosaic = np.ascontiguousarray(raw_demosaic.astype(np.float32)) + return raw_demosaic # 3xHxW + + +def read_wb(txtfile, key): + wb = np.zeros((1,4)) + with open(txtfile) as f: + for l in f: + if key in l: + for i in range(wb.shape[0]): + nextline = next(f) + try: + wb[i,:] = nextline.split() + except: + print("WB error XXXXXXX") + print(txtfile) + wb = wb.astype(np.float32) + return wb + + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/visualizer.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/visualizer.py new file mode 100644 index 0000000..0b65393 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/util/visualizer.py @@ -0,0 +1,62 @@ +import numpy as np +from os.path import join +from tensorboardX import SummaryWriter +from matplotlib import pyplot as plt +from io import BytesIO +from PIL import Image +from functools import partial +from functools import wraps +import time + +def write_until_success(func): + @wraps(func) + def wrapper(*args, **kwargs): + for i in range(30): + try: + ret = func(*args, **kwargs) + break + except OSError: + print('%s OSError' % str(args)) + time.sleep(1) + return ret + return wrapper + +class Visualizer(): + def __init__(self, opt): + self.opt = opt + if opt.isTrain: + self.name = opt.name + self.save_dir = join(opt.checkpoints_dir, opt.name, 'log') + self.writer = SummaryWriter(logdir=join(self.save_dir)) + else: + self.name = '%s_%s_%d' % ( + opt.name, opt.dataset_name, opt.load_iter) + self.save_dir = join(opt.checkpoints_dir, opt.name) + if opt.save_imgs: + self.writer = SummaryWriter(logdir=join( + self.save_dir, 'ckpts', self.name)) + + @write_until_success + def display_current_results(self, phase, visuals, iters): + for k, v in visuals.items(): + v = v.cpu() + self.writer.add_image('%s/%s'%(phase, k), v[0]/255, iters) + self.writer.flush() + + @write_until_success + def print_current_losses(self, epoch, iters, losses, + t_comp, t_data, total_iters): + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' \ + % (epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.4e ' % (k, v) + self.writer.add_scalar('loss/%s'%k, v, total_iters) + print(message) + + @write_until_success + def print_psnr(self, epoch, total_epoch, time_val, mean_psnr): + self.writer.add_scalar('val/psnr', mean_psnr, epoch) + print('End of epoch %d / %d (Val) \t Time Taken: %.3f s \t PSNR: %f' + % (epoch, total_epoch, time_val, mean_psnr)) + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/__init__.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/__init__.py new file mode 100644 index 0000000..4559e6f --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/__init__.py @@ -0,0 +1,57 @@ +import importlib +import torch.utils.data +from data.base_dataset import BaseDataset + + +def find_dataset_using_name(dataset_name, split='train'): + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + raise NotImplementedError("In %s.py, there should be a subclass of " + "BaseDataset with class name that matches %s in " + "lowercase." % (dataset_filename, target_dataset_name)) + return dataset + + +def create_dataset(dataset_name, split, opt): + data_loader = CustomDatasetDataLoader(dataset_name, split, opt) + dataset = data_loader.load_data() + return dataset + + +class CustomDatasetDataLoader(): + def __init__(self, dataset_name, split, opt): + self.opt = opt + dataset_class = find_dataset_using_name(dataset_name, split) + self.dataset = dataset_class(opt, split, dataset_name) +# self.imio = self.dataset.imio + print("dataset [%s(%s)] created" % (dataset_name, split)) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batch_size if split=='train' else 1, + shuffle=opt.shuffle and split=='train', + num_workers=int(opt.num_dataloader), + drop_last=opt.drop_last) + + def load_data(self): + return self + + def __len__(self): + """Return the number of data in the dataset""" + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + """Return a batch of data""" + for i, data in enumerate(self.dataloader): + if i * self.opt.batch_size >= self.opt.max_dataset_size: + break + yield data + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/base_dataset.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/base_dataset.py new file mode 100644 index 0000000..eadad20 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/base_dataset.py @@ -0,0 +1,19 @@ +import torch.utils.data as data +from abc import ABC, abstractmethod + + +class BaseDataset(data.Dataset, ABC): + def __init__(self, opt, split, dataset_name): + self.opt = opt + self.split = split + self.root = opt.dataroot + self.dataset_name = dataset_name.lower() + + @abstractmethod + def __len__(self): + return 0 + + @abstractmethod + def __getitem__(self, index): + pass + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/imlib.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/imlib.py new file mode 100644 index 0000000..b06539c --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/imlib.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- + +import numpy as np +import os +import cv2 +from PIL import Image +from functools import wraps +import time + + +class imlib(): + """ + Note that YCxCx in OpenCV and PIL are different. + Therefore, be careful if a model is trained with OpenCV and tested with + PIL in Y mode, and vise versa + + force_color = True: return a 3 channel YCxCx image + For mode 'Y', if a gray image is given, repeat the channel for 3 times, + and then converted to YCxCx mode. + force_color = False: return a 3 channel YCxCx image or a 1 channel gray one. + For mode 'Y', if a gray image is given, the gray image is directly used. + """ + def __init__(self, mode='RGB', fmt='CHW', lib='cv2', force_color=True): + assert mode.upper() in ('RGB', 'L', 'Y', 'RAW') + self.mode = mode.upper() + + assert fmt.upper() in ('HWC', 'CHW', 'NHWC', 'NCHW') + self.fmt = 'CHW' if fmt.upper() in ('CHW', 'NCHW') else 'HWC' + + assert lib.lower() in ('cv2', 'pillow') + self.lib = lib.lower() + + self.force_color = force_color + + self.dtype = np.uint8 + + self._imread = getattr(self, '_imread_%s_%s'%(self.lib, self.mode)) + self._imwrite = getattr(self, '_imwrite_%s_%s'%(self.lib, self.mode)) + self._trans_batch = getattr(self, '_trans_batch_%s_%s' + % (self.mode, self.fmt)) + self._trans_image = getattr(self, '_trans_image_%s_%s' + % (self.mode, self.fmt)) + self._trans_back = getattr(self, '_trans_back_%s_%s' + % (self.mode, self.fmt)) + + def _imread_cv2_RGB(self, path): + return cv2_imread(path, cv2.IMREAD_COLOR)[..., ::-1] + def _imread_cv2_RAW(self, path): + return cv2_imread(path, -1) + def _imread_cv2_Y(self, path): + if self.force_color: + img = cv2_imread(path, cv2.IMREAD_COLOR) + else: + img = cv2_imread(path, cv2.IMREAD_ANYCOLOR) + if len(img.shape) == 2: + return np.expand_dims(img, 3) + elif len(img.shape) == 3: + return cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) + else: + raise ValueError('The dimension should be either 2 or 3.') + def _imread_cv2_L(self, path): + return cv2_imread(path, cv2.IMREAD_GRAYSCALE) + + def _imread_pillow_RGB(self, path): + img = Image.open(path) + im = np.array(img.convert(self.mode)) + img.close() + return im + _imread_pillow_L = _imread_pillow_RGB + # WARNING: the RGB->YCbCr procedure of PIL may be different with OpenCV + def _imread_pillow_Y(self, path): + img = Image.open(path) + if img.mode == 'RGB': + im = np.array(img.convert('YCbCr')) + elif img.mode == 'L': + if self.force_color: + im = np.array(img.convert('RGB').convert('YCbCr')) + else: + im = np.expand_dims(np.array(img), 3) + else: + img.close() + raise NotImplementedError('Only support RGB and gray images now.') + img.close() + return im + + def _imwrite_cv2_RGB(self, image, path): + cv2.imwrite(path, image[..., ::-1]) + def _imwrite_cv2_RAW(self, image, path): + pass + def _imwrite_cv2_Y(self, image, path): + if image.shape[2] == 1: + cv2.imwrite(path, image[..., 0]) + elif image.shape[2] == 3: + cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_YCrCb2BGR)) + else: + raise ValueError('There should be 1 or 3 channels.') + def _imwrite_cv2_L(self, image, path): + cv2.imwrite(path, image) + + def _imwrite_pillow_RGB(self, image, path): + Image.fromarray(image).save(path) + _imwrite_pillow_L = _imwrite_pillow_RGB + def _imwrite_pillow_Y(self, image, path): + if image.shape[2] == 1: + self._imwrite_pillow_L(np.squeeze(image, 2), path) + elif image.shape[2] == 3: + Image.fromarray(image, mode='YCbCr').convert('RGB').save(path) + else: + raise ValueError('There should be 1 or 3 channels.') + + def _trans_batch_RGB_HWC(self, images): + return np.ascontiguousarray(images) + def _trans_batch_RGB_CHW(self, images): + return np.ascontiguousarray(np.transpose(images, (0, 3, 1, 2))) + _trans_batch_RAW_HWC = _trans_batch_RGB_HWC + _trans_batch_RAW_CHW = _trans_batch_RGB_CHW + _trans_batch_Y_HWC = _trans_batch_RGB_HWC + _trans_batch_Y_CHW = _trans_batch_RGB_CHW + def _trans_batch_L_HWC(self, images): + return np.ascontiguousarray(np.expand_dims(images, 3)) + def _trans_batch_L_CHW(slef, images): + return np.ascontiguousarray(np.expand_dims(images, 1)) + + def _trans_image_RGB_HWC(self, image): + return np.ascontiguousarray(image) + def _trans_image_RGB_CHW(self, image): + return np.ascontiguousarray(np.transpose(image, (2, 0, 1))) + _trans_image_RAW_HWC = _trans_image_RGB_HWC + _trans_image_RAW_CHW = _trans_image_RGB_CHW + _trans_image_Y_HWC = _trans_image_RGB_HWC + _trans_image_Y_CHW = _trans_image_RGB_CHW + def _trans_image_L_HWC(self, image): + return np.ascontiguousarray(np.expand_dims(image, 2)) + def _trans_image_L_CHW(self, image): + return np.ascontiguousarray(np.expand_dims(image, 0)) + + def _trans_back_RGB_HWC(self, image): + return image + def _trans_back_RGB_CHW(self, image): + return np.transpose(image, (1, 2, 0)) + _trans_back_RAW_HWC = _trans_back_RGB_HWC + _trans_back_RAW_CHW = _trans_back_RGB_CHW + _trans_back_Y_HWC = _trans_back_RGB_HWC + _trans_back_Y_CHW = _trans_back_RGB_CHW + def _trans_back_L_HWC(self, image): + return np.squeeze(image, 2) + def _trans_back_L_CHW(self, image): + return np.squeeze(image, 0) + + img_ext = ('png', 'PNG', 'jpg', 'JPG', 'bmp', 'BMP', 'jpeg', 'JPEG') + + def is_image(self, fname): + return any(fname.endswith(i) for i in self.img_ext) + + def read(self, paths): + if isinstance(paths, (list, tuple)): + images = [self._imread(path) for path in paths] + return self._trans_batch(np.array(images)) + return self._trans_image(self._imread(paths)) + + def back(self, image): + return self._trans_back(image) + + def write(self, image, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + self._imwrite(self.back(image), path) + +def read_until_success(func): + @wraps(func) + def wrapper(*args, **kwargs): + for i in range(30): + try: + ret = func(*args, **kwargs) + if ret is None: + raise OSError() + else: + break + except OSError: + print('%s OSError' % str(args)) + time.sleep(1) + return ret + return wrapper + +@read_until_success +def cv2_imread(*args, **kwargs): + return cv2.imread(*args, **kwargs) + +# if __name__ == '__main__': +# import matplotlib.pyplot as plt +# im_rgb_chw_cv2 = imlib('rgb', fmt='chw', lib='cv2') +# im_rgb_hwc_cv2 = imlib('rgb', fmt='hwc', lib='cv2') +# im_rgb_chw_pil = imlib('rgb', fmt='chw', lib='pillow') +# im_rgb_hwc_pil = imlib('rgb', fmt='hwc', lib='pillow') +# im_y_chw_cv2 = imlib('y', fmt='chw', lib='cv2') +# im_y_hwc_cv2 = imlib('y', fmt='hwc', lib='cv2') +# im_y_chw_pil = imlib('y', fmt='chw', lib='pillow') +# im_y_hwc_pil = imlib('y', fmt='hwc', lib='pillow') +# im_l_chw_cv2 = imlib('l', fmt='chw', lib='cv2') +# im_l_hwc_cv2 = imlib('l', fmt='hwc', lib='cv2') +# im_l_chw_pil = imlib('l', fmt='chw', lib='pillow') +# im_l_hwc_pil = imlib('l', fmt='hwc', lib='pillow') +# path = 'D:/Datasets/test/000001.jpg' + +# img_rgb_chw_cv2 = im_rgb_chw_cv2.read(path) +# print(img_rgb_chw_cv2.shape) +# plt.imshow(im_rgb_chw_cv2.back(img_rgb_chw_cv2)) +# plt.show() +# im_rgb_chw_cv2.write(img_rgb_chw_cv2, +# (path.replace('000001.jpg', 'img_rgb_chw_cv2.jpg'))) +# img_rgb_hwc_cv2 = im_rgb_hwc_cv2.read(path) +# print(img_rgb_hwc_cv2.shape) +# plt.imshow(im_rgb_hwc_cv2.back(img_rgb_hwc_cv2)) +# plt.show() +# im_rgb_hwc_cv2.write(img_rgb_hwc_cv2, +# (path.replace('000001.jpg', 'img_rgb_hwc_cv2.jpg'))) +# img_rgb_chw_pil = im_rgb_chw_pil.read(path) +# print(img_rgb_chw_pil.shape) +# plt.imshow(im_rgb_chw_pil.back(img_rgb_chw_pil)) +# plt.show() +# im_rgb_chw_pil.write(img_rgb_chw_pil, +# (path.replace('000001.jpg', 'img_rgb_chw_pil.jpg'))) +# img_rgb_hwc_pil = im_rgb_hwc_pil.read(path) +# print(img_rgb_hwc_pil.shape) +# plt.imshow(im_rgb_hwc_pil.back(img_rgb_hwc_pil)) +# plt.show() +# im_rgb_hwc_pil.write(img_rgb_hwc_pil, +# (path.replace('000001.jpg', 'img_rgb_hwc_pil.jpg'))) + + +# img_y_chw_cv2 = im_y_chw_cv2.read(path) +# print(img_y_chw_cv2.shape) +# plt.imshow(np.squeeze(im_y_chw_cv2.back(img_y_chw_cv2))) +# plt.show() +# im_y_chw_cv2.write(img_y_chw_cv2, +# (path.replace('000001.jpg', 'img_y_chw_cv2.jpg'))) +# img_y_hwc_cv2 = im_y_hwc_cv2.read(path) +# print(img_y_hwc_cv2.shape) +# plt.imshow(np.squeeze(im_y_hwc_cv2.back(img_y_hwc_cv2))) +# plt.show() +# im_y_hwc_cv2.write(img_y_hwc_cv2, +# (path.replace('000001.jpg', 'img_y_hwc_cv2.jpg'))) +# img_y_chw_pil = im_y_chw_pil.read(path) +# print(img_y_chw_pil.shape) +# plt.imshow(np.squeeze(im_y_chw_pil.back(img_y_chw_pil))) +# plt.show() +# im_y_chw_pil.write(img_y_chw_pil, +# (path.replace('000001.jpg', 'img_y_chw_pil.jpg'))) +# img_y_hwc_pil = im_y_hwc_pil.read(path) +# print(img_y_hwc_pil.shape) +# plt.imshow(np.squeeze(im_y_hwc_pil.back(img_y_hwc_pil))) +# plt.show() +# im_y_hwc_pil.write(img_y_hwc_pil, +# (path.replace('000001.jpg', 'img_y_hwc_pil.jpg'))) + + +# img_l_chw_cv2 = im_l_chw_cv2.read(path) +# print(img_l_chw_cv2.shape) +# plt.imshow(im_l_chw_cv2.back(img_l_chw_cv2)) +# plt.show() +# im_l_chw_cv2.write(img_l_chw_cv2, +# (path.replace('000001.jpg', 'img_l_chw_cv2.jpg'))) +# img_l_hwc_cv2 = im_l_hwc_cv2.read(path) +# print(img_l_hwc_cv2.shape) +# plt.imshow(im_l_hwc_cv2.back(img_l_hwc_cv2)) +# plt.show() +# im_l_hwc_cv2.write(img_l_hwc_cv2, +# (path.replace('000001.jpg', 'img_l_hwc_cv2.jpg'))) +# img_l_chw_pil = im_l_chw_pil.read(path) +# print(img_l_chw_pil.shape) +# plt.imshow(im_l_chw_pil.back(img_l_chw_pil)) +# plt.show() +# im_l_chw_pil.write(img_l_chw_pil, +# (path.replace('000001.jpg', 'img_l_chw_pil.jpg'))) +# img_l_hwc_pil = im_l_hwc_pil.read(path) +# print(img_l_hwc_pil.shape) +# plt.imshow(im_l_hwc_pil.back(img_l_hwc_pil)) +# plt.show() +# im_l_hwc_pil.write(img_l_hwc_pil, +# (path.replace('000001.jpg', 'img_l_hwc_pil.jpg'))) diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/s7align_dataset.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/s7align_dataset.py new file mode 100644 index 0000000..8eb2dad --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/s7align_dataset.py @@ -0,0 +1,95 @@ +import numpy as np +import os +from data.base_dataset import BaseDataset +from util.util import augment, remove_black_level, get_coord +from util.util import extract_bayer_channels, get_raw_demosaic, load_img +import glob + + +class S7alignDataset(BaseDataset): + def __init__(self, opt, split='train', dataset_name='ZRR'): + super(S7alignDataset, self).__init__(opt, split, dataset_name) + + + self.batch_size = opt.batch_size + + if split == 'train': + self.root_dir = os.path.join(self.root,'train'); + self.train_raws = sorted(glob.glob(os.path.join(self.root_dir, '*.npy'))) + self.train_rgbs = sorted(glob.glob(os.path.join(self.root_dir, '*.jpg'))) + self.names = sorted(glob.glob(os.path.join(self.root_dir, '*.jpg'))); + self._getitem = self._getitem_train + self.len_data = len(self.names) + + elif split == 'val': + self.root_dir = os.path.join(self.root, 'val') + self.test_raws = sorted(glob.glob(os.path.join(self.root_dir, '*.npy'))) + self.test_rgbs = sorted(glob.glob(os.path.join(self.root_dir, '*.jpg'))) + self.names = sorted(glob.glob(os.path.join(self.root_dir, '*.jpg'))); + self._getitem = self._getitem_val + self.len_data = len(self.names) + + elif split == 'test': + self.root_dir = os.path.join(self.root) + self.test_rgbs = sorted(glob.glob(os.path.join(self.root_dir, '*.jpg'))) + self.names = sorted(glob.glob(os.path.join(self.root_dir, '*.jpg'))); + self._getitem = self._getitem_test + self.len_data = len(self.names) + + else: + raise ValueError + + + + + def __getitem__(self, index): + return self._getitem(index) + + def __len__(self): + return self.len_data + + def _getitem_train(self, idx): + raw = np.load(self.train_raws[idx], encoding='bytes', allow_pickle=True); + raw_combined, raw_demosaic = self._process_raw(raw) + dslr_image = load_img(self.train_rgbs[idx]) + dslr_image = np.ascontiguousarray(dslr_image.transpose((2, 0, 1))) + raw_combined, raw_demosaic, dslr_image = augment( + raw_combined, raw_demosaic, dslr_image) + + return {'raw': raw_combined, + 'raw_demosaic': raw_demosaic, + 'dslr': dslr_image, + 'fname': self.names[idx]} + + def _getitem_val(self, idx): + raw = np.load(self.test_raws[idx], encoding='bytes', allow_pickle=True); + raw_combined, raw_demosaic = self._process_raw(raw) + dslr_image = load_img(self.test_rgbs[idx]) + dslr_image = np.ascontiguousarray(dslr_image.transpose((2, 0, 1))) + + return {'raw': raw_combined, + 'raw_demosaic': raw_demosaic, + 'dslr': dslr_image, + 'fname': self.names[idx]} + + def _getitem_test(self, idx): + dslr_image = load_img(self.test_rgbs[idx]) + dslr_image = dslr_image.transpose((2, 0, 1)) + + return { + 'dslr': dslr_image, + 'fname': self.names[idx]} + + + def _process_raw(self, raw): + raw = remove_black_level(raw) + raw_combined = extract_bayer_channels(raw) + raw_demosaic = get_raw_demosaic(raw) + return raw_combined, raw_demosaic + + + +if __name__ == '__main__': + pass + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/s7alignpatch_dataset.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/s7alignpatch_dataset.py new file mode 100644 index 0000000..9d7a699 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/data/s7alignpatch_dataset.py @@ -0,0 +1,116 @@ +import numpy as np +import os +from data.base_dataset import BaseDataset +from util.util import augment, remove_black_level +from util.util import extract_bayer_channels, get_raw_demosaic, load_img +import glob +import random + + +# Zurich RAW to RGB (ZRR) dataset +class S7alignpatchDataset(BaseDataset): + def __init__(self, opt, split='train', dataset_name='ZRR'): + super(S7alignpatchDataset, self).__init__(opt, split, dataset_name) + + + self.batch_size = opt.batch_size + + if split == 'train': + self.root_dir = os.path.join(self.root,'train_full'); + self.train_raws = sorted(glob.glob(os.path.join( self.root_dir, '*.npy'))) + self.train_rgbs = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))) + self.train_raws_file=[] + self.train_rgbs_file=[] + self.patch = opt.patch_size + for seq_path in self.train_raws: + seq = np.load(seq_path, encoding='bytes', allow_pickle=True); + self.train_raws_file.append(seq) + + for seq_path in self.train_rgbs: + seq = load_img(seq_path) + self.train_rgbs_file.append(seq) + + self.names = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))); + self._getitem = self._getitem_train + self.len_data = len(self.names)*48 + + elif split == 'val': + self.root_dir = os.path.join(self.root, 'val_full') + self.patch = opt.patch_size + self.test_raws = sorted(glob.glob(os.path.join( self.root_dir, '*.npy'))) + self.test_rgbs = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))) + self.names = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))); + self._getitem = self._getitem_val + self.len_data = len(self.names) + + elif split == 'test': + self.root_dir = os.path.join(self.root, 'val_full') + self.test_rgbs = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))) + self.names = sorted(glob.glob(os.path.join( self.root_dir, '*.jpg'))); + self._getitem = self._getitem_test + self.len_data = len(self.names) + + else: + raise ValueError + + + + + def __getitem__(self, index): + return self._getitem(index) + + def __len__(self): + return self.len_data + + def _getitem_train(self, idx): + idx = idx % (self.len_data//48) + H,W,C = self.train_raws_file[idx].shape + crop_h = random.randrange(0,H - self.patch ,2) + crop_w = random.randrange(0,W - self.patch ,2) + raw = self.train_raws_file[idx][crop_h:crop_h+self.patch , crop_w:crop_w+self.patch,:] + dslr_image = self.train_rgbs_file[idx][2*crop_h:2*crop_h+2*self.patch, 2*crop_w:2*crop_w+2*self.patch,:] + raw_combined, raw_demosaic = self._process_raw(raw) + dslr_image = dslr_image.transpose((2, 0, 1)) + raw_combined, raw_demosaic, dslr_image = augment( + raw_combined, raw_demosaic, dslr_image) + + return {'raw': raw_combined, + 'raw_demosaic': raw_demosaic, + 'dslr': dslr_image, + 'fname': self.names[idx]} + + def _getitem_val(self, idx): + raw_init = np.load(self.test_raws[idx], encoding='bytes', allow_pickle=True); + raw = raw_init[0:0+self.patch , 0:0+self.patch,:] + raw_combined, raw_demosaic = self._process_raw(raw) + dslr_image_init = load_img(self.test_rgbs[idx]) + dslr_image = dslr_image_init[0:0+2*self.patch , 0:0 + 2*self.patch,:] + dslr_image = dslr_image.transpose((2, 0, 1)) + + return {'raw': raw_combined, + 'raw_demosaic': raw_demosaic, + 'dslr': dslr_image, + 'fname': self.names[idx]} + + def _getitem_test(self, idx): + dslr_image = load_img(self.test_rgbs[idx]) + dslr_image = dslr_image.transpose((2, 0, 1)) + + return { + 'dslr': dslr_image, + 'fname': self.names[idx]} + + + def _process_raw(self, raw): + raw = remove_black_level(raw) + raw_combined = extract_bayer_channels(raw) + raw_demosaic = get_raw_demosaic(raw) + return raw_combined, raw_demosaic + + + + +if __name__ == '__main__': + pass + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/.DS_Store b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/.DS_Store new file mode 100644 index 0000000..63e6db6 Binary files /dev/null and b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/.DS_Store differ diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/__init__.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/__init__.py new file mode 100644 index 0000000..77aa38a --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/__init__.py @@ -0,0 +1,47 @@ +import importlib +from models.base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + raise NotImplementedError("In %s.py, there should be a subclass of " + "BaseModel with class name that matches %s in " + "lowercase." % (model_filename, target_model_name)) + + return model + + +def get_option_setter(model_name): + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/arch_util.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/arch_util.py new file mode 100644 index 0000000..8e005cf --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/arch_util.py @@ -0,0 +1,350 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ +import math +import torch +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + + + +# try: +# from basicsr.models.ops.dcn import (ModulatedDeformConvPack, +# modulated_deform_conv) +# except ImportError: +# # print('Cannot import dcn. Ignore this warning if dcn is not used. ' +# # 'Otherwise install BasicSR with compiling dcn.') +# + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' + 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +def flow_warp(x, + flow, + interp_mode='bilinear', + padding_mode='zeros', + align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid( + torch.arange(0, h).type_as(x), + torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample( + x, + vgrid_scaled, + mode=interp_mode, + padding_mode=padding_mode, + align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, + size_type, + sizes, + interp_mode='bilinear', + align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError( + f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, + size=(output_h, output_w), + mode=interp_mode, + align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +# class DCNv2Pack(ModulatedDeformConvPack): +# """Modulated deformable conv for deformable alignment. +# +# Different from the official DCNv2Pack, which generates offsets and masks +# from the preceding features, this DCNv2Pack takes another different +# features to generate offsets and masks. +# +# Ref: +# Delving Deep into Deformable Alignment in Video Super-Resolution. +# """ +# +# def forward(self, x, feat): +# out = self.conv_offset(feat) +# o1, o2, mask = torch.chunk(out, 3, dim=1) +# offset = torch.cat((o1, o2), dim=1) +# mask = torch.sigmoid(mask) +# +# offset_absmean = torch.mean(torch.abs(offset)) +# if offset_absmean > 50: +# logger = get_root_logger() +# logger.warning( +# f'Offset abs mean is {offset_absmean}, larger than 50.') +# +# return modulated_deform_conv(x, offset, mask, self.weight, self.bias, +# self.stride, self.padding, self.dilation, +# self.groups, self.deformable_groups) + + +class LayerNormFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( + dim=0), None + +class LayerNorm2d(nn.Module): + + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter('weight', nn.Parameter(torch.ones(channels))) + self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + +# handle multiple input +class MySequential(nn.Sequential): + def forward(self, *inputs): + for module in self._modules.values(): + if type(inputs) == tuple: + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + +import time +def measure_inference_speed(model, data, max_iter=200, log_interval=50): + model.eval() + + # the first several iterations may be very slow so skip them + num_warmup = 5 + pure_inf_time = 0 + fps = 0 + + # benchmark with 2000 image and take the average + for i in range(max_iter): + + torch.cuda.synchronize() + start_time = time.perf_counter() + + with torch.no_grad(): + model(*data) + + torch.cuda.synchronize() + elapsed = time.perf_counter() - start_time + + if i >= num_warmup: + pure_inf_time += elapsed + if (i + 1) % log_interval == 0: + fps = (i + 1 - num_warmup) / pure_inf_time + print( + f'Done image [{i + 1:<3}/ {max_iter}], ' + f'fps: {fps:.1f} img / s, ' + f'times per image: {1000 / fps:.1f} ms / img', + flush=True) + + if (i + 1) == max_iter: + fps = (i + 1 - num_warmup) / pure_inf_time + print( + f'Overall fps: {fps:.1f} img / s, ' + f'times per image: {1000 / fps:.1f} ms / img', + flush=True) + break + return fps \ No newline at end of file diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/base_model.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/base_model.py new file mode 100644 index 0000000..18dddaa --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/base_model.py @@ -0,0 +1,384 @@ +import os +import torch +from collections import OrderedDict +from abc import ABC, abstractmethod +from . import networks +import torch +from util.util import torch_save +import math +import torch.nn.functional as F + +def calc_psnr(sr, hr, range=1.): + # shave = 2 + with torch.no_grad(): + diff = (sr - hr) / range + mse = torch.pow(diff, 2) + mse= torch.mean(mse,dim=1,keepdim=True) + return (-10 * torch.log10(mse)) + +class BaseModel(ABC): + def __init__(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.scale = opt.scale + + if len(self.gpu_ids) > 0: + self.device = torch.device('cuda', self.gpu_ids[0]) + else: + self.device = torch.device('cpu') + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.optimizers = [] + self.optimizer_names = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + self.start_epoch = 0 + + self.backwarp_tenGrid = {} + self.backwarp_tenPartial = {} + + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + @abstractmethod + def set_input(self, input): + pass + + @abstractmethod + def forward(self): + pass + + @abstractmethod + def optimize_parameters(self): + pass + + def setup(self, opt=None): + opt = opt if opt is not None else self.opt + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) \ + for optimizer in self.optimizers] + for scheduler in self.schedulers: + scheduler.last_epoch = opt.load_iter + if opt.load_iter > 0 or opt.load_path != '': + load_suffix = opt.load_iter + self.load_networks(load_suffix) + if opt.load_optimizers: + self.load_optimizers(opt.load_iter) + + self.print_networks(opt.verbose) + + def eval(self): + for name in self.model_names: + net = getattr(self, 'net' + name) + net.eval() + + def train(self): + for name in self.model_names: + net = getattr(self, 'net' + name) + net.train() + + def test(self): + with torch.no_grad(): + self.forward() + + def get_image_paths(self): + return self.image_paths + + def update_learning_rate(self): + for i, scheduler in enumerate(self.schedulers): + if scheduler.__class__.__name__ == 'ReduceLROnPlateau': + scheduler.step(self.metric) + else: + scheduler.step() + print('lr of %s = %.7f' % ( + self.optimizer_names[i], scheduler.get_last_lr()[0])) + + def get_current_visuals(self): + visual_ret = OrderedDict() + for name in self.visual_names: + if 'xy' in name or 'coord' in name: + visual_ret[name] = getattr(self, name).detach() + else: + visual_ret[name] = torch.clamp( + getattr(self, name).detach(), 0., 1.) + return visual_ret + + def get_current_losses(self): + errors_ret = OrderedDict() + for name in self.loss_names: + errors_ret[name] = float(getattr(self, 'loss_' + name)) + return errors_ret + + def save_networks(self, epoch): + for name in self.model_names: + save_filename = '%s_model_%d.pth' % (name, epoch) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net' + name) + if self.device.type == 'cuda': + state = {'state_dict': net.module.cpu().state_dict()} + torch_save(state, save_path) + net.to(self.device) + else: + state = {'state_dict': net.state_dict()} + torch_save(state, save_path) + self.save_optimizers(epoch) + + def load_networks(self, epoch): +# self.model_names.append('GCMModel') + for name in self.model_names: #[0:1]: + # if name is 'Discriminator': + # continue + load_filename = '%s_model_%d.pth' % (name, epoch) +# if name=='GCMModel': +# load_filename = '%s_model_%d.pth' % (name, 1) + if self.opt.load_path != '': + load_path = self.opt.load_path + else: + load_path = os.path.join(self.save_dir, load_filename) + print(name,load_path) + net = getattr(self, 'net' + name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + state_dict = torch.load(load_path, map_location=self.device) + print('loading the model from %s' % (load_path)) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + + net_state = net.state_dict() + is_loaded = {n:False for n in net_state.keys()} + for name, param in state_dict['state_dict'].items(): + if name in net_state: + try: + net_state[name].copy_(param) + is_loaded[name] = True + except Exception: + print('While copying the parameter named [%s], ' + 'whose dimensions in the model are %s and ' + 'whose dimensions in the checkpoint are %s.' + % (name, list(net_state[name].shape), + list(param.shape))) + raise RuntimeError + else: + print('Saved parameter named [%s] is skipped' % name) + mark = True + for name in is_loaded: + if not is_loaded[name]: + print('Parameter named [%s] is randomly initialized' % name) + mark = False + if mark: + print('All parameters are initialized using [%s]' % load_path) + + self.start_epoch = epoch + + def save_optimizers(self, epoch): + assert len(self.optimizers) == len(self.optimizer_names) + for id, optimizer in enumerate(self.optimizers): + save_filename = self.optimizer_names[id] + state = {'name': save_filename, + 'epoch': epoch, + 'state_dict': optimizer.state_dict()} + save_path = os.path.join(self.save_dir, save_filename+'.pth') + torch_save(state, save_path) + + def load_optimizers(self, epoch): + assert len(self.optimizers) == len(self.optimizer_names) + for id, optimizer in enumerate(self.optimizer_names): + load_filename = self.optimizer_names[id] + load_path = os.path.join(self.save_dir, load_filename+'.pth') + print('loading the optimizer from %s' % load_path) + state_dict = torch.load(load_path) + print(state_dict['epoch']) + assert optimizer == state_dict['name'] + assert epoch == state_dict['epoch'] + self.optimizers[id].load_state_dict(state_dict['state_dict']) + + def print_networks(self, verbose): + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' + % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + def estimate(self, tenFirst, tenSecond, net): + assert(tenFirst.shape[3] == tenSecond.shape[3]) + assert(tenFirst.shape[2] == tenSecond.shape[2]) + intWidth = tenFirst.shape[3] + intHeight = tenFirst.shape[2] + # tenPreprocessedFirst = tenFirst.view(1, 3, intHeight, intWidth) + # tenPreprocessedSecond = tenSecond.view(1, 3, intHeight, intWidth) + + intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) + intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) + + tenPreprocessedFirst = F.interpolate(input=tenFirst, + size=(intPreprocessedHeight, intPreprocessedWidth), + mode='bilinear', align_corners=False) + tenPreprocessedSecond = F.interpolate(input=tenSecond, + size=(intPreprocessedHeight, intPreprocessedWidth), + mode='bilinear', align_corners=False) + + tenFlow = 20.0 * F.interpolate( + input=net(tenPreprocessedFirst, tenPreprocessedSecond), + size=(intHeight, intWidth), mode='bilinear', align_corners=False) + + tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) + tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) + + return tenFlow[:, :, :, :] + + def backwarp(self, tenInput, tenFlow): + index = str(tenFlow.shape) + str(tenInput.device) + if index not in self.backwarp_tenGrid: + tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), + tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1) + tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), + tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3]) + self.backwarp_tenGrid[index] = torch.cat([tenHor, tenVer], 1).to(tenInput.device) + + if index not in self.backwarp_tenPartial: + self.backwarp_tenPartial[index] = tenFlow.new_ones([ + tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3]]) + + tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) + tenInput = torch.cat([tenInput, self.backwarp_tenPartial[index]], 1) + + tenOutput = F.grid_sample(input=tenInput, + grid=(self.backwarp_tenGrid[index] + tenFlow).permute(0, 2, 3, 1), + mode='bilinear', padding_mode='zeros', align_corners=False) + + return tenOutput + + def get_backwarp(self, tenFirst, tenSecond,raw, net, flow=None): + if flow is None: + flow = self.get_flow(tenFirst, tenSecond, net) + + flow_raw = F.interpolate(flow, scale_factor=0.5)/2. + tenoutput = self.backwarp(tenSecond, flow) + rgb_tenMask = tenoutput[:, -1:, :, :] + rgb_tenMask[rgb_tenMask > 0.999] = 1.0 + rgb_tenMask[rgb_tenMask < 1.0] = 0.0 + + rawoutput = self.backwarp(raw, flow_raw) + raw_tenMask = rawoutput[:, -1:, :, :] + raw_tenMask[raw_tenMask > 0.999] = 1.0 + raw_tenMask[raw_tenMask < 1.0] = 0.0 + d=tenoutput[:, :-1, :, :] * rgb_tenMask + return tenoutput[:, :-1, :, :] * rgb_tenMask, rgb_tenMask,rawoutput[:, :-1, :, :] * raw_tenMask, raw_tenMask + + def get_backwarp_down(self, tenFirst, tenSecond,raw, net, flow=None): + if flow is None: + flow = self.get_flow(tenFirst, tenSecond, net) + + tenoutput = self.backwarp(tenSecond, flow) + rgb_tenMask = tenoutput[:, -1:, :, :] + rgb_tenMask[rgb_tenMask > 0.999] = 1.0 + rgb_tenMask[rgb_tenMask < 1.0] = 0.0 + + rawoutput = self.backwarp(raw, flow) + raw_tenMask = rawoutput[:, -1:, :, :] + raw_tenMask[raw_tenMask > 0.999] = 1.0 + raw_tenMask[raw_tenMask < 1.0] = 0.0 + d=tenoutput[:, :-1, :, :] * rgb_tenMask + return tenoutput[:, :-1, :, :] * rgb_tenMask, rgb_tenMask,rawoutput[:, :-1, :, :] * raw_tenMask, raw_tenMask + + def get_backwarp_nogcm(self, tenFirst, tenSecond,raw, net, flow=None): + if flow is None: + flow = self.get_flow(tenFirst, tenSecond, net) + + tenoutput = self.backwarp(raw, flow) + rgb_tenMask = tenoutput[:, -1:, :, :] + rgb_tenMask[rgb_tenMask > 0.999] = 1.0 + rgb_tenMask[rgb_tenMask < 1.0] = 0.0 + + return tenoutput[:, :-1, :, :] * rgb_tenMask, rgb_tenMask + + + + def get_backwarp_isp(self, tenFirst, tenSecond, net, flow=None): + if flow is None: + flow = self.get_flow(tenFirst, tenSecond, net) + + tenoutput = self.backwarp(tenSecond, flow) + tenMask = tenoutput[:, -1:, :, :] + tenMask[tenMask > 0.999] = 1.0 + tenMask[tenMask < 1.0] = 0.0 + return tenoutput[:, :-1, :, :] * tenMask, tenMask + + def get_backwarp_fb(self, tenFirst, tenSecond,raw, net, flow=None): + #获取前向flow + flow_fw = self.get_flow(tenFirst, tenSecond, net) + + flow_raw = F.interpolate(flow_fw, scale_factor=0.5)/2. + tenoutput = self.backwarp(tenSecond, flow_fw) + rgb_tenMask = tenoutput[:, -1:, :, :] + rgb_tenMask[rgb_tenMask > 0.999] = 1.0 + rgb_tenMask[rgb_tenMask < 1.0] = 0.0 + + rawoutput = self.backwarp(raw, flow_raw) + raw_tenMask = rawoutput[:, -1:, :, :] + raw_tenMask[raw_tenMask > 0.999] = 1.0 + raw_tenMask[raw_tenMask < 1.0] = 0.0 + + #获取后向flow + flow_bw = self.get_flow( tenFirst,tenoutput, net) + flow_fw_warped = self.backwarp(flow_fw, flow_fw) + flow_diff_fw = flow_bw + flow_fw_warped[:,:-1,:,:] + mag_sq_bw = self.length_sq(flow_bw) + self.length_sq(flow_fw_warped[:,:-1,:,:]) + occ_thresh_bw = 0.01 * mag_sq_bw + 0.5 + rgb_flowMask = flow_bw_warped[:, -1:, :, :] + rgb_flowMask[self.length_sq(flow_diff_bw) <= occ_thresh_bw] = 1.0 + rgb_flowMask[self.length_sq(flow_diff_bw) > occ_thresh_bw] = 0.0 + + return tenoutput[:, :-1, :, :] * rgb_tenMask * rgb_flowMask, rgb_tenMask,rawoutput[:, :-1, :, :] * raw_tenMask, raw_tenMask,flow_raw,flow_bw + + def get_backwarp_all(self, tenFirst, tenSecond,raw, net, flow=None): + if flow is None: + flow_fw = self.get_flow(tenFirst, tenSecond, net) + + flow_raw = F.interpolate(flow_fw, scale_factor=0.5)/2. + tenoutput = self.backwarp(tenSecond, flow_fw) + rgb_tenMask = tenoutput[:, -1:, :, :] + rgb_tenMask[rgb_tenMask > 0.999] = 1.0 + rgb_tenMask[rgb_tenMask < 1.0] = 0.0 + + rawoutput = self.backwarp(raw, flow_raw) + raw_tenMask = rawoutput[:, -1:, :, :] + raw_tenMask[raw_tenMask > 0.999] = 1.0 + raw_tenMask[raw_tenMask < 1.0] = 0.0 + + flow_bw = self.get_flow( tenSecond,tenFirst, net) + tenSecond_wrap = self.backwarp(tenoutput[:, :-1, :, :], flow_bw) + rgb_mask_consis = tenSecond_wrap[:, -1:, :, :] + + wrap_psnr=calc_psnr(tenSecond_wrap[:, :-1, :, :],tenSecond) + rgb_mask_consis[wrap_psnr<=30]=0.0 + rgb_mask_consis[wrap_psnr>30]=1.0 + raw_mask_consis=rgb_mask_consis[:,:,0::2,0::2] + + return tenoutput[:, :-1, :, :] * rgb_tenMask * rgb_mask_consis, rgb_tenMask* rgb_mask_consis,rawoutput[:, :-1, :, :] * raw_tenMask* raw_mask_consis, raw_tenMask* raw_mask_consis + + def get_flow(self, tenFirst, tenSecond, net): + with torch.no_grad(): + net.eval() + flow = self.estimate(tenFirst, tenSecond, net) + return flow \ No newline at end of file diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/local_arch.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/local_arch.py new file mode 100644 index 0000000..bc459c6 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/local_arch.py @@ -0,0 +1,104 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +class AvgPool2d(nn.Module): + def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None): + super().__init__() + self.kernel_size = kernel_size + self.base_size = base_size + self.auto_pad = auto_pad + + # only used for fast implementation + self.fast_imp = fast_imp + self.rs = [5, 4, 3, 2, 1] + self.max_r1 = self.rs[0] + self.max_r2 = self.rs[0] + self.train_size = train_size + + def extra_repr(self) -> str: + return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format( + self.kernel_size, self.base_size, self.kernel_size, self.fast_imp + ) + + def forward(self, x): + if self.kernel_size is None and self.base_size: + train_size = self.train_size + if isinstance(self.base_size, int): + self.base_size = (self.base_size, self.base_size) + self.kernel_size = list(self.base_size) + self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2] + self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1] + + # only used for fast implementation + self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2]) + self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1]) + + if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1): + return F.adaptive_avg_pool2d(x, 1) + + if self.fast_imp: # Non-equivalent implementation but faster + h, w = x.shape[2:] + if self.kernel_size[0] >= h and self.kernel_size[1] >= w: + out = F.adaptive_avg_pool2d(x, 1) + else: + r1 = [r for r in self.rs if h % r == 0][0] + r2 = [r for r in self.rs if w % r == 0][0] + # reduction_constraint + r1 = min(self.max_r1, r1) + r2 = min(self.max_r2, r2) + s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2) + n, c, h, w = s.shape + k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2) + out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2) + out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2)) + else: + n, c, h, w = x.shape + s = x.cumsum(dim=-1).cumsum_(dim=-2) + s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience + k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1]) + s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:] + out = s4 + s1 - s2 - s3 + out = out / (k1 * k2) + + if self.auto_pad: + n, c, h, w = x.shape + _h, _w = out.shape[2:] + # print(x.shape, self.kernel_size) + pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2) + out = torch.nn.functional.pad(out, pad2d, mode='replicate') + + return out + +def replace_layers(model, base_size, train_size, fast_imp, **kwargs): + for n, m in model.named_children(): + if len(list(m.children())) > 0: + ## compound module, go inside it + replace_layers(m, base_size, train_size, fast_imp, **kwargs) + + if isinstance(m, nn.AdaptiveAvgPool2d): + pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size) + assert m.output_size == 1 + setattr(model, n, pool) + + +''' +ref. +@article{chu2021tlsc, + title={Revisiting Global Statistics Aggregation for Improving Image Restoration}, + author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin}, + journal={arXiv preprint arXiv:2112.04491}, + year={2021} +} +''' +class Local_Base(): + def convert(self, *args, train_size, **kwargs): + replace_layers(self, *args, train_size=train_size, **kwargs) + imgs = torch.rand(train_size) + with torch.no_grad(): + self.forward(imgs) diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/losses.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/losses.py new file mode 100644 index 0000000..6e33447 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/losses.py @@ -0,0 +1,345 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp +from torch.nn import L1Loss, MSELoss +import numpy as np + +class CannyNet(nn.Module): + def __init__(self): + super(CannyNet, self).__init__() + self.pad = nn.ReflectionPad2d(1) + self.conv1 = nn.Conv2d(4, 4, 3, padding=(0, 0), bias=False) + def forward(self, x): + b,c,h,w = x.size() + x = self.conv1(self.pad(x)) + return x + +class Canny(nn.Module): + def __init__(self): + super(Canny, self).__init__() + self.net = CannyNet().cuda() + self.conv_rgb_core_original = [ + [[0, 0, 0], [0, 1, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 1, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 1, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 1, 0], [0, 0, 0] + ]] + self.conv_rgb_core_sobel = [ + [[-1, -1, -1], [-1, 8, -1], [-1, -1, -1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [-1, -1, -1], [-1, 8, -1], [-1, -1, -1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [-1, -1, -1], [-1, 8, -1], [-1, -1, -1], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [-1, -1, -1], [-1, 8, -1], [-1, -1, -1], + ]] + self.conv_rgb_core_sobel_vertical = [ + [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [-1, 0, 1], [-2, 0, 2], [-1, 0, 1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [-1, 0, 1], [-2, 0, 2], [-1, 0, 1], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [-1, 0, 1], [-2, 0, 2], [-1, 0, 1], + ]] + self.conv_rgb_core_sobel_horizontal = [ + [[1, 2, 1], [0, 0, 0], [-1, -2, -1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [1, 2, 1], [0, 0, 0], [-1, -2, -1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0] + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [1, 2, 1], [0, 0, 0], [-1, -2, -1], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + ], + [[0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [0, 0, 0], [0, 0, 0], [0, 0, 0], + [1, 2, 1], [0, 0, 0], [-1, -2, -1], + ]] + + def sobel(self, net, kernel): + sobel_kernel = np.array(kernel, dtype='float32') + sobel_kernel = sobel_kernel.reshape((4, 4, 3, 3)) + net.conv1.weight.data = torch.from_numpy(sobel_kernel).cuda() + + def forward(self, x): + # x = x*2-1 #to [-1,1] + # self.sobel(self.net, self.conv_rgb_core_sobel) + # out = self.net(x).detach() + self.sobel(self.net, self.conv_rgb_core_sobel_vertical) + out_v = self.net(x).detach() + self.sobel(self.net, self.conv_rgb_core_sobel_horizontal) + out_h = self.net(x).detach() + out = torch.sqrt(torch.square(out_h)+torch.square(out_v)) + # out = torch.abs((out+1)/2.) + return out + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp( + -(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) \ + for x in range(window_size)]) + return gauss / gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand( + channel, 1, window_size, window_size).contiguous()) + return window + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d( + img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d( + img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d( + img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ + ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + +def ssim(img1, img2, window_size=11, size_average=True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + +class SSIMLoss(torch.nn.Module): + def __init__(self, window_size=11, size_average=True): + super(SSIMLoss, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and \ + self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim(img1, img2, window, self.window_size, + channel, self.size_average) + +CONTENT_LAYER = 'relu_16' +cfgs = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + +def make_layers(cfg, batch_norm=False): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + +class VGG(nn.Module): + def __init__(self, num_classes=1000): + super(VGG, self).__init__() + self.features = make_layers(cfgs['E'], batch_norm=False) + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + self.load_state_dict(torch.load('./ckpt/vgg19.pth')) + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + +def vgg_19(): + vgg_19 = VGG().features + model = nn.Sequential() + + i = 0 + for layer in vgg_19.children(): + if isinstance(layer, nn.Conv2d): + i += 1 + name = 'conv_{}'.format(i) + elif isinstance(layer, nn.ReLU): + name = 'relu_{}'.format(i) + layer = nn.ReLU(inplace=False) + elif isinstance(layer, nn.MaxPool2d): + name = 'pool_{}'.format(i) + elif isinstance(layer, nn.BatchNorm2d): + name = 'bn_{}'.format(i) + else: + raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) + + model.add_module(name, layer) + if name == CONTENT_LAYER: + break + + for param in model.parameters(): + param.requires_grad = False + + for param in vgg_19.parameters(): + param.requires_grad = False + + return model + +def normalize_batch(batch): + mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) + std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) + return (batch - mean) / std + +class VGGLoss(torch.nn.Module): + def __init__(self): + super(VGGLoss, self).__init__() + self.VGG_19 = vgg_19() + self.L1_loss = torch.nn.L1Loss() + + def forward(self, img1, img2): + img1 = F.interpolate(img1, scale_factor=0.5, mode="bilinear") + img2 = F.interpolate(img2, scale_factor=0.5, mode="bilinear") + img1_vgg = self.VGG_19(normalize_batch(img1)) + img2_vgg = self.VGG_19(normalize_batch(img2)) + loss_vgg = self.L1_loss(img1_vgg, img2_vgg) + return loss_vgg + + +class FFTLoss(nn.Module): + def __init__(self): + super().__init__() + self.canny = Canny() + self.criterion = torch.nn.L1Loss() +# self.loss_weight = loss_weight + + def forward(self, pred, target): + Edge = self.canny(target) + # pred_fft = torch.fft.fft2(pred, dim=(-2, -1)) + # target_fft = torch.fft.fft2(target, dim=(-2, -1)) + return self.criterion(pred*Edge, target*Edge) + +class GANLoss(nn.Module): + def __init__(self, gan_mode='lsgan', target_real_label=1.0, target_fake_label=0.0): + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.gan_mode = gan_mode + if gan_mode == 'lsgan': + self.loss = nn.MSELoss() + elif gan_mode == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif gan_mode in ['wgangp']: + self.loss = None + else: + raise NotImplementedError('gan mode %s not implemented' % gan_mode) + + def get_target_tensor(self, prediction, target_is_real): + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(prediction) + + def __call__(self, prediction, target_is_real): + if self.gan_mode in ['lsgan', 'vanilla']: + target_tensor = self.get_target_tensor(prediction, target_is_real) + loss = self.loss(prediction, target_tensor) + elif self.gan_mode == 'wgangp': + if target_is_real: + loss = -prediction.mean() + else: + loss = prediction.mean() + return loss diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/modules.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/modules.py new file mode 100644 index 0000000..0f9ce53 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/modules.py @@ -0,0 +1,388 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import split_feature, compute_same_pad + + +def gaussian_p(mean, logs, x): + """ + lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) } + k = 1 (Independent) + Var = logs ** 2 + """ + c = math.log(2 * math.pi) + return -0.5 * (logs * 2.0 + ((x - mean) ** 2) / torch.exp(logs * 2.0) + c) + + +def gaussian_likelihood(mean, logs, x): + p = gaussian_p(mean, logs, x) + return torch.sum(p, dim=[1, 2, 3]) + + +def gaussian_sample(mean, logs, temperature=1): + # Sample from Gaussian with temperature + z = torch.normal(mean, torch.exp(logs) * temperature) + + return z + + +def squeeze2d(input, factor): + if factor == 1: + return input + + B, C, H, W = input.size() + + assert H % factor == 0 and W % factor == 0, "H or W modulo factor is not 0" + + x = input.view(B, C, H // factor, factor, W // factor, factor) + x = x.permute(0, 1, 3, 5, 2, 4).contiguous() + x = x.view(B, C * factor * factor, H // factor, W // factor) + + return x + + +def unsqueeze2d(input, factor): + if factor == 1: + return input + + factor2 = factor ** 2 + + B, C, H, W = input.size() + + assert C % (factor2) == 0, "C module factor squared is not 0" + + x = input.view(B, C // factor2, factor, factor, H, W) + x = x.permute(0, 1, 4, 2, 5, 3).contiguous() + x = x.view(B, C // (factor2), H * factor, W * factor) + + return x + + +class _ActNorm(nn.Module): + """ + Activation Normalization + Initialize the bias and scale with a given minibatch, + so that the output per-channel have zero mean and unit variance for that. + + After initialization, `bias` and `logs` will be trained as parameters. + """ + + def __init__(self, num_features, scale=1.0): + super().__init__() + # register mean and scale + size = [1, num_features, 1, 1] + self.bias = nn.Parameter(torch.zeros(*size)) + self.logs = nn.Parameter(torch.zeros(*size)) + self.num_features = num_features + self.scale = scale + self.inited = False + + def initialize_parameters(self, input): + if not self.training: + raise ValueError("In Eval mode, but ActNorm not inited") + + with torch.no_grad(): + bias = -torch.mean(input.clone(), dim=[0, 2, 3], keepdim=True) + vars = torch.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) + logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) + + self.bias.data.copy_(bias.data) + self.logs.data.copy_(logs.data) + + self.inited = True + + def _center(self, input, reverse=False): + if reverse: + return input - self.bias + else: + return input + self.bias + + def _scale(self, input, logdet=None, reverse=False): + + if reverse: + input = input * torch.exp(-self.logs) + else: + input = input * torch.exp(self.logs) + + if logdet is not None: + """ + logs is log_std of `mean of channels` + so we need to multiply by number of pixels + """ + b, c, h, w = input.shape + + dlogdet = torch.sum(self.logs) * h * w + + if reverse: + dlogdet *= -1 + + logdet = logdet + dlogdet + + return input, logdet + + def forward(self, input, logdet=None, reverse=False): + self._check_input_dim(input) + + if not self.inited: + self.initialize_parameters(input) + + if reverse: + input, logdet = self._scale(input, logdet, reverse) + input = self._center(input, reverse) + else: + input = self._center(input, reverse) + input, logdet = self._scale(input, logdet, reverse) + + return input, logdet + + +class ActNorm2d(_ActNorm): + def __init__(self, num_features, scale=1.0): + super().__init__(num_features, scale) + + def _check_input_dim(self, input): + assert len(input.size()) == 4 + assert input.size(1) == self.num_features, ( + "[ActNorm]: input should be in shape as `BCHW`," + " channels should be {} rather than {}".format( + self.num_features, input.size() + ) + ) + + +class LinearZeros(nn.Module): + def __init__(self, in_channels, out_channels, logscale_factor=3): + super().__init__() + + self.linear = nn.Linear(in_channels, out_channels) + self.linear.weight.data.zero_() + self.linear.bias.data.zero_() + + self.logscale_factor = logscale_factor + + self.logs = nn.Parameter(torch.zeros(out_channels)) + + def forward(self, input): + output = self.linear(input) + return output * torch.exp(self.logs * self.logscale_factor) + + +class Conv2d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding="same", + do_actnorm=True, + weight_std=0.05, + ): + super().__init__() + + if padding == "same": + padding = compute_same_pad(kernel_size, stride) + elif padding == "valid": + padding = 0 + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias=(not do_actnorm), + ) + + # init weight with std + self.conv.weight.data.normal_(mean=0.0, std=weight_std) + + if not do_actnorm: + self.conv.bias.data.zero_() + else: + self.actnorm = ActNorm2d(out_channels) + + self.do_actnorm = do_actnorm + + def forward(self, input): + x = self.conv(input) + if self.do_actnorm: + x, _ = self.actnorm(x) + return x + + +class Conv2dZeros(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding="same", + logscale_factor=3, + ): + super().__init__() + + if padding == "same": + padding = compute_same_pad(kernel_size, stride) + elif padding == "valid": + padding = 0 + + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) + + self.conv.weight.data.zero_() + self.conv.bias.data.zero_() + + self.logscale_factor = logscale_factor + self.logs = nn.Parameter(torch.zeros(out_channels, 1, 1)) + + def forward(self, input): + output = self.conv(input) + return output * torch.exp(self.logs * self.logscale_factor) + + +class Permute2d(nn.Module): + def __init__(self, num_channels, shuffle): + super().__init__() + self.num_channels = num_channels + self.indices = torch.arange(self.num_channels - 1, -1, -1, dtype=torch.long) + self.indices_inverse = torch.zeros((self.num_channels), dtype=torch.long) + + for i in range(self.num_channels): + self.indices_inverse[self.indices[i]] = i + + if shuffle: + self.reset_indices() + + def reset_indices(self): + shuffle_idx = torch.randperm(self.indices.shape[0]) + self.indices = self.indices[shuffle_idx] + + for i in range(self.num_channels): + self.indices_inverse[self.indices[i]] = i + + def forward(self, input, reverse=False): + assert len(input.size()) == 4 + + if not reverse: + input = input[:, self.indices, :, :] + return input + else: + return input[:, self.indices_inverse, :, :] + + +class Split2d(nn.Module): + def __init__(self, num_channels): + super().__init__() + self.conv = Conv2dZeros(num_channels // 2, num_channels) + + def split2d_prior(self, z): + h = self.conv(z) + return split_feature(h, "cross") + + def forward(self, input, logdet=0.0, reverse=False, temperature=None): + if reverse: + z1 = input + mean, logs = self.split2d_prior(z1) + z2 = gaussian_sample(mean, logs, temperature) + z = torch.cat((z1, z2), dim=1) + return z, logdet + else: + z1, z2 = split_feature(input, "split") + mean, logs = self.split2d_prior(z1) + logdet = gaussian_likelihood(mean, logs, z2) + logdet + return z1, logdet + + +class SqueezeLayer(nn.Module): + def __init__(self, factor): + super().__init__() + self.factor = factor + + def forward(self, input, logdet=None, reverse=False): + if reverse: + output = unsqueeze2d(input, self.factor) + else: + output = squeeze2d(input, self.factor) + + return output, logdet + + +class InvertibleConv1x1(nn.Module): + def __init__(self, num_channels, LU_decomposed): + super().__init__() + w_shape = [num_channels, num_channels] + w_init = torch.qr(torch.randn(*w_shape))[0] + + if not LU_decomposed: + self.weight = nn.Parameter(torch.Tensor(w_init)) + else: + p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) + s = torch.diag(upper) + sign_s = torch.sign(s) + log_s = torch.log(torch.abs(s)) + upper = torch.triu(upper, 1) + l_mask = torch.tril(torch.ones(w_shape), -1) + eye = torch.eye(*w_shape) + + self.register_buffer("p", p) + self.register_buffer("sign_s", sign_s) + self.lower = nn.Parameter(lower) + self.log_s = nn.Parameter(log_s) + self.upper = nn.Parameter(upper) + self.l_mask = l_mask + self.eye = eye + + self.w_shape = w_shape + self.LU_decomposed = LU_decomposed + + def get_weight(self, input, reverse): + b, c, h, w = input.shape + + if not self.LU_decomposed: + dlogdet = torch.slogdet(self.weight)[1] * h * w + if reverse: + weight = torch.inverse(self.weight) + else: + weight = self.weight + else: + self.l_mask = self.l_mask.to(input.device) + self.eye = self.eye.to(input.device) + + lower = self.lower * self.l_mask + self.eye + + u = self.upper * self.l_mask.transpose(0, 1).contiguous() + u += torch.diag(self.sign_s * torch.exp(self.log_s)) + + dlogdet = torch.sum(self.log_s) * h * w + + if reverse: + print(u) + u_inv = torch.inverse(u) + l_inv = torch.inverse(lower) + p_inv = torch.inverse(self.p) + + weight = torch.matmul(u_inv, torch.matmul(l_inv, p_inv)) + else: + weight = torch.matmul(self.p, torch.matmul(lower, u)) + + return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet + + def forward(self, input, logdet=None, reverse=False): + """ + log-det = log|abs(|W|)| * pixels + """ + weight, dlogdet = self.get_weight(input, reverse) + + if not reverse: + z = F.conv2d(input, weight) + if logdet is not None: + logdet = logdet + dlogdet + return z, logdet + else: + z = F.conv2d(input, weight) + if logdet is not None: + logdet = logdet - dlogdet + return z, logdet diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/mwcnn_model.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/mwcnn_model.py new file mode 100644 index 0000000..fe0e8d5 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/mwcnn_model.py @@ -0,0 +1,118 @@ +import torch +import networks as N +import torch.nn as nn +import math +import torch.optim as optim + +class MWRCAN(nn.Module): + def __init__(self): + super(MWRCAN, self).__init__() + c1 = 64 + c2 = 128 + c3 = 128 + n_b = 20 + self.head = N.DWTForward() + + self.down1 = N.seq( + nn.Conv2d(4 * 4, c1, 3, 1, 1), + nn.PReLU(), + N.RCAGroup(in_channels=c1, out_channels=c1, nb=n_b) + ) + + self.down2 = N.seq( + N.DWTForward(), + nn.Conv2d(c1 * 4, c2, 3, 1, 1), + nn.PReLU(), + N.RCAGroup(in_channels=c2, out_channels=c2, nb=n_b) + ) + + self.down3 = N.seq( + N.DWTForward(), + nn.Conv2d(c2 * 4, c3, 3, 1, 1), + nn.PReLU() + ) + + self.middle = N.seq( + N.RCAGroup(in_channels=c3, out_channels=c3, nb=n_b), + N.RCAGroup(in_channels=c3, out_channels=c3, nb=n_b) + ) + + self.up1 = N.seq( + nn.Conv2d(c3, c2 * 4, 3, 1, 1), + nn.PReLU(), + N.DWTInverse() + ) + + self.up2 = N.seq( + N.RCAGroup(in_channels=c2, out_channels=c2, nb=n_b), + nn.Conv2d(c2, c1 * 4, 3, 1, 1), + nn.PReLU(), + N.DWTInverse() + ) + + self.up3 = N.seq( + N.RCAGroup(in_channels=c1, out_channels=c1, nb=n_b), + nn.Conv2d(c1, 16, 3, 1, 1) + ) + + self.tail = N.seq( + N.DWTInverse(), + nn.Conv2d(4, 12, 3, 1, 1), + nn.PixelShuffle(upscale_factor=2) + ) + + def forward(self, x, c=None): + c0 = x + c1 = self.head(c0) + c2 = self.down1(c1) + c3 = self.down2(c2) + c4 = self.down3(c3) + m = self.middle(c4) + c5 = self.up1(m) + c3 + c6 = self.up2(c5) + c2 + c7 = self.up3(c6) + c1 + out = self.tail(c7) + + return out + +class Discriminator(nn.Module): + """Defines a PatchGAN discriminator""" + def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(Discriminator, self).__init__() + use_bias = False + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.model = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.model(input) \ No newline at end of file diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/net.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/net.py new file mode 100644 index 0000000..93221c9 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/net.py @@ -0,0 +1,684 @@ +import torch +import torch.nn as nn +from torch.nn import init +from torch.optim import lr_scheduler +from collections import OrderedDict +from models.arch_util import LayerNorm2d +import numbers +from einops import rearrange +import torch.nn.functional as F + +train_size=(1,3,504, 504) +class AvgPool2d(nn.Module): + def __init__(self, kernel_size=None, base_size=[490 ,490], auto_pad=True, fast_imp=False): + super().__init__() + self.kernel_size = kernel_size + self.base_size = base_size + self.auto_pad = auto_pad + + # only used for fast implementation + self.fast_imp = fast_imp + self.rs = [5,4,3,2,1] + self.max_r1 = self.rs[0] + self.max_r2 = self.rs[0] + + def extra_repr(self) -> str: + return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format( + self.kernel_size, self.base_size, self.kernel_size, self.fast_imp + ) + + def forward(self, x): + if self.kernel_size is None and self.base_size: + if isinstance(self.base_size, int): + self.base_size = (self.base_size, self.base_size) + self.kernel_size = list(self.base_size) + self.kernel_size[0] = x.shape[2]*self.base_size[0]//train_size[-2] + self.kernel_size[1] = x.shape[3]*self.base_size[1]//train_size[-1] + + # only used for fast implementation + self.max_r1 = max(1, self.rs[0]*x.shape[2]//train_size[-2]) + self.max_r2 = max(1, self.rs[0]*x.shape[3]//train_size[-1]) + + if self.fast_imp: # Non-equivalent implementation but faster + h, w = x.shape[2:] + if self.kernel_size[0]>=h and self.kernel_size[1]>=w: + out = F.adaptive_avg_pool2d(x,1) + else: + r1 = [r for r in self.rs if h%r==0][0] + r2 = [r for r in self.rs if w%r==0][0] + # reduction_constraint + r1 = min(self.max_r1, r1) + r2 = min(self.max_r2, r2) + s = x[:,:,::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2) + n, c, h, w = s.shape + k1, k2 = min(h-1, self.kernel_size[0]//r1), min(w-1, self.kernel_size[1]//r2) + out = (s[:,:,:-k1,:-k2]-s[:,:,:-k1,k2:]-s[:,:,k1:,:-k2]+s[:,:,k1:,k2:])/(k1*k2) + out = torch.nn.functional.interpolate(out, scale_factor=(r1,r2)) + else: + n, c, h, w = x.shape + s = x.cumsum(dim=-1).cumsum(dim=-2) + s = torch.nn.functional.pad(s, (1,0,1,0)) # pad 0 for convenience + k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1]) + s1, s2, s3, s4 = s[:,:,:-k1,:-k2],s[:,:,:-k1,k2:], s[:,:,k1:,:-k2], s[:,:,k1:,k2:] + out = s4+s1-s2-s3 + out = out / (k1*k2) + + if self.auto_pad: + n, c, h, w = x.shape + _h, _w = out.shape[2:] + # print(x.shape, self.kernel_size) + pad2d = ((w - _w)//2, (w - _w + 1)//2, (h - _h) // 2, (h - _h + 1) // 2) + out = torch.nn.functional.pad(out, pad2d, mode='replicate') + + return out + + +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + return 1 - max(0, epoch-opt.niter) / max(1, float(opt.niter_decay)) + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, + step_size=opt.lr_decay_iters, + gamma=0.5) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, + mode='min', + factor=0.2, + threshold=0.01, + patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, + T_max=opt.niter, + eta_min=1e-5) + else: + return NotImplementedError('lr [%s] is not implemented', opt.lr_policy) + return scheduler + +def init_weights(net, init_type='normal', init_gain=0.02): + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 \ + or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + elif init_type == 'uniform': + init.uniform_(m.weight.data, b=init_gain) + else: + raise NotImplementedError('[%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + +def init_net(net, init_type='default', init_gain=0.02, gpu_ids=[]): + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + if init_type != 'default' and init_type is not None: + init_weights(net, init_type, init_gain=init_gain) + return net + + +''' +# =================================== +# Advanced nn.Sequential +# reform nn.Sequentials and nn.Modules +# to a single nn.Sequential +# =================================== +''' + +def seq(*args): + if len(args) == 1: + args = args[0] + if isinstance(args, nn.Module): + return args + modules = OrderedDict() + if isinstance(args, OrderedDict): + for k, v in args.items(): + modules[k] = seq(v) + return nn.Sequential(modules) + assert isinstance(args, (list, tuple)) + return nn.Sequential(*[seq(i) for i in args]) + +''' +# =================================== +# Useful blocks +# -------------------------------- +# conv (+ normaliation + relu) +# concat +# sum +# resblock (ResBlock) +# resdenseblock (ResidualDenseBlock_5C) +# resinresdenseblock (RRDB) +# =================================== +''' + +# ------------------------------------------------------- +# return nn.Sequantial of (Conv + BN + ReLU) +# ------------------------------------------------------- +def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, + output_padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', mode='C'): + L = [] + for t in mode: + if t == 'C': + L.append(nn.utils.spectral_norm(nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode))) + elif t == 'X': + assert in_channels == out_channels + L.append(nn.utils.spectral_norm(nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=bias, + padding_mode=padding_mode))) + elif t == 'T': + L.append(nn.utils.spectral_norm(nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + padding_mode=padding_mode))) + elif t == 'B': + L.append(nn.BatchNorm2d(out_channels)) + elif t == 'I': + L.append(nn.InstanceNorm2d(out_channels, affine=True)) + elif t == 'i': + L.append(nn.InstanceNorm2d(out_channels)) + elif t == 'R': + L.append(nn.ReLU(inplace=True)) + elif t == 'r': + L.append(nn.ReLU(inplace=False)) + elif t == 'S': + L.append(nn.Sigmoid()) + elif t == 'P': + L.append(nn.PReLU()) + elif t == 'L': + L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=True)) + elif t == 'l': + L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=False)) + elif t == '2': + L.append(nn.PixelShuffle(upscale_factor=2)) + elif t == '3': + L.append(nn.PixelShuffle(upscale_factor=3)) + elif t == '4': + L.append(nn.PixelShuffle(upscale_factor=4)) + elif t == 'U': + L.append(nn.Upsample(scale_factor=2, mode='nearest')) + elif t == 'u': + L.append(nn.Upsample(scale_factor=3, mode='nearest')) + elif t == 'M': + L.append(nn.MaxPool2d(kernel_size=kernel_size, + stride=stride, + padding=0)) + elif t == 'A': + L.append(nn.AvgPool2d(kernel_size=kernel_size, + stride=stride, + padding=0)) + else: + raise NotImplementedError('Undefined type: '.format(t)) + return seq(*L) + + +class DWTForward(nn.Conv2d): + def __init__(self, in_channels=64): + super(DWTForward, self).__init__(in_channels, in_channels*4, 2, 2, + groups=in_channels, bias=False) + weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]], + [[[0.5, 0.5], [-0.5, -0.5]]], + [[[0.5, -0.5], [ 0.5, -0.5]]], + [[[0.5, -0.5], [-0.5, 0.5]]]], + dtype=torch.get_default_dtype() + ).repeat(in_channels, 1, 1, 1)# / 2 + self.weight.data.copy_(weight) + self.requires_grad_(False) + + +class DWTInverse(nn.ConvTranspose2d): + def __init__(self, in_channels=64): + super(DWTInverse, self).__init__(in_channels, in_channels//4, 2, 2, + groups=in_channels//4, bias=False) + weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]], + [[[0.5, 0.5], [-0.5, -0.5]]], + [[[0.5, -0.5], [ 0.5, -0.5]]], + [[[0.5, -0.5], [-0.5, 0.5]]]], + dtype=torch.get_default_dtype() + ).repeat(in_channels//4, 1, 1, 1)# * 2 + self.weight.data.copy_(weight) + self.requires_grad_(False) + + +# ------------------------------------------------------- +# Channel Attention (CA) Layer +# ------------------------------------------------------- +class CALayer(nn.Module): + def __init__(self, channel=64, reduction=16): + super(CALayer, self).__init__() + + self.avg_pool = AvgPool2d() + self.conv_du = nn.Sequential( + nn.utils.spectral_norm(nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True)), + nn.ReLU(inplace=True), + nn.utils.spectral_norm(nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True)), + nn.Sigmoid() + ) + + def forward(self, x): + # print(x.shape) + y = self.avg_pool(x) + # print(y.shape) + # print(1/0) + y = self.conv_du(y) + return x * y + + +# ------------------------------------------------------- +# Res Block: x + conv(relu(conv(x))) +# ------------------------------------------------------- +class ResBlock(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1, bias=True, mode='CRC'): + super(ResBlock, self).__init__() + + assert in_channels == out_channels + if mode[0] in ['R','L']: + mode = mode[0].lower() + mode[1:] + + self.res = conv(in_channels, out_channels, kernel_size=7,padding=3,stride=1, bias=bias, mode=mode) + + def forward(self, x): + res = self.res(x) + return x + res + + +# ------------------------------------------------------- +# Residual Channel Attention Block (RCAB) +# ------------------------------------------------------- +class RCABlock(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1, bias=True, mode='CRC', reduction=16): + super(RCABlock, self).__init__() + assert in_channels == out_channels + if mode[0] in ['R','L']: + mode = mode[0].lower() + mode[1:] + + self.res = conv(in_channels, out_channels, kernel_size, + stride, padding, bias=bias, mode=mode) + self.ca = CALayer(out_channels, reduction) + + def forward(self, x): + res = self.res(x) + res = self.ca(res) + return res + x + + +# ------------------------------------------------------- +# Residual Channel Attention Group (RG) +# ------------------------------------------------------- +class RCAGroup(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1, bias=True, mode='CRC', reduction=16, nb=12): + super(RCAGroup, self).__init__() + assert in_channels == out_channels + if mode[0] in ['R','L']: + mode = mode[0].lower() + mode[1:] + + RG = [RCABlock(in_channels, out_channels, kernel_size, stride, padding, + bias, mode, reduction) for _ in range(nb)] + # RG = [ResBlock(in_channels, out_channels, kernel_size, stride, padding, + # bias, mode) for _ in range(nb)] + RG.append(conv(out_channels, out_channels,kernel_size, stride, padding, mode='C')) + + self.rg = nn.Sequential(*RG) + + def forward(self, x): + res = self.rg(x) + return res + x + +class RNAFGroup(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1, bias=True, mode='CRC', reduction=16, nb=12): + super(RNAFGroup, self).__init__() + assert in_channels == out_channels + + + RG = [NAFBlock(in_channels) for _ in range(nb)] + + RG.append(conv(out_channels, out_channels, mode='C')) + + self.rg = nn.Sequential(*RG) + + def forward(self, x): + res = self.rg(x) + return res + x + +class RNAFBaselineGroup(nn.Module): + def __init__(self, in_channels=64, out_channels=64, nb=12): + super(RNAFBaselineGroup, self).__init__() + assert in_channels == out_channels + + + RG = [NAFBaselineBlock(in_channels) for _ in range(nb)] + + RG.append(conv(out_channels, out_channels, mode='C')) + + self.rg = nn.Sequential(*RG) + + def forward(self, x): + res = self.rg(x) + return res + x + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + self.conv2 = nn.utils.spectral_norm(nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True)) + self.conv3 = nn.utils.spectral_norm(nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.utils.spectral_norm(nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, + groups=1, bias=True)), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.utils.spectral_norm(nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + self.conv5 = nn.utils.spectral_norm(nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + + + + + +class NAFBaselineBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + self.conv2 = nn.utils.spectral_norm(nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True)) + self.conv3 = nn.utils.spectral_norm(nn.Conv2d(in_channels=dw_channel , out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + + # Simplified Channel Attention + self.sca = CALayer(dw_channel , reduction=16) + + # SimpleGate + # self.sg = + + ffn_channel = FFN_Expand * c + self.conv4 = nn.utils.spectral_norm(nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + self.conv5 = nn.utils.spectral_norm(nn.Conv2d(in_channels=ffn_channel , out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = F.gelu(x) + x = self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = F.gelu(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + + +def to_4d(x, h, w): + return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma + 1e-5) * self.weight + + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type == 'BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim * ffn_expansion_factor) + + self.project_in = nn.utils.spectral_norm(nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)) + + self.dwconv = nn.utils.spectral_norm(nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, + groups=hidden_features * 2, bias=bias)) + + self.project_out = nn.utils.spectral_norm(nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.utils.spectral_norm(nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)) + self.qkv_dwconv = nn.utils.spectral_norm(nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)) + self.project_out = nn.utils.spectral_norm(nn.Conv2d(dim, dim, kernel_size=1, bias=bias)) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.utils.spectral_norm(nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)) + + def forward(self, x): + x = self.proj(x) + + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False)), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False)), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/networks.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/networks.py new file mode 100644 index 0000000..dc73cf9 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/networks.py @@ -0,0 +1,617 @@ +import torch +import torch.nn as nn +from torch.nn import init +from torch.optim import lr_scheduler +from collections import OrderedDict +from models.arch_util import LayerNorm2d +import numbers +from einops import rearrange +import torch.nn.functional as F + + +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + return 1 - max(0, epoch-opt.niter) / max(1, float(opt.niter_decay)) + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, + step_size=opt.lr_decay_iters, + gamma=0.5) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, + mode='min', + factor=0.2, + threshold=0.01, + patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, + T_max=opt.niter, + eta_min=1e-5) + else: + return NotImplementedError('lr [%s] is not implemented', opt.lr_policy) + return scheduler + +def init_weights(net, init_type='normal', init_gain=0.02): + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 \ + or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + elif init_type == 'uniform': + init.uniform_(m.weight.data, b=init_gain) + else: + raise NotImplementedError('[%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + +def init_net(net, init_type='default', init_gain=0.02, gpu_ids=[]): + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + if init_type != 'default' and init_type is not None: + init_weights(net, init_type, init_gain=init_gain) + return net + + +''' +# =================================== +# Advanced nn.Sequential +# reform nn.Sequentials and nn.Modules +# to a single nn.Sequential +# =================================== +''' + +def seq(*args): + if len(args) == 1: + args = args[0] + if isinstance(args, nn.Module): + return args + modules = OrderedDict() + if isinstance(args, OrderedDict): + for k, v in args.items(): + modules[k] = seq(v) + return nn.Sequential(modules) + assert isinstance(args, (list, tuple)) + return nn.Sequential(*[seq(i) for i in args]) + +''' +# =================================== +# Useful blocks +# -------------------------------- +# conv (+ normaliation + relu) +# concat +# sum +# resblock (ResBlock) +# resdenseblock (ResidualDenseBlock_5C) +# resinresdenseblock (RRDB) +# =================================== +''' + +# ------------------------------------------------------- +# return nn.Sequantial of (Conv + BN + ReLU) +# ------------------------------------------------------- +def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, + output_padding=0, dilation=1, groups=1, bias=True, + padding_mode='zeros', mode='C'): + L = [] + for t in mode: + if t == 'C': + L.append(nn.utils.spectral_norm(nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode))) + elif t == 'X': + assert in_channels == out_channels + L.append(nn.utils.spectral_norm(nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=bias, + padding_mode=padding_mode))) + elif t == 'T': + L.append(nn.utils.spectral_norm(nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + dilation=dilation, + padding_mode=padding_mode))) + elif t == 'B': + L.append(nn.BatchNorm2d(out_channels)) + elif t == 'I': + L.append(nn.InstanceNorm2d(out_channels, affine=True)) + elif t == 'i': + L.append(nn.InstanceNorm2d(out_channels)) + elif t == 'R': + L.append(nn.ReLU(inplace=True)) + elif t == 'r': + L.append(nn.ReLU(inplace=False)) + elif t == 'S': + L.append(nn.Sigmoid()) + elif t == 'P': + L.append(nn.PReLU()) + elif t == 'L': + L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=True)) + elif t == 'l': + L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=False)) + elif t == '2': + L.append(nn.PixelShuffle(upscale_factor=2)) + elif t == '3': + L.append(nn.PixelShuffle(upscale_factor=3)) + elif t == '4': + L.append(nn.PixelShuffle(upscale_factor=4)) + elif t == 'U': + L.append(nn.Upsample(scale_factor=2, mode='nearest')) + elif t == 'u': + L.append(nn.Upsample(scale_factor=3, mode='nearest')) + elif t == 'M': + L.append(nn.MaxPool2d(kernel_size=kernel_size, + stride=stride, + padding=0)) + elif t == 'A': + L.append(nn.AvgPool2d(kernel_size=kernel_size, + stride=stride, + padding=0)) + else: + raise NotImplementedError('Undefined type: '.format(t)) + return seq(*L) + + +class DWTForward(nn.Conv2d): + def __init__(self, in_channels=64): + super(DWTForward, self).__init__(in_channels, in_channels*4, 2, 2, + groups=in_channels, bias=False) + weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]], + [[[0.5, 0.5], [-0.5, -0.5]]], + [[[0.5, -0.5], [ 0.5, -0.5]]], + [[[0.5, -0.5], [-0.5, 0.5]]]], + dtype=torch.get_default_dtype() + ).repeat(in_channels, 1, 1, 1)# / 2 + self.weight.data.copy_(weight) + self.requires_grad_(False) + + +class DWTInverse(nn.ConvTranspose2d): + def __init__(self, in_channels=64): + super(DWTInverse, self).__init__(in_channels, in_channels//4, 2, 2, + groups=in_channels//4, bias=False) + weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]], + [[[0.5, 0.5], [-0.5, -0.5]]], + [[[0.5, -0.5], [ 0.5, -0.5]]], + [[[0.5, -0.5], [-0.5, 0.5]]]], + dtype=torch.get_default_dtype() + ).repeat(in_channels//4, 1, 1, 1)# * 2 + self.weight.data.copy_(weight) + self.requires_grad_(False) + + +# ------------------------------------------------------- +# Channel Attention (CA) Layer +# ------------------------------------------------------- +class CALayer(nn.Module): + def __init__(self, channel=64, reduction=16): + super(CALayer, self).__init__() + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_du = nn.Sequential( + nn.utils.spectral_norm(nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True)), + nn.ReLU(inplace=True), + nn.utils.spectral_norm(nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True)), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + + +# ------------------------------------------------------- +# Res Block: x + conv(relu(conv(x))) +# ------------------------------------------------------- +class ResBlock(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1, bias=True, mode='CRC'): + super(ResBlock, self).__init__() + + assert in_channels == out_channels + if mode[0] in ['R','L']: + mode = mode[0].lower() + mode[1:] + + self.res = conv(in_channels, out_channels, kernel_size=7,padding=3,stride=1, bias=bias, mode=mode) + + def forward(self, x): + res = self.res(x) + return x + res + + +# ------------------------------------------------------- +# Residual Channel Attention Block (RCAB) +# ------------------------------------------------------- +class RCABlock(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1, bias=True, mode='CRC', reduction=16): + super(RCABlock, self).__init__() + assert in_channels == out_channels + if mode[0] in ['R','L']: + mode = mode[0].lower() + mode[1:] + + self.res = conv(in_channels, out_channels, kernel_size, + stride, padding, bias=bias, mode=mode) + self.ca = CALayer(out_channels, reduction) + + def forward(self, x): + res = self.res(x) + res = self.ca(res) + return res + x + + +# ------------------------------------------------------- +# Residual Channel Attention Group (RG) +# ------------------------------------------------------- +class RCAGroup(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1, bias=True, mode='CRC', reduction=16, nb=12): + super(RCAGroup, self).__init__() + assert in_channels == out_channels + if mode[0] in ['R','L']: + mode = mode[0].lower() + mode[1:] + + RG = [RCABlock(in_channels, out_channels, kernel_size, stride, padding, + bias, mode, reduction) for _ in range(nb)] + # RG = [ResBlock(in_channels, out_channels, kernel_size, stride, padding, + # bias, mode) for _ in range(nb)] + RG.append(conv(out_channels, out_channels,kernel_size, stride, padding, mode='C')) + + self.rg = nn.Sequential(*RG) + + def forward(self, x): + res = self.rg(x) + return res + x + +class RNAFGroup(nn.Module): + def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1, bias=True, mode='CRC', reduction=16, nb=12): + super(RNAFGroup, self).__init__() + assert in_channels == out_channels + + + RG = [NAFBlock(in_channels) for _ in range(nb)] + + RG.append(conv(out_channels, out_channels, mode='C')) + + self.rg = nn.Sequential(*RG) + + def forward(self, x): + res = self.rg(x) + return res + x + +class RNAFBaselineGroup(nn.Module): + def __init__(self, in_channels=64, out_channels=64, nb=12): + super(RNAFBaselineGroup, self).__init__() + assert in_channels == out_channels + + + RG = [NAFBaselineBlock(in_channels) for _ in range(nb)] + + RG.append(conv(out_channels, out_channels, mode='C')) + + self.rg = nn.Sequential(*RG) + + def forward(self, x): + res = self.rg(x) + return res + x + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + self.conv2 = nn.utils.spectral_norm(nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True)) + self.conv3 = nn.utils.spectral_norm(nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + + # Simplified Channel Attention + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.utils.spectral_norm(nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, + groups=1, bias=True)), + ) + + # SimpleGate + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.utils.spectral_norm(nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + self.conv5 = nn.utils.spectral_norm(nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + + + + + +class NAFBaselineBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + self.conv2 = nn.utils.spectral_norm(nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, + bias=True)) + self.conv3 = nn.utils.spectral_norm(nn.Conv2d(in_channels=dw_channel , out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + + # Simplified Channel Attention + self.sca = CALayer(dw_channel , reduction=16) + + # SimpleGate + # self.sg = + + ffn_channel = FFN_Expand * c + self.conv4 = nn.utils.spectral_norm(nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + self.conv5 = nn.utils.spectral_norm(nn.Conv2d(in_channels=ffn_channel , out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + + x = self.norm1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = F.gelu(x) + x = self.sca(x) + x = self.conv3(x) + + x = self.dropout1(x) + + y = inp + x * self.beta + + x = self.conv4(self.norm2(y)) + x = F.gelu(x) + x = self.conv5(x) + + x = self.dropout2(x) + + return y + x * self.gamma + +########################################################################## +## Layer Norm + +def to_3d(x): + return rearrange(x, 'b c h w -> b (h w) c') + + +def to_4d(x, h, w): + return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + + +class BiasFree_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(BiasFree_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + sigma = x.var(-1, keepdim=True, unbiased=False) + return x / torch.sqrt(sigma + 1e-5) * self.weight + + +class WithBias_LayerNorm(nn.Module): + def __init__(self, normalized_shape): + super(WithBias_LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + normalized_shape = torch.Size(normalized_shape) + + assert len(normalized_shape) == 1 + + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.normalized_shape = normalized_shape + + def forward(self, x): + mu = x.mean(-1, keepdim=True) + sigma = x.var(-1, keepdim=True, unbiased=False) + return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias + + +class LayerNorm(nn.Module): + def __init__(self, dim, LayerNorm_type): + super(LayerNorm, self).__init__() + if LayerNorm_type == 'BiasFree': + self.body = BiasFree_LayerNorm(dim) + else: + self.body = WithBias_LayerNorm(dim) + + def forward(self, x): + h, w = x.shape[-2:] + return to_4d(self.body(to_3d(x)), h, w) + + +########################################################################## +## Gated-Dconv Feed-Forward Network (GDFN) +class FeedForward(nn.Module): + def __init__(self, dim, ffn_expansion_factor, bias): + super(FeedForward, self).__init__() + + hidden_features = int(dim * ffn_expansion_factor) + + self.project_in = nn.utils.spectral_norm(nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)) + + self.dwconv = nn.utils.spectral_norm(nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, + groups=hidden_features * 2, bias=bias)) + + self.project_out = nn.utils.spectral_norm(nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + +########################################################################## +## Multi-DConv Head Transposed Self-Attention (MDTA) +class Attention(nn.Module): + def __init__(self, dim, num_heads, bias): + super(Attention, self).__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.utils.spectral_norm(nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)) + self.qkv_dwconv = nn.utils.spectral_norm(nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)) + self.project_out = nn.utils.spectral_norm(nn.Conv2d(dim, dim, kernel_size=1, bias=bias)) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + + out = (attn @ v) + + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + + out = self.project_out(out) + return out + + +########################################################################## +class TransformerBlock(nn.Module): + def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): + super(TransformerBlock, self).__init__() + + self.norm1 = LayerNorm(dim, LayerNorm_type) + self.attn = Attention(dim, num_heads, bias) + self.norm2 = LayerNorm(dim, LayerNorm_type) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + + return x + + +########################################################################## +## Overlapped image patch embedding with 3x3 Conv +class OverlapPatchEmbed(nn.Module): + def __init__(self, in_c=3, embed_dim=48, bias=False): + super(OverlapPatchEmbed, self).__init__() + + self.proj = nn.utils.spectral_norm(nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)) + + def forward(self, x): + x = self.proj(x) + + return x + + +########################################################################## +## Resizing modules +class Downsample(nn.Module): + def __init__(self, n_feat): + super(Downsample, self).__init__() + + self.body = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False)), + nn.PixelUnshuffle(2)) + + def forward(self, x): + return self.body(x) + + +class Upsample(nn.Module): + def __init__(self, n_feat): + super(Upsample, self).__init__() + + self.body = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False)), + nn.PixelShuffle(2)) + + def forward(self, x): + return self.body(x) diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/s7smooth_model.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/s7smooth_model.py new file mode 100644 index 0000000..852c2ba --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/s7smooth_model.py @@ -0,0 +1,291 @@ +import torch +from .base_model import BaseModel +from . import networks as N +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from . import losses as L +from pwc import pwc_net +from util.util import get_coord +import numpy as np + + + +def demosaic (raw): + """Simple demosaicing to visualize RAW images + Inputs: + - raw: (h,w,4) RAW RGGB image normalized [0..1] as float32 + Returns: + - Simple Avg. Green Demosaiced RAW image with shape (h*2, w*2, 3) + """ + + assert raw.shape[1] == 4 + shape = raw.shape + + blue = raw[:,0:1,:,:] + green_red = raw[:,1:2,:,:] + red = raw[:,2:3,:,:] + green_blue = raw[:,3:,:,:] + avg_green = (green_red + green_blue) / 2 + image = torch.cat((red, avg_green, blue), dim=1) + image = F.interpolate(input=image, size=(shape[2]*2, shape[3]*2), + mode='bilinear', align_corners=True) + return image + +def gamma_compression(image): + """Converts from linear to gamma space.""" + return torch.clamp(image, 1e-8, 1.0) ** (1.0 / 2.2) + +def tonemap(image): + """Simple S-curved global tonemap""" + return (3*(image**2)) - (2*(image**3)) + +def ISP(raw): + raw = demosaic(raw) + raw = gamma_compression(raw) + raw = tonemap(raw) + raw = torch.clamp(raw, 0.0, 1.0) + return raw + + +def pixel_unshuffle(input, downscale_factor): + ''' + input: batchSize * c * k*w * k*h + kdownscale_factor: k + batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h + ''' + c = input.shape[1] + + kernel = torch.zeros(size=[downscale_factor * downscale_factor * c, + 1, downscale_factor, downscale_factor], + device=input.device) + for y in range(downscale_factor): + for x in range(downscale_factor): + kernel[x + y * downscale_factor::downscale_factor*downscale_factor, 0, y, x] = 1 + return F.conv2d(input, kernel, stride=downscale_factor, groups=c) + +class PixelUnshuffle(nn.Module): + def __init__(self, downscale_factor): + super(PixelUnshuffle, self).__init__() + self.downscale_factor = downscale_factor + def forward(self, input): + ''' + input: batchSize * c * k*w * k*h + kdownscale_factor: k + batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h + ''' + + return pixel_unshuffle(input, self.downscale_factor) + +class S7smoothModel(BaseModel): + staticmethod + def modify_commandline_options(parser, is_train=True): + return parser + + def __init__(self, opt): + super(S7smoothModel, self).__init__(opt) + + self.opt = opt + self.loss_names = ['GCMModel_L1', 'LiteISPNet_L1', 'LiteISPNet_SSIM','LiteISPNet_VGG', 'Total'] + + if self.isTrain: + self.visual_names = [ 'data_out','data_raw_demosaic','data_dslr','GCMModel_out','GCMModel_out_warp','rgb_mask','raw_warp', 'raw_mask','data_raw','data_dslr_mask','data_out_mask'] + else: + self.visual_names = [ 'data_out','data_dslr'] + + self.model_names = ['LiteISPNet', 'GCMModel'] + self.optimizer_names = ['LiteISPNet_optimizer_%s' % opt.optimizer, + 'GCMModel_optimizer_%s' % opt.optimizer] + + isp = LiteISPNet(opt) + self.netLiteISPNet= N.init_net(isp, opt.init_type, opt.init_gain, opt.gpu_ids) + + gcm = GCMModel(opt) + self.netGCMModel = N.init_net(gcm, opt.init_type, opt.init_gain, opt.gpu_ids) + + pwcnet = pwc_net.PWCNET() + self.netPWCNET = N.init_net(pwcnet, opt.init_type, opt.init_gain, opt.gpu_ids) + self.set_requires_grad(self.netPWCNET, requires_grad=False) + + if self.isTrain: + self.optimizer_LiteISPNet = optim.AdamW(self.netLiteISPNet.parameters(), + lr=opt.lr, + betas=(opt.beta1, opt.beta2), + weight_decay=opt.weight_decay) + self.optimizer_GCMModel = optim.AdamW(self.netGCMModel.parameters(), + lr=opt.lr, + betas=(opt.beta1, opt.beta2), + weight_decay=opt.weight_decay) + self.optimizers = [self.optimizer_LiteISPNet, self.optimizer_GCMModel] + + self.criterionL1 = N.init_net(L.L1Loss(), gpu_ids=opt.gpu_ids) + self.criterionSSIM = N.init_net(L.SSIMLoss(), gpu_ids=opt.gpu_ids) + self.criterionVGG = N.init_net(L.VGGLoss(), gpu_ids=opt.gpu_ids) + + + def set_input(self, input): + if self.isTrain: + self.data_raw = input['raw'].to(self.device) + self.data_raw_demosaic = input['raw_demosaic'].to(self.device) + self.data_dslr = input['dslr'].to(self.device) + self.image_paths = input['fname'] + + def forward(self): + if self.isTrain: + self.GCMModel_out = self.netGCMModel(self.data_raw_demosaic, self.data_dslr) + self.GCMModel_out_warp, self.rgb_mask,self.raw_warp, self.raw_mask = \ + self.get_backwarp( self.data_dslr,self.GCMModel_out,self.data_raw, self.netPWCNET) + + self.data_out = self.netLiteISPNet(self.data_dslr) + + if self.isTrain: + self.data_dslr_mask = self.data_dslr * self.rgb_mask + self.data_out_mask = self.data_out * self.raw_mask + self.raw_warp_rgb = ISP(self.raw_warp).to(self.device) + self.data_out_mask_rgb = ISP(self.data_out_mask).to(self.device) + + def backward(self): + self.loss_GCMModel_L1 = self.criterionL1(self.GCMModel_out_warp, self.data_dslr_mask).mean() + self.loss_LiteISPNet_L1 = self.criterionL1(self.data_out_mask, self.raw_warp).mean() + self.loss_LiteISPNet_SSIM = 1 - self.criterionSSIM(self.data_out_mask, self.raw_warp).mean() + self.loss_LiteISPNet_VGG = self.criterionVGG(self.data_out_mask_rgb, self.raw_warp_rgb).mean() + self.loss_Total = self.loss_GCMModel_L1 + self.loss_LiteISPNet_L1+ \ + self.loss_LiteISPNet_VGG * 0.4 + self.loss_LiteISPNet_SSIM * 0.1 + self.loss_Total.backward() + + def optimize_parameters(self): + self.forward() + self.optimizer_LiteISPNet.zero_grad() + self.optimizer_GCMModel.zero_grad() + self.backward() + self.optimizer_LiteISPNet.step() + self.optimizer_GCMModel.step() + +class GCMModel(nn.Module): + def __init__(self, opt): + super(GCMModel, self).__init__() + self.opt = opt + self.ch_1 = 32 + self.ch_2 = 64 + + guide_input_channels = 6 + align_input_channels = 3 + + self.guide_net = N.seq( + N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'), + N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'), + nn.AdaptiveAvgPool2d(1), + N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C') + ) + + self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR') + + self.align_base = N.seq( + N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCR') + ) + self.align_tail = N.seq( + N.conv(self.ch_2, 3, 1, padding=0, mode='C') + ) + + def forward(self, demosaic_raw, dslr): + demosaic_raw = torch.pow(demosaic_raw, 1/2.2) + + guide_input = torch.cat((demosaic_raw, dslr), 1) + base_input = demosaic_raw + + guide = self.guide_net(guide_input) + + out = self.align_head(base_input) + out = guide * out + out + out = self.align_base(out) + out = self.align_tail(out) + demosaic_raw + + return out + +class LiteISPNet(nn.Module): + def __init__(self, opt): + super(LiteISPNet, self).__init__() + self.opt = opt + ch_1 = 64 + ch_2 = 128 + ch_3 = 128 + n_blocks = 4 + + self.head = N.seq( + N.conv(3, ch_1, mode='C') + ) # shape: (N, ch_1, H/2, W/2) + + + + self.down1 = N.seq( + N.conv(ch_1, ch_1, mode='C'), + N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), + N.conv(ch_1, ch_1, mode='C'), + N.DWTForward(ch_1) + ) # shape: (N, ch_1*4, H/4, W/4) + + self.down2 = N.seq( + N.conv(ch_1*4, ch_1, mode='C'), + N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), + N.DWTForward(ch_1) + ) # shape: (N, ch_1*4, H/8, W/8) + + self.down3 = N.seq( + N.conv(ch_1*4, ch_2, mode='C'), + N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks), + N.DWTForward(ch_2) + ) # shape: (N, ch_2*4, H/16, W/16) + + self.middle = N.seq( + N.conv(ch_2*4, ch_3, mode='C'), + N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks), + N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks), + N.conv(ch_3, ch_2*4, mode='C') + ) # shape: (N, ch_2*4, H/16, W/16) + + self.up3 = N.seq( + N.DWTInverse(ch_2*4), + N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks), + N.conv(ch_2, ch_1*4, mode='C') + ) # shape: (N, ch_1*4, H/8, W/8) + + self.up2 = N.seq( + N.DWTInverse(ch_1*4), + N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), + N.conv(ch_1, ch_1*4, mode='C') + ) # shape: (N, ch_1*4, H/4, W/4) + + self.up1 = N.seq( + N.DWTInverse(ch_1*4), + N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), + N.conv(ch_1, ch_1, mode='C') + ) # shape: (N, ch_1, H/2, W/2) + + self.tail = N.seq( + N.conv(ch_1, ch_1//4, mode='C'), + PixelUnshuffle(downscale_factor=2), + N.conv(ch_1, 4, mode='C') + ) # shape: (N, 3, H, W) + + def forward(self, raw, ): + # input = raw + raw =torch.clamp(raw,0.,1.) + raw = 0.5 - torch.sin(torch.asin(1.0 - 2.0 * raw) / 3.0) + raw = torch.clamp(raw,1e-8 ,1.) + input = torch.pow(raw, 2.2) + h = self.head(input) + + + d1 = self.down1(h) + d2 = self.down2(d1) + d3 = self.down3(d2) + m = self.middle(d3) + d3 + u3 = self.up3(m) + d2 + u2 = self.up2(u3) + d1 + u1 = self.up1(u2) + h + out = self.tail(u1) + + return out + + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/utils.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/utils.py new file mode 100644 index 0000000..d1bef31 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/models/utils.py @@ -0,0 +1,52 @@ +import math +import torch + + +def compute_same_pad(kernel_size, stride): + if isinstance(kernel_size, int): + kernel_size = [kernel_size] + + if isinstance(stride, int): + stride = [stride] + + assert len(stride) == len( + kernel_size + ), "Pass kernel size and stride both as int, or both as equal length iterable" + + return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)] + + +def uniform_binning_correction(x, n_bits=8): + """Replaces x^i with q^i(x) = U(x, x + 1.0 / 256.0). + + Args: + x: 4-D Tensor of shape (NCHW) + n_bits: optional. + Returns: + x: x ~ U(x, x + 1.0 / 256) + objective: Equivalent to -q(x)*log(q(x)). + """ + b, c, h, w = x.size() + n_bins = 2 ** n_bits + chw = c * h * w + x += torch.zeros_like(x).uniform_(0, 1.0 / n_bins) + + objective = -math.log(n_bins) * chw * torch.ones(b, device=x.device) + return x, objective + + +def split_feature(tensor, type="split"): + """ + type = ["split", "cross"] + """ + C = tensor.size(1) + if type == "split": + # return tensor[:, : C // 2, ...], tensor[:, C // 2 :, ...] + return tensor[:, :1, ...], tensor[:,1:, ...] + elif type == "cross": + # return tensor[:, 0::2, ...], tensor[:, 1::2, ...] + return tensor[:, 0::2, ...], tensor[:, 1::2, ...] + + + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/options/__init__.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/options/__init__.py new file mode 100644 index 0000000..17ab6b7 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/options/__init__.py @@ -0,0 +1 @@ +"""This package options includes option modules.""" diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/options/base_options.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/options/base_options.py new file mode 100644 index 0000000..403b35f --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/options/base_options.py @@ -0,0 +1,214 @@ +import argparse +import os +import re +from util import util +import torch +import models +import time + +def str2bool(v): + return v.lower() in ('yes', 'y', 'true', 't', '1') + +inf = float('inf') + +class BaseOptions(): + def __init__(self): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + # data parameters + parser.add_argument('--dataroot', type=str, default='') + parser.add_argument('--dataset_name', type=str, default=['eth'], nargs='+') + parser.add_argument('--max_dataset_size', type=int, default=inf) + parser.add_argument('--scale', type=int, default=4, help='Super-resolution scale.') + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--patch_size', type=int, default=224) + parser.add_argument('--shuffle', type=str2bool, default=True) + parser.add_argument('-j', '--num_dataloader', default=4, type=int) + parser.add_argument('--drop_last', type=str2bool, default=True) + + # device parameters + parser.add_argument('--gpu_ids', type=str, default='all', + help='Separate the GPU ids by `,`, using all GPUs by default. ' + 'eg, `--gpu_ids 0`, `--gpu_ids 2,3`, `--gpu_ids -1`(CPU)') + parser.add_argument('--checkpoints_dir', type=str, default='./ckpt') + parser.add_argument('-v', '--verbose', type=str2bool, default=True) + parser.add_argument('--suffix', default='', type=str) + + # model parameters + parser.add_argument('--name', type=str, required=True, + help='Name of the folder to save models and logs.') + parser.add_argument('--model', type=str, required=True) + parser.add_argument('--load_path', type=str, default='', + help='Will load pre-trained model if load_path is set') + parser.add_argument('--load_iter', type=int, default=[0], nargs='+', + help='Load parameters if > 0 and load_path is not set. ' + 'Set the value of `last_epoch`') + parser.add_argument('--gcm_coord', type=str2bool, default=True) + parser.add_argument('--pre_ispnet_coord', type=str2bool, default=True) + parser.add_argument('--chop', type=str2bool, default=False) + + # training parameters + parser.add_argument('--init_type', type=str, default='default', + choices=['default', 'normal', 'xavier', + 'kaiming', 'orthogonal', 'uniform'], + help='`default` means using PyTorch default init functions.') + parser.add_argument('--init_gain', type=float, default=0.02) + # parser.add_argument('--loss', type=str, default='L1', + # help='choose from [L1, MSE, SSIM, VGG, PSNR]') + parser.add_argument('--optimizer', type=str, default='Adam', + choices=['Adam', 'SGD', 'RMSprop']) + parser.add_argument('--niter', type=int, default=1000) + parser.add_argument('--niter_decay', type=int, default=0) + parser.add_argument('--lr_policy', type=str, default='step') + parser.add_argument('--lr_decay_iters', type=int, default=200) + parser.add_argument('--lr', type=float, default=0.0001) + + # Optimizer + parser.add_argument('--load_optimizers', type=str2bool, default=False, + help='Loading optimizer parameters for continuing training.') + parser.add_argument('--weight_decay', type=float, default=0) + # Adam + parser.add_argument('--beta1', type=float, default=0.9) + parser.add_argument('--beta2', type=float, default=0.999) + # SGD & RMSprop + parser.add_argument('--momentum', type=float, default=0) + # RMSprop + parser.add_argument('--alpha', type=float, default=0.99) + + # visualization parameters + parser.add_argument('--print_freq', type=int, default=100) + parser.add_argument('--test_every', type=int, default=1000) + parser.add_argument('--save_epoch_freq', type=int, default=1) + parser.add_argument('--calc_metrics', type=str2bool, default=False) + parser.add_argument('--save_imgs', type=str2bool, default=False) + parser.add_argument('--visual_full_imgs', type=str2bool, default=False) + + # test parameters + parser.add_argument('--save_path', type=str, default='./submission') + parser.add_argument('--test_patch', type=int, default=252) + parser.add_argument('--TLC', type=str2bool, default=False) + + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are difined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class= + argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, _ = parser.parse_known_args() + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + opt, _ = parser.parse_known_args() # parse again with new defaults + + # save and return the parser + self.parser = parser + return parser.parse_args() + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, 'opt_%s.txt' + % ('train' if self.isTrain else 'test')) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + def parse(self): + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + opt.serial_batches = not opt.shuffle + + if self.isTrain and (opt.load_iter != [0] or opt.load_path != '') \ + and not opt.load_optimizers: + util.prompt('You are loading a checkpoint and continuing training, ' + 'and no optimizer parameters are loaded. Please make ' + 'sure that the hyper parameters are correctly set.', 80) + time.sleep(3) + + opt.model = opt.model.lower() + opt.name = opt.name.lower() + + scale_patch = {2: 96, 3: 144, 4: 192} + if opt.patch_size is None: + opt.patch_size = scale_patch[opt.scale] + + if opt.name.startswith(opt.checkpoints_dir): + opt.name = opt.name.replace(opt.checkpoints_dir+'/', '') + if opt.name.endswith('/'): + opt.name = opt.name[:-1] + + if len(opt.dataset_name) == 1: + opt.dataset_name = opt.dataset_name[0] + + if len(opt.load_iter) == 1: + opt.load_iter = opt.load_iter[0] + + # process opt.suffix + if opt.suffix != '': + suffix = ('_' + opt.suffix.format(**vars(opt))) + opt.name = opt.name + suffix + + self.print_options(opt) + + # set gpu ids + cuda_device_count = torch.cuda.device_count() + if opt.gpu_ids == 'all': + # GT 710 (3.5), GT 610 (2.1) + gpu_ids = [i for i in range(cuda_device_count)] + else: + p = re.compile('[^-0-9]+') + gpu_ids = [int(i) for i in re.split(p, opt.gpu_ids) if int(i) >= 0] + opt.gpu_ids = [i for i in gpu_ids \ + if torch.cuda.get_device_capability(i) >= (4,0)] + + if len(opt.gpu_ids) == 0 and len(gpu_ids) > 0: + opt.gpu_ids = gpu_ids + util.prompt('You\'re using GPUs with computing capability < 4') + elif len(opt.gpu_ids) != len(gpu_ids): + util.prompt('GPUs(computing capability < 4) have been disabled') + + if len(opt.gpu_ids) > 0: + assert torch.cuda.is_available(), 'No cuda available !!!' + torch.cuda.set_device(opt.gpu_ids[0]) + print('The GPUs you are using:') + for gpu_id in opt.gpu_ids: + print(' %2d *%s* with capability %d.%d' % ( + gpu_id, + torch.cuda.get_device_name(gpu_id), + *torch.cuda.get_device_capability(gpu_id))) + else: + util.prompt('You are using CPU mode') + + self.opt = opt + return self.opt diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/options/test_options.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/options/test_options.py new file mode 100644 index 0000000..26e404d --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/options/test_options.py @@ -0,0 +1,8 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + self.isTrain = False + return parser diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/options/train_options.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/options/train_options.py new file mode 100644 index 0000000..c44ea20 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/options/train_options.py @@ -0,0 +1,8 @@ +from .base_options import BaseOptions + + +class TrainOptions(BaseOptions): + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + self.isTrain = True + return parser diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/pwc/correlation/correlation.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/pwc/correlation/correlation.py new file mode 100644 index 0000000..c9c97e3 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/pwc/correlation/correlation.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Correlation_rearrange = ''' + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + + if (intIndex >= n) { + return; + } + + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + + float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + + __syncthreads(); + + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; + } +''' + +kernel_Correlation_updateOutput = ''' + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[32]; + + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +''' + +kernel_Correlation_updateGradFirst = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradFirst( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradFirst); // channels + int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos + int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradFirst); + const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); + gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; + } } +''' + +kernel_Correlation_updateGradSecond = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradSecond( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradSecond); // channels + int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos + int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradSecond); + const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); + gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; + } } +''' + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + # end + + return strKernel +# end + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) +# end + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) + rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) + + self.save_for_backward(first, second, rbot0, rbot1) + + assert(first.is_contiguous() == True) + assert(second.is_contiguous() == True) + + output = first.new_zeros([ first.shape[0], 81, first.shape[2], first.shape[3] ]) + + if first.is_cuda == True: + n = first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': first, + 'output': rbot0 + }))( + grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]), + block=tuple([ 16, 1, 1 ]), + args=[ n, first.data_ptr(), rbot0.data_ptr() ] + ) + + n = second.shape[2] * second.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': second, + 'output': rbot1 + }))( + grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]), + block=tuple([ 16, 1, 1 ]), + args=[ n, second.data_ptr(), rbot1.data_ptr() ] + ) + + n = output.shape[1] * output.shape[2] * output.shape[3] + cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'top': output + }))( + grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]), + block=tuple([ 32, 1, 1 ]), + shared_mem=first.shape[1] * 4, + args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ] + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return output + # end + + @staticmethod + def backward(self, gradOutput): + first, second, rbot0, rbot1 = self.saved_tensors + + assert(gradOutput.is_contiguous() == True) + + gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None + gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None + + if first.is_cuda == True: + if gradFirst is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': gradFirst, + 'gradSecond': None + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ] + ) + # end + # end + + if gradSecond is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': None, + 'gradSecond': gradSecond + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ] + ) + # end + # end + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return gradFirst, gradSecond + # end +# end + +def FunctionCorrelation(tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) +# end + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super(ModuleCorrelation, self).__init__() + # end + + def forward(self, tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) + # end +# end \ No newline at end of file diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/pwc/pwc_net.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/pwc/pwc_net.py new file mode 100644 index 0000000..add26c4 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/pwc/pwc_net.py @@ -0,0 +1,251 @@ +#-*- encoding: UTF-8 -*- + +import torch +import sys +from functools import partial +import pickle + +try: + from pwc.correlation import correlation # the custom cost volume layer +except: + sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python + + +# Borrow the code of the optical flow network (PWC-Net) from https://github.com/sniklaus/pytorch-pwc/ +class PWCNET(torch.nn.Module): + def __init__(self): + super(PWCNET, self).__init__() + + class Extractor(torch.nn.Module): + def __init__(self): + super(Extractor, self).__init__() + + self.netOne = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netTwo = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netThr = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFou = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFiv = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netSix = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + def forward(self, tenInput): + tenOne = self.netOne(tenInput) + tenTwo = self.netTwo(tenOne) + tenThr = self.netThr(tenTwo) + tenFou = self.netFou(tenThr) + tenFiv = self.netFiv(tenFou) + tenSix = self.netSix(tenFiv) + + return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ] + + class Decoder(torch.nn.Module): + def __init__(self, intLevel): + super(Decoder, self).__init__() + + intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, + 81 + 128 + 2 + 2, 81, None ][intLevel + 1] + intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, + 81 + 128 + 2 + 2, 81, None ][intLevel + 0] + + self.backwarp_tenGrid = {} + self.backwarp_tenPartial = {} + + if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, + out_channels=2, kernel_size=4, stride=2, padding=1) + if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d( + in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, + kernel_size=4, stride=2, padding=1) + if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1] + + self.netOne = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, + kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netTwo = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, + kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netThr = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, + kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFou = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, + kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFiv = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, + kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netSix = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, + kernel_size=3, stride=1, padding=1) + ) + + def forward(self, tenFirst, tenSecond, objPrevious): + tenFlow = None + tenFeat = None + + if objPrevious is None: + tenFlow = None + tenFeat = None + tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation( + tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False) + tenFeat = torch.cat([ tenVolume ], 1) + + elif objPrevious is not None: + tenFlow = self.netUpflow(objPrevious['tenFlow']) + tenFeat = self.netUpfeat(objPrevious['tenFeat']) + + tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation( + tenFirst=tenFirst, tenSecond=self.backwarp(tenInput=tenSecond, + tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False) + + tenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1) + + tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1) + tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1) + tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1) + tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1) + tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1) + + tenFlow = self.netSix(tenFeat) + + return { + 'tenFlow': tenFlow, + 'tenFeat': tenFeat + } + + def backwarp(self, tenInput, tenFlow): + index = str(tenFlow.shape) + str(tenInput.device) + if index not in self.backwarp_tenGrid: + tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), + tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1) + tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), + tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3]) + self.backwarp_tenGrid[index] = torch.cat([ tenHor, tenVer ], 1).to(tenInput.device) + + if index not in self.backwarp_tenPartial: + self.backwarp_tenPartial[index] = tenFlow.new_ones([ tenFlow.shape[0], + 1, tenFlow.shape[2], tenFlow.shape[3] ]) + + tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) + tenInput = torch.cat([ tenInput, self.backwarp_tenPartial[index] ], 1) + + tenOutput = torch.nn.functional.grid_sample(input=tenInput, + grid=(self.backwarp_tenGrid[index] + tenFlow).permute(0, 2, 3, 1), + mode='bilinear', padding_mode='zeros', align_corners=False) + + tenMask = tenOutput[:, -1:, :, :]; tenMask[tenMask > 0.999] = 1.0; tenMask[tenMask < 1.0] = 0.0 + + return tenOutput[:, :-1, :, :] * tenMask + + class Refiner(torch.nn.Module): + def __init__(self): + super(Refiner, self).__init__() + self.netMain = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, + out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1) + ) + + def forward(self, tenInput): + return self.netMain(tenInput) + + self.netExtractor = Extractor() + + self.netTwo = Decoder(2) + self.netThr = Decoder(3) + self.netFou = Decoder(4) + self.netFiv = Decoder(5) + self.netSix = Decoder(6) + + self.netRefiner = Refiner() + + pickle.load = partial(pickle.load, encoding="latin1") + pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") + + self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight + in torch.load('./ckpt/pwc-net.pth', map_location=lambda storage, + loc: storage, pickle_module=pickle).items() }) + + def forward(self, tenFirst, tenSecond): + tenFirst = self.netExtractor(tenFirst) + tenSecond = self.netExtractor(tenSecond) + objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) + objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) + objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) + objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) + objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) + return objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat']) + + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/requirements.txt b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/requirements.txt new file mode 100644 index 0000000..836a2e7 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/requirements.txt @@ -0,0 +1,54 @@ +absl-py==1.2.0 +cachetools==5.2.0 +certifi @ file:///opt/conda/conda-bld/certifi_1655968806487/work/certifi +charset-normalizer==2.1.0 +colour-demosaicing==0.1.6 +colour-science==0.3.16 +cupy-cuda102==10.6.0 +cycler==0.11.0 +einops==0.4.1 +fastrlock==0.8 +fonttools==4.34.4 +google-auth==2.9.1 +google-auth-oauthlib==0.4.6 +grpcio==1.47.0 +idna==3.3 +imageio==2.19.5 +importlib-metadata==4.12.0 +kiwisolver==1.4.4 +Markdown==3.4.1 +matplotlib==3.5.2 +networkx==2.6.3 +numpy==1.21.6 +oauthlib==3.2.0 +opencv-python==4.6.0.66 +packaging==21.3 +pandas==1.1.5 +Pillow==9.2.0 +protobuf==3.19.4 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pyparsing==3.0.9 +python-dateutil==2.8.2 +pytz==2022.1 +PyWavelets==1.3.0 +rawpy==0.17.1 +requests==2.28.1 +requests-oauthlib==1.3.1 +rsa==4.8 +scikit-image==0.19.3 +scipy==1.7.3 +six==1.16.0 +tensorboard==2.9.1 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +tensorboardX==2.5.1 +tifffile==2021.11.2 +timm==0.6.5 +torch==1.8.1 +torchvision==0.9.1 +tqdm==4.64.0 +typing_extensions==4.3.0 +urllib3==1.26.10 +Werkzeug==2.1.2 +zipp==3.8.1 diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/test_full.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/test_full.py new file mode 100644 index 0000000..a764771 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/test_full.py @@ -0,0 +1,114 @@ +import os +import torch +from options.test_options import TestOptions +from data import create_dataset +from models import create_model +from util.visualizer import Visualizer +from tqdm import tqdm +from util.util import calc_psnr as calc_psnr +import time +import numpy as np +from collections import OrderedDict as odict +from copy import deepcopy +from util.AISP_utils import demosaic,postprocess_raw, plot_pair +from util.util import pack_rggb_channels +from os.path import join +from tensorboardX import SummaryWriter +import cv2 +from util.util import pack_rggb_channels +import glob + +def save_rgb (img, filename): + if np.max(img) <= 1: + img = img * 255 + + img = img.astype(np.float32) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + cv2.imwrite(filename, img) + +def log(log_file, str, also_print=True): + with open(log_file, 'a+') as F: + F.write(str) + if also_print: + print(str, end='') + +def _open_img(img_p,ratio): + return np.load(img_p, allow_pickle=True).astype(float) / ratio + + +if __name__ == '__main__': + opt = TestOptions().parse() + if opt.TLC: + opt.model = opt.model+'1' + + if not isinstance(opt.load_iter, list): + load_iters = [opt.load_iter] + else: + load_iters = deepcopy(opt.load_iter) + + if not isinstance(opt.dataset_name, list): + dataset_names = [opt.dataset_name] + else: + dataset_names = deepcopy(opt.dataset_name) + datasets = odict() + for dataset_name in dataset_names: + if opt.visual_full_imgs: + dataset = create_dataset(dataset_name, 'visual', opt) + else: + dataset = create_dataset(dataset_name, 'test', opt) + datasets[dataset_name] = tqdm(dataset) + + + + for load_iter in load_iters: + opt.load_iter = load_iter + model = create_model(opt) + model.setup(opt) + model.eval() + + for dataset_name in dataset_names: + opt.dataset_name = dataset_name + tqdm_val = datasets[dataset_name] + dataset_test = tqdm_val.iterable + dataset_size_test = len(dataset_test) + + print('='*80) + print(dataset_name + ' dataset') + tqdm_val.reset() + + + time_val = 0 + for i, data in enumerate(tqdm_val): + + + model.set_input(data) + model.test() + res = model.get_current_visuals() + recon_raw = res['data_out'][0].detach().permute(1, 2, 0).numpy() + + ratio = 1024 + folder_dir = opt.save_path + PS = opt.test_patch + os.makedirs(folder_dir, exist_ok=True) + H,W,C= recon_raw.shape + pic_i=0 + avg_ps=0 + for rr in np.arange(0, H - PS + 1, PS): + for cc in np.arange(0, W - PS + 1, PS): + + raw_patch = recon_raw[rr:rr + PS, cc:cc + PS,:] + pic_index=data['fname'][0].split('/')[-1].split('_')[-1].split('.')[0]+'_'+str(pic_i)+'.npy' + raw_patch=pack_rggb_channels(raw_patch) + raw_patch = (raw_patch * ratio).astype(np.uint16) + save_dir = '%s/%s' % (folder_dir, pic_index) + os.makedirs(folder_dir, exist_ok=True) + np.save(save_dir, raw_patch) + + pic_i+=1 + + + for dataset in datasets: + datasets[dataset].close() + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/test_full.sh b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/test_full.sh new file mode 100644 index 0000000..0cfea5b --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/test_full.sh @@ -0,0 +1,13 @@ +#!/bin/bash +echo "Start to test the model...." + +name="s7_1000" +dataroot="/mnt/disk10T/AIM2022/data-s7-full/test2" +save_path='./submission' + + +python test_full.py \ +--model s7smooth --name $name --dataset_name s7align --pre_ispnet_coord False --gcm_coord False \ +--load_iter 484 --batch_size 1 --gpu_ids -1 --save_imgs True --calc_metrics True --visual_full_imgs False -j 3 \ +--dataroot $dataroot --save_path $save_path + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/test_full_Ensembles.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/test_full_Ensembles.py new file mode 100644 index 0000000..6979fbf --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/test_full_Ensembles.py @@ -0,0 +1,202 @@ +import os +import torch +from options.test_options import TestOptions +from data import create_dataset +from models import create_model +from util.visualizer import Visualizer +from tqdm import tqdm +from util.util import calc_psnr as calc_psnr +import time +import numpy as np +from collections import OrderedDict as odict +from copy import deepcopy +from util.AISP_utils import demosaic,postprocess_raw, plot_pair +from util.util import pack_rggb_channels +from os.path import join +from tensorboardX import SummaryWriter +import cv2 +from util.util import pack_rggb_channels +import glob + +def data_augmentation(image, mode): + ''' + Performs data augmentation of the input image + Input: + image: a cv2 (OpenCV) image + mode: int. Choice of transformation to apply to the image + 0 - no transformation + 1 - flip up and down + 2 - rotate counterwise 90 degree + 3 - rotate 90 degree and flip up and down + 4 - rotate 180 degree + 5 - rotate 180 degree and flip + 6 - rotate 270 degree + 7 - rotate 270 degree and flip + ''' + if mode == 0: + # original + out = image + elif mode == 1: + # flip up and down + out = np.flipud(image) + elif mode == 2: + # rotate counterwise 90 degree + out = np.rot90(image) + elif mode == 3: + # rotate 90 degree and flip up and down + out = np.rot90(image) + out = np.flipud(out) + elif mode == 4: + # rotate 180 degree + out = np.rot90(image, k=2) + elif mode == 5: + # rotate 180 degree and flip + out = np.rot90(image, k=2) + out = np.flipud(out) + elif mode == 6: + # rotate 270 degree + out = np.rot90(image, k=3) + elif mode == 7: + # rotate 270 degree and flip + out = np.rot90(image, k=3) + out = np.flipud(out) + else: + raise Exception('Invalid choice of image transformation') + + return out + +def inverse_data_augmentation(image, mode): + ''' + Performs inverse data augmentation of the input image + ''' + if mode == 0: + # original + out = image + elif mode == 1: + out = np.flipud(image) + elif mode == 2: + out = np.rot90(image, axes=(1,0)) + elif mode == 3: + out = np.flipud(image) + out = np.rot90(out, axes=(1,0)) + elif mode == 4: + out = np.rot90(image, k=2, axes=(1,0)) + elif mode == 5: + out = np.flipud(image) + out = np.rot90(out, k=2, axes=(1,0)) + elif mode == 6: + out = np.rot90(image, k=3, axes=(1,0)) + elif mode == 7: + # rotate 270 degree and flip + out = np.flipud(image) + out = np.rot90(out, k=3, axes=(1,0)) + else: + raise Exception('Invalid choice of image transformation') + + return out +def save_rgb (img, filename): + if np.max(img) <= 1: + img = img * 255 + + img = img.astype(np.float32) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + cv2.imwrite(filename, img) + +def log(log_file, str, also_print=True): + with open(log_file, 'a+') as F: + F.write(str) + if also_print: + print(str, end='') + +def _open_img(img_p,ratio): + return np.load(img_p, allow_pickle=True).astype(float) / ratio + + +if __name__ == '__main__': + opt = TestOptions().parse() + + if not isinstance(opt.load_iter, list): + load_iters = [opt.load_iter] + else: + load_iters = deepcopy(opt.load_iter) + + if not isinstance(opt.dataset_name, list): + dataset_names = [opt.dataset_name] + else: + dataset_names = deepcopy(opt.dataset_name) + datasets = odict() + for dataset_name in dataset_names: + if opt.visual_full_imgs: + dataset = create_dataset(dataset_name, 'visual', opt) + else: + dataset = create_dataset(dataset_name, 'test', opt) + datasets[dataset_name] = tqdm(dataset) + + + + for load_iter in load_iters: + opt.load_iter = load_iter + model = create_model(opt) + model.setup(opt) + model.eval() + + for dataset_name in dataset_names: + opt.dataset_name = dataset_name + tqdm_val = datasets[dataset_name] + dataset_test = tqdm_val.iterable + dataset_size_test = len(dataset_test) + + print('='*80) + print(dataset_name + ' dataset') + tqdm_val.reset() + + psnr = [0.0] * dataset_size_test*48 + + time_val = 0 + for i, data in enumerate(tqdm_val): + + inp = data['dslr'][0] + C,H,W=inp.shape + recon_raw = np.zeros((H//2, W//2, 4), dtype=np.float32) + for flag in range(8): + pch_noisy_flag = np.ascontiguousarray(data_augmentation(inp.cpu().numpy().transpose(1, 2, 0), flag)) + pch_noisy_flag = torch.from_numpy(pch_noisy_flag.transpose((2,0,1))[np.newaxis,]) + + data['dslr']=pch_noisy_flag + model.set_input(data) + model.test() + res = model.get_current_visuals() + raw_out = res['data_out'][0].detach().cpu().permute(1, 2, 0).numpy() + recon_raw += inverse_data_augmentation(raw_out, flag) + recon_raw = recon_raw / 8 + + + ratio = 1024 + folder_dir = opt.save_path + PS = opt.test_patch + + os.makedirs(folder_dir, exist_ok=True) + + H,W,C= recon_raw.shape + pic_i=0 + avg_ps=0 + for rr in np.arange(0, H - PS + 1, PS): + for cc in np.arange(0, W - PS + 1, PS): + + raw_patch = recon_raw[rr:rr + PS, cc:cc + PS,:] + pic_index=data['fname'][0].split('/')[-1].split('_')[-1].split('.')[0]+'_'+str(pic_i)+'.npy' + + raw_patch=pack_rggb_channels(raw_patch) + + raw_patch = (raw_patch * ratio).astype(np.uint16) + save_dir = '%s/%s' % (folder_dir, pic_index) + os.makedirs(folder_dir, exist_ok=True) + np.save(save_dir, raw_patch) + pic_i+=1 + + + for dataset in datasets: + datasets[dataset].close() + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/train.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/train.py new file mode 100644 index 0000000..c5c22be --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/train.py @@ -0,0 +1,186 @@ +import time +import torch +from options.train_options import TrainOptions +from data import create_dataset +from models import create_model +from util.visualizer import Visualizer +from tqdm import tqdm +import numpy as np +import math +import sys +import torch.multiprocessing as mp + +from util.util import calc_psnr as calc_psnr +from util.AISP_utils import demosaic, postprocess_raw, plot_pair +from util.util import pack_rggb_channels +import os +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "6,7" +import cv2 + + +# from skimage.exposure import match_histograms + +def save_rgb(img, filename): + if np.max(img) <= 1: + img = img * 255 + + img = img.astype(np.float32) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + cv2.imwrite(filename, img) + + +if __name__ == '__main__': + opt = TrainOptions().parse() + dataset_train = create_dataset(opt.dataset_name, 'train', opt) + dataset_size_train = len(dataset_train) + print('The number of training images = %d' % dataset_size_train) + dataset_val = create_dataset(opt.dataset_name, 'val', opt) + dataset_size_val = len(dataset_val) + print('The number of val images = %d' % dataset_size_val) + + model = create_model(opt) + model.setup(opt) + visualizer = Visualizer(opt) + total_iters = ((model.start_epoch * (dataset_size_train // opt.batch_size)) \ + // opt.print_freq) * opt.print_freq + + for epoch in range(model.start_epoch + 1, opt.niter + opt.niter_decay + 1): + # training + epoch_start_time = time.time() + epoch_iter = 0 + model.train() + + iter_data_time = iter_start_time = time.time() + for i, data in enumerate(dataset_train): + if total_iters % opt.print_freq == 0: + t_data = time.time() - iter_data_time + total_iters += 1 + epoch_iter += 1 + model.set_input(data) + model.optimize_parameters() + res = model.get_current_visuals() + + if opt.save_imgs: + + psnr_train = calc_psnr(data['raw'], res['data_out'].detach().cpu()) + print(data['fname'][0], psnr_train) + res = model.get_current_visuals() + folder_dir = './ckpt/%s/output_train' % (opt.name); + os.makedirs(folder_dir, exist_ok=True) + + # save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_data_out') + # raw = res['data_out'][0].cpu().permute(1, 2, 0).numpy() + # raw = pack_rggb_channels(raw) + # raw_dm = postprocess_raw(demosaic(raw)) + # save_rgb(raw_dm, save_dir) + + + # save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_raw') + # raw = res['data_raw'][0].cpu().permute(1, 2, 0).numpy() + # raw = pack_rggb_channels(raw) + # raw_dm = postprocess_raw(demosaic(raw)) + # save_rgb(raw_dm, save_dir) + + + # save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_raw_fusion') + # raw = res['raw_fusion'][0].cpu().permute(1, 2, 0).numpy() + # raw = pack_rggb_channels(raw) + # raw_dm = postprocess_raw(demosaic(raw)) + # save_rgb(raw_dm, save_dir) + + # save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_raw_warp') + # raw = res['raw_warp'][0].cpu().permute(1, 2, 0).numpy() + # raw = pack_rggb_channels(raw) + # raw_dm = postprocess_raw(demosaic(raw)) + # save_rgb(raw_dm, save_dir) + + # save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_data_out_mask') + # raw = res['data_out_mask'][0].cpu().permute(1, 2, 0).numpy() + # raw = pack_rggb_channels(raw) + # raw_dm = postprocess_raw(demosaic(raw)) + # save_rgb(raw_dm, save_dir) + + # save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_data_raw_mask') + # raw = res['data_raw_mask'][0].cpu().permute(1, 2, 0).numpy() + # raw = pack_rggb_channels(raw) + # raw_dm = postprocess_raw(demosaic(raw)) + # save_rgb(raw_dm, save_dir) + + save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_dslr') + dslr = res['data_dslr'][0].cpu().permute(1, 2, 0).numpy(); + save_rgb(dslr, save_dir) + + # save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_GCMModel_out_fusion') + # dslr = res['GCMModel_out_fusion'][0].cpu().permute(1, 2, 0).numpy(); + # save_rgb(dslr, save_dir) + + # save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_data_dslr_mask') + # dslr = res['data_dslr_mask'][0].cpu().permute(1, 2, 0).numpy(); + # save_rgb(dslr, save_dir) + + save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_GCMModel_out_warp_all') + dslr = res['GCMModel_out_warp'][0].cpu().permute(1, 2, 0).numpy(); + save_rgb(dslr, save_dir) + + save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_GCMModel_out_warp_test') + dslr = res['mask_test'][0].cpu().permute(1, 2, 0).numpy(); + save_rgb(dslr, save_dir) + + # save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_GCMModel_out_mask') + # dslr = res['GCMModel_out_mask'][0].cpu().permute(1, 2, 0).numpy(); + # save_rgb(dslr, save_dir) + + # save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_rgb_best_mask') + # dslr = res['rgb_best_mask'][0].cpu().permute(1, 2, 0).numpy(); + # save_rgb(dslr, save_dir) + + # save_dir = '%s/%s.jpg' % (folder_dir, os.path.basename(data['fname'][0]).split('.')[0] + '_raw_best_mask') + # dslr = res['raw_best_mask'][0].cpu().permute(1, 2, 0).numpy(); + # save_rgb(dslr, save_dir) + + if i > 2: + print(1 / 0) + + if total_iters % opt.print_freq == 0: + losses = model.get_current_losses() + t_comp = (time.time() - iter_start_time) + visualizer.print_current_losses( + epoch, epoch_iter, losses, t_comp, t_data, total_iters) + iter_start_time = time.time() + + iter_data_time = time.time() + if epoch % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d, iters %d' + % (epoch, total_iters)) + model.save_networks(epoch) + + print('End of epoch %d / %d \t Time Taken: %.3f sec' + % (epoch, opt.niter + opt.niter_decay, + time.time() - epoch_start_time)) + model.update_learning_rate() + + # val + if opt.calc_metrics: + model.eval() + val_iter_time = time.time() + tqdm_val = tqdm(dataset_val) + psnr = [0.0] * dataset_size_val + time_val = 0 + for i, data in enumerate(tqdm_val): + model.set_input(data) + time_val_start = time.time() + with torch.no_grad(): + model.test() + time_val += time.time() - time_val_start + res = model.get_current_visuals() + psnr[i] = calc_psnr(res['data_raw'].detach().cpu(), res['data_out'].detach().cpu()) + visualizer.print_psnr(epoch, opt.niter + opt.niter_decay, time_val, np.mean(psnr)) + + sys.stdout.flush() + + + + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/train_s7.sh b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/train_s7.sh new file mode 100644 index 0000000..8f9454b --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/train_s7.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +echo "Start to train the model...." + +name="s7_3000" +dataroot="/mnt/disk10T/AIM2022/data-s7-before" + +build_dir="./ckpt/"$name + +if [ ! -d "$build_dir" ]; then + mkdir $build_dir +fi + +LOG=./ckpt/$name/`date +%Y-%m-%d-%H-%M-%S`.txt + +# You can set "--model zrrganjoint" to train LiteISPGAN. + +python train.py \ + --dataset_name s7align --model s7smooth --name $name --gcm_coord False \ + --pre_ispnet_coord False --niter 1800 --lr_policy cosine --save_imgs False \ + --batch_size 6 --print_freq 300 --calc_metrics True --lr 3e-4 -j 24 \ + --weight_decay 0.01 \ + --dataroot $dataroot | tee $LOG + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/AISP_utils.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/AISP_utils.py new file mode 100644 index 0000000..783d2dd --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/AISP_utils.py @@ -0,0 +1,191 @@ +import cv2 +import numpy as np +import rawpy +import matplotlib.pyplot as plt +import imageio + + +def extract_bayer_channels(raw): + + ch_B = raw[1::2, 1::2] + ch_Gb = raw[0::2, 1::2] + ch_R = raw[0::2, 0::2] + ch_Gr = raw[1::2, 0::2] + + return ch_R, ch_Gr, ch_B, ch_Gb + +def load_rawpy (raw_file): + raw = rawpy.imread(raw_file) + raw_image = raw.raw_image + return raw_image + +def load_img (filename, debug=False, norm=True, resize=None): + img = cv2.imread(filename) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if norm: + img = img / 255. + img = img.astype(np.float32) + if debug: + print (img.shape, img.dtype, img.min(), img.max()) + + if resize: + img = cv2.resize(img, (resize[0], resize[1]), interpolation = cv2.INTER_AREA) + + return img + +def save_rgb (img, filename): + if np.max(img) <= 1: + img = img * 255 + + img = img.astype(np.float32) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + cv2.imwrite(filename, img) + +def load_raw_png(raw, debug=False): + ''' + Load RAW images from the ZurichRAW2RGB Dataset + Reference: https://github.com/aiff22/PyNET-PyTorch/blob/master/dng_to_png.py + by Andrey Ignatov. + + inputs: + - raw: filename to the raw image saved as '.png' + returns: + - RAW_norm: normalized float32 4-channel raw image with bayer pattern RGGB. + ''' + + assert '.png' in raw + raw = np.asarray(imageio.imread((raw))) + ch_R, ch_Gr, ch_B, ch_Gb = extract_bayer_channels (raw) + + RAW_combined = np.dstack((ch_R, ch_Gr, ch_Gb, ch_B)) + RAW_norm = RAW_combined.astype(np.float32) / (4 * 255) + RAW_norm = np.clip(RAW_norm, 0, 1) + + if debug: + print (RAW_norm.shape, RAW_norm.dtype, RAW_norm.min(), RAW_norm.max()) + + # raw as (h,w,1) in RGBG domain! do not use + raw_unpack = raw.astype(np.float32) / (4 * 255) + raw_unpack = np.expand_dims(raw_unpack, axis=-1) + + return RAW_norm + +# def load_raw(raw, max_val=2**10): +# raw = np.load (raw)/ max_val +# return raw.astype(np.float32) + + +########## RAW image manipulation + +def unpack_raw(im): + """ + Unpack RAW image from (h,w,4) to (h*2 , w*2, 1) + """ + h,w,chan = im.shape + H, W = h*2, w*2 + img2 = np.zeros((H,W)) + img2[0:H:2,0:W:2]=im[:,:,0] + img2[0:H:2,1:W:2]=im[:,:,1] + img2[1:H:2,0:W:2]=im[:,:,2] + img2[1:H:2,1:W:2]=im[:,:,3] + img2 = np.squeeze(img2) + img2 = np.expand_dims(img2, axis=-1) + return img2 + +def pack_raw(im): + """ + Pack RAW image from (h,w,1) to (h/2 , w/2, 4) + """ + img_shape = im.shape + H = img_shape[0] + W = img_shape[1] + ## R G G B + out = np.concatenate((im[0:H:2,0:W:2,:], + im[0:H:2,1:W:2,:], + im[1:H:2,0:W:2,:], + im[1:H:2,1:W:2,:]), axis=2) + return out + + + +########## VISUALIZATION + +def demosaic (raw): + """Simple demosaicing to visualize RAW images + Inputs: + - raw: (h,w,4) RAW RGGB image normalized [0..1] as float32 + Returns: + - Simple Avg. Green Demosaiced RAW image with shape (h*2, w*2, 3) + """ + + assert raw.shape[-1] == 4 + shape = raw.shape + + red = raw[:,:,0] + green_red = raw[:,:,1] + green_blue = raw[:,:,2] + blue = raw[:,:,3] + avg_green = (green_red + green_blue) / 2 + image = np.stack((red, avg_green, blue), axis=-1) + image = cv2.resize(image, (shape[1]*2, shape[0]*2)) + return image + + +def mosaic(rgb): + """Extracts RGGB Bayer planes from an RGB image.""" + + assert rgb.shape[-1] == 3 + shape = rgb.shape + + red = rgb[0::2, 0::2, 0] + green_red = rgb[0::2, 1::2, 1] + green_blue = rgb[1::2, 0::2, 1] + blue = rgb[1::2, 1::2, 2] + + image = np.stack((red, green_red, green_blue, blue), axis=-1) + return image + + +def gamma_compression(image): + """Converts from linear to gamma space.""" + return np.maximum(image, 1e-8) ** (1.0 / 2.2) + +def tonemap(image): + """Simple S-curved global tonemap""" + return (3*(image**2)) - (2*(image**3)) + +def postprocess_raw(raw): + """Simple post-processing to visualize demosaic RAW imgaes + Input: (h,w,3) RAW image normalized + Output: (h,w,3) post-processed RAW image + """ + raw = gamma_compression(raw) + raw = tonemap(raw) + raw = np.clip(raw, 0, 1) + return raw + +def plot_pair (rgb, raw, t1='RGB', t2='RAW', axis='off'): + + fig = plt.figure(figsize=(12, 6), dpi=80) + plt.subplot(1,2,1) + plt.title(t1) + plt.axis(axis) + plt.imshow(rgb) + + plt.subplot(1,2,2) + plt.title(t2) + plt.axis(axis) + plt.imshow(raw) + plt.show() + +########## METRICS + +def PSNR(y_true, y_pred): + mse = np.mean((y_true - y_pred) ** 2) + if(mse == 0): + return np.inf + + max_pixel = np.max(y_true) + psnr = 20 * np.log10(max_pixel / np.sqrt(mse)) + return psnr \ No newline at end of file diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/GaussianBlur.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/GaussianBlur.py new file mode 100644 index 0000000..5a6ba5b --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/GaussianBlur.py @@ -0,0 +1,49 @@ +""" +## CycleISP: Real Image Restoration Via Improved Data Synthesis +## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao +## CVPR 2020 +## https://arxiv.org/abs/2003.07761 +""" + +import torch +import torch.nn as nn +import math +import numpy as np + +def get_gaussian_kernel(kernel_size=21, sigma=5, channels=4): + #if not kernel_size: kernel_size = int(2*np.ceil(2*sigma)+1) + #print("Kernel is: ",kernel_size) + #print("Sigma is: ",sigma) + padding = kernel_size//2 + # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) + x_coord = torch.arange(kernel_size) + x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) + y_grid = x_grid.t() + xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() + + mean = (kernel_size - 1)/2. + variance = sigma**2. + + # Calculate the 2-dimensional gaussian kernel which is + # the product of two gaussian distributions for two different + # variables (in this case called x and y) + gaussian_kernel = (1./(2.*math.pi*variance)) *\ + torch.exp( + -torch.sum((xy_grid - mean)**2., dim=-1) /\ + (2*variance) + ) + + # Make sure sum of values in gaussian kernel equals 1. + gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) + + # Reshape to 2d depthwise convolutional weight + gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) + gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1) + + gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels, + kernel_size=kernel_size, groups=channels, bias=False) + + gaussian_filter.weight.data = gaussian_kernel + gaussian_filter.weight.requires_grad = False + + return gaussian_filter, padding diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/JPEG.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/JPEG.py new file mode 100644 index 0000000..8997ee9 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/JPEG.py @@ -0,0 +1,43 @@ + + +import torch +import torch.nn as nn + +from .JPEG_utils import diff_round, quality_to_factor, Quantization +from .compression import compress_jpeg +from .decompression import decompress_jpeg + + +class DiffJPEG(nn.Module): + def __init__(self, differentiable=True, quality=75): + ''' Initialize the DiffJPEG layer + Inputs: + height(int): Original image height + width(int): Original image width + differentiable(bool): If true uses custom differentiable + rounding function, if false uses standrard torch.round + quality(float): Quality factor for jpeg compression scheme. + ''' + super(DiffJPEG, self).__init__() + if differentiable: + rounding = diff_round + # rounding = Quantization() + else: + rounding = torch.round + factor = quality_to_factor(quality) + self.compress = compress_jpeg(rounding=rounding, factor=factor) + # self.decompress = decompress_jpeg(height, width, rounding=rounding, + # factor=factor) + self.decompress = decompress_jpeg(rounding=rounding, factor=factor) + + def forward(self, x): + ''' + ''' + org_height = x.shape[2] + org_width = x.shape[3] + y, cb, cr = self.compress(x) + + recovered = self.decompress(y, cb, cr, org_height, org_width) + return recovered + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/JPEG_utils.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/JPEG_utils.py new file mode 100644 index 0000000..e2ebd9b --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/JPEG_utils.py @@ -0,0 +1,75 @@ +# Standard libraries +import numpy as np +# PyTorch +import torch +import torch.nn as nn +import math + +y_table = np.array( + [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, + 55], [14, 13, 16, 24, 40, 57, 69, 56], + [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, + 77], [24, 35, 55, 64, 81, 104, 113, 92], + [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], + dtype=np.float32).T + +y_table = nn.Parameter(torch.from_numpy(y_table)) +# +c_table = np.empty((8, 8), dtype=np.float32) +c_table.fill(99) +c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], + [24, 26, 56, 99], [47, 66, 99, 99]]).T +c_table = nn.Parameter(torch.from_numpy(c_table)) + + +def diff_round_back(x): + """ Differentiable rounding function + Input: + x(tensor) + Output: + x(tensor) + """ + return torch.round(x) + (x - torch.round(x))**3 + + + +def diff_round(input_tensor): + test = 0 + for n in range(1, 10): + test += math.pow(-1, n+1) / n * torch.sin(2 * math.pi * n * input_tensor) + final_tensor = input_tensor - 1 / math.pi * test + return final_tensor + + +class Quant(torch.autograd.Function): + + @staticmethod + def forward(ctx, input): + input = torch.clamp(input, 0, 1) + output = (input * 255.).round() / 255. + return output + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +class Quantization(nn.Module): + def __init__(self): + super(Quantization, self).__init__() + + def forward(self, input): + return Quant.apply(input) + + +def quality_to_factor(quality): + """ Calculate factor corresponding to quality + Input: + quality(float): Quality for jpeg compression + Output: + factor(float): Compression factor + """ + if quality < 50: + quality = 5000. / quality + else: + quality = 200. - quality*2 + return quality / 100. \ No newline at end of file diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/__init__.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/__init__.py new file mode 100644 index 0000000..e2f595e --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/__init__.py @@ -0,0 +1 @@ +"""This package includes a miscellaneous collection of helper functions.""" diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/compression.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/compression.py new file mode 100644 index 0000000..3ae22f8 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/compression.py @@ -0,0 +1,185 @@ +# Standard libraries +import itertools +import numpy as np +# PyTorch +import torch +import torch.nn as nn +# Local +from . import JPEG_utils + + +class rgb_to_ycbcr_jpeg(nn.Module): + """ Converts RGB image to YCbCr + Input: + image(tensor): batch x 3 x height x width + Outpput: + result(tensor): batch x height x width x 3 + """ + def __init__(self): + super(rgb_to_ycbcr_jpeg, self).__init__() + matrix = np.array( + [[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], + [0.5, -0.418688, -0.081312]], dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0., 128., 128.])) + # + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + image = image.permute(0, 2, 3, 1) + result = torch.tensordot(image, self.matrix, dims=1) + self.shift + # result = torch.from_numpy(result) + result.view(image.shape) + return result + + + +class chroma_subsampling(nn.Module): + """ Chroma subsampling on CbCv channels + Input: + image(tensor): batch x height x width x 3 + Output: + y(tensor): batch x height x width + cb(tensor): batch x height/2 x width/2 + cr(tensor): batch x height/2 x width/2 + """ + def __init__(self): + super(chroma_subsampling, self).__init__() + + def forward(self, image): + image_2 = image.permute(0, 3, 1, 2).clone() + avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2), + count_include_pad=False) + cb = avg_pool(image_2[:, 1, :, :].unsqueeze(1)) + cr = avg_pool(image_2[:, 2, :, :].unsqueeze(1)) + cb = cb.permute(0, 2, 3, 1) + cr = cr.permute(0, 2, 3, 1) + return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) + + +class block_splitting(nn.Module): + """ Splitting image into patches + Input: + image(tensor): batch x height x width + Output: + patch(tensor): batch x h*w/64 x h x w + """ + def __init__(self): + super(block_splitting, self).__init__() + self.k = 8 + + def forward(self, image): + height, width = image.shape[1:3] + # print(height, width) + batch_size = image.shape[0] + # print(image.shape) + image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) + + +class dct_8x8(nn.Module): + """ Discrete Cosine Transformation + Input: + image(tensor): batch x height x width + Output: + dcp(tensor): batch x height x width + """ + def __init__(self): + super(dct_8x8, self).__init__() + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos( + (2 * y + 1) * v * np.pi / 16) + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + # + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float() ) + + def forward(self, image): + image = image - 128 + result = self.scale * torch.tensordot(image, self.tensor, dims=2) + result.view(image.shape) + return result + + +class y_quantize(nn.Module): + """ JPEG Quantization for Y channel + Input: + image(tensor): batch x height x width + rounding(function): rounding function to use + factor(float): Degree of compression + Output: + image(tensor): batch x height x width + """ + def __init__(self, rounding, factor=1): + super(y_quantize, self).__init__() + self.rounding = rounding + self.factor = factor + self.y_table = JPEG_utils.y_table + + def forward(self, image): + image = image.float() / (self.y_table * self.factor) + image = self.rounding(image) + return image + + +class c_quantize(nn.Module): + """ JPEG Quantization for CrCb channels + Input: + image(tensor): batch x height x width + rounding(function): rounding function to use + factor(float): Degree of compression + Output: + image(tensor): batch x height x width + """ + def __init__(self, rounding, factor=1): + super(c_quantize, self).__init__() + self.rounding = rounding + self.factor = factor + self.c_table = JPEG_utils.c_table + + def forward(self, image): + image = image.float() / (self.c_table * self.factor) + image = self.rounding(image) + return image + + +class compress_jpeg(nn.Module): + """ Full JPEG compression algortihm + Input: + imgs(tensor): batch x 3 x height x width + rounding(function): rounding function to use + factor(float): Compression factor + Ouput: + compressed(dict(tensor)): batch x h*w/64 x 8 x 8 + """ + def __init__(self, rounding=torch.round, factor=1): + super(compress_jpeg, self).__init__() + self.l1 = nn.Sequential( + rgb_to_ycbcr_jpeg(), + # comment this line if no subsampling + chroma_subsampling() + ) + self.l2 = nn.Sequential( + block_splitting(), + dct_8x8() + ) + self.c_quantize = c_quantize(rounding=rounding, factor=factor) + self.y_quantize = y_quantize(rounding=rounding, factor=factor) + + def forward(self, image): + y, cb, cr = self.l1(image*255) # modify + + # y, cb, cr = result[:,:,:,0], result[:,:,:,1], result[:,:,:,2] + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + comp = self.l2(components[k]) + # print(comp.shape) + if k in ('cb', 'cr'): + comp = self.c_quantize(comp) + else: + comp = self.y_quantize(comp) + + components[k] = comp + + return components['y'], components['cb'], components['cr'] \ No newline at end of file diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/decompression.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/decompression.py new file mode 100644 index 0000000..b73ff96 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/decompression.py @@ -0,0 +1,190 @@ +# Standard libraries +import itertools +import numpy as np +# PyTorch +import torch +import torch.nn as nn +# Local +from . import JPEG_utils as utils + + +class y_dequantize(nn.Module): + """ Dequantize Y channel + Inputs: + image(tensor): batch x height x width + factor(float): compression factor + Outputs: + image(tensor): batch x height x width + """ + def __init__(self, factor=1): + super(y_dequantize, self).__init__() + self.y_table = utils.y_table + self.factor = factor + + def forward(self, image): + return image * (self.y_table * self.factor) + + +class c_dequantize(nn.Module): + """ Dequantize CbCr channel + Inputs: + image(tensor): batch x height x width + factor(float): compression factor + Outputs: + image(tensor): batch x height x width + """ + def __init__(self, factor=1): + super(c_dequantize, self).__init__() + self.factor = factor + self.c_table = utils.c_table + + def forward(self, image): + return image * (self.c_table * self.factor) + + +class idct_8x8(nn.Module): + """ Inverse discrete Cosine Transformation + Input: + dcp(tensor): batch x height x width + Output: + image(tensor): batch x height x width + """ + def __init__(self): + super(idct_8x8, self).__init__() + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos( + (2 * v + 1) * y * np.pi / 16) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + + def forward(self, image): + + image = image * self.alpha + result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 + result.view(image.shape) + return result + + +class block_merging(nn.Module): + """ Merge pathces into image + Inputs: + patches(tensor) batch x height*width/64, height x width + height(int) + width(int) + Output: + image(tensor): batch x height x width + """ + def __init__(self): + super(block_merging, self).__init__() + + def forward(self, patches, height, width): + k = 8 + batch_size = patches.shape[0] + # print(patches.shape) # (1,1024,8,8) + image_reshaped = patches.view(batch_size, height//k, width//k, k, k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, height, width) + + +class chroma_upsampling(nn.Module): + """ Upsample chroma layers + Input: + y(tensor): y channel image + cb(tensor): cb channel + cr(tensor): cr channel + Ouput: + image(tensor): batch x height x width x 3 + """ + def __init__(self): + super(chroma_upsampling, self).__init__() + + def forward(self, y, cb, cr): + def repeat(x, k=2): + height, width = x.shape[1:3] + x = x.unsqueeze(-1) + x = x.repeat(1, 1, k, k) + x = x.view(-1, height * k, width * k) + return x + + cb = repeat(cb) + cr = repeat(cr) + + return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) + + +class ycbcr_to_rgb_jpeg(nn.Module): + """ Converts YCbCr image to RGB JPEG + Input: + image(tensor): batch x height x width x 3 + Outpput: + result(tensor): batch x 3 x height x width + """ + def __init__(self): + super(ycbcr_to_rgb_jpeg, self).__init__() + + matrix = np.array( + [[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], + dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + result = torch.tensordot(image + self.shift, self.matrix, dims=1) + #result = torch.from_numpy(result) + result.view(image.shape) + return result.permute(0, 3, 1, 2) + + +class decompress_jpeg(nn.Module): + """ Full JPEG decompression algortihm + Input: + compressed(dict(tensor)): batch x h*w/64 x 8 x 8 + rounding(function): rounding function to use + factor(float): Compression factor + Ouput: + image(tensor): batch x 3 x height x width + """ + # def __init__(self, height, width, rounding=torch.round, factor=1): + def __init__(self, rounding=torch.round, factor=1): + super(decompress_jpeg, self).__init__() + self.c_dequantize = c_dequantize(factor=factor) + self.y_dequantize = y_dequantize(factor=factor) + self.idct = idct_8x8() + self.merging = block_merging() + # comment this line if no subsampling + self.chroma = chroma_upsampling() + self.colors = ycbcr_to_rgb_jpeg() + + # self.height, self.width = height, width + + def forward(self, y, cb, cr, height, width): + components = {'y': y, 'cb': cb, 'cr': cr} + # height = y.shape[0] + # width = y.shape[1] + self.height = height + self.width = width + for k in components.keys(): + if k in ('cb', 'cr'): + comp = self.c_dequantize(components[k]) + # comment this line if no subsampling + height, width = int(self.height/2), int(self.width/2) + # height, width = int(self.height), int(self.width) + + else: + comp = self.y_dequantize(components[k]) + # comment this line if no subsampling + height, width = self.height, self.width + comp = self.idct(comp) + components[k] = self.merging(comp, height, width) + # + # comment this line if no subsampling + image = self.chroma(components['y'], components['cb'], components['cr']) + # image = torch.cat([components['y'].unsqueeze(3), components['cb'].unsqueeze(3), components['cr'].unsqueeze(3)], dim=3) + image = self.colors(image) + + image = torch.min(255*torch.ones_like(image), + torch.max(torch.zeros_like(image), image)) + return image/255 + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/myssim.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/myssim.py new file mode 100644 index 0000000..eb5d9db --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/myssim.py @@ -0,0 +1,371 @@ +from __future__ import division, absolute_import, print_function + + + +import numpy as np +# from numpy.lib.arraypad import _validate_lengths +from scipy.ndimage import uniform_filter, gaussian_filter + +dtype_range = {np.bool_: (False, True), + np.bool8: (False, True), + np.uint8: (0, 255), + np.uint16: (0, 65535), + np.uint32: (0, 2**32 - 1), + np.uint64: (0, 2**64 - 1), + np.int8: (-128, 127), + np.int16: (-32768, 32767), + np.int32: (-2**31, 2**31 - 1), + np.int64: (-2**63, 2**63 - 1), + np.float16: (-1, 1), + np.float32: (-1, 1), + np.float64: (-1, 1)} + +def _normalize_shape(ndarray, shape, cast_to_int=True): + """ + Private function which does some checks and normalizes the possibly + much simpler representations of 'pad_width', 'stat_length', + 'constant_values', 'end_values'. + + Parameters + ---------- + narray : ndarray + Input ndarray + shape : {sequence, array_like, float, int}, optional + The width of padding (pad_width), the number of elements on the + edge of the narray used for statistics (stat_length), the constant + value(s) to use when filling padded regions (constant_values), or the + endpoint target(s) for linear ramps (end_values). + ((before_1, after_1), ... (before_N, after_N)) unique number of + elements for each axis where `N` is rank of `narray`. + ((before, after),) yields same before and after constants for each + axis. + (constant,) or val is a shortcut for before = after = constant for + all axes. + cast_to_int : bool, optional + Controls if values in ``shape`` will be rounded and cast to int + before being returned. + + Returns + ------- + normalized_shape : tuple of tuples + val => ((val, val), (val, val), ...) + [[val1, val2], [val3, val4], ...] => ((val1, val2), (val3, val4), ...) + ((val1, val2), (val3, val4), ...) => no change + [[val1, val2], ] => ((val1, val2), (val1, val2), ...) + ((val1, val2), ) => ((val1, val2), (val1, val2), ...) + [[val , ], ] => ((val, val), (val, val), ...) + ((val , ), ) => ((val, val), (val, val), ...) + + """ + ndims = ndarray.ndim + + # Shortcut shape=None + if shape is None: + return ((None, None), ) * ndims + + # Convert any input `info` to a NumPy array + shape_arr = np.asarray(shape) + + try: + shape_arr = np.broadcast_to(shape_arr, (ndims, 2)) + except ValueError: + fmt = "Unable to create correctly shaped tuple from %s" + raise ValueError(fmt % (shape,)) + + # Cast if necessary + if cast_to_int is True: + shape_arr = np.round(shape_arr).astype(int) + + # Convert list of lists to tuple of tuples + return tuple(tuple(axis) for axis in shape_arr.tolist()) + + +def _validate_lengths(narray, number_elements): + """ + Private function which does some checks and reformats pad_width and + stat_length using _normalize_shape. + + Parameters + ---------- + narray : ndarray + Input ndarray + number_elements : {sequence, int}, optional + The width of padding (pad_width) or the number of elements on the edge + of the narray used for statistics (stat_length). + ((before_1, after_1), ... (before_N, after_N)) unique number of + elements for each axis. + ((before, after),) yields same before and after constants for each + axis. + (constant,) or int is a shortcut for before = after = constant for all + axes. + + Returns + ------- + _validate_lengths : tuple of tuples + int => ((int, int), (int, int), ...) + [[int1, int2], [int3, int4], ...] => ((int1, int2), (int3, int4), ...) + ((int1, int2), (int3, int4), ...) => no change + [[int1, int2], ] => ((int1, int2), (int1, int2), ...) + ((int1, int2), ) => ((int1, int2), (int1, int2), ...) + [[int , ], ] => ((int, int), (int, int), ...) + ((int , ), ) => ((int, int), (int, int), ...) + + """ + normshp = _normalize_shape(narray, number_elements) + for i in normshp: + chk = [1 if x is None else x for x in i] + chk = [1 if x >= 0 else -1 for x in chk] + if (chk[0] < 0) or (chk[1] < 0): + fmt = "%s cannot contain negative values." + raise ValueError(fmt % (number_elements,)) + return normshp + + + +def crop(ar, crop_width, copy=False, order='K'): + """Crop array `ar` by `crop_width` along each dimension. + Parameters + ---------- + ar : array-like of rank N + Input array. + crop_width : {sequence, int} + Number of values to remove from the edges of each axis. + ``((before_1, after_1),`` ... ``(before_N, after_N))`` specifies + unique crop widths at the start and end of each axis. + ``((before, after),)`` specifies a fixed start and end crop + for every axis. + ``(n,)`` or ``n`` for integer ``n`` is a shortcut for + before = after = ``n`` for all axes. + copy : bool, optional + If `True`, ensure the returned array is a contiguous copy. Normally, + a crop operation will return a discontiguous view of the underlying + input array. + order : {'C', 'F', 'A', 'K'}, optional + If ``copy==True``, control the memory layout of the copy. See + ``np.copy``. + Returns + ------- + cropped : array + The cropped array. If ``copy=False`` (default), this is a sliced + view of the input array. + """ + ar = np.array(ar, copy=False) + crops = _validate_lengths(ar, crop_width) + slices = [slice(a, ar.shape[i] - b) for i, (a, b) in enumerate(crops)] + if copy: + cropped = np.array(ar[slices], order=order, copy=True) + else: + cropped = ar[slices] + return cropped + +def compare_ssim(X, Y, win_size=None, gradient=False, + data_range=1, multichannel=False, gaussian_weights=False, + full=False, dynamic_range=None, **kwargs): + """Compute the mean structural similarity index between two images. + Parameters + ---------- + X, Y : ndarray + Image. Any dimensionality. + win_size : int or None + The side-length of the sliding window used in comparison. Must be an + odd value. If `gaussian_weights` is True, this is ignored and the + window size will depend on `sigma`. + gradient : bool, optional + If True, also return the gradient. + data_range : int, optional + The data range of the input image (distance between minimum and + maximum possible values). By default, this is estimated from the image + data-type. + multichannel : bool, optional + If True, treat the last dimension of the array as channels. Similarity + calculations are done independently for each channel then averaged. + gaussian_weights : bool, optional + If True, each patch has its mean and variance spatially weighted by a + normalized Gaussian kernel of width sigma=1.5. + full : bool, optional + If True, return the full structural similarity image instead of the + mean value. + Other Parameters + ---------------- + use_sample_covariance : bool + if True, normalize covariances by N-1 rather than, N where N is the + number of pixels within the sliding window. + K1 : float + algorithm parameter, K1 (small constant, see [1]_) + K2 : float + algorithm parameter, K2 (small constant, see [1]_) + sigma : float + sigma for the Gaussian when `gaussian_weights` is True. + Returns + ------- + mssim : float + The mean structural similarity over the image. + grad : ndarray + The gradient of the structural similarity index between X and Y [2]_. + This is only returned if `gradient` is set to True. + S : ndarray + The full SSIM image. This is only returned if `full` is set to True. + Notes + ----- + To match the implementation of Wang et. al. [1]_, set `gaussian_weights` + to True, `sigma` to 1.5, and `use_sample_covariance` to False. + References + ---------- + .. [1] Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. + (2004). Image quality assessment: From error visibility to + structural similarity. IEEE Transactions on Image Processing, + 13, 600-612. + https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, + DOI:10.1.1.11.2477 + .. [2] Avanaki, A. N. (2009). Exact global histogram specification + optimized for structural similarity. Optical Review, 16, 613-621. + http://arxiv.org/abs/0901.0065, + DOI:10.1007/s10043-009-0119-z + """ + if not X.dtype == Y.dtype: + raise ValueError('Input images must have the same dtype.') + + if not X.shape == Y.shape: + raise ValueError('Input images must have the same dimensions.') + + if dynamic_range is not None: + #warn('`dynamic_range` has been deprecated in favor of ' + # '`data_range`. The `dynamic_range` keyword argument ' + # 'will be removed in v0.14', skimage_deprecation) + data_range = dynamic_range + + if multichannel: + # loop over channels + args = dict(win_size=win_size, + gradient=gradient, + data_range=data_range, + multichannel=False, + gaussian_weights=gaussian_weights, + full=full) + args.update(kwargs) + nch = X.shape[-1] + mssim = np.empty(nch) + if gradient: + G = np.empty(X.shape) + if full: + S = np.empty(X.shape) + for ch in range(nch): + ch_result = compare_ssim(X[..., ch], Y[..., ch], **args) + if gradient and full: + mssim[..., ch], G[..., ch], S[..., ch] = ch_result + elif gradient: + mssim[..., ch], G[..., ch] = ch_result + elif full: + mssim[..., ch], S[..., ch] = ch_result + else: + mssim[..., ch] = ch_result + mssim = mssim.mean() + if gradient and full: + return mssim, G, S + elif gradient: + return mssim, G + elif full: + return mssim, S + else: + return mssim + + K1 = kwargs.pop('K1', 0.01) + K2 = kwargs.pop('K2', 0.03) + sigma = kwargs.pop('sigma', 1.5) + if K1 < 0: + raise ValueError("K1 must be positive") + if K2 < 0: + raise ValueError("K2 must be positive") + if sigma < 0: + raise ValueError("sigma must be positive") + use_sample_covariance = kwargs.pop('use_sample_covariance', True) + + if win_size is None: + if gaussian_weights: + win_size = 11 # 11 to match Wang et. al. 2004 + else: + win_size = 7 # backwards compatibility + + if np.any((np.asarray(X.shape) - win_size) < 0): + raise ValueError( + "win_size exceeds image extent. If the input is a multichannel " + "(color) image, set multichannel=True.") + + if not (win_size % 2 == 1): + raise ValueError('Window size must be odd.') + + if data_range is None: + dmin, dmax = dtype_range[X.dtype.type] + data_range = dmax - dmin + + ndim = X.ndim + + if gaussian_weights: + # sigma = 1.5 to approximately match filter in Wang et. al. 2004 + # this ends up giving a 13-tap rather than 11-tap Gaussian + filter_func = gaussian_filter + filter_args = {'sigma': sigma} + + else: + filter_func = uniform_filter + filter_args = {'size': win_size} + + # ndimage filters need floating point data + X = X.astype(np.float64) + Y = Y.astype(np.float64) + + NP = win_size ** ndim + + # filter has already normalized by NP + if use_sample_covariance: + cov_norm = NP / (NP - 1) # sample covariance + else: + cov_norm = 1.0 # population covariance to match Wang et. al. 2004 + + # compute (weighted) means + ux = filter_func(X, **filter_args) + uy = filter_func(Y, **filter_args) + + # compute (weighted) variances and covariances + uxx = filter_func(X * X, **filter_args) + uyy = filter_func(Y * Y, **filter_args) + uxy = filter_func(X * Y, **filter_args) + vx = cov_norm * (uxx - ux * ux) + vy = cov_norm * (uyy - uy * uy) + vxy = cov_norm * (uxy - ux * uy) + + R = data_range + C1 = (K1 * R) ** 2 + C2 = (K2 * R) ** 2 + + A1, A2, B1, B2 = ((2 * ux * uy + C1, + 2 * vxy + C2, + ux ** 2 + uy ** 2 + C1, + vx + vy + C2)) + D = B1 * B2 + S = (A1 * A2) / D + + # to avoid edge effects will ignore filter radius strip around edges + pad = (win_size - 1) // 2 + + # compute (weighted) mean of ssim + mssim = crop(S, pad).mean() + + if gradient: + # The following is Eqs. 7-8 of Avanaki 2009. + grad = filter_func(A1 / D, **filter_args) * X + grad += filter_func(-S / B2, **filter_args) * Y + grad += filter_func((ux * (A2 - A1) - uy * (B2 - B1) * S) / D, + **filter_args) + grad *= (2 / X.size) + + if full: + return mssim, grad, S + else: + return mssim, grad + else: + if full: + return mssim, S + else: + return mssim + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/util.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/util.py new file mode 100644 index 0000000..d666227 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/util.py @@ -0,0 +1,264 @@ +"""This module contains simple helper functions """ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os +import time +from functools import wraps +import torch +import random +import numpy as np +import cv2 +import torch +import colour_demosaicing +import glob + +# 修饰函数,重新尝试600次,每次间隔1秒钟 +# 能对func本身处理,缺点在于无法查看func本身的提示 +def loop_until_success(func): + @wraps(func) + def wrapper(*args, **kwargs): + for i in range(600): + try: + ret = func(*args, **kwargs) + break + except OSError: + time.sleep(1) + return ret + return wrapper + +# 修改后的print函数及torch.save函数示例 +@loop_until_success +def loop_print(*args, **kwargs): + print(*args, **kwargs) + +@loop_until_success +def torch_save(*args, **kwargs): + torch.save(*args, **kwargs) + +def calc_psnr(sr, hr, range=1.): + # shave = 2 + with torch.no_grad(): + diff = (sr - hr) / range + # diff = diff[:, :, shave:-shave, shave:-shave] + mse = torch.pow(diff, 2).mean() + return (-10 * torch.log10(mse)).item() + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + +def print_numpy(x, val=True, shp=True): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, mid = %3.3f, std=%3.3f' + % (np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + + + +def get_coord(H, W, x=448/3968, y=448/2976): + x_coord = np.linspace(-x + (x / W), x - (x / W), W) + x_coord = np.expand_dims(x_coord, axis=0) + x_coord = np.tile(x_coord, (H, 1)) + x_coord = np.expand_dims(x_coord, axis=0) + + y_coord = np.linspace(-y + (y / H), y - (y / H), H) + y_coord = np.expand_dims(y_coord, axis=1) + y_coord = np.tile(y_coord, (1, W)) + y_coord = np.expand_dims(y_coord, axis=0) + + coord = np.ascontiguousarray(np.concatenate([x_coord, y_coord])) + coord = np.float32(coord) + + return coord + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) + +def prompt(s, width=66): + print('='*(width+4)) + ss = s.split('\n') + if len(ss) == 1 and len(s) <= width: + print('= ' + s.center(width) + ' =') + else: + for s in ss: + for i in split_str(s, width): + print('= ' + i.ljust(width) + ' =') + print('='*(width+4)) + +def split_str(s, width): + ss = [] + while len(s) > width: + idx = s.rfind(' ', 0, width+1) + if idx > width >> 1: + ss.append(s[:idx]) + s = s[idx+1:] + else: + ss.append(s[:width]) + s = s[width:] + if s.strip() != '': + ss.append(s) + return ss + + +def load_img(filename, debug=False, norm=True, resize=False): + img = cv2.imread(filename) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if norm: + img = (img) / 255. + img = img.astype(np.float32) + if debug: + print(img.shape, img.dtype, img.min(), img.max()) + + if resize: + img = cv2.resize(img, (resize[0], resize[1]), interpolation=cv2.INTER_AREA) + + return img + +def augment_func(img, hflip, vflip, rot90): # CxHxW + if hflip: img = img[:, :, ::-1] + if vflip: img = img[:, ::-1, :] + if rot90: img = img.transpose(0, 2, 1) + return np.ascontiguousarray(img) + +def augment(*imgs): # CxHxW + hflip = random.random() < 0.5 + vflip = random.random() < 0.5 + rot90 = random.random() < 0.5 + return (augment_func(img, hflip, vflip, rot90) for img in imgs) + +def remove_black_level(img, black_lv=0, white_lv=2**10): + img = np.maximum(img.astype(np.float32)-black_lv, 0) / (white_lv-black_lv) + return img + +# def remove_black_level(img, black_lv=63, white_lv=4*255): +# img = np.maximum(img.astype(np.float32)-black_lv, 0) / (white_lv-black_lv) +# return img + +def gamma_correction(img, r=1/2.2): + img = np.maximum(img, 0) + img = np.power(img, r) + return img + +def extract_bayer_channels(raw): # HxWx4 + ch_R = raw[:,:,0] + ch_Gb = raw[:,:,1] + ch_Gr = raw[:,:,2] + ch_B = raw[:,:,3] + raw_combined = np.dstack((ch_B, ch_Gb, ch_R, ch_Gr)); + # raw_combined = raw + raw_combined = np.ascontiguousarray(raw_combined.transpose((2, 0, 1))); + return raw_combined # 4xHxW + +def extract_bayer_channels_rggb(raw): # HxWx4 + raw_combined = np.ascontiguousarray(raw.transpose((2, 0, 1))); + return raw_combined # 4xHxW + +def pack_rggb_channels(raw): # HxWx4 + ch_B = raw[:,:,0] + ch_Gb = raw[:,:,1] + ch_R = raw[:,:,2] + ch_Gr = raw[:,:,3] + raw_combined = np.dstack((ch_R, ch_Gb, ch_Gr, ch_B)); + raw_combined = np.ascontiguousarray(raw_combined); + return raw_combined # HxWx4 + +def RGGB2Bayer(im):# H//2xW//2x4 + # convert RGGB stacked image to one channel Bayer + bayer = np.zeros((im.shape[0] * 2, im.shape[1] * 2)) + bayer[0::2, 0::2] = im[:, :, 0] + bayer[0::2, 1::2] = im[:, :, 1] + bayer[1::2, 0::2] = im[:, :, 2] + bayer[1::2, 1::2] = im[:, :, 3] + return bayer# HxWx1 + +def get_raw_demosaic(raw, pattern='RGGB'): # HxW + raw=RGGB2Bayer(raw) + raw_demosaic = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, pattern=pattern) + raw_demosaic = np.ascontiguousarray(raw_demosaic.astype(np.float32).transpose((2, 0, 1))) +# raw_demosaic = np.ascontiguousarray(raw_demosaic.astype(np.float32)) + return raw_demosaic # 3xHxW + +def demosaic (raw): + + assert raw.shape[-1] == 4 + shape = raw.shape + + red = raw[:,:,0] + green_red = raw[:,:,1] + green_blue = raw[:,:,2] + blue = raw[:,:,3] + avg_green = (green_red + green_blue) / 2 + image = np.stack((red, avg_green, blue), axis=-1) + return image + + + +def get_raw_demosaic_nogcm(raw, pattern='RGGB'): # HxW + raw_demosaic = demosaic(raw) + raw_demosaic = np.ascontiguousarray(raw_demosaic.astype(np.float32).transpose((2, 0, 1))) + return raw_demosaic # 3xHxW + + +def read_wb(txtfile, key): + wb = np.zeros((1,4)) + with open(txtfile) as f: + for l in f: + if key in l: + for i in range(wb.shape[0]): + nextline = next(f) + try: + wb[i,:] = nextline.split() + except: + print("WB error XXXXXXX") + print(txtfile) + wb = wb.astype(np.float32) + return wb + + + diff --git a/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/visualizer.py b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/visualizer.py new file mode 100644 index 0000000..0b65393 --- /dev/null +++ b/aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-s7/util/visualizer.py @@ -0,0 +1,62 @@ +import numpy as np +from os.path import join +from tensorboardX import SummaryWriter +from matplotlib import pyplot as plt +from io import BytesIO +from PIL import Image +from functools import partial +from functools import wraps +import time + +def write_until_success(func): + @wraps(func) + def wrapper(*args, **kwargs): + for i in range(30): + try: + ret = func(*args, **kwargs) + break + except OSError: + print('%s OSError' % str(args)) + time.sleep(1) + return ret + return wrapper + +class Visualizer(): + def __init__(self, opt): + self.opt = opt + if opt.isTrain: + self.name = opt.name + self.save_dir = join(opt.checkpoints_dir, opt.name, 'log') + self.writer = SummaryWriter(logdir=join(self.save_dir)) + else: + self.name = '%s_%s_%d' % ( + opt.name, opt.dataset_name, opt.load_iter) + self.save_dir = join(opt.checkpoints_dir, opt.name) + if opt.save_imgs: + self.writer = SummaryWriter(logdir=join( + self.save_dir, 'ckpts', self.name)) + + @write_until_success + def display_current_results(self, phase, visuals, iters): + for k, v in visuals.items(): + v = v.cpu() + self.writer.add_image('%s/%s'%(phase, k), v[0]/255, iters) + self.writer.flush() + + @write_until_success + def print_current_losses(self, epoch, iters, losses, + t_comp, t_data, total_iters): + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' \ + % (epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.4e ' % (k, v) + self.writer.add_scalar('loss/%s'%k, v, total_iters) + print(message) + + @write_until_success + def print_psnr(self, epoch, total_epoch, time_val, mean_psnr): + self.writer.add_scalar('val/psnr', mean_psnr, epoch) + print('End of epoch %d / %d (Val) \t Time Taken: %.3f s \t PSNR: %f' + % (epoch, total_epoch, time_val, mean_psnr)) + +