Skip to content

Commit

Permalink
Port stable diffision guide to Keras 3
Browse files Browse the repository at this point in the history
  • Loading branch information
tirthasheshpatel committed Nov 28, 2023
1 parent 9511acb commit 0dc793d
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions guides/keras_cv/generate_images_with_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,17 @@
"""

"""shell
pip install --upgrade keras-cv
pip install -Uq keras-cv >> /dev/null
pip install -Uq keras >> /dev/null # Upgrade to Keras 3.
"""

import os

os.environ["KERAS_BACKEND"] = "jax"

import time
import keras_cv
from tensorflow import keras
import keras
import matplotlib.pyplot as plt

"""
Expand Down Expand Up @@ -275,7 +280,7 @@ def plot_images(images):
"""
### XLA Compilation
TensorFlow comes with the
TensorFlow and JAX come with the
[XLA: Accelerated Linear Algebra](https://www.tensorflow.org/xla) compiler built-in.
`keras_cv.models.StableDiffusion` supports a `jit_compile` argument out of the box.
Setting this argument to `True` enables XLA compilation, resulting in a significant
Expand Down

0 comments on commit 0dc793d

Please sign in to comment.