-
Notifications
You must be signed in to change notification settings - Fork 3
/
laploss.py
83 lines (68 loc) · 2.95 KB
/
laploss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
from PIL import Image
import torch
from torch import nn
import torch.nn.functional as fnn
from torch.autograd import Variable
def build_gauss_kernel(size=5, sigma=1.0, n_channels=1, cuda=False):
if size % 2 != 1:
raise ValueError("kernel size must be uneven")
grid = np.float32(np.mgrid[0:size, 0:size].T)
gaussian = lambda x: np.exp((x - size // 2) ** 2 / (-2 * sigma ** 2)) ** 2
kernel = np.sum(gaussian(grid), axis=2)
kernel /= np.sum(kernel)
# repeat same kernel across depth dimension
kernel = np.tile(kernel, (n_channels, 1, 1))
# conv weight should be (out_channels, groups/in_channels, h, w),
# and since we have depth-separable convolution we want the groups dimension to be 1
kernel = torch.FloatTensor(kernel[:, None, :, :])
if cuda:
kernel = kernel.cuda()
return Variable(kernel, requires_grad=False)
def conv_gauss(img, kernel):
""" convolve img with a gaussian kernel that has been built with build_gauss_kernel """
n_channels, _, kw, kh = kernel.shape
img = fnn.pad(img, (kw // 2, kh // 2, kw // 2, kh // 2), mode='replicate')
return fnn.conv2d(img, kernel, groups=n_channels)
def laplacian_pyramid(img, kernel, max_levels=5):
current = img
pyr = []
for level in range(max_levels):
filtered = conv_gauss(current, kernel)
diff = current - filtered
pyr.append(diff)
current = fnn.avg_pool2d(filtered, 2)
pyr.append(current)
return pyr
class LapLoss(nn.Module):
def __init__(self, max_levels=5, k_size=5, sigma=2.0):
super(LapLoss, self).__init__()
self.max_levels = max_levels
self.k_size = k_size
self.sigma = sigma
self._gauss_kernel = None
self.L1_loss = nn.L1Loss(size_average=False)
def forward(self, input, target):
if self._gauss_kernel is None or self._gauss_kernel.shape[1] != input.shape[1]:
self._gauss_kernel = build_gauss_kernel(
size=self.k_size, sigma=self.sigma,
n_channels=input.shape[1], cuda=input.is_cuda
)
pyr_input = laplacian_pyramid(input, self._gauss_kernel, self.max_levels)
pyr_target = laplacian_pyramid(target, self._gauss_kernel, self.max_levels)
return sum(self.L1_loss(a, b) for a, b in zip(pyr_input, pyr_target))
class LapMap(nn.Module):
def __init__(self, max_levels=5, k_size=5, sigma=2.0):
super(LapMap, self).__init__()
self.max_levels = max_levels
self.k_size = k_size
self.sigma = sigma
self._gauss_kernel = None
def forward(self, input):
if self._gauss_kernel is None or self._gauss_kernel.shape[1] != input.shape[1]:
self._gauss_kernel = build_gauss_kernel(
size=self.k_size, sigma=self.sigma,
n_channels=input.shape[1], cuda=input.is_cuda
)
pyr_input = laplacian_pyramid(input, self._gauss_kernel, self.max_levels)
return pyr_input