Skip to content

Commit

Permalink
Migrating Train a Vision Transformer on small datasets example to Ker…
Browse files Browse the repository at this point in the history
…as 3 (#1991)

* migrate vit small dataset example to keras3

* requested changes added and backend agnostic done

* other generated files are added too
  • Loading branch information
sitamgithub-MSIT authored Nov 27, 2024
1 parent ede75bc commit 8a55f0a
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 133 deletions.
95 changes: 44 additions & 51 deletions examples/vision/ipynb/vit_small_ds.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498)<br>\n",
"**Date created:** 2022/01/07<br>\n",
"**Last modified:** 2022/01/10<br>\n",
"**Last modified:** 2024/11/27<br>\n",
"**Description:** Training a ViT from scratch on smaller datasets with shifted patch tokenization and locality self-attention."
]
},
Expand Down Expand Up @@ -47,13 +47,7 @@
"example is inspired from\n",
"[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).\n",
"\n",
"_Note_: This example requires TensorFlow 2.6 or higher, as well as\n",
"[TensorFlow Addons](https://www.tensorflow.org/addons), which can be\n",
"installed using the following command:\n",
"\n",
"```python\n",
"pip install -qq -U tensorflow-addons\n",
"```"
"_Note_: This example requires TensorFlow 2.6 or higher."
]
},
{
Expand All @@ -75,11 +69,11 @@
"source": [
"import math\n",
"import numpy as np\n",
"import keras\n",
"from keras import ops\n",
"from keras import layers\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"import tensorflow_addons as tfa\n",
"import matplotlib.pyplot as plt\n",
"from tensorflow.keras import layers\n",
"\n",
"# Setting seed for reproducibiltiy\n",
"SEED = 42\n",
Expand Down Expand Up @@ -279,17 +273,17 @@
" shift_width = self.half_patch\n",
"\n",
" # Crop the shifted images and pad them\n",
" crop = tf.image.crop_to_bounding_box(\n",
" crop = ops.image.crop_images(\n",
" images,\n",
" offset_height=crop_height,\n",
" offset_width=crop_width,\n",
" top_cropping=crop_height,\n",
" left_cropping=crop_width,\n",
" target_height=self.image_size - self.half_patch,\n",
" target_width=self.image_size - self.half_patch,\n",
" )\n",
" shift_pad = tf.image.pad_to_bounding_box(\n",
" shift_pad = ops.image.pad_images(\n",
" crop,\n",
" offset_height=shift_height,\n",
" offset_width=shift_width,\n",
" top_padding=shift_height,\n",
" left_padding=shift_width,\n",
" target_height=self.image_size,\n",
" target_width=self.image_size,\n",
" )\n",
Expand All @@ -298,7 +292,7 @@
" def call(self, images):\n",
" if not self.vanilla:\n",
" # Concat the shifted images with the original image\n",
" images = tf.concat(\n",
" images = ops.concatenate(\n",
" [\n",
" images,\n",
" self.crop_shift_pad(images, mode=\"left-up\"),\n",
Expand All @@ -309,11 +303,11 @@
" axis=-1,\n",
" )\n",
" # Patchify the images and flatten it\n",
" patches = tf.image.extract_patches(\n",
" patches = ops.image.extract_patches(\n",
" images=images,\n",
" sizes=[1, self.patch_size, self.patch_size, 1],\n",
" size=(self.patch_size, self.patch_size),\n",
" strides=[1, self.patch_size, self.patch_size, 1],\n",
" rates=[1, 1, 1, 1],\n",
" dilation_rate=1,\n",
" padding=\"VALID\",\n",
" )\n",
" flat_patches = self.flatten_patches(patches)\n",
Expand All @@ -324,8 +318,7 @@
" else:\n",
" # Linearly project the flat patches\n",
" tokens = self.projection(flat_patches)\n",
" return (tokens, patches)\n",
""
" return (tokens, patches)\n"
]
},
{
Expand All @@ -348,8 +341,9 @@
"# Get a random image from the training dataset\n",
"# and resize the image\n",
"image = x_train[np.random.choice(range(x_train.shape[0]))]\n",
"resized_image = tf.image.resize(\n",
" tf.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)\n",
"resized_image = ops.cast(\n",
" ops.image.resize(ops.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)),\n",
" dtype=\"float32\",\n",
")\n",
"\n",
"# Vanilla patch maker: This takes an image and divides into\n",
Expand All @@ -363,7 +357,7 @@
" for col in range(n):\n",
" plt.subplot(n, n, count)\n",
" count = count + 1\n",
" image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))\n",
" image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))\n",
" plt.imshow(image)\n",
" plt.axis(\"off\")\n",
"plt.show()\n",
Expand All @@ -382,7 +376,7 @@
" for col in range(n):\n",
" plt.subplot(n, n, count)\n",
" count = count + 1\n",
" image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))\n",
" image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))\n",
" plt.imshow(image[..., 3 * index : 3 * index + 3])\n",
" plt.axis(\"off\")\n",
" plt.show()"
Expand Down Expand Up @@ -418,13 +412,12 @@
" self.position_embedding = layers.Embedding(\n",
" input_dim=num_patches, output_dim=projection_dim\n",
" )\n",
" self.positions = tf.range(start=0, limit=self.num_patches, delta=1)\n",
" self.positions = ops.arange(start=0, stop=self.num_patches, step=1)\n",
"\n",
" def call(self, encoded_patches):\n",
" encoded_positions = self.position_embedding(self.positions)\n",
" encoded_patches = encoded_patches + encoded_positions\n",
" return encoded_patches\n",
""
" return encoded_patches\n"
]
},
{
Expand Down Expand Up @@ -479,25 +472,24 @@
"outputs": [],
"source": [
"\n",
"class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):\n",
"class MultiHeadAttentionLSA(layers.MultiHeadAttention):\n",
" def __init__(self, **kwargs):\n",
" super().__init__(**kwargs)\n",
" # The trainable temperature term. The initial value is\n",
" # the square root of the key dimension.\n",
" self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)\n",
" self.tau = keras.Variable(math.sqrt(float(self._key_dim)), trainable=True)\n",
"\n",
" def _compute_attention(self, query, key, value, attention_mask=None, training=None):\n",
" query = tf.multiply(query, 1.0 / self.tau)\n",
" attention_scores = tf.einsum(self._dot_product_equation, key, query)\n",
" query = ops.multiply(query, 1.0 / self.tau)\n",
" attention_scores = ops.einsum(self._dot_product_equation, key, query)\n",
" attention_scores = self._masked_softmax(attention_scores, attention_mask)\n",
" attention_scores_dropout = self._dropout_layer(\n",
" attention_scores, training=training\n",
" )\n",
" attention_output = tf.einsum(\n",
" attention_output = ops.einsum(\n",
" self._combine_equation, attention_scores_dropout, value\n",
" )\n",
" return attention_output, attention_scores\n",
""
" return attention_output, attention_scores\n"
]
},
{
Expand All @@ -520,14 +512,14 @@
"\n",
"def mlp(x, hidden_units, dropout_rate):\n",
" for units in hidden_units:\n",
" x = layers.Dense(units, activation=tf.nn.gelu)(x)\n",
" x = layers.Dense(units, activation=\"gelu\")(x)\n",
" x = layers.Dropout(dropout_rate)(x)\n",
" return x\n",
"\n",
"\n",
"# Build the diagonal attention mask\n",
"diag_attn_mask = 1 - tf.eye(NUM_PATCHES)\n",
"diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)"
"diag_attn_mask = 1 - ops.eye(NUM_PATCHES)\n",
"diag_attn_mask = ops.cast([diag_attn_mask], dtype=\"int8\")"
]
},
{
Expand Down Expand Up @@ -589,8 +581,7 @@
" logits = layers.Dense(NUM_CLASSES)(features)\n",
" # Create the Keras model.\n",
" model = keras.Model(inputs=inputs, outputs=logits)\n",
" return model\n",
""
" return model\n"
]
},
{
Expand Down Expand Up @@ -622,15 +613,15 @@
" self.total_steps = total_steps\n",
" self.warmup_learning_rate = warmup_learning_rate\n",
" self.warmup_steps = warmup_steps\n",
" self.pi = tf.constant(np.pi)\n",
" self.pi = ops.array(np.pi)\n",
"\n",
" def __call__(self, step):\n",
" if self.total_steps < self.warmup_steps:\n",
" raise ValueError(\"Total_steps must be larger or equal to warmup_steps.\")\n",
"\n",
" cos_annealed_lr = tf.cos(\n",
" cos_annealed_lr = ops.cos(\n",
" self.pi\n",
" * (tf.cast(step, tf.float32) - self.warmup_steps)\n",
" * (ops.cast(step, dtype=\"float32\") - self.warmup_steps)\n",
" / float(self.total_steps - self.warmup_steps)\n",
" )\n",
" learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)\n",
Expand All @@ -644,11 +635,13 @@
" slope = (\n",
" self.learning_rate_base - self.warmup_learning_rate\n",
" ) / self.warmup_steps\n",
" warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate\n",
" learning_rate = tf.where(\n",
" warmup_rate = (\n",
" slope * ops.cast(step, dtype=\"float32\") + self.warmup_learning_rate\n",
" )\n",
" learning_rate = ops.where(\n",
" step < self.warmup_steps, warmup_rate, learning_rate\n",
" )\n",
" return tf.where(\n",
" return ops.where(\n",
" step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n",
" )\n",
"\n",
Expand All @@ -664,7 +657,7 @@
" warmup_steps=warmup_steps,\n",
" )\n",
"\n",
" optimizer = tfa.optimizers.AdamW(\n",
" optimizer = keras.optimizers.AdamW(\n",
" learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY\n",
" )\n",
"\n",
Expand Down Expand Up @@ -720,7 +713,7 @@
"I would like to thank [Jarvislabs.ai](https://jarvislabs.ai/) for\n",
"generously helping with GPU credits.\n",
"\n",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2) ",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2) \n",
"and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/vit-small-ds)."
]
}
Expand Down Expand Up @@ -754,4 +747,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading

0 comments on commit 8a55f0a

Please sign in to comment.