Skip to content

Commit

Permalink
Add multi-backend examples/generative/text_generation_with_miniature_gpt
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Nov 9, 2023
1 parent a7787f6 commit f9aebfa
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 146 deletions.
82 changes: 51 additions & 31 deletions examples/generative/ipynb/text_generation_with_miniature_gpt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,30 @@
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"from tensorflow.keras.layers import TextVectorization\n",
"# We set the backend to TensorFlow. The code works with\n",
"# both `tensorflow` and `torch`. It does not work with JAX\n",
"# due to the behavior of `jax.numpy.tile` in a jit scope\n",
"# (used in `causal_attention_mask()`: `tile` in JAX does\n",
"# not support a dynamic `reps` argument.\n",
"# You can make the code work in JAX by wrapping the\n",
"# inside of the `causal_attention_mask` function in\n",
"# a decorator to prevent jit compilation:\n",
"# `with jax.ensure_compile_time_eval():`.\n",
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"\n",
"import keras\n",
"from keras import layers\n",
"from keras import ops\n",
"from keras.layers import TextVectorization\n",
"import numpy as np\n",
"import os\n",
"import string\n",
"import random\n",
"import tensorflow\n",
"import tensorflow.data as tf_data\n",
"import tensorflow.strings as tf_strings\n",
""
]
},
Expand Down Expand Up @@ -93,34 +109,37 @@
" This prevents flow of information from future tokens to current token.\n",
" 1's in the lower triangle, counting from the lower right corner.\n",
" \"\"\"\n",
" i = tf.range(n_dest)[:, None]\n",
" j = tf.range(n_src)\n",
" i = ops.arange(n_dest)[:, None]\n",
" j = ops.arange(n_src)\n",
" m = i >= j - n_src + n_dest\n",
" mask = tf.cast(m, dtype)\n",
" mask = tf.reshape(mask, [1, n_dest, n_src])\n",
" mult = tf.concat(\n",
" [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0\n",
" mask = ops.cast(m, dtype)\n",
" mask = ops.reshape(mask, [1, n_dest, n_src])\n",
" mult = ops.concatenate(\n",
" [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0\n",
" )\n",
" return tf.tile(mask, mult)\n",
" return ops.tile(mask, mult)\n",
"\n",
"\n",
"class TransformerBlock(layers.Layer):\n",
" def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):\n",
" super().__init__()\n",
" self.att = layers.MultiHeadAttention(num_heads, embed_dim)\n",
" self.ffn = keras.Sequential(\n",
" [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
" [\n",
" layers.Dense(ff_dim, activation=\"relu\"),\n",
" layers.Dense(embed_dim),\n",
" ]\n",
" )\n",
" self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n",
" self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n",
" self.dropout1 = layers.Dropout(rate)\n",
" self.dropout2 = layers.Dropout(rate)\n",
"\n",
" def call(self, inputs):\n",
" input_shape = tf.shape(inputs)\n",
" input_shape = ops.shape(inputs)\n",
" batch_size = input_shape[0]\n",
" seq_len = input_shape[1]\n",
" causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)\n",
" causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, \"bool\")\n",
" attention_output = self.att(inputs, inputs, attention_mask=causal_mask)\n",
" attention_output = self.dropout1(attention_output)\n",
" out1 = self.layernorm1(inputs + attention_output)\n",
Expand Down Expand Up @@ -158,8 +177,8 @@
" self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)\n",
"\n",
" def call(self, x):\n",
" maxlen = tf.shape(x)[-1]\n",
" positions = tf.range(start=0, limit=maxlen, delta=1)\n",
" maxlen = ops.shape(x)[-1]\n",
" positions = ops.arange(0, maxlen, 1)\n",
" positions = self.pos_emb(positions)\n",
" x = self.token_emb(x)\n",
" return x + positions\n",
Expand Down Expand Up @@ -191,16 +210,17 @@
"\n",
"\n",
"def create_model():\n",
" inputs = layers.Input(shape=(maxlen,), dtype=tf.int32)\n",
" inputs = layers.Input(shape=(maxlen,), dtype=\"int32\")\n",
" embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)\n",
" x = embedding_layer(inputs)\n",
" transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)\n",
" x = transformer_block(x)\n",
" outputs = layers.Dense(vocab_size)(x)\n",
" model = keras.Model(inputs=inputs, outputs=[outputs, x])\n",
" loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
" loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
" model.compile(\n",
" \"adam\", loss=[loss_fn, None],\n",
" \"adam\",\n",
" loss=[loss_fn, None],\n",
" ) # No loss and optimization based on word embeddings from transformer block\n",
" return model\n",
""
Expand Down Expand Up @@ -259,16 +279,16 @@
"\n",
"# Create a dataset from text files\n",
"random.shuffle(filenames)\n",
"text_ds = tf.data.TextLineDataset(filenames)\n",
"text_ds = tf_data.TextLineDataset(filenames)\n",
"text_ds = text_ds.shuffle(buffer_size=256)\n",
"text_ds = text_ds.batch(batch_size)\n",
"\n",
"\n",
"def custom_standardization(input_string):\n",
" \"\"\" Remove html line-break tags and handle punctuation \"\"\"\n",
" lowercased = tf.strings.lower(input_string)\n",
" stripped_html = tf.strings.regex_replace(lowercased, \"<br />\", \" \")\n",
" return tf.strings.regex_replace(stripped_html, f\"([{string.punctuation}])\", r\" \\1\")\n",
" \"\"\"Remove html line-break tags and handle punctuation\"\"\"\n",
" lowercased = tf_strings.lower(input_string)\n",
" stripped_html = tf_strings.regex_replace(lowercased, \"<br />\", \" \")\n",
" return tf_strings.regex_replace(stripped_html, f\"([{string.punctuation}])\", r\" \\1\")\n",
"\n",
"\n",
"# Create a vectorization layer and adapt it to the text\n",
Expand All @@ -288,15 +308,15 @@
" word at position (i+1). The model will use all words up till position (i)\n",
" to predict the next word.\n",
" \"\"\"\n",
" text = tf.expand_dims(text, -1)\n",
" text = tensorflow.expand_dims(text, -1)\n",
" tokenized_sentences = vectorize_layer(text)\n",
" x = tokenized_sentences[:, :-1]\n",
" y = tokenized_sentences[:, 1:]\n",
" return x, y\n",
"\n",
"\n",
"text_ds = text_ds.map(prepare_lm_inputs_labels)\n",
"text_ds = text_ds.prefetch(tf.data.AUTOTUNE)\n",
"text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE)\n",
"text_ds = text_ds.prefetch(tf_data.AUTOTUNE)\n",
""
]
},
Expand Down Expand Up @@ -342,9 +362,9 @@
" self.k = top_k\n",
"\n",
" def sample_from(self, logits):\n",
" logits, indices = tf.math.top_k(logits, k=self.k, sorted=True)\n",
" logits, indices = ops.top_k(logits, k=self.k, sorted=True)\n",
" indices = np.asarray(indices).astype(\"int32\")\n",
" preds = keras.activations.softmax(tf.expand_dims(logits, 0))[0]\n",
" preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0]\n",
" preds = np.asarray(preds).astype(\"float32\")\n",
" return np.random.choice(indices, p=preds)\n",
"\n",
Expand All @@ -368,7 +388,7 @@
" else:\n",
" x = start_tokens\n",
" x = np.array([x])\n",
" y, _ = self.model.predict(x)\n",
" y, _ = self.model.predict(x, verbose=0)\n",
" sample_token = self.sample_from(y[0][sample_index])\n",
" tokens_generated.append(sample_token)\n",
" start_tokens.append(sample_token)\n",
Expand Down Expand Up @@ -445,4 +465,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading

0 comments on commit f9aebfa

Please sign in to comment.