Skip to content

Commit 3fc1bad

Browse files
author
NoobMaster
authored
Add equalize image op (#1413)
1 parent 0df86d7 commit 3fc1bad

File tree

5 files changed

+231
-60
lines changed

5 files changed

+231
-60
lines changed

.github/CODEOWNERS

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
/tensorflow_addons/callbacks/time_stopping*.py @shun-lin
3737
/tensorflow_addons/callbacks/tqdm_progress_bar*.py @shun-lin
3838

39+
/tensorflow_addons/image/color_ops*.py @abhichou4
3940
/tensorflow_addons/image/connected_components*.py @sayoojbk
4041
/tensorflow_addons/image/cutout_ops*.py @fsx950223
4142
/tensorflow_addons/image/dense_image_warp*.py @windQAQ

tensorflow_addons/image/BUILD

+65-52
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,20 @@ py_library(
66
name = "image",
77
srcs = ([
88
"__init__.py",
9+
"color_ops.py",
10+
"compose_ops.py",
11+
"connected_components.py",
12+
"cutout_ops.py",
913
"dense_image_warp.py",
1014
"distance_transform.py",
1115
"distort_image_ops.py",
1216
"filters.py",
17+
"interpolate_spline.py",
18+
"resampler_ops.py",
19+
"sparse_image_warp.py",
1320
"transform_ops.py",
1421
"translate_ops.py",
1522
"utils.py",
16-
"sparse_image_warp.py",
17-
"interpolate_spline.py",
18-
"connected_components.py",
19-
"resampler_ops.py",
20-
"compose_ops.py",
21-
"cutout_ops.py",
2223
]),
2324
data = [
2425
":sparse_image_warp_test_data",
@@ -35,157 +36,169 @@ filegroup(
3536
)
3637

3738
py_test(
38-
name = "dense_image_warp_test",
39+
name = "color_ops_test",
3940
size = "small",
4041
srcs = [
41-
"dense_image_warp_test.py",
42+
"color_ops_test.py",
4243
],
43-
main = "dense_image_warp_test.py",
44+
main = "color_ops_test.py",
4445
deps = [
4546
":image",
4647
],
4748
)
4849

4950
py_test(
50-
name = "distance_transform_ops_test",
51-
size = "small",
51+
name = "compose_ops_test",
52+
size = "medium",
5253
srcs = [
53-
"distance_transform_test.py",
54+
"compose_ops_test.py",
5455
],
55-
main = "distance_transform_test.py",
56+
main = "compose_ops_test.py",
5657
deps = [
5758
":image",
5859
],
5960
)
6061

6162
py_test(
62-
name = "distort_image_ops_test",
63-
size = "small",
63+
name = "connected_components_test",
64+
size = "medium",
6465
srcs = [
65-
"distort_image_ops_test.py",
66+
"connected_components_test.py",
6667
],
67-
main = "distort_image_ops_test.py",
68+
main = "connected_components_test.py",
6869
deps = [
6970
":image",
7071
],
7172
)
7273

7374
py_test(
74-
name = "filters_test",
75-
size = "medium",
75+
name = "cutout_ops_test",
76+
size = "small",
7677
srcs = [
77-
"filters_test.py",
78+
"cutout_ops_test.py",
7879
],
79-
flaky = True,
80-
main = "filters_test.py",
80+
main = "cutout_ops_test.py",
8181
deps = [
8282
":image",
8383
],
8484
)
8585

8686
py_test(
87-
name = "transform_ops_test",
88-
size = "medium",
87+
name = "dense_image_warp_test",
88+
size = "small",
8989
srcs = [
90-
"transform_ops_test.py",
90+
"dense_image_warp_test.py",
9191
],
92-
main = "transform_ops_test.py",
92+
main = "dense_image_warp_test.py",
9393
deps = [
9494
":image",
9595
],
9696
)
9797

9898
py_test(
99-
name = "translate_ops_test",
100-
size = "medium",
99+
name = "distance_transform_ops_test",
100+
size = "small",
101101
srcs = [
102-
"translate_ops_test.py",
102+
"distance_transform_test.py",
103103
],
104-
main = "translate_ops_test.py",
104+
main = "distance_transform_test.py",
105105
deps = [
106106
":image",
107107
],
108108
)
109109

110110
py_test(
111-
name = "utils_test",
111+
name = "distort_image_ops_test",
112112
size = "small",
113113
srcs = [
114-
"utils_test.py",
114+
"distort_image_ops_test.py",
115115
],
116-
main = "utils_test.py",
116+
main = "distort_image_ops_test.py",
117117
deps = [
118118
":image",
119119
],
120120
)
121121

122122
py_test(
123-
name = "cutout_ops_test",
124-
size = "small",
123+
name = "filters_test",
124+
size = "medium",
125125
srcs = [
126-
"cutout_ops_test.py",
126+
"filters_test.py",
127127
],
128-
main = "cutout_ops_test.py",
128+
flaky = True,
129+
main = "filters_test.py",
129130
deps = [
130131
":image",
131132
],
132133
)
133134

134135
py_test(
135-
name = "sparse_image_warp_test",
136+
name = "interpolate_spline_test",
136137
size = "medium",
137138
srcs = [
138-
"sparse_image_warp_test.py",
139+
"interpolate_spline_test.py",
139140
],
140-
main = "sparse_image_warp_test.py",
141+
main = "interpolate_spline_test.py",
141142
deps = [
142143
":image",
143144
],
144145
)
145146

146147
py_test(
147-
name = "interpolate_spline_test",
148+
name = "resampler_ops_test",
148149
size = "medium",
149150
srcs = [
150-
"interpolate_spline_test.py",
151+
"resampler_ops_test.py",
151152
],
152-
main = "interpolate_spline_test.py",
153+
main = "resampler_ops_test.py",
153154
deps = [
154155
":image",
155156
],
156157
)
157158

158159
py_test(
159-
name = "connected_components_test",
160+
name = "sparse_image_warp_test",
160161
size = "medium",
161162
srcs = [
162-
"connected_components_test.py",
163+
"sparse_image_warp_test.py",
163164
],
164-
main = "connected_components_test.py",
165+
main = "sparse_image_warp_test.py",
165166
deps = [
166167
":image",
167168
],
168169
)
169170

170171
py_test(
171-
name = "resampler_ops_test",
172+
name = "transform_ops_test",
172173
size = "medium",
173174
srcs = [
174-
"resampler_ops_test.py",
175+
"transform_ops_test.py",
175176
],
176-
main = "resampler_ops_test.py",
177+
main = "transform_ops_test.py",
177178
deps = [
178179
":image",
179180
],
180181
)
181182

182183
py_test(
183-
name = "compose_ops_test",
184+
name = "translate_ops_test",
184185
size = "medium",
185186
srcs = [
186-
"compose_ops_test.py",
187+
"translate_ops_test.py",
187188
],
188-
main = "compose_ops_test.py",
189+
main = "translate_ops_test.py",
190+
deps = [
191+
":image",
192+
],
193+
)
194+
195+
py_test(
196+
name = "utils_test",
197+
size = "small",
198+
srcs = [
199+
"utils_test.py",
200+
],
201+
main = "utils_test.py",
189202
deps = [
190203
":image",
191204
],

tensorflow_addons/image/__init__.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,24 @@
1414
# ==============================================================================
1515
"""Additional image manipulation ops."""
1616

17+
from tensorflow_addons.image.distort_image_ops import adjust_hsv_in_yiq
18+
from tensorflow_addons.image.compose_ops import blend
19+
from tensorflow_addons.image.color_ops import equalize
1720
from tensorflow_addons.image.connected_components import connected_components
21+
from tensorflow_addons.image.cutout_ops import cutout
1822
from tensorflow_addons.image.dense_image_warp import dense_image_warp
19-
from tensorflow_addons.image.dense_image_warp import interpolate_bilinear
2023
from tensorflow_addons.image.distance_transform import euclidean_dist_transform
21-
from tensorflow_addons.image.distort_image_ops import adjust_hsv_in_yiq
22-
from tensorflow_addons.image.distort_image_ops import random_hsv_in_yiq
24+
from tensorflow_addons.image.dense_image_warp import interpolate_bilinear
25+
from tensorflow_addons.image.interpolate_spline import interpolate_spline
2326
from tensorflow_addons.image.filters import mean_filter2d
2427
from tensorflow_addons.image.filters import median_filter2d
25-
from tensorflow_addons.image.interpolate_spline import interpolate_spline
28+
from tensorflow_addons.image.cutout_ops import random_cutout
29+
from tensorflow_addons.image.distort_image_ops import random_hsv_in_yiq
2630
from tensorflow_addons.image.resampler_ops import resampler
27-
from tensorflow_addons.image.sparse_image_warp import sparse_image_warp
2831
from tensorflow_addons.image.transform_ops import rotate
29-
from tensorflow_addons.image.transform_ops import transform
3032
from tensorflow_addons.image.transform_ops import shear_x
3133
from tensorflow_addons.image.transform_ops import shear_y
34+
from tensorflow_addons.image.sparse_image_warp import sparse_image_warp
35+
from tensorflow_addons.image.transform_ops import transform
3236
from tensorflow_addons.image.translate_ops import translate
3337
from tensorflow_addons.image.translate_ops import translate_xy
34-
from tensorflow_addons.image.compose_ops import blend
35-
from tensorflow_addons.image.cutout_ops import random_cutout, cutout

tensorflow_addons/image/color_ops.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Color operations.
16+
equalize: Equalizes image histogram
17+
"""
18+
19+
import tensorflow as tf
20+
21+
from tensorflow_addons.utils.types import TensorLike
22+
from tensorflow_addons.image.utils import to_4D_image, from_4D_image
23+
24+
from typing import Optional
25+
from functools import partial
26+
27+
28+
def equalize_image(image: TensorLike, data_format: str = "channels_last") -> tf.Tensor:
29+
"""Implements Equalize function from PIL using TF ops."""
30+
31+
def scale_channel(image, channel):
32+
"""Scale the data in the channel to implement equalize."""
33+
image_dtype = image.dtype
34+
35+
if data_format == "channels_last":
36+
image = tf.cast(image[:, :, channel], tf.int32)
37+
elif data_format == "channels_first":
38+
image = tf.cast(image[channel], tf.int32)
39+
else:
40+
raise ValueError(
41+
"data_format can either be channels_last or channels_first"
42+
)
43+
# Compute the histogram of the image channel.
44+
histo = tf.histogram_fixed_width(image, [0, 255], nbins=256)
45+
46+
# For the purposes of computing the step, filter out the nonzeros.
47+
nonzero = tf.where(tf.not_equal(histo, 0))
48+
nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
49+
step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255
50+
51+
def build_lut(histo, step):
52+
# Compute the cumulative sum, shifting by step // 2
53+
# and then normalization by step.
54+
lut = (tf.cumsum(histo) + (step // 2)) // step
55+
# Shift lut, prepending with 0.
56+
lut = tf.concat([[0], lut[:-1]], 0)
57+
# Clip the counts to be in range. This is done
58+
# in the C code for image.point.
59+
return tf.clip_by_value(lut, 0, 255)
60+
61+
# If step is zero, return the original image. Otherwise, build
62+
# lut from the full histogram and step and then index from it.
63+
64+
if step == 0:
65+
result = image
66+
else:
67+
result = tf.gather(build_lut(histo, step), image)
68+
69+
return tf.cast(result, image_dtype)
70+
71+
idx = 2 if data_format == "channels_last" else 0
72+
image = tf.stack([scale_channel(image, c) for c in range(image.shape[idx])], idx)
73+
74+
return image
75+
76+
77+
def equalize(
78+
image: TensorLike, data_format: str = "channels_last", name: Optional[str] = None
79+
) -> tf.Tensor:
80+
"""Equalize image(s)
81+
82+
Args:
83+
images: A tensor of shape
84+
(num_images, num_rows, num_columns, num_channels) (NHWC), or
85+
(num_images, num_channels, num_rows, num_columns) (NCHW), or
86+
(num_rows, num_columns, num_channels) (HWC), or
87+
(num_channels, num_rows, num_columns) (HWC), or
88+
(num_rows, num_columns) (HW). The rank must be statically known (the
89+
shape is not `TensorShape(None)`).
90+
data_format: Either 'channels_first' or 'channels_last'
91+
name: The name of the op.
92+
Returns:
93+
Image(s) with the same type and shape as `images`, equalized.
94+
"""
95+
with tf.name_scope(name or "equalize"):
96+
image_dims = tf.rank(image)
97+
image = to_4D_image(image)
98+
fn = partial(equalize_image, data_format=data_format)
99+
image = tf.map_fn(fn, image)
100+
return from_4D_image(image, image_dims)

0 commit comments

Comments
 (0)