diff --git a/examples/generative/img/molecule_generation/molecule_generation_21_18.png b/examples/generative/img/molecule_generation/molecule_generation_21_18.png new file mode 100644 index 0000000000..768d4cea86 Binary files /dev/null and b/examples/generative/img/molecule_generation/molecule_generation_21_18.png differ diff --git a/examples/generative/img/molecule_generation/molecule_generation_23_39.png b/examples/generative/img/molecule_generation/molecule_generation_23_39.png new file mode 100644 index 0000000000..fa5efcc123 Binary files /dev/null and b/examples/generative/img/molecule_generation/molecule_generation_23_39.png differ diff --git a/examples/generative/ipynb/molecule_generation.ipynb b/examples/generative/ipynb/molecule_generation.ipynb index 4ab407a6c1..94c9722177 100644 --- a/examples/generative/ipynb/molecule_generation.ipynb +++ b/examples/generative/ipynb/molecule_generation.ipynb @@ -10,7 +10,7 @@ "\n", "**Author:** [Victor Basu](https://www.linkedin.com/in/victor-basu-520958147)
\n", "**Date created:** 2022/03/10
\n", - "**Last modified:** 2022/03/24
\n", + "**Last modified:** 2024/12/17
\n", "**Description:** Implementing a Convolutional Variational AutoEncoder (VAE) for Drug Discovery." ] }, @@ -85,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -96,20 +96,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ + "import os\n", + "\n", + "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n", + "\n", "import ast\n", "\n", "import pandas as pd\n", "import numpy as np\n", "\n", "import tensorflow as tf\n", - "from tensorflow import keras\n", - "from tensorflow.keras import layers\n", + "import keras\n", + "from keras import layers\n", + "from keras import ops\n", "\n", "import matplotlib.pyplot as plt\n", "from rdkit import Chem, RDLogger\n", @@ -127,27 +132,27 @@ "source": [ "## Dataset\n", "\n", - "We use the [**ZINC – A Free Database of Commercially Available Compounds for\n", + "We use the [**ZINC \u2013 A Free Database of Commercially Available Compounds for\n", "Virtual Screening**](https://bit.ly/3IVBI4x) dataset. The dataset comes with molecule\n", "formula in SMILE representation along with their respective molecular properties such as\n", - "**logP** (water–octanal partition coefficient), **SAS** (synthetic\n", + "**logP** (water\u2013octanal partition coefficient), **SAS** (synthetic\n", "accessibility score) and **QED** (Qualitative Estimate of Drug-likeness)." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ "csv_path = keras.utils.get_file(\n", - " \"/content/250k_rndm_zinc_drugs_clean_3.csv\",\n", + " \"250k_rndm_zinc_drugs_clean_3.csv\",\n", " \"https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv\",\n", ")\n", "\n", - "df = pd.read_csv(\"/content/250k_rndm_zinc_drugs_clean_3.csv\")\n", + "df = pd.read_csv(csv_path)\n", "df[\"smiles\"] = df[\"smiles\"].apply(lambda s: s.replace(\"\\n\", \"\"))\n", "df.head()" ] @@ -163,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -247,7 +252,7 @@ " # Add bonds between atoms in molecule; based on the upper triangles\n", " # of the [symmetric] adjacency tensor\n", " (bonds_ij, atoms_i, atoms_j) = np.where(np.triu(adjacency) == 1)\n", - " for (bond_ij, atom_i, atom_j) in zip(bonds_ij, atoms_i, atoms_j):\n", + " for bond_ij, atom_i, atom_j in zip(bonds_ij, atoms_i, atoms_j):\n", " if atom_i == atom_j or bond_ij == BOND_DIM - 1:\n", " continue\n", " bond_type = bond_mapping[bond_ij]\n", @@ -260,7 +265,8 @@ " if flag != Chem.SanitizeFlags.SANITIZE_NONE:\n", " return None\n", "\n", - " return molecule\n" + " return molecule\n", + "" ] }, { @@ -274,7 +280,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -328,7 +334,7 @@ " regularizer=self.kernel_regularizer,\n", " trainable=True,\n", " name=\"W\",\n", - " dtype=tf.float32,\n", + " dtype=\"float32\",\n", " )\n", "\n", " if self.use_bias:\n", @@ -338,7 +344,7 @@ " regularizer=self.bias_regularizer,\n", " trainable=True,\n", " name=\"b\",\n", - " dtype=tf.float32,\n", + " dtype=\"float32\",\n", " )\n", "\n", " self.built = True\n", @@ -346,15 +352,16 @@ " def call(self, inputs, training=False):\n", " adjacency, features = inputs\n", " # Aggregate information from neighbors\n", - " x = tf.matmul(adjacency, features[:, None, :, :])\n", + " x = ops.matmul(adjacency, features[:, None])\n", " # Apply linear transformation\n", - " x = tf.matmul(x, self.kernel)\n", + " x = ops.matmul(x, self.kernel)\n", " if self.use_bias:\n", " x += self.bias\n", " # Reduce bond types dim\n", - " x_reduced = tf.reduce_sum(x, axis=1)\n", + " x_reduced = ops.sum(x, axis=1)\n", " # Apply non-linear transformation\n", - " return self.activation(x_reduced)\n" + " return self.activation(x_reduced)\n", + "" ] }, { @@ -374,9 +381,9 @@ "non-linearly transformed neighbourhood aggregations. We can define these layers as\n", "follows:\n", "\n", - "`H_hat**(l+1) = σ(D_hat**(-1) * A_hat * H_hat**(l+1) * W**(l))`\n", + "`H_hat**(l+1) = \u03c3(D_hat**(-1) * A_hat * H_hat**(l+1) * W**(l))`\n", "\n", - "Where `σ` denotes the non-linear transformation (commonly a ReLU activation), `A` the\n", + "Where `\u03c3` denotes the non-linear transformation (commonly a ReLU activation), `A` the\n", "adjacency tensor, `H_hat**(l)` the feature tensor at the `l-th` layer, `D_hat**(-1)` the\n", "inverse diagonal degree tensor of `A_hat`, and `W_hat**(l)` the trainable weight tensor\n", "at the `l-th` layer. Specifically, for each bond type (relation), the degree tensor\n", @@ -391,7 +398,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -401,8 +408,8 @@ "def get_encoder(\n", " gconv_units, latent_dim, adjacency_shape, feature_shape, dense_units, dropout_rate\n", "):\n", - " adjacency = keras.layers.Input(shape=adjacency_shape)\n", - " features = keras.layers.Input(shape=feature_shape)\n", + " adjacency = layers.Input(shape=adjacency_shape)\n", + " features = layers.Input(shape=feature_shape)\n", "\n", " # Propagate through one or more graph convolutional layers\n", " features_transformed = features\n", @@ -411,7 +418,7 @@ " [adjacency, features_transformed]\n", " )\n", " # Reduce 2-D representation of molecule to 1-D\n", - " x = keras.layers.GlobalAveragePooling1D()(features_transformed)\n", + " x = layers.GlobalAveragePooling1D()(features_transformed)\n", "\n", " # Propagate through one or more densely connected layers\n", " for units in dense_units:\n", @@ -431,26 +438,27 @@ "\n", " x = latent_inputs\n", " for units in dense_units:\n", - " x = keras.layers.Dense(units, activation=\"tanh\")(x)\n", - " x = keras.layers.Dropout(dropout_rate)(x)\n", + " x = layers.Dense(units, activation=\"tanh\")(x)\n", + " x = layers.Dropout(dropout_rate)(x)\n", "\n", " # Map outputs of previous layer (x) to [continuous] adjacency tensors (x_adjacency)\n", - " x_adjacency = keras.layers.Dense(tf.math.reduce_prod(adjacency_shape))(x)\n", - " x_adjacency = keras.layers.Reshape(adjacency_shape)(x_adjacency)\n", + " x_adjacency = layers.Dense(np.prod(adjacency_shape))(x)\n", + " x_adjacency = layers.Reshape(adjacency_shape)(x_adjacency)\n", " # Symmetrify tensors in the last two dimensions\n", - " x_adjacency = (x_adjacency + tf.transpose(x_adjacency, (0, 1, 3, 2))) / 2\n", - " x_adjacency = keras.layers.Softmax(axis=1)(x_adjacency)\n", + " x_adjacency = (x_adjacency + ops.transpose(x_adjacency, (0, 1, 3, 2))) / 2\n", + " x_adjacency = layers.Softmax(axis=1)(x_adjacency)\n", "\n", " # Map outputs of previous layer (x) to [continuous] feature tensors (x_features)\n", - " x_features = keras.layers.Dense(tf.math.reduce_prod(feature_shape))(x)\n", - " x_features = keras.layers.Reshape(feature_shape)(x_features)\n", - " x_features = keras.layers.Softmax(axis=2)(x_features)\n", + " x_features = layers.Dense(np.prod(feature_shape))(x)\n", + " x_features = layers.Reshape(feature_shape)(x_features)\n", + " x_features = layers.Softmax(axis=2)(x_features)\n", "\n", " decoder = keras.Model(\n", " latent_inputs, outputs=[x_adjacency, x_features], name=\"decoder\"\n", " )\n", "\n", - " return decoder\n" + " return decoder\n", + "" ] }, { @@ -464,7 +472,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -472,12 +480,16 @@ "source": [ "\n", "class Sampling(layers.Layer):\n", + " def __init__(self, seed=None, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.seed_generator = keras.random.SeedGenerator(seed)\n", + "\n", " def call(self, inputs):\n", " z_mean, z_log_var = inputs\n", - " batch = tf.shape(z_log_var)[0]\n", - " dim = tf.shape(z_log_var)[1]\n", - " epsilon = tf.keras.backend.random_normal(shape=(batch, dim))\n", - " return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n" + " batch, dim = ops.shape(z_log_var)\n", + " epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)\n", + " return z_mean + ops.exp(0.5 * z_log_var) * epsilon\n", + "" ] }, { @@ -512,7 +524,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -520,12 +532,14 @@ "source": [ "\n", "class MoleculeGenerator(keras.Model):\n", - " def __init__(self, encoder, decoder, max_len, **kwargs):\n", + " def __init__(self, encoder, decoder, max_len, seed=None, **kwargs):\n", " super().__init__(**kwargs)\n", " self.encoder = encoder\n", " self.decoder = decoder\n", " self.property_prediction_layer = layers.Dense(1)\n", " self.max_len = max_len\n", + " self.seed_generator = keras.random.SeedGenerator(seed)\n", + " self.sampling_layer = Sampling(seed=seed)\n", "\n", " self.train_total_loss_tracker = keras.metrics.Mean(name=\"train_total_loss\")\n", " self.val_total_loss_tracker = keras.metrics.Mean(name=\"val_total_loss\")\n", @@ -533,7 +547,7 @@ " def train_step(self, data):\n", " adjacency_tensor, feature_tensor, qed_tensor = data[0]\n", " graph_real = [adjacency_tensor, feature_tensor]\n", - " self.batch_size = tf.shape(qed_tensor)[0]\n", + " self.batch_size = ops.shape(qed_tensor)[0]\n", " with tf.GradientTape() as tape:\n", " z_mean, z_log_var, qed_pred, gen_adjacency, gen_features = self(\n", " graph_real, training=True\n", @@ -552,29 +566,30 @@ " def _compute_loss(\n", " self, z_log_var, z_mean, qed_true, qed_pred, graph_real, graph_generated\n", " ):\n", - "\n", " adjacency_real, features_real = graph_real\n", " adjacency_gen, features_gen = graph_generated\n", "\n", - " adjacency_loss = tf.reduce_mean(\n", - " tf.reduce_sum(\n", - " keras.losses.categorical_crossentropy(adjacency_real, adjacency_gen),\n", + " adjacency_loss = ops.mean(\n", + " ops.sum(\n", + " keras.losses.categorical_crossentropy(\n", + " adjacency_real, adjacency_gen, axis=1\n", + " ),\n", " axis=(1, 2),\n", " )\n", " )\n", - " features_loss = tf.reduce_mean(\n", - " tf.reduce_sum(\n", + " features_loss = ops.mean(\n", + " ops.sum(\n", " keras.losses.categorical_crossentropy(features_real, features_gen),\n", " axis=(1),\n", " )\n", " )\n", - " kl_loss = -0.5 * tf.reduce_sum(\n", - " 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), 1\n", + " kl_loss = -0.5 * ops.sum(\n", + " 1 + z_log_var - z_mean**2 - ops.minimum(ops.exp(z_log_var), 1e6), 1\n", " )\n", - " kl_loss = tf.reduce_mean(kl_loss)\n", + " kl_loss = ops.mean(kl_loss)\n", "\n", - " property_loss = tf.reduce_mean(\n", - " keras.losses.binary_crossentropy(qed_true, qed_pred)\n", + " property_loss = ops.mean(\n", + " keras.losses.binary_crossentropy(qed_true, ops.squeeze(qed_pred, axis=1))\n", " )\n", "\n", " graph_loss = self._gradient_penalty(graph_real, graph_generated)\n", @@ -587,11 +602,13 @@ " adjacency_generated, features_generated = graph_generated\n", "\n", " # Generate interpolated graphs (adjacency_interp and features_interp)\n", - " alpha = tf.random.uniform([self.batch_size])\n", - " alpha = tf.reshape(alpha, (self.batch_size, 1, 1, 1))\n", - " adjacency_interp = (adjacency_real * alpha) + (1 - alpha) * adjacency_generated\n", - " alpha = tf.reshape(alpha, (self.batch_size, 1, 1))\n", - " features_interp = (features_real * alpha) + (1 - alpha) * features_generated\n", + " alpha = keras.random.uniform(shape=(self.batch_size,), seed=self.seed_generator)\n", + " alpha = ops.reshape(alpha, (self.batch_size, 1, 1, 1))\n", + " adjacency_interp = (adjacency_real * alpha) + (\n", + " 1.0 - alpha\n", + " ) * adjacency_generated\n", + " alpha = ops.reshape(alpha, (self.batch_size, 1, 1))\n", + " features_interp = (features_real * alpha) + (1.0 - alpha) * features_generated\n", "\n", " # Compute the logits of interpolated graphs\n", " with tf.GradientTape() as tape:\n", @@ -604,24 +621,26 @@ " # Compute the gradients with respect to the interpolated graphs\n", " grads = tape.gradient(logits, [adjacency_interp, features_interp])\n", " # Compute the gradient penalty\n", - " grads_adjacency_penalty = (1 - tf.norm(grads[0], axis=1)) ** 2\n", - " grads_features_penalty = (1 - tf.norm(grads[1], axis=2)) ** 2\n", - " return tf.reduce_mean(\n", - " tf.reduce_mean(grads_adjacency_penalty, axis=(-2, -1))\n", - " + tf.reduce_mean(grads_features_penalty, axis=(-1))\n", + " grads_adjacency_penalty = (1 - ops.norm(grads[0], axis=1)) ** 2\n", + " grads_features_penalty = (1 - ops.norm(grads[1], axis=2)) ** 2\n", + " return ops.mean(\n", + " ops.mean(grads_adjacency_penalty, axis=(-2, -1))\n", + " + ops.mean(grads_features_penalty, axis=(-1))\n", " )\n", "\n", " def inference(self, batch_size):\n", - " z = tf.random.normal((batch_size, LATENT_DIM))\n", + " z = keras.random.normal(\n", + " shape=(batch_size, LATENT_DIM), seed=self.seed_generator\n", + " )\n", " reconstruction_adjacency, reconstruction_features = model.decoder.predict(z)\n", " # obtain one-hot encoded adjacency tensor\n", - " adjacency = tf.argmax(reconstruction_adjacency, axis=1)\n", - " adjacency = tf.one_hot(adjacency, depth=BOND_DIM, axis=1)\n", + " adjacency = ops.argmax(reconstruction_adjacency, axis=1)\n", + " adjacency = ops.one_hot(adjacency, num_classes=BOND_DIM, axis=1)\n", " # Remove potential self-loops from adjacency\n", - " adjacency = tf.linalg.set_diag(adjacency, tf.zeros(tf.shape(adjacency)[:-1]))\n", + " adjacency = adjacency * (1.0 - ops.eye(NUM_ATOMS, dtype=\"float32\")[None, None])\n", " # obtain one-hot encoded feature tensor\n", - " features = tf.argmax(reconstruction_features, axis=2)\n", - " features = tf.one_hot(features, depth=ATOM_DIM, axis=2)\n", + " features = ops.argmax(reconstruction_features, axis=2)\n", + " features = ops.one_hot(features, num_classes=ATOM_DIM, axis=2)\n", " return [\n", " graph_to_molecule([adjacency[i].numpy(), features[i].numpy()])\n", " for i in range(batch_size)\n", @@ -629,13 +648,14 @@ "\n", " def call(self, inputs):\n", " z_mean, log_var = self.encoder(inputs)\n", - " z = Sampling()([z_mean, log_var])\n", + " z = self.sampling_layer([z_mean, log_var])\n", "\n", " gen_adjacency, gen_features = self.decoder(z)\n", "\n", " property_pred = self.property_prediction_layer(z_mean)\n", "\n", - " return z_mean, log_var, property_pred, gen_adjacency, gen_features\n" + " return z_mean, log_var, property_pred, gen_adjacency, gen_features\n", + "" ] }, { @@ -649,13 +669,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ - "vae_optimizer = tf.keras.optimizers.Adam(learning_rate=VAE_LR)\n", + "vae_optimizer = keras.optimizers.Adam(learning_rate=VAE_LR)\n", "\n", "encoder = get_encoder(\n", " gconv_units=[9],\n", @@ -701,7 +721,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -725,7 +745,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 0, "metadata": { "colab_type": "code" }, @@ -800,4 +820,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/examples/generative/md/molecule_generation.md b/examples/generative/md/molecule_generation.md index a0c4217081..5cc7ccd9d2 100644 --- a/examples/generative/md/molecule_generation.md +++ b/examples/generative/md/molecule_generation.md @@ -2,7 +2,7 @@ **Author:** [Victor Basu](https://www.linkedin.com/in/victor-basu-520958147)
**Date created:** 2022/03/10
-**Last modified:** 2022/03/24
+**Last modified:** 2024/12/17
**Description:** Implementing a Convolutional Variational AutoEncoder (VAE) for Drug Discovery. @@ -73,14 +73,19 @@ a molecule object, which can then be used to compute a great number of molecular ``` ```python +import os + +os.environ["KERAS_BACKEND"] = "tensorflow" + import ast import pandas as pd import numpy as np import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers +import keras +from keras import layers +from keras import ops import matplotlib.pyplot as plt from rdkit import Chem, RDLogger @@ -89,13 +94,7 @@ from rdkit.Chem.Draw import MolsToGridImage RDLogger.DisableLog("rdApp.*") ``` -
-``` - |████████████████████████████████| 20.6 MB 1.2 MB/s -[?25h -``` -
--- ## Dataset @@ -108,22 +107,18 @@ accessibility score) and **QED** (Qualitative Estimate of Drug-likeness). ```python csv_path = keras.utils.get_file( - "/content/250k_rndm_zinc_drugs_clean_3.csv", + "250k_rndm_zinc_drugs_clean_3.csv", "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv", ) -df = pd.read_csv("/content/250k_rndm_zinc_drugs_clean_3.csv") +df = pd.read_csv(csv_path) df["smiles"] = df["smiles"].apply(lambda s: s.replace("\n", "")) df.head() ``` -
-``` -Downloading data from https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv -22606589/22606589 [==============================] - 0s 0us/step -``` -
+ +