Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert conv_lstm #1603

Merged
merged 1 commit into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions examples/vision/conv_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Next-Frame Video Prediction with Convolutional LSTMs
Author: [Amogh Joshi](https://github.com/amogh7joshi)
Date created: 2021/06/02
Last modified: 2021/06/05
Last modified: 2023/11/10
Description: How to build and train a convolutional LSTM model for next-frame video prediction.
Accelerator: GPU
"""
Expand All @@ -25,9 +25,8 @@
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import keras
from keras import layers

import io
import imageio
Expand Down Expand Up @@ -279,7 +278,7 @@ def create_shifted_frames(data):

# Construct a GIF from the frames.
with io.BytesIO() as gif:
imageio.mimsave(gif, current_frames, "GIF", fps=5)
imageio.mimsave(gif, current_frames, "GIF", duration=200)
predicted_videos.append(gif.getvalue())

# Display the videos.
Expand Down
Binary file added examples/vision/img/conv_lstm/conv_lstm_13_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/vision/img/conv_lstm/conv_lstm_7_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 8 additions & 7 deletions examples/vision/ipynb/conv_lstm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Amogh Joshi](https://github.com/amogh7joshi)<br>\n",
"**Date created:** 2021/06/02<br>\n",
"**Last modified:** 2021/06/05<br>\n",
"**Last modified:** 2023/11/10<br>\n",
"**Description:** How to build and train a convolutional LSTM model for next-frame video prediction."
]
},
Expand Down Expand Up @@ -50,9 +50,8 @@
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import keras\n",
"from keras import layers\n",
"\n",
"import io\n",
"import imageio\n",
Expand Down Expand Up @@ -116,6 +115,7 @@
"train_dataset = train_dataset / 255\n",
"val_dataset = val_dataset / 255\n",
"\n",
"\n",
"# We'll define a helper function to shift the frames, where\n",
"# `x` is frames 0 to n - 1, and `y` is frames 1 to n.\n",
"def create_shifted_frames(data):\n",
Expand Down Expand Up @@ -226,7 +226,8 @@
"# Next, we will build the complete model and compile it.\n",
"model = keras.models.Model(inp, x)\n",
"model.compile(\n",
" loss=keras.losses.binary_crossentropy, optimizer=keras.optimizers.Adam(),\n",
" loss=keras.losses.binary_crossentropy,\n",
" optimizer=keras.optimizers.Adam(),\n",
")"
]
},
Expand Down Expand Up @@ -342,7 +343,7 @@
"and construct some GIFs with them to see the model's\n",
"predicted videos.\n",
"\n",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/conv-lstm) ",
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/conv-lstm)\n",
"and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/conv-lstm)."
]
},
Expand Down Expand Up @@ -386,7 +387,7 @@
"\n",
" # Construct a GIF from the frames.\n",
" with io.BytesIO() as gif:\n",
" imageio.mimsave(gif, current_frames, \"GIF\", fps=5)\n",
" imageio.mimsave(gif, current_frames, \"GIF\", duration=200)\n",
" predicted_videos.append(gif.getvalue())\n",
"\n",
"# Display the videos.\n",
Expand Down
147 changes: 135 additions & 12 deletions examples/vision/md/conv_lstm.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**Author:** [Amogh Joshi](https://github.com/amogh7joshi)<br>
**Date created:** 2021/06/02<br>
**Last modified:** 2021/06/05<br>
**Last modified:** 2023/11/10<br>
**Description:** How to build and train a convolutional LSTM model for next-frame video prediction.


Expand All @@ -28,9 +28,8 @@ of predicting what video frames come next given a series of past frames.
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import keras
from keras import layers

import io
import imageio
Expand Down Expand Up @@ -82,6 +81,7 @@ val_dataset = dataset[val_index]
train_dataset = train_dataset / 255
val_dataset = val_dataset / 255


# We'll define a helper function to shift the frames, where
# `x` is frames 0 to n - 1, and `y` is frames 1 to n.
def create_shifted_frames(data):
Expand All @@ -101,6 +101,8 @@ print("Validation Dataset Shapes: " + str(x_val.shape) + ", " + str(y_val.shape)

<div class="k-default-codeblock">
```
Downloading data from http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy
819200096/819200096 ━━━━━━━━━━━━━━━━━━━━ 116s 0us/step
Training Dataset Shapes: (900, 19, 64, 64, 1), (900, 19, 64, 64, 1)
Validation Dataset Shapes: (100, 19, 64, 64, 1), (100, 19, 64, 64, 1)

Expand Down Expand Up @@ -132,7 +134,7 @@ plt.show()

<div class="k-default-codeblock">
```
Displaying frames for example 130.
Displaying frames for example 95.

```
</div>
Expand Down Expand Up @@ -186,7 +188,8 @@ x = layers.Conv3D(
# Next, we will build the complete model and compile it.
model = keras.models.Model(inp, x)
model.compile(
loss=keras.losses.binary_crossentropy, optimizer=keras.optimizers.Adam(),
loss=keras.losses.binary_crossentropy,
optimizer=keras.optimizers.Adam(),
)
```

Expand Down Expand Up @@ -216,6 +219,53 @@ model.fit(
)
```

<div class="k-default-codeblock">
```
Epoch 1/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 50s 226ms/step - loss: 0.1510 - val_loss: 0.2966 - learning_rate: 0.0010
Epoch 2/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0287 - val_loss: 0.1766 - learning_rate: 0.0010
Epoch 3/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0269 - val_loss: 0.0661 - learning_rate: 0.0010
Epoch 4/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0264 - val_loss: 0.0279 - learning_rate: 0.0010
Epoch 5/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0258 - val_loss: 0.0254 - learning_rate: 0.0010
Epoch 6/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0256 - val_loss: 0.0253 - learning_rate: 0.0010
Epoch 7/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0251 - val_loss: 0.0248 - learning_rate: 0.0010
Epoch 8/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0251 - val_loss: 0.0251 - learning_rate: 0.0010
Epoch 9/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0247 - val_loss: 0.0243 - learning_rate: 0.0010
Epoch 10/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0246 - val_loss: 0.0246 - learning_rate: 0.0010
Epoch 11/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0245 - val_loss: 0.0247 - learning_rate: 0.0010
Epoch 12/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0241 - val_loss: 0.0243 - learning_rate: 0.0010
Epoch 13/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0244 - val_loss: 0.0245 - learning_rate: 0.0010
Epoch 14/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0241 - val_loss: 0.0241 - learning_rate: 0.0010
Epoch 15/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0243 - val_loss: 0.0241 - learning_rate: 0.0010
Epoch 16/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0242 - val_loss: 0.0242 - learning_rate: 0.0010
Epoch 17/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0240 - val_loss: 0.0240 - learning_rate: 0.0010
Epoch 18/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0240 - val_loss: 0.0243 - learning_rate: 0.0010
Epoch 19/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0240 - val_loss: 0.0244 - learning_rate: 0.0010
Epoch 20/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0237 - val_loss: 0.0238 - learning_rate: 1.0000e-04

<keras.src.callbacks.history.History at 0x7ff294f9c340>

```
</div>
---
## Frame Prediction Visualizations

Expand Down Expand Up @@ -266,9 +316,23 @@ for idx, ax in enumerate(axes[1]):
plt.show()
```

<div class="k-default-codeblock">
```
1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 800ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 805ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 790ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 821ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 824ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 928ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 813ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 810ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 814ms/step

```
</div>

![png](/img/examples/vision/conv_lstm/conv_lstm_13_0.png)
![png](/img/examples/vision/conv_lstm/conv_lstm_13_1.png)



Expand All @@ -279,7 +343,8 @@ Finally, we'll pick a few examples from the validation set
and construct some GIFs with them to see the model's
predicted videos.

You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/conv-lstm) and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/conv-lstm).
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/conv-lstm)
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/conv-lstm).


```python
Expand Down Expand Up @@ -315,7 +380,7 @@ for example in examples:

# Construct a GIF from the frames.
with io.BytesIO() as gif:
imageio.mimsave(gif, current_frames, "GIF", fps=5)
imageio.mimsave(gif, current_frames, "GIF", duration=200)
predicted_videos.append(gif.getvalue())

# Display the videos.
Expand All @@ -333,9 +398,67 @@ for i in range(0, len(predicted_videos), 2):

<div class="k-default-codeblock">
```
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 790ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step
Truth Prediction

```
</div>
HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xf8\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfd\xfd\xfd\xfc\xfc\xfc\xfb\xfb\xfb\xf4\…

![Imgur](https://i.imgur.com/UYMTsw7.gif)
HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…

HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfd\xfd\xfd\xfc\xfc\xfc\xf9\xf9\xf9\xf7\…

```
</div>
Loading