-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from liuxiaoyu1104/main
Add top solution from HIT-IIL
- Loading branch information
Showing
73 changed files
with
12,879 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
- [The pre-trained models for track 1](https://drive.google.com/drive/folders/1GmahiWwpMsPb9Y37TEiHs-ittD7ZAF0A?usp=sharing) can be downloaded. You need to put them in the `sRGB-to-RAW-s7/ckpt` folder. | ||
- [The pre-trained models for track 2](https://drive.google.com/drive/folders/1VMFi8ombywlD60sroWJOUwv_0l_CF71N?usp=sharing) can be downloaded. You need to put them in the `sRGB-to-RAW-p20/ckpt` folder. | ||
|
57 changes: 57 additions & 0 deletions
57
aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import importlib | ||
import torch.utils.data | ||
from data.base_dataset import BaseDataset | ||
|
||
|
||
def find_dataset_using_name(dataset_name, split='train'): | ||
dataset_filename = "data." + dataset_name + "_dataset" | ||
datasetlib = importlib.import_module(dataset_filename) | ||
|
||
dataset = None | ||
target_dataset_name = dataset_name.replace('_', '') + 'dataset' | ||
for name, cls in datasetlib.__dict__.items(): | ||
if name.lower() == target_dataset_name.lower() \ | ||
and issubclass(cls, BaseDataset): | ||
dataset = cls | ||
|
||
if dataset is None: | ||
raise NotImplementedError("In %s.py, there should be a subclass of " | ||
"BaseDataset with class name that matches %s in " | ||
"lowercase." % (dataset_filename, target_dataset_name)) | ||
return dataset | ||
|
||
|
||
def create_dataset(dataset_name, split, opt): | ||
data_loader = CustomDatasetDataLoader(dataset_name, split, opt) | ||
dataset = data_loader.load_data() | ||
return dataset | ||
|
||
|
||
class CustomDatasetDataLoader(): | ||
def __init__(self, dataset_name, split, opt): | ||
self.opt = opt | ||
dataset_class = find_dataset_using_name(dataset_name, split) | ||
self.dataset = dataset_class(opt, split, dataset_name) | ||
# self.imio = self.dataset.imio | ||
print("dataset [%s(%s)] created" % (dataset_name, split)) | ||
self.dataloader = torch.utils.data.DataLoader( | ||
self.dataset, | ||
batch_size=opt.batch_size if split=='train' else 1, | ||
shuffle=opt.shuffle and split=='train', | ||
num_workers=int(opt.num_dataloader), | ||
drop_last=opt.drop_last) | ||
|
||
def load_data(self): | ||
return self | ||
|
||
def __len__(self): | ||
"""Return the number of data in the dataset""" | ||
return min(len(self.dataset), self.opt.max_dataset_size) | ||
|
||
def __iter__(self): | ||
"""Return a batch of data""" | ||
for i, data in enumerate(self.dataloader): | ||
if i * self.opt.batch_size >= self.opt.max_dataset_size: | ||
break | ||
yield data | ||
|
19 changes: 19 additions & 0 deletions
19
aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/base_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch.utils.data as data | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class BaseDataset(data.Dataset, ABC): | ||
def __init__(self, opt, split, dataset_name): | ||
self.opt = opt | ||
self.split = split | ||
self.root = opt.dataroot | ||
self.dataset_name = dataset_name.lower() | ||
|
||
@abstractmethod | ||
def __len__(self): | ||
return 0 | ||
|
||
@abstractmethod | ||
def __getitem__(self, index): | ||
pass | ||
|
279 changes: 279 additions & 0 deletions
279
aim22-reverseisp/teams/HIT-IIL/sRGB-to-RAW-p20/data/imlib.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,279 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import numpy as np | ||
import os | ||
import cv2 | ||
from PIL import Image | ||
from functools import wraps | ||
import time | ||
|
||
|
||
class imlib(): | ||
""" | ||
Note that YCxCx in OpenCV and PIL are different. | ||
Therefore, be careful if a model is trained with OpenCV and tested with | ||
PIL in Y mode, and vise versa | ||
force_color = True: return a 3 channel YCxCx image | ||
For mode 'Y', if a gray image is given, repeat the channel for 3 times, | ||
and then converted to YCxCx mode. | ||
force_color = False: return a 3 channel YCxCx image or a 1 channel gray one. | ||
For mode 'Y', if a gray image is given, the gray image is directly used. | ||
""" | ||
def __init__(self, mode='RGB', fmt='CHW', lib='cv2', force_color=True): | ||
assert mode.upper() in ('RGB', 'L', 'Y', 'RAW') | ||
self.mode = mode.upper() | ||
|
||
assert fmt.upper() in ('HWC', 'CHW', 'NHWC', 'NCHW') | ||
self.fmt = 'CHW' if fmt.upper() in ('CHW', 'NCHW') else 'HWC' | ||
|
||
assert lib.lower() in ('cv2', 'pillow') | ||
self.lib = lib.lower() | ||
|
||
self.force_color = force_color | ||
|
||
self.dtype = np.uint8 | ||
|
||
self._imread = getattr(self, '_imread_%s_%s'%(self.lib, self.mode)) | ||
self._imwrite = getattr(self, '_imwrite_%s_%s'%(self.lib, self.mode)) | ||
self._trans_batch = getattr(self, '_trans_batch_%s_%s' | ||
% (self.mode, self.fmt)) | ||
self._trans_image = getattr(self, '_trans_image_%s_%s' | ||
% (self.mode, self.fmt)) | ||
self._trans_back = getattr(self, '_trans_back_%s_%s' | ||
% (self.mode, self.fmt)) | ||
|
||
def _imread_cv2_RGB(self, path): | ||
return cv2_imread(path, cv2.IMREAD_COLOR)[..., ::-1] | ||
def _imread_cv2_RAW(self, path): | ||
return cv2_imread(path, -1) | ||
def _imread_cv2_Y(self, path): | ||
if self.force_color: | ||
img = cv2_imread(path, cv2.IMREAD_COLOR) | ||
else: | ||
img = cv2_imread(path, cv2.IMREAD_ANYCOLOR) | ||
if len(img.shape) == 2: | ||
return np.expand_dims(img, 3) | ||
elif len(img.shape) == 3: | ||
return cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) | ||
else: | ||
raise ValueError('The dimension should be either 2 or 3.') | ||
def _imread_cv2_L(self, path): | ||
return cv2_imread(path, cv2.IMREAD_GRAYSCALE) | ||
|
||
def _imread_pillow_RGB(self, path): | ||
img = Image.open(path) | ||
im = np.array(img.convert(self.mode)) | ||
img.close() | ||
return im | ||
_imread_pillow_L = _imread_pillow_RGB | ||
# WARNING: the RGB->YCbCr procedure of PIL may be different with OpenCV | ||
def _imread_pillow_Y(self, path): | ||
img = Image.open(path) | ||
if img.mode == 'RGB': | ||
im = np.array(img.convert('YCbCr')) | ||
elif img.mode == 'L': | ||
if self.force_color: | ||
im = np.array(img.convert('RGB').convert('YCbCr')) | ||
else: | ||
im = np.expand_dims(np.array(img), 3) | ||
else: | ||
img.close() | ||
raise NotImplementedError('Only support RGB and gray images now.') | ||
img.close() | ||
return im | ||
|
||
def _imwrite_cv2_RGB(self, image, path): | ||
cv2.imwrite(path, image[..., ::-1]) | ||
def _imwrite_cv2_RAW(self, image, path): | ||
pass | ||
def _imwrite_cv2_Y(self, image, path): | ||
if image.shape[2] == 1: | ||
cv2.imwrite(path, image[..., 0]) | ||
elif image.shape[2] == 3: | ||
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_YCrCb2BGR)) | ||
else: | ||
raise ValueError('There should be 1 or 3 channels.') | ||
def _imwrite_cv2_L(self, image, path): | ||
cv2.imwrite(path, image) | ||
|
||
def _imwrite_pillow_RGB(self, image, path): | ||
Image.fromarray(image).save(path) | ||
_imwrite_pillow_L = _imwrite_pillow_RGB | ||
def _imwrite_pillow_Y(self, image, path): | ||
if image.shape[2] == 1: | ||
self._imwrite_pillow_L(np.squeeze(image, 2), path) | ||
elif image.shape[2] == 3: | ||
Image.fromarray(image, mode='YCbCr').convert('RGB').save(path) | ||
else: | ||
raise ValueError('There should be 1 or 3 channels.') | ||
|
||
def _trans_batch_RGB_HWC(self, images): | ||
return np.ascontiguousarray(images) | ||
def _trans_batch_RGB_CHW(self, images): | ||
return np.ascontiguousarray(np.transpose(images, (0, 3, 1, 2))) | ||
_trans_batch_RAW_HWC = _trans_batch_RGB_HWC | ||
_trans_batch_RAW_CHW = _trans_batch_RGB_CHW | ||
_trans_batch_Y_HWC = _trans_batch_RGB_HWC | ||
_trans_batch_Y_CHW = _trans_batch_RGB_CHW | ||
def _trans_batch_L_HWC(self, images): | ||
return np.ascontiguousarray(np.expand_dims(images, 3)) | ||
def _trans_batch_L_CHW(slef, images): | ||
return np.ascontiguousarray(np.expand_dims(images, 1)) | ||
|
||
def _trans_image_RGB_HWC(self, image): | ||
return np.ascontiguousarray(image) | ||
def _trans_image_RGB_CHW(self, image): | ||
return np.ascontiguousarray(np.transpose(image, (2, 0, 1))) | ||
_trans_image_RAW_HWC = _trans_image_RGB_HWC | ||
_trans_image_RAW_CHW = _trans_image_RGB_CHW | ||
_trans_image_Y_HWC = _trans_image_RGB_HWC | ||
_trans_image_Y_CHW = _trans_image_RGB_CHW | ||
def _trans_image_L_HWC(self, image): | ||
return np.ascontiguousarray(np.expand_dims(image, 2)) | ||
def _trans_image_L_CHW(self, image): | ||
return np.ascontiguousarray(np.expand_dims(image, 0)) | ||
|
||
def _trans_back_RGB_HWC(self, image): | ||
return image | ||
def _trans_back_RGB_CHW(self, image): | ||
return np.transpose(image, (1, 2, 0)) | ||
_trans_back_RAW_HWC = _trans_back_RGB_HWC | ||
_trans_back_RAW_CHW = _trans_back_RGB_CHW | ||
_trans_back_Y_HWC = _trans_back_RGB_HWC | ||
_trans_back_Y_CHW = _trans_back_RGB_CHW | ||
def _trans_back_L_HWC(self, image): | ||
return np.squeeze(image, 2) | ||
def _trans_back_L_CHW(self, image): | ||
return np.squeeze(image, 0) | ||
|
||
img_ext = ('png', 'PNG', 'jpg', 'JPG', 'bmp', 'BMP', 'jpeg', 'JPEG') | ||
|
||
def is_image(self, fname): | ||
return any(fname.endswith(i) for i in self.img_ext) | ||
|
||
def read(self, paths): | ||
if isinstance(paths, (list, tuple)): | ||
images = [self._imread(path) for path in paths] | ||
return self._trans_batch(np.array(images)) | ||
return self._trans_image(self._imread(paths)) | ||
|
||
def back(self, image): | ||
return self._trans_back(image) | ||
|
||
def write(self, image, path): | ||
os.makedirs(os.path.dirname(path), exist_ok=True) | ||
self._imwrite(self.back(image), path) | ||
|
||
def read_until_success(func): | ||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
for i in range(30): | ||
try: | ||
ret = func(*args, **kwargs) | ||
if ret is None: | ||
raise OSError() | ||
else: | ||
break | ||
except OSError: | ||
print('%s OSError' % str(args)) | ||
time.sleep(1) | ||
return ret | ||
return wrapper | ||
|
||
@read_until_success | ||
def cv2_imread(*args, **kwargs): | ||
return cv2.imread(*args, **kwargs) | ||
|
||
# if __name__ == '__main__': | ||
# import matplotlib.pyplot as plt | ||
# im_rgb_chw_cv2 = imlib('rgb', fmt='chw', lib='cv2') | ||
# im_rgb_hwc_cv2 = imlib('rgb', fmt='hwc', lib='cv2') | ||
# im_rgb_chw_pil = imlib('rgb', fmt='chw', lib='pillow') | ||
# im_rgb_hwc_pil = imlib('rgb', fmt='hwc', lib='pillow') | ||
# im_y_chw_cv2 = imlib('y', fmt='chw', lib='cv2') | ||
# im_y_hwc_cv2 = imlib('y', fmt='hwc', lib='cv2') | ||
# im_y_chw_pil = imlib('y', fmt='chw', lib='pillow') | ||
# im_y_hwc_pil = imlib('y', fmt='hwc', lib='pillow') | ||
# im_l_chw_cv2 = imlib('l', fmt='chw', lib='cv2') | ||
# im_l_hwc_cv2 = imlib('l', fmt='hwc', lib='cv2') | ||
# im_l_chw_pil = imlib('l', fmt='chw', lib='pillow') | ||
# im_l_hwc_pil = imlib('l', fmt='hwc', lib='pillow') | ||
# path = 'D:/Datasets/test/000001.jpg' | ||
|
||
# img_rgb_chw_cv2 = im_rgb_chw_cv2.read(path) | ||
# print(img_rgb_chw_cv2.shape) | ||
# plt.imshow(im_rgb_chw_cv2.back(img_rgb_chw_cv2)) | ||
# plt.show() | ||
# im_rgb_chw_cv2.write(img_rgb_chw_cv2, | ||
# (path.replace('000001.jpg', 'img_rgb_chw_cv2.jpg'))) | ||
# img_rgb_hwc_cv2 = im_rgb_hwc_cv2.read(path) | ||
# print(img_rgb_hwc_cv2.shape) | ||
# plt.imshow(im_rgb_hwc_cv2.back(img_rgb_hwc_cv2)) | ||
# plt.show() | ||
# im_rgb_hwc_cv2.write(img_rgb_hwc_cv2, | ||
# (path.replace('000001.jpg', 'img_rgb_hwc_cv2.jpg'))) | ||
# img_rgb_chw_pil = im_rgb_chw_pil.read(path) | ||
# print(img_rgb_chw_pil.shape) | ||
# plt.imshow(im_rgb_chw_pil.back(img_rgb_chw_pil)) | ||
# plt.show() | ||
# im_rgb_chw_pil.write(img_rgb_chw_pil, | ||
# (path.replace('000001.jpg', 'img_rgb_chw_pil.jpg'))) | ||
# img_rgb_hwc_pil = im_rgb_hwc_pil.read(path) | ||
# print(img_rgb_hwc_pil.shape) | ||
# plt.imshow(im_rgb_hwc_pil.back(img_rgb_hwc_pil)) | ||
# plt.show() | ||
# im_rgb_hwc_pil.write(img_rgb_hwc_pil, | ||
# (path.replace('000001.jpg', 'img_rgb_hwc_pil.jpg'))) | ||
|
||
|
||
# img_y_chw_cv2 = im_y_chw_cv2.read(path) | ||
# print(img_y_chw_cv2.shape) | ||
# plt.imshow(np.squeeze(im_y_chw_cv2.back(img_y_chw_cv2))) | ||
# plt.show() | ||
# im_y_chw_cv2.write(img_y_chw_cv2, | ||
# (path.replace('000001.jpg', 'img_y_chw_cv2.jpg'))) | ||
# img_y_hwc_cv2 = im_y_hwc_cv2.read(path) | ||
# print(img_y_hwc_cv2.shape) | ||
# plt.imshow(np.squeeze(im_y_hwc_cv2.back(img_y_hwc_cv2))) | ||
# plt.show() | ||
# im_y_hwc_cv2.write(img_y_hwc_cv2, | ||
# (path.replace('000001.jpg', 'img_y_hwc_cv2.jpg'))) | ||
# img_y_chw_pil = im_y_chw_pil.read(path) | ||
# print(img_y_chw_pil.shape) | ||
# plt.imshow(np.squeeze(im_y_chw_pil.back(img_y_chw_pil))) | ||
# plt.show() | ||
# im_y_chw_pil.write(img_y_chw_pil, | ||
# (path.replace('000001.jpg', 'img_y_chw_pil.jpg'))) | ||
# img_y_hwc_pil = im_y_hwc_pil.read(path) | ||
# print(img_y_hwc_pil.shape) | ||
# plt.imshow(np.squeeze(im_y_hwc_pil.back(img_y_hwc_pil))) | ||
# plt.show() | ||
# im_y_hwc_pil.write(img_y_hwc_pil, | ||
# (path.replace('000001.jpg', 'img_y_hwc_pil.jpg'))) | ||
|
||
|
||
# img_l_chw_cv2 = im_l_chw_cv2.read(path) | ||
# print(img_l_chw_cv2.shape) | ||
# plt.imshow(im_l_chw_cv2.back(img_l_chw_cv2)) | ||
# plt.show() | ||
# im_l_chw_cv2.write(img_l_chw_cv2, | ||
# (path.replace('000001.jpg', 'img_l_chw_cv2.jpg'))) | ||
# img_l_hwc_cv2 = im_l_hwc_cv2.read(path) | ||
# print(img_l_hwc_cv2.shape) | ||
# plt.imshow(im_l_hwc_cv2.back(img_l_hwc_cv2)) | ||
# plt.show() | ||
# im_l_hwc_cv2.write(img_l_hwc_cv2, | ||
# (path.replace('000001.jpg', 'img_l_hwc_cv2.jpg'))) | ||
# img_l_chw_pil = im_l_chw_pil.read(path) | ||
# print(img_l_chw_pil.shape) | ||
# plt.imshow(im_l_chw_pil.back(img_l_chw_pil)) | ||
# plt.show() | ||
# im_l_chw_pil.write(img_l_chw_pil, | ||
# (path.replace('000001.jpg', 'img_l_chw_pil.jpg'))) | ||
# img_l_hwc_pil = im_l_hwc_pil.read(path) | ||
# print(img_l_hwc_pil.shape) | ||
# plt.imshow(im_l_hwc_pil.back(img_l_hwc_pil)) | ||
# plt.show() | ||
# im_l_hwc_pil.write(img_l_hwc_pil, | ||
# (path.replace('000001.jpg', 'img_l_hwc_pil.jpg'))) |
Oops, something went wrong.