Skip to content

Commit

Permalink
Add token_learner.py example for Keras 3 (#1596)
Browse files Browse the repository at this point in the history
* initial changes

* generate the files
  • Loading branch information
haifeng-jin authored Nov 10, 2023
1 parent 6145fd2 commit fdaa9de
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 185 deletions.
90 changes: 41 additions & 49 deletions examples/vision/ipynb/token_learner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
"source": [
"# Learning to tokenize in Vision Transformers\n",
"\n",
"**Authors:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution)<br>\n",
"**Authors:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Sayak Paul](https://twitter.com/RisingSayak) (equal contribution), converted to Keras 3 by [Muhammad Anas Raza](https://anasrz.com)<br>\n",
"**Date created:** 2021/12/10<br>\n",
"**Last modified:** 2021/12/15<br>\n",
"**Last modified:** 2023/08/14<br>\n",
"**Description:** Adaptively generating a smaller number of tokens for Vision Transformers."
]
},
Expand Down Expand Up @@ -56,22 +56,6 @@
"* [TokenLearner slides from NeurIPS 2021](https://nips.cc/media/neurips-2021/Slides/26578.pdf)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"## Setup\n",
"\n",
"We need to install TensorFlow Addons to run this example. To install it, execute the\n",
"following:\n",
"\n",
"```shell\n",
"pip install tensorflow-addons\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand All @@ -89,11 +73,11 @@
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import keras\n",
"from keras import layers\n",
"from keras import ops\n",
"from tensorflow import data as tf_data\n",
"\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import tensorflow_addons as tfa\n",
"\n",
"from datetime import datetime\n",
"import matplotlib.pyplot as plt\n",
Expand Down Expand Up @@ -124,7 +108,7 @@
"source": [
"# DATA\n",
"BATCH_SIZE = 256\n",
"AUTO = tf.data.AUTOTUNE\n",
"AUTO = tf_data.AUTOTUNE\n",
"INPUT_SHAPE = (32, 32, 3)\n",
"NUM_CLASSES = 10\n",
"\n",
Expand All @@ -133,7 +117,7 @@
"WEIGHT_DECAY = 1e-4\n",
"\n",
"# TRAINING\n",
"EPOCHS = 20\n",
"EPOCHS = 1\n",
"\n",
"# AUGMENTATION\n",
"IMAGE_SIZE = 48 # We will resize input images to this size.\n",
Expand Down Expand Up @@ -182,13 +166,13 @@
"print(f\"Testing samples: {len(x_test)}\")\n",
"\n",
"# Convert to tf.data.Dataset objects.\n",
"train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
"train_ds = tf_data.Dataset.from_tensor_slices((x_train, y_train))\n",
"train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO)\n",
"\n",
"val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))\n",
"val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val))\n",
"val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)\n",
"\n",
"test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n",
"test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))\n",
"test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)"
]
},
Expand Down Expand Up @@ -266,19 +250,25 @@
"outputs": [],
"source": [
"\n",
"def position_embedding(\n",
" projected_patches, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM\n",
"):\n",
" # Build the positions.\n",
" positions = tf.range(start=0, limit=num_patches, delta=1)\n",
"\n",
" # Encode the positions with an Embedding layer.\n",
" encoded_positions = layers.Embedding(\n",
" input_dim=num_patches, output_dim=projection_dim\n",
" )(positions)\n",
"\n",
" # Add encoded positions to the projected patches.\n",
" return projected_patches + encoded_positions\n",
"class PatchEncoder(layers.Layer):\n",
" def __init__(self, num_patches, projection_dim):\n",
" super().__init__()\n",
" self.num_patches = num_patches\n",
" self.position_embedding = layers.Embedding(\n",
" input_dim=num_patches, output_dim=projection_dim\n",
" )\n",
"\n",
" def call(self, patch):\n",
" positions = ops.expand_dims(\n",
" ops.arange(start=0, stop=self.num_patches, step=1), axis=0\n",
" )\n",
" encoded = patch + self.position_embedding(positions)\n",
" return encoded\n",
"\n",
" def get_config(self):\n",
" config = super().get_config()\n",
" config.update({\"num_patches\": self.num_patches})\n",
" return config\n",
""
]
},
Expand Down Expand Up @@ -306,7 +296,7 @@
" # Iterate over the hidden units and\n",
" # add Dense => Dropout.\n",
" for units in hidden_units:\n",
" x = layers.Dense(units, activation=tf.nn.gelu)(x)\n",
" x = layers.Dense(units, activation=ops.gelu)(x)\n",
" x = layers.Dropout(dropout_rate)(x)\n",
" return x\n",
""
Expand Down Expand Up @@ -360,21 +350,21 @@
" layers.Conv2D(\n",
" filters=number_of_tokens,\n",
" kernel_size=(3, 3),\n",
" activation=tf.nn.gelu,\n",
" activation=ops.gelu,\n",
" padding=\"same\",\n",
" use_bias=False,\n",
" ),\n",
" layers.Conv2D(\n",
" filters=number_of_tokens,\n",
" kernel_size=(3, 3),\n",
" activation=tf.nn.gelu,\n",
" activation=ops.gelu,\n",
" padding=\"same\",\n",
" use_bias=False,\n",
" ),\n",
" layers.Conv2D(\n",
" filters=number_of_tokens,\n",
" kernel_size=(3, 3),\n",
" activation=tf.nn.gelu,\n",
" activation=ops.gelu,\n",
" padding=\"same\",\n",
" use_bias=False,\n",
" ),\n",
Expand All @@ -400,11 +390,11 @@
"\n",
" # Element-Wise multiplication of the attention maps and the inputs\n",
" attended_inputs = (\n",
" attention_maps[..., tf.newaxis] * inputs\n",
" ops.expand_dims(attention_maps, axis=-1) * inputs\n",
" ) # (B, num_tokens, H*W, C)\n",
"\n",
" # Global average pooling the element wise multiplication result.\n",
" outputs = tf.reduce_mean(attended_inputs, axis=2) # (B, num_tokens, C)\n",
" outputs = ops.mean(attended_inputs, axis=2) # (B, num_tokens, C)\n",
" return outputs\n",
""
]
Expand Down Expand Up @@ -488,7 +478,9 @@
" ) # (B, number_patches, projection_dim)\n",
"\n",
" # Add positional embeddings to the projected patches.\n",
" encoded_patches = position_embedding(\n",
" encoded_patches = PatchEncoder(\n",
" num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM\n",
" )(\n",
" projected_patches\n",
" ) # (B, number_patches, projection_dim)\n",
" encoded_patches = layers.Dropout(0.1)(encoded_patches)\n",
Expand Down Expand Up @@ -556,7 +548,7 @@
"\n",
"def run_experiment(model):\n",
" # Initialize the AdamW optimizer.\n",
" optimizer = tfa.optimizers.AdamW(\n",
" optimizer = keras.optimizers.AdamW(\n",
" learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY\n",
" )\n",
"\n",
Expand All @@ -572,7 +564,7 @@
" )\n",
"\n",
" # Define callbacks\n",
" checkpoint_filepath = \"/tmp/checkpoint\"\n",
" checkpoint_filepath = \"/tmp/checkpoint.weights.h5\"\n",
" checkpoint_callback = keras.callbacks.ModelCheckpoint(\n",
" checkpoint_filepath,\n",
" monitor=\"val_accuracy\",\n",
Expand Down
Loading

0 comments on commit fdaa9de

Please sign in to comment.