Skip to content

Commit

Permalink
RDN
Browse files Browse the repository at this point in the history
  • Loading branch information
wenrui-purdue committed Nov 18, 2019
1 parent b3aae1f commit a770351
Show file tree
Hide file tree
Showing 10 changed files with 1,519 additions and 1 deletion.
75 changes: 75 additions & 0 deletions Metric.log
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/scratch/gilbreth/li3120/RDB/Test_Result/Manga109_16.h5
average ssim:0.957323
average psnr:32.212730
/scratch/gilbreth/li3120/RDB/Test_Result/Set14_19.h5
average ssim:0.865827
average psnr:28.640444
/scratch/gilbreth/li3120/RDB/Test_Result/Set14_17.h5
average ssim:0.945727
average psnr:34.091388
/scratch/gilbreth/li3120/RDB/Test_Result/Manga109_17.h5
average ssim:0.983521
average psnr:36.844997
/scratch/gilbreth/li3120/RDB/Test_Result/Set14_18.h5
average ssim:0.867282
average psnr:28.860134
/scratch/gilbreth/li3120/RDB/Test_Result/B100_18.h5
average ssim:0.815428
average psnr:26.904180
/scratch/gilbreth/li3120/RDB/Test_Result/Urban100_19.h5
average ssim:0.820174
average psnr:25.668797
/scratch/gilbreth/li3120/RDB/Test_Result/Urban100_20.h5
average ssim:0.785363
average psnr:24.468557
/scratch/gilbreth/li3120/RDB/Test_Result/Urban100_18.h5
average ssim:0.828697
average psnr:25.981574
/scratch/gilbreth/li3120/RDB/Test_Result/Set5_16.h5
average ssim:0.961617
average psnr:34.722022
/scratch/gilbreth/li3120/RDB/Test_Result/Set5_18.h5
average ssim:0.940866
average psnr:32.595460
/scratch/gilbreth/li3120/RDB/Test_Result/B100_20.h5
average ssim:0.757086
average psnr:25.223548
/scratch/gilbreth/li3120/RDB/Test_Result/Manga109_19.h5
average ssim:0.921437
average psnr:29.053814
/scratch/gilbreth/li3120/RDB/Test_Result/Manga109_20.h5
average ssim:0.880391
average psnr:26.505380
/scratch/gilbreth/li3120/RDB/Test_Result/B100_16.h5
average ssim:0.866474
average psnr:28.625940
/scratch/gilbreth/li3120/RDB/Test_Result/B100_19.h5
average ssim:0.826464
average psnr:27.226383
/scratch/gilbreth/li3120/RDB/Test_Result/Set5_17.h5
average ssim:0.980510
average psnr:38.318403
/scratch/gilbreth/li3120/RDB/Test_Result/Manga109_18.h5
average ssim:0.925166
average psnr:29.275654
/scratch/gilbreth/li3120/RDB/Test_Result/Urban100_16.h5
average ssim:0.884011
average psnr:28.174515
/scratch/gilbreth/li3120/RDB/Test_Result/Set5_19.h5
average ssim:0.938823
average psnr:32.228928
/scratch/gilbreth/li3120/RDB/Test_Result/Urban100_17.h5
average ssim:0.942854
average psnr:32.116574
/scratch/gilbreth/li3120/RDB/Test_Result/Set14_16.h5
average ssim:0.901079
average psnr:30.622271
/scratch/gilbreth/li3120/RDB/Test_Result/Set14_20.h5
average ssim:0.811752
average psnr:26.307208
/scratch/gilbreth/li3120/RDB/Test_Result/Set5_20.h5
average ssim:0.885235
average psnr:28.487348
/scratch/gilbreth/li3120/RDB/Test_Result/B100_17.h5
average ssim:0.927497
average psnr:31.904156
31 changes: 31 additions & 0 deletions Metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
import h5py
import os
from glob import glob
import matplotlib.pyplot as plt
from skimage.transform import rescale
from skimage.measure import compare_psnr,compare_ssim

paths=glob('/scratch/gilbreth/li3120/RDB/Test_Result/*')

for path in paths:
print(path)
f= h5py.File(path, "r")
rec=f['rec']
gt=f['gt']
Ave_ssim=np.zeros((len(gt),1))
Ave_psnr=np.zeros((len(gt),1))
def rgb2ycbcr (img):
y = 16 + (65.481 * img[:, :, 0]) + (128.553 * img[:, :, 1]) + (24.966 * img[:, :, 2])
return y / 255
for i in range(len(gt)):
img=np.squeeze(rec[i,:,:,:])
imgg=np.squeeze(gt[i,:,:,:])
img=np.clip(img,a_min=0,a_max=1)
img=rgb2ycbcr(img)
imgg=rgb2ycbcr(imgg)
Ave_ssim[i]=compare_ssim(img,imgg)
Ave_psnr[i]=compare_psnr(img,imgg)
print('average ssim:%f'%(np.mean(Ave_ssim)))
print('average psnr:%f'%(np.mean(Ave_psnr)))
f.close()
11 changes: 11 additions & 0 deletions RDB_run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash
module use /depot/itap/amaji/modules
module spider learning
module load learning/conda-5.1.0-py36-gpu
module load ml-toolkit-gpu/tensorflow
module load ml-toolkit-gpu/keras
module load ml-toolkit-gpu/opencv
module list
source activate /home/li3120/.conda/envs/cent7/5.1.0-py36/tfgan1
cd /home/li3120/ECE570
python /home/li3120/ECE570/model.py
44 changes: 43 additions & 1 deletion README.md
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1 +1,43 @@
# RDN-SR-Keras
# course-project-DamonLee5
# Acknowledgement
For the best of my knowledge, my project is the first RDN implementation written by Keras. However, I still need to acknowledge that https://github.com/yulunzhang/RDN and https://github.com/hengchuan/RDN-TensorFlow help me for some hyperparameter setting.

# Data preparation
## utils.py
This file generates our training datasets for three different degradations, including 5 datasets.
- Bicubic degradation with factor 2, 3, and 4.
- BD degradation with factor 3
- DN degradation with factor 3
Also this file generates 25 datasets for 5 degradations on 5 benchmark datasets, which have not been seen while training.
To generate those dataset, I wrote function to chop, rescale, filter, read, and save the image.

# DNN Model
I implement this project by Keras. I utilize a standard Keras framework with a data_loader and model file.
## data_loader.py
This file is used for load the data from the dataset for training, sampling and testing.
## model.py
This file include the entire RDN network. Also, it can test on validation set while traing and load a trained model for testing.

# Training
RDB_run.sh is used to train in cluster by qsub.

# Test Result Generation
## test.py
Since we need to calculate the metric, and online calculate takes too long, I first generate the entire reconstructed high resolution dataset for metric calculation. This result dataset include ground truth and reconstructed image.

## Metric.py
Calculate the Metric among all 25 testing datasets. Result is in the Metric.log, which is consistent with the result displayed in the term paper.

## loss_plot.ipynb
Plot loss from training log file.

## Utils_test.ipynb
Generate an entire high resolution image from reconstructed patches.

# Saved model
https://drive.google.com/open?id=1sQZO1ZZ4MIYM_vF5wOm6e9F8OiyeoF_H
- RDB_best_16.h5, BI degradation, factor 3
- RDB_best_17.h5, BI degradation, factor 2
- RDB_best_18.h5, BI degradation, factor 4
- RDB_best_19.h5, BD degradation, factor 3
- RDB_best_20.h5, DN degradation, factor 3
653 changes: 653 additions & 0 deletions Utils_test.ipynb

Large diffs are not rendered by default.

92 changes: 92 additions & 0 deletions data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import scipy
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage.io import imread
from skimage import data_dir
from skimage.transform import radon, rescale
from skimage.transform import iradon
import h5py
class DataLoader():
def __init__(self, dataset_name, img_crop=(128, 128)):
self.dataset_name = dataset_name
self.img_crop = img_crop
np.random.seed(3)
def load_sample_data(self, batch_size=1, is_testing=False,cal_mse=False,mse_path=''):
imgs_bl = []
imgs_sh = []
views=16
f=h5py.File(mse_path, "r")
label5=f['label']
input5=f['input']
for iiter,(img_bl, img_sh) in enumerate(zip(input5,label5)):
imgs_bl.append(img_bl)
imgs_sh.append(img_sh)
f.close()
imgs_bl = np.array(imgs_bl)
imgs_sh = np.array(imgs_sh)
return imgs_bl, imgs_sh

def load_data(self, batch_size=1, is_testing=False,cal_mse=False,mse_path=''):
path = '/scratch/gilbreth/li3120/dataset/DIV2K_train_HR/Train/%s.h5' % (self.dataset_name)
f= h5py.File(path, "r")
input5=f['input']
ll=len(input5)
f.close()
iteration=1000
ipath=np.random.permutation(range(ll))
ipath=ipath[:1000*batch_size]
for i in range(iteration):
batch = ipath[i*batch_size:(i+1)*batch_size]
rd=np.random.randint(4,size=batch_size)
imgs_bl = []
imgs_sh = []
f= h5py.File(path, "r")
rdind=0
for img_path in batch:

label5 =f["label"]
img_sh=label5[img_path]
img_sh=np.array(img_sh)
if rd[rdind]==0:
img_sh=np.flip(img_sh,0)
if rd[rdind]==1:
img_sh=np.flip(img_sh,1)
if rd[rdind]==2:
img_sh=np.rot90(img_sh,1,(0,1))
input5=f["input"]
img_bl=input5[img_path]
img_bl=np.array(img_bl)
if rd[rdind]==0:
img_bl=np.flip(img_bl,0)
if rd[rdind]==1:
img_bl=np.flip(img_bl,1)
if rd[rdind]==2:
img_bl=np.rot90(img_bl,1,(0,1))
imgs_bl.append(img_bl)
imgs_sh.append(img_sh)
rdind=rdind+1
f.close()
imgs_bl = np.array(imgs_bl)
imgs_sh = np.array(imgs_sh)
yield imgs_bl, imgs_sh

def load_test_data(self, batch_size=1, is_testing=False,cal_mse=False,mse_path=''):
batch_images=glob('%s/*'%mse_path)
print(batch_images[0])
imgs_bl = []
imgs_sh = []
views=16
for img_path in batch_images:
f= h5py.File(img_path, "r")
img_sh =f["gt"]
img_sh=np.expand_dims(img_sh, axis=-1)
img_bl=f["mvbp"]
img_bl=np.array(img_bl)
imgs_bl.append(img_bl)
imgs_sh.append(img_sh)
f.close()
imgs_bl = np.array(imgs_bl) / 1000.0
imgs_sh = np.array(imgs_sh) / 1000.0
return imgs_bl, imgs_sh
Loading

0 comments on commit a770351

Please sign in to comment.