-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrating Train a Vision Transformer on small datasets example to Keras 3 #1991
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I think we could go further:
- Use
keras.ops.image.extract_patches
instead oftf.image.extract_patches
- Use
keras.ops.array
instead oftf.constant
- Use
keras.ops.image.crop_images
instead oftf.image.crop_to_bounding_box
- Use
keras.ops.image.pad_images
instead oftf.image.pad_to_bounding_box
After this I believe the example should be backend-agnostic.
examples/vision/vit_small_ds.py
Outdated
@@ -355,21 +354,21 @@ def call(self, encoded_patches): | |||
""" | |||
|
|||
|
|||
class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention): | |||
class MultiHeadAttentionLSA(layers.MultiHeadAttention): | |||
def __init__(self, **kwargs): | |||
super().__init__(**kwargs) | |||
# The trainable temperature term. The initial value is | |||
# the square root of the key dimension. | |||
self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please replace this with keras.Variable()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
@fchollet All changes have been added, and it is completely backend agnostic now; tested with all three backends. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks for the update! Please add the generated files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you! 👍
This PR changes the Train a Vision Transformer on small datasets example to keras 3.0 [TF-Only Example].
For example, here is the notebook link provided:
https://colab.research.google.com/drive/1ugp-3Zkkev9RNfuboWTFpS5202hhUcFv?usp=sharing
The following describes the Git difference for the changed files:
Changes:
"""
"""
Setup
"""
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
import math
import numpy as np
+import keras
+from keras import ops
+from keras import layers
import tensorflow as tf
-from tensorflow import keras
-import tensorflow_addons as tfa
import matplotlib.pyplot as plt
-from tensorflow.keras import layers
Setting seed for reproducibiltiy
SEED = 42
@@ -216,7 +214,7 @@ class ShiftedPatchTokenization(layers.Layer):
def call(self, images):
if not self.vanilla:
# Concat the shifted images with the original image
@@ -252,8 +250,9 @@ class ShiftedPatchTokenization(layers.Layer):
Get a random image from the training dataset
and resize the image
image = x_train[np.random.choice(range(x_train.shape[0]))]
-resized_image = tf.image.resize(
+resized_image = ops.cast(
)
Vanilla patch maker: This takes an image and divides into
@@ -267,7 +266,7 @@ for row in range(n):
for col in range(n):
plt.subplot(n, n, count)
count = count + 1
plt.show()
@@ -286,7 +285,7 @@ for index, name in enumerate(shifted_images):
for col in range(n):
plt.subplot(n, n, count)
count = count + 1
@@ -308,7 +307,7 @@ class PatchEncoder(layers.Layer):
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, encoded_patches):
encoded_positions = self.position_embedding(self.positions)
@@ -355,7 +354,7 @@ at a later stage.
"""
-class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
+class MultiHeadAttentionLSA(layers.MultiHeadAttention):
def init(self, **kwargs):
super().init(**kwargs)
# The trainable temperature term. The initial value is
@@ -363,13 +362,13 @@ class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)
@@ -382,14 +381,14 @@ class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
Build the diagonal attention mask
-diag_attn_mask = 1 - tf.eye(NUM_PATCHES)
-diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)
+diag_attn_mask = 1 - ops.eye(NUM_PATCHES)
+diag_attn_mask = ops.cast([diag_attn_mask], dtype="int8")
"""
Build the ViT
@@ -463,9 +462,9 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
if self.total_steps < self.warmup_steps:
raise ValueError("Total_steps must be larger or equal to warmup_steps.")
@@ -479,11 +478,13 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
slope = (
self.learning_rate_base - self.warmup_learning_rate
) / self.warmup_steps
@@ -499,7 +500,7 @@ def run_experiment(model):
warmup_steps=warmup_steps,
)
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
(END)