Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dense layer not built when created in init - leads to undefined shape and unable to be used in call #18458

Open
grasskin opened this issue Jun 8, 2023 · 5 comments

Comments

@grasskin
Copy link
Member

grasskin commented Jun 8, 2023

From the example https://keras.io/examples/vision/vit_small_ds/ when building the projection layer self.projection is unable to be built and have its shape inferred automatically. Including a full traceback. In the interest of getting through as many examples as possible filing this issue so we can come back to it. Workaround unclear.

Traceback (most recent call last):
  File "/Users/grasskin/Desktop/keras-core/examples/keras_io/tensorflow/vision/vit_small_ds.py", line 533, in <module>
    # Run experiments with the vanilla ViT
  File "/Users/grasskin/Desktop/keras-core/examples/keras_io/tensorflow/vision/vit_small_ds.py", line 407, in create_vit_classifier
    # Create patches.
  File "/Users/grasskin/Desktop/keras-core/keras_core/utils/traceback_utils.py", line 113, in error_handler
    return fn(*args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/layers/layer.py", line 634, in __call__
    outputs = super().__call__(*args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/utils/traceback_utils.py", line 113, in error_handler
    return fn(*args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/operations/operation.py", line 45, in __call__
    return self.symbolic_call(*args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/operations/operation.py", line 50, in symbolic_call
    outputs = self.compute_output_spec(*args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/layers/layer.py", line 766, in compute_output_spec
    return super().compute_output_spec(*args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/operations/operation.py", line 77, in compute_output_spec
    raise new_e.with_traceback(e.__traceback__) from None
  File "/Users/grasskin/Desktop/keras-core/keras_core/operations/operation.py", line 67, in compute_output_spec
    return backend.compute_output_spec(self.call, *args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/backend/tensorflow/core.py", line 101, in compute_output_spec
    with StatelessScope():
  File "/Users/grasskin/Desktop/keras-core/keras_core/backend/common/stateless_scope.py", line 60, in __exit__
    initialize_all_variables()
  File "/Users/grasskin/Desktop/keras-core/keras_core/backend/common/variables.py", line 358, in initialize_all_variables
    v._deferred_initialize()
  File "/Users/grasskin/Desktop/keras-core/keras_core/backend/common/variables.py", line 74, in _deferred_initialize
    value = self._initializer(self._shape, dtype=self._dtype)
  File "/Users/grasskin/Desktop/keras-core/keras_core/initializers/random_initializers.py", line 258, in __call__
    fan_in, fan_out = compute_fans(shape)
  File "/Users/grasskin/Desktop/keras-core/keras_core/initializers/random_initializers.py", line 588, in compute_fans
    return int(fan_in), int(fan_out)
class ShiftedPatchTokenization(layers.Layer):
    def __init__(
        self,
        image_size=IMAGE_SIZE,
        patch_size=PATCH_SIZE,
        num_patches=NUM_PATCHES,
        projection_dim=PROJECTION_DIM,
        vanilla=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.vanilla = vanilla  # Flag to switch to vanilla patch extractor
        self.image_size = image_size
        self.patch_size = patch_size
        self.half_patch = patch_size // 2
        self.flatten_patches = layers.Reshape((num_patches, -1))
        self.projection = layers.Dense(units=projection_dim)
        self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)

    def crop_shift_pad(self, images, mode):
        # Build the diagonally shifted images
        if mode == "left-up":
            crop_height = self.half_patch
            crop_width = self.half_patch
            shift_height = 0
            shift_width = 0
        elif mode == "left-down":
            crop_height = 0
            crop_width = self.half_patch
            shift_height = self.half_patch
            shift_width = 0
        elif mode == "right-up":
            crop_height = self.half_patch
            crop_width = 0
            shift_height = 0
            shift_width = self.half_patch
        else:
            crop_height = 0
            crop_width = 0
            shift_height = self.half_patch
            shift_width = self.half_patch

        # Crop the shifted images and pad them
        crop = tf.image.crop_to_bounding_box(
            images,
            offset_height=crop_height,
            offset_width=crop_width,
            target_height=self.image_size - self.half_patch,
            target_width=self.image_size - self.half_patch,
        )
        shift_pad = tf.image.pad_to_bounding_box(
            crop,
            offset_height=shift_height,
            offset_width=shift_width,
            target_height=self.image_size,
            target_width=self.image_size,
        )
        return shift_pad

    def call(self, images):
        if not self.vanilla:
            # Concat the shifted images with the original image
            images = tf.concat(
                [
                    images,
                    self.crop_shift_pad(images, mode="left-up"),
                    self.crop_shift_pad(images, mode="left-down"),
                    self.crop_shift_pad(images, mode="right-up"),
                    self.crop_shift_pad(images, mode="right-down"),
                ],
                axis=-1,
            )
        # Patchify the images and flatten it
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        flat_patches = self.flatten_patches(patches)
        if not self.vanilla:
            # Layer normalize the flat patches and linearly project it
            tokens = self.layer_norm(flat_patches)
            tokens = self.projection(tokens)
        else:
            # Linearly project the flat patches
            tokens = self.projection(flat_patches)
        return (tokens, patches)
@fchollet
Copy link
Collaborator

fchollet commented Jun 8, 2023

What's the error message?

@grasskin
Copy link
Member Author

grasskin commented Jun 8, 2023

Traceback (most recent call last):
  File "/Users/grasskin/Desktop/keras-core/examples/keras_io/tensorflow/vision/vit_small_ds.py", line 528, in <module>
    vit = create_vit_classifier(vanilla=True)
  File "/Users/grasskin/Desktop/keras-core/examples/keras_io/tensorflow/vision/vit_small_ds.py", line 402, in create_vit_classifier
    (tokens, _) = ShiftedPatchTokenization(vanilla=vanilla)(augmented)
  File "/Users/grasskin/Desktop/keras-core/keras_core/utils/traceback_utils.py", line 113, in error_handler
    return fn(*args, **kwargs)
  File "/Users/grasskin/Desktop/keras-core/keras_core/layers/layer.py", line 569, in __call__
    self._maybe_build(call_spec)
  File "/Users/grasskin/Desktop/keras-core/keras_core/layers/layer.py", line 962, in _maybe_build
    raise ValueError(
ValueError: Layer 'shifted_patch_tokenization_2' looks like it has unbuilt state, but Keras is not able to trace the layer `call()` in order to build it automatically. You must implement the `def build(self, input_shape)` method on your layer. It should create all variables used by the layer (e.g. by calling `layer.build()` on all its children layers).

The error message is a bit generic but after toying with it it has something to do with the projector layer receiving the wrong shape or not being fully defined or something along those lines. The rest of the __call__ is fine.

@cosmo3769
Copy link
Contributor

@grasskin @fchollet I was too going through this example. Getting the same error. Do we have to implement build method for a workaround or is it a work in progress to fix it? 🤔

@fchollet fchollet transferred this issue from keras-team/keras-core Sep 22, 2023
@dhantule
Copy link
Contributor

dhantule commented Feb 3, 2025

Hi @grasskin, thanks for reporting this!

The example was updated. Please try running it again and let us know if the error persists. Thanks!

Copy link

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Feb 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants