Skip to content

Dense layer not built when created in init - leads to undefined shape and unable to be used in call #18458

Closed
@grasskin

Description

@grasskin

From the example https://keras.io/examples/vision/vit_small_ds/ when building the projection layer self.projection is unable to be built and have its shape inferred automatically. Including a full traceback. In the interest of getting through as many examples as possible filing this issue so we can come back to it. Workaround unclear.

Traceback (most recent call last):
  File "/Users/grasskin/Desktop/keras-core/examples/keras_io/tensorflow/vision/vit_small_ds.py", line 533, in <module>
    # Run experiments with the vanilla ViT
  File "/Users/grasskin/Desktop/keras-core/examples/keras_io/tensorflow/vision/vit_small_ds.py", line 407, in create_vit_classifier
    # Create patches.
  File "/Users/grasskin/Desktop/keras-core/keras_core/utils/traceback_utils.py", line 113, in error_handler
    return fn(*args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/layers/layer.py", line 634, in __call__
    outputs = super().__call__(*args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/utils/traceback_utils.py", line 113, in error_handler
    return fn(*args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/operations/operation.py", line 45, in __call__
    return self.symbolic_call(*args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/operations/operation.py", line 50, in symbolic_call
    outputs = self.compute_output_spec(*args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/layers/layer.py", line 766, in compute_output_spec
    return super().compute_output_spec(*args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/operations/operation.py", line 77, in compute_output_spec
    raise new_e.with_traceback(e.__traceback__) from None
  File "/Users/grasskin/Desktop/keras-core/keras_core/operations/operation.py", line 67, in compute_output_spec
    return backend.compute_output_spec(self.call, *args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/backend/tensorflow/core.py", line 101, in compute_output_spec
    with StatelessScope():
  File "/Users/grasskin/Desktop/keras-core/keras_core/backend/common/stateless_scope.py", line 60, in __exit__
    initialize_all_variables()
  File "/Users/grasskin/Desktop/keras-core/keras_core/backend/common/variables.py", line 358, in initialize_all_variables
    v._deferred_initialize()
  File "/Users/grasskin/Desktop/keras-core/keras_core/backend/common/variables.py", line 74, in _deferred_initialize
    value = self._initializer(self._shape, dtype=self._dtype)
  File "/Users/grasskin/Desktop/keras-core/keras_core/initializers/random_initializers.py", line 258, in __call__
    fan_in, fan_out = compute_fans(shape)
  File "/Users/grasskin/Desktop/keras-core/keras_core/initializers/random_initializers.py", line 588, in compute_fans
    return int(fan_in), int(fan_out)
class ShiftedPatchTokenization(layers.Layer):
    def __init__(
        self,
        image_size=IMAGE_SIZE,
        patch_size=PATCH_SIZE,
        num_patches=NUM_PATCHES,
        projection_dim=PROJECTION_DIM,
        vanilla=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.vanilla = vanilla  # Flag to switch to vanilla patch extractor
        self.image_size = image_size
        self.patch_size = patch_size
        self.half_patch = patch_size // 2
        self.flatten_patches = layers.Reshape((num_patches, -1))
        self.projection = layers.Dense(units=projection_dim)
        self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)

    def crop_shift_pad(self, images, mode):
        # Build the diagonally shifted images
        if mode == "left-up":
            crop_height = self.half_patch
            crop_width = self.half_patch
            shift_height = 0
            shift_width = 0
        elif mode == "left-down":
            crop_height = 0
            crop_width = self.half_patch
            shift_height = self.half_patch
            shift_width = 0
        elif mode == "right-up":
            crop_height = self.half_patch
            crop_width = 0
            shift_height = 0
            shift_width = self.half_patch
        else:
            crop_height = 0
            crop_width = 0
            shift_height = self.half_patch
            shift_width = self.half_patch

        # Crop the shifted images and pad them
        crop = tf.image.crop_to_bounding_box(
            images,
            offset_height=crop_height,
            offset_width=crop_width,
            target_height=self.image_size - self.half_patch,
            target_width=self.image_size - self.half_patch,
        )
        shift_pad = tf.image.pad_to_bounding_box(
            crop,
            offset_height=shift_height,
            offset_width=shift_width,
            target_height=self.image_size,
            target_width=self.image_size,
        )
        return shift_pad

    def call(self, images):
        if not self.vanilla:
            # Concat the shifted images with the original image
            images = tf.concat(
                [
                    images,
                    self.crop_shift_pad(images, mode="left-up"),
                    self.crop_shift_pad(images, mode="left-down"),
                    self.crop_shift_pad(images, mode="right-up"),
                    self.crop_shift_pad(images, mode="right-down"),
                ],
                axis=-1,
            )
        # Patchify the images and flatten it
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        flat_patches = self.flatten_patches(patches)
        if not self.vanilla:
            # Layer normalize the flat patches and linearly project it
            tokens = self.layer_norm(flat_patches)
            tokens = self.projection(tokens)
        else:
            # Linearly project the flat patches
            tokens = self.projection(flat_patches)
        return (tokens, patches)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions