From 8a55f0a98f1e67af09cfdbdc73f552157180a881 Mon Sep 17 00:00:00 2001 From: Sitam Meur <103279526+sitamgithub-MSIT@users.noreply.github.com> Date: Wed, 27 Nov 2024 15:39:37 +0530 Subject: [PATCH] Migrating Train a Vision Transformer on small datasets example to Keras 3 (#1991) * migrate vit small dataset example to keras3 * requested changes added and backend agnostic done * other generated files are added too --- examples/vision/ipynb/vit_small_ds.ipynb | 95 +++++++++++------------- examples/vision/md/vit_small_ds.md | 79 ++++++++++---------- examples/vision/vit_small_ds.py | 80 ++++++++++---------- 3 files changed, 121 insertions(+), 133 deletions(-) diff --git a/examples/vision/ipynb/vit_small_ds.ipynb b/examples/vision/ipynb/vit_small_ds.ipynb index 7f51e46813..2d89b52fbd 100644 --- a/examples/vision/ipynb/vit_small_ds.ipynb +++ b/examples/vision/ipynb/vit_small_ds.ipynb @@ -10,7 +10,7 @@ "\n", "**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498)
\n", "**Date created:** 2022/01/07
\n", - "**Last modified:** 2022/01/10
\n", + "**Last modified:** 2024/11/27
\n", "**Description:** Training a ViT from scratch on smaller datasets with shifted patch tokenization and locality self-attention." ] }, @@ -47,13 +47,7 @@ "example is inspired from\n", "[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).\n", "\n", - "_Note_: This example requires TensorFlow 2.6 or higher, as well as\n", - "[TensorFlow Addons](https://www.tensorflow.org/addons), which can be\n", - "installed using the following command:\n", - "\n", - "```python\n", - "pip install -qq -U tensorflow-addons\n", - "```" + "_Note_: This example requires TensorFlow 2.6 or higher." ] }, { @@ -75,11 +69,11 @@ "source": [ "import math\n", "import numpy as np\n", + "import keras\n", + "from keras import ops\n", + "from keras import layers\n", "import tensorflow as tf\n", - "from tensorflow import keras\n", - "import tensorflow_addons as tfa\n", "import matplotlib.pyplot as plt\n", - "from tensorflow.keras import layers\n", "\n", "# Setting seed for reproducibiltiy\n", "SEED = 42\n", @@ -279,17 +273,17 @@ " shift_width = self.half_patch\n", "\n", " # Crop the shifted images and pad them\n", - " crop = tf.image.crop_to_bounding_box(\n", + " crop = ops.image.crop_images(\n", " images,\n", - " offset_height=crop_height,\n", - " offset_width=crop_width,\n", + " top_cropping=crop_height,\n", + " left_cropping=crop_width,\n", " target_height=self.image_size - self.half_patch,\n", " target_width=self.image_size - self.half_patch,\n", " )\n", - " shift_pad = tf.image.pad_to_bounding_box(\n", + " shift_pad = ops.image.pad_images(\n", " crop,\n", - " offset_height=shift_height,\n", - " offset_width=shift_width,\n", + " top_padding=shift_height,\n", + " left_padding=shift_width,\n", " target_height=self.image_size,\n", " target_width=self.image_size,\n", " )\n", @@ -298,7 +292,7 @@ " def call(self, images):\n", " if not self.vanilla:\n", " # Concat the shifted images with the original image\n", - " images = tf.concat(\n", + " images = ops.concatenate(\n", " [\n", " images,\n", " self.crop_shift_pad(images, mode=\"left-up\"),\n", @@ -309,11 +303,11 @@ " axis=-1,\n", " )\n", " # Patchify the images and flatten it\n", - " patches = tf.image.extract_patches(\n", + " patches = ops.image.extract_patches(\n", " images=images,\n", - " sizes=[1, self.patch_size, self.patch_size, 1],\n", + " size=(self.patch_size, self.patch_size),\n", " strides=[1, self.patch_size, self.patch_size, 1],\n", - " rates=[1, 1, 1, 1],\n", + " dilation_rate=1,\n", " padding=\"VALID\",\n", " )\n", " flat_patches = self.flatten_patches(patches)\n", @@ -324,8 +318,7 @@ " else:\n", " # Linearly project the flat patches\n", " tokens = self.projection(flat_patches)\n", - " return (tokens, patches)\n", - "" + " return (tokens, patches)\n" ] }, { @@ -348,8 +341,9 @@ "# Get a random image from the training dataset\n", "# and resize the image\n", "image = x_train[np.random.choice(range(x_train.shape[0]))]\n", - "resized_image = tf.image.resize(\n", - " tf.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)\n", + "resized_image = ops.cast(\n", + " ops.image.resize(ops.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)),\n", + " dtype=\"float32\",\n", ")\n", "\n", "# Vanilla patch maker: This takes an image and divides into\n", @@ -363,7 +357,7 @@ " for col in range(n):\n", " plt.subplot(n, n, count)\n", " count = count + 1\n", - " image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))\n", + " image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))\n", " plt.imshow(image)\n", " plt.axis(\"off\")\n", "plt.show()\n", @@ -382,7 +376,7 @@ " for col in range(n):\n", " plt.subplot(n, n, count)\n", " count = count + 1\n", - " image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))\n", + " image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))\n", " plt.imshow(image[..., 3 * index : 3 * index + 3])\n", " plt.axis(\"off\")\n", " plt.show()" @@ -418,13 +412,12 @@ " self.position_embedding = layers.Embedding(\n", " input_dim=num_patches, output_dim=projection_dim\n", " )\n", - " self.positions = tf.range(start=0, limit=self.num_patches, delta=1)\n", + " self.positions = ops.arange(start=0, stop=self.num_patches, step=1)\n", "\n", " def call(self, encoded_patches):\n", " encoded_positions = self.position_embedding(self.positions)\n", " encoded_patches = encoded_patches + encoded_positions\n", - " return encoded_patches\n", - "" + " return encoded_patches\n" ] }, { @@ -479,25 +472,24 @@ "outputs": [], "source": [ "\n", - "class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):\n", + "class MultiHeadAttentionLSA(layers.MultiHeadAttention):\n", " def __init__(self, **kwargs):\n", " super().__init__(**kwargs)\n", " # The trainable temperature term. The initial value is\n", " # the square root of the key dimension.\n", - " self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)\n", + " self.tau = keras.Variable(math.sqrt(float(self._key_dim)), trainable=True)\n", "\n", " def _compute_attention(self, query, key, value, attention_mask=None, training=None):\n", - " query = tf.multiply(query, 1.0 / self.tau)\n", - " attention_scores = tf.einsum(self._dot_product_equation, key, query)\n", + " query = ops.multiply(query, 1.0 / self.tau)\n", + " attention_scores = ops.einsum(self._dot_product_equation, key, query)\n", " attention_scores = self._masked_softmax(attention_scores, attention_mask)\n", " attention_scores_dropout = self._dropout_layer(\n", " attention_scores, training=training\n", " )\n", - " attention_output = tf.einsum(\n", + " attention_output = ops.einsum(\n", " self._combine_equation, attention_scores_dropout, value\n", " )\n", - " return attention_output, attention_scores\n", - "" + " return attention_output, attention_scores\n" ] }, { @@ -520,14 +512,14 @@ "\n", "def mlp(x, hidden_units, dropout_rate):\n", " for units in hidden_units:\n", - " x = layers.Dense(units, activation=tf.nn.gelu)(x)\n", + " x = layers.Dense(units, activation=\"gelu\")(x)\n", " x = layers.Dropout(dropout_rate)(x)\n", " return x\n", "\n", "\n", "# Build the diagonal attention mask\n", - "diag_attn_mask = 1 - tf.eye(NUM_PATCHES)\n", - "diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)" + "diag_attn_mask = 1 - ops.eye(NUM_PATCHES)\n", + "diag_attn_mask = ops.cast([diag_attn_mask], dtype=\"int8\")" ] }, { @@ -589,8 +581,7 @@ " logits = layers.Dense(NUM_CLASSES)(features)\n", " # Create the Keras model.\n", " model = keras.Model(inputs=inputs, outputs=logits)\n", - " return model\n", - "" + " return model\n" ] }, { @@ -622,15 +613,15 @@ " self.total_steps = total_steps\n", " self.warmup_learning_rate = warmup_learning_rate\n", " self.warmup_steps = warmup_steps\n", - " self.pi = tf.constant(np.pi)\n", + " self.pi = ops.array(np.pi)\n", "\n", " def __call__(self, step):\n", " if self.total_steps < self.warmup_steps:\n", " raise ValueError(\"Total_steps must be larger or equal to warmup_steps.\")\n", "\n", - " cos_annealed_lr = tf.cos(\n", + " cos_annealed_lr = ops.cos(\n", " self.pi\n", - " * (tf.cast(step, tf.float32) - self.warmup_steps)\n", + " * (ops.cast(step, dtype=\"float32\") - self.warmup_steps)\n", " / float(self.total_steps - self.warmup_steps)\n", " )\n", " learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)\n", @@ -644,11 +635,13 @@ " slope = (\n", " self.learning_rate_base - self.warmup_learning_rate\n", " ) / self.warmup_steps\n", - " warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate\n", - " learning_rate = tf.where(\n", + " warmup_rate = (\n", + " slope * ops.cast(step, dtype=\"float32\") + self.warmup_learning_rate\n", + " )\n", + " learning_rate = ops.where(\n", " step < self.warmup_steps, warmup_rate, learning_rate\n", " )\n", - " return tf.where(\n", + " return ops.where(\n", " step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n", " )\n", "\n", @@ -664,7 +657,7 @@ " warmup_steps=warmup_steps,\n", " )\n", "\n", - " optimizer = tfa.optimizers.AdamW(\n", + " optimizer = keras.optimizers.AdamW(\n", " learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY\n", " )\n", "\n", @@ -720,7 +713,7 @@ "I would like to thank [Jarvislabs.ai](https://jarvislabs.ai/) for\n", "generously helping with GPU credits.\n", "\n", - "You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2) ", + "You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2) \n", "and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/vit-small-ds)." ] } @@ -754,4 +747,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/examples/vision/md/vit_small_ds.md b/examples/vision/md/vit_small_ds.md index 8a7f824b78..d916f6728a 100644 --- a/examples/vision/md/vit_small_ds.md +++ b/examples/vision/md/vit_small_ds.md @@ -2,7 +2,7 @@ **Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498)
**Date created:** 2022/01/07
-**Last modified:** 2022/01/10
+**Last modified:** 2024/11/27
**Description:** Training a ViT from scratch on smaller datasets with shifted patch tokenization and locality self-attention. @@ -38,13 +38,7 @@ This example implements the ideas of the paper. A large part of this example is inspired from [Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/). -_Note_: This example requires TensorFlow 2.6 or higher, as well as -[TensorFlow Addons](https://www.tensorflow.org/addons), which can be -installed using the following command: - -```python -pip install -qq -U tensorflow-addons -``` +_Note_: This example requires TensorFlow 2.6 or higher. --- ## Setup @@ -53,11 +47,11 @@ pip install -qq -U tensorflow-addons ```python import math import numpy as np +import keras +from keras import ops +from keras import layers import tensorflow as tf -from tensorflow import keras -import tensorflow_addons as tfa import matplotlib.pyplot as plt -from tensorflow.keras import layers # Setting seed for reproducibiltiy SEED = 42 @@ -221,17 +215,17 @@ class ShiftedPatchTokenization(layers.Layer): shift_width = self.half_patch # Crop the shifted images and pad them - crop = tf.image.crop_to_bounding_box( + crop = ops.image.crop_images( images, - offset_height=crop_height, - offset_width=crop_width, + top_cropping=crop_height, + left_cropping=crop_width, target_height=self.image_size - self.half_patch, target_width=self.image_size - self.half_patch, ) - shift_pad = tf.image.pad_to_bounding_box( + shift_pad = ops.image.pad_images( crop, - offset_height=shift_height, - offset_width=shift_width, + top_padding=shift_height, + left_padding=shift_width, target_height=self.image_size, target_width=self.image_size, ) @@ -240,7 +234,7 @@ class ShiftedPatchTokenization(layers.Layer): def call(self, images): if not self.vanilla: # Concat the shifted images with the original image - images = tf.concat( + images = ops.concatenate( [ images, self.crop_shift_pad(images, mode="left-up"), @@ -251,11 +245,11 @@ class ShiftedPatchTokenization(layers.Layer): axis=-1, ) # Patchify the images and flatten it - patches = tf.image.extract_patches( + patches = ops.image.extract_patches( images=images, - sizes=[1, self.patch_size, self.patch_size, 1], + size=(self.patch_size, self.patch_size), strides=[1, self.patch_size, self.patch_size, 1], - rates=[1, 1, 1, 1], + dilation_rate=1, padding="VALID", ) flat_patches = self.flatten_patches(patches) @@ -277,8 +271,9 @@ class ShiftedPatchTokenization(layers.Layer): # Get a random image from the training dataset # and resize the image image = x_train[np.random.choice(range(x_train.shape[0]))] -resized_image = tf.image.resize( - tf.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE) +resized_image = ops.cast( + ops.image.resize(ops.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)), + dtype="float32", ) # Vanilla patch maker: This takes an image and divides into @@ -292,7 +287,7 @@ for row in range(n): for col in range(n): plt.subplot(n, n, count) count = count + 1 - image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3)) + image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3)) plt.imshow(image) plt.axis("off") plt.show() @@ -311,7 +306,7 @@ for index, name in enumerate(shifted_images): for col in range(n): plt.subplot(n, n, count) count = count + 1 - image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3)) + image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3)) plt.imshow(image[..., 3 * index : 3 * index + 3]) plt.axis("off") plt.show() @@ -401,7 +396,7 @@ class PatchEncoder(layers.Layer): self.position_embedding = layers.Embedding( input_dim=num_patches, output_dim=projection_dim ) - self.positions = tf.range(start=0, limit=self.num_patches, delta=1) + self.positions = ops.arange(start=0, stop=self.num_patches, step=1) def call(self, encoded_patches): encoded_positions = self.position_embedding(self.positions) @@ -450,21 +445,21 @@ at a later stage. ```python -class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention): +class MultiHeadAttentionLSA(layers.MultiHeadAttention): def __init__(self, **kwargs): super().__init__(**kwargs) # The trainable temperature term. The initial value is # the square root of the key dimension. - self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True) + self.tau = keras.Variable(math.sqrt(float(self._key_dim)), trainable=True) def _compute_attention(self, query, key, value, attention_mask=None, training=None): - query = tf.multiply(query, 1.0 / self.tau) - attention_scores = tf.einsum(self._dot_product_equation, key, query) + query = ops.multiply(query, 1.0 / self.tau) + attention_scores = ops.einsum(self._dot_product_equation, key, query) attention_scores = self._masked_softmax(attention_scores, attention_mask) attention_scores_dropout = self._dropout_layer( attention_scores, training=training ) - attention_output = tf.einsum( + attention_output = ops.einsum( self._combine_equation, attention_scores_dropout, value ) return attention_output, attention_scores @@ -479,14 +474,14 @@ class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention): def mlp(x, hidden_units, dropout_rate): for units in hidden_units: - x = layers.Dense(units, activation=tf.nn.gelu)(x) + x = layers.Dense(units, activation="gelu")(x) x = layers.Dropout(dropout_rate)(x) return x # Build the diagonal attention mask -diag_attn_mask = 1 - tf.eye(NUM_PATCHES) -diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8) +diag_attn_mask = 1 - ops.eye(NUM_PATCHES) +diag_attn_mask = ops.cast([diag_attn_mask], dtype="int8") ``` --- @@ -557,15 +552,15 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule): self.total_steps = total_steps self.warmup_learning_rate = warmup_learning_rate self.warmup_steps = warmup_steps - self.pi = tf.constant(np.pi) + self.pi = ops.array(np.pi) def __call__(self, step): if self.total_steps < self.warmup_steps: raise ValueError("Total_steps must be larger or equal to warmup_steps.") - cos_annealed_lr = tf.cos( + cos_annealed_lr = ops.cos( self.pi - * (tf.cast(step, tf.float32) - self.warmup_steps) + * (ops.cast(step, dtype="float32") - self.warmup_steps) / float(self.total_steps - self.warmup_steps) ) learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr) @@ -579,11 +574,13 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule): slope = ( self.learning_rate_base - self.warmup_learning_rate ) / self.warmup_steps - warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate - learning_rate = tf.where( + warmup_rate = ( + slope * ops.cast(step, dtype="float32") + self.warmup_learning_rate + ) + learning_rate = ops.where( step < self.warmup_steps, warmup_rate, learning_rate ) - return tf.where( + return ops.where( step > self.total_steps, 0.0, learning_rate, name="learning_rate" ) @@ -599,7 +596,7 @@ def run_experiment(model): warmup_steps=warmup_steps, ) - optimizer = tfa.optimizers.AdamW( + optimizer = keras.optimizers.AdamW( learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY ) diff --git a/examples/vision/vit_small_ds.py b/examples/vision/vit_small_ds.py index 40ef2c526c..525967fca6 100644 --- a/examples/vision/vit_small_ds.py +++ b/examples/vision/vit_small_ds.py @@ -2,9 +2,10 @@ Title: Train a Vision Transformer on small datasets Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498) Date created: 2022/01/07 -Last modified: 2022/01/10 +Last modified: 2024/11/27 Description: Training a ViT from scratch on smaller datasets with shifted patch tokenization and locality self-attention. Accelerator: GPU +Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT) """ """ @@ -35,25 +36,19 @@ example is inspired from [Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/). -_Note_: This example requires TensorFlow 2.6 or higher, as well as -[TensorFlow Addons](https://www.tensorflow.org/addons), which can be -installed using the following command: - -```python -pip install -qq -U tensorflow-addons +_Note_: This example requires TensorFlow 2.6 or higher. ``` """ """ ## Setup """ - import math import numpy as np +import keras +from keras import ops +from keras import layers import tensorflow as tf -from tensorflow import keras -import tensorflow_addons as tfa import matplotlib.pyplot as plt -from tensorflow.keras import layers # Setting seed for reproducibiltiy SEED = 42 @@ -197,17 +192,17 @@ def crop_shift_pad(self, images, mode): shift_width = self.half_patch # Crop the shifted images and pad them - crop = tf.image.crop_to_bounding_box( + crop = ops.image.crop_images( images, - offset_height=crop_height, - offset_width=crop_width, + top_cropping=crop_height, + left_cropping=crop_width, target_height=self.image_size - self.half_patch, target_width=self.image_size - self.half_patch, ) - shift_pad = tf.image.pad_to_bounding_box( + shift_pad = ops.image.pad_images( crop, - offset_height=shift_height, - offset_width=shift_width, + top_padding=shift_height, + left_padding=shift_width, target_height=self.image_size, target_width=self.image_size, ) @@ -216,7 +211,7 @@ def crop_shift_pad(self, images, mode): def call(self, images): if not self.vanilla: # Concat the shifted images with the original image - images = tf.concat( + images = ops.concatenate( [ images, self.crop_shift_pad(images, mode="left-up"), @@ -227,11 +222,11 @@ def call(self, images): axis=-1, ) # Patchify the images and flatten it - patches = tf.image.extract_patches( + patches = ops.image.extract_patches( images=images, - sizes=[1, self.patch_size, self.patch_size, 1], + size=(self.patch_size, self.patch_size), strides=[1, self.patch_size, self.patch_size, 1], - rates=[1, 1, 1, 1], + dilation_rate=1, padding="VALID", ) flat_patches = self.flatten_patches(patches) @@ -252,8 +247,9 @@ def call(self, images): # Get a random image from the training dataset # and resize the image image = x_train[np.random.choice(range(x_train.shape[0]))] -resized_image = tf.image.resize( - tf.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE) +resized_image = ops.cast( + ops.image.resize(ops.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)), + dtype="float32", ) # Vanilla patch maker: This takes an image and divides into @@ -267,7 +263,7 @@ def call(self, images): for col in range(n): plt.subplot(n, n, count) count = count + 1 - image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3)) + image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3)) plt.imshow(image) plt.axis("off") plt.show() @@ -286,7 +282,7 @@ def call(self, images): for col in range(n): plt.subplot(n, n, count) count = count + 1 - image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3)) + image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3)) plt.imshow(image[..., 3 * index : 3 * index + 3]) plt.axis("off") plt.show() @@ -308,7 +304,7 @@ def __init__( self.position_embedding = layers.Embedding( input_dim=num_patches, output_dim=projection_dim ) - self.positions = tf.range(start=0, limit=self.num_patches, delta=1) + self.positions = ops.arange(start=0, stop=self.num_patches, step=1) def call(self, encoded_patches): encoded_positions = self.position_embedding(self.positions) @@ -355,21 +351,21 @@ def call(self, encoded_patches): """ -class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention): +class MultiHeadAttentionLSA(layers.MultiHeadAttention): def __init__(self, **kwargs): super().__init__(**kwargs) # The trainable temperature term. The initial value is # the square root of the key dimension. - self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True) + self.tau = keras.Variable(math.sqrt(float(self._key_dim)), trainable=True) def _compute_attention(self, query, key, value, attention_mask=None, training=None): - query = tf.multiply(query, 1.0 / self.tau) - attention_scores = tf.einsum(self._dot_product_equation, key, query) + query = ops.multiply(query, 1.0 / self.tau) + attention_scores = ops.einsum(self._dot_product_equation, key, query) attention_scores = self._masked_softmax(attention_scores, attention_mask) attention_scores_dropout = self._dropout_layer( attention_scores, training=training ) - attention_output = tf.einsum( + attention_output = ops.einsum( self._combine_equation, attention_scores_dropout, value ) return attention_output, attention_scores @@ -382,14 +378,14 @@ def _compute_attention(self, query, key, value, attention_mask=None, training=No def mlp(x, hidden_units, dropout_rate): for units in hidden_units: - x = layers.Dense(units, activation=tf.nn.gelu)(x) + x = layers.Dense(units, activation="gelu")(x) x = layers.Dropout(dropout_rate)(x) return x # Build the diagonal attention mask -diag_attn_mask = 1 - tf.eye(NUM_PATCHES) -diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8) +diag_attn_mask = 1 - ops.eye(NUM_PATCHES) +diag_attn_mask = ops.cast([diag_attn_mask], dtype="int8") """ ## Build the ViT @@ -457,15 +453,15 @@ def __init__( self.total_steps = total_steps self.warmup_learning_rate = warmup_learning_rate self.warmup_steps = warmup_steps - self.pi = tf.constant(np.pi) + self.pi = ops.array(np.pi) def __call__(self, step): if self.total_steps < self.warmup_steps: raise ValueError("Total_steps must be larger or equal to warmup_steps.") - cos_annealed_lr = tf.cos( + cos_annealed_lr = ops.cos( self.pi - * (tf.cast(step, tf.float32) - self.warmup_steps) + * (ops.cast(step, dtype="float32") - self.warmup_steps) / float(self.total_steps - self.warmup_steps) ) learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr) @@ -479,11 +475,13 @@ def __call__(self, step): slope = ( self.learning_rate_base - self.warmup_learning_rate ) / self.warmup_steps - warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate - learning_rate = tf.where( + warmup_rate = ( + slope * ops.cast(step, dtype="float32") + self.warmup_learning_rate + ) + learning_rate = ops.where( step < self.warmup_steps, warmup_rate, learning_rate ) - return tf.where( + return ops.where( step > self.total_steps, 0.0, learning_rate, name="learning_rate" ) @@ -499,7 +497,7 @@ def run_experiment(model): warmup_steps=warmup_steps, ) - optimizer = tfa.optimizers.AdamW( + optimizer = keras.optimizers.AdamW( learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY )