-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 81c91a6
Showing
12 changed files
with
440 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
The MIT License (MIT) | ||
|
||
Copyright (c) 2017- Jiu XU | ||
Copyright (c) 2017- Rakuten, Inc | ||
Copyright (c) 2017- Rakuten Institute of Technology | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# PyTorch LapSRN | ||
Implementation of CVPR2017 Paper: "Deep Laplacian Pyramid Networks for Fast and Accurate Super-Resolution"(http://vllab1.ucmerced.edu/~wlai24/LapSRN/) in PyTorch | ||
|
||
## Usage | ||
### Training | ||
``` | ||
usage: main.py [-h] [--batchSize BATCHSIZE] [--nEpochs NEPOCHS] [--lr LR] | ||
[--step STEP] [--cuda] [--resume RESUME] | ||
[--start-epoch START_EPOCH] [--threads THREADS] | ||
[--momentum MOMENTUM] [--weight-decay WEIGHT_DECAY] | ||
[--pretrained PRETRAINED] | ||
PyTorch LapSRN | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
--batchSize BATCHSIZE | ||
training batch size | ||
--nEpochs NEPOCHS number of epochs to train for | ||
--lr LR Learning Rate. Default=1e-4 | ||
--step STEP Sets the learning rate to the initial LR decayed by | ||
momentum every n epochs, Default: n=10 | ||
--cuda Use cuda? | ||
--resume RESUME Path to checkpoint (default: none) | ||
--start-epoch START_EPOCH | ||
Manual epoch number (useful on restarts) | ||
--threads THREADS Number of threads for data loader to use, Default: 1 | ||
--momentum MOMENTUM Momentum, Default: 0.9 | ||
--weight-decay WEIGHT_DECAY, --wd WEIGHT_DECAY | ||
weight decay, Default: 1e-4 | ||
--pretrained PRETRAINED | ||
path to pretrained model (default: none) | ||
``` | ||
|
||
### Test | ||
``` | ||
usage: test.py [-h] [--cuda] [--model MODEL] [--image IMAGE] [--scale SCALE] | ||
PyTorch LapSRN Test | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
--cuda use cuda? | ||
--model MODEL model path | ||
--image IMAGE image name | ||
--scale SCALE scale factor, Default: 4 | ||
``` | ||
We convert Set5 test set images to mat format using Matlab, for best PSNR performance, please use Matlab | ||
|
||
### Prepare Training dataset | ||
- We use hdf5 format training samples with 'data', 'label_x2', and 'label_x4' keys, the training data is generated with Matlab Bicubic Interplotation, please refer [Code for Data Generation](https://github.com/twtygqyy/pytorch-vdsr/tree/master/data) for creating training files. | ||
|
||
### Performance | ||
- We provide a pretrained LapSRN x4 model trained on T91 and BSDS200 images from [SR_training_datasets] (http://vllab1.ucmerced.edu/~wlai24/LapSRN/results/SR_training_datasets.zip) with data augmentation as mentioned in the paper | ||
- No bias is used in this implementation, and another difference from paper is that Adam with 1e-4 learning is applied instead of SGD | ||
- Performance in PSNR on Set5, Set14, and BSD100 | ||
|
||
| DataSet/Method | LapSRN Paper | LapSRN PyTorch| | ||
| ------------- |:-------------:| -----:| | ||
| Set5 | 37.54 | 37.65 | | ||
| Set14 | 28.19 | 28.27| | ||
| BSD100 | 27.32 | 27.36 | | ||
|
||
### ToDos | ||
- LapSRN x8 | ||
- Code for data generation |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch.utils.data as data | ||
import torch | ||
import numpy as np | ||
import h5py | ||
|
||
class DatasetFromHdf5(data.Dataset): | ||
def __init__(self, file_path): | ||
super(DatasetFromHdf5, self).__init__() | ||
hf = h5py.File(file_path) | ||
self.data = hf.get("data") | ||
self.label_x2 = hf.get("label_x2") | ||
self.label_x4 = hf.get("label_x4") | ||
|
||
def __getitem__(self, index): | ||
return torch.from_numpy(self.data[index,:,:,:]).float(), torch.from_numpy(self.label_x2[index,:,:,:]).float(), torch.from_numpy(self.label_x4[index,:,:,:]).float() | ||
|
||
def __len__(self): | ||
return self.data.shape[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
import math | ||
|
||
def get_upsample_filter(size): | ||
"""Make a 2D bilinear kernel suitable for upsampling""" | ||
factor = (size + 1) // 2 | ||
if size % 2 == 1: | ||
center = factor - 1 | ||
else: | ||
center = factor - 0.5 | ||
og = np.ogrid[:size, :size] | ||
filter = (1 - abs(og[0] - center) / factor) * \ | ||
(1 - abs(og[1] - center) / factor) | ||
return torch.from_numpy(filter).float() | ||
|
||
class _Conv_Block(nn.Module): | ||
def __init__(self): | ||
super(_Conv_Block, self).__init__() | ||
|
||
self.cov_block = nn.Sequential( | ||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
) | ||
|
||
def forward(self, x): | ||
output = self.cov_block(x) | ||
return output | ||
|
||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
|
||
self.conv_input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.relu = nn.LeakyReLU(0.2, inplace=True) | ||
|
||
self.convt_I1 = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=4, stride=2, padding=1, bias=False) | ||
self.convt_R1 = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.convt_F1 = self.make_layer(_Conv_Block) | ||
|
||
self.convt_I2 = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=4, stride=2, padding=1, bias=False) | ||
self.convt_R2 = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.convt_F2 = self.make_layer(_Conv_Block) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
if m.bias is not None: | ||
m.bias.data.zero_() | ||
if isinstance(m, nn.ConvTranspose2d): | ||
c1, c2, h, w = m.weight.data.size() | ||
weight = get_upsample_filter(h) | ||
m.weight.data = weight.view(1, 1, h, w).repeat(c1, c2, 1, 1) | ||
if m.bias is not None: | ||
m.bias.data.zero_() | ||
|
||
def make_layer(self, block): | ||
layers = [] | ||
layers.append(block()) | ||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
out = self.relu(self.conv_input(x)) | ||
|
||
convt_F1 = self.convt_F1(out) | ||
convt_I1 = self.convt_I1(x) | ||
convt_R1 = self.convt_R1(convt_F1) | ||
HR_2x = convt_I1 + convt_R1 | ||
|
||
convt_F2 = self.convt_F2(convt_F1) | ||
convt_I2 = self.convt_I2(HR_2x) | ||
convt_R2 = self.convt_R2(convt_F2) | ||
HR_4x = convt_I2 + convt_R2 | ||
|
||
return HR_2x, HR_4x | ||
|
||
class L1_Charbonnier_loss(nn.Module): | ||
"""L1 Charbonnierloss.""" | ||
def __init__(self): | ||
super(L1_Charbonnier_loss, self).__init__() | ||
self.eps = 1e-6 | ||
|
||
def forward(self, X, Y): | ||
diff = torch.add(X, -Y) | ||
error = torch.sqrt( diff * diff + self.eps ) | ||
loss = torch.sum(error) | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import argparse, os | ||
import torch | ||
import random | ||
import torch.backends.cudnn as cudnn | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch.autograd import Variable | ||
from torch.utils.data import DataLoader | ||
from lapsrn import Net, L1_Charbonnier_loss | ||
from dataset import DatasetFromHdf5 | ||
|
||
# Training settings | ||
parser = argparse.ArgumentParser(description="PyTorch LapSRN") | ||
parser.add_argument("--batchSize", type=int, default=64, help="training batch size") | ||
parser.add_argument("--nEpochs", type=int, default=100, help="number of epochs to train for") | ||
parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate. Default=1e-4") | ||
parser.add_argument("--step", type=int, default=100, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10") | ||
parser.add_argument("--cuda", action="store_true", help="Use cuda?") | ||
parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)") | ||
parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number (useful on restarts)") | ||
parser.add_argument("--threads", type=int, default=1, help="Number of threads for data loader to use, Default: 1") | ||
parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9") | ||
parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="weight decay, Default: 1e-4") | ||
parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)") | ||
|
||
def main(): | ||
|
||
global opt, model | ||
opt = parser.parse_args() | ||
print opt | ||
|
||
cuda = opt.cuda | ||
if cuda and not torch.cuda.is_available(): | ||
raise Exception("No GPU found, please run without --cuda") | ||
|
||
opt.seed = random.randint(1, 10000) | ||
print("Random Seed: ", opt.seed) | ||
torch.manual_seed(opt.seed) | ||
if cuda: | ||
torch.cuda.manual_seed(opt.seed) | ||
|
||
cudnn.benchmark = True | ||
|
||
print("===> Loading datasets") | ||
train_set = DatasetFromHdf5("data/lapsr_pry_x4.h5") | ||
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) | ||
|
||
print("===> Building model") | ||
model = Net() | ||
criterion = L1_Charbonnier_loss() | ||
|
||
print("===> Setting GPU") | ||
if cuda: | ||
model = model.cuda() | ||
criterion = criterion.cuda() | ||
else: | ||
model = model.cpu() | ||
|
||
# optionally resume from a checkpoint | ||
if opt.resume: | ||
if os.path.isfile(opt.resume): | ||
print("=> loading checkpoint '{}'".format(opt.resume)) | ||
checkpoint = torch.load(opt.resume) | ||
opt.start_epoch = checkpoint["epoch"] + 1 | ||
model.load_state_dict(checkpoint["model"].state_dict()) | ||
else: | ||
print("=> no checkpoint found at '{}'".format(opt.resume)) | ||
|
||
# optionally copy weights from a checkpoint | ||
if opt.pretrained: | ||
if os.path.isfile(opt.pretrained): | ||
print("=> loading model '{}'".format(opt.pretrained)) | ||
weights = torch.load(opt.pretrained) | ||
model.load_state_dict(weights['model'].state_dict()) | ||
else: | ||
print("=> no model found at '{}'".format(opt.pretrained)) | ||
|
||
print("===> Setting Optimizer") | ||
optimizer = optim.Adam(model.parameters(), lr=opt.lr) | ||
|
||
print("===> Training") | ||
for epoch in range(opt.start_epoch, opt.nEpochs + 1): | ||
train(training_data_loader, optimizer, model, criterion, epoch) | ||
save_checkpoint(model, epoch) | ||
|
||
def adjust_learning_rate(optimizer, epoch): | ||
"""Sets the learning rate to the initial LR decayed by 10 every 10 epochs""" | ||
lr = opt.lr * (0.1 ** (epoch // opt.step)) | ||
return lr | ||
|
||
def train(training_data_loader, optimizer, model, criterion, epoch): | ||
|
||
lr = adjust_learning_rate(optimizer, epoch-1) | ||
|
||
for param_group in optimizer.param_groups: | ||
param_group["lr"] = lr | ||
print "epoch =", epoch,"lr =",optimizer.param_groups[0]["lr"] | ||
model.train() | ||
|
||
for iteration, batch in enumerate(training_data_loader, 1): | ||
|
||
input, label_x2, label_x4 = Variable(batch[0]), Variable(batch[1], requires_grad=False), Variable(batch[2], requires_grad=False) | ||
|
||
if opt.cuda: | ||
input = input.cuda() | ||
label_x2 = label_x2.cuda() | ||
label_x4 = label_x4.cuda() | ||
|
||
HR_2x, HR_4x = model(input) | ||
|
||
loss_x2 = criterion(HR_2x, label_x2) | ||
loss_x4 = criterion(HR_4x, label_x4) | ||
loss = loss_x2 + loss_x4 | ||
|
||
optimizer.zero_grad() | ||
|
||
loss_x2.backward(retain_variables=True) | ||
|
||
loss_x4.backward() | ||
|
||
optimizer.step() | ||
|
||
if iteration%100 == 0: | ||
print("===> Epoch[{}]({}/{}): Loss: {:.10f}".format(epoch, iteration, len(training_data_loader), loss.data[0])) | ||
|
||
def save_checkpoint(model, epoch): | ||
model_folder = "model_adam/" | ||
model_out_path = model_folder + "model_epoch_{}.pth".format(epoch) | ||
state = {"epoch": epoch ,"model": model} | ||
if not os.path.exists(model_folder): | ||
os.makedirs(model_folder) | ||
|
||
torch.save(state, model_out_path) | ||
|
||
print("Checkpoint saved to {}".format(model_out_path)) | ||
|
||
if __name__ == "__main__": | ||
main() |
Binary file not shown.
Oops, something went wrong.