Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes related to TF backend compatibility check for Vision Transformer on small datasets #1671

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions examples/vision/vit_small_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Last modified: 2022/01/10
Description: Training a ViT from scratch on smaller datasets with shifted patch tokenization and locality self-attention.
Accelerator: GPU
Keras 3 Conversion initiated by: [Pavan Kumar Singh](https://github.com/pksX01)
"""
"""
## Introduction
Expand Down Expand Up @@ -33,26 +34,20 @@
This example implements the ideas of the paper. A large part of this
example is inspired from
[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).

_Note_: This example requires TensorFlow 2.6 or higher, as well as
[TensorFlow Addons](https://www.tensorflow.org/addons), which can be
installed using the following command:

```python
pip install -qq -U tensorflow-addons
```
"""
"""
## Setup
"""

import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import math
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_addons as tfa
import keras
import matplotlib.pyplot as plt
from tensorflow.keras import layers
from keras import layers

# Setting seed for reproducibiltiy
SEED = 42
Expand Down Expand Up @@ -354,7 +349,7 @@ def call(self, encoded_patches):
"""


class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
class MultiHeadAttentionLSA(keras.layers.MultiHeadAttention):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# The trainable temperature term. The initial value is
Expand Down Expand Up @@ -498,7 +493,7 @@ def run_experiment(model):
warmup_steps=warmup_steps,
)

optimizer = tfa.optimizers.AdamW(
optimizer = keras.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

Expand Down Expand Up @@ -547,7 +542,4 @@ def run_experiment(model):

I would like to thank [Jarvislabs.ai](https://jarvislabs.ai/) for
generously helping with GPU credits.

You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2)
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/vit-small-ds).
"""