Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update gradient centralization example for Keras 3. #1598

Merged
merged 1 commit into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 30 additions & 25 deletions examples/vision/gradient_centralization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Title: Gradient Centralization for Better Training Performance
Author: [Rishit Dagli](https://github.com/Rishit-dagli)
Date created: 06/18/21
Last modified: 06/18/21
Last modified: 07/25/23
Description: Implement Gradient Centralization to improve training performance of DNNs.
Accelerator: GPU
Converted to Keras 3 by: [Muhammad Anas Raza](https://anasrz.com)
"""
"""
## Introduction
Expand All @@ -19,16 +20,11 @@
the loss function and its gradient so that the training process becomes more efficient
and stable.

This example requires TensorFlow 2.2 or higher as well as `tensorflow_datasets` which can
be installed with this command:
This example requires `tensorflow_datasets` which can be installed with this command:

```
pip install tensorflow-datasets
```

We will be implementing Gradient Centralization in this example but you could also use
this very easily with a package I built,
[gradient-centralization-tf](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow).
"""

"""
Expand All @@ -37,10 +33,14 @@

from time import time

import tensorflow as tf
import keras
from keras import layers
from keras.optimizers import RMSprop
from keras import ops

from tensorflow import data as tf_data
import tensorflow_datasets as tfds
from tensorflow.keras import layers
from tensorflow.keras.optimizers import RMSprop


"""
## Prepare the data
Expand All @@ -53,7 +53,7 @@
input_shape = (300, 300, 3)
dataset_name = "horses_or_humans"
batch_size = 128
AUTOTUNE = tf.data.AUTOTUNE
AUTOTUNE = tf_data.AUTOTUNE

(train_ds, test_ds), metadata = tfds.load(
name=dataset_name,
Expand All @@ -74,13 +74,18 @@

rescale = layers.Rescaling(1.0 / 255)

data_augmentation = tf.keras.Sequential(
[
layers.RandomFlip("horizontal_and_vertical"),
layers.RandomRotation(0.3),
layers.RandomZoom(0.2),
]
)
data_augmentation = [
layers.RandomFlip("horizontal_and_vertical"),
layers.RandomRotation(0.3),
layers.RandomZoom(0.2),
]


# Helper to apply augmentation
def apply_aug(x):
for aug in data_augmentation:
x = aug(x)
return x


def prepare(ds, shuffle=False, augment=False):
Expand All @@ -96,7 +101,7 @@ def prepare(ds, shuffle=False, augment=False):
# Use data augmentation only on the training set
if augment:
ds = ds.map(
lambda x, y: (data_augmentation(x, training=True), y),
lambda x, y: (apply_aug(x), y),
num_parallel_calls=AUTOTUNE,
)

Expand All @@ -110,16 +115,16 @@ def prepare(ds, shuffle=False, augment=False):

train_ds = prepare(train_ds, shuffle=True, augment=True)
test_ds = prepare(test_ds)

"""
## Define a model

In this section we will define a Convolutional neural network.
"""

model = tf.keras.Sequential(
model = keras.Sequential(
[
layers.Conv2D(16, (3, 3), activation="relu", input_shape=(300, 300, 3)),
layers.Input(shape=input_shape),
layers.Conv2D(16, (3, 3), activation="relu"),
layers.MaxPooling2D(2, 2),
layers.Conv2D(32, (3, 3), activation="relu"),
layers.Dropout(0.5),
Expand All @@ -143,7 +148,7 @@ def prepare(ds, shuffle=False, augment=False):

We will now
subclass the `RMSProp` optimizer class modifying the
`tf.keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient
`keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient
Centralization. On a high level the idea is that let us say we obtain our gradients
through back propogation for a Dense or Convolution layer we then compute the mean of the
column vectors of the weight matrix, and then remove the mean from each column vector.
Expand Down Expand Up @@ -174,7 +179,7 @@ def get_gradients(self, loss, params):
grad_len = len(grad.shape)
if grad_len > 1:
axis = list(range(grad_len - 1))
grad -= tf.reduce_mean(grad, axis=axis, keep_dims=True)
grad -= ops.mean(grad, axis=axis, keep_dims=True)
grads.append(grad)

return grads
Expand All @@ -191,7 +196,7 @@ def get_gradients(self, loss, params):
"""


class TimeHistory(tf.keras.callbacks.Callback):
class TimeHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.times = []

Expand Down
55 changes: 30 additions & 25 deletions examples/vision/ipynb/gradient_centralization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Rishit Dagli](https://github.com/Rishit-dagli)<br>\n",
"**Date created:** 06/18/21<br>\n",
"**Last modified:** 06/18/21<br>\n",
"**Last modified:** 07/25/23<br>\n",
"**Description:** Implement Gradient Centralization to improve training performance of DNNs."
]
},
Expand All @@ -32,16 +32,11 @@
"the loss function and its gradient so that the training process becomes more efficient\n",
"and stable.\n",
"\n",
"This example requires TensorFlow 2.2 or higher as well as `tensorflow_datasets` which can\n",
"be installed with this command:\n",
"This example requires `tensorflow_datasets` which can be installed with this command:\n",
"\n",
"```\n",
"pip install tensorflow-datasets\n",
"```\n",
"\n",
"We will be implementing Gradient Centralization in this example but you could also use\n",
"this very easily with a package I built,\n",
"[gradient-centralization-tf](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow)."
"```"
]
},
{
Expand All @@ -63,10 +58,14 @@
"source": [
"from time import time\n",
"\n",
"import tensorflow as tf\n",
"import keras\n",
"from keras import layers\n",
"from keras.optimizers import RMSprop\n",
"from keras import ops\n",
"\n",
"from tensorflow import data as tf_data\n",
"import tensorflow_datasets as tfds\n",
"from tensorflow.keras import layers\n",
"from tensorflow.keras.optimizers import RMSprop"
""
]
},
{
Expand All @@ -93,7 +92,7 @@
"input_shape = (300, 300, 3)\n",
"dataset_name = \"horses_or_humans\"\n",
"batch_size = 128\n",
"AUTOTUNE = tf.data.AUTOTUNE\n",
"AUTOTUNE = tf_data.AUTOTUNE\n",
"\n",
"(train_ds, test_ds), metadata = tfds.load(\n",
" name=dataset_name,\n",
Expand Down Expand Up @@ -128,13 +127,18 @@
"source": [
"rescale = layers.Rescaling(1.0 / 255)\n",
"\n",
"data_augmentation = tf.keras.Sequential(\n",
" [\n",
" layers.RandomFlip(\"horizontal_and_vertical\"),\n",
" layers.RandomRotation(0.3),\n",
" layers.RandomZoom(0.2),\n",
" ]\n",
")\n",
"data_augmentation = [\n",
" layers.RandomFlip(\"horizontal_and_vertical\"),\n",
" layers.RandomRotation(0.3),\n",
" layers.RandomZoom(0.2),\n",
"]\n",
"\n",
"\n",
"# Helper to apply augmentation\n",
"def apply_aug(x):\n",
" for aug in data_augmentation:\n",
" x = aug(x)\n",
" return x\n",
"\n",
"\n",
"def prepare(ds, shuffle=False, augment=False):\n",
Expand All @@ -150,7 +154,7 @@
" # Use data augmentation only on the training set\n",
" if augment:\n",
" ds = ds.map(\n",
" lambda x, y: (data_augmentation(x, training=True), y),\n",
" lambda x, y: (apply_aug(x), y),\n",
" num_parallel_calls=AUTOTUNE,\n",
" )\n",
"\n",
Expand Down Expand Up @@ -199,9 +203,10 @@
},
"outputs": [],
"source": [
"model = tf.keras.Sequential(\n",
"model = keras.Sequential(\n",
" [\n",
" layers.Conv2D(16, (3, 3), activation=\"relu\", input_shape=(300, 300, 3)),\n",
" layers.Input(shape=input_shape),\n",
" layers.Conv2D(16, (3, 3), activation=\"relu\"),\n",
" layers.MaxPooling2D(2, 2),\n",
" layers.Conv2D(32, (3, 3), activation=\"relu\"),\n",
" layers.Dropout(0.5),\n",
Expand Down Expand Up @@ -231,7 +236,7 @@
"\n",
"We will now\n",
"subclass the `RMSProp` optimizer class modifying the\n",
"`tf.keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient\n",
"`keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient\n",
"Centralization. On a high level the idea is that let us say we obtain our gradients\n",
"through back propogation for a Dense or Convolution layer we then compute the mean of the\n",
"column vectors of the weight matrix, and then remove the mean from each column vector.\n",
Expand Down Expand Up @@ -270,7 +275,7 @@
" grad_len = len(grad.shape)\n",
" if grad_len > 1:\n",
" axis = list(range(grad_len - 1))\n",
" grad -= tf.reduce_mean(grad, axis=axis, keep_dims=True)\n",
" grad -= ops.mean(grad, axis=axis, keep_dims=True)\n",
" grads.append(grad)\n",
"\n",
" return grads\n",
Expand Down Expand Up @@ -301,7 +306,7 @@
"outputs": [],
"source": [
"\n",
"class TimeHistory(tf.keras.callbacks.Callback):\n",
"class TimeHistory(keras.callbacks.Callback):\n",
" def on_train_begin(self, logs={}):\n",
" self.times = []\n",
"\n",
Expand Down
Loading
Loading