Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(layers): Migrate RandomChoice and RandomApply layers from keras-cv to keras #20752

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
6 changes: 6 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@
from keras.src.layers.preprocessing.image_preprocessing.solarization import (
Solarization,
)
from keras.src.layers.preprocessing.image_preprocessing.random_choice import (
RandomChoice,
)
from keras.src.layers.preprocessing.image_preprocessing.random_apply import (
RandomApply,
)
from keras.src.layers.preprocessing.index_lookup import IndexLookup
from keras.src.layers.preprocessing.integer_lookup import IntegerLookup
from keras.src.layers.preprocessing.mel_spectrogram import MelSpectrogram
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@
densify_bounding_boxes,
)
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer

import tensorflow as tf

class BaseImagePreprocessingLayer(TFDataLayer):
_USE_BASE_FACTOR = True
_FACTOR_BOUNDS = (-1, 1)

def __init__(
self, factor=None, bounding_box_format=None, data_format=None, **kwargs
self, factor=None, bounding_box_format=None, data_format=None, seed=None, **kwargs
):
if seed is None:
self.random_generator = tf.random.Generator.from_non_deterministic_state()
else:
self.random_generator = tf.random.Generator.from_seed(seed)
self.seed = seed
super().__init__(**kwargs)
self.bounding_box_format = bounding_box_format
self.data_format = backend_config.standardize_data_format(data_format)
Expand Down
192 changes: 192 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/random_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright 2022 The KerasCV Authors
harshaljanjani marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
import tensorflow as tf

@keras_export("keras.layers.RandomApply")
class RandomApply(BaseImagePreprocessingLayer):
harshaljanjani marked this conversation as resolved.
Show resolved Hide resolved
"""Apply provided layer to random elements in a batch.

Args:
layer: a keras `Layer` or `BaseImagePreprocessingLayer`. This layer will
be applied to randomly chosen samples in a batch. Layer should not
modify the size of provided inputs.
rate: controls the frequency of applying the layer. 1.0 means all
elements in a batch will be modified. 0.0 means no elements will be
modified. Defaults to 0.5.
batchwise: (Optional) bool, whether to pass entire batches to the
underlying layer. When set to true, only a single random sample is
drawn to determine if the batch should be passed to the underlying
layer.
auto_vectorize: bool, whether to use tf.vectorized_map or tf.map_fn for
batched input. Setting this to True might give better performance
but currently doesn't work with XLA. Defaults to False.
seed: integer, controls random behaviour.

Example:
```
# Let's declare an example layer that will set all image pixels to zero.
zero_out = keras.layers.Lambda(lambda x: {"images": 0 * x["images"]})

# Create a small batch of random, single-channel, 2x2 images:
images = tf.random.stateless_uniform(shape=(5, 2, 2, 1), seed=[0, 1])
print(images[..., 0])
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy=
harshaljanjani marked this conversation as resolved.
Show resolved Hide resolved
# array([[[0.08216608, 0.40928006],
# [0.39318466, 0.3162533 ]],
#
# [[0.34717774, 0.73199546],
# [0.56369007, 0.9769211 ]],
#
# [[0.55243933, 0.13101244],
# [0.2941643 , 0.5130266 ]],
#
# [[0.38977218, 0.80855536],
# [0.6040567 , 0.10502195]],
#
# [[0.51828027, 0.12730157],
# [0.288486 , 0.252975 ]]], dtype=float32)>

# Apply the layer with 50% probability:
random_apply = RandomApply(layer=zero_out, rate=0.5, seed=1234)
outputs = random_apply(images)
print(outputs[..., 0])
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy=
# array([[[0. , 0. ],
# [0. , 0. ]],
#
# [[0.34717774, 0.73199546],
# [0.56369007, 0.9769211 ]],
#
# [[0.55243933, 0.13101244],
# [0.2941643 , 0.5130266 ]],
#
# [[0.38977218, 0.80855536],
# [0.6040567 , 0.10502195]],
#
# [[0. , 0. ],
# [0. , 0. ]]], dtype=float32)>

# We can observe that the layer has been randomly applied to 2 out of 5
samples.
```
"""
def __init__(
self,
layer,
rate=0.5,
batchwise=False,
auto_vectorize=False,
harshaljanjani marked this conversation as resolved.
Show resolved Hide resolved
seed=None,
**kwargs,
):
super().__init__(seed=seed, **kwargs)
if not (0 <= rate <= 1.0):
raise ValueError(f"rate must be in range [0, 1]. Received rate: {rate}")
self._layer = layer
self._rate = rate
self.auto_vectorize = auto_vectorize
self.batchwise = batchwise
self.seed = seed
self.built = True

def _get_should_augment(self, inputs):
input_shape = tf.shape(inputs)

if self.batchwise:
return self._rate > tf.random.uniform(shape=(), seed=self.seed)

batch_size = input_shape[0]
random_values = tf.random.uniform(shape=(batch_size,), seed=self.seed)
harshaljanjani marked this conversation as resolved.
Show resolved Hide resolved
should_augment = random_values < self._rate

ndims = tf.rank(inputs)
broadcast_shape = tf.concat(
[input_shape[:1], tf.ones(ndims - 1, dtype=tf.int32)],
axis=0
)
return tf.reshape(should_augment, broadcast_shape)

def _augment_single(self, inputs):
random_value = tf.random.uniform(shape=(), seed=self.seed)
should_augment = random_value < self._rate

def apply_layer():
return self._layer(inputs)

def return_inputs():
return inputs

return tf.cond(should_augment, apply_layer, return_inputs)

def _augment_batch(self, inputs):
should_augment = self._get_should_augment(inputs)
augmented = self._layer(inputs)
return tf.where(should_augment, augmented, inputs)

def call(self, inputs):
if isinstance(inputs, dict):
return {key: self._call_single(input_tensor) for
key, input_tensor in inputs.items()}
else:
return self._call_single(inputs)

def _call_single(self, inputs):
inputs_rank = tf.rank(inputs)
is_single_sample = tf.equal(inputs_rank, 3)
is_batch = tf.equal(inputs_rank, 4)

def augment_single():
return self._augment_single(inputs)

def augment_batch():
return self._augment_batch(inputs)

condition = tf.logical_or(is_single_sample, is_batch)
return tf.cond(tf.reduce_all(condition), augment_batch, augment_single)

def transform_images(self, images, transformation=None, training=True):
if not training:
return images
return self.call(images)

def transform_labels(self, labels, transformation=None, training=True):
if not training:
return labels
return self.call(labels)

def transform_bounding_boxes(self, bboxes, transformation=None, training=True):
if not training:
return bboxes
return self.call(bboxes)

def transform_segmentation_masks(self, masks, transformation=None, training=True):
if not training:
return masks
return self.call(masks)

def get_config(self):
config = super().get_config()
config.update({
"rate": self._rate,
"layer": self._layer,
"seed": self.seed,
"batchwise": self.batchwise,
"auto_vectorize": self.auto_vectorize,
})
return config
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import tensorflow as tf
from absl.testing import parameterized

from keras.src import layers
from keras.src import ops
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.testing import TestCase


class ZeroOut(BaseImagePreprocessingLayer):
"""Layer that zeros out tensors."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.built = True

def call(self, inputs):
return tf.zeros_like(inputs)

def transform_images(self, images, transformation=None, training=True):
return tf.zeros_like(images)

def transform_segmentation_masks(self, masks, transformation=None, training=True):
return tf.zeros_like(masks)

def transform_bounding_boxes(self, bboxes, transformation=None, training=True):
return tf.zeros_like(bboxes)

def transform_labels(self, labels, transformation=None, training=True):
return tf.zeros_like(labels)

def get_config(self):
return super().get_config()


class RandomApplyTest(TestCase):
rng = tf.random.Generator.from_seed(seed=1234)

@parameterized.parameters([-0.5, 1.7])
def test_raises_error_on_invalid_rate_parameter(self, invalid_rate):
with self.assertRaises(ValueError):
layers.RandomApply(rate=invalid_rate, layer=ZeroOut())

def test_works_with_batched_input(self):
batch_size = 32
dummy_inputs = self.rng.uniform(shape=(batch_size, 224, 224, 3))
layer = layers.RandomApply(rate=0.5, layer=ZeroOut(), seed=1234)

outputs = ops.convert_to_numpy(layer(dummy_inputs))
num_zero_inputs = self._num_zero_batches(dummy_inputs)
num_zero_outputs = self._num_zero_batches(outputs)

self.assertEqual(num_zero_inputs, 0)
self.assertLess(num_zero_outputs, batch_size)
self.assertGreater(num_zero_outputs, 0)

def test_works_with_batchwise_layers(self):
batch_size = 32
dummy_inputs = self.rng.uniform(shape=(batch_size, 224, 224, 3))
random_flip_layer = layers.RandomFlip("vertical", data_format="channels_last", seed=42)
layer = layers.RandomApply(random_flip_layer, rate=0.5, batchwise=True)
outputs = layer(dummy_inputs)
self.assertEqual(outputs.shape, dummy_inputs.shape)

@staticmethod
def _num_zero_batches(images):
num_batches = tf.shape(images)[0]
num_non_zero_batches = tf.math.count_nonzero(
tf.math.count_nonzero(images, axis=[1, 2, 3]), dtype=tf.int32
)
return num_batches - num_non_zero_batches

def test_inputs_unchanged_with_zero_rate(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
layer = layers.RandomApply(rate=0.0, layer=ZeroOut())

outputs = layer(dummy_inputs)

self.assertAllClose(outputs, dummy_inputs)

def test_all_inputs_changed_with_rate_equal_to_one(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
layer = layers.RandomApply(rate=1.0, layer=ZeroOut())
outputs = layer(dummy_inputs)
tf.reduce_all(tf.equal(outputs, tf.zeros_like(dummy_inputs)))

def test_works_with_single_image(self):
dummy_inputs = self.rng.uniform(shape=(224, 224, 3))
layer = layers.RandomApply(rate=1.0, layer=ZeroOut())
outputs = layer(dummy_inputs)
tf.reduce_all(tf.equal(outputs, tf.zeros_like(dummy_inputs)))

def test_can_modify_label(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
dummy_labels = tf.ones(shape=(32, 2))
layer = layers.RandomApply(rate=1.0, layer=ZeroOut())
outputs = layer({"images": dummy_inputs, "labels": dummy_labels})
tf.reduce_all(tf.equal(outputs["labels"], tf.zeros_like(dummy_labels)))

def test_works_with_xla(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
layer = layers.RandomApply(rate=0.5, layer=ZeroOut(), auto_vectorize=False)

@tf.function(jit_compile=True)
def apply(x):
return layer(x)

outputs = apply(dummy_inputs)
Loading
Loading