Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrating Train a Vision Transformer on small datasets example to Keras 3 #1991

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 44 additions & 51 deletions examples/vision/ipynb/vit_small_ds.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498)<br>\n",
"**Date created:** 2022/01/07<br>\n",
"**Last modified:** 2022/01/10<br>\n",
"**Last modified:** 2024/11/27<br>\n",
"**Description:** Training a ViT from scratch on smaller datasets with shifted patch tokenization and locality self-attention."
]
},
Expand Down Expand Up @@ -47,13 +47,7 @@
"example is inspired from\n",
"[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).\n",
"\n",
"_Note_: This example requires TensorFlow 2.6 or higher, as well as\n",
"[TensorFlow Addons](https://www.tensorflow.org/addons), which can be\n",
"installed using the following command:\n",
"\n",
"```python\n",
"pip install -qq -U tensorflow-addons\n",
"```"
"_Note_: This example requires TensorFlow 2.6 or higher."
]
},
{
Expand All @@ -75,11 +69,11 @@
"source": [
"import math\n",
"import numpy as np\n",
"import keras\n",
"from keras import ops\n",
"from keras import layers\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"import tensorflow_addons as tfa\n",
"import matplotlib.pyplot as plt\n",
"from tensorflow.keras import layers\n",
"\n",
"# Setting seed for reproducibiltiy\n",
"SEED = 42\n",
Expand Down Expand Up @@ -279,17 +273,17 @@
" shift_width = self.half_patch\n",
"\n",
" # Crop the shifted images and pad them\n",
" crop = tf.image.crop_to_bounding_box(\n",
" crop = ops.image.crop_images(\n",
" images,\n",
" offset_height=crop_height,\n",
" offset_width=crop_width,\n",
" top_cropping=crop_height,\n",
" left_cropping=crop_width,\n",
" target_height=self.image_size - self.half_patch,\n",
" target_width=self.image_size - self.half_patch,\n",
" )\n",
" shift_pad = tf.image.pad_to_bounding_box(\n",
" shift_pad = ops.image.pad_images(\n",
" crop,\n",
" offset_height=shift_height,\n",
" offset_width=shift_width,\n",
" top_padding=shift_height,\n",
" left_padding=shift_width,\n",
" target_height=self.image_size,\n",
" target_width=self.image_size,\n",
" )\n",
Expand All @@ -298,7 +292,7 @@
" def call(self, images):\n",
" if not self.vanilla:\n",
" # Concat the shifted images with the original image\n",
" images = tf.concat(\n",
" images = ops.concatenate(\n",
" [\n",
" images,\n",
" self.crop_shift_pad(images, mode=\"left-up\"),\n",
Expand All @@ -309,11 +303,11 @@
" axis=-1,\n",
" )\n",
" # Patchify the images and flatten it\n",
" patches = tf.image.extract_patches(\n",
" patches = ops.image.extract_patches(\n",
" images=images,\n",
" sizes=[1, self.patch_size, self.patch_size, 1],\n",
" size=(self.patch_size, self.patch_size),\n",
" strides=[1, self.patch_size, self.patch_size, 1],\n",
" rates=[1, 1, 1, 1],\n",
" dilation_rate=1,\n",
" padding=\"VALID\",\n",
" )\n",
" flat_patches = self.flatten_patches(patches)\n",
Expand All @@ -324,8 +318,7 @@
" else:\n",
" # Linearly project the flat patches\n",
" tokens = self.projection(flat_patches)\n",
" return (tokens, patches)\n",
""
" return (tokens, patches)\n"
]
},
{
Expand All @@ -348,8 +341,9 @@
"# Get a random image from the training dataset\n",
"# and resize the image\n",
"image = x_train[np.random.choice(range(x_train.shape[0]))]\n",
"resized_image = tf.image.resize(\n",
" tf.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)\n",
"resized_image = ops.cast(\n",
" ops.image.resize(ops.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)),\n",
" dtype=\"float32\",\n",
")\n",
"\n",
"# Vanilla patch maker: This takes an image and divides into\n",
Expand All @@ -363,7 +357,7 @@
" for col in range(n):\n",
" plt.subplot(n, n, count)\n",
" count = count + 1\n",
" image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))\n",
" image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))\n",
" plt.imshow(image)\n",
" plt.axis(\"off\")\n",
"plt.show()\n",
Expand All @@ -382,7 +376,7 @@
" for col in range(n):\n",
" plt.subplot(n, n, count)\n",
" count = count + 1\n",
" image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))\n",
" image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))\n",
" plt.imshow(image[..., 3 * index : 3 * index + 3])\n",
" plt.axis(\"off\")\n",
" plt.show()"
Expand Down Expand Up @@ -418,13 +412,12 @@
" self.position_embedding = layers.Embedding(\n",
" input_dim=num_patches, output_dim=projection_dim\n",
" )\n",
" self.positions = tf.range(start=0, limit=self.num_patches, delta=1)\n",
" self.positions = ops.arange(start=0, stop=self.num_patches, step=1)\n",
"\n",
" def call(self, encoded_patches):\n",
" encoded_positions = self.position_embedding(self.positions)\n",
" encoded_patches = encoded_patches + encoded_positions\n",
" return encoded_patches\n",
""
" return encoded_patches\n"
]
},
{
Expand Down Expand Up @@ -479,25 +472,24 @@
"outputs": [],
"source": [
"\n",
"class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):\n",
"class MultiHeadAttentionLSA(layers.MultiHeadAttention):\n",
" def __init__(self, **kwargs):\n",
" super().__init__(**kwargs)\n",
" # The trainable temperature term. The initial value is\n",
" # the square root of the key dimension.\n",
" self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)\n",
" self.tau = keras.Variable(math.sqrt(float(self._key_dim)), trainable=True)\n",
"\n",
" def _compute_attention(self, query, key, value, attention_mask=None, training=None):\n",
" query = tf.multiply(query, 1.0 / self.tau)\n",
" attention_scores = tf.einsum(self._dot_product_equation, key, query)\n",
" query = ops.multiply(query, 1.0 / self.tau)\n",
" attention_scores = ops.einsum(self._dot_product_equation, key, query)\n",
" attention_scores = self._masked_softmax(attention_scores, attention_mask)\n",
" attention_scores_dropout = self._dropout_layer(\n",
" attention_scores, training=training\n",
" )\n",
" attention_output = tf.einsum(\n",
" attention_output = ops.einsum(\n",
" self._combine_equation, attention_scores_dropout, value\n",
" )\n",
" return attention_output, attention_scores\n",
""
" return attention_output, attention_scores\n"
]
},
{
Expand All @@ -520,14 +512,14 @@
"\n",
"def mlp(x, hidden_units, dropout_rate):\n",
" for units in hidden_units:\n",
" x = layers.Dense(units, activation=tf.nn.gelu)(x)\n",
" x = layers.Dense(units, activation=\"gelu\")(x)\n",
" x = layers.Dropout(dropout_rate)(x)\n",
" return x\n",
"\n",
"\n",
"# Build the diagonal attention mask\n",
"diag_attn_mask = 1 - tf.eye(NUM_PATCHES)\n",
"diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)"
"diag_attn_mask = 1 - ops.eye(NUM_PATCHES)\n",
"diag_attn_mask = ops.cast([diag_attn_mask], dtype=\"int8\")"
]
},
{
Expand Down Expand Up @@ -589,8 +581,7 @@
" logits = layers.Dense(NUM_CLASSES)(features)\n",
" # Create the Keras model.\n",
" model = keras.Model(inputs=inputs, outputs=logits)\n",
" return model\n",
""
" return model\n"
]
},
{
Expand Down Expand Up @@ -622,15 +613,15 @@
" self.total_steps = total_steps\n",
" self.warmup_learning_rate = warmup_learning_rate\n",
" self.warmup_steps = warmup_steps\n",
" self.pi = tf.constant(np.pi)\n",
" self.pi = ops.array(np.pi)\n",
"\n",
" def __call__(self, step):\n",
" if self.total_steps < self.warmup_steps:\n",
" raise ValueError(\"Total_steps must be larger or equal to warmup_steps.\")\n",
"\n",
" cos_annealed_lr = tf.cos(\n",
" cos_annealed_lr = ops.cos(\n",
" self.pi\n",
" * (tf.cast(step, tf.float32) - self.warmup_steps)\n",
" * (ops.cast(step, dtype=\"float32\") - self.warmup_steps)\n",
" / float(self.total_steps - self.warmup_steps)\n",
" )\n",
" learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)\n",
Expand All @@ -644,11 +635,13 @@
" slope = (\n",
" self.learning_rate_base - self.warmup_learning_rate\n",
" ) / self.warmup_steps\n",
" warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate\n",
" learning_rate = tf.where(\n",
" warmup_rate = (\n",
" slope * ops.cast(step, dtype=\"float32\") + self.warmup_learning_rate\n",
" )\n",
" learning_rate = ops.where(\n",
" step < self.warmup_steps, warmup_rate, learning_rate\n",
" )\n",
" return tf.where(\n",
" return ops.where(\n",
" step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n",
" )\n",
"\n",
Expand All @@ -664,7 +657,7 @@
" warmup_steps=warmup_steps,\n",
" )\n",
"\n",
" optimizer = tfa.optimizers.AdamW(\n",
" optimizer = keras.optimizers.AdamW(\n",
" learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY\n",
" )\n",
"\n",
Expand Down Expand Up @@ -720,7 +713,7 @@
"I would like to thank [Jarvislabs.ai](https://jarvislabs.ai/) for\n",
"generously helping with GPU credits.\n",
"\n",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2) ",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2) \n",
"and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/vit-small-ds)."
]
}
Expand Down Expand Up @@ -754,4 +747,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading