diff --git a/examples/generative/img/molecule_generation/molecule_generation_21_18.png b/examples/generative/img/molecule_generation/molecule_generation_21_18.png
new file mode 100644
index 0000000000..768d4cea86
Binary files /dev/null and b/examples/generative/img/molecule_generation/molecule_generation_21_18.png differ
diff --git a/examples/generative/img/molecule_generation/molecule_generation_23_39.png b/examples/generative/img/molecule_generation/molecule_generation_23_39.png
new file mode 100644
index 0000000000..fa5efcc123
Binary files /dev/null and b/examples/generative/img/molecule_generation/molecule_generation_23_39.png differ
diff --git a/examples/generative/ipynb/molecule_generation.ipynb b/examples/generative/ipynb/molecule_generation.ipynb
index 4ab407a6c1..94c9722177 100644
--- a/examples/generative/ipynb/molecule_generation.ipynb
+++ b/examples/generative/ipynb/molecule_generation.ipynb
@@ -10,7 +10,7 @@
"\n",
"**Author:** [Victor Basu](https://www.linkedin.com/in/victor-basu-520958147)
\n",
"**Date created:** 2022/03/10
\n",
- "**Last modified:** 2022/03/24
\n",
+ "**Last modified:** 2024/12/17
\n",
"**Description:** Implementing a Convolutional Variational AutoEncoder (VAE) for Drug Discovery."
]
},
@@ -85,7 +85,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"colab_type": "code"
},
@@ -96,20 +96,25 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
+ "import os\n",
+ "\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
+ "\n",
"import ast\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"import tensorflow as tf\n",
- "from tensorflow import keras\n",
- "from tensorflow.keras import layers\n",
+ "import keras\n",
+ "from keras import layers\n",
+ "from keras import ops\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from rdkit import Chem, RDLogger\n",
@@ -127,27 +132,27 @@
"source": [
"## Dataset\n",
"\n",
- "We use the [**ZINC – A Free Database of Commercially Available Compounds for\n",
+ "We use the [**ZINC \u2013 A Free Database of Commercially Available Compounds for\n",
"Virtual Screening**](https://bit.ly/3IVBI4x) dataset. The dataset comes with molecule\n",
"formula in SMILE representation along with their respective molecular properties such as\n",
- "**logP** (water–octanal partition coefficient), **SAS** (synthetic\n",
+ "**logP** (water\u2013octanal partition coefficient), **SAS** (synthetic\n",
"accessibility score) and **QED** (Qualitative Estimate of Drug-likeness)."
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"csv_path = keras.utils.get_file(\n",
- " \"/content/250k_rndm_zinc_drugs_clean_3.csv\",\n",
+ " \"250k_rndm_zinc_drugs_clean_3.csv\",\n",
" \"https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv\",\n",
")\n",
"\n",
- "df = pd.read_csv(\"/content/250k_rndm_zinc_drugs_clean_3.csv\")\n",
+ "df = pd.read_csv(csv_path)\n",
"df[\"smiles\"] = df[\"smiles\"].apply(lambda s: s.replace(\"\\n\", \"\"))\n",
"df.head()"
]
@@ -163,7 +168,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"colab_type": "code"
},
@@ -247,7 +252,7 @@
" # Add bonds between atoms in molecule; based on the upper triangles\n",
" # of the [symmetric] adjacency tensor\n",
" (bonds_ij, atoms_i, atoms_j) = np.where(np.triu(adjacency) == 1)\n",
- " for (bond_ij, atom_i, atom_j) in zip(bonds_ij, atoms_i, atoms_j):\n",
+ " for bond_ij, atom_i, atom_j in zip(bonds_ij, atoms_i, atoms_j):\n",
" if atom_i == atom_j or bond_ij == BOND_DIM - 1:\n",
" continue\n",
" bond_type = bond_mapping[bond_ij]\n",
@@ -260,7 +265,8 @@
" if flag != Chem.SanitizeFlags.SANITIZE_NONE:\n",
" return None\n",
"\n",
- " return molecule\n"
+ " return molecule\n",
+ ""
]
},
{
@@ -274,7 +280,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"colab_type": "code"
},
@@ -328,7 +334,7 @@
" regularizer=self.kernel_regularizer,\n",
" trainable=True,\n",
" name=\"W\",\n",
- " dtype=tf.float32,\n",
+ " dtype=\"float32\",\n",
" )\n",
"\n",
" if self.use_bias:\n",
@@ -338,7 +344,7 @@
" regularizer=self.bias_regularizer,\n",
" trainable=True,\n",
" name=\"b\",\n",
- " dtype=tf.float32,\n",
+ " dtype=\"float32\",\n",
" )\n",
"\n",
" self.built = True\n",
@@ -346,15 +352,16 @@
" def call(self, inputs, training=False):\n",
" adjacency, features = inputs\n",
" # Aggregate information from neighbors\n",
- " x = tf.matmul(adjacency, features[:, None, :, :])\n",
+ " x = ops.matmul(adjacency, features[:, None])\n",
" # Apply linear transformation\n",
- " x = tf.matmul(x, self.kernel)\n",
+ " x = ops.matmul(x, self.kernel)\n",
" if self.use_bias:\n",
" x += self.bias\n",
" # Reduce bond types dim\n",
- " x_reduced = tf.reduce_sum(x, axis=1)\n",
+ " x_reduced = ops.sum(x, axis=1)\n",
" # Apply non-linear transformation\n",
- " return self.activation(x_reduced)\n"
+ " return self.activation(x_reduced)\n",
+ ""
]
},
{
@@ -374,9 +381,9 @@
"non-linearly transformed neighbourhood aggregations. We can define these layers as\n",
"follows:\n",
"\n",
- "`H_hat**(l+1) = σ(D_hat**(-1) * A_hat * H_hat**(l+1) * W**(l))`\n",
+ "`H_hat**(l+1) = \u03c3(D_hat**(-1) * A_hat * H_hat**(l+1) * W**(l))`\n",
"\n",
- "Where `σ` denotes the non-linear transformation (commonly a ReLU activation), `A` the\n",
+ "Where `\u03c3` denotes the non-linear transformation (commonly a ReLU activation), `A` the\n",
"adjacency tensor, `H_hat**(l)` the feature tensor at the `l-th` layer, `D_hat**(-1)` the\n",
"inverse diagonal degree tensor of `A_hat`, and `W_hat**(l)` the trainable weight tensor\n",
"at the `l-th` layer. Specifically, for each bond type (relation), the degree tensor\n",
@@ -391,7 +398,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"colab_type": "code"
},
@@ -401,8 +408,8 @@
"def get_encoder(\n",
" gconv_units, latent_dim, adjacency_shape, feature_shape, dense_units, dropout_rate\n",
"):\n",
- " adjacency = keras.layers.Input(shape=adjacency_shape)\n",
- " features = keras.layers.Input(shape=feature_shape)\n",
+ " adjacency = layers.Input(shape=adjacency_shape)\n",
+ " features = layers.Input(shape=feature_shape)\n",
"\n",
" # Propagate through one or more graph convolutional layers\n",
" features_transformed = features\n",
@@ -411,7 +418,7 @@
" [adjacency, features_transformed]\n",
" )\n",
" # Reduce 2-D representation of molecule to 1-D\n",
- " x = keras.layers.GlobalAveragePooling1D()(features_transformed)\n",
+ " x = layers.GlobalAveragePooling1D()(features_transformed)\n",
"\n",
" # Propagate through one or more densely connected layers\n",
" for units in dense_units:\n",
@@ -431,26 +438,27 @@
"\n",
" x = latent_inputs\n",
" for units in dense_units:\n",
- " x = keras.layers.Dense(units, activation=\"tanh\")(x)\n",
- " x = keras.layers.Dropout(dropout_rate)(x)\n",
+ " x = layers.Dense(units, activation=\"tanh\")(x)\n",
+ " x = layers.Dropout(dropout_rate)(x)\n",
"\n",
" # Map outputs of previous layer (x) to [continuous] adjacency tensors (x_adjacency)\n",
- " x_adjacency = keras.layers.Dense(tf.math.reduce_prod(adjacency_shape))(x)\n",
- " x_adjacency = keras.layers.Reshape(adjacency_shape)(x_adjacency)\n",
+ " x_adjacency = layers.Dense(np.prod(adjacency_shape))(x)\n",
+ " x_adjacency = layers.Reshape(adjacency_shape)(x_adjacency)\n",
" # Symmetrify tensors in the last two dimensions\n",
- " x_adjacency = (x_adjacency + tf.transpose(x_adjacency, (0, 1, 3, 2))) / 2\n",
- " x_adjacency = keras.layers.Softmax(axis=1)(x_adjacency)\n",
+ " x_adjacency = (x_adjacency + ops.transpose(x_adjacency, (0, 1, 3, 2))) / 2\n",
+ " x_adjacency = layers.Softmax(axis=1)(x_adjacency)\n",
"\n",
" # Map outputs of previous layer (x) to [continuous] feature tensors (x_features)\n",
- " x_features = keras.layers.Dense(tf.math.reduce_prod(feature_shape))(x)\n",
- " x_features = keras.layers.Reshape(feature_shape)(x_features)\n",
- " x_features = keras.layers.Softmax(axis=2)(x_features)\n",
+ " x_features = layers.Dense(np.prod(feature_shape))(x)\n",
+ " x_features = layers.Reshape(feature_shape)(x_features)\n",
+ " x_features = layers.Softmax(axis=2)(x_features)\n",
"\n",
" decoder = keras.Model(\n",
" latent_inputs, outputs=[x_adjacency, x_features], name=\"decoder\"\n",
" )\n",
"\n",
- " return decoder\n"
+ " return decoder\n",
+ ""
]
},
{
@@ -464,7 +472,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"colab_type": "code"
},
@@ -472,12 +480,16 @@
"source": [
"\n",
"class Sampling(layers.Layer):\n",
+ " def __init__(self, seed=None, **kwargs):\n",
+ " super().__init__(**kwargs)\n",
+ " self.seed_generator = keras.random.SeedGenerator(seed)\n",
+ "\n",
" def call(self, inputs):\n",
" z_mean, z_log_var = inputs\n",
- " batch = tf.shape(z_log_var)[0]\n",
- " dim = tf.shape(z_log_var)[1]\n",
- " epsilon = tf.keras.backend.random_normal(shape=(batch, dim))\n",
- " return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n"
+ " batch, dim = ops.shape(z_log_var)\n",
+ " epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)\n",
+ " return z_mean + ops.exp(0.5 * z_log_var) * epsilon\n",
+ ""
]
},
{
@@ -512,7 +524,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"colab_type": "code"
},
@@ -520,12 +532,14 @@
"source": [
"\n",
"class MoleculeGenerator(keras.Model):\n",
- " def __init__(self, encoder, decoder, max_len, **kwargs):\n",
+ " def __init__(self, encoder, decoder, max_len, seed=None, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" self.property_prediction_layer = layers.Dense(1)\n",
" self.max_len = max_len\n",
+ " self.seed_generator = keras.random.SeedGenerator(seed)\n",
+ " self.sampling_layer = Sampling(seed=seed)\n",
"\n",
" self.train_total_loss_tracker = keras.metrics.Mean(name=\"train_total_loss\")\n",
" self.val_total_loss_tracker = keras.metrics.Mean(name=\"val_total_loss\")\n",
@@ -533,7 +547,7 @@
" def train_step(self, data):\n",
" adjacency_tensor, feature_tensor, qed_tensor = data[0]\n",
" graph_real = [adjacency_tensor, feature_tensor]\n",
- " self.batch_size = tf.shape(qed_tensor)[0]\n",
+ " self.batch_size = ops.shape(qed_tensor)[0]\n",
" with tf.GradientTape() as tape:\n",
" z_mean, z_log_var, qed_pred, gen_adjacency, gen_features = self(\n",
" graph_real, training=True\n",
@@ -552,29 +566,30 @@
" def _compute_loss(\n",
" self, z_log_var, z_mean, qed_true, qed_pred, graph_real, graph_generated\n",
" ):\n",
- "\n",
" adjacency_real, features_real = graph_real\n",
" adjacency_gen, features_gen = graph_generated\n",
"\n",
- " adjacency_loss = tf.reduce_mean(\n",
- " tf.reduce_sum(\n",
- " keras.losses.categorical_crossentropy(adjacency_real, adjacency_gen),\n",
+ " adjacency_loss = ops.mean(\n",
+ " ops.sum(\n",
+ " keras.losses.categorical_crossentropy(\n",
+ " adjacency_real, adjacency_gen, axis=1\n",
+ " ),\n",
" axis=(1, 2),\n",
" )\n",
" )\n",
- " features_loss = tf.reduce_mean(\n",
- " tf.reduce_sum(\n",
+ " features_loss = ops.mean(\n",
+ " ops.sum(\n",
" keras.losses.categorical_crossentropy(features_real, features_gen),\n",
" axis=(1),\n",
" )\n",
" )\n",
- " kl_loss = -0.5 * tf.reduce_sum(\n",
- " 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), 1\n",
+ " kl_loss = -0.5 * ops.sum(\n",
+ " 1 + z_log_var - z_mean**2 - ops.minimum(ops.exp(z_log_var), 1e6), 1\n",
" )\n",
- " kl_loss = tf.reduce_mean(kl_loss)\n",
+ " kl_loss = ops.mean(kl_loss)\n",
"\n",
- " property_loss = tf.reduce_mean(\n",
- " keras.losses.binary_crossentropy(qed_true, qed_pred)\n",
+ " property_loss = ops.mean(\n",
+ " keras.losses.binary_crossentropy(qed_true, ops.squeeze(qed_pred, axis=1))\n",
" )\n",
"\n",
" graph_loss = self._gradient_penalty(graph_real, graph_generated)\n",
@@ -587,11 +602,13 @@
" adjacency_generated, features_generated = graph_generated\n",
"\n",
" # Generate interpolated graphs (adjacency_interp and features_interp)\n",
- " alpha = tf.random.uniform([self.batch_size])\n",
- " alpha = tf.reshape(alpha, (self.batch_size, 1, 1, 1))\n",
- " adjacency_interp = (adjacency_real * alpha) + (1 - alpha) * adjacency_generated\n",
- " alpha = tf.reshape(alpha, (self.batch_size, 1, 1))\n",
- " features_interp = (features_real * alpha) + (1 - alpha) * features_generated\n",
+ " alpha = keras.random.uniform(shape=(self.batch_size,), seed=self.seed_generator)\n",
+ " alpha = ops.reshape(alpha, (self.batch_size, 1, 1, 1))\n",
+ " adjacency_interp = (adjacency_real * alpha) + (\n",
+ " 1.0 - alpha\n",
+ " ) * adjacency_generated\n",
+ " alpha = ops.reshape(alpha, (self.batch_size, 1, 1))\n",
+ " features_interp = (features_real * alpha) + (1.0 - alpha) * features_generated\n",
"\n",
" # Compute the logits of interpolated graphs\n",
" with tf.GradientTape() as tape:\n",
@@ -604,24 +621,26 @@
" # Compute the gradients with respect to the interpolated graphs\n",
" grads = tape.gradient(logits, [adjacency_interp, features_interp])\n",
" # Compute the gradient penalty\n",
- " grads_adjacency_penalty = (1 - tf.norm(grads[0], axis=1)) ** 2\n",
- " grads_features_penalty = (1 - tf.norm(grads[1], axis=2)) ** 2\n",
- " return tf.reduce_mean(\n",
- " tf.reduce_mean(grads_adjacency_penalty, axis=(-2, -1))\n",
- " + tf.reduce_mean(grads_features_penalty, axis=(-1))\n",
+ " grads_adjacency_penalty = (1 - ops.norm(grads[0], axis=1)) ** 2\n",
+ " grads_features_penalty = (1 - ops.norm(grads[1], axis=2)) ** 2\n",
+ " return ops.mean(\n",
+ " ops.mean(grads_adjacency_penalty, axis=(-2, -1))\n",
+ " + ops.mean(grads_features_penalty, axis=(-1))\n",
" )\n",
"\n",
" def inference(self, batch_size):\n",
- " z = tf.random.normal((batch_size, LATENT_DIM))\n",
+ " z = keras.random.normal(\n",
+ " shape=(batch_size, LATENT_DIM), seed=self.seed_generator\n",
+ " )\n",
" reconstruction_adjacency, reconstruction_features = model.decoder.predict(z)\n",
" # obtain one-hot encoded adjacency tensor\n",
- " adjacency = tf.argmax(reconstruction_adjacency, axis=1)\n",
- " adjacency = tf.one_hot(adjacency, depth=BOND_DIM, axis=1)\n",
+ " adjacency = ops.argmax(reconstruction_adjacency, axis=1)\n",
+ " adjacency = ops.one_hot(adjacency, num_classes=BOND_DIM, axis=1)\n",
" # Remove potential self-loops from adjacency\n",
- " adjacency = tf.linalg.set_diag(adjacency, tf.zeros(tf.shape(adjacency)[:-1]))\n",
+ " adjacency = adjacency * (1.0 - ops.eye(NUM_ATOMS, dtype=\"float32\")[None, None])\n",
" # obtain one-hot encoded feature tensor\n",
- " features = tf.argmax(reconstruction_features, axis=2)\n",
- " features = tf.one_hot(features, depth=ATOM_DIM, axis=2)\n",
+ " features = ops.argmax(reconstruction_features, axis=2)\n",
+ " features = ops.one_hot(features, num_classes=ATOM_DIM, axis=2)\n",
" return [\n",
" graph_to_molecule([adjacency[i].numpy(), features[i].numpy()])\n",
" for i in range(batch_size)\n",
@@ -629,13 +648,14 @@
"\n",
" def call(self, inputs):\n",
" z_mean, log_var = self.encoder(inputs)\n",
- " z = Sampling()([z_mean, log_var])\n",
+ " z = self.sampling_layer([z_mean, log_var])\n",
"\n",
" gen_adjacency, gen_features = self.decoder(z)\n",
"\n",
" property_pred = self.property_prediction_layer(z_mean)\n",
"\n",
- " return z_mean, log_var, property_pred, gen_adjacency, gen_features\n"
+ " return z_mean, log_var, property_pred, gen_adjacency, gen_features\n",
+ ""
]
},
{
@@ -649,13 +669,13 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
- "vae_optimizer = tf.keras.optimizers.Adam(learning_rate=VAE_LR)\n",
+ "vae_optimizer = keras.optimizers.Adam(learning_rate=VAE_LR)\n",
"\n",
"encoder = get_encoder(\n",
" gconv_units=[9],\n",
@@ -701,7 +721,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"colab_type": "code"
},
@@ -725,7 +745,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 0,
"metadata": {
"colab_type": "code"
},
@@ -800,4 +820,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
-}
+}
\ No newline at end of file
diff --git a/examples/generative/md/molecule_generation.md b/examples/generative/md/molecule_generation.md
index a0c4217081..5cc7ccd9d2 100644
--- a/examples/generative/md/molecule_generation.md
+++ b/examples/generative/md/molecule_generation.md
@@ -2,7 +2,7 @@
**Author:** [Victor Basu](https://www.linkedin.com/in/victor-basu-520958147)
**Date created:** 2022/03/10
-**Last modified:** 2022/03/24
+**Last modified:** 2024/12/17
**Description:** Implementing a Convolutional Variational AutoEncoder (VAE) for Drug Discovery.
@@ -73,14 +73,19 @@ a molecule object, which can then be used to compute a great number of molecular
```
```python
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
import ast
import pandas as pd
import numpy as np
import tensorflow as tf
-from tensorflow import keras
-from tensorflow.keras import layers
+import keras
+from keras import layers
+from keras import ops
import matplotlib.pyplot as plt
from rdkit import Chem, RDLogger
@@ -89,13 +94,7 @@ from rdkit.Chem.Draw import MolsToGridImage
RDLogger.DisableLog("rdApp.*")
```
-