Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(layers): Migrate RandomChoice and RandomApply layers from keras-cv to keras #20752

6 changes: 6 additions & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,15 @@
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_apply import (
RandomApply,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
from keras.src.layers.preprocessing.image_preprocessing.random_choice import (
RandomChoice,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (
RandomColorDegeneration,
)
Expand Down
6 changes: 6 additions & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,15 @@
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_apply import (
RandomApply,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
from keras.src.layers.preprocessing.image_preprocessing.random_choice import (
RandomChoice,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (
RandomColorDegeneration,
)
Expand Down
6 changes: 6 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,15 @@
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_apply import (
RandomApply,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
from keras.src.layers.preprocessing.image_preprocessing.random_choice import (
RandomChoice,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (
RandomColorDegeneration,
)
Expand Down
182 changes: 182 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/random_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.random.seed_generator import SeedGenerator


@keras_export("keras.layers.RandomApply")
class RandomApply(BaseImagePreprocessingLayer):
"""Preprocessing layer to randomly apply a specified layer during training.

This layer randomly applies a given transformation layer to inputs based on
the `rate` parameter. It is useful for stochastic data augmentation to
improve model robustness. At inference time, the output is identical to
the input. Call the layer with `training=True` to enable random application.

Args:
layer: A `keras.Layer` to apply. The layer must not modify input shape.
rate: Float between 0.0 and 1.0, representing the probability of
applying the layer. Defaults to 0.5.
batchwise: Boolean. If `True`, the decision to apply the layer is made
for the entire batch. If `False`, it is made independently for each
input. Defaults to `False`.
seed: Optional integer to ensure reproducibility.

Inputs: A tensor or dictionary of tensors. The input type must be compatible
with the wrapped layer.

Output: A tensor or dictionary of tensors matching the input structure, with
the transformation layer randomly applied according to a specified rate.
"""

def __init__(
self,
layer,
rate=0.5,
batchwise=False,
seed=None,
**kwargs,
):
super().__init__(**kwargs)
if not (0 <= rate <= 1.0):
raise ValueError(
f"rate must be in range [0, 1]. Received rate: {rate}"
)
self._layer = layer
self._rate = rate
self.batchwise = batchwise
self.seed = seed
self.generator = SeedGenerator(seed)
self.built = True

def _should_augment(self, batch_size=None):
if batch_size is None:
return (
self.backend.random.uniform(
shape=(),
seed=self._get_seed_generator(self.backend._backend),
)
> 1.0 - self._rate
)
else:
return (
self.backend.random.uniform(
shape=(batch_size,),
seed=self._get_seed_generator(self.backend._backend),
)
> 1.0 - self._rate
)

def _batch_augment(self, inputs):
if self.batchwise:
if self._should_augment():
return self._layer(inputs)
return inputs

batch_size = ops.shape(inputs)[0]
should_augment = self._should_augment(batch_size)
should_augment = ops.reshape(should_augment, (-1, 1, 1, 1))

augmented = self._layer(inputs)
return self.backend.numpy.where(should_augment, augmented, inputs)

def transform_images(self, images, transformation, training=True):
if not training or transformation is None:
return images

should_augment = transformation["should_augment"]
should_augment = ops.reshape(should_augment, (-1, 1, 1, 1))

if hasattr(self._layer, "get_random_transformation"):
layer_transform = self._layer.get_random_transformation(
images, training=training
)
augmented = self._layer.transform_images(
images, layer_transform, training=training
)
else:
augmented = self._layer(images)

return self.backend.numpy.where(should_augment, augmented, images)

def call(self, inputs, training=True):
if not training:
return inputs
if isinstance(inputs, dict):
result = {}
for key, value in inputs.items():
result[key] = self._batch_augment(value)
return result

return self._batch_augment(inputs)

def transform_labels(self, labels, transformation, training=True):
if not training or transformation is None:
return labels

should_augment = transformation["should_augment"]
should_augment = ops.reshape(should_augment, (-1, 1))

if hasattr(self._layer, "transform_labels"):
layer_transform = self._layer.get_random_transformation(
labels, training=training
)
augmented = self._layer.transform_labels(
labels, layer_transform, training=training
)
else:
augmented = self._layer(labels)

return self.backend.numpy.where(should_augment, augmented, labels)

def transform_bounding_boxes(self, bboxes, transformation, training=True):
if not training or transformation is None:
return bboxes

should_augment = transformation["should_augment"]

if hasattr(self._layer, "transform_bounding_boxes"):
layer_transform = self._layer.get_random_transformation(
bboxes, training=training
)
augmented = self._layer.transform_bounding_boxes(
bboxes, layer_transform, training=training
)
else:
augmented = self._layer(bboxes)

return self.backend.numpy.where(should_augment, augmented, bboxes)

def transform_segmentation_masks(
self, masks, transformation, training=True
):
if not training or transformation is None:
return masks

should_augment = transformation["should_augment"]

if hasattr(self._layer, "transform_segmentation_masks"):
layer_transform = self._layer.get_random_transformation(
masks, training=training
)
augmented = self._layer.transform_segmentation_masks(
masks, layer_transform, training=training
)
else:
augmented = self._layer(masks)

return self.backend.numpy.where(should_augment, augmented, masks)

def get_config(self):
config = super().get_config()
config.update(
{
"rate": self._rate,
"layer": self._layer,
"seed": self.seed,
"batchwise": self.batchwise,
}
)
return config
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import pytest
from absl.testing import parameterized

import keras.src.random as random
from keras.src import layers
from keras.src import ops
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.testing import TestCase


class ZeroOut(BaseImagePreprocessingLayer):
"""Layer that zeros out tensors."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.built = True

def call(self, inputs):
return ops.zeros_like(inputs)

def transform_images(self, images, transformation=None, training=True):
return ops.zeros_like(images)

def transform_segmentation_masks(
self, masks, transformation=None, training=True
):
return ops.zeros_like(masks)

def transform_bounding_boxes(
self, bboxes, transformation=None, training=True
):
return ops.zeros_like(bboxes)

def transform_labels(self, labels, transformation=None, training=True):
return ops.zeros_like(labels)

def get_config(self):
return super().get_config()


# Utility function to count number of all-zero batches in the input.
def _num_zero_batches(images):
num_batches = ops.shape(images)[0]
flattened = ops.reshape(images, (num_batches, -1))
any_nonzero = ops.any(ops.not_equal(flattened, 0), axis=1)
num_non_zero_batches = ops.sum(ops.cast(any_nonzero, dtype="int32"))
return num_batches - num_non_zero_batches


class RandomApplyTest(TestCase):
@parameterized.parameters([-0.5, 1.7])
def test_raises_error_on_invalid_rate_parameter(self, invalid_rate):
with self.assertRaises(ValueError):
layers.RandomApply(rate=invalid_rate, layer=ZeroOut())

def test_works_with_batched_input(self):
batch_size = 32
dummy_inputs = random.uniform(shape=(batch_size, 224, 224, 3))
layer = layers.RandomApply(rate=0.5, layer=ZeroOut(), seed=1234)

outputs = layer(dummy_inputs)
num_zero_inputs = _num_zero_batches(dummy_inputs)
num_zero_outputs = _num_zero_batches(outputs)

self.assertEqual(num_zero_inputs, 0)
self.assertLess(num_zero_outputs, batch_size)
self.assertGreater(num_zero_outputs, 0)

def test_works_with_batchwise_layers(self):
batch_size = 32
dummy_inputs = random.uniform(shape=(batch_size, 224, 224, 3))
random_flip_layer = layers.RandomFlip(
"vertical", data_format="channels_last", seed=42
)
layer = layers.RandomApply(random_flip_layer, rate=0.5, batchwise=True)
outputs = layer(dummy_inputs)
self.assertEqual(outputs.shape, dummy_inputs.shape)

def test_inputs_unchanged_with_zero_rate(self):
dummy_inputs = random.uniform(shape=(32, 224, 224, 3))
layer = layers.RandomApply(rate=0.0, layer=ZeroOut())

outputs = layer(dummy_inputs)
self.assertAllClose(outputs, dummy_inputs)

def test_all_inputs_changed_with_rate_equal_to_one(self):
dummy_inputs = random.uniform(shape=(32, 224, 224, 3))
layer = layers.RandomApply(rate=1.0, layer=ZeroOut())
outputs = layer(dummy_inputs)
self.assertTrue(
ops.all(ops.equal(outputs, ops.zeros_like(dummy_inputs)))
)

def test_works_with_single_image(self):
dummy_inputs = random.uniform(shape=(224, 224, 3))
layer = layers.RandomApply(rate=1.0, layer=ZeroOut())
outputs = layer(dummy_inputs)
self.assertTrue(
ops.all(ops.equal(outputs, ops.zeros_like(dummy_inputs)))
)

def test_can_modify_label(self):
dummy_inputs = random.uniform(shape=(32, 224, 224, 3))
dummy_labels = ops.ones(shape=(32, 2))
layer = layers.RandomApply(rate=1.0, layer=ZeroOut())
outputs = layer({"images": dummy_inputs, "labels": dummy_labels})
self.assertTrue(
ops.all(ops.equal(outputs["labels"], ops.zeros_like(dummy_labels)))
)

@pytest.mark.skipif(
ops.backend.backend() != "tensorflow",
reason="XLA compilation is only supported with TensorFlow backend",
)
def test_works_with_xla(self):
dummy_inputs = random.uniform(shape=(32, 224, 224, 3))
layer = layers.RandomApply(rate=0.5, layer=ZeroOut())

def apply(x):
return layer(x)

outputs = apply(dummy_inputs)
apply(outputs)
Loading
Loading