Closed
Description
From the example https://keras.io/examples/vision/vit_small_ds/ when building the projection layer self.projection
is unable to be built and have its shape inferred automatically. Including a full traceback. In the interest of getting through as many examples as possible filing this issue so we can come back to it. Workaround unclear.
Traceback (most recent call last):
File "/Users/grasskin/Desktop/keras-core/examples/keras_io/tensorflow/vision/vit_small_ds.py", line 533, in <module>
# Run experiments with the vanilla ViT
File "/Users/grasskin/Desktop/keras-core/examples/keras_io/tensorflow/vision/vit_small_ds.py", line 407, in create_vit_classifier
# Create patches.
File "/Users/grasskin/Desktop/keras-core/keras_core/utils/traceback_utils.py", line 113, in error_handler
return fn(*args, **kwargs)
File "/Users/grasskin/Desktop/keras-core/keras_core/layers/layer.py", line 634, in __call__
outputs = super().__call__(*args, **kwargs)
File "/Users/grasskin/Desktop/keras-core/keras_core/utils/traceback_utils.py", line 113, in error_handler
return fn(*args, **kwargs)
File "/Users/grasskin/Desktop/keras-core/keras_core/operations/operation.py", line 45, in __call__
return self.symbolic_call(*args, **kwargs)
File "/Users/grasskin/Desktop/keras-core/keras_core/operations/operation.py", line 50, in symbolic_call
outputs = self.compute_output_spec(*args, **kwargs)
File "/Users/grasskin/Desktop/keras-core/keras_core/layers/layer.py", line 766, in compute_output_spec
return super().compute_output_spec(*args, **kwargs)
File "/Users/grasskin/Desktop/keras-core/keras_core/operations/operation.py", line 77, in compute_output_spec
raise new_e.with_traceback(e.__traceback__) from None
File "/Users/grasskin/Desktop/keras-core/keras_core/operations/operation.py", line 67, in compute_output_spec
return backend.compute_output_spec(self.call, *args, **kwargs)
File "/Users/grasskin/Desktop/keras-core/keras_core/backend/tensorflow/core.py", line 101, in compute_output_spec
with StatelessScope():
File "/Users/grasskin/Desktop/keras-core/keras_core/backend/common/stateless_scope.py", line 60, in __exit__
initialize_all_variables()
File "/Users/grasskin/Desktop/keras-core/keras_core/backend/common/variables.py", line 358, in initialize_all_variables
v._deferred_initialize()
File "/Users/grasskin/Desktop/keras-core/keras_core/backend/common/variables.py", line 74, in _deferred_initialize
value = self._initializer(self._shape, dtype=self._dtype)
File "/Users/grasskin/Desktop/keras-core/keras_core/initializers/random_initializers.py", line 258, in __call__
fan_in, fan_out = compute_fans(shape)
File "/Users/grasskin/Desktop/keras-core/keras_core/initializers/random_initializers.py", line 588, in compute_fans
return int(fan_in), int(fan_out)
class ShiftedPatchTokenization(layers.Layer):
def __init__(
self,
image_size=IMAGE_SIZE,
patch_size=PATCH_SIZE,
num_patches=NUM_PATCHES,
projection_dim=PROJECTION_DIM,
vanilla=False,
**kwargs,
):
super().__init__(**kwargs)
self.vanilla = vanilla # Flag to switch to vanilla patch extractor
self.image_size = image_size
self.patch_size = patch_size
self.half_patch = patch_size // 2
self.flatten_patches = layers.Reshape((num_patches, -1))
self.projection = layers.Dense(units=projection_dim)
self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
def crop_shift_pad(self, images, mode):
# Build the diagonally shifted images
if mode == "left-up":
crop_height = self.half_patch
crop_width = self.half_patch
shift_height = 0
shift_width = 0
elif mode == "left-down":
crop_height = 0
crop_width = self.half_patch
shift_height = self.half_patch
shift_width = 0
elif mode == "right-up":
crop_height = self.half_patch
crop_width = 0
shift_height = 0
shift_width = self.half_patch
else:
crop_height = 0
crop_width = 0
shift_height = self.half_patch
shift_width = self.half_patch
# Crop the shifted images and pad them
crop = tf.image.crop_to_bounding_box(
images,
offset_height=crop_height,
offset_width=crop_width,
target_height=self.image_size - self.half_patch,
target_width=self.image_size - self.half_patch,
)
shift_pad = tf.image.pad_to_bounding_box(
crop,
offset_height=shift_height,
offset_width=shift_width,
target_height=self.image_size,
target_width=self.image_size,
)
return shift_pad
def call(self, images):
if not self.vanilla:
# Concat the shifted images with the original image
images = tf.concat(
[
images,
self.crop_shift_pad(images, mode="left-up"),
self.crop_shift_pad(images, mode="left-down"),
self.crop_shift_pad(images, mode="right-up"),
self.crop_shift_pad(images, mode="right-down"),
],
axis=-1,
)
# Patchify the images and flatten it
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
flat_patches = self.flatten_patches(patches)
if not self.vanilla:
# Layer normalize the flat patches and linearly project it
tokens = self.layer_norm(flat_patches)
tokens = self.projection(tokens)
else:
# Linearly project the flat patches
tokens = self.projection(flat_patches)
return (tokens, patches)