Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized speed of Python implementation #11

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 42 additions & 29 deletions WB_sRGB_Python/classes/WBsRGB.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@


import numpy as np
import numpy.matlib
import cv2


Expand All @@ -41,6 +40,10 @@ def __init__(self, gamut_mapping=2, upgraded=0):
# our results reported with gamut_mapping=2, however gamut_mapping=1
# gives more compelling results with over-saturated examples.
self.gamut_mapping = gamut_mapping # options: 1 scaling, 2 clipping
# precompute the norm of all features for later use
self.features_norm = np.einsum('ij, ij ->i', self.features,
self.features)[:, None]


def encode(self, hist):
""" Generates a compacted feature of a given RGB-uv histogram tensor."""
Expand All @@ -64,19 +67,16 @@ def rgb_uv_hist(self, I):
newH = int(np.floor(sz[0] * factor))
newW = int(np.floor(sz[1] * factor))
I = cv2.resize(I, (newW, newH), interpolation=cv2.INTER_NEAREST)
I_reshaped = I[(I > 0).all(axis=2)]
eps = 6.4 / self.h
hist = np.zeros((self.h, self.h, 3)) # histogram will be stored here
Iy = np.linalg.norm(I_reshaped, axis=1) # intensity vector
I_reshaped = I.reshape(-1,3).T.copy() # reshaped and transposed
I_reshaped = I_reshaped[:,(I_reshaped>0).all(0)].copy()
hist = np.zeros((self.h, self.h, 3), dtype=np.float32) # histogram will be stored here
Iy = np.linalg.norm(I_reshaped, axis=0) # intensity vector
I_reshaped_log = np.log(I_reshaped)
for i in range(3): # for each histogram layer, do
r = [] # excluded channels will be stored here
for j in range(3): # for each color channel do
if j != i:
r.append(j)
Iu = np.log(I_reshaped[:, i] / I_reshaped[:, r[1]])
Iv = np.log(I_reshaped[:, i] / I_reshaped[:, r[0]])
hist[:, :, i], _, _ = np.histogram2d(
Iu, Iv, bins=self.h, range=((-3.2 - eps / 2, 3.2 - eps / 2),) * 2, weights=Iy)
r = [j for j in range(3) if i!=j] # excluded channels
Iu = I_reshaped_log[i] - I_reshaped_log[r[1]]
Iv = I_reshaped_log[i] - I_reshaped_log[r[0]]
hist[:, :, i] = hist2d(Iv, Iu, Iy, (-3.2, 3.2), self.h)
norm_ = hist[:, :, i].sum()
hist[:, :, i] = np.sqrt(hist[:, :, i] / norm_) # (hist/norm)^(1/2)
return hist
Expand All @@ -92,30 +92,26 @@ def correctImage(self, I):
# feature_diff = self.features - feature
# D_sq = np.einsum('ij,ij->i', feature_diff, feature_diff)[:, None]
# ```
D_sq = np.einsum(
'ij, ij ->i', self.features, self.features)[:, None] + np.einsum(
D_sq = self.features_norm + np.einsum(
'ij, ij ->i', feature, feature) - 2 * self.features.dot(feature.T)

# get smallest K distances
idH = D_sq.argpartition(self.K, axis=0)[:self.K]
mappingFuncs = np.squeeze(self.mappingFuncs[idH, :])
dH = np.sqrt(
np.take_along_axis(D_sq, idH, axis=0))
dH = np.sqrt(np.take_along_axis(D_sq, idH, axis=0))
weightsH = np.exp(-(np.power(dH, 2)) /
(2 * np.power(self.sigma, 2))) # compute weights
weightsH = weightsH / sum(weightsH) # normalize blending weights
mf = sum(np.matlib.repmat(weightsH, 1, 33) *
mappingFuncs, 0) # compute the mapping function
mf = weightsH.T.dot(mappingFuncs) # compute the mapping function
mf = mf.reshape(11, 3, order="F") # reshape it to be 9 * 3
I_corr = self.colorCorrection(I, mf) # apply it!
return I_corr

def colorCorrection(self, input, m):
""" Applies a mapping function m to a given input image. """
sz = np.shape(input) # get size of input image
I_reshaped = np.reshape(input, (int(input.size / 3), 3), order="F")
I_reshaped = np.reshape(input, (-1, 3)).T # transposed for speed
kernel_out = kernelP(I_reshaped)
out = np.dot(kernel_out, m)
out = m.T.dot(kernel_out).T
if self.gamut_mapping == 1:
# scaling based on input image energy
out = normScaling(I_reshaped, out)
Expand All @@ -125,8 +121,8 @@ def colorCorrection(self, input, m):
else:
raise Exception('Wrong gamut_mapping value')
# reshape output image back to the original image shape
out = out.reshape(sz[0], sz[1], sz[2], order="F")
out = out.astype('float32')[..., ::-1] # convert from BGR to RGB
out = out.reshape(sz)
out = out[..., ::-1] # convert from BGR to RGB
return out


Expand All @@ -146,10 +142,16 @@ def kernelP(rgb):
Ref: Hong, et al., "A study of digital camera colorimetric
characterization based on polynomial modeling." Color Research &
Application, 2001. """
r, g, b = np.split(rgb, 3, axis=1)
return np.concatenate(
[rgb, r * g, r * b, g * b, rgb ** 2, r * g * b, np.ones_like(r)], axis=1)

r, g, b = (rgb[0], rgb[1], rgb[2])
out = np.empty((11, rgb.shape[1]), dtype=rgb.dtype)
out[:3, :] = rgb
out[3, :] = r*g
out[4, :] = r*b
out[5, :] = g*b
out[6:9, :] = rgb*rgb
out[9, :] = r*g*b
out[10, :] = np.ones_like(r)
return out

def outOfGamutClipping(I):
""" Clips out-of-gamut pixels. """
Expand All @@ -160,4 +162,15 @@ def outOfGamutClipping(I):

def im2double(im):
""" Returns a double image [0,1] of the uint8 im [0,255]. """
return cv2.normalize(im.astype('float'), None, 0.0, 1.0, cv2.NORM_MINMAX)
return cv2.normalize(im, None, 0.0, 1.0, cv2.NORM_MINMAX, cv2.CV_32F)


def hist2d(x, y, weight, limits, bins):
""" Computes a 2D histogram of values using only numpy"""
eps = (limits[1]-limits[0]) / bins
lower_lim = limits[0]-eps/2
y = np.floor((y-lower_lim)/eps).astype(np.int16)
x = np.floor((x-lower_lim)/eps).astype(np.int16)
valid = (0<=x)*(x<bins)*(0<=y)*(y<bins)
hist = np.bincount(y[valid]*bins+x[valid], weight[valid], bins**2)
return hist.reshape(bins, bins)