Skip to content

Commit

Permalink
Port KerasCV's BaseImageAugmentationLayer guide to Keras 3 (#1664)
Browse files Browse the repository at this point in the history
  • Loading branch information
tirthasheshpatel authored Nov 30, 2023
1 parent 1148db9 commit c03f594
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 173 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
141 changes: 82 additions & 59 deletions guides/ipynb/keras_cv/custom_image_augmentations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [lukewood](https://twitter.com/luke_wood_ml)<br>\n",
"**Date created:** 2022/04/26<br>\n",
"**Last modified:** 2022/04/26<br>\n",
"**Last modified:** 2023/11/29<br>\n",
"**Description:** Use BaseImageAugmentationLayer to implement custom data augmentations."
]
},
Expand All @@ -31,7 +31,9 @@
"\n",
"This guide will show you how to implement your own custom augmentation layers using\n",
"`BaseImageAugmentationLayer`. As an example, we will implement a layer that tints all\n",
"images blue."
"images blue.\n",
"\n",
"Currently, KerasCV's preprocessing layers only support the TensorFlow backend with Keras 3."
]
},
{
Expand All @@ -42,15 +44,27 @@
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"import keras_cv\n",
"from tensorflow.keras import layers\n",
"from keras_cv import utils\n",
"from keras_cv.layers import BaseImageAugmentationLayer\n",
"import matplotlib.pyplot as plt\n",
"!pip install -q --upgrade keras-cv\n",
"!pip install -q --upgrade keras # Upgrade to Keras 3"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"\n",
"tf.autograph.set_verbosity(0)"
"import keras\n",
"from keras import ops\n",
"from keras import layers\n",
"import keras_cv\n",
"import matplotlib.pyplot as plt"
]
},
{
Expand All @@ -59,7 +73,7 @@
"colab_type": "text"
},
"source": [
"First, let's implement some helper functions to visualize intermediate results"
"First, let's implement some helper functions for visualization and some transformations."
]
},
{
Expand All @@ -86,6 +100,22 @@
" plt.imshow(image.astype(\"uint8\"))\n",
" plt.axis(\"off\")\n",
" plt.show()\n",
"\n",
"\n",
"def transform_value_range(images, original_range, target_range):\n",
" images = (images - original_range[0]) / (original_range[1] - original_range[0])\n",
" scale_factor = target_range[1] - target_range[0]\n",
" return (images * scale_factor) + target_range[0]\n",
"\n",
"\n",
"def parse_factor(param, min_value=0.0, max_value=1.0, seed=None):\n",
" if isinstance(param, keras_cv.core.FactorSampler):\n",
" return param\n",
" if isinstance(param, float) or isinstance(param, int):\n",
" param = (min_value, param)\n",
" if param[0] == param[1]:\n",
" return keras_cv.core.ConstantFactorSampler(param[0])\n",
" return keras_cv.core.UniformFactorSampler(param[0], param[1], seed=seed)\n",
""
]
},
Expand Down Expand Up @@ -113,15 +143,8 @@
"present in the input images. KerasCV offers the `value_range` API to simplify the handling of this.\n",
"\n",
"In our example, we will use the `FactorSampler` API, the `value_range` API, and\n",
"`BaseImageAugmentationLayer` to implement a robust, configurable, and correct `RandomBlueTint` layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"`BaseImageAugmentationLayer` to implement a robust, configurable, and correct `RandomBlueTint` layer.\n",
"\n",
"## Overriding `augment_image()`\n",
"\n",
"Let's start off with the minimum:"
Expand All @@ -137,11 +160,11 @@
"source": [
"\n",
"class RandomBlueTint(keras_cv.layers.BaseImageAugmentationLayer):\n",
" def augment_image(self, image, transformation=None):\n",
" def augment_image(self, image, *args, transformation=None, **kwargs):\n",
" # image is of shape (height, width, channels)\n",
" [*others, blue] = tf.unstack(image, axis=-1)\n",
" blue = tf.clip_by_value(blue + 100, 0.0, 255.0)\n",
" return tf.stack([*others, blue], axis=-1)\n",
" [*others, blue] = ops.unstack(image, axis=-1)\n",
" blue = ops.clip(blue + 100, 0.0, 255.0)\n",
" return ops.stack([*others, blue], axis=-1)\n",
""
]
},
Expand Down Expand Up @@ -172,11 +195,11 @@
"outputs": [],
"source": [
"SIZE = (300, 300)\n",
"elephants = tf.keras.utils.get_file(\n",
"elephants = keras.utils.get_file(\n",
" \"african_elephant.jpg\", \"https://i.imgur.com/Bvro0YD.png\"\n",
")\n",
"elephants = tf.keras.utils.load_img(elephants, target_size=SIZE)\n",
"elephants = tf.keras.utils.img_to_array(elephants)\n",
"elephants = keras.utils.load_img(elephants, target_size=SIZE)\n",
"elephants = keras.utils.img_to_array(elephants)\n",
"imshow(elephants)"
]
},
Expand All @@ -199,7 +222,7 @@
"source": [
"layer = RandomBlueTint()\n",
"augmented = layer(elephants)\n",
"imshow(augmented.numpy())"
"imshow(ops.convert_to_numpy(augmented))"
]
},
{
Expand All @@ -220,8 +243,8 @@
"outputs": [],
"source": [
"layer = RandomBlueTint()\n",
"augmented = layer(tf.expand_dims(elephants, axis=0))\n",
"imshow(augmented.numpy()[0])"
"augmented = layer(ops.expand_dims(elephants, axis=0))\n",
"imshow(ops.convert_to_numpy(augmented)[0])"
]
},
{
Expand Down Expand Up @@ -266,13 +289,13 @@
"\n",
" def __init__(self, factor, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.factor = utils.parse_factor(factor)\n",
" self.factor = parse_factor(factor)\n",
"\n",
" def augment_image(self, image, transformation=None):\n",
" [*others, blue] = tf.unstack(image, axis=-1)\n",
" def augment_image(self, image, *args, transformation=None, **kwargs):\n",
" [*others, blue] = ops.unstack(image, axis=-1)\n",
" blue_shift = self.factor() * 255\n",
" blue = tf.clip_by_value(blue + blue_shift, 0.0, 255.0)\n",
" return tf.stack([*others, blue], axis=-1)\n",
" blue = ops.clip(blue + blue_shift, 0.0, 255.0)\n",
" return ops.stack([*others, blue], axis=-1)\n",
""
]
},
Expand All @@ -294,10 +317,10 @@
},
"outputs": [],
"source": [
"many_elephants = tf.repeat(tf.expand_dims(elephants, axis=0), 9, axis=0)\n",
"many_elephants = ops.repeat(ops.expand_dims(elephants, axis=0), 9, axis=0)\n",
"layer = RandomBlueTint(factor=0.5)\n",
"augmented = layer(many_elephants)\n",
"gallery_show(augmented.numpy())"
"gallery_show(ops.convert_to_numpy(augmented))"
]
},
{
Expand All @@ -320,13 +343,13 @@
},
"outputs": [],
"source": [
"many_elephants = tf.repeat(tf.expand_dims(elephants, axis=0), 9, axis=0)\n",
"factor = keras_cv.NormalFactorSampler(\n",
"many_elephants = ops.repeat(ops.expand_dims(elephants, axis=0), 9, axis=0)\n",
"factor = keras_cv.core.NormalFactorSampler(\n",
" mean=0.3, stddev=0.1, min_value=0.0, max_value=1.0\n",
")\n",
"layer = RandomBlueTint(factor=factor)\n",
"augmented = layer(many_elephants)\n",
"gallery_show(augmented.numpy())"
"gallery_show(ops.convert_to_numpy(augmented))"
]
},
{
Expand Down Expand Up @@ -383,16 +406,16 @@
"\n",
" def __init__(self, factor, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.factor = utils.parse_factor(factor)\n",
" self.factor = parse_factor(factor)\n",
"\n",
" def get_random_transformation(self, **kwargs):\n",
" # kwargs holds {\"images\": image, \"labels\": label, etc...}\n",
" return self.factor() * 255\n",
"\n",
" def augment_image(self, image, transformation=None, **kwargs):\n",
" [*others, blue] = tf.unstack(image, axis=-1)\n",
" blue = tf.clip_by_value(blue + transformation, 0.0, 255.0)\n",
" return tf.stack([*others, blue], axis=-1)\n",
" [*others, blue] = ops.unstack(image, axis=-1)\n",
" blue = ops.clip(blue + transformation, 0.0, 255.0)\n",
" return ops.stack([*others, blue], axis=-1)\n",
"\n",
" def augment_label(self, label, transformation=None, **kwargs):\n",
" # you can use transformation somehow if you want\n",
Expand Down Expand Up @@ -436,8 +459,8 @@
},
"outputs": [],
"source": [
"labels = tf.constant([[1, 0]])\n",
"inputs = {\"images\": elephants, \"labels\": labels}"
"labels = ops.array([[1, 0]])\n",
"inputs = {\"images\": ops.convert_to_tensor(elephants), \"labels\": labels}"
]
},
{
Expand Down Expand Up @@ -496,10 +519,10 @@
"augmented = layer(elephants_0_1)\n",
"print(\n",
" \"min and max after augmentation:\",\n",
" (augmented.numpy()).min(),\n",
" augmented.numpy().max(),\n",
" ops.convert_to_numpy(augmented).min(),\n",
" ops.convert_to_numpy(augmented).max(),\n",
")\n",
"imshow((augmented * 255).numpy().astype(int))"
"imshow(ops.convert_to_numpy(augmented * 255).astype(int))"
]
},
{
Expand Down Expand Up @@ -547,18 +570,18 @@
" def __init__(self, value_range, factor, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.value_range = value_range\n",
" self.factor = utils.parse_factor(factor)\n",
" self.factor = parse_factor(factor)\n",
"\n",
" def get_random_transformation(self, **kwargs):\n",
" # kwargs holds {\"images\": image, \"labels\": label, etc...}\n",
" return self.factor() * 255\n",
"\n",
" def augment_image(self, image, transformation=None, **kwargs):\n",
" image = utils.transform_value_range(image, self.value_range, (0, 255))\n",
" [*others, blue] = tf.unstack(image, axis=-1)\n",
" blue = tf.clip_by_value(blue + transformation, 0.0, 255.0)\n",
" result = tf.stack([*others, blue], axis=-1)\n",
" result = utils.transform_value_range(result, (0, 255), self.value_range)\n",
" image = transform_value_range(image, self.value_range, (0, 255))\n",
" [*others, blue] = ops.unstack(image, axis=-1)\n",
" blue = ops.clip(blue + transformation, 0.0, 255.0)\n",
" result = ops.stack([*others, blue], axis=-1)\n",
" result = transform_value_range(result, (0, 255), self.value_range)\n",
" return result\n",
"\n",
" def augment_label(self, label, transformation=None, **kwargs):\n",
Expand All @@ -582,10 +605,10 @@
"augmented = layer(elephants_0_1)\n",
"print(\n",
" \"min and max after augmentation:\",\n",
" augmented.numpy().min(),\n",
" augmented.numpy().max(),\n",
" ops.convert_to_numpy(augmented).min(),\n",
" ops.convert_to_numpy(augmented).max(),\n",
")\n",
"imshow((augmented * 255).numpy().astype(int))"
"imshow(ops.convert_to_numpy(augmented * 255).astype(int))"
]
},
{
Expand Down Expand Up @@ -713,4 +736,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading

0 comments on commit c03f594

Please sign in to comment.