diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py index 6ba8ce783084..dc207be0df72 100644 --- a/keras/api/_tf_keras/keras/layers/__init__.py +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -155,9 +155,15 @@ from keras.src.layers.preprocessing.image_preprocessing.rand_augment import ( RandAugment, ) +from keras.src.layers.preprocessing.image_preprocessing.random_apply import ( + RandomApply, +) from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) +from keras.src.layers.preprocessing.image_preprocessing.random_choice import ( + RandomChoice, +) from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import ( RandomColorDegeneration, ) diff --git a/keras/api/layers/__init__.py b/keras/api/layers/__init__.py index 3aa267859c77..538b279662eb 100644 --- a/keras/api/layers/__init__.py +++ b/keras/api/layers/__init__.py @@ -155,9 +155,15 @@ from keras.src.layers.preprocessing.image_preprocessing.rand_augment import ( RandAugment, ) +from keras.src.layers.preprocessing.image_preprocessing.random_apply import ( + RandomApply, +) from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) +from keras.src.layers.preprocessing.image_preprocessing.random_choice import ( + RandomChoice, +) from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import ( RandomColorDegeneration, ) diff --git a/keras/src/layers/__init__.py b/keras/src/layers/__init__.py index 59f241cbaf23..257b78374560 100644 --- a/keras/src/layers/__init__.py +++ b/keras/src/layers/__init__.py @@ -99,9 +99,15 @@ from keras.src.layers.preprocessing.image_preprocessing.rand_augment import ( RandAugment, ) +from keras.src.layers.preprocessing.image_preprocessing.random_apply import ( + RandomApply, +) from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) +from keras.src.layers.preprocessing.image_preprocessing.random_choice import ( + RandomChoice, +) from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import ( RandomColorDegeneration, ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_apply.py b/keras/src/layers/preprocessing/image_preprocessing/random_apply.py new file mode 100644 index 000000000000..71f5c1fb26e9 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_apply.py @@ -0,0 +1,182 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random.seed_generator import SeedGenerator + + +@keras_export("keras.layers.RandomApply") +class RandomApply(BaseImagePreprocessingLayer): + """Preprocessing layer to randomly apply a specified layer during training. + + This layer randomly applies a given transformation layer to inputs based on + the `rate` parameter. It is useful for stochastic data augmentation to + improve model robustness. At inference time, the output is identical to + the input. Call the layer with `training=True` to enable random application. + + Args: + layer: A `keras.Layer` to apply. The layer must not modify input shape. + rate: Float between 0.0 and 1.0, representing the probability of + applying the layer. Defaults to 0.5. + batchwise: Boolean. If `True`, the decision to apply the layer is made + for the entire batch. If `False`, it is made independently for each + input. Defaults to `False`. + seed: Optional integer to ensure reproducibility. + + Inputs: A tensor or dictionary of tensors. The input type must be compatible + with the wrapped layer. + + Output: A tensor or dictionary of tensors matching the input structure, with + the transformation layer randomly applied according to a specified rate. + """ + + def __init__( + self, + layer, + rate=0.5, + batchwise=False, + seed=None, + **kwargs, + ): + super().__init__(**kwargs) + if not (0 <= rate <= 1.0): + raise ValueError( + f"rate must be in range [0, 1]. Received rate: {rate}" + ) + self._layer = layer + self._rate = rate + self.batchwise = batchwise + self.seed = seed + self.generator = SeedGenerator(seed) + self.built = True + + def _should_augment(self, batch_size=None): + if batch_size is None: + return ( + self.backend.random.uniform( + shape=(), + seed=self._get_seed_generator(self.backend._backend), + ) + > 1.0 - self._rate + ) + else: + return ( + self.backend.random.uniform( + shape=(batch_size,), + seed=self._get_seed_generator(self.backend._backend), + ) + > 1.0 - self._rate + ) + + def _batch_augment(self, inputs): + if self.batchwise: + if self._should_augment(): + return self._layer(inputs) + return inputs + + batch_size = ops.shape(inputs)[0] + should_augment = self._should_augment(batch_size) + should_augment = ops.reshape(should_augment, (-1, 1, 1, 1)) + + augmented = self._layer(inputs) + return self.backend.numpy.where(should_augment, augmented, inputs) + + def transform_images(self, images, transformation, training=True): + if not training or transformation is None: + return images + + should_augment = transformation["should_augment"] + should_augment = ops.reshape(should_augment, (-1, 1, 1, 1)) + + if hasattr(self._layer, "get_random_transformation"): + layer_transform = self._layer.get_random_transformation( + images, training=training + ) + augmented = self._layer.transform_images( + images, layer_transform, training=training + ) + else: + augmented = self._layer(images) + + return self.backend.numpy.where(should_augment, augmented, images) + + def call(self, inputs, training=True): + if not training: + return inputs + if isinstance(inputs, dict): + result = {} + for key, value in inputs.items(): + result[key] = self._batch_augment(value) + return result + + return self._batch_augment(inputs) + + def transform_labels(self, labels, transformation, training=True): + if not training or transformation is None: + return labels + + should_augment = transformation["should_augment"] + should_augment = ops.reshape(should_augment, (-1, 1)) + + if hasattr(self._layer, "transform_labels"): + layer_transform = self._layer.get_random_transformation( + labels, training=training + ) + augmented = self._layer.transform_labels( + labels, layer_transform, training=training + ) + else: + augmented = self._layer(labels) + + return self.backend.numpy.where(should_augment, augmented, labels) + + def transform_bounding_boxes(self, bboxes, transformation, training=True): + if not training or transformation is None: + return bboxes + + should_augment = transformation["should_augment"] + + if hasattr(self._layer, "transform_bounding_boxes"): + layer_transform = self._layer.get_random_transformation( + bboxes, training=training + ) + augmented = self._layer.transform_bounding_boxes( + bboxes, layer_transform, training=training + ) + else: + augmented = self._layer(bboxes) + + return self.backend.numpy.where(should_augment, augmented, bboxes) + + def transform_segmentation_masks( + self, masks, transformation, training=True + ): + if not training or transformation is None: + return masks + + should_augment = transformation["should_augment"] + + if hasattr(self._layer, "transform_segmentation_masks"): + layer_transform = self._layer.get_random_transformation( + masks, training=training + ) + augmented = self._layer.transform_segmentation_masks( + masks, layer_transform, training=training + ) + else: + augmented = self._layer(masks) + + return self.backend.numpy.where(should_augment, augmented, masks) + + def get_config(self): + config = super().get_config() + config.update( + { + "rate": self._rate, + "layer": self._layer, + "seed": self.seed, + "batchwise": self.batchwise, + } + ) + return config diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_apply_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_apply_test.py new file mode 100644 index 000000000000..7a980a30850a --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_apply_test.py @@ -0,0 +1,125 @@ +import pytest +from absl.testing import parameterized + +import keras.src.random as random +from keras.src import layers +from keras.src import ops +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.testing import TestCase + + +class ZeroOut(BaseImagePreprocessingLayer): + """Layer that zeros out tensors.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.built = True + + def call(self, inputs): + return ops.zeros_like(inputs) + + def transform_images(self, images, transformation=None, training=True): + return ops.zeros_like(images) + + def transform_segmentation_masks( + self, masks, transformation=None, training=True + ): + return ops.zeros_like(masks) + + def transform_bounding_boxes( + self, bboxes, transformation=None, training=True + ): + return ops.zeros_like(bboxes) + + def transform_labels(self, labels, transformation=None, training=True): + return ops.zeros_like(labels) + + def get_config(self): + return super().get_config() + + +# Utility function to count number of all-zero batches in the input. +def _num_zero_batches(images): + num_batches = ops.shape(images)[0] + flattened = ops.reshape(images, (num_batches, -1)) + any_nonzero = ops.any(ops.not_equal(flattened, 0), axis=1) + num_non_zero_batches = ops.sum(ops.cast(any_nonzero, dtype="int32")) + return num_batches - num_non_zero_batches + + +class RandomApplyTest(TestCase): + @parameterized.parameters([-0.5, 1.7]) + def test_raises_error_on_invalid_rate_parameter(self, invalid_rate): + with self.assertRaises(ValueError): + layers.RandomApply(rate=invalid_rate, layer=ZeroOut()) + + def test_works_with_batched_input(self): + batch_size = 32 + dummy_inputs = random.uniform(shape=(batch_size, 224, 224, 3)) + layer = layers.RandomApply(rate=0.5, layer=ZeroOut(), seed=1234) + + outputs = layer(dummy_inputs) + num_zero_inputs = _num_zero_batches(dummy_inputs) + num_zero_outputs = _num_zero_batches(outputs) + + self.assertEqual(num_zero_inputs, 0) + self.assertLess(num_zero_outputs, batch_size) + self.assertGreater(num_zero_outputs, 0) + + def test_works_with_batchwise_layers(self): + batch_size = 32 + dummy_inputs = random.uniform(shape=(batch_size, 224, 224, 3)) + random_flip_layer = layers.RandomFlip( + "vertical", data_format="channels_last", seed=42 + ) + layer = layers.RandomApply(random_flip_layer, rate=0.5, batchwise=True) + outputs = layer(dummy_inputs) + self.assertEqual(outputs.shape, dummy_inputs.shape) + + def test_inputs_unchanged_with_zero_rate(self): + dummy_inputs = random.uniform(shape=(32, 224, 224, 3)) + layer = layers.RandomApply(rate=0.0, layer=ZeroOut()) + + outputs = layer(dummy_inputs) + self.assertAllClose(outputs, dummy_inputs) + + def test_all_inputs_changed_with_rate_equal_to_one(self): + dummy_inputs = random.uniform(shape=(32, 224, 224, 3)) + layer = layers.RandomApply(rate=1.0, layer=ZeroOut()) + outputs = layer(dummy_inputs) + self.assertTrue( + ops.all(ops.equal(outputs, ops.zeros_like(dummy_inputs))) + ) + + def test_works_with_single_image(self): + dummy_inputs = random.uniform(shape=(224, 224, 3)) + layer = layers.RandomApply(rate=1.0, layer=ZeroOut()) + outputs = layer(dummy_inputs) + self.assertTrue( + ops.all(ops.equal(outputs, ops.zeros_like(dummy_inputs))) + ) + + def test_can_modify_label(self): + dummy_inputs = random.uniform(shape=(32, 224, 224, 3)) + dummy_labels = ops.ones(shape=(32, 2)) + layer = layers.RandomApply(rate=1.0, layer=ZeroOut()) + outputs = layer({"images": dummy_inputs, "labels": dummy_labels}) + self.assertTrue( + ops.all(ops.equal(outputs["labels"], ops.zeros_like(dummy_labels))) + ) + + @pytest.mark.skipif( + ops.backend.backend() != "tensorflow", + reason="XLA compilation is only supported with TensorFlow backend", + ) + def test_works_with_xla(self): + dummy_inputs = random.uniform(shape=(32, 224, 224, 3)) + layer = layers.RandomApply(rate=0.5, layer=ZeroOut()) + + def apply(x): + return layer(x) + + outputs = apply(dummy_inputs) + apply(outputs) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_choice.py b/keras/src/layers/preprocessing/image_preprocessing/random_choice.py new file mode 100644 index 000000000000..43801e4f92d9 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_choice.py @@ -0,0 +1,129 @@ +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random.seed_generator import SeedGenerator + + +@keras_export("keras.layers.RandomChoice") +class RandomChoice(BaseImagePreprocessingLayer): + """A preprocessing layer that randomly applies a layer from a list. + + Useful for creating randomized data augmentation pipelines. During + training, for each input (or batch of inputs), it randomly selects one layer + from the provided list and applies it to the input. This allows for diverse + augmentations to be applied dynamically. + + Args: + layers: A list of `keras.Layers` instances. Each layer should subclass + `BaseImagePreprocessingLayer`. During augmentation, one layer is + randomly selected and applied to the input. + batchwise: Boolean, whether to apply the same randomly selected layer to + the entire batch of inputs. If `True`, the entire batch is passed to + a single layer. If `False`, each input in the batch is processed by + an independently selected layer. Defaults to `False`. + seed: Integer to seed random number generator for reproducibility. + Defaults to `None`. + + Call Arguments: + inputs: Single image tensor (rank 3), batch of image tensors (rank 4), + or a dictionary of tensors. The input is augmented by one randomly + selected layer from the `layers` list. + + Returns: + Augmented inputs, with the same shape and structure as the input. + """ + + def __init__( + self, + layers, + batchwise=False, + seed=None, + **kwargs, + ): + super().__init__(**kwargs) + self.layers = layers + self.batchwise = batchwise + self.seed = seed + self.generator = SeedGenerator(seed) + self._convert_input_args = False + self._allow_non_tensor_positional_args = True + + def _curry_call_layer(self, inputs, training=True): + if not training: + return None + + input_shape = self.backend.shape(inputs) + + if self.batchwise: + selected_op = ops.floor( + self.backend.random.uniform( + shape=(), + minval=0, + maxval=len(self.layers), + dtype="float32", + seed=self._get_seed_generator(self.backend._backend), + ) + ) + else: + batch_size = input_shape[0] + selected_op = ops.floor( + self.backend.random.uniform( + shape=(batch_size,), + minval=0, + maxval=len(self.layers), + dtype="float32", + seed=self._get_seed_generator(self.backend._backend), + ) + ) + + ndims = len(input_shape) + ones = [1] * (ndims - 1) + broadcast_shape = tuple([batch_size] + ones) + selected_op = self.backend.numpy.reshape( + selected_op, broadcast_shape + ) + + return selected_op + + def _call_single(self, inputs): + selected_op = self._curry_call_layer(inputs, training=True) + output = self.backend.cast(inputs, self.compute_dtype) + + for i, layer in enumerate(self.layers): + condition = ops.equal(selected_op, float(i)) + if hasattr(layer, "get_random_transformation"): + layer_transform = layer.get_random_transformation( + inputs, + training=True, + seed=self._get_seed_generator(self.backend._backend), + ) + augmented = layer.transform_images( + inputs, layer_transform, training=True + ) + else: + augmented = layer(inputs) + output = self.backend.numpy.where(condition, augmented, output) + + return output + + def call(self, inputs): + if isinstance(inputs, dict): + return { + key: self._call_single(input_tensor) + for key, input_tensor in inputs.items() + } + else: + return self._call_single(inputs) + + def get_config(self): + config = super().get_config() + config.update( + { + "layers": self.layers, + "batchwise": self.batchwise, + "seed": self.seed, + } + ) + return config diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_choice_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_choice_test.py new file mode 100644 index 000000000000..70ac25afef42 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_choice_test.py @@ -0,0 +1,113 @@ +import keras.src.random as random +from keras.src import layers +from keras.src import ops +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.testing import TestCase + + +class AddOneToInputs(BaseImagePreprocessingLayer): + """Add 1 to all image values, for testing purposes.""" + + def __init__(self): + super().__init__() + self.call_counter = 0 + + def call(self, inputs): + self.call_counter += 1 + return inputs + 1 + + def transform_images(self, images, transformation=None, training=True): + return images + 1 + + def transform_labels(self, labels, transformation=None, training=True): + return labels + 1 + + def transform_bounding_boxes( + self, bboxes, transformation=None, training=True + ): + return bboxes + 1 + + def transform_segmentation_masks( + self, masks, transformation=None, training=True + ): + return masks + 1 + + +class RandomChoiceTest(TestCase): + def test_calls_layer_augmentation_per_image(self): + layer = AddOneToInputs() + pipeline = layers.RandomChoice(layers=[layer]) + xs = random.uniform( + shape=(2, 5, 5, 3), minval=0, maxval=100, dtype="float32" + ) + os = pipeline(xs) + + self.assertAllClose(xs + 1, os) + + def test_calls_layer_augmentation_eager(self): + layer = AddOneToInputs() + pipeline = layers.RandomChoice(layers=[layer]) + + def call_pipeline(xs): + return pipeline(xs) + + xs = random.uniform( + shape=(2, 5, 5, 3), minval=0, maxval=100, dtype="float32" + ) + os = call_pipeline(xs) + + self.assertAllClose(xs + 1, os) + + def test_batchwise(self): + layer = AddOneToInputs() + pipeline = layers.RandomChoice(layers=[layer], batchwise=True) + xs = random.uniform( + shape=(4, 5, 5, 3), minval=0, maxval=100, dtype="float32" + ) + os = pipeline(xs) + + self.assertAllClose(xs + 1, os) + ops.all(ops.equal(layer.call_counter, 1)) + + def test_works_with_random_flip(self): + pipeline = layers.RandomChoice( + layers=[ + layers.RandomFlip( + "vertical", data_format="channels_last", seed=42 + ) + ], + batchwise=True, + ) + xs = random.uniform( + shape=(4, 5, 5, 3), minval=0, maxval=100, dtype="float32" + ) + pipeline(xs) + + def test_calls_layer_augmentation_single_image(self): + layer = AddOneToInputs() + pipeline = layers.RandomChoice(layers=[layer]) + xs = random.uniform( + shape=(5, 5, 3), minval=0, maxval=100, dtype="float32" + ) + os = pipeline(xs) + + self.assertAllClose(xs + 1, os) + + def test_calls_choose_one_layer_augmentation(self): + batch_size = 10 + pipeline = layers.RandomChoice( + layers=[AddOneToInputs(), AddOneToInputs()] + ) + xs = random.uniform( + shape=(batch_size, 5, 5, 3), minval=0, maxval=100, dtype="float32" + ) + os = pipeline(xs) + + self.assertAllClose(xs + 1, os) + + total_calls = ( + pipeline.layers[0].call_counter + pipeline.layers[1].call_counter + ) + ops.all(ops.equal(total_calls, batch_size))