Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrating Train a Vision Transformer on small datasets example to Keras 3 #1991

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 37 additions & 40 deletions examples/vision/vit_small_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,19 @@
example is inspired from
[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).

_Note_: This example requires TensorFlow 2.6 or higher, as well as
[TensorFlow Addons](https://www.tensorflow.org/addons), which can be
installed using the following command:

```python
pip install -qq -U tensorflow-addons
_Note_: This example requires TensorFlow 2.6 or higher.
```
"""
"""
## Setup
"""

import math
import numpy as np
import keras
from keras import ops
from keras import layers
import tensorflow as tf
from tensorflow import keras
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from tensorflow.keras import layers

# Setting seed for reproducibiltiy
SEED = 42
Expand Down Expand Up @@ -197,17 +191,17 @@ def crop_shift_pad(self, images, mode):
shift_width = self.half_patch

# Crop the shifted images and pad them
crop = tf.image.crop_to_bounding_box(
crop = ops.image.crop_images(
images,
offset_height=crop_height,
offset_width=crop_width,
top_cropping=crop_height,
left_cropping=crop_width,
target_height=self.image_size - self.half_patch,
target_width=self.image_size - self.half_patch,
)
shift_pad = tf.image.pad_to_bounding_box(
shift_pad = ops.image.pad_images(
crop,
offset_height=shift_height,
offset_width=shift_width,
top_padding=shift_height,
left_padding=shift_width,
target_height=self.image_size,
target_width=self.image_size,
)
Expand All @@ -216,7 +210,7 @@ def crop_shift_pad(self, images, mode):
def call(self, images):
if not self.vanilla:
# Concat the shifted images with the original image
images = tf.concat(
images = ops.concatenate(
[
images,
self.crop_shift_pad(images, mode="left-up"),
Expand All @@ -227,11 +221,11 @@ def call(self, images):
axis=-1,
)
# Patchify the images and flatten it
patches = tf.image.extract_patches(
patches = ops.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
size=(self.patch_size, self.patch_size),
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
dilation_rate=1,
padding="VALID",
)
flat_patches = self.flatten_patches(patches)
Expand All @@ -252,8 +246,9 @@ def call(self, images):
# Get a random image from the training dataset
# and resize the image
image = x_train[np.random.choice(range(x_train.shape[0]))]
resized_image = tf.image.resize(
tf.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)
resized_image = ops.cast(
ops.image.resize(ops.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)),
dtype="float32",
)

# Vanilla patch maker: This takes an image and divides into
Expand All @@ -267,7 +262,7 @@ def call(self, images):
for col in range(n):
plt.subplot(n, n, count)
count = count + 1
image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))
image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))
plt.imshow(image)
plt.axis("off")
plt.show()
Expand All @@ -286,7 +281,7 @@ def call(self, images):
for col in range(n):
plt.subplot(n, n, count)
count = count + 1
image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))
image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))
plt.imshow(image[..., 3 * index : 3 * index + 3])
plt.axis("off")
plt.show()
Expand All @@ -308,7 +303,7 @@ def __init__(
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
self.positions = tf.range(start=0, limit=self.num_patches, delta=1)
self.positions = ops.arange(start=0, stop=self.num_patches, step=1)

def call(self, encoded_patches):
encoded_positions = self.position_embedding(self.positions)
Expand Down Expand Up @@ -355,21 +350,21 @@ def call(self, encoded_patches):
"""


class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
class MultiHeadAttentionLSA(layers.MultiHeadAttention):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# The trainable temperature term. The initial value is
# the square root of the key dimension.
self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)
self.tau = keras.Variable(math.sqrt(float(self._key_dim)), trainable=True)

def _compute_attention(self, query, key, value, attention_mask=None, training=None):
query = tf.multiply(query, 1.0 / self.tau)
attention_scores = tf.einsum(self._dot_product_equation, key, query)
query = ops.multiply(query, 1.0 / self.tau)
attention_scores = ops.einsum(self._dot_product_equation, key, query)
attention_scores = self._masked_softmax(attention_scores, attention_mask)
attention_scores_dropout = self._dropout_layer(
attention_scores, training=training
)
attention_output = tf.einsum(
attention_output = ops.einsum(
self._combine_equation, attention_scores_dropout, value
)
return attention_output, attention_scores
Expand All @@ -382,14 +377,14 @@ def _compute_attention(self, query, key, value, attention_mask=None, training=No

def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dense(units, activation="gelu")(x)
x = layers.Dropout(dropout_rate)(x)
return x


# Build the diagonal attention mask
diag_attn_mask = 1 - tf.eye(NUM_PATCHES)
diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)
diag_attn_mask = 1 - ops.eye(NUM_PATCHES)
diag_attn_mask = ops.cast([diag_attn_mask], dtype="int8")

"""
## Build the ViT
Expand Down Expand Up @@ -457,15 +452,15 @@ def __init__(
self.total_steps = total_steps
self.warmup_learning_rate = warmup_learning_rate
self.warmup_steps = warmup_steps
self.pi = tf.constant(np.pi)
self.pi = ops.array(np.pi)

def __call__(self, step):
if self.total_steps < self.warmup_steps:
raise ValueError("Total_steps must be larger or equal to warmup_steps.")

cos_annealed_lr = tf.cos(
cos_annealed_lr = ops.cos(
self.pi
* (tf.cast(step, tf.float32) - self.warmup_steps)
* (ops.cast(step, dtype="float32") - self.warmup_steps)
/ float(self.total_steps - self.warmup_steps)
)
learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
Expand All @@ -479,11 +474,13 @@ def __call__(self, step):
slope = (
self.learning_rate_base - self.warmup_learning_rate
) / self.warmup_steps
warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
learning_rate = tf.where(
warmup_rate = (
slope * ops.cast(step, dtype="float32") + self.warmup_learning_rate
)
learning_rate = ops.where(
step < self.warmup_steps, warmup_rate, learning_rate
)
return tf.where(
return ops.where(
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)

Expand All @@ -499,7 +496,7 @@ def run_experiment(model):
warmup_steps=warmup_steps,
)

optimizer = tfa.optimizers.AdamW(
optimizer = keras.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

Expand Down