From 3ed601404da1779a0e8ddd8827af838539189918 Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 14 Jan 2025 23:22:18 +0900 Subject: [PATCH 01/11] Initial commit - first couple of layers --- keras_hub/src/models/detr/detr_backbone.py | 0 keras_hub/src/models/detr/detr_layers.py | 105 +++++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 keras_hub/src/models/detr/detr_backbone.py create mode 100644 keras_hub/src/models/detr/detr_layers.py diff --git a/keras_hub/src/models/detr/detr_backbone.py b/keras_hub/src/models/detr/detr_backbone.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/detr/detr_layers.py b/keras_hub/src/models/detr/detr_layers.py new file mode 100644 index 0000000000..47ad5438fe --- /dev/null +++ b/keras_hub/src/models/detr/detr_layers.py @@ -0,0 +1,105 @@ +import math + +from keras import Layer +from keras import ops + + +class DetrFrozenBatchNormalization(Layer): + """BatchNormalization with fixed affine + batch stats. + Based on https://github.com/facebookresearch/detr/blob/master/models/backbone.py. + """ + + def __init__(self, num_features, epsilon=1e-5, **kwargs): + super().__init__(**kwargs) + self.num_features = num_features + self.epsilon = epsilon + + def build(self): + self.weight = self.add_weight( + shape=(self.num_features,), + initializer="ones", + trainable=False, + name="weight", + ) + self.bias = self.add_weight( + shape=(self.num_features,), + initializer="zeros", + trainable=False, + name="bias", + ) + self.running_mean = self.add_weight( + shape=(self.num_features,), + initializer="zeros", + trainable=False, + name="running_mean", + ) + self.running_var = self.add_weight( + shape=(self.num_features,), + initializer="ones", + trainable=False, + name="running_var", + ) + + def call(self, inputs): + weight = ops.reshape(self.weight, (1, 1, 1, -1)) + bias = ops.reshape(self.bias, (1, 1, 1, -1)) + running_mean = ops.reshape(self.running_mean, (1, 1, 1, -1)) + running_var = ops.reshape(self.running_var, (1, 1, 1, -1)) + + scale = weight * ops.rsqrt(running_var + self.epsilon) + bias = bias - running_mean * scale + return inputs * scale + bias + + def get_config(self): + config = super().get_config() + config.update( + {"num_features": self.num_features, "epsilon": self.epsilon} + ) + return config + + +class DetrSinePositionEmbedding(Layer): + def __init__( + self, embedding_dim=64, temperature=10000, normalize=False, scale=None + ): + super().__init__() + self.embedding_dim = embedding_dim + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def call(self, pixel_mask): + if pixel_mask is None: + raise ValueError("No pixel mask provided") + y_embed = ops.cumsum(pixel_mask, axis=1) + x_embed = ops.cumsum(pixel_mask, axis=2) + if self.normalize: + y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale + + dim_t = ops.arange(self.embedding_dim) + dim_t = self.temperature ** ( + 2 * ops.floor(dim_t / 2) / self.embedding_dim + ) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = ops.stack( + (ops.sin(pos_x[:, :, :, ::2]), ops.cos(pos_x[:, :, :, 1::2])), + axis=4, + ) + pos_y = ops.stack( + (ops.sin(pos_y[:, :, :, ::2]), ops.cos(pos_y[:, :, :, 1::2])), + axis=4, + ) + + pos_x = ops.flatten(pos_x, axis=3) + pos_y = ops.flatten(pos_y, axis=3) + + pos = ops.cat((pos_y, pos_x), axis=3) + pos = ops.transpose(pos, [0, 3, 1, 2]) + return pos From b103664896a6e69ea99d14ca74008cc0acede7e2 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 19 Jan 2025 20:58:14 +0900 Subject: [PATCH 02/11] Add transformer layers - trimmed down --- keras_hub/src/models/detr/detr_layers.py | 577 ++++++++++++++++++++++- 1 file changed, 571 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/detr/detr_layers.py b/keras_hub/src/models/detr/detr_layers.py index 47ad5438fe..9f4af8091f 100644 --- a/keras_hub/src/models/detr/detr_layers.py +++ b/keras_hub/src/models/detr/detr_layers.py @@ -1,6 +1,8 @@ import math from keras import Layer +from keras import activations +from keras import layers from keras import ops @@ -75,13 +77,14 @@ def __init__( def call(self, pixel_mask): if pixel_mask is None: raise ValueError("No pixel mask provided") - y_embed = ops.cumsum(pixel_mask, axis=1) - x_embed = ops.cumsum(pixel_mask, axis=2) + + y_embed = ops.cumsum(pixel_mask, axis=1, dtype="float32") + x_embed = ops.cumsum(pixel_mask, axis=2, dtype="float32") if self.normalize: y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale - dim_t = ops.arange(self.embedding_dim) + dim_t = ops.arange(self.embedding_dim, dtype="float32") dim_t = self.temperature ** ( 2 * ops.floor(dim_t / 2) / self.embedding_dim ) @@ -97,9 +100,571 @@ def call(self, pixel_mask): axis=4, ) - pos_x = ops.flatten(pos_x, axis=3) - pos_y = ops.flatten(pos_y, axis=3) + pos_x = ops.reshape( + pos_x, [pos_x.shape[0], pos_x.shape[1], pos_x.shape[2], -1] + ) + pos_y = ops.reshape( + pos_y, [pos_y.shape[0], pos_y.shape[1], pos_y.shape[2], -1] + ) - pos = ops.cat((pos_y, pos_x), axis=3) + pos = ops.concatenate((pos_y, pos_x), axis=3) pos = ops.transpose(pos, [0, 3, 1, 2]) return pos + + +class DetrTransformerEncoder(layers.Layer): + """ + Adapted from + https://github.com/tensorflow/models/blob/master/official/projects/detr/modeling/transformer.py + """ + + def __init__( + self, + num_layers=6, + num_attention_heads=8, + intermediate_size=2048, + activation="relu", + dropout_rate=0.0, + attention_dropout_rate=0.0, + use_bias=False, + norm_first=True, + norm_epsilon=1e-6, + intermediate_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self._intermediate_size = intermediate_size + self._activation = activation + self._dropout_rate = dropout_rate + self._attention_dropout_rate = attention_dropout_rate + self._use_bias = use_bias + self._norm_first = norm_first + self._norm_epsilon = norm_epsilon + self._intermediate_dropout = intermediate_dropout + + def build(self, input_shape): + self.encoder_layers = [] + for i in range(self.num_layers): + self.encoder_layers.append( + DetrTransformerEncoderBlock( + num_attention_heads=self.num_attention_heads, + inner_dim=self._intermediate_size, + inner_activation=self._activation, + output_dropout=self._dropout_rate, + attention_dropout=self._attention_dropout_rate, + use_bias=self._use_bias, + norm_first=self._norm_first, + norm_epsilon=self._norm_epsilon, + inner_dropout=self._intermediate_dropout, + ) + ) + self.output_normalization = layers.LayerNormalization( + epsilon=self._norm_epsilon, dtype="float32" + ) + super().build(input_shape) + + def get_config(self): + config = { + "num_layers": self.num_layers, + "num_attention_heads": self.num_attention_heads, + "intermediate_size": self._intermediate_size, + "activation": self._activation, + "dropout_rate": self._dropout_rate, + "attention_dropout_rate": self._attention_dropout_rate, + "use_bias": self._use_bias, + "norm_first": self._norm_first, + "norm_epsilon": self._norm_epsilon, + "intermediate_dropout": self._intermediate_dropout, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + def call(self, encoder_inputs, attention_mask=None, pos_embed=None): + for layer_idx in range(self.num_layers): + encoder_inputs = self.encoder_layers[layer_idx]( + [encoder_inputs, attention_mask, pos_embed] + ) + + output_tensor = encoder_inputs + output_tensor = self.output_normalization(output_tensor) + + return output_tensor + + +class DetrTransformerEncoderBlock(layers.Layer): + """ + Adapted from + https://github.com/tensorflow/models/blob/master/official/projects/detr/modeling/transformer.py + """ + + def __init__( + self, + num_attention_heads, + inner_dim, + inner_activation, + output_range=None, + use_bias=True, + norm_first=False, + norm_epsilon=1e-12, + output_dropout=0.0, + attention_dropout=0.0, + inner_dropout=0.0, + attention_axes=None, + **kwargs, + ): + super().__init__(**kwargs) + + self._num_heads = num_attention_heads + self._inner_dim = inner_dim + self._inner_activation = inner_activation + self._attention_dropout = attention_dropout + self._attention_dropout_rate = attention_dropout + self._output_dropout = output_dropout + self._output_dropout_rate = output_dropout + self._output_range = output_range + self._use_bias = use_bias + self._norm_first = norm_first + self._norm_epsilon = norm_epsilon + self._inner_dropout = inner_dropout + self._attention_axes = attention_axes + + def build(self, input_shape): + einsum_equation = "abc,cd->abd" + if len(len(input_shape)) > 3: + einsum_equation = "...bc,cd->...bd" + + hidden_size = input_shape[-1] + if hidden_size % self._num_heads != 0: + raise ValueError( + "The input size (%d) is not a multiple of " + "the number of attention heads (%d)" + % (hidden_size, self._num_heads) + ) + self._attention_head_size = int(hidden_size // self._num_heads) + + self._attention_layer = layers.MultiHeadAttention( + num_heads=self._num_heads, + key_dim=self._attention_head_size, + dropout=self._attention_dropout, + use_bias=self._use_bias, + attention_axes=self._attention_axes, + name="self_attention", + ) + self._attention_dropout = layers.Dropout(rate=self._output_dropout) + self._attention_layer_norm = layers.LayerNormalization( + name="self_attention_layer_norm", + axis=-1, + epsilon=self._norm_epsilon, + dtype="float32", + ) + self._intermediate_dense = layers.EinsumDense( + einsum_equation, + output_shape=(None, self._inner_dim), + bias_axes="d", + name="intermediate", + ) + + self._intermediate_activation_layer = layers.Activation( + self._inner_activation + ) + self._inner_dropout_layer = layers.Dropout(rate=self._inner_dropout) + self._output_dense = layers.EinsumDense( + einsum_equation, + output_shape=(None, hidden_size), + bias_axes="d", + name="output", + ) + self._output_dropout = layers.Dropout(rate=self._output_dropout) + self._output_layer_norm = layers.LayerNormalization( + name="output_layer_norm", + axis=-1, + epsilon=self._norm_epsilon, + dtype="float32", + ) + + super().build(input_shape) + + def get_config(self): + config = { + "num_attention_heads": self._num_heads, + "inner_dim": self._inner_dim, + "inner_activation": self._inner_activation, + "output_dropout": self._output_dropout_rate, + "attention_dropout": self._attention_dropout_rate, + "output_range": self._output_range, + "use_bias": self._use_bias, + "norm_first": self._norm_first, + "norm_epsilon": self._norm_epsilon, + "inner_dropout": self._inner_dropout, + "attention_axes": self._attention_axes, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + def call(self, inputs): + input_tensor, attention_mask, pos_embed = inputs + + key_value = None + + if self._output_range: + if self._norm_first: + source_tensor = input_tensor[:, 0 : self._output_range, :] + input_tensor = self._attention_layer_norm(input_tensor) + if key_value is not None: + key_value = self._attention_layer_norm(key_value) + target_tensor = input_tensor[:, 0 : self._output_range, :] + if attention_mask is not None: + attention_mask = attention_mask[:, 0 : self._output_range, :] + else: + if self._norm_first: + source_tensor = input_tensor + input_tensor = self._attention_layer_norm(input_tensor) + if key_value is not None: + key_value = self._attention_layer_norm(key_value) + target_tensor = input_tensor + + if key_value is None: + key_value = input_tensor + attention_output = self._attention_layer( + query=target_tensor + pos_embed, + key=key_value + pos_embed, + value=key_value, + attention_mask=attention_mask, + ) + attention_output = self._attention_dropout(attention_output) + if self._norm_first: + attention_output = source_tensor + attention_output + else: + attention_output = self._attention_layer_norm( + target_tensor + attention_output + ) + if self._norm_first: + source_attention_output = attention_output + attention_output = self._output_layer_norm(attention_output) + inner_output = self._intermediate_dense(attention_output) + inner_output = self._intermediate_activation_layer(inner_output) + inner_output = self._inner_dropout_layer(inner_output) + layer_output = self._output_dense(inner_output) + layer_output = self._output_dropout(layer_output) + + if self._norm_first: + return source_attention_output + layer_output + + return self._output_layer_norm(layer_output + attention_output) + + +class DetrTransformerDecoder(layers.Layer): + """ + Adapted from + https://github.com/tensorflow/models/blob/master/official/projects/detr/modeling/transformer.py + """ + + def __init__( + self, + num_layers=6, + num_attention_heads=8, + intermediate_size=2048, + activation="relu", + dropout_rate=0.0, + attention_dropout_rate=0.0, + use_bias=False, + norm_first=True, + norm_epsilon=1e-6, + intermediate_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self._intermediate_size = intermediate_size + self._activation = activation + self._dropout_rate = dropout_rate + self._attention_dropout_rate = attention_dropout_rate + self._use_bias = use_bias + self._norm_first = norm_first + self._norm_epsilon = norm_epsilon + self._intermediate_dropout = intermediate_dropout + + def build(self, input_shape): + self.decoder_layers = [] + for i in range(self.num_layers): + self.decoder_layers.append( + DetrTransformerDecoderBlock( + num_attention_heads=self.num_attention_heads, + intermediate_size=self._intermediate_size, + intermediate_activation=self._activation, + dropout_rate=self._dropout_rate, + attention_dropout_rate=self._attention_dropout_rate, + use_bias=self._use_bias, + norm_first=self._norm_first, + norm_epsilon=self._norm_epsilon, + intermediate_dropout=self._intermediate_dropout, + name=("layer_%d" % i), + ) + ) + self.output_normalization = layers.LayerNormalization( + epsilon=self._norm_epsilon, dtype="float32" + ) + super().build(input_shape) + + def get_config(self): + config = { + "num_layers": self.num_layers, + "num_attention_heads": self.num_attention_heads, + "intermediate_size": self._intermediate_size, + "activation": self._activation, + "dropout_rate": self._dropout_rate, + "attention_dropout_rate": self._attention_dropout_rate, + "use_bias": self._use_bias, + "norm_first": self._norm_first, + "norm_epsilon": self._norm_epsilon, + "intermediate_dropout": self._intermediate_dropout, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + def call( + self, + target, + memory, + self_attention_mask=None, + cross_attention_mask=None, + cache=None, + decode_loop_step=None, + return_all_decoder_outputs=False, + input_pos_embed=None, + memory_pos_embed=None, + ): + output_tensor = target + decoder_outputs = [] + for layer_idx in range(self.num_layers): + transformer_inputs = [ + output_tensor, + memory, + cross_attention_mask, + self_attention_mask, + input_pos_embed, + memory_pos_embed, + ] + # Gets the cache for decoding. + if cache is None: + output_tensor, _ = self.decoder_layers[layer_idx]( + transformer_inputs + ) + else: + cache_layer_idx = str(layer_idx) + output_tensor, cache[cache_layer_idx] = self.decoder_layers[ + layer_idx + ]( + transformer_inputs, + cache=cache[cache_layer_idx], + decode_loop_step=decode_loop_step, + ) + if return_all_decoder_outputs: + decoder_outputs.append(self.output_normalization(output_tensor)) + + if return_all_decoder_outputs: + return decoder_outputs + else: + return self.output_normalization(output_tensor) + + +class DetrTransformerDecoderBlock(layers.Layer): + """ + Adapted from + https://github.com/tensorflow/models/blob/master/official/projects/detr/modeling/transformer.py + """ + + def __init__( + self, + num_attention_heads, + intermediate_size, + intermediate_activation, + dropout_rate=0.0, + attention_dropout_rate=0.0, + use_bias=True, + norm_first=False, + norm_epsilon=1e-12, + intermediate_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.intermediate_activation = activations.get(intermediate_activation) + self.dropout_rate = dropout_rate + self.attention_dropout_rate = attention_dropout_rate + + self._use_bias = use_bias + self._norm_first = norm_first + self._norm_epsilon = norm_epsilon + self._intermediate_dropout = intermediate_dropout + + self._cross_attention_cls = layers.MultiHeadAttention + + def build(self, input_shape): + if len(input_shape) != 3: + raise ValueError( + "TransformerLayer expects a three-dimensional input of " + "shape [batch, sequence, width]." + ) + hidden_size = input_shape[2] + if hidden_size % self.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the " + "number of attention heads (%d)" + % (hidden_size, self.num_attention_heads) + ) + self.attention_head_size = int(hidden_size) // self.num_attention_heads + + # Self attention. + self.self_attention = layers.MultiHeadAttention( + num_heads=self.num_attention_heads, + key_dim=self.attention_head_size, + dropout=self.attention_dropout_rate, + use_bias=self._use_bias, + name="self_attention", + ) + self.self_attention_output_dense = layers.EinsumDense( + "abc,cd->abd", + output_shape=(None, hidden_size), + bias_axes="d", + name="output", + ) + self.self_attention_dropout = layers.Dropout(rate=self.dropout_rate) + self.self_attention_layer_norm = layers.LayerNormalization( + name="self_attention_layer_norm", + axis=-1, + epsilon=self._norm_epsilon, + dtype="float32", + ) + # Encoder-decoder attention. + self.encdec_attention = self._cross_attention_cls( + num_heads=self.num_attention_heads, + key_dim=self.attention_head_size, + dropout=self.attention_dropout_rate, + output_shape=hidden_size, + use_bias=self._use_bias, + name="attention/encdec", + ) + + self.encdec_attention_dropout = layers.Dropout(rate=self.dropout_rate) + self.encdec_attention_layer_norm = layers.LayerNormalization( + name="attention/encdec_output_layer_norm", + axis=-1, + epsilon=self._norm_epsilon, + dtype="float32", + ) + + # Feed-forward projection. + self.intermediate_dense = layers.EinsumDense( + "abc,cd->abd", + output_shape=(None, self.intermediate_size), + bias_axes="d", + name="intermediate", + ) + self.intermediate_activation_layer = layers.Activation( + self.intermediate_activation + ) + self._intermediate_dropout_layer = layers.Dropout( + rate=self._intermediate_dropout + ) + self.output_dense = layers.EinsumDense( + "abc,cd->abd", + output_shape=(None, hidden_size), + bias_axes="d", + name="output", + ) + self.output_dropout = layers.Dropout(rate=self.dropout_rate) + self.output_layer_norm = layers.LayerNormalization( + name="output_layer_norm", + axis=-1, + epsilon=self._norm_epsilon, + dtype="float32", + ) + super().build(input_shape) + + def get_config(self): + config = { + "num_attention_heads": self.num_attention_heads, + "intermediate_size": self.intermediate_size, + "dropout_rate": self.dropout_rate, + "attention_dropout_rate": self.attention_dropout_rate, + "use_bias": self._use_bias, + "norm_first": self._norm_first, + "norm_epsilon": self._norm_epsilon, + "intermediate_dropout": self._intermediate_dropout, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) + + def call(self, inputs, cache=None, decode_loop_step=None): + ( + input_tensor, + memory, + attention_mask, + self_attention_mask, + input_pos_embed, + memory_pos_embed, + ) = inputs + source_tensor = input_tensor + if self._norm_first: + input_tensor = self.self_attention_layer_norm(input_tensor) + self_attention_output, cache = self.self_attention( + query=input_tensor + input_pos_embed, + key=input_tensor + input_pos_embed, + value=input_tensor, + attention_mask=self_attention_mask, + cache=cache, + decode_loop_step=decode_loop_step, + ) + self_attention_output = self.self_attention_dropout( + self_attention_output + ) + if self._norm_first: + self_attention_output = source_tensor + self_attention_output + else: + self_attention_output = self.self_attention_layer_norm( + input_tensor + self_attention_output + ) + if self._norm_first: + source_self_attention_output = self_attention_output + self_attention_output = self.encdec_attention_layer_norm( + self_attention_output + ) + cross_attn_inputs = dict( + query=self_attention_output + input_pos_embed, + key=memory + memory_pos_embed, + value=memory, + attention_mask=attention_mask, + ) + attention_output = self.encdec_attention(**cross_attn_inputs) + attention_output = self.encdec_attention_dropout(attention_output) + if self._norm_first: + attention_output = source_self_attention_output + attention_output + else: + attention_output = self.encdec_attention_layer_norm( + self_attention_output + attention_output + ) + if self._norm_first: + source_attention_output = attention_output + attention_output = self.output_layer_norm(attention_output) + + intermediate_output = self.intermediate_dense(attention_output) + intermediate_output = self.intermediate_activation_layer( + intermediate_output + ) + intermediate_output = self._intermediate_dropout_layer( + intermediate_output + ) + layer_output = self.output_dense(intermediate_output) + layer_output = self.output_dropout(layer_output) + if self._norm_first: + layer_output = source_attention_output + layer_output + else: + layer_output = self.output_layer_norm( + layer_output + attention_output + ) + return layer_output, cache From cef86cbbd3a3455e1c8ab3b566be7f889ce34858 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 19 Jan 2025 21:00:21 +0900 Subject: [PATCH 03/11] remove cache/decoding steps --- keras_hub/src/models/detr/detr_layers.py | 27 ++++++------------------ 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/keras_hub/src/models/detr/detr_layers.py b/keras_hub/src/models/detr/detr_layers.py index 9f4af8091f..f68ab11250 100644 --- a/keras_hub/src/models/detr/detr_layers.py +++ b/keras_hub/src/models/detr/detr_layers.py @@ -431,8 +431,6 @@ def call( memory, self_attention_mask=None, cross_attention_mask=None, - cache=None, - decode_loop_step=None, return_all_decoder_outputs=False, input_pos_embed=None, memory_pos_embed=None, @@ -448,20 +446,9 @@ def call( input_pos_embed, memory_pos_embed, ] - # Gets the cache for decoding. - if cache is None: - output_tensor, _ = self.decoder_layers[layer_idx]( - transformer_inputs - ) - else: - cache_layer_idx = str(layer_idx) - output_tensor, cache[cache_layer_idx] = self.decoder_layers[ - layer_idx - ]( - transformer_inputs, - cache=cache[cache_layer_idx], - decode_loop_step=decode_loop_step, - ) + + output_tensor = self.decoder_layers[layer_idx](transformer_inputs) + if return_all_decoder_outputs: decoder_outputs.append(self.output_normalization(output_tensor)) @@ -600,7 +587,7 @@ def get_config(self): base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) - def call(self, inputs, cache=None, decode_loop_step=None): + def call(self, inputs): ( input_tensor, memory, @@ -612,13 +599,11 @@ def call(self, inputs, cache=None, decode_loop_step=None): source_tensor = input_tensor if self._norm_first: input_tensor = self.self_attention_layer_norm(input_tensor) - self_attention_output, cache = self.self_attention( + self_attention_output = self.self_attention( query=input_tensor + input_pos_embed, key=input_tensor + input_pos_embed, value=input_tensor, attention_mask=self_attention_mask, - cache=cache, - decode_loop_step=decode_loop_step, ) self_attention_output = self.self_attention_dropout( self_attention_output @@ -667,4 +652,4 @@ def call(self, inputs, cache=None, decode_loop_step=None): layer_output = self.output_layer_norm( layer_output + attention_output ) - return layer_output, cache + return layer_output From 8f33663155e26224706c6076346b95e32e482549 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 19 Jan 2025 21:11:06 +0900 Subject: [PATCH 04/11] Convert einsumdense into dense layers, remove output range --- keras_hub/src/models/detr/detr_layers.py | 48 +++++++----------------- 1 file changed, 13 insertions(+), 35 deletions(-) diff --git a/keras_hub/src/models/detr/detr_layers.py b/keras_hub/src/models/detr/detr_layers.py index f68ab11250..6d9e6b53a8 100644 --- a/keras_hub/src/models/detr/detr_layers.py +++ b/keras_hub/src/models/detr/detr_layers.py @@ -204,7 +204,6 @@ def __init__( num_attention_heads, inner_dim, inner_activation, - output_range=None, use_bias=True, norm_first=False, norm_epsilon=1e-12, @@ -223,7 +222,6 @@ def __init__( self._attention_dropout_rate = attention_dropout self._output_dropout = output_dropout self._output_dropout_rate = output_dropout - self._output_range = output_range self._use_bias = use_bias self._norm_first = norm_first self._norm_epsilon = norm_epsilon @@ -231,10 +229,6 @@ def __init__( self._attention_axes = attention_axes def build(self, input_shape): - einsum_equation = "abc,cd->abd" - if len(len(input_shape)) > 3: - einsum_equation = "...bc,cd->...bd" - hidden_size = input_shape[-1] if hidden_size % self._num_heads != 0: raise ValueError( @@ -259,21 +253,17 @@ def build(self, input_shape): epsilon=self._norm_epsilon, dtype="float32", ) - self._intermediate_dense = layers.EinsumDense( - einsum_equation, - output_shape=(None, self._inner_dim), - bias_axes="d", + self._intermediate_dense = layers.Dense( + self._inner_dim, + activation=self._inner_activation, + use_bias=self._use_bias, name="intermediate", ) - self._intermediate_activation_layer = layers.Activation( - self._inner_activation - ) self._inner_dropout_layer = layers.Dropout(rate=self._inner_dropout) - self._output_dense = layers.EinsumDense( - einsum_equation, - output_shape=(None, hidden_size), - bias_axes="d", + self._output_dense = layers.Dense( + hidden_size, + use_bias=self._use_bias, name="output", ) self._output_dropout = layers.Dropout(rate=self._output_dropout) @@ -293,7 +283,6 @@ def get_config(self): "inner_activation": self._inner_activation, "output_dropout": self._output_dropout_rate, "attention_dropout": self._attention_dropout_rate, - "output_range": self._output_range, "use_bias": self._use_bias, "norm_first": self._norm_first, "norm_epsilon": self._norm_epsilon, @@ -308,22 +297,12 @@ def call(self, inputs): key_value = None - if self._output_range: - if self._norm_first: - source_tensor = input_tensor[:, 0 : self._output_range, :] - input_tensor = self._attention_layer_norm(input_tensor) - if key_value is not None: - key_value = self._attention_layer_norm(key_value) - target_tensor = input_tensor[:, 0 : self._output_range, :] - if attention_mask is not None: - attention_mask = attention_mask[:, 0 : self._output_range, :] - else: - if self._norm_first: - source_tensor = input_tensor - input_tensor = self._attention_layer_norm(input_tensor) - if key_value is not None: - key_value = self._attention_layer_norm(key_value) - target_tensor = input_tensor + if self._norm_first: + source_tensor = input_tensor + input_tensor = self._attention_layer_norm(input_tensor) + if key_value is not None: + key_value = self._attention_layer_norm(key_value) + target_tensor = input_tensor if key_value is None: key_value = input_tensor @@ -344,7 +323,6 @@ def call(self, inputs): source_attention_output = attention_output attention_output = self._output_layer_norm(attention_output) inner_output = self._intermediate_dense(attention_output) - inner_output = self._intermediate_activation_layer(inner_output) inner_output = self._inner_dropout_layer(inner_output) layer_output = self._output_dense(inner_output) layer_output = self._output_dropout(layer_output) From 1b0b97d5c0f281227b018c9d3c642368095c8d53 Mon Sep 17 00:00:00 2001 From: David Landup Date: Sun, 19 Jan 2025 21:24:00 +0900 Subject: [PATCH 05/11] Slight cleanup --- keras_hub/src/models/detr/detr_layers.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/keras_hub/src/models/detr/detr_layers.py b/keras_hub/src/models/detr/detr_layers.py index 6d9e6b53a8..566f2d4cd1 100644 --- a/keras_hub/src/models/detr/detr_layers.py +++ b/keras_hub/src/models/detr/detr_layers.py @@ -295,21 +295,15 @@ def get_config(self): def call(self, inputs): input_tensor, attention_mask, pos_embed = inputs - key_value = None - if self._norm_first: source_tensor = input_tensor input_tensor = self._attention_layer_norm(input_tensor) - if key_value is not None: - key_value = self._attention_layer_norm(key_value) target_tensor = input_tensor - if key_value is None: - key_value = input_tensor attention_output = self._attention_layer( query=target_tensor + pos_embed, - key=key_value + pos_embed, - value=key_value, + key=input_tensor + pos_embed, + value=input_tensor, attention_mask=attention_mask, ) attention_output = self._attention_dropout(attention_output) @@ -322,6 +316,7 @@ def call(self, inputs): if self._norm_first: source_attention_output = attention_output attention_output = self._output_layer_norm(attention_output) + inner_output = self._intermediate_dense(attention_output) inner_output = self._inner_dropout_layer(inner_output) layer_output = self._output_dense(inner_output) From e8cdbbc89715a62132349729e6c0879e452ca434 Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 22 Jan 2025 19:47:32 +0900 Subject: [PATCH 06/11] Add detrtransformer --- keras_hub/src/models/detr/detr_layers.py | 89 ++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/keras_hub/src/models/detr/detr_layers.py b/keras_hub/src/models/detr/detr_layers.py index 566f2d4cd1..432e7a9493 100644 --- a/keras_hub/src/models/detr/detr_layers.py +++ b/keras_hub/src/models/detr/detr_layers.py @@ -626,3 +626,92 @@ def call(self, inputs): layer_output + attention_output ) return layer_output + + +class DETRTransformer(Layer): + """Encoder and decoder, forming a DETRTransformer.""" + + def __init__( + self, + num_encoder_layers=6, + num_decoder_layers=6, + num_attention_heads=8, + intermediate_size=2048, + dropout_rate=0.1, + **kwargs, + ): + super().__init__(**kwargs) + self._dropout_rate = dropout_rate + self._num_encoder_layers = num_encoder_layers + self._num_decoder_layers = num_decoder_layers + self._num_attention_heads = num_attention_heads + self._intermediate_size = intermediate_size + + def build(self, input_shape=None): + if self._num_encoder_layers > 0: + self._encoder = DetrTransformerEncoder( + attention_dropout_rate=self._dropout_rate, + dropout_rate=self._dropout_rate, + intermediate_dropout=self._dropout_rate, + norm_first=False, + num_layers=self._num_encoder_layers, + num_attention_heads=self._num_attention_heads, + intermediate_size=self._intermediate_size, + ) + else: + self._encoder = None + + self._decoder = DetrTransformerDecoder( + attention_dropout_rate=self._dropout_rate, + dropout_rate=self._dropout_rate, + intermediate_dropout=self._dropout_rate, + norm_first=False, + num_layers=self._num_decoder_layers, + num_attention_heads=self._num_attention_heads, + intermediate_size=self._intermediate_size, + ) + super().build(input_shape) + + def get_config(self): + return { + "num_encoder_layers": self._num_encoder_layers, + "num_decoder_layers": self._num_decoder_layers, + "dropout_rate": self._dropout_rate, + } + + def call(self, inputs): + sources = inputs["inputs"] + targets = inputs["targets"] + pos_embed = inputs["pos_embed"] + mask = inputs["mask"] + input_shape = ops.shape(sources) + source_attention_mask = ops.tile( + ops.expand_dims(mask, axis=1), [1, input_shape[1], 1] + ) + if self._encoder is not None: + memory = self._encoder( + sources, + attention_mask=source_attention_mask, + pos_embed=pos_embed, + ) + else: + memory = sources + + target_shape = ops.shape(targets) + cross_attention_mask = ops.tile( + ops.expand_dims(mask, axis=1), [1, target_shape[1], 1] + ) + target_shape = ops.shape(targets) + + decoded = self._decoder( + ops.zeros_like(targets), + memory, + self_attention_mask=ops.ones( + (target_shape[0], target_shape[1], target_shape[1]) + ), + cross_attention_mask=cross_attention_mask, + return_all_decoder_outputs=True, + input_pos_embed=targets, + memory_pos_embed=pos_embed, + ) + return decoded From 62f900d449f8d3b514228b32f4c23bb73f684a1a Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 22 Jan 2025 19:55:30 +0900 Subject: [PATCH 07/11] Add basic subclassing DETR model --- keras_hub/src/models/detr/detr_backbone.py | 144 +++++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/keras_hub/src/models/detr/detr_backbone.py b/keras_hub/src/models/detr/detr_backbone.py index e69de29bb2..3643544326 100644 --- a/keras_hub/src/models/detr/detr_backbone.py +++ b/keras_hub/src/models/detr/detr_backbone.py @@ -0,0 +1,144 @@ +import math + +from keras import Model +from keras import layers +from keras import ops +from src.models.detr.detr_layers import DetrSinePositionEmbedding +from src.models.detr.detr_layers import DETRTransformer + + +class DETR(Model): + """DETR Model. + + Includes a backbone (ResNet50), query embedding, + DETRTransformer (DetrTransformerEncoder + DetrTransformerDecoder) + class and box heads. + """ + + def __init__( + self, + backbone, + backbone_endpoint_name, + num_queries, + hidden_size, + num_classes, + num_encoder_layers=6, + num_decoder_layers=6, + dropout_rate=0.1, + **kwargs, + ): + super().__init__(**kwargs) + self._num_queries = num_queries + self._hidden_size = hidden_size + self._num_classes = num_classes + self._num_encoder_layers = num_encoder_layers + self._num_decoder_layers = num_decoder_layers + self._dropout_rate = dropout_rate + if hidden_size % 2 != 0: + raise ValueError("hidden_size must be a multiple of 2.") + self._backbone = backbone + self._backbone_endpoint_name = backbone_endpoint_name + + def build(self, input_shape=None): + self._input_proj = layers.Conv2D( + self._hidden_size, 1, name="detr/conv2d" + ) + self._build_detection_decoder() + super().build(input_shape) + + def _build_detection_decoder(self): + """Builds detection decoder.""" + self._transformer = DETRTransformer( + num_encoder_layers=self._num_encoder_layers, + num_decoder_layers=self._num_decoder_layers, + dropout_rate=self._dropout_rate, + ) + self._query_embeddings = self.add_weight( + "detr/query_embeddings", + shape=[self._num_queries, self._hidden_size], + ) + sqrt_k = math.sqrt(1.0 / self._hidden_size) + self._class_embed = layers.layers.Dense( + self._num_classes, name="detr/cls_dense" + ) + self._bbox_embed = [ + layers.Dense( + self._hidden_size, activation="relu", name="detr/box_dense_0" + ), + layers.Dense( + self._hidden_size, activation="relu", name="detr/box_dense_1" + ), + layers.Dense(4, name="detr/box_dense_2"), + ] + self._sigmoid = layers.Activation("sigmoid") + + @property + def backbone(self): + return self._backbone + + def get_config(self): + return { + "backbone": self._backbone, + "backbone_endpoint_name": self._backbone_endpoint_name, + "num_queries": self._num_queries, + "hidden_size": self._hidden_size, + "num_classes": self._num_classes, + "num_encoder_layers": self._num_encoder_layers, + "num_decoder_layers": self._num_decoder_layers, + "dropout_rate": self._dropout_rate, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + def _generate_image_mask(self, inputs, target_shape): + """Generates image mask from input image.""" + mask = ops.expand_dims( + ops.cast(ops.not_equal(ops.sum(inputs, axis=-1), 0), inputs.dtype), + axis=-1, + ) + mask = tf.image.resize( + mask, target_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR + ) + return mask + + def call(self, inputs, training=None): + batch_size = ops.shape(inputs)[0] + features = self._backbone(inputs)[self._backbone_endpoint_name] + shape = ops.shape(features) + mask = self._generate_image_mask(inputs, shape[1:3]) + + pos_embed = DetrSinePositionEmbedding(embedding_dim=self._hidden_size)( + mask[:, :, :, 0] + ) + pos_embed = ops.reshape(pos_embed, [batch_size, -1, self._hidden_size]) + + features = ops.reshape( + self._input_proj(features), [batch_size, -1, self._hidden_size] + ) + mask = ops.reshape(mask, [batch_size, -1]) + + decoded_list = self._transformer( + { + "inputs": features, + "targets": ops.tile( + ops.expand_dims(self._query_embeddings, axis=0), + (batch_size, 1, 1), + ), + "pos_embed": pos_embed, + "mask": mask, + } + ) + out_list = [] + for decoded in decoded_list: + decoded = ops.stack(decoded) + output_class = self._class_embed(decoded) + box_out = decoded + for layer in self._bbox_embed: + box_out = layer(box_out) + output_coord = self._sigmoid(box_out) + out = {"cls_outputs": output_class, "box_outputs": output_coord} + out_list.append(out) + + return out_list From 736101280b850e7c9233954cc28d70796d858726 Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 22 Jan 2025 20:07:24 +0900 Subject: [PATCH 08/11] Slight cleanup --- keras_hub/src/models/detr/detr_backbone.py | 30 +++++++--------------- keras_hub/src/models/detr/detr_layers.py | 8 +++--- 2 files changed, 13 insertions(+), 25 deletions(-) diff --git a/keras_hub/src/models/detr/detr_backbone.py b/keras_hub/src/models/detr/detr_backbone.py index 3643544326..4b235a1e22 100644 --- a/keras_hub/src/models/detr/detr_backbone.py +++ b/keras_hub/src/models/detr/detr_backbone.py @@ -1,5 +1,3 @@ -import math - from keras import Model from keras import layers from keras import ops @@ -18,7 +16,6 @@ class and box heads. def __init__( self, backbone, - backbone_endpoint_name, num_queries, hidden_size, num_classes, @@ -37,12 +34,9 @@ def __init__( if hidden_size % 2 != 0: raise ValueError("hidden_size must be a multiple of 2.") self._backbone = backbone - self._backbone_endpoint_name = backbone_endpoint_name def build(self, input_shape=None): - self._input_proj = layers.Conv2D( - self._hidden_size, 1, name="detr/conv2d" - ) + self._input_proj = layers.Conv2D(self._hidden_size, 1, name="conv2d") self._build_detection_decoder() super().build(input_shape) @@ -54,21 +48,18 @@ def _build_detection_decoder(self): dropout_rate=self._dropout_rate, ) self._query_embeddings = self.add_weight( - "detr/query_embeddings", shape=[self._num_queries, self._hidden_size], ) - sqrt_k = math.sqrt(1.0 / self._hidden_size) - self._class_embed = layers.layers.Dense( - self._num_classes, name="detr/cls_dense" - ) + # sqrt_k = math.sqrt(1.0 / self._hidden_size) + self._class_embed = layers.Dense(self._num_classes, name="cls_dense") self._bbox_embed = [ layers.Dense( - self._hidden_size, activation="relu", name="detr/box_dense_0" + self._hidden_size, activation="relu", name="box_dense_0" ), layers.Dense( - self._hidden_size, activation="relu", name="detr/box_dense_1" + self._hidden_size, activation="relu", name="box_dense_1" ), - layers.Dense(4, name="detr/box_dense_2"), + layers.Dense(4, name="box_dense_2"), ] self._sigmoid = layers.Activation("sigmoid") @@ -79,7 +70,6 @@ def backbone(self): def get_config(self): return { "backbone": self._backbone, - "backbone_endpoint_name": self._backbone_endpoint_name, "num_queries": self._num_queries, "hidden_size": self._hidden_size, "num_classes": self._num_classes, @@ -98,19 +88,17 @@ def _generate_image_mask(self, inputs, target_shape): ops.cast(ops.not_equal(ops.sum(inputs, axis=-1), 0), inputs.dtype), axis=-1, ) - mask = tf.image.resize( - mask, target_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR - ) + mask = ops.image.resize(mask, target_shape, interpolation="nearest") return mask def call(self, inputs, training=None): batch_size = ops.shape(inputs)[0] - features = self._backbone(inputs)[self._backbone_endpoint_name] + features = self._backbone(inputs) shape = ops.shape(features) mask = self._generate_image_mask(inputs, shape[1:3]) pos_embed = DetrSinePositionEmbedding(embedding_dim=self._hidden_size)( - mask[:, :, :, 0] + pixel_mask=mask[:, :, :, 0] ) pos_embed = ops.reshape(pos_embed, [batch_size, -1, self._hidden_size]) diff --git a/keras_hub/src/models/detr/detr_layers.py b/keras_hub/src/models/detr/detr_layers.py index 432e7a9493..2db430147d 100644 --- a/keras_hub/src/models/detr/detr_layers.py +++ b/keras_hub/src/models/detr/detr_layers.py @@ -74,12 +74,12 @@ def __init__( scale = 2 * math.pi self.scale = scale - def call(self, pixel_mask): - if pixel_mask is None: + def call(self, inputs): + if input is None: raise ValueError("No pixel mask provided") - y_embed = ops.cumsum(pixel_mask, axis=1, dtype="float32") - x_embed = ops.cumsum(pixel_mask, axis=2, dtype="float32") + y_embed = ops.cumsum(inputs, axis=1, dtype="float32") + x_embed = ops.cumsum(inputs, axis=2, dtype="float32") if self.normalize: y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale From 0f0e87955b5e1f09261a6eefe67dc09ee0d39c84 Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 22 Jan 2025 20:46:11 +0900 Subject: [PATCH 09/11] First wave of refactoring --- keras_hub/src/models/detr/detr_backbone.py | 106 ++++-- keras_hub/src/models/detr/detr_layers.py | 420 ++++++++++++--------- 2 files changed, 313 insertions(+), 213 deletions(-) diff --git a/keras_hub/src/models/detr/detr_backbone.py b/keras_hub/src/models/detr/detr_backbone.py index 4b235a1e22..f210459d73 100644 --- a/keras_hub/src/models/detr/detr_backbone.py +++ b/keras_hub/src/models/detr/detr_backbone.py @@ -1,8 +1,32 @@ from keras import Model from keras import layers from keras import ops -from src.models.detr.detr_layers import DetrSinePositionEmbedding from src.models.detr.detr_layers import DETRTransformer +from src.models.detr.detr_layers import position_embedding_sine + + +def _freeze_batch_norm(model): + """DETR uses "frozen" batch norm, i.e. batch normalization + with zeros and ones as the parameters, and they don't get adjusted + during training. This was done through a custom class. + + Since it's tricky to exchange all BatchNormalization layers + in an existing model with FrozenBatchNormalization, we just + make them untrainable and assign the "frozen" parameters. + """ + for layer in model.layers: + if isinstance(layer, layers.BatchNormalization): + # Disable training of the layer + layer.trainable = False + # Set the layer to inference mode + layer._trainable = False + # Manually freeze weights and stats + layer.gamma.assign(ops.ones_like(layer.gamma)) + layer.beta.assign(ops.zeros_like(layer.beta)) + layer.moving_mean.assign(ops.zeros_like(layer.moving_mean)) + layer.moving_variance.assign(ops.ones_like(layer.moving_variance)) + + return model class DETR(Model): @@ -25,57 +49,57 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self._num_queries = num_queries - self._hidden_size = hidden_size - self._num_classes = num_classes - self._num_encoder_layers = num_encoder_layers - self._num_decoder_layers = num_decoder_layers - self._dropout_rate = dropout_rate + self.num_queries = num_queries + self.hidden_size = hidden_size + self.num_classes = num_classes + self.num_encoder_layers = num_encoder_layers + self.num_decoder_layers = num_decoder_layers + self.dropout_rate = dropout_rate if hidden_size % 2 != 0: raise ValueError("hidden_size must be a multiple of 2.") - self._backbone = backbone + self.backbone = backbone def build(self, input_shape=None): - self._input_proj = layers.Conv2D(self._hidden_size, 1, name="conv2d") - self._build_detection_decoder() + self.input_proj = layers.Conv2D(self.hidden_size, 1, name="conv2d") + self.build_detection_decoder() super().build(input_shape) def _build_detection_decoder(self): """Builds detection decoder.""" - self._transformer = DETRTransformer( - num_encoder_layers=self._num_encoder_layers, - num_decoder_layers=self._num_decoder_layers, - dropout_rate=self._dropout_rate, + self.transformer = DETRTransformer( + num_encoder_layers=self.num_encoder_layers, + num_decoder_layers=self.num_decoder_layers, + dropout_rate=self.dropout_rate, ) - self._query_embeddings = self.add_weight( - shape=[self._num_queries, self._hidden_size], + self.query_embeddings = self.add_weight( + shape=[self.num_queries, self.hidden_size], ) - # sqrt_k = math.sqrt(1.0 / self._hidden_size) - self._class_embed = layers.Dense(self._num_classes, name="cls_dense") - self._bbox_embed = [ + # sqrt_k = math.sqrt(1.0 / self.hidden_size) + self.class_embed = layers.Dense(self.num_classes, name="cls_dense") + self.bbox_embed = [ layers.Dense( - self._hidden_size, activation="relu", name="box_dense_0" + self.hidden_size, activation="relu", name="box_dense_0" ), layers.Dense( - self._hidden_size, activation="relu", name="box_dense_1" + self.hidden_size, activation="relu", name="box_dense_1" ), layers.Dense(4, name="box_dense_2"), ] - self._sigmoid = layers.Activation("sigmoid") + self.sigmoid = layers.Activation("sigmoid") @property def backbone(self): - return self._backbone + return self.backbone def get_config(self): return { - "backbone": self._backbone, - "num_queries": self._num_queries, - "hidden_size": self._hidden_size, - "num_classes": self._num_classes, - "num_encoder_layers": self._num_encoder_layers, - "num_decoder_layers": self._num_decoder_layers, - "dropout_rate": self._dropout_rate, + "backbone": self.backbone, + "num_queries": self.num_queries, + "hidden_size": self.hidden_size, + "num_classes": self.num_classes, + "num_encoder_layers": self.num_encoder_layers, + "num_decoder_layers": self.num_decoder_layers, + "dropout_rate": self.dropout_rate, } @classmethod @@ -93,25 +117,25 @@ def _generate_image_mask(self, inputs, target_shape): def call(self, inputs, training=None): batch_size = ops.shape(inputs)[0] - features = self._backbone(inputs) + features = self.backbone(inputs) shape = ops.shape(features) - mask = self._generate_image_mask(inputs, shape[1:3]) + mask = self.generate_image_mask(inputs, shape[1:3]) - pos_embed = DetrSinePositionEmbedding(embedding_dim=self._hidden_size)( - pixel_mask=mask[:, :, :, 0] + pos_embed = position_embedding_sine( + mask[:, :, :, 0], num_pos_features=self.hidden_size ) - pos_embed = ops.reshape(pos_embed, [batch_size, -1, self._hidden_size]) + pos_embed = ops.reshape(pos_embed, [batch_size, -1, self.hidden_size]) features = ops.reshape( - self._input_proj(features), [batch_size, -1, self._hidden_size] + self.input_proj(features), [batch_size, -1, self.hidden_size] ) mask = ops.reshape(mask, [batch_size, -1]) - decoded_list = self._transformer( + decoded_list = self.transformer( { "inputs": features, "targets": ops.tile( - ops.expand_dims(self._query_embeddings, axis=0), + ops.expand_dims(self.query_embeddings, axis=0), (batch_size, 1, 1), ), "pos_embed": pos_embed, @@ -121,11 +145,11 @@ def call(self, inputs, training=None): out_list = [] for decoded in decoded_list: decoded = ops.stack(decoded) - output_class = self._class_embed(decoded) + output_class = self.class_embed(decoded) box_out = decoded - for layer in self._bbox_embed: + for layer in self.bbox_embed: box_out = layer(box_out) - output_coord = self._sigmoid(box_out) + output_coord = self.sigmoid(box_out) out = {"cls_outputs": output_class, "box_outputs": output_coord} out_list.append(out) diff --git a/keras_hub/src/models/detr/detr_layers.py b/keras_hub/src/models/detr/detr_layers.py index 2db430147d..4e1a38e6a3 100644 --- a/keras_hub/src/models/detr/detr_layers.py +++ b/keras_hub/src/models/detr/detr_layers.py @@ -112,6 +112,81 @@ def call(self, inputs): return pos +# Functional version of the code based on https://github.com/tensorflow/models/blob/master/official/projects/detr/modeling/detr.py + + +def position_embedding_sine( + attention_mask, + num_pos_features=256, + temperature=10000.0, + normalize=True, + scale=2 * math.pi, +): + """Sine-based positional embeddings for 2D images. + + Args: + attention_mask: a `bool` Tensor specifying the size of the input image to + the Transformer and which elements are padded, of size [batch_size, + height, width] + num_pos_features: a `int` specifying the number of positional features, + should be equal to the hidden size of the Transformer network + temperature: a `float` specifying the temperature of the positional + embedding. Any type that is converted to a `float` can also be accepted. + normalize: a `bool` determining whether the positional embeddings should be + normalized between [0, scale] before application of the sine and cos + functions. + scale: a `float` if normalize is True specifying the scale embeddings before + application of the embedding function. + + Returns: + embeddings: a `float` tensor of the same shape as input_tensor specifying + the positional embeddings based on sine features. + """ + if num_pos_features % 2 != 0: + raise ValueError( + "Number of embedding features (num_pos_features) must be even when " + "column and row embeddings are concatenated." + ) + num_pos_features = num_pos_features // 2 + + # Produce row and column embeddings based on total size of the image + # [batch_size, height, width] + row_embedding = ops.cumsum(attention_mask, 1) + col_embedding = ops.cumsum(attention_mask, 2) + + if normalize: + eps = 1e-6 + row_embedding = row_embedding / (row_embedding[:, -1:, :] + eps) * scale + col_embedding = col_embedding / (col_embedding[:, :, -1:] + eps) * scale + + dim_t = ops.arange(num_pos_features, dtype=row_embedding.dtype) + dim_t = ops.power(temperature, 2 * (dim_t // 2) / num_pos_features) + + # Creates positional embeddings for each row and column position + # [batch_size, height, width, num_pos_features] + pos_row = ops.expand_dims(row_embedding, -1) / dim_t + pos_col = ops.expand_dims(col_embedding, -1) / dim_t + pos_row = ops.stack( + [ops.sin(pos_row[:, :, :, 0::2]), ops.cos(pos_row[:, :, :, 1::2])], + axis=4, + ) + pos_col = ops.stack( + [ops.sin(pos_col[:, :, :, 0::2]), ops.cos(pos_col[:, :, :, 1::2])], + axis=4, + ) + + # final_shape = pos_row.shape.as_list()[:3] + [-1] + final_shape = ops.shape(pos_row)[:3] + (-1,) + pos_row = ops.reshape(pos_row, final_shape) + pos_col = ops.reshape(pos_col, final_shape) + output = ops.concatenate([pos_row, pos_col], -1) + + return output + + +from keras.layers import Layer + + class DetrTransformerEncoder(layers.Layer): """ Adapted from @@ -125,7 +200,7 @@ def __init__( intermediate_size=2048, activation="relu", dropout_rate=0.0, - attention_dropout_rate=0.0, + attentiondropout_rate=0.0, use_bias=False, norm_first=True, norm_epsilon=1e-6, @@ -135,14 +210,14 @@ def __init__( super().__init__(**kwargs) self.num_layers = num_layers self.num_attention_heads = num_attention_heads - self._intermediate_size = intermediate_size - self._activation = activation - self._dropout_rate = dropout_rate - self._attention_dropout_rate = attention_dropout_rate - self._use_bias = use_bias - self._norm_first = norm_first - self._norm_epsilon = norm_epsilon - self._intermediate_dropout = intermediate_dropout + self.intermediate_size = intermediate_size + self.activation = activation + self.dropout_rate = dropout_rate + self.attentiondropout_rate = attentiondropout_rate + self.use_bias = use_bias + self.norm_first = norm_first + self.norm_epsilon = norm_epsilon + self.intermediate_dropout = intermediate_dropout def build(self, input_shape): self.encoder_layers = [] @@ -150,18 +225,18 @@ def build(self, input_shape): self.encoder_layers.append( DetrTransformerEncoderBlock( num_attention_heads=self.num_attention_heads, - inner_dim=self._intermediate_size, - inner_activation=self._activation, - output_dropout=self._dropout_rate, - attention_dropout=self._attention_dropout_rate, - use_bias=self._use_bias, - norm_first=self._norm_first, - norm_epsilon=self._norm_epsilon, - inner_dropout=self._intermediate_dropout, + inner_dim=self.intermediate_size, + inner_activation=self.activation, + output_dropout=self.dropout_rate, + attention_dropout=self.attentiondropout_rate, + use_bias=self.use_bias, + norm_first=self.norm_first, + norm_epsilon=self.norm_epsilon, + inner_dropout=self.intermediate_dropout, ) ) self.output_normalization = layers.LayerNormalization( - epsilon=self._norm_epsilon, dtype="float32" + epsilon=self.norm_epsilon, dtype="float32" ) super().build(input_shape) @@ -169,14 +244,14 @@ def get_config(self): config = { "num_layers": self.num_layers, "num_attention_heads": self.num_attention_heads, - "intermediate_size": self._intermediate_size, - "activation": self._activation, - "dropout_rate": self._dropout_rate, - "attention_dropout_rate": self._attention_dropout_rate, - "use_bias": self._use_bias, - "norm_first": self._norm_first, - "norm_epsilon": self._norm_epsilon, - "intermediate_dropout": self._intermediate_dropout, + "intermediate_size": self.intermediate_size, + "activation": self.activation, + "dropout_rate": self.dropout_rate, + "attentiondropout_rate": self.attentiondropout_rate, + "use_bias": self.use_bias, + "norm_first": self.norm_first, + "norm_epsilon": self.norm_epsilon, + "intermediate_dropout": self.intermediate_dropout, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) @@ -215,62 +290,62 @@ def __init__( ): super().__init__(**kwargs) - self._num_heads = num_attention_heads - self._inner_dim = inner_dim - self._inner_activation = inner_activation - self._attention_dropout = attention_dropout - self._attention_dropout_rate = attention_dropout - self._output_dropout = output_dropout - self._output_dropout_rate = output_dropout - self._use_bias = use_bias - self._norm_first = norm_first - self._norm_epsilon = norm_epsilon - self._inner_dropout = inner_dropout - self._attention_axes = attention_axes + self.num_heads = num_attention_heads + self.inner_dim = inner_dim + self.inner_activation = inner_activation + self.attention_dropout = attention_dropout + self.attentiondropout_rate = attention_dropout + self.output_dropout = output_dropout + self.outputdropout_rate = output_dropout + self.use_bias = use_bias + self.norm_first = norm_first + self.norm_epsilon = norm_epsilon + self.inner_dropout = inner_dropout + self.attention_axes = attention_axes def build(self, input_shape): - hidden_size = input_shape[-1] - if hidden_size % self._num_heads != 0: + hidden_size = input_shape[-1][-1] + if hidden_size % self.num_heads != 0: raise ValueError( "The input size (%d) is not a multiple of " "the number of attention heads (%d)" - % (hidden_size, self._num_heads) + % (hidden_size, self.num_heads) ) - self._attention_head_size = int(hidden_size // self._num_heads) - - self._attention_layer = layers.MultiHeadAttention( - num_heads=self._num_heads, - key_dim=self._attention_head_size, - dropout=self._attention_dropout, - use_bias=self._use_bias, - attention_axes=self._attention_axes, + self.attention_head_size = int(hidden_size // self.num_heads) + + self.attention_layer = layers.MultiHeadAttention( + num_heads=self.num_heads, + key_dim=self.attention_head_size, + dropout=self.attention_dropout, + use_bias=self.use_bias, + attention_axes=self.attention_axes, name="self_attention", ) - self._attention_dropout = layers.Dropout(rate=self._output_dropout) - self._attention_layer_norm = layers.LayerNormalization( + self.attention_dropout = layers.Dropout(rate=self.output_dropout) + self.attention_layer_norm = layers.LayerNormalization( name="self_attention_layer_norm", axis=-1, - epsilon=self._norm_epsilon, + epsilon=self.norm_epsilon, dtype="float32", ) - self._intermediate_dense = layers.Dense( - self._inner_dim, - activation=self._inner_activation, - use_bias=self._use_bias, + self.intermediate_dense = layers.Dense( + self.inner_dim, + activation=self.inner_activation, + use_bias=self.use_bias, name="intermediate", ) - self._inner_dropout_layer = layers.Dropout(rate=self._inner_dropout) - self._output_dense = layers.Dense( + self.inner_dropout_layer = layers.Dropout(rate=self.inner_dropout) + self.output_dense = layers.Dense( hidden_size, - use_bias=self._use_bias, + use_bias=self.use_bias, name="output", ) - self._output_dropout = layers.Dropout(rate=self._output_dropout) - self._output_layer_norm = layers.LayerNormalization( + self.output_dropout = layers.Dropout(rate=self.output_dropout) + self.output_layer_norm = layers.LayerNormalization( name="output_layer_norm", axis=-1, - epsilon=self._norm_epsilon, + epsilon=self.norm_epsilon, dtype="float32", ) @@ -278,16 +353,16 @@ def build(self, input_shape): def get_config(self): config = { - "num_attention_heads": self._num_heads, - "inner_dim": self._inner_dim, - "inner_activation": self._inner_activation, - "output_dropout": self._output_dropout_rate, - "attention_dropout": self._attention_dropout_rate, - "use_bias": self._use_bias, - "norm_first": self._norm_first, - "norm_epsilon": self._norm_epsilon, - "inner_dropout": self._inner_dropout, - "attention_axes": self._attention_axes, + "num_attention_heads": self.num_heads, + "inner_dim": self.inner_dim, + "inner_activation": self.inner_activation, + "output_dropout": self.outputdropout_rate, + "attention_dropout": self.attentiondropout_rate, + "use_bias": self.use_bias, + "norm_first": self.norm_first, + "norm_epsilon": self.norm_epsilon, + "inner_dropout": self.inner_dropout, + "attention_axes": self.attention_axes, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) @@ -295,37 +370,37 @@ def get_config(self): def call(self, inputs): input_tensor, attention_mask, pos_embed = inputs - if self._norm_first: + if self.norm_first: source_tensor = input_tensor - input_tensor = self._attention_layer_norm(input_tensor) + input_tensor = self.attention_layer_norm(input_tensor) target_tensor = input_tensor - attention_output = self._attention_layer( + attention_output = self.attention_layer( query=target_tensor + pos_embed, key=input_tensor + pos_embed, value=input_tensor, attention_mask=attention_mask, ) - attention_output = self._attention_dropout(attention_output) - if self._norm_first: + attention_output = self.attention_dropout(attention_output) + if self.norm_first: attention_output = source_tensor + attention_output else: - attention_output = self._attention_layer_norm( + attention_output = self.attention_layer_norm( target_tensor + attention_output ) - if self._norm_first: + if self.norm_first: source_attention_output = attention_output - attention_output = self._output_layer_norm(attention_output) + attention_output = self.output_layer_norm(attention_output) - inner_output = self._intermediate_dense(attention_output) - inner_output = self._inner_dropout_layer(inner_output) - layer_output = self._output_dense(inner_output) - layer_output = self._output_dropout(layer_output) + inner_output = self.intermediate_dense(attention_output) + inner_output = self.inner_dropout_layer(inner_output) + layer_output = self.output_dense(inner_output) + layer_output = self.output_dropout(layer_output) - if self._norm_first: + if self.norm_first: return source_attention_output + layer_output - return self._output_layer_norm(layer_output + attention_output) + return self.output_layer_norm(layer_output + attention_output) class DetrTransformerDecoder(layers.Layer): @@ -341,7 +416,7 @@ def __init__( intermediate_size=2048, activation="relu", dropout_rate=0.0, - attention_dropout_rate=0.0, + attentiondropout_rate=0.0, use_bias=False, norm_first=True, norm_epsilon=1e-6, @@ -351,14 +426,14 @@ def __init__( super().__init__(**kwargs) self.num_layers = num_layers self.num_attention_heads = num_attention_heads - self._intermediate_size = intermediate_size - self._activation = activation - self._dropout_rate = dropout_rate - self._attention_dropout_rate = attention_dropout_rate - self._use_bias = use_bias - self._norm_first = norm_first - self._norm_epsilon = norm_epsilon - self._intermediate_dropout = intermediate_dropout + self.intermediate_size = intermediate_size + self.activation = activation + self.dropout_rate = dropout_rate + self.attentiondropout_rate = attentiondropout_rate + self.use_bias = use_bias + self.norm_first = norm_first + self.norm_epsilon = norm_epsilon + self.intermediate_dropout = intermediate_dropout def build(self, input_shape): self.decoder_layers = [] @@ -366,19 +441,19 @@ def build(self, input_shape): self.decoder_layers.append( DetrTransformerDecoderBlock( num_attention_heads=self.num_attention_heads, - intermediate_size=self._intermediate_size, - intermediate_activation=self._activation, - dropout_rate=self._dropout_rate, - attention_dropout_rate=self._attention_dropout_rate, - use_bias=self._use_bias, - norm_first=self._norm_first, - norm_epsilon=self._norm_epsilon, - intermediate_dropout=self._intermediate_dropout, + intermediate_size=self.intermediate_size, + intermediate_activation=self.activation, + dropout_rate=self.dropout_rate, + attentiondropout_rate=self.attentiondropout_rate, + use_bias=self.use_bias, + norm_first=self.norm_first, + norm_epsilon=self.norm_epsilon, + intermediate_dropout=self.intermediate_dropout, name=("layer_%d" % i), ) ) self.output_normalization = layers.LayerNormalization( - epsilon=self._norm_epsilon, dtype="float32" + epsilon=self.norm_epsilon, dtype="float32" ) super().build(input_shape) @@ -386,14 +461,14 @@ def get_config(self): config = { "num_layers": self.num_layers, "num_attention_heads": self.num_attention_heads, - "intermediate_size": self._intermediate_size, - "activation": self._activation, - "dropout_rate": self._dropout_rate, - "attention_dropout_rate": self._attention_dropout_rate, - "use_bias": self._use_bias, - "norm_first": self._norm_first, - "norm_epsilon": self._norm_epsilon, - "intermediate_dropout": self._intermediate_dropout, + "intermediate_size": self.intermediate_size, + "activation": self.activation, + "dropout_rate": self.dropout_rate, + "attentiondropout_rate": self.attentiondropout_rate, + "use_bias": self.use_bias, + "norm_first": self.norm_first, + "norm_epsilon": self.norm_epsilon, + "intermediate_dropout": self.intermediate_dropout, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) @@ -443,7 +518,7 @@ def __init__( intermediate_size, intermediate_activation, dropout_rate=0.0, - attention_dropout_rate=0.0, + attentiondropout_rate=0.0, use_bias=True, norm_first=False, norm_epsilon=1e-12, @@ -455,16 +530,16 @@ def __init__( self.intermediate_size = intermediate_size self.intermediate_activation = activations.get(intermediate_activation) self.dropout_rate = dropout_rate - self.attention_dropout_rate = attention_dropout_rate - - self._use_bias = use_bias - self._norm_first = norm_first - self._norm_epsilon = norm_epsilon - self._intermediate_dropout = intermediate_dropout + self.attentiondropout_rate = attentiondropout_rate - self._cross_attention_cls = layers.MultiHeadAttention + self.use_bias = use_bias + self.norm_first = norm_first + self.norm_epsilon = norm_epsilon + self.intermediate_dropout = intermediate_dropout def build(self, input_shape): + # List of lists + input_shape = input_shape[0] if len(input_shape) != 3: raise ValueError( "TransformerLayer expects a three-dimensional input of " @@ -483,8 +558,8 @@ def build(self, input_shape): self.self_attention = layers.MultiHeadAttention( num_heads=self.num_attention_heads, key_dim=self.attention_head_size, - dropout=self.attention_dropout_rate, - use_bias=self._use_bias, + dropout=self.attentiondropout_rate, + use_bias=self.use_bias, name="self_attention", ) self.self_attention_output_dense = layers.EinsumDense( @@ -497,24 +572,24 @@ def build(self, input_shape): self.self_attention_layer_norm = layers.LayerNormalization( name="self_attention_layer_norm", axis=-1, - epsilon=self._norm_epsilon, + epsilon=self.norm_epsilon, dtype="float32", ) # Encoder-decoder attention. - self.encdec_attention = self._cross_attention_cls( + self.encdec_attention = layers.MultiHeadAttention( num_heads=self.num_attention_heads, key_dim=self.attention_head_size, - dropout=self.attention_dropout_rate, + dropout=self.attentiondropout_rate, output_shape=hidden_size, - use_bias=self._use_bias, - name="attention/encdec", + use_bias=self.use_bias, + name="encdec", ) self.encdec_attention_dropout = layers.Dropout(rate=self.dropout_rate) self.encdec_attention_layer_norm = layers.LayerNormalization( - name="attention/encdec_output_layer_norm", + name="encdec_output_layer_norm", axis=-1, - epsilon=self._norm_epsilon, + epsilon=self.norm_epsilon, dtype="float32", ) @@ -528,8 +603,8 @@ def build(self, input_shape): self.intermediate_activation_layer = layers.Activation( self.intermediate_activation ) - self._intermediate_dropout_layer = layers.Dropout( - rate=self._intermediate_dropout + self.intermediate_dropout_layer = layers.Dropout( + rate=self.intermediate_dropout ) self.output_dense = layers.EinsumDense( "abc,cd->abd", @@ -541,7 +616,7 @@ def build(self, input_shape): self.output_layer_norm = layers.LayerNormalization( name="output_layer_norm", axis=-1, - epsilon=self._norm_epsilon, + epsilon=self.norm_epsilon, dtype="float32", ) super().build(input_shape) @@ -551,11 +626,11 @@ def get_config(self): "num_attention_heads": self.num_attention_heads, "intermediate_size": self.intermediate_size, "dropout_rate": self.dropout_rate, - "attention_dropout_rate": self.attention_dropout_rate, - "use_bias": self._use_bias, - "norm_first": self._norm_first, - "norm_epsilon": self._norm_epsilon, - "intermediate_dropout": self._intermediate_dropout, + "attentiondropout_rate": self.attentiondropout_rate, + "use_bias": self.use_bias, + "norm_first": self.norm_first, + "norm_epsilon": self.norm_epsilon, + "intermediate_dropout": self.intermediate_dropout, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) @@ -570,7 +645,7 @@ def call(self, inputs): memory_pos_embed, ) = inputs source_tensor = input_tensor - if self._norm_first: + if self.norm_first: input_tensor = self.self_attention_layer_norm(input_tensor) self_attention_output = self.self_attention( query=input_tensor + input_pos_embed, @@ -581,13 +656,13 @@ def call(self, inputs): self_attention_output = self.self_attention_dropout( self_attention_output ) - if self._norm_first: + if self.norm_first: self_attention_output = source_tensor + self_attention_output else: self_attention_output = self.self_attention_layer_norm( input_tensor + self_attention_output ) - if self._norm_first: + if self.norm_first: source_self_attention_output = self_attention_output self_attention_output = self.encdec_attention_layer_norm( self_attention_output @@ -600,13 +675,13 @@ def call(self, inputs): ) attention_output = self.encdec_attention(**cross_attn_inputs) attention_output = self.encdec_attention_dropout(attention_output) - if self._norm_first: + if self.norm_first: attention_output = source_self_attention_output + attention_output else: attention_output = self.encdec_attention_layer_norm( self_attention_output + attention_output ) - if self._norm_first: + if self.norm_first: source_attention_output = attention_output attention_output = self.output_layer_norm(attention_output) @@ -614,12 +689,12 @@ def call(self, inputs): intermediate_output = self.intermediate_activation_layer( intermediate_output ) - intermediate_output = self._intermediate_dropout_layer( + intermediate_output = self.intermediate_dropout_layer( intermediate_output ) layer_output = self.output_dense(intermediate_output) layer_output = self.output_dropout(layer_output) - if self._norm_first: + if self.norm_first: layer_output = source_attention_output + layer_output else: layer_output = self.output_layer_norm( @@ -629,7 +704,7 @@ def call(self, inputs): class DETRTransformer(Layer): - """Encoder and decoder, forming a DETRTransformer.""" + """Encoder and Decoder of DETR.""" def __init__( self, @@ -641,42 +716,42 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self._dropout_rate = dropout_rate - self._num_encoder_layers = num_encoder_layers - self._num_decoder_layers = num_decoder_layers - self._num_attention_heads = num_attention_heads - self._intermediate_size = intermediate_size + self.dropout_rate = dropout_rate + self.num_encoder_layers = num_encoder_layers + self.num_decoder_layers = num_decoder_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size def build(self, input_shape=None): - if self._num_encoder_layers > 0: - self._encoder = DetrTransformerEncoder( - attention_dropout_rate=self._dropout_rate, - dropout_rate=self._dropout_rate, - intermediate_dropout=self._dropout_rate, + if self.num_encoder_layers > 0: + self.encoder = DetrTransformerEncoder( + attentiondropout_rate=self.dropout_rate, + dropout_rate=self.dropout_rate, + intermediate_dropout=self.dropout_rate, norm_first=False, - num_layers=self._num_encoder_layers, - num_attention_heads=self._num_attention_heads, - intermediate_size=self._intermediate_size, + num_layers=self.num_encoder_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, ) else: - self._encoder = None + self.encoder = None - self._decoder = DetrTransformerDecoder( - attention_dropout_rate=self._dropout_rate, - dropout_rate=self._dropout_rate, - intermediate_dropout=self._dropout_rate, + self.decoder = DetrTransformerDecoder( + attentiondropout_rate=self.dropout_rate, + dropout_rate=self.dropout_rate, + intermediate_dropout=self.dropout_rate, norm_first=False, - num_layers=self._num_decoder_layers, - num_attention_heads=self._num_attention_heads, - intermediate_size=self._intermediate_size, + num_layers=self.num_decoder_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, ) super().build(input_shape) def get_config(self): return { - "num_encoder_layers": self._num_encoder_layers, - "num_decoder_layers": self._num_decoder_layers, - "dropout_rate": self._dropout_rate, + "num_encoder_layers": self.num_encoder_layers, + "num_decoder_layers": self.num_decoder_layers, + "dropout_rate": self.dropout_rate, } def call(self, inputs): @@ -688,8 +763,9 @@ def call(self, inputs): source_attention_mask = ops.tile( ops.expand_dims(mask, axis=1), [1, input_shape[1], 1] ) - if self._encoder is not None: - memory = self._encoder( + + if self.encoder is not None: + memory = self.encoder( sources, attention_mask=source_attention_mask, pos_embed=pos_embed, @@ -703,7 +779,7 @@ def call(self, inputs): ) target_shape = ops.shape(targets) - decoded = self._decoder( + decoded = self.decoder( ops.zeros_like(targets), memory, self_attention_mask=ops.ones( From 1ff98ddcf89f429b458616f6b3645573dedaefaf Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 22 Jan 2025 20:50:11 +0900 Subject: [PATCH 10/11] Add docstrings and start porting subclass into keras backbone --- keras_hub/src/models/detr/detr_backbone.py | 31 +++++++++++++++++----- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/detr/detr_backbone.py b/keras_hub/src/models/detr/detr_backbone.py index f210459d73..b99955997b 100644 --- a/keras_hub/src/models/detr/detr_backbone.py +++ b/keras_hub/src/models/detr/detr_backbone.py @@ -1,9 +1,11 @@ -from keras import Model from keras import layers from keras import ops from src.models.detr.detr_layers import DETRTransformer from src.models.detr.detr_layers import position_embedding_sine +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone + def _freeze_batch_norm(model): """DETR uses "frozen" batch norm, i.e. batch normalization @@ -29,12 +31,29 @@ def _freeze_batch_norm(model): return model -class DETR(Model): - """DETR Model. +@keras_hub_export("keras_hub.models.DETR") +class DETR(Backbone): + """A Keras model implementing DETR for object detection. + + This class implements the majority of the DETR architecture described + in [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) + and based on the [TensorFlow implementation] + (https://github.com/tensorflow/models/tree/master/official/projects/detr). + + DETR is meant to be used with a modified ResNet50 backbone/encoder. + + Args: + image_encoder: `keras.Model`. The backbone network for the model that is + used as a feature extractor for the SegFormer encoder. + Should be used with `keras_hub.models.ResNetBackbone.from_preset("resnet_50_imagenet")`. + ... + + Examples: + + ``` + # todo + ``` - Includes a backbone (ResNet50), query embedding, - DETRTransformer (DetrTransformerEncoder + DetrTransformerDecoder) - class and box heads. """ def __init__( From d3eb561abb3dc22331b9d5d2c236ba3529c4ce34 Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 22 Jan 2025 21:20:12 +0900 Subject: [PATCH 11/11] More refactors into functional sublcassing --- keras_hub/src/models/detr/detr_backbone.py | 154 +++++++++++---------- keras_hub/src/models/detr/detr_layers.py | 13 +- 2 files changed, 85 insertions(+), 82 deletions(-) diff --git a/keras_hub/src/models/detr/detr_backbone.py b/keras_hub/src/models/detr/detr_backbone.py index b99955997b..2d4275f6f8 100644 --- a/keras_hub/src/models/detr/detr_backbone.py +++ b/keras_hub/src/models/detr/detr_backbone.py @@ -45,7 +45,8 @@ class DETR(Backbone): Args: image_encoder: `keras.Model`. The backbone network for the model that is used as a feature extractor for the SegFormer encoder. - Should be used with `keras_hub.models.ResNetBackbone.from_preset("resnet_50_imagenet")`. + Should be used with + `keras_hub.models.ResNetBackbone.from_preset("resnet_50_imagenet")`. ... Examples: @@ -67,7 +68,76 @@ def __init__( dropout_rate=0.1, **kwargs, ): - super().__init__(**kwargs) + # === Layers === + inputs = layers.Input(shape=backbone.input.shape[1:]) + + input_proj = layers.Conv2D(hidden_size, 1, name="conv2d") + transformer = DETRTransformer( + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dropout_rate=dropout_rate, + ) + # query_embeddings = self.add_weight( + # shape=[num_queries, hidden_size], + # ) + # cannot call self.add_weight before super() + # TODO: look into how to work around this. + # for the time being, initialize query_embeddings + # as a static vector + query_embeddings = ops.ones([num_queries, hidden_size]) + + class_embed = layers.Dense(num_classes, name="cls_dense") + bbox_embed = [ + layers.Dense(hidden_size, activation="relu", name="box_dense_0"), + layers.Dense(hidden_size, activation="relu", name="box_dense_1"), + layers.Dense(4, name="box_dense_2"), + ] + + # === Functional Model === + batch_size = ops.shape(inputs)[0] + features = backbone(inputs) + shape = ops.shape(features) + mask = self._generate_image_mask(inputs, shape[1:3]) + + pos_embed = position_embedding_sine( + mask[:, :, :, 0], num_pos_features=hidden_size + ) + pos_embed = ops.reshape(pos_embed, [batch_size, -1, hidden_size]) + + features = ops.reshape( + input_proj(features), [batch_size, -1, hidden_size] + ) + mask = ops.reshape(mask, [batch_size, -1]) + + decoded_list = transformer( + { + "inputs": features, + "targets": ops.tile( + ops.expand_dims(query_embeddings, axis=0), + (batch_size, 1, 1), + ), + "pos_embed": pos_embed, + "mask": mask, + } + ) + out_list = [] + for decoded in decoded_list: + decoded = ops.stack(decoded) + output_class = class_embed(decoded) + box_out = decoded + for layer in bbox_embed: + box_out = layer(box_out) + output_coord = layers.Activation("sigmoid")(box_out) + out = {"cls_outputs": output_class, "box_outputs": output_coord} + out_list.append(out) + + super().__init__( + inputs=inputs, + outputs=out_list, + **kwargs, + ) + + # === Config === self.num_queries = num_queries self.hidden_size = hidden_size self.num_classes = num_classes @@ -78,38 +148,6 @@ def __init__( raise ValueError("hidden_size must be a multiple of 2.") self.backbone = backbone - def build(self, input_shape=None): - self.input_proj = layers.Conv2D(self.hidden_size, 1, name="conv2d") - self.build_detection_decoder() - super().build(input_shape) - - def _build_detection_decoder(self): - """Builds detection decoder.""" - self.transformer = DETRTransformer( - num_encoder_layers=self.num_encoder_layers, - num_decoder_layers=self.num_decoder_layers, - dropout_rate=self.dropout_rate, - ) - self.query_embeddings = self.add_weight( - shape=[self.num_queries, self.hidden_size], - ) - # sqrt_k = math.sqrt(1.0 / self.hidden_size) - self.class_embed = layers.Dense(self.num_classes, name="cls_dense") - self.bbox_embed = [ - layers.Dense( - self.hidden_size, activation="relu", name="box_dense_0" - ), - layers.Dense( - self.hidden_size, activation="relu", name="box_dense_1" - ), - layers.Dense(4, name="box_dense_2"), - ] - self.sigmoid = layers.Activation("sigmoid") - - @property - def backbone(self): - return self.backbone - def get_config(self): return { "backbone": self.backbone, @@ -121,10 +159,18 @@ def get_config(self): "dropout_rate": self.dropout_rate, } + @property + def backbone(self): + return self.backbone + @classmethod def from_config(cls, config): return cls(**config) + def build(self, input_shape=None): + self.build_detection_decoder() + super().build(input_shape) + def _generate_image_mask(self, inputs, target_shape): """Generates image mask from input image.""" mask = ops.expand_dims( @@ -133,43 +179,3 @@ def _generate_image_mask(self, inputs, target_shape): ) mask = ops.image.resize(mask, target_shape, interpolation="nearest") return mask - - def call(self, inputs, training=None): - batch_size = ops.shape(inputs)[0] - features = self.backbone(inputs) - shape = ops.shape(features) - mask = self.generate_image_mask(inputs, shape[1:3]) - - pos_embed = position_embedding_sine( - mask[:, :, :, 0], num_pos_features=self.hidden_size - ) - pos_embed = ops.reshape(pos_embed, [batch_size, -1, self.hidden_size]) - - features = ops.reshape( - self.input_proj(features), [batch_size, -1, self.hidden_size] - ) - mask = ops.reshape(mask, [batch_size, -1]) - - decoded_list = self.transformer( - { - "inputs": features, - "targets": ops.tile( - ops.expand_dims(self.query_embeddings, axis=0), - (batch_size, 1, 1), - ), - "pos_embed": pos_embed, - "mask": mask, - } - ) - out_list = [] - for decoded in decoded_list: - decoded = ops.stack(decoded) - output_class = self.class_embed(decoded) - box_out = decoded - for layer in self.bbox_embed: - box_out = layer(box_out) - output_coord = self.sigmoid(box_out) - out = {"cls_outputs": output_class, "box_outputs": output_coord} - out_list.append(out) - - return out_list diff --git a/keras_hub/src/models/detr/detr_layers.py b/keras_hub/src/models/detr/detr_layers.py index 4e1a38e6a3..f381a10f96 100644 --- a/keras_hub/src/models/detr/detr_layers.py +++ b/keras_hub/src/models/detr/detr_layers.py @@ -132,11 +132,11 @@ def position_embedding_sine( should be equal to the hidden size of the Transformer network temperature: a `float` specifying the temperature of the positional embedding. Any type that is converted to a `float` can also be accepted. - normalize: a `bool` determining whether the positional embeddings should be - normalized between [0, scale] before application of the sine and cos - functions. - scale: a `float` if normalize is True specifying the scale embeddings before - application of the embedding function. + normalize: a `bool` determining whether the positional embeddings + should be normalized between [0, scale] before application + of the sine and cos functions. + scale: a `float` if normalize is True specifying the + scale embeddings before application of the embedding function. Returns: embeddings: a `float` tensor of the same shape as input_tensor specifying @@ -184,9 +184,6 @@ def position_embedding_sine( return output -from keras.layers import Layer - - class DetrTransformerEncoder(layers.Layer): """ Adapted from