-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_processing.py
executable file
·132 lines (105 loc) · 5.46 KB
/
image_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from __future__ import absolute_import
from __future__ import division
import tensorflow as tf
def deprocess(image):
with tf.name_scope("deprocess"):
# [-1, 1] => [0, 1]
return (image + 1) / 2
def preprocess_lab(lab):
with tf.name_scope("preprocess_lab"):
L_chan, a_chan, b_chan = tf.unstack(lab, axis=2)
# L_chan: black and white with input range [0, 100]
# a_chan/b_chan: color channels with input range ~[-110, 110], not exact
# [0, 100] => [-1, 1], ~[-110, 110] => [-1, 1]
return [L_chan / 50 - 1, a_chan / 110, b_chan / 110]
def deprocess_lab(L_chan, a_chan, b_chan):
with tf.name_scope("deprocess_lab"):
# this is axis=3 instead of axis=2 because we process individual images but deprocess batches
return tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3)
def augment(image, brightness):
# (a, b) color channels, combine with L channel and convert to rgb
a_chan, b_chan = tf.unstack(image, axis=3)
L_chan = tf.squeeze(brightness, axis=3)
lab = deprocess_lab(L_chan, a_chan, b_chan)
rgb = lab_to_rgb(lab)
return rgb
def check_image(image):
assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels")
with tf.control_dependencies([assertion]):
image = tf.identity(image)
if image.get_shape().ndims not in (3, 4):
raise ValueError("image must be either 3 or 4 dimensions")
# make the last dimension 3 so that you can unstack the colors
shape = list(image.get_shape())
shape[-1] = 3
image.set_shape(shape)
return image
# based on https://github.com/torch/image/blob/9f65c30167b2048ecbe8b7befdc6b2d6d12baee9/generic/image.c
def rgb_to_lab(srgb):
with tf.name_scope("rgb_to_lab"):
srgb = check_image(srgb)
srgb_pixels = tf.reshape(srgb, [-1, 3])
with tf.name_scope("srgb_to_xyz"):
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
rgb_to_xyz = tf.constant([
# X Y Z
[0.412453, 0.212671, 0.019334], # R
[0.357580, 0.715160, 0.119193], # G
[0.180423, 0.072169, 0.950227], # B
])
xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("xyz_to_cielab"):
# convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)
# normalize for D65 white point
xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
epsilon = 6/29
linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
# convert to lab
fxfyfz_to_lab = tf.constant([
# l a b
[ 0.0, 500.0, 0.0], # fx
[116.0, -500.0, 200.0], # fy
[ 0.0, 0.0, -200.0], # fz
])
lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
return tf.reshape(lab_pixels, tf.shape(srgb))
def lab_to_rgb(lab):
with tf.name_scope("lab_to_rgb"):
lab = check_image(lab)
lab_pixels = tf.reshape(lab, [-1, 3])
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("cielab_to_xyz"):
# convert to fxfyfz
lab_to_fxfyfz = tf.constant([
# fx fy fz
[1/116.0, 1/116.0, 1/116.0], # l
[1/500.0, 0.0, 0.0], # a
[ 0.0, 0.0, -1/200.0], # b
])
fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)
# convert to xyz
epsilon = 6/29
linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask
# denormalize for D65 white point
xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
with tf.name_scope("xyz_to_srgb"):
xyz_to_rgb = tf.constant([
# r g b
[ 3.2404542, -0.9692660, 0.0556434], # x
[-1.5371385, 1.8760108, -0.2040259], # y
[-0.4985314, 0.0415560, 1.0572252], # z
])
rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
# avoid a slightly negative number messing up the conversion
rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask
return tf.reshape(srgb_pixels, tf.shape(lab))