Skip to content

Commit

Permalink
Update gradient centralization example for Keras 3. (#1598)
Browse files Browse the repository at this point in the history
  • Loading branch information
hertschuh authored Nov 10, 2023
1 parent e9358d0 commit c8fa98c
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 184 deletions.
55 changes: 30 additions & 25 deletions examples/vision/gradient_centralization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Title: Gradient Centralization for Better Training Performance
Author: [Rishit Dagli](https://github.com/Rishit-dagli)
Date created: 06/18/21
Last modified: 06/18/21
Last modified: 07/25/23
Description: Implement Gradient Centralization to improve training performance of DNNs.
Accelerator: GPU
Converted to Keras 3 by: [Muhammad Anas Raza](https://anasrz.com)
"""
"""
## Introduction
Expand All @@ -19,16 +20,11 @@
the loss function and its gradient so that the training process becomes more efficient
and stable.
This example requires TensorFlow 2.2 or higher as well as `tensorflow_datasets` which can
be installed with this command:
This example requires `tensorflow_datasets` which can be installed with this command:
```
pip install tensorflow-datasets
```
We will be implementing Gradient Centralization in this example but you could also use
this very easily with a package I built,
[gradient-centralization-tf](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow).
"""

"""
Expand All @@ -37,10 +33,14 @@

from time import time

import tensorflow as tf
import keras
from keras import layers
from keras.optimizers import RMSprop
from keras import ops

from tensorflow import data as tf_data
import tensorflow_datasets as tfds
from tensorflow.keras import layers
from tensorflow.keras.optimizers import RMSprop


"""
## Prepare the data
Expand All @@ -53,7 +53,7 @@
input_shape = (300, 300, 3)
dataset_name = "horses_or_humans"
batch_size = 128
AUTOTUNE = tf.data.AUTOTUNE
AUTOTUNE = tf_data.AUTOTUNE

(train_ds, test_ds), metadata = tfds.load(
name=dataset_name,
Expand All @@ -74,13 +74,18 @@

rescale = layers.Rescaling(1.0 / 255)

data_augmentation = tf.keras.Sequential(
[
layers.RandomFlip("horizontal_and_vertical"),
layers.RandomRotation(0.3),
layers.RandomZoom(0.2),
]
)
data_augmentation = [
layers.RandomFlip("horizontal_and_vertical"),
layers.RandomRotation(0.3),
layers.RandomZoom(0.2),
]


# Helper to apply augmentation
def apply_aug(x):
for aug in data_augmentation:
x = aug(x)
return x


def prepare(ds, shuffle=False, augment=False):
Expand All @@ -96,7 +101,7 @@ def prepare(ds, shuffle=False, augment=False):
# Use data augmentation only on the training set
if augment:
ds = ds.map(
lambda x, y: (data_augmentation(x, training=True), y),
lambda x, y: (apply_aug(x), y),
num_parallel_calls=AUTOTUNE,
)

Expand All @@ -110,16 +115,16 @@ def prepare(ds, shuffle=False, augment=False):

train_ds = prepare(train_ds, shuffle=True, augment=True)
test_ds = prepare(test_ds)

"""
## Define a model
In this section we will define a Convolutional neural network.
"""

model = tf.keras.Sequential(
model = keras.Sequential(
[
layers.Conv2D(16, (3, 3), activation="relu", input_shape=(300, 300, 3)),
layers.Input(shape=input_shape),
layers.Conv2D(16, (3, 3), activation="relu"),
layers.MaxPooling2D(2, 2),
layers.Conv2D(32, (3, 3), activation="relu"),
layers.Dropout(0.5),
Expand All @@ -143,7 +148,7 @@ def prepare(ds, shuffle=False, augment=False):
We will now
subclass the `RMSProp` optimizer class modifying the
`tf.keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient
`keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient
Centralization. On a high level the idea is that let us say we obtain our gradients
through back propogation for a Dense or Convolution layer we then compute the mean of the
column vectors of the weight matrix, and then remove the mean from each column vector.
Expand Down Expand Up @@ -174,7 +179,7 @@ def get_gradients(self, loss, params):
grad_len = len(grad.shape)
if grad_len > 1:
axis = list(range(grad_len - 1))
grad -= tf.reduce_mean(grad, axis=axis, keep_dims=True)
grad -= ops.mean(grad, axis=axis, keep_dims=True)
grads.append(grad)

return grads
Expand All @@ -191,7 +196,7 @@ def get_gradients(self, loss, params):
"""


class TimeHistory(tf.keras.callbacks.Callback):
class TimeHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.times = []

Expand Down
55 changes: 30 additions & 25 deletions examples/vision/ipynb/gradient_centralization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Rishit Dagli](https://github.com/Rishit-dagli)<br>\n",
"**Date created:** 06/18/21<br>\n",
"**Last modified:** 06/18/21<br>\n",
"**Last modified:** 07/25/23<br>\n",
"**Description:** Implement Gradient Centralization to improve training performance of DNNs."
]
},
Expand All @@ -32,16 +32,11 @@
"the loss function and its gradient so that the training process becomes more efficient\n",
"and stable.\n",
"\n",
"This example requires TensorFlow 2.2 or higher as well as `tensorflow_datasets` which can\n",
"be installed with this command:\n",
"This example requires `tensorflow_datasets` which can be installed with this command:\n",
"\n",
"```\n",
"pip install tensorflow-datasets\n",
"```\n",
"\n",
"We will be implementing Gradient Centralization in this example but you could also use\n",
"this very easily with a package I built,\n",
"[gradient-centralization-tf](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow)."
"```"
]
},
{
Expand All @@ -63,10 +58,14 @@
"source": [
"from time import time\n",
"\n",
"import tensorflow as tf\n",
"import keras\n",
"from keras import layers\n",
"from keras.optimizers import RMSprop\n",
"from keras import ops\n",
"\n",
"from tensorflow import data as tf_data\n",
"import tensorflow_datasets as tfds\n",
"from tensorflow.keras import layers\n",
"from tensorflow.keras.optimizers import RMSprop"
""
]
},
{
Expand All @@ -93,7 +92,7 @@
"input_shape = (300, 300, 3)\n",
"dataset_name = \"horses_or_humans\"\n",
"batch_size = 128\n",
"AUTOTUNE = tf.data.AUTOTUNE\n",
"AUTOTUNE = tf_data.AUTOTUNE\n",
"\n",
"(train_ds, test_ds), metadata = tfds.load(\n",
" name=dataset_name,\n",
Expand Down Expand Up @@ -128,13 +127,18 @@
"source": [
"rescale = layers.Rescaling(1.0 / 255)\n",
"\n",
"data_augmentation = tf.keras.Sequential(\n",
" [\n",
" layers.RandomFlip(\"horizontal_and_vertical\"),\n",
" layers.RandomRotation(0.3),\n",
" layers.RandomZoom(0.2),\n",
" ]\n",
")\n",
"data_augmentation = [\n",
" layers.RandomFlip(\"horizontal_and_vertical\"),\n",
" layers.RandomRotation(0.3),\n",
" layers.RandomZoom(0.2),\n",
"]\n",
"\n",
"\n",
"# Helper to apply augmentation\n",
"def apply_aug(x):\n",
" for aug in data_augmentation:\n",
" x = aug(x)\n",
" return x\n",
"\n",
"\n",
"def prepare(ds, shuffle=False, augment=False):\n",
Expand All @@ -150,7 +154,7 @@
" # Use data augmentation only on the training set\n",
" if augment:\n",
" ds = ds.map(\n",
" lambda x, y: (data_augmentation(x, training=True), y),\n",
" lambda x, y: (apply_aug(x), y),\n",
" num_parallel_calls=AUTOTUNE,\n",
" )\n",
"\n",
Expand Down Expand Up @@ -199,9 +203,10 @@
},
"outputs": [],
"source": [
"model = tf.keras.Sequential(\n",
"model = keras.Sequential(\n",
" [\n",
" layers.Conv2D(16, (3, 3), activation=\"relu\", input_shape=(300, 300, 3)),\n",
" layers.Input(shape=input_shape),\n",
" layers.Conv2D(16, (3, 3), activation=\"relu\"),\n",
" layers.MaxPooling2D(2, 2),\n",
" layers.Conv2D(32, (3, 3), activation=\"relu\"),\n",
" layers.Dropout(0.5),\n",
Expand Down Expand Up @@ -231,7 +236,7 @@
"\n",
"We will now\n",
"subclass the `RMSProp` optimizer class modifying the\n",
"`tf.keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient\n",
"`keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient\n",
"Centralization. On a high level the idea is that let us say we obtain our gradients\n",
"through back propogation for a Dense or Convolution layer we then compute the mean of the\n",
"column vectors of the weight matrix, and then remove the mean from each column vector.\n",
Expand Down Expand Up @@ -270,7 +275,7 @@
" grad_len = len(grad.shape)\n",
" if grad_len > 1:\n",
" axis = list(range(grad_len - 1))\n",
" grad -= tf.reduce_mean(grad, axis=axis, keep_dims=True)\n",
" grad -= ops.mean(grad, axis=axis, keep_dims=True)\n",
" grads.append(grad)\n",
"\n",
" return grads\n",
Expand Down Expand Up @@ -301,7 +306,7 @@
"outputs": [],
"source": [
"\n",
"class TimeHistory(tf.keras.callbacks.Callback):\n",
"class TimeHistory(keras.callbacks.Callback):\n",
" def on_train_begin(self, logs={}):\n",
" self.times = []\n",
"\n",
Expand Down
Loading

0 comments on commit c8fa98c

Please sign in to comment.