\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Model: \"functional\"\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1mModel: \"functional\"\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ input_layer (InputLayer) │ (None, 500, 1) │ 0 │ - │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ multi_head_attention │ (None, 500, 1) │ 7,169 │ input_layer[0][0], │\n",
+ "│ (MultiHeadAttention) │ │ │ input_layer[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_1 (Dropout) │ (None, 500, 1) │ 0 │ multi_head_attention[… │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization │ (None, 500, 1) │ 2 │ dropout_1[0][0] │\n",
+ "│ (LayerNormalization) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add (Add) │ (None, 500, 1) │ 0 │ layer_normalization[0… │\n",
+ "│ │ │ │ input_layer[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d (Conv1D) │ (None, 500, 4) │ 8 │ add[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_2 (Dropout) │ (None, 500, 4) │ 0 │ conv1d[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_1 (Conv1D) │ (None, 500, 1) │ 5 │ dropout_2[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_1 │ (None, 500, 1) │ 2 │ conv1d_1[0][0] │\n",
+ "│ (LayerNormalization) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_1 (Add) │ (None, 500, 1) │ 0 │ layer_normalization_1… │\n",
+ "│ │ │ │ add[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ multi_head_attention_1 │ (None, 500, 1) │ 7,169 │ add_1[0][0], │\n",
+ "│ (MultiHeadAttention) │ │ │ add_1[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_4 (Dropout) │ (None, 500, 1) │ 0 │ multi_head_attention_… │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_2 │ (None, 500, 1) │ 2 │ dropout_4[0][0] │\n",
+ "│ (LayerNormalization) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_2 (Add) │ (None, 500, 1) │ 0 │ layer_normalization_2… │\n",
+ "│ │ │ │ add_1[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_2 (Conv1D) │ (None, 500, 4) │ 8 │ add_2[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_5 (Dropout) │ (None, 500, 4) │ 0 │ conv1d_2[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_3 (Conv1D) │ (None, 500, 1) │ 5 │ dropout_5[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_3 │ (None, 500, 1) │ 2 │ conv1d_3[0][0] │\n",
+ "│ (LayerNormalization) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_3 (Add) │ (None, 500, 1) │ 0 │ layer_normalization_3… │\n",
+ "│ │ │ │ add_2[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ multi_head_attention_2 │ (None, 500, 1) │ 7,169 │ add_3[0][0], │\n",
+ "│ (MultiHeadAttention) │ │ │ add_3[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_7 (Dropout) │ (None, 500, 1) │ 0 │ multi_head_attention_… │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_4 │ (None, 500, 1) │ 2 │ dropout_7[0][0] │\n",
+ "│ (LayerNormalization) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_4 (Add) │ (None, 500, 1) │ 0 │ layer_normalization_4… │\n",
+ "│ │ │ │ add_3[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_4 (Conv1D) │ (None, 500, 4) │ 8 │ add_4[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_8 (Dropout) │ (None, 500, 4) │ 0 │ conv1d_4[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_5 (Conv1D) │ (None, 500, 1) │ 5 │ dropout_8[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_5 │ (None, 500, 1) │ 2 │ conv1d_5[0][0] │\n",
+ "│ (LayerNormalization) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_5 (Add) │ (None, 500, 1) │ 0 │ layer_normalization_5… │\n",
+ "│ │ │ │ add_4[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ multi_head_attention_3 │ (None, 500, 1) │ 7,169 │ add_5[0][0], │\n",
+ "│ (MultiHeadAttention) │ │ │ add_5[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_10 (Dropout) │ (None, 500, 1) │ 0 │ multi_head_attention_… │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_6 │ (None, 500, 1) │ 2 │ dropout_10[0][0] │\n",
+ "│ (LayerNormalization) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_6 (Add) │ (None, 500, 1) │ 0 │ layer_normalization_6… │\n",
+ "│ │ │ │ add_5[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_6 (Conv1D) │ (None, 500, 4) │ 8 │ add_6[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_11 (Dropout) │ (None, 500, 4) │ 0 │ conv1d_6[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_7 (Conv1D) │ (None, 500, 1) │ 5 │ dropout_11[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_7 │ (None, 500, 1) │ 2 │ conv1d_7[0][0] │\n",
+ "│ (LayerNormalization) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_7 (Add) │ (None, 500, 1) │ 0 │ layer_normalization_7… │\n",
+ "│ │ │ │ add_6[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ global_average_pooling1d │ (None, 500) │ 0 │ add_7[0][0] │\n",
+ "│ (GlobalAveragePooling1D) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dense (Dense) │ (None, 128) │ 64,128 │ global_average_poolin… │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_12 (Dropout) │ (None, 128) │ 0 │ dense[0][0] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dense_1 (Dense) │ (None, 2) │ 258 │ dropout_12[0][0] │\n",
+ "└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ input_layer (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ multi_head_attention │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m7,169\u001b[0m │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
+ "│ (\u001b[38;5;33mMultiHeadAttention\u001b[0m) │ │ │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_1 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ multi_head_attention[\u001b[38;5;34m…\u001b[0m │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ dropout_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization[\u001b[38;5;34m0\u001b[0m… │\n",
+ "│ │ │ │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m8\u001b[0m │ add[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_2 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv1d[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_1 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m5\u001b[0m │ dropout_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ conv1d_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_1 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_1… │\n",
+ "│ │ │ │ add[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ multi_head_attention_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m7,169\u001b[0m │ add_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
+ "│ (\u001b[38;5;33mMultiHeadAttention\u001b[0m) │ │ │ add_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_4 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ multi_head_attention_… │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ dropout_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_2 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_2… │\n",
+ "│ │ │ │ add_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_2 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m8\u001b[0m │ add_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_5 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv1d_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_3 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m5\u001b[0m │ dropout_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ conv1d_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_3 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_3… │\n",
+ "│ │ │ │ add_2[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ multi_head_attention_2 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m7,169\u001b[0m │ add_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
+ "│ (\u001b[38;5;33mMultiHeadAttention\u001b[0m) │ │ │ add_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_7 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ multi_head_attention_… │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_4 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ dropout_7[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_4 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_4… │\n",
+ "│ │ │ │ add_3[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_4 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m8\u001b[0m │ add_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_8 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv1d_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_5 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m5\u001b[0m │ dropout_8[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_5 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ conv1d_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_5 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_5… │\n",
+ "│ │ │ │ add_4[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ multi_head_attention_3 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m7,169\u001b[0m │ add_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
+ "│ (\u001b[38;5;33mMultiHeadAttention\u001b[0m) │ │ │ add_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_10 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ multi_head_attention_… │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_6 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ dropout_10[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_6 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_6… │\n",
+ "│ │ │ │ add_5[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_6 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m8\u001b[0m │ add_6[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_11 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m4\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ conv1d_6[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ conv1d_7 (\u001b[38;5;33mConv1D\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m5\u001b[0m │ dropout_11[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ layer_normalization_7 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m2\u001b[0m │ conv1d_7[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ (\u001b[38;5;33mLayerNormalization\u001b[0m) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ add_7 (\u001b[38;5;33mAdd\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ layer_normalization_7… │\n",
+ "│ │ │ │ add_6[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ global_average_pooling1d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m500\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ add_7[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "│ (\u001b[38;5;33mGlobalAveragePooling1D\u001b[0m) │ │ │ │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m64,128\u001b[0m │ global_average_poolin… │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dropout_12 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dense[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
+ "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m258\u001b[0m │ dropout_12[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
+ "└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " Total params: 93,130 (363.79 KB)\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m93,130\u001b[0m (363.79 KB)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " Trainable params: 93,130 (363.79 KB)\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m93,130\u001b[0m (363.79 KB)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " Non-trainable params: 0 (0.00 B)\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "input_shape = x_train.shape[1:]\n",
+ "\n",
+ "model = build_model(\n",
+ " input_shape,\n",
+ " head_size=256,\n",
+ " num_heads=4,\n",
+ " ff_dim=4,\n",
+ " num_transformer_blocks=4,\n",
+ " mlp_units=[128],\n",
+ " mlp_dropout=0.4,\n",
+ " dropout=0.25,\n",
+ ")\n",
+ "\n",
+ "model.compile(\n",
+ " loss=\"sparse_categorical_crossentropy\",\n",
+ " optimizer=keras.optimizers.Adam(learning_rate=1e-4),\n",
+ " metrics=[\"sparse_categorical_accuracy\"],\n",
+ ")\n",
+ "model.summary()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "eDOvm4BR4w3M",
+ "outputId": "fa53a57d-d9af-41b6-f0bd-5dc4d7dfb621"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch 1/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m54s\u001b[0m 401ms/step - loss: 1.1535 - sparse_categorical_accuracy: 0.4893 - val_loss: 0.7982 - val_sparse_categorical_accuracy: 0.5298\n",
+ "Epoch 2/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 352ms/step - loss: 0.8630 - sparse_categorical_accuracy: 0.5636 - val_loss: 0.7076 - val_sparse_categorical_accuracy: 0.6019\n",
+ "Epoch 3/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 341ms/step - loss: 0.8010 - sparse_categorical_accuracy: 0.5744 - val_loss: 0.6721 - val_sparse_categorical_accuracy: 0.6283\n",
+ "Epoch 4/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 352ms/step - loss: 0.7667 - sparse_categorical_accuracy: 0.6024 - val_loss: 0.6411 - val_sparse_categorical_accuracy: 0.6546\n",
+ "Epoch 5/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 341ms/step - loss: 0.6981 - sparse_categorical_accuracy: 0.6420 - val_loss: 0.6224 - val_sparse_categorical_accuracy: 0.6574\n",
+ "Epoch 6/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 334ms/step - loss: 0.6731 - sparse_categorical_accuracy: 0.6317 - val_loss: 0.6064 - val_sparse_categorical_accuracy: 0.6713\n",
+ "Epoch 7/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 338ms/step - loss: 0.6521 - sparse_categorical_accuracy: 0.6693 - val_loss: 0.5913 - val_sparse_categorical_accuracy: 0.6727\n",
+ "Epoch 8/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 373ms/step - loss: 0.6322 - sparse_categorical_accuracy: 0.6663 - val_loss: 0.5807 - val_sparse_categorical_accuracy: 0.6893\n",
+ "Epoch 9/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 371ms/step - loss: 0.5975 - sparse_categorical_accuracy: 0.6800 - val_loss: 0.5698 - val_sparse_categorical_accuracy: 0.6976\n",
+ "Epoch 10/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 339ms/step - loss: 0.6067 - sparse_categorical_accuracy: 0.6786 - val_loss: 0.5639 - val_sparse_categorical_accuracy: 0.7046\n",
+ "Epoch 11/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 343ms/step - loss: 0.5685 - sparse_categorical_accuracy: 0.7052 - val_loss: 0.5599 - val_sparse_categorical_accuracy: 0.7101\n",
+ "Epoch 12/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 337ms/step - loss: 0.5386 - sparse_categorical_accuracy: 0.7203 - val_loss: 0.5500 - val_sparse_categorical_accuracy: 0.7198\n",
+ "Epoch 13/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 340ms/step - loss: 0.5474 - sparse_categorical_accuracy: 0.7134 - val_loss: 0.5431 - val_sparse_categorical_accuracy: 0.7198\n",
+ "Epoch 14/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 350ms/step - loss: 0.5267 - sparse_categorical_accuracy: 0.7421 - val_loss: 0.5363 - val_sparse_categorical_accuracy: 0.7295\n",
+ "Epoch 15/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 372ms/step - loss: 0.5117 - sparse_categorical_accuracy: 0.7442 - val_loss: 0.5323 - val_sparse_categorical_accuracy: 0.7337\n",
+ "Epoch 16/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 366ms/step - loss: 0.5199 - sparse_categorical_accuracy: 0.7359 - val_loss: 0.5258 - val_sparse_categorical_accuracy: 0.7323\n",
+ "Epoch 17/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 362ms/step - loss: 0.5032 - sparse_categorical_accuracy: 0.7527 - val_loss: 0.5224 - val_sparse_categorical_accuracy: 0.7351\n",
+ "Epoch 18/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 375ms/step - loss: 0.4854 - sparse_categorical_accuracy: 0.7615 - val_loss: 0.5164 - val_sparse_categorical_accuracy: 0.7406\n",
+ "Epoch 19/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 346ms/step - loss: 0.4763 - sparse_categorical_accuracy: 0.7718 - val_loss: 0.5121 - val_sparse_categorical_accuracy: 0.7434\n",
+ "Epoch 20/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 370ms/step - loss: 0.4577 - sparse_categorical_accuracy: 0.7865 - val_loss: 0.5084 - val_sparse_categorical_accuracy: 0.7476\n",
+ "Epoch 21/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 333ms/step - loss: 0.4633 - sparse_categorical_accuracy: 0.7759 - val_loss: 0.5052 - val_sparse_categorical_accuracy: 0.7365\n",
+ "Epoch 22/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 371ms/step - loss: 0.4557 - sparse_categorical_accuracy: 0.7855 - val_loss: 0.4984 - val_sparse_categorical_accuracy: 0.7462\n",
+ "Epoch 23/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 368ms/step - loss: 0.4451 - sparse_categorical_accuracy: 0.7897 - val_loss: 0.4939 - val_sparse_categorical_accuracy: 0.7503\n",
+ "Epoch 24/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 370ms/step - loss: 0.4348 - sparse_categorical_accuracy: 0.8004 - val_loss: 0.4916 - val_sparse_categorical_accuracy: 0.7628\n",
+ "Epoch 25/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 342ms/step - loss: 0.4248 - sparse_categorical_accuracy: 0.8200 - val_loss: 0.4855 - val_sparse_categorical_accuracy: 0.7587\n",
+ "Epoch 26/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 368ms/step - loss: 0.4233 - sparse_categorical_accuracy: 0.8066 - val_loss: 0.4833 - val_sparse_categorical_accuracy: 0.7628\n",
+ "Epoch 27/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 342ms/step - loss: 0.4194 - sparse_categorical_accuracy: 0.8025 - val_loss: 0.4799 - val_sparse_categorical_accuracy: 0.7587\n",
+ "Epoch 28/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 342ms/step - loss: 0.3837 - sparse_categorical_accuracy: 0.8354 - val_loss: 0.4757 - val_sparse_categorical_accuracy: 0.7628\n",
+ "Epoch 29/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 368ms/step - loss: 0.3742 - sparse_categorical_accuracy: 0.8414 - val_loss: 0.4733 - val_sparse_categorical_accuracy: 0.7656\n",
+ "Epoch 30/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 342ms/step - loss: 0.3777 - sparse_categorical_accuracy: 0.8434 - val_loss: 0.4722 - val_sparse_categorical_accuracy: 0.7698\n",
+ "Epoch 31/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 342ms/step - loss: 0.3796 - sparse_categorical_accuracy: 0.8412 - val_loss: 0.4660 - val_sparse_categorical_accuracy: 0.7712\n",
+ "Epoch 32/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 369ms/step - loss: 0.3903 - sparse_categorical_accuracy: 0.8277 - val_loss: 0.4629 - val_sparse_categorical_accuracy: 0.7725\n",
+ "Epoch 33/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 345ms/step - loss: 0.3521 - sparse_categorical_accuracy: 0.8533 - val_loss: 0.4605 - val_sparse_categorical_accuracy: 0.7684\n",
+ "Epoch 34/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 336ms/step - loss: 0.3624 - sparse_categorical_accuracy: 0.8480 - val_loss: 0.4596 - val_sparse_categorical_accuracy: 0.7725\n",
+ "Epoch 35/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 371ms/step - loss: 0.3579 - sparse_categorical_accuracy: 0.8444 - val_loss: 0.4543 - val_sparse_categorical_accuracy: 0.7809\n",
+ "Epoch 36/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 371ms/step - loss: 0.3583 - sparse_categorical_accuracy: 0.8565 - val_loss: 0.4502 - val_sparse_categorical_accuracy: 0.7864\n",
+ "Epoch 37/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 363ms/step - loss: 0.3323 - sparse_categorical_accuracy: 0.8678 - val_loss: 0.4474 - val_sparse_categorical_accuracy: 0.7878\n",
+ "Epoch 38/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 368ms/step - loss: 0.3440 - sparse_categorical_accuracy: 0.8574 - val_loss: 0.4444 - val_sparse_categorical_accuracy: 0.7892\n",
+ "Epoch 39/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 375ms/step - loss: 0.3346 - sparse_categorical_accuracy: 0.8589 - val_loss: 0.4429 - val_sparse_categorical_accuracy: 0.7975\n",
+ "Epoch 40/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 342ms/step - loss: 0.3274 - sparse_categorical_accuracy: 0.8565 - val_loss: 0.4393 - val_sparse_categorical_accuracy: 0.7975\n",
+ "Epoch 41/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 366ms/step - loss: 0.3294 - sparse_categorical_accuracy: 0.8717 - val_loss: 0.4349 - val_sparse_categorical_accuracy: 0.8031\n",
+ "Epoch 42/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 368ms/step - loss: 0.3222 - sparse_categorical_accuracy: 0.8725 - val_loss: 0.4323 - val_sparse_categorical_accuracy: 0.8031\n",
+ "Epoch 43/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 373ms/step - loss: 0.3095 - sparse_categorical_accuracy: 0.8963 - val_loss: 0.4315 - val_sparse_categorical_accuracy: 0.8003\n",
+ "Epoch 44/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 340ms/step - loss: 0.2991 - sparse_categorical_accuracy: 0.8872 - val_loss: 0.4290 - val_sparse_categorical_accuracy: 0.8100\n",
+ "Epoch 45/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 339ms/step - loss: 0.3148 - sparse_categorical_accuracy: 0.8842 - val_loss: 0.4271 - val_sparse_categorical_accuracy: 0.8072\n",
+ "Epoch 46/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 342ms/step - loss: 0.3020 - sparse_categorical_accuracy: 0.8880 - val_loss: 0.4237 - val_sparse_categorical_accuracy: 0.8100\n",
+ "Epoch 47/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 372ms/step - loss: 0.2914 - sparse_categorical_accuracy: 0.8968 - val_loss: 0.4184 - val_sparse_categorical_accuracy: 0.8114\n",
+ "Epoch 48/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 369ms/step - loss: 0.2889 - sparse_categorical_accuracy: 0.8976 - val_loss: 0.4179 - val_sparse_categorical_accuracy: 0.8141\n",
+ "Epoch 49/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 369ms/step - loss: 0.2961 - sparse_categorical_accuracy: 0.8957 - val_loss: 0.4166 - val_sparse_categorical_accuracy: 0.8197\n",
+ "Epoch 50/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 369ms/step - loss: 0.2956 - sparse_categorical_accuracy: 0.8883 - val_loss: 0.4146 - val_sparse_categorical_accuracy: 0.8197\n",
+ "Epoch 51/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 347ms/step - loss: 0.2708 - sparse_categorical_accuracy: 0.9043 - val_loss: 0.4144 - val_sparse_categorical_accuracy: 0.8183\n",
+ "Epoch 52/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 345ms/step - loss: 0.2883 - sparse_categorical_accuracy: 0.8916 - val_loss: 0.4095 - val_sparse_categorical_accuracy: 0.8183\n",
+ "Epoch 53/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 339ms/step - loss: 0.2698 - sparse_categorical_accuracy: 0.8986 - val_loss: 0.4055 - val_sparse_categorical_accuracy: 0.8225\n",
+ "Epoch 54/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 333ms/step - loss: 0.2878 - sparse_categorical_accuracy: 0.8807 - val_loss: 0.4033 - val_sparse_categorical_accuracy: 0.8239\n",
+ "Epoch 55/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 372ms/step - loss: 0.2751 - sparse_categorical_accuracy: 0.9022 - val_loss: 0.4020 - val_sparse_categorical_accuracy: 0.8294\n",
+ "Epoch 56/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 341ms/step - loss: 0.2604 - sparse_categorical_accuracy: 0.9021 - val_loss: 0.4014 - val_sparse_categorical_accuracy: 0.8266\n",
+ "Epoch 57/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 368ms/step - loss: 0.2639 - sparse_categorical_accuracy: 0.9062 - val_loss: 0.3990 - val_sparse_categorical_accuracy: 0.8239\n",
+ "Epoch 58/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 373ms/step - loss: 0.2597 - sparse_categorical_accuracy: 0.8953 - val_loss: 0.4008 - val_sparse_categorical_accuracy: 0.8322\n",
+ "Epoch 59/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 344ms/step - loss: 0.2633 - sparse_categorical_accuracy: 0.9069 - val_loss: 0.3969 - val_sparse_categorical_accuracy: 0.8308\n",
+ "Epoch 60/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 364ms/step - loss: 0.2682 - sparse_categorical_accuracy: 0.8988 - val_loss: 0.3959 - val_sparse_categorical_accuracy: 0.8322\n",
+ "Epoch 61/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 340ms/step - loss: 0.2789 - sparse_categorical_accuracy: 0.8952 - val_loss: 0.3948 - val_sparse_categorical_accuracy: 0.8280\n",
+ "Epoch 62/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 352ms/step - loss: 0.2493 - sparse_categorical_accuracy: 0.9158 - val_loss: 0.3915 - val_sparse_categorical_accuracy: 0.8308\n",
+ "Epoch 63/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 367ms/step - loss: 0.2594 - sparse_categorical_accuracy: 0.9033 - val_loss: 0.3909 - val_sparse_categorical_accuracy: 0.8239\n",
+ "Epoch 64/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 368ms/step - loss: 0.2526 - sparse_categorical_accuracy: 0.9003 - val_loss: 0.3869 - val_sparse_categorical_accuracy: 0.8322\n",
+ "Epoch 65/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 367ms/step - loss: 0.2413 - sparse_categorical_accuracy: 0.9174 - val_loss: 0.3876 - val_sparse_categorical_accuracy: 0.8266\n",
+ "Epoch 66/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 367ms/step - loss: 0.2448 - sparse_categorical_accuracy: 0.9144 - val_loss: 0.3879 - val_sparse_categorical_accuracy: 0.8252\n",
+ "Epoch 67/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 366ms/step - loss: 0.2592 - sparse_categorical_accuracy: 0.9118 - val_loss: 0.3851 - val_sparse_categorical_accuracy: 0.8336\n",
+ "Epoch 68/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 372ms/step - loss: 0.2469 - sparse_categorical_accuracy: 0.9032 - val_loss: 0.3866 - val_sparse_categorical_accuracy: 0.8308\n",
+ "Epoch 69/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 369ms/step - loss: 0.2234 - sparse_categorical_accuracy: 0.9246 - val_loss: 0.3806 - val_sparse_categorical_accuracy: 0.8294\n",
+ "Epoch 70/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 367ms/step - loss: 0.2204 - sparse_categorical_accuracy: 0.9304 - val_loss: 0.3804 - val_sparse_categorical_accuracy: 0.8322\n",
+ "Epoch 71/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 371ms/step - loss: 0.2259 - sparse_categorical_accuracy: 0.9244 - val_loss: 0.3782 - val_sparse_categorical_accuracy: 0.8322\n",
+ "Epoch 72/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 366ms/step - loss: 0.2165 - sparse_categorical_accuracy: 0.9274 - val_loss: 0.3779 - val_sparse_categorical_accuracy: 0.8336\n",
+ "Epoch 73/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 343ms/step - loss: 0.2285 - sparse_categorical_accuracy: 0.9248 - val_loss: 0.3760 - val_sparse_categorical_accuracy: 0.8391\n",
+ "Epoch 74/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 371ms/step - loss: 0.2211 - sparse_categorical_accuracy: 0.9183 - val_loss: 0.3762 - val_sparse_categorical_accuracy: 0.8377\n",
+ "Epoch 75/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 335ms/step - loss: 0.2274 - sparse_categorical_accuracy: 0.9190 - val_loss: 0.3753 - val_sparse_categorical_accuracy: 0.8350\n",
+ "Epoch 76/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 338ms/step - loss: 0.2096 - sparse_categorical_accuracy: 0.9314 - val_loss: 0.3737 - val_sparse_categorical_accuracy: 0.8350\n",
+ "Epoch 77/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m22s\u001b[0m 373ms/step - loss: 0.2072 - sparse_categorical_accuracy: 0.9336 - val_loss: 0.3751 - val_sparse_categorical_accuracy: 0.8336\n",
+ "Epoch 78/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 372ms/step - loss: 0.2125 - sparse_categorical_accuracy: 0.9287 - val_loss: 0.3735 - val_sparse_categorical_accuracy: 0.8350\n",
+ "Epoch 79/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 341ms/step - loss: 0.1964 - sparse_categorical_accuracy: 0.9248 - val_loss: 0.3710 - val_sparse_categorical_accuracy: 0.8350\n",
+ "Epoch 80/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 341ms/step - loss: 0.2226 - sparse_categorical_accuracy: 0.9141 - val_loss: 0.3678 - val_sparse_categorical_accuracy: 0.8405\n",
+ "Epoch 81/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 374ms/step - loss: 0.2151 - sparse_categorical_accuracy: 0.9234 - val_loss: 0.3679 - val_sparse_categorical_accuracy: 0.8377\n",
+ "Epoch 82/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 343ms/step - loss: 0.2053 - sparse_categorical_accuracy: 0.9285 - val_loss: 0.3687 - val_sparse_categorical_accuracy: 0.8363\n",
+ "Epoch 83/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 335ms/step - loss: 0.2062 - sparse_categorical_accuracy: 0.9326 - val_loss: 0.3638 - val_sparse_categorical_accuracy: 0.8363\n",
+ "Epoch 84/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 339ms/step - loss: 0.2051 - sparse_categorical_accuracy: 0.9268 - val_loss: 0.3663 - val_sparse_categorical_accuracy: 0.8350\n",
+ "Epoch 85/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 347ms/step - loss: 0.1903 - sparse_categorical_accuracy: 0.9364 - val_loss: 0.3627 - val_sparse_categorical_accuracy: 0.8350\n",
+ "Epoch 86/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 342ms/step - loss: 0.2070 - sparse_categorical_accuracy: 0.9210 - val_loss: 0.3616 - val_sparse_categorical_accuracy: 0.8363\n",
+ "Epoch 87/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 344ms/step - loss: 0.1906 - sparse_categorical_accuracy: 0.9325 - val_loss: 0.3611 - val_sparse_categorical_accuracy: 0.8377\n",
+ "Epoch 88/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 363ms/step - loss: 0.1934 - sparse_categorical_accuracy: 0.9334 - val_loss: 0.3584 - val_sparse_categorical_accuracy: 0.8391\n",
+ "Epoch 89/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 369ms/step - loss: 0.1987 - sparse_categorical_accuracy: 0.9330 - val_loss: 0.3603 - val_sparse_categorical_accuracy: 0.8322\n",
+ "Epoch 90/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 368ms/step - loss: 0.1913 - sparse_categorical_accuracy: 0.9335 - val_loss: 0.3551 - val_sparse_categorical_accuracy: 0.8377\n",
+ "Epoch 91/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 368ms/step - loss: 0.1960 - sparse_categorical_accuracy: 0.9313 - val_loss: 0.3568 - val_sparse_categorical_accuracy: 0.8447\n",
+ "Epoch 92/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 370ms/step - loss: 0.1989 - sparse_categorical_accuracy: 0.9260 - val_loss: 0.3508 - val_sparse_categorical_accuracy: 0.8433\n",
+ "Epoch 93/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 369ms/step - loss: 0.1978 - sparse_categorical_accuracy: 0.9303 - val_loss: 0.3582 - val_sparse_categorical_accuracy: 0.8447\n",
+ "Epoch 94/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 368ms/step - loss: 0.1761 - sparse_categorical_accuracy: 0.9435 - val_loss: 0.3536 - val_sparse_categorical_accuracy: 0.8391\n",
+ "Epoch 95/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 341ms/step - loss: 0.1945 - sparse_categorical_accuracy: 0.9276 - val_loss: 0.3527 - val_sparse_categorical_accuracy: 0.8377\n",
+ "Epoch 96/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 345ms/step - loss: 0.1700 - sparse_categorical_accuracy: 0.9453 - val_loss: 0.3540 - val_sparse_categorical_accuracy: 0.8391\n",
+ "Epoch 97/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 365ms/step - loss: 0.1737 - sparse_categorical_accuracy: 0.9419 - val_loss: 0.3532 - val_sparse_categorical_accuracy: 0.8336\n",
+ "Epoch 98/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 365ms/step - loss: 0.1868 - sparse_categorical_accuracy: 0.9330 - val_loss: 0.3536 - val_sparse_categorical_accuracy: 0.8377\n",
+ "Epoch 99/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 371ms/step - loss: 0.1722 - sparse_categorical_accuracy: 0.9450 - val_loss: 0.3532 - val_sparse_categorical_accuracy: 0.8377\n",
+ "Epoch 100/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 371ms/step - loss: 0.1893 - sparse_categorical_accuracy: 0.9311 - val_loss: 0.3526 - val_sparse_categorical_accuracy: 0.8350\n",
+ "Epoch 101/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 371ms/step - loss: 0.1819 - sparse_categorical_accuracy: 0.9418 - val_loss: 0.3552 - val_sparse_categorical_accuracy: 0.8391\n",
+ "Epoch 102/150\n",
+ "\u001b[1m45/45\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 368ms/step - loss: 0.1700 - sparse_categorical_accuracy: 0.9460 - val_loss: 0.3525 - val_sparse_categorical_accuracy: 0.8447\n",
+ "\u001b[1m42/42\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 103ms/step - loss: 0.3497 - sparse_categorical_accuracy: 0.8483\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[0.36210155487060547, 0.8446969985961914]"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\n",
+ "\n",
+ "callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]\n",
+ "\n",
+ "model.fit(\n",
+ " x_train,\n",
+ " y_train,\n",
+ " validation_split=0.2,\n",
+ " epochs=150,\n",
+ " batch_size=64,\n",
+ " callbacks=callbacks,\n",
+ ")\n",
+ "\n",
+ "model.evaluate(x_test, y_test, verbose=1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "B7nRTbzn4w3M"
+ },
+ "source": [
+ "## Conclusions\n",
+ "\n",
+ "In about 100-102 epochs (25s each on Colab), the model reaches a training\n",
+ "accuracy of ~0.94, validation accuracy of ~84 and a testing\n",
+ "accuracy of ~85, without hyperparameter tuning. And that is for a model\n",
+ "with less than 100k parameters. Of course, parameter count and accuracy could be\n",
+ "improved by a hyperparameter search and a more sophisticated learning rate\n",
+ "schedule, or a different optimizer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "GZeFCnD6YW1l"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
},
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "## Introduction\n",
- "\n",
- "This is the Transformer architecture from\n",
- "[Attention Is All You Need](https://arxiv.org/abs/1706.03762),\n",
- "applied to timeseries instead of natural language.\n",
- "\n",
- "This example requires TensorFlow 2.4 or higher.\n",
- "\n",
- "## Load the dataset\n",
- "\n",
- "We are going to use the same dataset and preprocessing as the\n",
- "[TimeSeries Classification from Scratch](https://keras.io/examples/timeseries/timeseries_classification_from_scratch)\n",
- "example."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import keras\n",
- "from keras import layers\n",
- "\n",
- "\n",
- "def readucr(filename):\n",
- " data = np.loadtxt(filename, delimiter=\"\\t\")\n",
- " y = data[:, 0]\n",
- " x = data[:, 1:]\n",
- " return x, y.astype(int)\n",
- "\n",
- "\n",
- "root_url = \"https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/\"\n",
- "\n",
- "x_train, y_train = readucr(root_url + \"FordA_TRAIN.tsv\")\n",
- "x_test, y_test = readucr(root_url + \"FordA_TEST.tsv\")\n",
- "\n",
- "x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))\n",
- "x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))\n",
- "\n",
- "n_classes = len(np.unique(y_train))\n",
- "\n",
- "idx = np.random.permutation(len(x_train))\n",
- "x_train = x_train[idx]\n",
- "y_train = y_train[idx]\n",
- "\n",
- "y_train[y_train == -1] = 0\n",
- "y_test[y_test == -1] = 0"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "## Build the model\n",
- "\n",
- "Our model processes a tensor of shape `(batch size, sequence length, features)`,\n",
- "where `sequence length` is the number of time steps and `features` is each input\n",
- "timeseries.\n",
- "\n",
- "You can replace your classification RNN layers with this one: the\n",
- "inputs are fully compatible!\n",
- "\n",
- "We include residual connections, layer normalization, and dropout.\n",
- "The resulting layer can be stacked multiple times.\n",
- "\n",
- "The projection layers are implemented through `keras.layers.Conv1D`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "\n",
- "def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):\n",
- " # Attention and Normalization\n",
- " x = layers.MultiHeadAttention(\n",
- " key_dim=head_size, num_heads=num_heads, dropout=dropout\n",
- " )(inputs, inputs)\n",
- " x = layers.Dropout(dropout)(x)\n",
- " x = layers.LayerNormalization(epsilon=1e-6)(x)\n",
- " res = x + inputs\n",
- "\n",
- " # Feed Forward Part\n",
- " x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation=\"relu\")(res)\n",
- " x = layers.Dropout(dropout)(x)\n",
- " x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)\n",
- " x = layers.LayerNormalization(epsilon=1e-6)(x)\n",
- " return x + res\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "The main part of our model is now complete. We can stack multiple of those\n",
- "`transformer_encoder` blocks and we can also proceed to add the final\n",
- "Multi-Layer Perceptron classification head. Apart from a stack of `Dense`\n",
- "layers, we need to reduce the output tensor of the `TransformerEncoder` part of\n",
- "our model down to a vector of features for each data point in the current\n",
- "batch. A common way to achieve this is to use a pooling layer. For\n",
- "this example, a `GlobalAveragePooling1D` layer is sufficient."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "\n",
- "def build_model(\n",
- " input_shape,\n",
- " head_size,\n",
- " num_heads,\n",
- " ff_dim,\n",
- " num_transformer_blocks,\n",
- " mlp_units,\n",
- " dropout=0,\n",
- " mlp_dropout=0,\n",
- "):\n",
- " inputs = keras.Input(shape=input_shape)\n",
- " x = inputs\n",
- " for _ in range(num_transformer_blocks):\n",
- " x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)\n",
- "\n",
- " x = layers.GlobalAveragePooling1D(data_format=\"channels_last\")(x)\n",
- " for dim in mlp_units:\n",
- " x = layers.Dense(dim, activation=\"relu\")(x)\n",
- " x = layers.Dropout(mlp_dropout)(x)\n",
- " outputs = layers.Dense(n_classes, activation=\"softmax\")(x)\n",
- " return keras.Model(inputs, outputs)\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "## Train and evaluate"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "input_shape = x_train.shape[1:]\n",
- "\n",
- "model = build_model(\n",
- " input_shape,\n",
- " head_size=256,\n",
- " num_heads=4,\n",
- " ff_dim=4,\n",
- " num_transformer_blocks=4,\n",
- " mlp_units=[128],\n",
- " mlp_dropout=0.4,\n",
- " dropout=0.25,\n",
- ")\n",
- "\n",
- "model.compile(\n",
- " loss=\"sparse_categorical_crossentropy\",\n",
- " optimizer=keras.optimizers.Adam(learning_rate=1e-4),\n",
- " metrics=[\"sparse_categorical_accuracy\"],\n",
- ")\n",
- "model.summary()\n",
- "\n",
- "callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]\n",
- "\n",
- "model.fit(\n",
- " x_train,\n",
- " y_train,\n",
- " validation_split=0.2,\n",
- " epochs=150,\n",
- " batch_size=64,\n",
- " callbacks=callbacks,\n",
- ")\n",
- "\n",
- "model.evaluate(x_test, y_test, verbose=1)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "## Conclusions\n",
- "\n",
- "In about 110-120 epochs (25s each on Colab), the model reaches a training\n",
- "accuracy of ~0.95, validation accuracy of ~84 and a testing\n",
- "accuracy of ~85, without hyperparameter tuning. And that is for a model\n",
- "with less than 100k parameters. Of course, parameter count and accuracy could be\n",
- "improved by a hyperparameter search and a more sophisticated learning rate\n",
- "schedule, or a different optimizer."
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "collapsed_sections": [],
- "name": "timeseries_classification_transformer",
- "private_outputs": false,
- "provenance": [],
- "toc_visible": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
+ "nbformat": 4,
+ "nbformat_minor": 0
}
diff --git a/examples/timeseries/timeseries_classification_transformer.py b/examples/timeseries/timeseries_classification_transformer.py
index fde0aa5f73..b9d95bd242 100644
--- a/examples/timeseries/timeseries_classification_transformer.py
+++ b/examples/timeseries/timeseries_classification_transformer.py
@@ -1,10 +1,19 @@
"""
-Title: Timeseries classification with a Transformer model
-Author: [Theodoros Ntakouris](https://github.com/ntakouris)
-Date created: 2021/06/25
-Last modified: 2021/08/05
-Description: This notebook demonstrates how to do timeseries classification using a Transformer model.
-Accelerator: GPU
+Title: FILLME
+Author: FILLME
+Date created: FILLME
+Last modified: FILLME
+Description: FILLME
+"""
+
+"""
+# Timeseries classification with a Transformer model
+
+**Author:** [Theodoros Ntakouris](https://github.com/ntakouris)
+**Date created:** 2021/06/25
+**Last modified:** 2024/12/18
+**Description:** This notebook demonstrates how to do timeseries classification using a
+Transformer model.
"""
"""
@@ -19,7 +28,8 @@
## Load the dataset
We are going to use the same dataset and preprocessing as the
-[TimeSeries Classification from Scratch](https://keras.io/examples/timeseries/timeseries_classification_from_scratch)
+[TimeSeries Classification from
+Scratch](https://keras.io/examples/timeseries/timeseries_classification_from_scratch)
example.
"""
@@ -111,8 +121,9 @@ def build_model(
x = inputs
for _ in range(num_transformer_blocks):
x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)
-
- x = layers.GlobalAveragePooling1D(data_format="channels_last")(x)
+ print(f"Transformer Encoder: {x}")
+ x = layers.GlobalAveragePooling1D(data_format="channels_first")(x)
+ print(f"Global Average Pooling: {x}")
for dim in mlp_units:
x = layers.Dense(dim, activation="relu")(x)
x = layers.Dropout(mlp_dropout)(x)
@@ -144,6 +155,7 @@ def build_model(
)
model.summary()
+
callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]
model.fit(
@@ -160,11 +172,10 @@ def build_model(
"""
## Conclusions
-In about 110-120 epochs (25s each on Colab), the model reaches a training
-accuracy of ~0.95, validation accuracy of ~84 and a testing
+In about 100-102 epochs (25s each on Colab), the model reaches a training
+accuracy of ~0.94, validation accuracy of ~84 and a testing
accuracy of ~85, without hyperparameter tuning. And that is for a model
with less than 100k parameters. Of course, parameter count and accuracy could be
improved by a hyperparameter search and a more sophisticated learning rate
schedule, or a different optimizer.
-
"""