From 8a55f0a98f1e67af09cfdbdc73f552157180a881 Mon Sep 17 00:00:00 2001
From: Sitam Meur <103279526+sitamgithub-MSIT@users.noreply.github.com>
Date: Wed, 27 Nov 2024 15:39:37 +0530
Subject: [PATCH] Migrating Train a Vision Transformer on small datasets
example to Keras 3 (#1991)
* migrate vit small dataset example to keras3
* requested changes added and backend agnostic done
* other generated files are added too
---
examples/vision/ipynb/vit_small_ds.ipynb | 95 +++++++++++-------------
examples/vision/md/vit_small_ds.md | 79 ++++++++++----------
examples/vision/vit_small_ds.py | 80 ++++++++++----------
3 files changed, 121 insertions(+), 133 deletions(-)
diff --git a/examples/vision/ipynb/vit_small_ds.ipynb b/examples/vision/ipynb/vit_small_ds.ipynb
index 7f51e46813..2d89b52fbd 100644
--- a/examples/vision/ipynb/vit_small_ds.ipynb
+++ b/examples/vision/ipynb/vit_small_ds.ipynb
@@ -10,7 +10,7 @@
"\n",
"**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498)
\n",
"**Date created:** 2022/01/07
\n",
- "**Last modified:** 2022/01/10
\n",
+ "**Last modified:** 2024/11/27
\n",
"**Description:** Training a ViT from scratch on smaller datasets with shifted patch tokenization and locality self-attention."
]
},
@@ -47,13 +47,7 @@
"example is inspired from\n",
"[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).\n",
"\n",
- "_Note_: This example requires TensorFlow 2.6 or higher, as well as\n",
- "[TensorFlow Addons](https://www.tensorflow.org/addons), which can be\n",
- "installed using the following command:\n",
- "\n",
- "```python\n",
- "pip install -qq -U tensorflow-addons\n",
- "```"
+ "_Note_: This example requires TensorFlow 2.6 or higher."
]
},
{
@@ -75,11 +69,11 @@
"source": [
"import math\n",
"import numpy as np\n",
+ "import keras\n",
+ "from keras import ops\n",
+ "from keras import layers\n",
"import tensorflow as tf\n",
- "from tensorflow import keras\n",
- "import tensorflow_addons as tfa\n",
"import matplotlib.pyplot as plt\n",
- "from tensorflow.keras import layers\n",
"\n",
"# Setting seed for reproducibiltiy\n",
"SEED = 42\n",
@@ -279,17 +273,17 @@
" shift_width = self.half_patch\n",
"\n",
" # Crop the shifted images and pad them\n",
- " crop = tf.image.crop_to_bounding_box(\n",
+ " crop = ops.image.crop_images(\n",
" images,\n",
- " offset_height=crop_height,\n",
- " offset_width=crop_width,\n",
+ " top_cropping=crop_height,\n",
+ " left_cropping=crop_width,\n",
" target_height=self.image_size - self.half_patch,\n",
" target_width=self.image_size - self.half_patch,\n",
" )\n",
- " shift_pad = tf.image.pad_to_bounding_box(\n",
+ " shift_pad = ops.image.pad_images(\n",
" crop,\n",
- " offset_height=shift_height,\n",
- " offset_width=shift_width,\n",
+ " top_padding=shift_height,\n",
+ " left_padding=shift_width,\n",
" target_height=self.image_size,\n",
" target_width=self.image_size,\n",
" )\n",
@@ -298,7 +292,7 @@
" def call(self, images):\n",
" if not self.vanilla:\n",
" # Concat the shifted images with the original image\n",
- " images = tf.concat(\n",
+ " images = ops.concatenate(\n",
" [\n",
" images,\n",
" self.crop_shift_pad(images, mode=\"left-up\"),\n",
@@ -309,11 +303,11 @@
" axis=-1,\n",
" )\n",
" # Patchify the images and flatten it\n",
- " patches = tf.image.extract_patches(\n",
+ " patches = ops.image.extract_patches(\n",
" images=images,\n",
- " sizes=[1, self.patch_size, self.patch_size, 1],\n",
+ " size=(self.patch_size, self.patch_size),\n",
" strides=[1, self.patch_size, self.patch_size, 1],\n",
- " rates=[1, 1, 1, 1],\n",
+ " dilation_rate=1,\n",
" padding=\"VALID\",\n",
" )\n",
" flat_patches = self.flatten_patches(patches)\n",
@@ -324,8 +318,7 @@
" else:\n",
" # Linearly project the flat patches\n",
" tokens = self.projection(flat_patches)\n",
- " return (tokens, patches)\n",
- ""
+ " return (tokens, patches)\n"
]
},
{
@@ -348,8 +341,9 @@
"# Get a random image from the training dataset\n",
"# and resize the image\n",
"image = x_train[np.random.choice(range(x_train.shape[0]))]\n",
- "resized_image = tf.image.resize(\n",
- " tf.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)\n",
+ "resized_image = ops.cast(\n",
+ " ops.image.resize(ops.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)),\n",
+ " dtype=\"float32\",\n",
")\n",
"\n",
"# Vanilla patch maker: This takes an image and divides into\n",
@@ -363,7 +357,7 @@
" for col in range(n):\n",
" plt.subplot(n, n, count)\n",
" count = count + 1\n",
- " image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))\n",
+ " image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))\n",
" plt.imshow(image)\n",
" plt.axis(\"off\")\n",
"plt.show()\n",
@@ -382,7 +376,7 @@
" for col in range(n):\n",
" plt.subplot(n, n, count)\n",
" count = count + 1\n",
- " image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))\n",
+ " image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))\n",
" plt.imshow(image[..., 3 * index : 3 * index + 3])\n",
" plt.axis(\"off\")\n",
" plt.show()"
@@ -418,13 +412,12 @@
" self.position_embedding = layers.Embedding(\n",
" input_dim=num_patches, output_dim=projection_dim\n",
" )\n",
- " self.positions = tf.range(start=0, limit=self.num_patches, delta=1)\n",
+ " self.positions = ops.arange(start=0, stop=self.num_patches, step=1)\n",
"\n",
" def call(self, encoded_patches):\n",
" encoded_positions = self.position_embedding(self.positions)\n",
" encoded_patches = encoded_patches + encoded_positions\n",
- " return encoded_patches\n",
- ""
+ " return encoded_patches\n"
]
},
{
@@ -479,25 +472,24 @@
"outputs": [],
"source": [
"\n",
- "class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):\n",
+ "class MultiHeadAttentionLSA(layers.MultiHeadAttention):\n",
" def __init__(self, **kwargs):\n",
" super().__init__(**kwargs)\n",
" # The trainable temperature term. The initial value is\n",
" # the square root of the key dimension.\n",
- " self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)\n",
+ " self.tau = keras.Variable(math.sqrt(float(self._key_dim)), trainable=True)\n",
"\n",
" def _compute_attention(self, query, key, value, attention_mask=None, training=None):\n",
- " query = tf.multiply(query, 1.0 / self.tau)\n",
- " attention_scores = tf.einsum(self._dot_product_equation, key, query)\n",
+ " query = ops.multiply(query, 1.0 / self.tau)\n",
+ " attention_scores = ops.einsum(self._dot_product_equation, key, query)\n",
" attention_scores = self._masked_softmax(attention_scores, attention_mask)\n",
" attention_scores_dropout = self._dropout_layer(\n",
" attention_scores, training=training\n",
" )\n",
- " attention_output = tf.einsum(\n",
+ " attention_output = ops.einsum(\n",
" self._combine_equation, attention_scores_dropout, value\n",
" )\n",
- " return attention_output, attention_scores\n",
- ""
+ " return attention_output, attention_scores\n"
]
},
{
@@ -520,14 +512,14 @@
"\n",
"def mlp(x, hidden_units, dropout_rate):\n",
" for units in hidden_units:\n",
- " x = layers.Dense(units, activation=tf.nn.gelu)(x)\n",
+ " x = layers.Dense(units, activation=\"gelu\")(x)\n",
" x = layers.Dropout(dropout_rate)(x)\n",
" return x\n",
"\n",
"\n",
"# Build the diagonal attention mask\n",
- "diag_attn_mask = 1 - tf.eye(NUM_PATCHES)\n",
- "diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)"
+ "diag_attn_mask = 1 - ops.eye(NUM_PATCHES)\n",
+ "diag_attn_mask = ops.cast([diag_attn_mask], dtype=\"int8\")"
]
},
{
@@ -589,8 +581,7 @@
" logits = layers.Dense(NUM_CLASSES)(features)\n",
" # Create the Keras model.\n",
" model = keras.Model(inputs=inputs, outputs=logits)\n",
- " return model\n",
- ""
+ " return model\n"
]
},
{
@@ -622,15 +613,15 @@
" self.total_steps = total_steps\n",
" self.warmup_learning_rate = warmup_learning_rate\n",
" self.warmup_steps = warmup_steps\n",
- " self.pi = tf.constant(np.pi)\n",
+ " self.pi = ops.array(np.pi)\n",
"\n",
" def __call__(self, step):\n",
" if self.total_steps < self.warmup_steps:\n",
" raise ValueError(\"Total_steps must be larger or equal to warmup_steps.\")\n",
"\n",
- " cos_annealed_lr = tf.cos(\n",
+ " cos_annealed_lr = ops.cos(\n",
" self.pi\n",
- " * (tf.cast(step, tf.float32) - self.warmup_steps)\n",
+ " * (ops.cast(step, dtype=\"float32\") - self.warmup_steps)\n",
" / float(self.total_steps - self.warmup_steps)\n",
" )\n",
" learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)\n",
@@ -644,11 +635,13 @@
" slope = (\n",
" self.learning_rate_base - self.warmup_learning_rate\n",
" ) / self.warmup_steps\n",
- " warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate\n",
- " learning_rate = tf.where(\n",
+ " warmup_rate = (\n",
+ " slope * ops.cast(step, dtype=\"float32\") + self.warmup_learning_rate\n",
+ " )\n",
+ " learning_rate = ops.where(\n",
" step < self.warmup_steps, warmup_rate, learning_rate\n",
" )\n",
- " return tf.where(\n",
+ " return ops.where(\n",
" step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n",
" )\n",
"\n",
@@ -664,7 +657,7 @@
" warmup_steps=warmup_steps,\n",
" )\n",
"\n",
- " optimizer = tfa.optimizers.AdamW(\n",
+ " optimizer = keras.optimizers.AdamW(\n",
" learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY\n",
" )\n",
"\n",
@@ -720,7 +713,7 @@
"I would like to thank [Jarvislabs.ai](https://jarvislabs.ai/) for\n",
"generously helping with GPU credits.\n",
"\n",
- "You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2) ",
+ "You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2) \n",
"and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/vit-small-ds)."
]
}
@@ -754,4 +747,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
-}
\ No newline at end of file
+}
diff --git a/examples/vision/md/vit_small_ds.md b/examples/vision/md/vit_small_ds.md
index 8a7f824b78..d916f6728a 100644
--- a/examples/vision/md/vit_small_ds.md
+++ b/examples/vision/md/vit_small_ds.md
@@ -2,7 +2,7 @@
**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498)
**Date created:** 2022/01/07
-**Last modified:** 2022/01/10
+**Last modified:** 2024/11/27
**Description:** Training a ViT from scratch on smaller datasets with shifted patch tokenization and locality self-attention.
@@ -38,13 +38,7 @@ This example implements the ideas of the paper. A large part of this
example is inspired from
[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).
-_Note_: This example requires TensorFlow 2.6 or higher, as well as
-[TensorFlow Addons](https://www.tensorflow.org/addons), which can be
-installed using the following command:
-
-```python
-pip install -qq -U tensorflow-addons
-```
+_Note_: This example requires TensorFlow 2.6 or higher.
---
## Setup
@@ -53,11 +47,11 @@ pip install -qq -U tensorflow-addons
```python
import math
import numpy as np
+import keras
+from keras import ops
+from keras import layers
import tensorflow as tf
-from tensorflow import keras
-import tensorflow_addons as tfa
import matplotlib.pyplot as plt
-from tensorflow.keras import layers
# Setting seed for reproducibiltiy
SEED = 42
@@ -221,17 +215,17 @@ class ShiftedPatchTokenization(layers.Layer):
shift_width = self.half_patch
# Crop the shifted images and pad them
- crop = tf.image.crop_to_bounding_box(
+ crop = ops.image.crop_images(
images,
- offset_height=crop_height,
- offset_width=crop_width,
+ top_cropping=crop_height,
+ left_cropping=crop_width,
target_height=self.image_size - self.half_patch,
target_width=self.image_size - self.half_patch,
)
- shift_pad = tf.image.pad_to_bounding_box(
+ shift_pad = ops.image.pad_images(
crop,
- offset_height=shift_height,
- offset_width=shift_width,
+ top_padding=shift_height,
+ left_padding=shift_width,
target_height=self.image_size,
target_width=self.image_size,
)
@@ -240,7 +234,7 @@ class ShiftedPatchTokenization(layers.Layer):
def call(self, images):
if not self.vanilla:
# Concat the shifted images with the original image
- images = tf.concat(
+ images = ops.concatenate(
[
images,
self.crop_shift_pad(images, mode="left-up"),
@@ -251,11 +245,11 @@ class ShiftedPatchTokenization(layers.Layer):
axis=-1,
)
# Patchify the images and flatten it
- patches = tf.image.extract_patches(
+ patches = ops.image.extract_patches(
images=images,
- sizes=[1, self.patch_size, self.patch_size, 1],
+ size=(self.patch_size, self.patch_size),
strides=[1, self.patch_size, self.patch_size, 1],
- rates=[1, 1, 1, 1],
+ dilation_rate=1,
padding="VALID",
)
flat_patches = self.flatten_patches(patches)
@@ -277,8 +271,9 @@ class ShiftedPatchTokenization(layers.Layer):
# Get a random image from the training dataset
# and resize the image
image = x_train[np.random.choice(range(x_train.shape[0]))]
-resized_image = tf.image.resize(
- tf.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)
+resized_image = ops.cast(
+ ops.image.resize(ops.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)),
+ dtype="float32",
)
# Vanilla patch maker: This takes an image and divides into
@@ -292,7 +287,7 @@ for row in range(n):
for col in range(n):
plt.subplot(n, n, count)
count = count + 1
- image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))
+ image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))
plt.imshow(image)
plt.axis("off")
plt.show()
@@ -311,7 +306,7 @@ for index, name in enumerate(shifted_images):
for col in range(n):
plt.subplot(n, n, count)
count = count + 1
- image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))
+ image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))
plt.imshow(image[..., 3 * index : 3 * index + 3])
plt.axis("off")
plt.show()
@@ -401,7 +396,7 @@ class PatchEncoder(layers.Layer):
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
- self.positions = tf.range(start=0, limit=self.num_patches, delta=1)
+ self.positions = ops.arange(start=0, stop=self.num_patches, step=1)
def call(self, encoded_patches):
encoded_positions = self.position_embedding(self.positions)
@@ -450,21 +445,21 @@ at a later stage.
```python
-class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
+class MultiHeadAttentionLSA(layers.MultiHeadAttention):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# The trainable temperature term. The initial value is
# the square root of the key dimension.
- self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)
+ self.tau = keras.Variable(math.sqrt(float(self._key_dim)), trainable=True)
def _compute_attention(self, query, key, value, attention_mask=None, training=None):
- query = tf.multiply(query, 1.0 / self.tau)
- attention_scores = tf.einsum(self._dot_product_equation, key, query)
+ query = ops.multiply(query, 1.0 / self.tau)
+ attention_scores = ops.einsum(self._dot_product_equation, key, query)
attention_scores = self._masked_softmax(attention_scores, attention_mask)
attention_scores_dropout = self._dropout_layer(
attention_scores, training=training
)
- attention_output = tf.einsum(
+ attention_output = ops.einsum(
self._combine_equation, attention_scores_dropout, value
)
return attention_output, attention_scores
@@ -479,14 +474,14 @@ class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
- x = layers.Dense(units, activation=tf.nn.gelu)(x)
+ x = layers.Dense(units, activation="gelu")(x)
x = layers.Dropout(dropout_rate)(x)
return x
# Build the diagonal attention mask
-diag_attn_mask = 1 - tf.eye(NUM_PATCHES)
-diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)
+diag_attn_mask = 1 - ops.eye(NUM_PATCHES)
+diag_attn_mask = ops.cast([diag_attn_mask], dtype="int8")
```
---
@@ -557,15 +552,15 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
self.total_steps = total_steps
self.warmup_learning_rate = warmup_learning_rate
self.warmup_steps = warmup_steps
- self.pi = tf.constant(np.pi)
+ self.pi = ops.array(np.pi)
def __call__(self, step):
if self.total_steps < self.warmup_steps:
raise ValueError("Total_steps must be larger or equal to warmup_steps.")
- cos_annealed_lr = tf.cos(
+ cos_annealed_lr = ops.cos(
self.pi
- * (tf.cast(step, tf.float32) - self.warmup_steps)
+ * (ops.cast(step, dtype="float32") - self.warmup_steps)
/ float(self.total_steps - self.warmup_steps)
)
learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
@@ -579,11 +574,13 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
slope = (
self.learning_rate_base - self.warmup_learning_rate
) / self.warmup_steps
- warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
- learning_rate = tf.where(
+ warmup_rate = (
+ slope * ops.cast(step, dtype="float32") + self.warmup_learning_rate
+ )
+ learning_rate = ops.where(
step < self.warmup_steps, warmup_rate, learning_rate
)
- return tf.where(
+ return ops.where(
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)
@@ -599,7 +596,7 @@ def run_experiment(model):
warmup_steps=warmup_steps,
)
- optimizer = tfa.optimizers.AdamW(
+ optimizer = keras.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
diff --git a/examples/vision/vit_small_ds.py b/examples/vision/vit_small_ds.py
index 40ef2c526c..525967fca6 100644
--- a/examples/vision/vit_small_ds.py
+++ b/examples/vision/vit_small_ds.py
@@ -2,9 +2,10 @@
Title: Train a Vision Transformer on small datasets
Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498)
Date created: 2022/01/07
-Last modified: 2022/01/10
+Last modified: 2024/11/27
Description: Training a ViT from scratch on smaller datasets with shifted patch tokenization and locality self-attention.
Accelerator: GPU
+Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
"""
"""
@@ -35,25 +36,19 @@
example is inspired from
[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).
-_Note_: This example requires TensorFlow 2.6 or higher, as well as
-[TensorFlow Addons](https://www.tensorflow.org/addons), which can be
-installed using the following command:
-
-```python
-pip install -qq -U tensorflow-addons
+_Note_: This example requires TensorFlow 2.6 or higher.
```
"""
"""
## Setup
"""
-
import math
import numpy as np
+import keras
+from keras import ops
+from keras import layers
import tensorflow as tf
-from tensorflow import keras
-import tensorflow_addons as tfa
import matplotlib.pyplot as plt
-from tensorflow.keras import layers
# Setting seed for reproducibiltiy
SEED = 42
@@ -197,17 +192,17 @@ def crop_shift_pad(self, images, mode):
shift_width = self.half_patch
# Crop the shifted images and pad them
- crop = tf.image.crop_to_bounding_box(
+ crop = ops.image.crop_images(
images,
- offset_height=crop_height,
- offset_width=crop_width,
+ top_cropping=crop_height,
+ left_cropping=crop_width,
target_height=self.image_size - self.half_patch,
target_width=self.image_size - self.half_patch,
)
- shift_pad = tf.image.pad_to_bounding_box(
+ shift_pad = ops.image.pad_images(
crop,
- offset_height=shift_height,
- offset_width=shift_width,
+ top_padding=shift_height,
+ left_padding=shift_width,
target_height=self.image_size,
target_width=self.image_size,
)
@@ -216,7 +211,7 @@ def crop_shift_pad(self, images, mode):
def call(self, images):
if not self.vanilla:
# Concat the shifted images with the original image
- images = tf.concat(
+ images = ops.concatenate(
[
images,
self.crop_shift_pad(images, mode="left-up"),
@@ -227,11 +222,11 @@ def call(self, images):
axis=-1,
)
# Patchify the images and flatten it
- patches = tf.image.extract_patches(
+ patches = ops.image.extract_patches(
images=images,
- sizes=[1, self.patch_size, self.patch_size, 1],
+ size=(self.patch_size, self.patch_size),
strides=[1, self.patch_size, self.patch_size, 1],
- rates=[1, 1, 1, 1],
+ dilation_rate=1,
padding="VALID",
)
flat_patches = self.flatten_patches(patches)
@@ -252,8 +247,9 @@ def call(self, images):
# Get a random image from the training dataset
# and resize the image
image = x_train[np.random.choice(range(x_train.shape[0]))]
-resized_image = tf.image.resize(
- tf.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)
+resized_image = ops.cast(
+ ops.image.resize(ops.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)),
+ dtype="float32",
)
# Vanilla patch maker: This takes an image and divides into
@@ -267,7 +263,7 @@ def call(self, images):
for col in range(n):
plt.subplot(n, n, count)
count = count + 1
- image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))
+ image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))
plt.imshow(image)
plt.axis("off")
plt.show()
@@ -286,7 +282,7 @@ def call(self, images):
for col in range(n):
plt.subplot(n, n, count)
count = count + 1
- image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))
+ image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))
plt.imshow(image[..., 3 * index : 3 * index + 3])
plt.axis("off")
plt.show()
@@ -308,7 +304,7 @@ def __init__(
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
- self.positions = tf.range(start=0, limit=self.num_patches, delta=1)
+ self.positions = ops.arange(start=0, stop=self.num_patches, step=1)
def call(self, encoded_patches):
encoded_positions = self.position_embedding(self.positions)
@@ -355,21 +351,21 @@ def call(self, encoded_patches):
"""
-class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
+class MultiHeadAttentionLSA(layers.MultiHeadAttention):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# The trainable temperature term. The initial value is
# the square root of the key dimension.
- self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)
+ self.tau = keras.Variable(math.sqrt(float(self._key_dim)), trainable=True)
def _compute_attention(self, query, key, value, attention_mask=None, training=None):
- query = tf.multiply(query, 1.0 / self.tau)
- attention_scores = tf.einsum(self._dot_product_equation, key, query)
+ query = ops.multiply(query, 1.0 / self.tau)
+ attention_scores = ops.einsum(self._dot_product_equation, key, query)
attention_scores = self._masked_softmax(attention_scores, attention_mask)
attention_scores_dropout = self._dropout_layer(
attention_scores, training=training
)
- attention_output = tf.einsum(
+ attention_output = ops.einsum(
self._combine_equation, attention_scores_dropout, value
)
return attention_output, attention_scores
@@ -382,14 +378,14 @@ def _compute_attention(self, query, key, value, attention_mask=None, training=No
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
- x = layers.Dense(units, activation=tf.nn.gelu)(x)
+ x = layers.Dense(units, activation="gelu")(x)
x = layers.Dropout(dropout_rate)(x)
return x
# Build the diagonal attention mask
-diag_attn_mask = 1 - tf.eye(NUM_PATCHES)
-diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)
+diag_attn_mask = 1 - ops.eye(NUM_PATCHES)
+diag_attn_mask = ops.cast([diag_attn_mask], dtype="int8")
"""
## Build the ViT
@@ -457,15 +453,15 @@ def __init__(
self.total_steps = total_steps
self.warmup_learning_rate = warmup_learning_rate
self.warmup_steps = warmup_steps
- self.pi = tf.constant(np.pi)
+ self.pi = ops.array(np.pi)
def __call__(self, step):
if self.total_steps < self.warmup_steps:
raise ValueError("Total_steps must be larger or equal to warmup_steps.")
- cos_annealed_lr = tf.cos(
+ cos_annealed_lr = ops.cos(
self.pi
- * (tf.cast(step, tf.float32) - self.warmup_steps)
+ * (ops.cast(step, dtype="float32") - self.warmup_steps)
/ float(self.total_steps - self.warmup_steps)
)
learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
@@ -479,11 +475,13 @@ def __call__(self, step):
slope = (
self.learning_rate_base - self.warmup_learning_rate
) / self.warmup_steps
- warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
- learning_rate = tf.where(
+ warmup_rate = (
+ slope * ops.cast(step, dtype="float32") + self.warmup_learning_rate
+ )
+ learning_rate = ops.where(
step < self.warmup_steps, warmup_rate, learning_rate
)
- return tf.where(
+ return ops.where(
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)
@@ -499,7 +497,7 @@ def run_experiment(model):
warmup_steps=warmup_steps,
)
- optimizer = tfa.optimizers.AdamW(
+ optimizer = keras.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)