Skip to content

Commit

Permalink
Update RandAugment example to Keras 3 (keras-team#1683)
Browse files Browse the repository at this point in the history
  • Loading branch information
sachinprasadhs authored and SuryanarayanaY committed Jan 19, 2024
1 parent 764b786 commit afe0d2c
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 223 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
99 changes: 27 additions & 72 deletions examples/vision/ipynb/randaugment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
"source": [
"# RandAugment for Image Classification for Improved Robustness\n",
"\n",
"**Author:** [Sayak Paul](https://twitter.com/RisingSayak)<br>\n",
"**Authors:** [Sayak Paul](https://twitter.com/RisingSayak)[Sachin Prasad](https://github.com/sachinprasadhs)<br>\n",
"**Date created:** 2021/03/13<br>\n",
"**Last modified:** 2021/03/17<br>\n",
"**Last modified:** 2023/12/12<br>\n",
"**Description:** RandAugment for training an image classification model with improved robustness."
]
},
Expand All @@ -28,12 +28,6 @@
"saturations, etc. along with more traditional augmentation transforms such as\n",
"random crops.\n",
"\n",
"RandAugment has two parameters:\n",
"\n",
"* `n` that denotes the number of randomly selected augmentation transforms to apply\n",
"sequentially\n",
"* `m` strength of all the augmentation transforms\n",
"\n",
"These parameters are tuned for a given dataset and a network architecture. The authors of\n",
"RandAugment also provide pseudocode of RandAugment in the original paper (Figure 2).\n",
"\n",
Expand All @@ -43,12 +37,9 @@
"It has been also central to the\n",
"success of [EfficientNets](https://arxiv.org/abs/1905.11946).\n",
"\n",
"This example requires TensorFlow 2.4 or higher, as well as\n",
"[`imgaug`](https://imgaug.readthedocs.io/),\n",
"which can be installed using the following command:\n",
"\n",
"```python\n",
"pip install imgaug\n",
"pip install keras-cv\n",
"```"
]
},
Expand All @@ -69,17 +60,20 @@
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"import keras\n",
"import keras_cv\n",
"from keras import ops\n",
"from keras import layers\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from tensorflow.keras import layers\n",
"import tensorflow_datasets as tfds\n",
"from imgaug import augmenters as iaa\n",
"import imgaug as ia\n",
"\n",
"tfds.disable_progress_bar()\n",
"tf.random.set_seed(42)\n",
"ia.seed(42)"
"keras.utils.set_random_seed(42)"
]
},
{
Expand All @@ -102,7 +96,7 @@
},
"outputs": [],
"source": [
"(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()\n",
"(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n",
"print(f\"Total training examples: {len(x_train)}\")\n",
"print(f\"Total test examples: {len(x_test)}\")"
]
Expand Down Expand Up @@ -150,15 +144,9 @@
},
"outputs": [],
"source": [
"rand_aug = iaa.RandAugment(n=3, m=7)\n",
"\n",
"\n",
"def augment(images):\n",
" # Input to `augment()` is a TensorFlow tensor which\n",
" # is not supported by `imgaug`. This is why we first\n",
" # convert it to its `numpy` variant.\n",
" images = tf.cast(images, tf.uint8)\n",
" return rand_aug(images=images.numpy())\n",
"rand_augment = keras_cv.layers.RandAugment(\n",
" value_range=(0, 255), augmentations_per_image=3, magnitude=0.8\n",
")\n",
""
]
},
Expand All @@ -168,20 +156,7 @@
"colab_type": "text"
},
"source": [
"## Create TensorFlow `Dataset` objects\n",
"\n",
"Because `RandAugment` can only process NumPy arrays, it\n",
"cannot be applied directly as part of the `Dataset` object (which expects TensorFlow\n",
"tensors). To make `RandAugment` part of the dataset, we need to wrap it in a\n",
"[`tf.py_function`](https://www.tensorflow.org/api_docs/python/tf/py_function).\n",
"\n",
"A `tf.py_function` is a TensorFlow operation (which, like any other TensorFlow operation,\n",
"takes TF tensors as arguments and returns TensorFlow tensors) that is capable of running\n",
"arbitrary Python code. Naturally, this Python code can only be executed on CPU (whereas\n",
"the rest of the TensorFlow graph can be accelerated on GPU), which in some cases can\n",
"cause significant slowdowns -- however, in this case, the `Dataset` pipeline will run\n",
"asynchronously together with the model, and doing preprocessing on CPU will remain\n",
"performant."
"## Create TensorFlow `Dataset` objects"
]
},
{
Expand All @@ -201,7 +176,7 @@
" num_parallel_calls=AUTO,\n",
" )\n",
" .map(\n",
" lambda x, y: (tf.py_function(augment, [x], [tf.float32])[0], y),\n",
" lambda x, y: (rand_augment(tf.cast(x, tf.uint8)), y),\n",
" num_parallel_calls=AUTO,\n",
" )\n",
" .prefetch(AUTO)\n",
Expand All @@ -218,23 +193,6 @@
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"**Note about using `tf.py_function`**:\n",
"\n",
"* As our `augment()` function is not a native TensorFlow operation chances are likely\n",
"that it can turn into an expensive operation. This is why it is much better to apply it\n",
"_after_ batching our dataset.\n",
"* `tf.py_function` is [not compatible](https://github.com/tensorflow/tensorflow/issues/38762)\n",
"with TPUs. So, if you have distributed TensorFlow training pipelines that use TPUs\n",
"you cannot use `tf.py_function`. In that case, consider switching to a multi-GPU environment,\n",
"or rewriting the contents of the function in pure TensorFlow."
]
},
{
"cell_type": "markdown",
"metadata": {
Expand All @@ -253,14 +211,12 @@
},
"outputs": [],
"source": [
"simple_aug = tf.keras.Sequential(\n",
"simple_aug = keras.Sequential(\n",
" [\n",
" layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),\n",
" layers.RandomFlip(\"horizontal\"),\n",
" layers.RandomRotation(factor=0.02),\n",
" layers.RandomZoom(\n",
" height_factor=0.2, width_factor=0.2\n",
" ),\n",
" layers.RandomZoom(height_factor=0.2, width_factor=0.2),\n",
" ]\n",
")\n",
"\n",
Expand Down Expand Up @@ -359,13 +315,13 @@
"source": [
"\n",
"def get_training_model():\n",
" resnet50_v2 = tf.keras.applications.ResNet50V2(\n",
" resnet50_v2 = keras.applications.ResNet50V2(\n",
" weights=None,\n",
" include_top=True,\n",
" input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),\n",
" classes=10,\n",
" )\n",
" model = tf.keras.Sequential(\n",
" model = keras.Sequential(\n",
" [\n",
" layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),\n",
" layers.Rescaling(scale=1.0 / 127.5, offset=-1),\n",
Expand Down Expand Up @@ -415,7 +371,7 @@
"outputs": [],
"source": [
"initial_model = get_training_model()\n",
"initial_model.save_weights(\"initial_weights.h5\")"
"initial_model.save_weights(\"initial.weights.h5\")"
]
},
{
Expand All @@ -436,7 +392,7 @@
"outputs": [],
"source": [
"rand_aug_model = get_training_model()\n",
"rand_aug_model.load_weights(\"initial_weights.h5\")\n",
"rand_aug_model.load_weights(\"initial.weights.h5\")\n",
"rand_aug_model.compile(\n",
" loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"]\n",
")\n",
Expand All @@ -463,7 +419,7 @@
"outputs": [],
"source": [
"simple_aug_model = get_training_model()\n",
"simple_aug_model.load_weights(\"initial_weights.h5\")\n",
"simple_aug_model.load_weights(\"initial.weights.h5\")\n",
"simple_aug_model.compile(\n",
" loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"]\n",
")\n",
Expand Down Expand Up @@ -523,8 +479,7 @@
"For the purpose of this example, we trained the models for only a single epoch. On the\n",
"CIFAR-10-C dataset, the model with RandAugment can perform better with a higher accuracy\n",
"(for example, 76.64% in one experiment) compared with the model trained with `simple_aug`\n",
"(e.g., 64.80%). RandAugment can also help stabilize the training. You can explore this\n",
"[notebook](https://nbviewer.jupyter.org/github/sayakpaul/Keras-Examples-RandAugment/blob/main/RandAugment.ipynb) to check some of the results.\n",
"(e.g., 64.80%). RandAugment can also help stabilize the training.\n",
"\n",
"In the notebook, you may notice that, at the expense of increased training time with RandAugment,\n",
"we are able to carve out far better performance on the CIFAR-10-C dataset. You can\n",
Expand All @@ -541,8 +496,8 @@
"[FixMatch](https://arxiv.org/abs/2001.07685). This makes RandAugment quite a useful\n",
"recipe for training different vision models.\n",
"\n",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/randaugment) ",
"and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/randaugment).",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/randaugment)\n",
"and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/randaugment)."
]
}
],
Expand Down
Loading

0 comments on commit afe0d2c

Please sign in to comment.