Skip to content

Commit

Permalink
Update neural machine translation with transformer example for Keras 3.
Browse files Browse the repository at this point in the history
  • Loading branch information
hertschuh committed Nov 9, 2023
1 parent a7787f6 commit f9be98e
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"\n",
"**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
"**Date created:** 2021/05/26<br>\n",
"**Last modified:** 2023/08/17<br>\n",
"**Description:** Implementing a sequence-to-sequene Transformer and training it on a machine translation task."
"**Last modified:** 2023/02/25<br>\n",
"**Description:** Implementing a sequence-to-sequence Transformer and training it on a machine translation task."
]
},
{
Expand Down Expand Up @@ -59,15 +59,32 @@
},
"outputs": [],
"source": [
"# We set the backend to TensorFlow. The code works with\n",
"# both `tensorflow` and `torch`. It does not work with JAX\n",
"# due to the behavior of `jax.numpy.tile` in a jit scope\n",
"# (used in `TransformerDecoder.get_causal_attention_mask()`:\n",
"# `tile` in JAX does not support a dynamic `reps` argument.\n",
"# You can make the code work in JAX by wrapping the\n",
"# inside of the `get_causal_attention_mask` method in\n",
"# a decorator to prevent jit compilation:\n",
"# `with jax.ensure_compile_time_eval():`.\n",
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"\n",
"import pathlib\n",
"import random\n",
"import string\n",
"import re\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"from tensorflow.keras.layers import TextVectorization"
"\n",
"import tensorflow.data as tf_data\n",
"import tensorflow.strings as tf_strings\n",
"\n",
"import keras\n",
"from keras import layers\n",
"from keras import ops\n",
"from keras.layers import TextVectorization\n"
]
},
{
Expand Down Expand Up @@ -222,8 +239,8 @@
"\n",
"\n",
"def custom_standardization(input_string):\n",
" lowercase = tf.strings.lower(input_string)\n",
" return tf.strings.regex_replace(lowercase, \"[%s]\" % re.escape(strip_chars), \"\")\n",
" lowercase = tf_strings.lower(input_string)\n",
" return tf_strings.regex_replace(lowercase, \"[%s]\" % re.escape(strip_chars), \"\")\n",
"\n",
"\n",
"eng_vectorization = TextVectorization(\n",
Expand Down Expand Up @@ -257,7 +274,7 @@
"As such, the training dataset will yield a tuple `(inputs, targets)`, where:\n",
"\n",
"- `inputs` is a dictionary with the keys `encoder_inputs` and `decoder_inputs`.\n",
"`encoder_inputs` is the vectorized source sentence and `decoder_inputs` is the target sentence \"so far\",\n",
"`encoder_inputs` is the vectorized source sentence and `encoder_inputs` is the target sentence \"so far\",\n",
"that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.\n",
"- `target` is the target sentence offset by one step:\n",
"it provides the next words in the target sentence -- what the model will try to predict."
Expand Down Expand Up @@ -288,10 +305,10 @@
" eng_texts, spa_texts = zip(*pairs)\n",
" eng_texts = list(eng_texts)\n",
" spa_texts = list(spa_texts)\n",
" dataset = tf.data.Dataset.from_tensor_slices((eng_texts, spa_texts))\n",
" dataset = tf_data.Dataset.from_tensor_slices((eng_texts, spa_texts))\n",
" dataset = dataset.batch(batch_size)\n",
" dataset = dataset.map(format_dataset)\n",
" return dataset.shuffle(2048).prefetch(16).cache()\n",
" return dataset.cache().shuffle(2048).prefetch(16)\n",
"\n",
"\n",
"train_ds = make_dataset(train_pairs)\n",
Expand Down Expand Up @@ -341,7 +358,7 @@
"The `TransformerDecoder` will then seek to predict the next words in the target sequence (N+1 and beyond).\n",
"\n",
"A key detail that makes this possible is causal masking\n",
"(`use_causal_mask=True` in the first attention layer of the `TransformerDecoder`).\n",
"(see method `get_causal_attention_mask()` on the `TransformerDecoder`).\n",
"The `TransformerDecoder` sees the entire sequences at once, and thus we must make\n",
"sure that it only uses information from target tokens 0 to N when predicting token N+1\n",
"(otherwise, it could use information from the future, which would\n",
Expand All @@ -356,6 +373,8 @@
},
"outputs": [],
"source": [
"import keras.ops as ops\n",
"\n",
"\n",
"class TransformerEncoder(layers.Layer):\n",
" def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):\n",
Expand All @@ -377,7 +396,14 @@
" self.supports_masking = True\n",
"\n",
" def call(self, inputs, mask=None):\n",
" attention_output = self.attention(query=inputs, value=inputs, key=inputs)\n",
" if mask is not None:\n",
" padding_mask = ops.cast(mask[:, None, :], dtype=\"int32\")\n",
" else:\n",
" padding_mask = None\n",
"\n",
" attention_output = self.attention(\n",
" query=inputs, value=inputs, key=inputs, attention_mask=padding_mask\n",
" )\n",
" proj_input = self.layernorm_1(inputs + attention_output)\n",
" proj_output = self.dense_proj(proj_input)\n",
" return self.layernorm_2(proj_input + proj_output)\n",
Expand Down Expand Up @@ -408,14 +434,17 @@
" self.embed_dim = embed_dim\n",
"\n",
" def call(self, inputs):\n",
" length = tf.shape(inputs)[-1]\n",
" positions = tf.range(start=0, limit=length, delta=1)\n",
" length = ops.shape(inputs)[-1]\n",
" positions = ops.arange(0, length, 1)\n",
" embedded_tokens = self.token_embeddings(inputs)\n",
" embedded_positions = self.position_embeddings(positions)\n",
" return embedded_tokens + embedded_positions\n",
"\n",
" def compute_mask(self, inputs, mask=None):\n",
" return tf.math.not_equal(inputs, 0)\n",
" if mask is None:\n",
" return None\n",
" else:\n",
" return ops.not_equal(inputs, 0)\n",
"\n",
" def get_config(self):\n",
" config = super().get_config()\n",
Expand Down Expand Up @@ -450,24 +479,44 @@
" self.layernorm_1 = layers.LayerNormalization()\n",
" self.layernorm_2 = layers.LayerNormalization()\n",
" self.layernorm_3 = layers.LayerNormalization()\n",
" self.add = layers.Add() # instead of `+` to preserve mask\n",
" self.supports_masking = True\n",
"\n",
" def call(self, inputs, encoder_outputs, mask=None):\n",
" causal_mask = self.get_causal_attention_mask(inputs)\n",
" if mask is not None:\n",
" padding_mask = ops.cast(mask[:, None, :], dtype=\"int32\")\n",
" padding_mask = ops.minimum(padding_mask, causal_mask)\n",
" else:\n",
" padding_mask = None\n",
"\n",
" attention_output_1 = self.attention_1(\n",
" query=inputs, value=inputs, key=inputs, use_causal_mask=True\n",
" query=inputs, value=inputs, key=inputs, attention_mask=causal_mask\n",
" )\n",
" out_1 = self.layernorm_1(self.add([inputs, attention_output_1]))\n",
" out_1 = self.layernorm_1(inputs + attention_output_1)\n",
"\n",
" attention_output_2 = self.attention_2(\n",
" query=out_1,\n",
" value=encoder_outputs,\n",
" key=encoder_outputs,\n",
" attention_mask=padding_mask,\n",
" )\n",
" out_2 = self.layernorm_2(self.add([out_1, attention_output_2]))\n",
" out_2 = self.layernorm_2(out_1 + attention_output_2)\n",
"\n",
" proj_output = self.dense_proj(out_2)\n",
" return self.layernorm_3(self.add([out_2, proj_output]))\n",
" return self.layernorm_3(out_2 + proj_output)\n",
"\n",
" def get_causal_attention_mask(self, inputs):\n",
" input_shape = ops.shape(inputs)\n",
" batch_size, sequence_length = input_shape[0], input_shape[1]\n",
" i = ops.arange(sequence_length)[:, None]\n",
" j = ops.arange(sequence_length)\n",
" mask = ops.cast(i >= j, dtype=\"int32\")\n",
" mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))\n",
" mult = ops.concatenate(\n",
" [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],\n",
" axis=0,\n",
" )\n",
" return ops.tile(mask, mult)\n",
"\n",
" def get_config(self):\n",
" config = super().get_config()\n",
Expand Down Expand Up @@ -588,7 +637,10 @@
" tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]\n",
" predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])\n",
"\n",
" sampled_token_index = np.argmax(predictions[0, i, :])\n",
" # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here\n",
" sampled_token_index = ops.convert_to_numpy(\n",
" ops.argmax(predictions[0, i, :])\n",
" ).item(0)\n",
" sampled_token = spa_index_lookup[sampled_token_index]\n",
" decoded_sentence += \" \" + sampled_token\n",
"\n",
Expand Down Expand Up @@ -660,4 +712,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading

0 comments on commit f9be98e

Please sign in to comment.