diff --git a/guides/custom_train_step_in_torch.py b/guides/custom_train_step_in_torch.py
index 07a89cfed4..b493411ab5 100644
--- a/guides/custom_train_step_in_torch.py
+++ b/guides/custom_train_step_in_torch.py
@@ -2,7 +2,7 @@
Title: Customizing what happens in `fit()` with PyTorch
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2023/06/27
-Last modified: 2023/06/27
+Last modified: 2024/08/01
Description: Overriding the training step of the Model class with PyTorch.
Accelerator: GPU
"""
@@ -397,7 +397,7 @@ def compile(self, d_optimizer, g_optimizer, loss_fn):
def train_step(self, real_images):
device = "cuda" if torch.cuda.is_available() else "cpu"
- if isinstance(real_images, tuple):
+ if isinstance(real_images, tuple) or isinstance(real_images, list):
real_images = real_images[0]
# Sample random points in the latent space
batch_size = real_images.shape[0]
diff --git a/guides/ipynb/custom_train_step_in_torch.ipynb b/guides/ipynb/custom_train_step_in_torch.ipynb
index 0ae7a8b9f0..9f3d3576d1 100644
--- a/guides/ipynb/custom_train_step_in_torch.ipynb
+++ b/guides/ipynb/custom_train_step_in_torch.ipynb
@@ -10,7 +10,7 @@
"\n",
"**Author:** [fchollet](https://twitter.com/fchollet)
\n",
"**Date created:** 2023/06/27
\n",
- "**Last modified:** 2023/06/27
\n",
+ "**Last modified:** 2024/08/01
\n",
"**Description:** Overriding the training step of the Model class with PyTorch."
]
},
@@ -50,17 +50,6 @@
"Let's see how that works."
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "!pip install keras --upgrade --quiet"
- ]
- },
{
"cell_type": "markdown",
"metadata": {
@@ -282,7 +271,7 @@
"outputs = keras.layers.Dense(1)(inputs)\n",
"model = CustomModel(inputs, outputs)\n",
"\n",
- "# We don't passs a loss or metrics here.\n",
+ "# We don't pass a loss or metrics here.\n",
"model.compile(optimizer=\"adam\")\n",
"\n",
"# Just use `fit` as usual -- you can use callbacks, etc.\n",
@@ -531,7 +520,7 @@
"\n",
" def train_step(self, real_images):\n",
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
- " if isinstance(real_images, tuple):\n",
+ " if isinstance(real_images, tuple) or isinstance(real_images, list):\n",
" real_images = real_images[0]\n",
" # Sample random points in the latent space\n",
" batch_size = real_images.shape[0]\n",
diff --git a/guides/md/custom_train_step_in_torch.md b/guides/md/custom_train_step_in_torch.md
index c112a44be6..0b62d818f0 100644
--- a/guides/md/custom_train_step_in_torch.md
+++ b/guides/md/custom_train_step_in_torch.md
@@ -2,7 +2,7 @@
**Author:** [fchollet](https://twitter.com/fchollet)
**Date created:** 2023/06/27
-**Last modified:** 2023/06/27
+**Last modified:** 2024/08/01
**Description:** Overriding the training step of the Model class with PyTorch.
@@ -142,14 +142,14 @@ model.fit(x, y, epochs=3)