Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrating Train a Vision Transformer on small datasets example to Keras 3 #1991

Merged
merged 3 commits into from
Nov 27, 2024

Conversation

sitamgithub-MSIT
Copy link
Contributor

This PR changes the Train a Vision Transformer on small datasets example to keras 3.0 [TF-Only Example].

For example, here is the notebook link provided:
https://colab.research.google.com/drive/1ugp-3Zkkev9RNfuboWTFpS5202hhUcFv?usp=sharing

The following describes the Git difference for the changed files:

Changes:
diff --git a/examples/vision/vit_small_ds.py b/examples/vision/vit_small_ds.py
index 40ef2c52..658068dc 100644
--- a/examples/vision/vit_small_ds.py
+++ b/examples/vision/vit_small_ds.py
@@ -35,25 +35,23 @@ This example implements the ideas of the paper. A large part of this
 example is inspired from
 [Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).
 
-_Note_: This example requires TensorFlow 2.6 or higher, as well as
-[TensorFlow Addons](https://www.tensorflow.org/addons), which can be
-installed using the following command:
-
-```python
-pip install -qq -U tensorflow-addons
+_Note_: This example requires TensorFlow 2.6 or higher.

"""
"""

Setup

"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"

import math
import numpy as np
+import keras
+from keras import ops
+from keras import layers
import tensorflow as tf
-from tensorflow import keras
-import tensorflow_addons as tfa
import matplotlib.pyplot as plt
-from tensorflow.keras import layers

Setting seed for reproducibiltiy

SEED = 42
@@ -216,7 +214,7 @@ class ShiftedPatchTokenization(layers.Layer):
def call(self, images):
if not self.vanilla:
# Concat the shifted images with the original image

  •        images = tf.concat(
    
  •        images = ops.concatenate(
               [
                   images,
                   self.crop_shift_pad(images, mode="left-up"),
    

@@ -252,8 +250,9 @@ class ShiftedPatchTokenization(layers.Layer):

Get a random image from the training dataset

and resize the image

image = x_train[np.random.choice(range(x_train.shape[0]))]
-resized_image = tf.image.resize(

  • tf.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)
    +resized_image = ops.cast(
  • ops.image.resize(ops.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)),
  • dtype="float32",
    )

Vanilla patch maker: This takes an image and divides into

@@ -267,7 +266,7 @@ for row in range(n):
for col in range(n):
plt.subplot(n, n, count)
count = count + 1

  •    image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))
    
  •    image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))
       plt.imshow(image)
       plt.axis("off")
    

plt.show()
@@ -286,7 +285,7 @@ for index, name in enumerate(shifted_images):
for col in range(n):
plt.subplot(n, n, count)
count = count + 1

  •        image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))
    
  •        image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))
           plt.imshow(image[..., 3 * index : 3 * index + 3])
           plt.axis("off")
    
    plt.show()
    @@ -308,7 +307,7 @@ class PatchEncoder(layers.Layer):
    self.position_embedding = layers.Embedding(
    input_dim=num_patches, output_dim=projection_dim
    )
  •    self.positions = tf.range(start=0, limit=self.num_patches, delta=1)
    
  •    self.positions = ops.arange(start=0, stop=self.num_patches, step=1)
    

    def call(self, encoded_patches):
    encoded_positions = self.position_embedding(self.positions)
    @@ -355,7 +354,7 @@ at a later stage.
    """

-class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
+class MultiHeadAttentionLSA(layers.MultiHeadAttention):
def init(self, **kwargs):
super().init(**kwargs)
# The trainable temperature term. The initial value is
@@ -363,13 +362,13 @@ class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)

 def _compute_attention(self, query, key, value, attention_mask=None, training=None):
  •    query = tf.multiply(query, 1.0 / self.tau)
    
  •    attention_scores = tf.einsum(self._dot_product_equation, key, query)
    
  •    query = ops.multiply(query, 1.0 / self.tau)
    
  •    attention_scores = ops.einsum(self._dot_product_equation, key, query)
       attention_scores = self._masked_softmax(attention_scores, attention_mask)
       attention_scores_dropout = self._dropout_layer(
           attention_scores, training=training
       )
    
  •    attention_output = tf.einsum(
    
  •    attention_output = ops.einsum(
           self._combine_equation, attention_scores_dropout, value
       )
       return attention_output, attention_scores
    

@@ -382,14 +381,14 @@ class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):

def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:

  •    x = layers.Dense(units, activation=tf.nn.gelu)(x)
    
  •    x = layers.Dense(units, activation="gelu")(x)
       x = layers.Dropout(dropout_rate)(x)
    
    return x

Build the diagonal attention mask

-diag_attn_mask = 1 - tf.eye(NUM_PATCHES)
-diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)
+diag_attn_mask = 1 - ops.eye(NUM_PATCHES)
+diag_attn_mask = ops.cast([diag_attn_mask], dtype="int8")

"""

Build the ViT

@@ -463,9 +462,9 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
if self.total_steps < self.warmup_steps:
raise ValueError("Total_steps must be larger or equal to warmup_steps.")

  •    cos_annealed_lr = tf.cos(
    
  •    cos_annealed_lr = ops.cos(
           self.pi
    
  •        * (tf.cast(step, tf.float32) - self.warmup_steps)
    
  •        * (ops.cast(step, dtype="float32") - self.warmup_steps)
           / float(self.total_steps - self.warmup_steps)
       )
       learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
    

@@ -479,11 +478,13 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
slope = (
self.learning_rate_base - self.warmup_learning_rate
) / self.warmup_steps

  •        warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
    
  •        learning_rate = tf.where(
    
  •        warmup_rate = (
    
  •            slope * ops.cast(step, dtype="float32") + self.warmup_learning_rate
    
  •        )
    
  •        learning_rate = ops.where(
               step < self.warmup_steps, warmup_rate, learning_rate
           )
    
  •    return tf.where(
    
  •    return ops.where(
           step > self.total_steps, 0.0, learning_rate, name="learning_rate"
       )
    

@@ -499,7 +500,7 @@ def run_experiment(model):
warmup_steps=warmup_steps,
)

  • optimizer = tfa.optimizers.AdamW(
  • optimizer = keras.optimizers.AdamW(
    learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )

(END)

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I think we could go further:

  • Use keras.ops.image.extract_patches instead of tf.image.extract_patches
  • Use keras.ops.array instead of tf.constant
  • Use keras.ops.image.crop_images instead of tf.image.crop_to_bounding_box
  • Use keras.ops.image.pad_images instead of tf.image.pad_to_bounding_box

After this I believe the example should be backend-agnostic.

@@ -355,21 +354,21 @@ def call(self, encoded_patches):
"""


class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
class MultiHeadAttentionLSA(layers.MultiHeadAttention):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# The trainable temperature term. The initial value is
# the square root of the key dimension.
self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace this with keras.Variable()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@sitamgithub-MSIT
Copy link
Contributor Author

@fchollet All changes have been added, and it is completely backend agnostic now; tested with all three backends.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for the update! Please add the generated files.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you! 👍

@fchollet fchollet merged commit 8a55f0a into keras-team:master Nov 27, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants