Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update WBsRGB.py #5

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 102 additions & 58 deletions WB_sRGB_Python/classes/WBsRGB.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,56 @@
##########################################################################


import numpy as np
import numpy.matlib
from datetime import datetime

# import numpy.matlib
import cupy as cp
import cv2
import numpy as np


class WBsRGB:
def __init__(self, gamut_mapping=2, upgraded=0):

cp.cuda.Device(1).use()
if upgraded == 1:
self.features = np.load('models/features+.npy') # training encoded features
self.mappingFuncs = np.load('models/mappingFuncs+.npy') # mapping correction functions
self.encoderWeights = np.load('models/encoderWeights+.npy') # weight matrix for histogram encoding
self.encoderBias = np.load('models/encoderBias+.npy') # bias vector for histogram encoding
self.features = cp.load(
'/opt/instore-app/test/whitebalance/WB_sRGB_Python/models/features+.npy') # training encoded features
self.mappingFuncs = np.load(
'/opt/instore-app/test/whitebalance/WB_sRGB_Python/models/mappingFuncs+.npy') # mapping correction functions
self.encoderWeights = np.load(
'/opt/instore-app/test/whitebalance/WB_sRGB_Python/models/encoderWeights+.npy') # weight matrix for histogram encoding
self.encoderBias = np.load(
'/opt/instore-app/test/whitebalance/WB_sRGB_Python/models/encoderBias+.npy') # bias vector for histogram encoding
self.K = 75 # K value for nearest neighbor searching---for the upgraded model, we found 75 is better
else:
self.features = np.load('models/features.npy') # training encoded features
self.mappingFuncs = np.load('models/mappingFuncs.npy') # mapping correction functions
self.encoderWeights = np.load('models/encoderWeights.npy') # weight matrix for histogram encoding
self.encoderBias = np.load('models/encoderBias.npy') # bias vector for histogram encoding
self.features = cp.load(
'/opt/instore-app/test/whitebalance/WB_sRGB_Python/models/features.npy') # training encoded features
self.mappingFuncs = np.load(
'/opt/instore-app/test/whitebalance/WB_sRGB_Python/models/mappingFuncs.npy') # mapping correction functions
self.encoderWeights = np.load(
'/opt/instore-app/test/whitebalance/WB_sRGB_Python/models/encoderWeights.npy') # weight matrix for histogram encoding
self.encoderBias = np.load(
'/opt/instore-app/test/whitebalance/WB_sRGB_Python/models/encoderBias.npy') # bias vector for histogram encoding
self.K = 25 # K value for nearest neighbor searching

self.sigma = 0.25 # fall-off factor for KNN blending
self.h = 60 # histogram bin width
self.h = 60 # histogram bin width
# our results reported with gamut_mapping=2, however gamut_mapping=1 gives more compelling results with
# over-saturated examples
self.gamut_mapping = gamut_mapping #options: =1 for scaling, =2 for clipping
self.gamut_mapping = gamut_mapping # options: =1 for scaling, =2 for clipping

def encode(self, hist):
""" Generates a compacted feature of a given RGB-uv histogram tensor. """
histR_reshaped = np.reshape(np.transpose(hist[:, :, 0]),
(1, int(hist.size / 3)), order="F") # reshaped red layer of histogram
(1, int(hist.size / 3)), order="F") # reshaped red layer of histogram
histG_reshaped = np.reshape(np.transpose(hist[:, :, 1]),
(1, int(hist.size / 3)), order="F") # reshaped green layer of histogram
histB_reshaped = np.reshape(np.transpose(hist[:, :, 2]),
(1, int(hist.size / 3)), order="F") # reshaped blue layer of histogram
hist_reshaped = np.append(histR_reshaped,
[histG_reshaped, histB_reshaped]) # reshaped histogram n * 3 (n = h*h)
feature = np.dot(hist_reshaped - self.encoderBias.transpose(), self.encoderWeights) # compute compacted histogram feature
feature = np.dot(hist_reshaped - self.encoderBias.transpose(),
self.encoderWeights) # compute compacted histogram feature
return feature

def rgbUVhist(self, I):
Expand All @@ -70,104 +82,136 @@ def rgbUVhist(self, I):
G = II[inds, 1] # green channel
B = II[inds, 2] # blue channel
I_reshaped = np.concatenate((R, G, B), axis=0).transpose() # reshaped image (wo zero values)
I_reshaped = cp.asarray(I_reshaped)
eps = 6.4 / self.h
A = np.arange(-3.2, 3.19, eps) # dummy vector
hist = np.zeros((A.size, A.size, 3)) # histogram will be stored here
Iy = np.sqrt(np.power(I_reshaped[:, 0], 2) + np.power(I_reshaped[:, 1], 2) +
np.power(I_reshaped[:, 2], 2)) # intensity vector
A = cp.arange(-3.2, 3.19, eps) # dummy vector
hist = cp.zeros((A.size, A.size, 3)) # histogram will be stored here
Iy = cp.sqrt(cp.power(I_reshaped[:, 0], 2) + cp.power(I_reshaped[:, 1], 2) +
cp.power(I_reshaped[:, 2], 2)) # intensity vector
for i in range(3): # for each histogram layer, do
time_start_for = datetime.now()
r = [] # excluded channels will be stored here
for j in range(3): # for each color channel do
if j != i: # if current color channel does not match current histogram layer,
r.append(j) # exclude it
Iu = np.log(I_reshaped[:, i] / I_reshaped[:, r[1]]) # current color channel / the first excluded channel
Iv = np.log(I_reshaped[:, i] / I_reshaped[:, r[0]]) # current color channel / the second excluded channel
diff_u = np.abs(np.matlib.repmat(Iu, np.size(A), 1).transpose() - np.matlib.repmat(A, np.size(Iu),
1)) # differences in u space
diff_v = np.abs(np.matlib.repmat(Iv, np.size(A), 1).transpose() - np.matlib.repmat(A, np.size(Iv),
1)) # differences in v space
Iu = cp.log(I_reshaped[:, i] / I_reshaped[:, r[1]]) # current color channel / the first excluded channel
Iv = cp.log(I_reshaped[:, i] / I_reshaped[:, r[0]]) # current color channel / the second excluded channel
diff_u = cp.abs(
cp.tile(Iu, (cp.size(A), 1)).transpose() - cp.tile(A, (cp.size(Iu), 1))) # differences in u space
diff_v = cp.abs(
cp.tile(Iv, (cp.size(A), 1)).transpose() - cp.tile(A, (cp.size(Iv), 1))) # differences in v space
diff_u[diff_u >= (eps / 2)] = 0 # do not count any pixel has difference beyond the threshold in the u space
diff_u[diff_u != 0] = 1 # remaining pixels will be counted
diff_v[diff_v >= (eps / 2)] = 0 # do not count any pixel has difference beyond the threshold in the v space
diff_v[diff_v != 0] = 1 # remaining pixels will be counted
# here, we will use a matrix multiplication expression to compute eq. 4 in the main paper.
# why? because it is much faster
temp = (np.matlib.repmat(Iy, np.size(A), 1) * (diff_u).transpose()) # Iy .* diff_u' (.* element-wise mult)
hist[:, :, i] = np.dot(temp, diff_v) # initialize current histogram layer with Iy .* diff' * diff_v
norm_ = np.sum(hist[:, :, i], axis=None) # compute sum of hist for normalization
hist[:, :, i] = np.sqrt(hist[:, :, i] / norm_) # (hist/norm)^(1/2)
return hist
temp = (cp.tile(Iy, (cp.size(A), 1)) * (diff_u).transpose()) # Iy .* diff_u' (.* element-wise mult)
hist[:, :, i] = cp.dot(temp, diff_v) # initialize current histogram layer with Iy .* diff' * diff_v
norm_ = cp.sum(hist[:, :, i], axis=None) # compute sum of hist for normalization
hist[:, :, i] = cp.sqrt(hist[:, :, i] / norm_) # (hist/norm)^(1/2)

return hist

def correctImage(self, I):
""" White balance a given image I. """
time_read = datetime.now()
I = cv2.cvtColor(I, cv2.COLOR_BGR2RGB) # convert from BGR to RGB
I = im2double(I) # convert to double
feature = self.encode(self.rgbUVhist(I))
D_sq = np.einsum('ij, ij ->i', self.features, self.features)[:, None] + \
np.einsum('ij, ij ->i', feature, feature) - \
2 * self.features.dot(feature.T) # squared euclidean distances

print("Image convert : {}".format(datetime.now() - time_read))
time_start_hist = datetime.now()
I_hist = self.rgbUVhist(I)
I_hist = cp.asnumpy(I_hist)
print("Hist computation: {}".format(datetime.now() - time_start_hist))
time_start_encode = datetime.now()
feature = self.encode(I_hist)
print("Encode : {}".format(datetime.now() - time_start_encode))
time_start_dsq = datetime.now()
feature = cp.asarray(feature)
D_sq = cp.linalg.norm(self.features - feature, axis=1, keepdims=True)
D_sq = cp.asnumpy(D_sq)
print("Dist computation : {}".format(datetime.now() - time_start_dsq))
time_start_uc = datetime.now()
idH = D_sq.argpartition(self.K, axis=0)[:self.K] # get smallest K distances
mappingFuncs = np.squeeze(self.mappingFuncs[idH, :])
dH = np.sqrt(
np.take_along_axis(D_sq, idH, axis=0)) # square root nearest distances to get real euclidean distances
dH = np.take_along_axis(D_sq, idH, axis=0) # square root nearest distances to get real euclidean distances
sorted_idx = dH.argsort(axis=0) # get sorting indices
idH = np.take_along_axis(idH, sorted_idx, axis=0) # sort distance indices
dH = np.take_along_axis(dH, sorted_idx, axis=0) # sort distances
weightsH = np.exp(-(np.power(dH, 2)) /
(2 * np.power(self.sigma, 2))) # compute blending weights
weightsH = weightsH / sum(weightsH) # normalize blending weights
mf = sum(np.matlib.repmat(weightsH, 1, 33) *
mf = sum(np.tile(weightsH, (1, 33)) *
mappingFuncs, 0) # compute the mapping function
mf = mf.reshape(11, 3, order="F") # reshape it to be 9 * 3
I_corr = self.colorCorrection(I, mf) # apply it!
print("Mapping and weights : {}".format(datetime.now() - time_start_uc))
time_start_c = datetime.now()
I_corr = self.colorCorrection(I, mf) # apply it!
print("Correction : {}".format(datetime.now() - time_start_c))
return I_corr

def colorCorrection(self, input, m):
""" Applies a mapping function m to a given input image. """
sz = np.shape(input) # get size of input image
I_reshaped = np.reshape(input,(int(input.size/3),3),
order="F") # reshape input to be n*3 (n: total number of pixels)
kernel_out = kernelP(I_reshaped) # raise input image to a higher-dim space
out = np.dot(kernel_out, m) # apply m to the input image after raising it the selected higher degree
sz = np.shape(input) # get size of input image
I_reshaped = np.reshape(input, (int(input.size / 3), 3),
order="F") # reshape input to be n*3 (n: total number of pixels)
shape_mat = np.shape(I_reshaped)
I_reshaped = cp.asarray(I_reshaped)
kernel_out = kernelP(I_reshaped, shape_mat) # raise input image to a higher-dim space
m = cp.asarray(m)
out = cp.dot(kernel_out, m) # apply m to the input image after raising it the selected higher degree
if self.gamut_mapping == 1:
out = normScaling(I_reshaped, out) # scaling based on input image energy
out = normScaling(I_reshaped, out) # scaling based on input image energy
elif self.gamut_mapping == 2:
out = outOfGamutClipping(out) # clip out-of-gamut pixels
out = outOfGamutClipping(out) # clip out-of-gamut pixels
else:
raise Exception('Wrong gamut_mapping value')
out = cp.asnumpy(out)
out = out.reshape(sz[0], sz[1], sz[2], order="F") # reshape output image back to the original image shape
out = cv2.cvtColor(out.astype('float32'), cv2.COLOR_RGB2BGR)
return out


def normScaling(I, I_corr):
""" Scales each pixel based on original image energy. """
norm_I_corr = np.sqrt(np.sum(np.power(I_corr, 2), 1))
norm_I_corr = cp.sqrt(cp.sum(cp.power(I_corr, 2), 1))
inds = norm_I_corr != 0
norm_I_corr = norm_I_corr[inds]
norm_I = np.sqrt(np.sum(np.power(I[inds, :],2), 1))
I_corr[inds, :] = I_corr[inds, :]/np.tile(norm_I_corr[:, np.newaxis], 3) * \
np.tile(norm_I[:, np.newaxis], 3)
norm_I = cp.sqrt(cp.sum(cp.power(I[inds, :], 2), 1))
I_corr[inds, :] = I_corr[inds, :] / cp.tile(norm_I_corr[:, cp.newaxis], 3) * \
cp.tile(norm_I[:, cp.newaxis], 3)
return I_corr


def kernelP(I):
def kernelP(I, shape_mat):
""" Kernel function: kernel(r, g, b) -> (r,g,b,rg,rb,gb,r^2,g^2,b^2,rgb,1)
Ref: Hong, et al., "A study of digital camera colorimetric characterization
based on polynomial modeling." Color Research & Application, 2001. """
return (np.transpose((I[:,0], I[:,1], I[:,2], I[:,0] * I[:,1], I[:,0] * I[:,2],
I[:,1] * I[:,2], I[:, 0] * I[:, 0], I[:, 1] * I[:, 1],
repeat_1 = np.repeat(1, shape_mat[0])
repeat_1 = cp.asarray(repeat_1)
kernel_out = cp.stack((I[:, 0], I[:, 1], I[:, 2], I[:, 0] * I[:, 1], I[:, 0] * I[:, 2],
I[:, 1] * I[:, 2], I[:, 0] * I[:, 0], I[:, 1] * I[:, 1],
I[:, 2] * I[:, 2], I[:, 0] * I[:, 1] * I[:, 2],
repeat_1))
return (cp.transpose(kernel_out))


def kernelP_native(I):
""" Kernel function: kernel(r, g, b) -> (r,g,b,rg,rb,gb,r^2,g^2,b^2,rgb,1)
Ref: Hong, et al., "A study of digital camera colorimetric characterization
based on polynomial modeling." Color Research & Application, 2001. """
return (np.transpose((I[:, 0], I[:, 1], I[:, 2], I[:, 0] * I[:, 1], I[:, 0] * I[:, 2],
I[:, 1] * I[:, 2], I[:, 0] * I[:, 0], I[:, 1] * I[:, 1],
I[:, 2] * I[:, 2], I[:, 0] * I[:, 1] * I[:, 2],
np.repeat(1,np.shape(I)[0]))))
np.repeat(1, np.shape(I)[0]))))


def outOfGamutClipping(I):
""" Clips out-of-gamut pixels. """
I[I > 1] = 1 # any pixel is higher than 1, clip it to 1
I[I < 0] = 0 # any pixel is below 0, clip it to 0
I[I > 1] = 1 # any pixel is higher than 1, clip it to 1
I[I < 0] = 0 # any pixel is below 0, clip it to 0
return I


def im2double(im):
""" Returns a double image [0,1] of the uint8 im [0,255]. """
return cv2.normalize(im.astype('float'), None, 0.0, 1.0, cv2.NORM_MINMAX)
return cv2.normalize(im.astype('float'), None, 0.0, 1.0, cv2.NORM_MINMAX)