From a124a3f0391d957b31dd99eef0a5591a866b9c8d Mon Sep 17 00:00:00 2001 From: Mansi Mehta Date: Wed, 18 Dec 2024 22:56:27 +0530 Subject: [PATCH] Update timeseries_classification_transformer tutorial with good accuracy --- ...imeseries_classification_transformer.ipynb | 1053 ++++++++++++----- .../timeseries_classification_transformer.py | 35 +- 2 files changed, 810 insertions(+), 278 deletions(-) diff --git a/examples/timeseries/ipynb/timeseries_classification_transformer.ipynb b/examples/timeseries/ipynb/timeseries_classification_transformer.ipynb index 78c164603a..6c665ab82b 100644 --- a/examples/timeseries/ipynb/timeseries_classification_transformer.ipynb +++ b/examples/timeseries/ipynb/timeseries_classification_transformer.ipynb @@ -1,269 +1,790 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "# Timeseries classification with a Transformer model\n", - "\n", - "**Author:** [Theodoros Ntakouris](https://github.com/ntakouris)
\n", - "**Date created:** 2021/06/25
\n", - "**Last modified:** 2021/08/05
\n", - "**Description:** This notebook demonstrates how to do timeseries classification using a Transformer model." - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "1vqRPhIg4w3J" + }, + "source": [ + "# Timeseries classification with a Transformer model\n", + "\n", + "**Author:** [Theodoros Ntakouris](https://github.com/ntakouris)
\n", + "**Date created:** 2021/06/25
\n", + "**Last modified:** 2024/12/18
\n", + "**Description:** This notebook demonstrates how to do timeseries classification using a Transformer model." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6EeV4iem4w3K" + }, + "source": [ + "## Introduction\n", + "\n", + "This is the Transformer architecture from\n", + "[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n", + "applied to timeseries instead of natural language.\n", + "\n", + "This example requires TensorFlow 2.4 or higher.\n", + "\n", + "## Load the dataset\n", + "\n", + "We are going to use the same dataset and preprocessing as the\n", + "[TimeSeries Classification from Scratch](https://keras.io/examples/timeseries/timeseries_classification_from_scratch)\n", + "example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "w67sxoEc4w3K" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import keras\n", + "from keras import layers\n", + "\n", + "\n", + "def readucr(filename):\n", + " data = np.loadtxt(filename, delimiter=\"\\t\")\n", + " y = data[:, 0]\n", + " x = data[:, 1:]\n", + " return x, y.astype(int)\n", + "\n", + "\n", + "root_url = \"https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/\"\n", + "\n", + "x_train, y_train = readucr(root_url + \"FordA_TRAIN.tsv\")\n", + "x_test, y_test = readucr(root_url + \"FordA_TEST.tsv\")\n", + "\n", + "x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))\n", + "x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))\n", + "\n", + "n_classes = len(np.unique(y_train))\n", + "\n", + "idx = np.random.permutation(len(x_train))\n", + "x_train = x_train[idx]\n", + "y_train = y_train[idx]\n", + "\n", + "y_train[y_train == -1] = 0\n", + "y_test[y_test == -1] = 0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_TER3qzj4w3L" + }, + "source": [ + "## Build the model\n", + "\n", + "Our model processes a tensor of shape `(batch size, sequence length, features)`,\n", + "where `sequence length` is the number of time steps and `features` is each input\n", + "timeseries.\n", + "\n", + "You can replace your classification RNN layers with this one: the\n", + "inputs are fully compatible!\n", + "\n", + "We include residual connections, layer normalization, and dropout.\n", + "The resulting layer can be stacked multiple times.\n", + "\n", + "The projection layers are implemented through `keras.layers.Conv1D`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qQPKnxqX4w3L" + }, + "outputs": [], + "source": [ + "\n", + "def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):\n", + " # Attention and Normalization\n", + " x = layers.MultiHeadAttention(\n", + " key_dim=head_size, num_heads=num_heads, dropout=dropout\n", + " )(inputs, inputs)\n", + " x = layers.Dropout(dropout)(x)\n", + " x = layers.LayerNormalization(epsilon=1e-6)(x)\n", + " res = x + inputs\n", + "\n", + " # Feed Forward Part\n", + " x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation=\"relu\")(res)\n", + " x = layers.Dropout(dropout)(x)\n", + " x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)\n", + " x = layers.LayerNormalization(epsilon=1e-6)(x)\n", + " return x + res\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uRYsZ56C4w3L" + }, + "source": [ + "The main part of our model is now complete. We can stack multiple of those\n", + "`transformer_encoder` blocks and we can also proceed to add the final\n", + "Multi-Layer Perceptron classification head. Apart from a stack of `Dense`\n", + "layers, we need to reduce the output tensor of the `TransformerEncoder` part of\n", + "our model down to a vector of features for each data point in the current\n", + "batch. A common way to achieve this is to use a pooling layer. For\n", + "this example, a `GlobalAveragePooling1D` layer is sufficient." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2kFpWK064w3L" + }, + "outputs": [], + "source": [ + "\n", + "def build_model(\n", + " input_shape,\n", + " head_size,\n", + " num_heads,\n", + " ff_dim,\n", + " num_transformer_blocks,\n", + " mlp_units,\n", + " dropout=0,\n", + " mlp_dropout=0,\n", + "):\n", + " inputs = keras.Input(shape=input_shape)\n", + " x = inputs\n", + " for _ in range(num_transformer_blocks):\n", + " x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)\n", + " print(f'Transformer Encoder: {x}')\n", + " x = layers.GlobalAveragePooling1D(data_format=\"channels_first\")(x)\n", + " print(f'Global Average Pooling: {x}')\n", + " for dim in mlp_units:\n", + " x = layers.Dense(dim, activation=\"relu\")(x)\n", + " x = layers.Dropout(mlp_dropout)(x)\n", + " outputs = layers.Dense(n_classes, activation=\"softmax\")(x)\n", + " return keras.Model(inputs, outputs)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V1vu7wBI4w3M" + }, + "source": [ + "## Train and evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "BZ5JILKHQA4x", + "outputId": "e01ab94d-68c9-4548-df26-5a3c07988b4a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transformer Encoder: \n", + "Global Average Pooling: \n" + ] + }, + { + "data": { + "text/html": [ + "
Model: \"functional\"\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"functional\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+              "┃ Layer (type)               Output Shape                   Param #  Connected to           ┃\n",
+              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+              "│ input_layer (InputLayer)  │ (None, 500, 1)         │              0 │ -                      │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ multi_head_attention      │ (None, 500, 1)         │          7,169 │ input_layer[0][0],     │\n",
+              "│ (MultiHeadAttention)      │                        │                │ input_layer[0][0]      │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ dropout_1 (Dropout)       │ (None, 500, 1)         │              0 │ multi_head_attention[ │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ layer_normalization       │ (None, 500, 1)         │              2 │ dropout_1[0][0]        │\n",
+              "│ (LayerNormalization)      │                        │                │                        │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ add (Add)                 │ (None, 500, 1)         │              0 │ layer_normalization[0… │\n",
+              "│                           │                        │                │ input_layer[0][0]      │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ conv1d (Conv1D)           │ (None, 500, 4)         │              8 │ add[0][0]              │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ dropout_2 (Dropout)       │ (None, 500, 4)         │              0 │ conv1d[0][0]           │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ conv1d_1 (Conv1D)         │ (None, 500, 1)         │              5 │ dropout_2[0][0]        │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ layer_normalization_1     │ (None, 500, 1)         │              2 │ conv1d_1[0][0]         │\n",
+              "│ (LayerNormalization)      │                        │                │                        │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ add_1 (Add)               │ (None, 500, 1)         │              0 │ layer_normalization_1… │\n",
+              "│                           │                        │                │ add[0][0]              │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ multi_head_attention_1    │ (None, 500, 1)         │          7,169 │ add_1[0][0],           │\n",
+              "│ (MultiHeadAttention)      │                        │                │ add_1[0][0]            │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ dropout_4 (Dropout)       │ (None, 500, 1)         │              0 │ multi_head_attention_… │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ layer_normalization_2     │ (None, 500, 1)         │              2 │ dropout_4[0][0]        │\n",
+              "│ (LayerNormalization)      │                        │                │                        │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ add_2 (Add)               │ (None, 500, 1)         │              0 │ layer_normalization_2… │\n",
+              "│                           │                        │                │ add_1[0][0]            │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ conv1d_2 (Conv1D)         │ (None, 500, 4)         │              8 │ add_2[0][0]            │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ dropout_5 (Dropout)       │ (None, 500, 4)         │              0 │ conv1d_2[0][0]         │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ conv1d_3 (Conv1D)         │ (None, 500, 1)         │              5 │ dropout_5[0][0]        │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ layer_normalization_3     │ (None, 500, 1)         │              2 │ conv1d_3[0][0]         │\n",
+              "│ (LayerNormalization)      │                        │                │                        │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ add_3 (Add)               │ (None, 500, 1)         │              0 │ layer_normalization_3… │\n",
+              "│                           │                        │                │ add_2[0][0]            │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ multi_head_attention_2    │ (None, 500, 1)         │          7,169 │ add_3[0][0],           │\n",
+              "│ (MultiHeadAttention)      │                        │                │ add_3[0][0]            │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ dropout_7 (Dropout)       │ (None, 500, 1)         │              0 │ multi_head_attention_… │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ layer_normalization_4     │ (None, 500, 1)         │              2 │ dropout_7[0][0]        │\n",
+              "│ (LayerNormalization)      │                        │                │                        │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ add_4 (Add)               │ (None, 500, 1)         │              0 │ layer_normalization_4… │\n",
+              "│                           │                        │                │ add_3[0][0]            │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ conv1d_4 (Conv1D)         │ (None, 500, 4)         │              8 │ add_4[0][0]            │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ dropout_8 (Dropout)       │ (None, 500, 4)         │              0 │ conv1d_4[0][0]         │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ conv1d_5 (Conv1D)         │ (None, 500, 1)         │              5 │ dropout_8[0][0]        │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ layer_normalization_5     │ (None, 500, 1)         │              2 │ conv1d_5[0][0]         │\n",
+              "│ (LayerNormalization)      │                        │                │                        │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ add_5 (Add)               │ (None, 500, 1)         │              0 │ layer_normalization_5… │\n",
+              "│                           │                        │                │ add_4[0][0]            │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ multi_head_attention_3    │ (None, 500, 1)         │          7,169 │ add_5[0][0],           │\n",
+              "│ (MultiHeadAttention)      │                        │                │ add_5[0][0]            │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ dropout_10 (Dropout)      │ (None, 500, 1)         │              0 │ multi_head_attention_… │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ layer_normalization_6     │ (None, 500, 1)         │              2 │ dropout_10[0][0]       │\n",
+              "│ (LayerNormalization)      │                        │                │                        │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ add_6 (Add)               │ (None, 500, 1)         │              0 │ layer_normalization_6… │\n",
+              "│                           │                        │                │ add_5[0][0]            │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ conv1d_6 (Conv1D)         │ (None, 500, 4)         │              8 │ add_6[0][0]            │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ dropout_11 (Dropout)      │ (None, 500, 4)         │              0 │ conv1d_6[0][0]         │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ conv1d_7 (Conv1D)         │ (None, 500, 1)         │              5 │ dropout_11[0][0]       │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ layer_normalization_7     │ (None, 500, 1)         │              2 │ conv1d_7[0][0]         │\n",
+              "│ (LayerNormalization)      │                        │                │                        │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ add_7 (Add)               │ (None, 500, 1)         │              0 │ layer_normalization_7… │\n",
+              "│                           │                        │                │ add_6[0][0]            │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ global_average_pooling1d  │ (None, 500)            │              0 │ add_7[0][0]            │\n",
+              "│ (GlobalAveragePooling1D)  │                        │                │                        │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ dense (Dense)             │ (None, 128)            │         64,128 │ global_average_poolin… │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ dropout_12 (Dropout)      │ (None, 128)            │              0 │ dense[0][0]            │\n",
+              "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+              "│ dense_1 (Dense)           │ (None, 2)              │            258 │ dropout_12[0][0]       │\n",
+              "└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘\n",
+              "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ input_layer (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ multi_head_attention │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m7,169\u001b[0m │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mMultiHeadAttention\u001b[0m) │ │ │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ dropout_1 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ multi_head_attention[\u001b[38;5;34m…\u001b[0m │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ layer_normalization │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ dropout_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ add (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization[\u001b[38;5;34m0\u001b[0m… │\n", + "│ │ │ │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ conv1d (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m8\u001b[0m │ add[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ dropout_2 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv1d[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ conv1d_1 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m5\u001b[0m │ dropout_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ layer_normalization_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ conv1d_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ add_1 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_1… │\n", + "│ │ │ │ add[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ multi_head_attention_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m7,169\u001b[0m │ add_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mMultiHeadAttention\u001b[0m) │ │ │ add_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ dropout_4 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ multi_head_attention_… │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ layer_normalization_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ dropout_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ add_2 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_2… │\n", + "│ │ │ │ add_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ conv1d_2 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m8\u001b[0m │ add_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ dropout_5 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv1d_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ conv1d_3 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m5\u001b[0m │ dropout_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ layer_normalization_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ conv1d_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ add_3 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_3… │\n", + "│ │ │ │ add_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ multi_head_attention_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m7,169\u001b[0m │ add_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mMultiHeadAttention\u001b[0m) │ │ │ add_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ dropout_7 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ multi_head_attention_… │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ layer_normalization_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ dropout_7[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ add_4 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_4… │\n", + "│ │ │ │ add_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ conv1d_4 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m8\u001b[0m │ add_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ dropout_8 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv1d_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ conv1d_5 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m5\u001b[0m │ dropout_8[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ layer_normalization_5 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ conv1d_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ add_5 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_5… │\n", + "│ │ │ │ add_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ multi_head_attention_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m7,169\u001b[0m │ add_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", + "│ (\u001b[38;5;33mMultiHeadAttention\u001b[0m) │ │ │ add_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ dropout_10 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ multi_head_attention_… │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ layer_normalization_6 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ dropout_10[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ add_6 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_6… │\n", + "│ │ │ │ add_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ conv1d_6 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m8\u001b[0m │ add_6[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ dropout_11 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv1d_6[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ conv1d_7 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m5\u001b[0m │ dropout_11[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ layer_normalization_7 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ conv1d_7[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ add_7 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_7… │\n", + "│ │ │ │ add_6[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ global_average_pooling1d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ add_7[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "│ (\u001b[38;5;33mGlobalAveragePooling1D\u001b[0m) │ │ │ │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m64,128\u001b[0m │ global_average_poolin… │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ dropout_12 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dense[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", + "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m258\u001b[0m │ dropout_12[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", + "└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 93,130 (363.79 KB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m93,130\u001b[0m (363.79 KB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Trainable params: 93,130 (363.79 KB)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m93,130\u001b[0m (363.79 KB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "input_shape = x_train.shape[1:]\n", + "\n", + "model = build_model(\n", + " input_shape,\n", + " head_size=256,\n", + " num_heads=4,\n", + " ff_dim=4,\n", + " num_transformer_blocks=4,\n", + " mlp_units=[128],\n", + " mlp_dropout=0.4,\n", + " dropout=0.25,\n", + ")\n", + "\n", + "model.compile(\n", + " loss=\"sparse_categorical_crossentropy\",\n", + " optimizer=keras.optimizers.Adam(learning_rate=1e-4),\n", + " metrics=[\"sparse_categorical_accuracy\"],\n", + ")\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "eDOvm4BR4w3M", + "outputId": "fa53a57d-d9af-41b6-f0bd-5dc4d7dfb621" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m54s\u001b[0m 401ms/step - loss: 1.1535 - sparse_categorical_accuracy: 0.4893 - val_loss: 0.7982 - val_sparse_categorical_accuracy: 0.5298\n", + "Epoch 2/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 352ms/step - loss: 0.8630 - sparse_categorical_accuracy: 0.5636 - val_loss: 0.7076 - val_sparse_categorical_accuracy: 0.6019\n", + "Epoch 3/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 341ms/step - loss: 0.8010 - sparse_categorical_accuracy: 0.5744 - val_loss: 0.6721 - val_sparse_categorical_accuracy: 0.6283\n", + "Epoch 4/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 352ms/step - loss: 0.7667 - sparse_categorical_accuracy: 0.6024 - val_loss: 0.6411 - val_sparse_categorical_accuracy: 0.6546\n", + "Epoch 5/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 341ms/step - loss: 0.6981 - sparse_categorical_accuracy: 0.6420 - val_loss: 0.6224 - val_sparse_categorical_accuracy: 0.6574\n", + "Epoch 6/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 334ms/step - loss: 0.6731 - sparse_categorical_accuracy: 0.6317 - val_loss: 0.6064 - val_sparse_categorical_accuracy: 0.6713\n", + "Epoch 7/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 338ms/step - loss: 0.6521 - sparse_categorical_accuracy: 0.6693 - val_loss: 0.5913 - val_sparse_categorical_accuracy: 0.6727\n", + "Epoch 8/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 373ms/step - loss: 0.6322 - sparse_categorical_accuracy: 0.6663 - val_loss: 0.5807 - val_sparse_categorical_accuracy: 0.6893\n", + "Epoch 9/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 371ms/step - loss: 0.5975 - sparse_categorical_accuracy: 0.6800 - val_loss: 0.5698 - val_sparse_categorical_accuracy: 0.6976\n", + "Epoch 10/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 339ms/step - loss: 0.6067 - sparse_categorical_accuracy: 0.6786 - val_loss: 0.5639 - val_sparse_categorical_accuracy: 0.7046\n", + "Epoch 11/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 343ms/step - loss: 0.5685 - sparse_categorical_accuracy: 0.7052 - val_loss: 0.5599 - val_sparse_categorical_accuracy: 0.7101\n", + "Epoch 12/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 337ms/step - loss: 0.5386 - sparse_categorical_accuracy: 0.7203 - val_loss: 0.5500 - val_sparse_categorical_accuracy: 0.7198\n", + "Epoch 13/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 340ms/step - loss: 0.5474 - sparse_categorical_accuracy: 0.7134 - val_loss: 0.5431 - val_sparse_categorical_accuracy: 0.7198\n", + "Epoch 14/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 350ms/step - loss: 0.5267 - sparse_categorical_accuracy: 0.7421 - val_loss: 0.5363 - val_sparse_categorical_accuracy: 0.7295\n", + "Epoch 15/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 372ms/step - loss: 0.5117 - sparse_categorical_accuracy: 0.7442 - val_loss: 0.5323 - val_sparse_categorical_accuracy: 0.7337\n", + "Epoch 16/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 366ms/step - loss: 0.5199 - sparse_categorical_accuracy: 0.7359 - val_loss: 0.5258 - val_sparse_categorical_accuracy: 0.7323\n", + "Epoch 17/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 362ms/step - loss: 0.5032 - sparse_categorical_accuracy: 0.7527 - val_loss: 0.5224 - val_sparse_categorical_accuracy: 0.7351\n", + "Epoch 18/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 375ms/step - loss: 0.4854 - sparse_categorical_accuracy: 0.7615 - val_loss: 0.5164 - val_sparse_categorical_accuracy: 0.7406\n", + "Epoch 19/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 346ms/step - loss: 0.4763 - sparse_categorical_accuracy: 0.7718 - val_loss: 0.5121 - val_sparse_categorical_accuracy: 0.7434\n", + "Epoch 20/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 370ms/step - loss: 0.4577 - sparse_categorical_accuracy: 0.7865 - val_loss: 0.5084 - val_sparse_categorical_accuracy: 0.7476\n", + "Epoch 21/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 333ms/step - loss: 0.4633 - sparse_categorical_accuracy: 0.7759 - val_loss: 0.5052 - val_sparse_categorical_accuracy: 0.7365\n", + "Epoch 22/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 371ms/step - loss: 0.4557 - sparse_categorical_accuracy: 0.7855 - val_loss: 0.4984 - val_sparse_categorical_accuracy: 0.7462\n", + "Epoch 23/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 368ms/step - loss: 0.4451 - sparse_categorical_accuracy: 0.7897 - val_loss: 0.4939 - val_sparse_categorical_accuracy: 0.7503\n", + "Epoch 24/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 370ms/step - loss: 0.4348 - sparse_categorical_accuracy: 0.8004 - val_loss: 0.4916 - val_sparse_categorical_accuracy: 0.7628\n", + "Epoch 25/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 342ms/step - loss: 0.4248 - sparse_categorical_accuracy: 0.8200 - val_loss: 0.4855 - val_sparse_categorical_accuracy: 0.7587\n", + "Epoch 26/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 368ms/step - loss: 0.4233 - sparse_categorical_accuracy: 0.8066 - val_loss: 0.4833 - val_sparse_categorical_accuracy: 0.7628\n", + "Epoch 27/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 342ms/step - loss: 0.4194 - sparse_categorical_accuracy: 0.8025 - val_loss: 0.4799 - val_sparse_categorical_accuracy: 0.7587\n", + "Epoch 28/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 342ms/step - loss: 0.3837 - sparse_categorical_accuracy: 0.8354 - val_loss: 0.4757 - val_sparse_categorical_accuracy: 0.7628\n", + "Epoch 29/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 368ms/step - loss: 0.3742 - sparse_categorical_accuracy: 0.8414 - val_loss: 0.4733 - val_sparse_categorical_accuracy: 0.7656\n", + "Epoch 30/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 342ms/step - loss: 0.3777 - sparse_categorical_accuracy: 0.8434 - val_loss: 0.4722 - val_sparse_categorical_accuracy: 0.7698\n", + "Epoch 31/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 342ms/step - loss: 0.3796 - sparse_categorical_accuracy: 0.8412 - val_loss: 0.4660 - val_sparse_categorical_accuracy: 0.7712\n", + "Epoch 32/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 369ms/step - loss: 0.3903 - sparse_categorical_accuracy: 0.8277 - val_loss: 0.4629 - val_sparse_categorical_accuracy: 0.7725\n", + "Epoch 33/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 345ms/step - loss: 0.3521 - sparse_categorical_accuracy: 0.8533 - val_loss: 0.4605 - val_sparse_categorical_accuracy: 0.7684\n", + "Epoch 34/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 336ms/step - loss: 0.3624 - sparse_categorical_accuracy: 0.8480 - val_loss: 0.4596 - val_sparse_categorical_accuracy: 0.7725\n", + "Epoch 35/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 371ms/step - loss: 0.3579 - sparse_categorical_accuracy: 0.8444 - val_loss: 0.4543 - val_sparse_categorical_accuracy: 0.7809\n", + "Epoch 36/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 371ms/step - loss: 0.3583 - sparse_categorical_accuracy: 0.8565 - val_loss: 0.4502 - val_sparse_categorical_accuracy: 0.7864\n", + "Epoch 37/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 363ms/step - loss: 0.3323 - sparse_categorical_accuracy: 0.8678 - val_loss: 0.4474 - val_sparse_categorical_accuracy: 0.7878\n", + "Epoch 38/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 368ms/step - loss: 0.3440 - sparse_categorical_accuracy: 0.8574 - val_loss: 0.4444 - val_sparse_categorical_accuracy: 0.7892\n", + "Epoch 39/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 375ms/step - loss: 0.3346 - sparse_categorical_accuracy: 0.8589 - val_loss: 0.4429 - val_sparse_categorical_accuracy: 0.7975\n", + "Epoch 40/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 342ms/step - loss: 0.3274 - sparse_categorical_accuracy: 0.8565 - val_loss: 0.4393 - val_sparse_categorical_accuracy: 0.7975\n", + "Epoch 41/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 366ms/step - loss: 0.3294 - sparse_categorical_accuracy: 0.8717 - val_loss: 0.4349 - val_sparse_categorical_accuracy: 0.8031\n", + "Epoch 42/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 368ms/step - loss: 0.3222 - sparse_categorical_accuracy: 0.8725 - val_loss: 0.4323 - val_sparse_categorical_accuracy: 0.8031\n", + "Epoch 43/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 373ms/step - loss: 0.3095 - sparse_categorical_accuracy: 0.8963 - val_loss: 0.4315 - val_sparse_categorical_accuracy: 0.8003\n", + "Epoch 44/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 340ms/step - loss: 0.2991 - sparse_categorical_accuracy: 0.8872 - val_loss: 0.4290 - val_sparse_categorical_accuracy: 0.8100\n", + "Epoch 45/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 339ms/step - loss: 0.3148 - sparse_categorical_accuracy: 0.8842 - val_loss: 0.4271 - val_sparse_categorical_accuracy: 0.8072\n", + "Epoch 46/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 342ms/step - loss: 0.3020 - sparse_categorical_accuracy: 0.8880 - val_loss: 0.4237 - val_sparse_categorical_accuracy: 0.8100\n", + "Epoch 47/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 372ms/step - loss: 0.2914 - sparse_categorical_accuracy: 0.8968 - val_loss: 0.4184 - val_sparse_categorical_accuracy: 0.8114\n", + "Epoch 48/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 369ms/step - loss: 0.2889 - sparse_categorical_accuracy: 0.8976 - val_loss: 0.4179 - val_sparse_categorical_accuracy: 0.8141\n", + "Epoch 49/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 369ms/step - loss: 0.2961 - sparse_categorical_accuracy: 0.8957 - val_loss: 0.4166 - val_sparse_categorical_accuracy: 0.8197\n", + "Epoch 50/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 369ms/step - loss: 0.2956 - sparse_categorical_accuracy: 0.8883 - val_loss: 0.4146 - val_sparse_categorical_accuracy: 0.8197\n", + "Epoch 51/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 347ms/step - loss: 0.2708 - sparse_categorical_accuracy: 0.9043 - val_loss: 0.4144 - val_sparse_categorical_accuracy: 0.8183\n", + "Epoch 52/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 345ms/step - loss: 0.2883 - sparse_categorical_accuracy: 0.8916 - val_loss: 0.4095 - val_sparse_categorical_accuracy: 0.8183\n", + "Epoch 53/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 339ms/step - loss: 0.2698 - sparse_categorical_accuracy: 0.8986 - val_loss: 0.4055 - val_sparse_categorical_accuracy: 0.8225\n", + "Epoch 54/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 333ms/step - loss: 0.2878 - sparse_categorical_accuracy: 0.8807 - val_loss: 0.4033 - val_sparse_categorical_accuracy: 0.8239\n", + "Epoch 55/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 372ms/step - loss: 0.2751 - sparse_categorical_accuracy: 0.9022 - val_loss: 0.4020 - val_sparse_categorical_accuracy: 0.8294\n", + "Epoch 56/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 341ms/step - loss: 0.2604 - sparse_categorical_accuracy: 0.9021 - val_loss: 0.4014 - val_sparse_categorical_accuracy: 0.8266\n", + "Epoch 57/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 368ms/step - loss: 0.2639 - sparse_categorical_accuracy: 0.9062 - val_loss: 0.3990 - val_sparse_categorical_accuracy: 0.8239\n", + "Epoch 58/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 373ms/step - loss: 0.2597 - sparse_categorical_accuracy: 0.8953 - val_loss: 0.4008 - val_sparse_categorical_accuracy: 0.8322\n", + "Epoch 59/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 344ms/step - loss: 0.2633 - sparse_categorical_accuracy: 0.9069 - val_loss: 0.3969 - val_sparse_categorical_accuracy: 0.8308\n", + "Epoch 60/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 364ms/step - loss: 0.2682 - sparse_categorical_accuracy: 0.8988 - val_loss: 0.3959 - val_sparse_categorical_accuracy: 0.8322\n", + "Epoch 61/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 340ms/step - loss: 0.2789 - sparse_categorical_accuracy: 0.8952 - val_loss: 0.3948 - val_sparse_categorical_accuracy: 0.8280\n", + "Epoch 62/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 352ms/step - loss: 0.2493 - sparse_categorical_accuracy: 0.9158 - val_loss: 0.3915 - val_sparse_categorical_accuracy: 0.8308\n", + "Epoch 63/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 367ms/step - loss: 0.2594 - sparse_categorical_accuracy: 0.9033 - val_loss: 0.3909 - val_sparse_categorical_accuracy: 0.8239\n", + "Epoch 64/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 368ms/step - loss: 0.2526 - sparse_categorical_accuracy: 0.9003 - val_loss: 0.3869 - val_sparse_categorical_accuracy: 0.8322\n", + "Epoch 65/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 367ms/step - loss: 0.2413 - sparse_categorical_accuracy: 0.9174 - val_loss: 0.3876 - val_sparse_categorical_accuracy: 0.8266\n", + "Epoch 66/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 367ms/step - loss: 0.2448 - sparse_categorical_accuracy: 0.9144 - val_loss: 0.3879 - val_sparse_categorical_accuracy: 0.8252\n", + "Epoch 67/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 366ms/step - loss: 0.2592 - sparse_categorical_accuracy: 0.9118 - val_loss: 0.3851 - val_sparse_categorical_accuracy: 0.8336\n", + "Epoch 68/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 372ms/step - loss: 0.2469 - sparse_categorical_accuracy: 0.9032 - val_loss: 0.3866 - val_sparse_categorical_accuracy: 0.8308\n", + "Epoch 69/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 369ms/step - loss: 0.2234 - sparse_categorical_accuracy: 0.9246 - val_loss: 0.3806 - val_sparse_categorical_accuracy: 0.8294\n", + "Epoch 70/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 367ms/step - loss: 0.2204 - sparse_categorical_accuracy: 0.9304 - val_loss: 0.3804 - val_sparse_categorical_accuracy: 0.8322\n", + "Epoch 71/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 371ms/step - loss: 0.2259 - sparse_categorical_accuracy: 0.9244 - val_loss: 0.3782 - val_sparse_categorical_accuracy: 0.8322\n", + "Epoch 72/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 366ms/step - loss: 0.2165 - sparse_categorical_accuracy: 0.9274 - val_loss: 0.3779 - val_sparse_categorical_accuracy: 0.8336\n", + "Epoch 73/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 343ms/step - loss: 0.2285 - sparse_categorical_accuracy: 0.9248 - val_loss: 0.3760 - val_sparse_categorical_accuracy: 0.8391\n", + "Epoch 74/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 371ms/step - loss: 0.2211 - sparse_categorical_accuracy: 0.9183 - val_loss: 0.3762 - val_sparse_categorical_accuracy: 0.8377\n", + "Epoch 75/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 335ms/step - loss: 0.2274 - sparse_categorical_accuracy: 0.9190 - val_loss: 0.3753 - val_sparse_categorical_accuracy: 0.8350\n", + "Epoch 76/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 338ms/step - loss: 0.2096 - sparse_categorical_accuracy: 0.9314 - val_loss: 0.3737 - val_sparse_categorical_accuracy: 0.8350\n", + "Epoch 77/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 373ms/step - loss: 0.2072 - sparse_categorical_accuracy: 0.9336 - val_loss: 0.3751 - val_sparse_categorical_accuracy: 0.8336\n", + "Epoch 78/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 372ms/step - loss: 0.2125 - sparse_categorical_accuracy: 0.9287 - val_loss: 0.3735 - val_sparse_categorical_accuracy: 0.8350\n", + "Epoch 79/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 341ms/step - loss: 0.1964 - sparse_categorical_accuracy: 0.9248 - val_loss: 0.3710 - val_sparse_categorical_accuracy: 0.8350\n", + "Epoch 80/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 341ms/step - loss: 0.2226 - sparse_categorical_accuracy: 0.9141 - val_loss: 0.3678 - val_sparse_categorical_accuracy: 0.8405\n", + "Epoch 81/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 374ms/step - loss: 0.2151 - sparse_categorical_accuracy: 0.9234 - val_loss: 0.3679 - val_sparse_categorical_accuracy: 0.8377\n", + "Epoch 82/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 343ms/step - loss: 0.2053 - sparse_categorical_accuracy: 0.9285 - val_loss: 0.3687 - val_sparse_categorical_accuracy: 0.8363\n", + "Epoch 83/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 335ms/step - loss: 0.2062 - sparse_categorical_accuracy: 0.9326 - val_loss: 0.3638 - val_sparse_categorical_accuracy: 0.8363\n", + "Epoch 84/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 339ms/step - loss: 0.2051 - sparse_categorical_accuracy: 0.9268 - val_loss: 0.3663 - val_sparse_categorical_accuracy: 0.8350\n", + "Epoch 85/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 347ms/step - loss: 0.1903 - sparse_categorical_accuracy: 0.9364 - val_loss: 0.3627 - val_sparse_categorical_accuracy: 0.8350\n", + "Epoch 86/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 342ms/step - loss: 0.2070 - sparse_categorical_accuracy: 0.9210 - val_loss: 0.3616 - val_sparse_categorical_accuracy: 0.8363\n", + "Epoch 87/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 344ms/step - loss: 0.1906 - sparse_categorical_accuracy: 0.9325 - val_loss: 0.3611 - val_sparse_categorical_accuracy: 0.8377\n", + "Epoch 88/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 363ms/step - loss: 0.1934 - sparse_categorical_accuracy: 0.9334 - val_loss: 0.3584 - val_sparse_categorical_accuracy: 0.8391\n", + "Epoch 89/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 369ms/step - loss: 0.1987 - sparse_categorical_accuracy: 0.9330 - val_loss: 0.3603 - val_sparse_categorical_accuracy: 0.8322\n", + "Epoch 90/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 368ms/step - loss: 0.1913 - sparse_categorical_accuracy: 0.9335 - val_loss: 0.3551 - val_sparse_categorical_accuracy: 0.8377\n", + "Epoch 91/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 368ms/step - loss: 0.1960 - sparse_categorical_accuracy: 0.9313 - val_loss: 0.3568 - val_sparse_categorical_accuracy: 0.8447\n", + "Epoch 92/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 370ms/step - loss: 0.1989 - sparse_categorical_accuracy: 0.9260 - val_loss: 0.3508 - val_sparse_categorical_accuracy: 0.8433\n", + "Epoch 93/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 369ms/step - loss: 0.1978 - sparse_categorical_accuracy: 0.9303 - val_loss: 0.3582 - val_sparse_categorical_accuracy: 0.8447\n", + "Epoch 94/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 368ms/step - loss: 0.1761 - sparse_categorical_accuracy: 0.9435 - val_loss: 0.3536 - val_sparse_categorical_accuracy: 0.8391\n", + "Epoch 95/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 341ms/step - loss: 0.1945 - sparse_categorical_accuracy: 0.9276 - val_loss: 0.3527 - val_sparse_categorical_accuracy: 0.8377\n", + "Epoch 96/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 345ms/step - loss: 0.1700 - sparse_categorical_accuracy: 0.9453 - val_loss: 0.3540 - val_sparse_categorical_accuracy: 0.8391\n", + "Epoch 97/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 365ms/step - loss: 0.1737 - sparse_categorical_accuracy: 0.9419 - val_loss: 0.3532 - val_sparse_categorical_accuracy: 0.8336\n", + "Epoch 98/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 365ms/step - loss: 0.1868 - sparse_categorical_accuracy: 0.9330 - val_loss: 0.3536 - val_sparse_categorical_accuracy: 0.8377\n", + "Epoch 99/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 371ms/step - loss: 0.1722 - sparse_categorical_accuracy: 0.9450 - val_loss: 0.3532 - val_sparse_categorical_accuracy: 0.8377\n", + "Epoch 100/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 371ms/step - loss: 0.1893 - sparse_categorical_accuracy: 0.9311 - val_loss: 0.3526 - val_sparse_categorical_accuracy: 0.8350\n", + "Epoch 101/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 371ms/step - loss: 0.1819 - sparse_categorical_accuracy: 0.9418 - val_loss: 0.3552 - val_sparse_categorical_accuracy: 0.8391\n", + "Epoch 102/150\n", + "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 368ms/step - loss: 0.1700 - sparse_categorical_accuracy: 0.9460 - val_loss: 0.3525 - val_sparse_categorical_accuracy: 0.8447\n", + "\u001b[1m42/42\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 103ms/step - loss: 0.3497 - sparse_categorical_accuracy: 0.8483\n" + ] + }, + { + "data": { + "text/plain": [ + "[0.36210155487060547, 0.8446969985961914]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "\n", + "callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]\n", + "\n", + "model.fit(\n", + " x_train,\n", + " y_train,\n", + " validation_split=0.2,\n", + " epochs=150,\n", + " batch_size=64,\n", + " callbacks=callbacks,\n", + ")\n", + "\n", + "model.evaluate(x_test, y_test, verbose=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B7nRTbzn4w3M" + }, + "source": [ + "## Conclusions\n", + "\n", + "In about 100-102 epochs (25s each on Colab), the model reaches a training\n", + "accuracy of ~0.94, validation accuracy of ~84 and a testing\n", + "accuracy of ~85, without hyperparameter tuning. And that is for a model\n", + "with less than 100k parameters. Of course, parameter count and accuracy could be\n", + "improved by a hyperparameter search and a more sophisticated learning rate\n", + "schedule, or a different optimizer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GZeFCnD6YW1l" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "## Introduction\n", - "\n", - "This is the Transformer architecture from\n", - "[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n", - "applied to timeseries instead of natural language.\n", - "\n", - "This example requires TensorFlow 2.4 or higher.\n", - "\n", - "## Load the dataset\n", - "\n", - "We are going to use the same dataset and preprocessing as the\n", - "[TimeSeries Classification from Scratch](https://keras.io/examples/timeseries/timeseries_classification_from_scratch)\n", - "example." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "import keras\n", - "from keras import layers\n", - "\n", - "\n", - "def readucr(filename):\n", - " data = np.loadtxt(filename, delimiter=\"\\t\")\n", - " y = data[:, 0]\n", - " x = data[:, 1:]\n", - " return x, y.astype(int)\n", - "\n", - "\n", - "root_url = \"https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/\"\n", - "\n", - "x_train, y_train = readucr(root_url + \"FordA_TRAIN.tsv\")\n", - "x_test, y_test = readucr(root_url + \"FordA_TEST.tsv\")\n", - "\n", - "x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))\n", - "x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))\n", - "\n", - "n_classes = len(np.unique(y_train))\n", - "\n", - "idx = np.random.permutation(len(x_train))\n", - "x_train = x_train[idx]\n", - "y_train = y_train[idx]\n", - "\n", - "y_train[y_train == -1] = 0\n", - "y_test[y_test == -1] = 0" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "## Build the model\n", - "\n", - "Our model processes a tensor of shape `(batch size, sequence length, features)`,\n", - "where `sequence length` is the number of time steps and `features` is each input\n", - "timeseries.\n", - "\n", - "You can replace your classification RNN layers with this one: the\n", - "inputs are fully compatible!\n", - "\n", - "We include residual connections, layer normalization, and dropout.\n", - "The resulting layer can be stacked multiple times.\n", - "\n", - "The projection layers are implemented through `keras.layers.Conv1D`." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "\n", - "def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):\n", - " # Attention and Normalization\n", - " x = layers.MultiHeadAttention(\n", - " key_dim=head_size, num_heads=num_heads, dropout=dropout\n", - " )(inputs, inputs)\n", - " x = layers.Dropout(dropout)(x)\n", - " x = layers.LayerNormalization(epsilon=1e-6)(x)\n", - " res = x + inputs\n", - "\n", - " # Feed Forward Part\n", - " x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation=\"relu\")(res)\n", - " x = layers.Dropout(dropout)(x)\n", - " x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)\n", - " x = layers.LayerNormalization(epsilon=1e-6)(x)\n", - " return x + res\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "The main part of our model is now complete. We can stack multiple of those\n", - "`transformer_encoder` blocks and we can also proceed to add the final\n", - "Multi-Layer Perceptron classification head. Apart from a stack of `Dense`\n", - "layers, we need to reduce the output tensor of the `TransformerEncoder` part of\n", - "our model down to a vector of features for each data point in the current\n", - "batch. A common way to achieve this is to use a pooling layer. For\n", - "this example, a `GlobalAveragePooling1D` layer is sufficient." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "\n", - "def build_model(\n", - " input_shape,\n", - " head_size,\n", - " num_heads,\n", - " ff_dim,\n", - " num_transformer_blocks,\n", - " mlp_units,\n", - " dropout=0,\n", - " mlp_dropout=0,\n", - "):\n", - " inputs = keras.Input(shape=input_shape)\n", - " x = inputs\n", - " for _ in range(num_transformer_blocks):\n", - " x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)\n", - "\n", - " x = layers.GlobalAveragePooling1D(data_format=\"channels_last\")(x)\n", - " for dim in mlp_units:\n", - " x = layers.Dense(dim, activation=\"relu\")(x)\n", - " x = layers.Dropout(mlp_dropout)(x)\n", - " outputs = layers.Dense(n_classes, activation=\"softmax\")(x)\n", - " return keras.Model(inputs, outputs)\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "## Train and evaluate" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "input_shape = x_train.shape[1:]\n", - "\n", - "model = build_model(\n", - " input_shape,\n", - " head_size=256,\n", - " num_heads=4,\n", - " ff_dim=4,\n", - " num_transformer_blocks=4,\n", - " mlp_units=[128],\n", - " mlp_dropout=0.4,\n", - " dropout=0.25,\n", - ")\n", - "\n", - "model.compile(\n", - " loss=\"sparse_categorical_crossentropy\",\n", - " optimizer=keras.optimizers.Adam(learning_rate=1e-4),\n", - " metrics=[\"sparse_categorical_accuracy\"],\n", - ")\n", - "model.summary()\n", - "\n", - "callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]\n", - "\n", - "model.fit(\n", - " x_train,\n", - " y_train,\n", - " validation_split=0.2,\n", - " epochs=150,\n", - " batch_size=64,\n", - " callbacks=callbacks,\n", - ")\n", - "\n", - "model.evaluate(x_test, y_test, verbose=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "## Conclusions\n", - "\n", - "In about 110-120 epochs (25s each on Colab), the model reaches a training\n", - "accuracy of ~0.95, validation accuracy of ~84 and a testing\n", - "accuracy of ~85, without hyperparameter tuning. And that is for a model\n", - "with less than 100k parameters. Of course, parameter count and accuracy could be\n", - "improved by a hyperparameter search and a more sophisticated learning rate\n", - "schedule, or a different optimizer." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "timeseries_classification_transformer", - "private_outputs": false, - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/timeseries/timeseries_classification_transformer.py b/examples/timeseries/timeseries_classification_transformer.py index fde0aa5f73..b9d95bd242 100644 --- a/examples/timeseries/timeseries_classification_transformer.py +++ b/examples/timeseries/timeseries_classification_transformer.py @@ -1,10 +1,19 @@ """ -Title: Timeseries classification with a Transformer model -Author: [Theodoros Ntakouris](https://github.com/ntakouris) -Date created: 2021/06/25 -Last modified: 2021/08/05 -Description: This notebook demonstrates how to do timeseries classification using a Transformer model. -Accelerator: GPU +Title: FILLME +Author: FILLME +Date created: FILLME +Last modified: FILLME +Description: FILLME +""" + +""" +# Timeseries classification with a Transformer model + +**Author:** [Theodoros Ntakouris](https://github.com/ntakouris)
+**Date created:** 2021/06/25
+**Last modified:** 2024/12/18
+**Description:** This notebook demonstrates how to do timeseries classification using a +Transformer model. """ """ @@ -19,7 +28,8 @@ ## Load the dataset We are going to use the same dataset and preprocessing as the -[TimeSeries Classification from Scratch](https://keras.io/examples/timeseries/timeseries_classification_from_scratch) +[TimeSeries Classification from +Scratch](https://keras.io/examples/timeseries/timeseries_classification_from_scratch) example. """ @@ -111,8 +121,9 @@ def build_model( x = inputs for _ in range(num_transformer_blocks): x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout) - - x = layers.GlobalAveragePooling1D(data_format="channels_last")(x) + print(f"Transformer Encoder: {x}") + x = layers.GlobalAveragePooling1D(data_format="channels_first")(x) + print(f"Global Average Pooling: {x}") for dim in mlp_units: x = layers.Dense(dim, activation="relu")(x) x = layers.Dropout(mlp_dropout)(x) @@ -144,6 +155,7 @@ def build_model( ) model.summary() + callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)] model.fit( @@ -160,11 +172,10 @@ def build_model( """ ## Conclusions -In about 110-120 epochs (25s each on Colab), the model reaches a training -accuracy of ~0.95, validation accuracy of ~84 and a testing +In about 100-102 epochs (25s each on Colab), the model reaches a training +accuracy of ~0.94, validation accuracy of ~84 and a testing accuracy of ~85, without hyperparameter tuning. And that is for a model with less than 100k parameters. Of course, parameter count and accuracy could be improved by a hyperparameter search and a more sophisticated learning rate schedule, or a different optimizer. - """