diff --git a/tensorflow_addons/image/color_ops.py b/tensorflow_addons/image/color_ops.py index dc40f66611..dd9d4b47a6 100644 --- a/tensorflow_addons/image/color_ops.py +++ b/tensorflow_addons/image/color_ops.py @@ -14,12 +14,14 @@ # ============================================================================== """Color operations. equalize: Equalizes image histogram + sharpness: Sharpen image """ import tensorflow as tf -from tensorflow_addons.utils.types import TensorLike +from tensorflow_addons.utils.types import TensorLike, Number from tensorflow_addons.image.utils import to_4D_image, from_4D_image +from tensorflow_addons.image.compose_ops import blend from typing import Optional from functools import partial @@ -84,7 +86,7 @@ def equalize( (num_images, num_rows, num_columns, num_channels) (NHWC), or (num_images, num_channels, num_rows, num_columns) (NCHW), or (num_rows, num_columns, num_channels) (HWC), or - (num_channels, num_rows, num_columns) (HWC), or + (num_channels, num_rows, num_columns) (CHW), or (num_rows, num_columns) (HW). The rank must be statically known (the shape is not `TensorShape(None)`). data_format: Either 'channels_first' or 'channels_last' @@ -98,3 +100,55 @@ def equalize( fn = partial(equalize_image, data_format=data_format) image = tf.map_fn(fn, image) return from_4D_image(image, image_dims) + + +def sharpness_image(image: TensorLike, factor: Number) -> tf.Tensor: + """Implements Sharpness function from PIL using TF ops.""" + orig_image = image + image_dtype = image.dtype + # Make image 4D for conv operation. + image = tf.expand_dims(image, 0) + # SMOOTH PIL Kernel. + image = tf.cast(image, tf.float32) + kernel = ( + tf.constant( + [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, shape=[3, 3, 1, 1] + ) + / 13.0 + ) + # Tile across channel dimension. + kernel = tf.tile(kernel, [1, 1, 3, 1]) + strides = [1, 1, 1, 1] + degenerate = tf.nn.depthwise_conv2d( + image, kernel, strides, padding="VALID", dilations=[1, 1] + ) + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.squeeze(tf.cast(degenerate, image_dtype), [0]) + + # For the borders of the resulting image, fill in the values of the + # original image. + mask = tf.ones_like(degenerate) + padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]]) + padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]]) + result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image) + # Blend the final result. + blended = blend(result, orig_image, factor) + return tf.cast(blended, image_dtype) + + +def sharpness(image: TensorLike, factor: Number) -> tf.Tensor: + """Change sharpness of image(s) + + Args: + images: A tensor of shape + (num_images, num_rows, num_columns, num_channels) (NHWC), or + (num_rows, num_columns, num_channels) (HWC) + factor: A floating point value or Tensor above 0.0. + Returns: + Image(s) with the same type and shape as `images`, sharper. + """ + image_dims = tf.rank(image) + image = to_4D_image(image) + fn = partial(sharpness_image, factor=factor) + image = tf.map_fn(fn, image) + return from_4D_image(image, image_dims) diff --git a/tensorflow_addons/image/tests/color_ops_test.py b/tensorflow_addons/image/tests/color_ops_test.py index 780066c364..8484f97548 100644 --- a/tensorflow_addons/image/tests/color_ops_test.py +++ b/tensorflow_addons/image/tests/color_ops_test.py @@ -19,7 +19,7 @@ import numpy as np from tensorflow_addons.image import color_ops -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageEnhance _DTYPES = { np.uint8, @@ -53,3 +53,24 @@ def test_equalize_channel_first(shape): image = tf.ones(shape=shape, dtype=tf.uint8) equalized = color_ops.equalize(image, "channels_first") np.testing.assert_equal(equalized.numpy(), image.numpy()) + + +@pytest.mark.parametrize("dtype", _DTYPES) +@pytest.mark.parametrize("shape", [(5, 5, 3), (10, 5, 5, 3)]) +def test_sharpness_dtype_shape(dtype, shape): + image = np.ones(shape=shape, dtype=dtype) + sharp = color_ops.sharpness(tf.constant(image), 0).numpy() + np.testing.assert_equal(sharp, image) + assert sharp.dtype == image.dtype + + +@pytest.mark.parametrize("factor", [0, 0.25, 0.5, 0.75, 1]) +def test_sharpness_with_PIL(factor): + np.random.seed(0) + image = np.random.randint(low=0, high=255, size=(10, 5, 5, 3), dtype=np.uint8) + sharpened = np.stack( + [ImageEnhance.Sharpness(Image.fromarray(i)).enhance(factor) for i in image] + ) + np.testing.assert_allclose( + color_ops.sharpness(tf.constant(image), factor).numpy(), sharpened, atol=1 + )