diff --git a/examples/vision/gradient_centralization.py b/examples/vision/gradient_centralization.py index 121977bed6..f53a45a5c7 100644 --- a/examples/vision/gradient_centralization.py +++ b/examples/vision/gradient_centralization.py @@ -2,9 +2,10 @@ Title: Gradient Centralization for Better Training Performance Author: [Rishit Dagli](https://github.com/Rishit-dagli) Date created: 06/18/21 -Last modified: 06/18/21 +Last modified: 07/25/23 Description: Implement Gradient Centralization to improve training performance of DNNs. Accelerator: GPU +Converted to Keras 3 by: [Muhammad Anas Raza](https://anasrz.com) """ """ ## Introduction @@ -19,16 +20,11 @@ the loss function and its gradient so that the training process becomes more efficient and stable. -This example requires TensorFlow 2.2 or higher as well as `tensorflow_datasets` which can -be installed with this command: +This example requires `tensorflow_datasets` which can be installed with this command: ``` pip install tensorflow-datasets ``` - -We will be implementing Gradient Centralization in this example but you could also use -this very easily with a package I built, -[gradient-centralization-tf](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow). """ """ @@ -37,10 +33,14 @@ from time import time -import tensorflow as tf +import keras +from keras import layers +from keras.optimizers import RMSprop +from keras import ops + +from tensorflow import data as tf_data import tensorflow_datasets as tfds -from tensorflow.keras import layers -from tensorflow.keras.optimizers import RMSprop + """ ## Prepare the data @@ -53,7 +53,7 @@ input_shape = (300, 300, 3) dataset_name = "horses_or_humans" batch_size = 128 -AUTOTUNE = tf.data.AUTOTUNE +AUTOTUNE = tf_data.AUTOTUNE (train_ds, test_ds), metadata = tfds.load( name=dataset_name, @@ -74,13 +74,18 @@ rescale = layers.Rescaling(1.0 / 255) -data_augmentation = tf.keras.Sequential( - [ - layers.RandomFlip("horizontal_and_vertical"), - layers.RandomRotation(0.3), - layers.RandomZoom(0.2), - ] -) +data_augmentation = [ + layers.RandomFlip("horizontal_and_vertical"), + layers.RandomRotation(0.3), + layers.RandomZoom(0.2), +] + + +# Helper to apply augmentation +def apply_aug(x): + for aug in data_augmentation: + x = aug(x) + return x def prepare(ds, shuffle=False, augment=False): @@ -96,7 +101,7 @@ def prepare(ds, shuffle=False, augment=False): # Use data augmentation only on the training set if augment: ds = ds.map( - lambda x, y: (data_augmentation(x, training=True), y), + lambda x, y: (apply_aug(x), y), num_parallel_calls=AUTOTUNE, ) @@ -110,16 +115,16 @@ def prepare(ds, shuffle=False, augment=False): train_ds = prepare(train_ds, shuffle=True, augment=True) test_ds = prepare(test_ds) - """ ## Define a model In this section we will define a Convolutional neural network. """ -model = tf.keras.Sequential( +model = keras.Sequential( [ - layers.Conv2D(16, (3, 3), activation="relu", input_shape=(300, 300, 3)), + layers.Input(shape=input_shape), + layers.Conv2D(16, (3, 3), activation="relu"), layers.MaxPooling2D(2, 2), layers.Conv2D(32, (3, 3), activation="relu"), layers.Dropout(0.5), @@ -143,7 +148,7 @@ def prepare(ds, shuffle=False, augment=False): We will now subclass the `RMSProp` optimizer class modifying the -`tf.keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient +`keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient Centralization. On a high level the idea is that let us say we obtain our gradients through back propogation for a Dense or Convolution layer we then compute the mean of the column vectors of the weight matrix, and then remove the mean from each column vector. @@ -174,7 +179,7 @@ def get_gradients(self, loss, params): grad_len = len(grad.shape) if grad_len > 1: axis = list(range(grad_len - 1)) - grad -= tf.reduce_mean(grad, axis=axis, keep_dims=True) + grad -= ops.mean(grad, axis=axis, keep_dims=True) grads.append(grad) return grads @@ -191,7 +196,7 @@ def get_gradients(self, loss, params): """ -class TimeHistory(tf.keras.callbacks.Callback): +class TimeHistory(keras.callbacks.Callback): def on_train_begin(self, logs={}): self.times = [] diff --git a/examples/vision/ipynb/gradient_centralization.ipynb b/examples/vision/ipynb/gradient_centralization.ipynb index 1971a2036a..c17e33056a 100644 --- a/examples/vision/ipynb/gradient_centralization.ipynb +++ b/examples/vision/ipynb/gradient_centralization.ipynb @@ -10,7 +10,7 @@ "\n", "**Author:** [Rishit Dagli](https://github.com/Rishit-dagli)
\n", "**Date created:** 06/18/21
\n", - "**Last modified:** 06/18/21
\n", + "**Last modified:** 07/25/23
\n", "**Description:** Implement Gradient Centralization to improve training performance of DNNs." ] }, @@ -32,16 +32,11 @@ "the loss function and its gradient so that the training process becomes more efficient\n", "and stable.\n", "\n", - "This example requires TensorFlow 2.2 or higher as well as `tensorflow_datasets` which can\n", - "be installed with this command:\n", + "This example requires `tensorflow_datasets` which can be installed with this command:\n", "\n", "```\n", "pip install tensorflow-datasets\n", - "```\n", - "\n", - "We will be implementing Gradient Centralization in this example but you could also use\n", - "this very easily with a package I built,\n", - "[gradient-centralization-tf](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow)." + "```" ] }, { @@ -63,10 +58,14 @@ "source": [ "from time import time\n", "\n", - "import tensorflow as tf\n", + "import keras\n", + "from keras import layers\n", + "from keras.optimizers import RMSprop\n", + "from keras import ops\n", + "\n", + "from tensorflow import data as tf_data\n", "import tensorflow_datasets as tfds\n", - "from tensorflow.keras import layers\n", - "from tensorflow.keras.optimizers import RMSprop" + "" ] }, { @@ -93,7 +92,7 @@ "input_shape = (300, 300, 3)\n", "dataset_name = \"horses_or_humans\"\n", "batch_size = 128\n", - "AUTOTUNE = tf.data.AUTOTUNE\n", + "AUTOTUNE = tf_data.AUTOTUNE\n", "\n", "(train_ds, test_ds), metadata = tfds.load(\n", " name=dataset_name,\n", @@ -128,13 +127,18 @@ "source": [ "rescale = layers.Rescaling(1.0 / 255)\n", "\n", - "data_augmentation = tf.keras.Sequential(\n", - " [\n", - " layers.RandomFlip(\"horizontal_and_vertical\"),\n", - " layers.RandomRotation(0.3),\n", - " layers.RandomZoom(0.2),\n", - " ]\n", - ")\n", + "data_augmentation = [\n", + " layers.RandomFlip(\"horizontal_and_vertical\"),\n", + " layers.RandomRotation(0.3),\n", + " layers.RandomZoom(0.2),\n", + "]\n", + "\n", + "\n", + "# Helper to apply augmentation\n", + "def apply_aug(x):\n", + " for aug in data_augmentation:\n", + " x = aug(x)\n", + " return x\n", "\n", "\n", "def prepare(ds, shuffle=False, augment=False):\n", @@ -150,7 +154,7 @@ " # Use data augmentation only on the training set\n", " if augment:\n", " ds = ds.map(\n", - " lambda x, y: (data_augmentation(x, training=True), y),\n", + " lambda x, y: (apply_aug(x), y),\n", " num_parallel_calls=AUTOTUNE,\n", " )\n", "\n", @@ -199,9 +203,10 @@ }, "outputs": [], "source": [ - "model = tf.keras.Sequential(\n", + "model = keras.Sequential(\n", " [\n", - " layers.Conv2D(16, (3, 3), activation=\"relu\", input_shape=(300, 300, 3)),\n", + " layers.Input(shape=input_shape),\n", + " layers.Conv2D(16, (3, 3), activation=\"relu\"),\n", " layers.MaxPooling2D(2, 2),\n", " layers.Conv2D(32, (3, 3), activation=\"relu\"),\n", " layers.Dropout(0.5),\n", @@ -231,7 +236,7 @@ "\n", "We will now\n", "subclass the `RMSProp` optimizer class modifying the\n", - "`tf.keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient\n", + "`keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient\n", "Centralization. On a high level the idea is that let us say we obtain our gradients\n", "through back propogation for a Dense or Convolution layer we then compute the mean of the\n", "column vectors of the weight matrix, and then remove the mean from each column vector.\n", @@ -270,7 +275,7 @@ " grad_len = len(grad.shape)\n", " if grad_len > 1:\n", " axis = list(range(grad_len - 1))\n", - " grad -= tf.reduce_mean(grad, axis=axis, keep_dims=True)\n", + " grad -= ops.mean(grad, axis=axis, keep_dims=True)\n", " grads.append(grad)\n", "\n", " return grads\n", @@ -301,7 +306,7 @@ "outputs": [], "source": [ "\n", - "class TimeHistory(tf.keras.callbacks.Callback):\n", + "class TimeHistory(keras.callbacks.Callback):\n", " def on_train_begin(self, logs={}):\n", " self.times = []\n", "\n", diff --git a/examples/vision/md/gradient_centralization.md b/examples/vision/md/gradient_centralization.md index a741af2b57..413207c5fc 100644 --- a/examples/vision/md/gradient_centralization.md +++ b/examples/vision/md/gradient_centralization.md @@ -2,7 +2,7 @@ **Author:** [Rishit Dagli](https://github.com/Rishit-dagli)
**Date created:** 06/18/21
-**Last modified:** 06/18/21
+**Last modified:** 07/25/23
**Description:** Implement Gradient Centralization to improve training performance of DNNs. @@ -10,6 +10,7 @@ +--- ## Introduction This example implements [Gradient Centralization](https://arxiv.org/abs/2004.01461), a @@ -22,29 +23,30 @@ vectors to have zero mean. Gradient Centralization morever improves the Lipschit the loss function and its gradient so that the training process becomes more efficient and stable. -This example requires TensorFlow 2.2 or higher as well as `tensorflow_datasets` which can -be installed with this command: +This example requires `tensorflow_datasets` which can be installed with this command: ``` pip install tensorflow-datasets ``` -We will be implementing Gradient Centralization in this example but you could also use -this very easily with a package I built, -[gradient-centralization-tf](https://github.com/Rishit-dagli/Gradient-Centralization-TensorFlow). - +--- ## Setup ```python from time import time -import tensorflow as tf +import keras +from keras import layers +from keras.optimizers import RMSprop +from keras import ops + +from tensorflow import data as tf_data import tensorflow_datasets as tfds -from tensorflow.keras import layers -from tensorflow.keras.optimizers import RMSprop + ``` +--- ## Prepare the data For this example, we will be using the [Horses or Humans @@ -56,7 +58,7 @@ num_classes = 2 input_shape = (300, 300, 3) dataset_name = "horses_or_humans" batch_size = 128 -AUTOTUNE = tf.data.AUTOTUNE +AUTOTUNE = tf_data.AUTOTUNE (train_ds, test_ds), metadata = tfds.load( name=dataset_name, @@ -78,6 +80,7 @@ Test images: 256 ``` +--- ## Use Data Augmentation We will rescale the data to `[0, 1]` and perform simple augmentations to our data. @@ -86,13 +89,18 @@ We will rescale the data to `[0, 1]` and perform simple augmentations to our dat ```python rescale = layers.Rescaling(1.0 / 255) -data_augmentation = tf.keras.Sequential( - [ - layers.RandomFlip("horizontal_and_vertical"), - layers.RandomRotation(0.3), - layers.RandomZoom(0.2), - ] -) +data_augmentation = [ + layers.RandomFlip("horizontal_and_vertical"), + layers.RandomRotation(0.3), + layers.RandomZoom(0.2), +] + + +# Helper to apply augmentation +def apply_aug(x): + for aug in data_augmentation: + x = aug(x) + return x def prepare(ds, shuffle=False, augment=False): @@ -108,7 +116,7 @@ def prepare(ds, shuffle=False, augment=False): # Use data augmentation only on the training set if augment: ds = ds.map( - lambda x, y: (data_augmentation(x, training=True), y), + lambda x, y: (apply_aug(x), y), num_parallel_calls=AUTOTUNE, ) @@ -125,15 +133,17 @@ train_ds = prepare(train_ds, shuffle=True, augment=True) test_ds = prepare(test_ds) ``` +--- ## Define a model In this section we will define a Convolutional neural network. ```python -model = tf.keras.Sequential( +model = keras.Sequential( [ - layers.Conv2D(16, (3, 3), activation="relu", input_shape=(300, 300, 3)), + layers.Input(shape=input_shape), + layers.Conv2D(16, (3, 3), activation="relu"), layers.MaxPooling2D(2, 2), layers.Conv2D(32, (3, 3), activation="relu"), layers.Dropout(0.5), @@ -153,11 +163,12 @@ model = tf.keras.Sequential( ) ``` +--- ## Implement Gradient Centralization We will now subclass the `RMSProp` optimizer class modifying the -`tf.keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient +`keras.optimizers.Optimizer.get_gradients()` method where we now implement Gradient Centralization. On a high level the idea is that let us say we obtain our gradients through back propogation for a Dense or Convolution layer we then compute the mean of the column vectors of the weight matrix, and then remove the mean from each column vector. @@ -189,7 +200,7 @@ class GCRMSprop(RMSprop): grad_len = len(grad.shape) if grad_len > 1: axis = list(range(grad_len - 1)) - grad -= tf.reduce_mean(grad, axis=axis, keep_dims=True) + grad -= ops.mean(grad, axis=axis, keep_dims=True) grads.append(grad) return grads @@ -198,6 +209,7 @@ class GCRMSprop(RMSprop): optimizer = GCRMSprop(learning_rate=1e-4) ``` +--- ## Training utilities We will also create a callback which allows us to easily measure the total training time @@ -207,7 +219,7 @@ Gradient Centralization on the model we built above. ```python -class TimeHistory(tf.keras.callbacks.Callback): +class TimeHistory(keras.callbacks.Callback): def on_train_begin(self, logs={}): self.times = [] @@ -219,6 +231,7 @@ class TimeHistory(tf.keras.callbacks.Callback): ``` +--- ## Train the model without GC We now train the model we built earlier without Gradient Centralization which we can @@ -236,51 +249,70 @@ model.compile( model.summary() ``` -
-``` -Model: "sequential_1" -_________________________________________________________________ -Layer (type) Output Shape Param # -================================================================= -conv2d (Conv2D) (None, 298, 298, 16) 448 -_________________________________________________________________ -max_pooling2d (MaxPooling2D) (None, 149, 149, 16) 0 -_________________________________________________________________ -conv2d_1 (Conv2D) (None, 147, 147, 32) 4640 -_________________________________________________________________ -dropout (Dropout) (None, 147, 147, 32) 0 -_________________________________________________________________ -max_pooling2d_1 (MaxPooling2 (None, 73, 73, 32) 0 -_________________________________________________________________ -conv2d_2 (Conv2D) (None, 71, 71, 64) 18496 -_________________________________________________________________ -dropout_1 (Dropout) (None, 71, 71, 64) 0 -_________________________________________________________________ -max_pooling2d_2 (MaxPooling2 (None, 35, 35, 64) 0 -_________________________________________________________________ -conv2d_3 (Conv2D) (None, 33, 33, 64) 36928 -_________________________________________________________________ -max_pooling2d_3 (MaxPooling2 (None, 16, 16, 64) 0 -_________________________________________________________________ -conv2d_4 (Conv2D) (None, 14, 14, 64) 36928 -_________________________________________________________________ -max_pooling2d_4 (MaxPooling2 (None, 7, 7, 64) 0 -_________________________________________________________________ -flatten (Flatten) (None, 3136) 0 -_________________________________________________________________ -dropout_2 (Dropout) (None, 3136) 0 -_________________________________________________________________ -dense (Dense) (None, 512) 1606144 -_________________________________________________________________ -dense_1 (Dense) (None, 1) 513 -================================================================= -Total params: 1,704,097 -Trainable params: 1,704,097 -Non-trainable params: 0 -_________________________________________________________________ -``` -
+
Model: "sequential"
+
+ + + + +
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
+┃ Layer (type)                     Output Shape                  Param # ┃
+┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
+│ conv2d (Conv2D)                 │ (None, 298, 298, 16)      │        448 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ max_pooling2d (MaxPooling2D)    │ (None, 149, 149, 16)      │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ conv2d_1 (Conv2D)               │ (None, 147, 147, 32)      │      4,640 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ dropout (Dropout)               │ (None, 147, 147, 32)      │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ max_pooling2d_1 (MaxPooling2D)  │ (None, 73, 73, 32)        │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ conv2d_2 (Conv2D)               │ (None, 71, 71, 64)        │     18,496 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ dropout_1 (Dropout)             │ (None, 71, 71, 64)        │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ max_pooling2d_2 (MaxPooling2D)  │ (None, 35, 35, 64)        │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ conv2d_3 (Conv2D)               │ (None, 33, 33, 64)        │     36,928 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ max_pooling2d_3 (MaxPooling2D)  │ (None, 16, 16, 64)        │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ conv2d_4 (Conv2D)               │ (None, 14, 14, 64)        │     36,928 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ max_pooling2d_4 (MaxPooling2D)  │ (None, 7, 7, 64)          │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ flatten (Flatten)               │ (None, 3136)              │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ dropout_2 (Dropout)             │ (None, 3136)              │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ dense (Dense)                   │ (None, 512)               │  1,606,144 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ dense_1 (Dense)                 │ (None, 1)                 │        513 │
+└─────────────────────────────────┴───────────────────────────┴────────────┘
+
+ + + + +
 Total params: 1,704,097 (6.50 MB)
+
+ + + + +
 Trainable params: 1,704,097 (6.50 MB)
+
+ + + + +
 Non-trainable params: 0 (0.00 B)
+
+ + + We also save the history since we later want to compare our model trained with and not trained with Gradient Centralization @@ -294,28 +326,29 @@ history_no_gc = model.fit(
``` Epoch 1/10 -9/9 [==============================] - 5s 571ms/step - loss: 0.7427 - accuracy: 0.5073 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 24s 778ms/step - accuracy: 0.4772 - loss: 0.7405 Epoch 2/10 -9/9 [==============================] - 6s 667ms/step - loss: 0.6757 - accuracy: 0.5433 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 597ms/step - accuracy: 0.5434 - loss: 0.6861 Epoch 3/10 -9/9 [==============================] - 6s 660ms/step - loss: 0.6616 - accuracy: 0.6144 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 700ms/step - accuracy: 0.5402 - loss: 0.6911 Epoch 4/10 -9/9 [==============================] - 6s 642ms/step - loss: 0.6598 - accuracy: 0.6203 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 586ms/step - accuracy: 0.5884 - loss: 0.6788 Epoch 5/10 -9/9 [==============================] - 6s 666ms/step - loss: 0.6782 - accuracy: 0.6329 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 588ms/step - accuracy: 0.6570 - loss: 0.6564 Epoch 6/10 -9/9 [==============================] - 6s 655ms/step - loss: 0.6550 - accuracy: 0.6524 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 591ms/step - accuracy: 0.6671 - loss: 0.6395 Epoch 7/10 -9/9 [==============================] - 6s 645ms/step - loss: 0.6157 - accuracy: 0.7186 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 594ms/step - accuracy: 0.7010 - loss: 0.6161 Epoch 8/10 -9/9 [==============================] - 6s 654ms/step - loss: 0.6095 - accuracy: 0.6913 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 593ms/step - accuracy: 0.6946 - loss: 0.6129 Epoch 9/10 -9/9 [==============================] - 6s 677ms/step - loss: 0.5880 - accuracy: 0.7147 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 699ms/step - accuracy: 0.6972 - loss: 0.5987 Epoch 10/10 -9/9 [==============================] - 6s 663ms/step - loss: 0.5814 - accuracy: 0.6933 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 11s 623ms/step - accuracy: 0.6839 - loss: 0.6197 ```
+--- ## Train the model with GC We will now train the same model, this time using Gradient Centralization, @@ -331,71 +364,96 @@ model.summary() history_gc = model.fit(train_ds, epochs=10, verbose=1, callbacks=[time_callback_gc]) ``` + +
Model: "sequential"
+
+ + + + +
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
+┃ Layer (type)                     Output Shape                  Param # ┃
+┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
+│ conv2d (Conv2D)                 │ (None, 298, 298, 16)      │        448 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ max_pooling2d (MaxPooling2D)    │ (None, 149, 149, 16)      │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ conv2d_1 (Conv2D)               │ (None, 147, 147, 32)      │      4,640 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ dropout (Dropout)               │ (None, 147, 147, 32)      │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ max_pooling2d_1 (MaxPooling2D)  │ (None, 73, 73, 32)        │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ conv2d_2 (Conv2D)               │ (None, 71, 71, 64)        │     18,496 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ dropout_1 (Dropout)             │ (None, 71, 71, 64)        │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ max_pooling2d_2 (MaxPooling2D)  │ (None, 35, 35, 64)        │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ conv2d_3 (Conv2D)               │ (None, 33, 33, 64)        │     36,928 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ max_pooling2d_3 (MaxPooling2D)  │ (None, 16, 16, 64)        │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ conv2d_4 (Conv2D)               │ (None, 14, 14, 64)        │     36,928 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ max_pooling2d_4 (MaxPooling2D)  │ (None, 7, 7, 64)          │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ flatten (Flatten)               │ (None, 3136)              │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ dropout_2 (Dropout)             │ (None, 3136)              │          0 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ dense (Dense)                   │ (None, 512)               │  1,606,144 │
+├─────────────────────────────────┼───────────────────────────┼────────────┤
+│ dense_1 (Dense)                 │ (None, 1)                 │        513 │
+└─────────────────────────────────┴───────────────────────────┴────────────┘
+
+ + + + +
 Total params: 1,704,097 (6.50 MB)
+
+ + + + +
 Trainable params: 1,704,097 (6.50 MB)
+
+ + + + +
 Non-trainable params: 0 (0.00 B)
+
+ + +
``` -Model: "sequential_1" -_________________________________________________________________ -Layer (type) Output Shape Param # -================================================================= -conv2d (Conv2D) (None, 298, 298, 16) 448 -_________________________________________________________________ -max_pooling2d (MaxPooling2D) (None, 149, 149, 16) 0 -_________________________________________________________________ -conv2d_1 (Conv2D) (None, 147, 147, 32) 4640 -_________________________________________________________________ -dropout (Dropout) (None, 147, 147, 32) 0 -_________________________________________________________________ -max_pooling2d_1 (MaxPooling2 (None, 73, 73, 32) 0 -_________________________________________________________________ -conv2d_2 (Conv2D) (None, 71, 71, 64) 18496 -_________________________________________________________________ -dropout_1 (Dropout) (None, 71, 71, 64) 0 -_________________________________________________________________ -max_pooling2d_2 (MaxPooling2 (None, 35, 35, 64) 0 -_________________________________________________________________ -conv2d_3 (Conv2D) (None, 33, 33, 64) 36928 -_________________________________________________________________ -max_pooling2d_3 (MaxPooling2 (None, 16, 16, 64) 0 -_________________________________________________________________ -conv2d_4 (Conv2D) (None, 14, 14, 64) 36928 -_________________________________________________________________ -max_pooling2d_4 (MaxPooling2 (None, 7, 7, 64) 0 -_________________________________________________________________ -flatten (Flatten) (None, 3136) 0 -_________________________________________________________________ -dropout_2 (Dropout) (None, 3136) 0 -_________________________________________________________________ -dense (Dense) (None, 512) 1606144 -_________________________________________________________________ -dense_1 (Dense) (None, 1) 513 -================================================================= -Total params: 1,704,097 -Trainable params: 1,704,097 -Non-trainable params: 0 -_________________________________________________________________ Epoch 1/10 -9/9 [==============================] - 6s 673ms/step - loss: 0.6022 - accuracy: 0.7147 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 12s 649ms/step - accuracy: 0.7118 - loss: 0.5594 Epoch 2/10 -9/9 [==============================] - 6s 662ms/step - loss: 0.5385 - accuracy: 0.7371 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 592ms/step - accuracy: 0.7249 - loss: 0.5817 Epoch 3/10 -9/9 [==============================] - 6s 673ms/step - loss: 0.4832 - accuracy: 0.7945 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 587ms/step - accuracy: 0.8060 - loss: 0.4448 Epoch 4/10 -9/9 [==============================] - 6s 645ms/step - loss: 0.4692 - accuracy: 0.7799 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 693ms/step - accuracy: 0.8472 - loss: 0.4051 Epoch 5/10 -9/9 [==============================] - 6s 720ms/step - loss: 0.4792 - accuracy: 0.7799 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 594ms/step - accuracy: 0.8386 - loss: 0.3978 Epoch 6/10 -9/9 [==============================] - 6s 658ms/step - loss: 0.4623 - accuracy: 0.7838 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 593ms/step - accuracy: 0.8442 - loss: 0.3976 Epoch 7/10 -9/9 [==============================] - 6s 651ms/step - loss: 0.4413 - accuracy: 0.8072 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 585ms/step - accuracy: 0.7409 - loss: 0.6626 Epoch 8/10 -9/9 [==============================] - 6s 682ms/step - loss: 0.4542 - accuracy: 0.8014 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 587ms/step - accuracy: 0.8191 - loss: 0.4357 Epoch 9/10 -9/9 [==============================] - 6s 649ms/step - loss: 0.4235 - accuracy: 0.8053 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 587ms/step - accuracy: 0.8248 - loss: 0.3974 Epoch 10/10 -9/9 [==============================] - 6s 686ms/step - loss: 0.4445 - accuracy: 0.7936 + 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 646ms/step - accuracy: 0.8022 - loss: 0.4589 ```
+--- ## Comparing performance @@ -414,13 +472,13 @@ print(f"Training Time: {sum(time_callback_gc.times)}")
``` Not using Gradient Centralization -Loss: 0.5814347863197327 -Accuracy: 0.6932814121246338 -Training Time: 136.35903406143188 +Loss: 0.5345584154129028 +Accuracy: 0.7604166865348816 +Training Time: 112.48799777030945 Using Gradient Centralization -Loss: 0.4444807469844818 -Accuracy: 0.7935734987258911 -Training Time: 131.61780261993408 +Loss: 0.4014038145542145 +Accuracy: 0.8153935074806213 +Training Time: 98.31573963165283 ```