diff --git a/keras_core/layers/preprocessing/random_crop.py b/keras_core/layers/preprocessing/random_crop.py index 4d1653de3..2da6240f9 100644 --- a/keras_core/layers/preprocessing/random_crop.py +++ b/keras_core/layers/preprocessing/random_crop.py @@ -1,14 +1,14 @@ -import numpy as np - from keras_core import backend +from keras_core import ops from keras_core.api_export import keras_core_export -from keras_core.layers.layer import Layer +from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer +from keras_core.random.seed_generator import SeedGenerator from keras_core.utils import backend_utils -from keras_core.utils.module_utils import tensorflow as tf +from keras_core.utils import image_utils @keras_core_export("keras_core.layers.RandomCrop") -class RandomCrop(Layer): +class RandomCrop(TFDataLayer): """A preprocessing layer which randomly crops images during training. During training, this layer will randomly choose a location to crop images @@ -52,41 +52,122 @@ class RandomCrop(Layer): `name` and `dtype`. """ - def __init__(self, height, width, seed=None, name=None, **kwargs): - if not tf.available: - raise ImportError( - "Layer RandomCrop requires TensorFlow. " - "Install it via `pip install tensorflow`." - ) - + def __init__( + self, height, width, seed=None, data_format=None, name=None, **kwargs + ): super().__init__(name=name, **kwargs) + self.height = height + self.width = width self.seed = seed or backend.random.make_default_seed() - self.layer = tf.keras.layers.RandomCrop( - height=height, - width=width, - seed=self.seed, - name=name, - ) + self.generator = SeedGenerator(seed) + self.data_format = backend.standardize_data_format(data_format) + + if self.data_format == "channels_first": + self.heigh_axis = -2 + self.width_axis = -1 + elif self.data_format == "channels_last": + self.height_axis = -3 + self.width_axis = -2 + self.supports_masking = False self.supports_jit = False self._convert_input_args = False self._allow_non_tensor_positional_args = True def call(self, inputs, training=True): - if not isinstance(inputs, (tf.Tensor, np.ndarray, list, tuple)): - inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs)) - outputs = self.layer.call(inputs, training=training) - if ( - backend.backend() != "tensorflow" - and not backend_utils.in_tf_graph() - ): - outputs = backend.convert_to_tensor(outputs) + inputs = self.backend.cast(inputs, self.compute_dtype) + input_shape = self.backend.shape(inputs) + is_batched = len(input_shape) > 3 + inputs = ( + self.backend.numpy.expand_dims(inputs, axis=0) + if not is_batched + else inputs + ) + + h_diff = input_shape[self.height_axis] - self.height + w_diff = input_shape[self.width_axis] - self.width + + def random_crop(): + input_height, input_width = ( + input_shape[self.height_axis], + input_shape[self.width_axis], + ) + + seed_generator = self._get_seed_generator(self.backend._backend) + h_start = self.backend.cast( + ops.random.uniform( + (), + 0, + maxval=float(input_height - self.height + 1), + dtype="int", + seed=seed_generator, + ), + "int", + ) + w_start = self.backend.cast( + ops.random.uniform( + (), + 0, + maxval=float(input_width - self.width + 1), + dtype="int", + seed=seed_generator, + ), + "int", + ) + if self.data_format == "channels_last": + return inputs[ + :, + h_start : h_start + self.height, + w_start : w_start + self.width, + ] + else: + return inputs[ + :, + :, + h_start : h_start + self.height, + w_start : w_start + self.width, + ] + + def resize(): + outputs = image_utils.smart_resize( + inputs, + [self.height, self.width], + data_format=self.data_format, + backend_module=self.backend, + ) + # smart_resize will always output float32, so we need to re-cast. + return self.backend.cast(outputs, self.compute_dtype) + + outputs = self.backend.cond( + self.backend.numpy.all((training, h_diff >= 0, w_diff >= 0)), + random_crop, + resize, + ) + + outputs = ( + self.backend.numpy.squeeze(outputs, axis=0) + if not is_batched + else outputs + ) + + if self.backend != "tensorflow" and not backend_utils.in_tf_graph(): + outputs = self.backend.convert_to_tensor(outputs) return outputs - def compute_output_shape(self, input_shape): - return tuple(self.layer.compute_output_shape(input_shape)) + def compute_output_shape(self, input_shape, *args, **kwargs): + input_shape = list(input_shape) + input_shape[self.height_axis] = self.height + input_shape[self.width_axis] = self.width + return tuple(input_shape) def get_config(self): - config = self.layer.get_config() - config.update({"seed": self.seed}) + config = super().get_config() + config.update( + { + "height": self.height, + "width": self.width, + "seed": self.seed, + "data_format": self.data_format, + } + ) return config