forked from keras-team/keras-cv
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduces MaybeApply layer. (keras-team#435)
* Added MaybeApply layer. * Changed MaybeApply to override _augment method. * Added seed to maybe_apply_test random generator. * Added seed to layer in batched input test. * Fixed MaybeApply docs.
- Loading branch information
1 parent
a6ece4a
commit 2c9ae8a
Showing
4 changed files
with
243 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright 2022 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import tensorflow as tf | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable(package="keras_cv") | ||
class MaybeApply(tf.keras.__internal__.layers.BaseImageAugmentationLayer): | ||
"""Apply provided layer to random elements in a batch. | ||
Args: | ||
layer: a keras `Layer` or `BaseImageAugmentationLayer`. This layer will be | ||
applied to randomly chosen samples in a batch. Layer should not modify the | ||
size of provided inputs. | ||
rate: controls the frequency of applying the layer. 1.0 means all elements in | ||
a batch will be modified. 0.0 means no elements will be modified. | ||
Defaults to 0.5. | ||
auto_vectorize: bool, whether to use tf.vectorized_map or tf.map_fn for | ||
batched input. Setting this to True might give better performance but | ||
currently doesn't work with XLA. Defaults to False. | ||
seed: integer, controls random behaviour. | ||
Example usage: | ||
``` | ||
# Let's declare an example layer that will set all image pixels to zero. | ||
zero_out = tf.keras.layers.Lambda(lambda x: {"images": 0 * x["images"]}) | ||
# Create a small batch of random, single-channel, 2x2 images: | ||
images = tf.random.stateless_uniform(shape=(5, 2, 2, 1), seed=[0, 1]) | ||
print(images[..., 0]) | ||
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy= | ||
# array([[[0.08216608, 0.40928006], | ||
# [0.39318466, 0.3162533 ]], | ||
# | ||
# [[0.34717774, 0.73199546], | ||
# [0.56369007, 0.9769211 ]], | ||
# | ||
# [[0.55243933, 0.13101244], | ||
# [0.2941643 , 0.5130266 ]], | ||
# | ||
# [[0.38977218, 0.80855536], | ||
# [0.6040567 , 0.10502195]], | ||
# | ||
# [[0.51828027, 0.12730157], | ||
# [0.288486 , 0.252975 ]]], dtype=float32)> | ||
# Apply the layer with 50% probability: | ||
maybe_apply = MaybeApply(layer=zero_out, rate=0.5, seed=1234) | ||
outputs = maybe_apply(images) | ||
print(outputs[..., 0]) | ||
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy= | ||
# array([[[0. , 0. ], | ||
# [0. , 0. ]], | ||
# | ||
# [[0.34717774, 0.73199546], | ||
# [0.56369007, 0.9769211 ]], | ||
# | ||
# [[0.55243933, 0.13101244], | ||
# [0.2941643 , 0.5130266 ]], | ||
# | ||
# [[0.38977218, 0.80855536], | ||
# [0.6040567 , 0.10502195]], | ||
# | ||
# [[0. , 0. ], | ||
# [0. , 0. ]]], dtype=float32)> | ||
# We can observe that the layer has been randomly applied to 2 out of 5 samples. | ||
``` | ||
""" | ||
|
||
def __init__(self, layer, rate=0.5, auto_vectorize=False, seed=None, **kwargs): | ||
super().__init__(seed=seed, **kwargs) | ||
|
||
if not (0 <= rate <= 1.0): | ||
raise ValueError(f"rate must be in range [0, 1]. Received rate: {rate}") | ||
|
||
self._layer = layer | ||
self._rate = rate | ||
self.auto_vectorize = auto_vectorize | ||
self.seed = seed | ||
|
||
def _augment(self, inputs): | ||
if self._random_generator.random_uniform(shape=()) > 1.0 - self._rate: | ||
return self._layer(inputs) | ||
else: | ||
return inputs | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"rate": self._rate, | ||
"layer": self._layer, | ||
"seed": self.seed, | ||
"auto_vectorize": self.auto_vectorize, | ||
} | ||
) | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright 2022 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import tensorflow as tf | ||
from absl.testing import parameterized | ||
|
||
from keras_cv.layers.preprocessing.maybe_apply import MaybeApply | ||
|
||
|
||
class ZeroOut(tf.keras.__internal__.layers.BaseImageAugmentationLayer): | ||
"""Zero out all entries, for testing purposes.""" | ||
|
||
def __init__(self): | ||
super(ZeroOut, self).__init__() | ||
|
||
def augment_image(self, image, transformation=None): | ||
return 0 * image | ||
|
||
def augment_label(self, label, transformation=None): | ||
return 0 * label | ||
|
||
def augment_bounding_box(self, bounding_box, transformation=None): | ||
return 0 * bounding_box | ||
|
||
|
||
class MaybeApplyTest(tf.test.TestCase, parameterized.TestCase): | ||
rng = tf.random.Generator.from_seed(seed=1234) | ||
|
||
@parameterized.parameters([-0.5, 1.7]) | ||
def test_raises_error_on_invalid_rate_parameter(self, invalid_rate): | ||
with self.assertRaises(ValueError): | ||
MaybeApply(rate=invalid_rate, layer=ZeroOut()) | ||
|
||
def test_works_with_batched_input(self): | ||
batch_size = 32 | ||
dummy_inputs = self.rng.uniform(shape=(batch_size, 224, 224, 3)) | ||
layer = MaybeApply(rate=0.5, layer=ZeroOut(), seed=1234) | ||
|
||
outputs = layer(dummy_inputs) | ||
num_zero_inputs = self._num_zero_batches(dummy_inputs) | ||
num_zero_outputs = self._num_zero_batches(outputs) | ||
|
||
self.assertEqual(num_zero_inputs, 0) | ||
self.assertLess(num_zero_outputs, batch_size) | ||
self.assertGreater(num_zero_outputs, 0) | ||
|
||
@staticmethod | ||
def _num_zero_batches(images): | ||
num_batches = tf.shape(images)[0] | ||
num_non_zero_batches = tf.math.count_nonzero( | ||
tf.math.count_nonzero(images, axis=[1, 2, 3]), dtype=tf.int32 | ||
) | ||
return num_batches - num_non_zero_batches | ||
|
||
def test_inputs_unchanged_with_zero_rate(self): | ||
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
layer = MaybeApply(rate=0.0, layer=ZeroOut()) | ||
|
||
outputs = layer(dummy_inputs) | ||
|
||
self.assertAllClose(outputs, dummy_inputs) | ||
|
||
def test_all_inputs_changed_with_rate_equal_to_one(self): | ||
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
layer = MaybeApply(rate=1.0, layer=ZeroOut()) | ||
|
||
outputs = layer(dummy_inputs) | ||
|
||
self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs)) | ||
|
||
def test_works_with_single_image(self): | ||
dummy_inputs = self.rng.uniform(shape=(224, 224, 3)) | ||
layer = MaybeApply(rate=1.0, layer=ZeroOut()) | ||
|
||
outputs = layer(dummy_inputs) | ||
|
||
self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs)) | ||
|
||
def test_can_modify_label(self): | ||
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
dummy_labels = tf.ones(shape=(32, 2)) | ||
layer = MaybeApply(rate=1.0, layer=ZeroOut()) | ||
|
||
outputs = layer({"images": dummy_inputs, "labels": dummy_labels}) | ||
|
||
self.assertAllEqual(outputs["labels"], tf.zeros_like(dummy_labels)) | ||
|
||
def test_can_modify_bounding_box(self): | ||
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
dummy_boxes = tf.ones(shape=(32, 4)) | ||
layer = MaybeApply(rate=1.0, layer=ZeroOut()) | ||
|
||
outputs = layer({"images": dummy_inputs, "bounding_boxes": dummy_boxes}) | ||
|
||
self.assertAllEqual(outputs["bounding_boxes"], tf.zeros_like(dummy_boxes)) | ||
|
||
def test_works_with_native_keras_layers(self): | ||
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
zero_out = tf.keras.layers.Lambda(lambda x: {"images": 0 * x["images"]}) | ||
layer = MaybeApply(rate=1.0, layer=zero_out) | ||
|
||
outputs = layer(dummy_inputs) | ||
|
||
self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs)) | ||
|
||
def test_works_with_xla(self): | ||
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
# auto_vectorize=True will crash XLA | ||
layer = MaybeApply(rate=0.5, layer=ZeroOut(), auto_vectorize=False) | ||
|
||
@tf.function(jit_compile=True) | ||
def apply(x): | ||
return layer(x) | ||
|
||
apply(dummy_inputs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters