Skip to content

Commit

Permalink
Add Stable Diffusion XL example
Browse files Browse the repository at this point in the history
  • Loading branch information
gongy committed Aug 9, 2023
1 parent f5ea162 commit 9758bb4
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 0 deletions.
92 changes: 92 additions & 0 deletions 06_gpu_and_ml/stable_diffusion/frontend/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
<html>
<head>
<script
defer
src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"
></script>
<script src="https://cdn.tailwindcss.com"></script>

<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Modal - SD XL 1.0</title>
</head>
<body x-data="state()">
<div class="max-w-3xl mx-auto pt-4 pb-8 px-10 sm:py-12 sm:px-6 lg:px-8">
<h2 class="text-3xl font-medium text-center mb-10">
Stable Diffusion XL 1.0 on Modal
</h2>

<form @submit.prevent="submitPrompt" class="flex items-center justify-center gap-x-4 gap-y-2 w-full mx-auto mb-4">
<input
x-data
x-model="prompt"
x-init="$nextTick(() => { $el.focus(); });"
type="text"
class="flex w-full px-3 py-3 text-md bg-white border rounded-md border-neutral-300 ring-offset-background placeholder:text-neutral-500 focus:border-neutral-300 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-neutral-400 disabled:cursor-not-allowed disabled:opacity-50 text-center"
/>
<button
type="submit"
class="inline-flex items-center justify-center px-4 py-3 text-sm font-medium tracking-wide text-white transition-colors duration-200 rounded-md bg-neutral-950 hover:bg-neutral-900 focus:ring-2 focus:ring-offset-2 focus:ring-neutral-900 focus:shadow-outline focus:outline-none"
:disabled="loading"
>
<span x-show="!loading">Submit</span>
<div class="animate-spin w-6 h-6 mx-3" x-show="loading">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
class="lucide lucide-loader-2"
>
<path d="M21 12a9 9 0 1 1-6.219-8.56" />
</svg>
</div>
</button>
</form>

<div class="mx-auto w-full max-w-[768px] relative grid">
<div style="padding-top: 100%;" x-show="loading" class="absolute w-full h-full animate-pulse bg-neutral-100 rounded-md"></div>
<img
x-show="imageURL"
class="rounded-md self-center justify-self-center"
:src="imageURL"
/>
</div>
</div>

<script>
function state() {
return {
prompt: "a beautiful Japanese temple, butterflies flying around",
features: [],
submitted: "",
submittedFeatures: [],
loading: false,
imageURL: '',
async submitPrompt() {
if (!this.prompt) return;
this.submitted = this.prompt;
this.submittedFeatures = [...this.features];
this.loading = true;

const queryString = new URLSearchParams(this.features.map((f) => ['features', f])).toString();
const res = await fetch(`/infer/${this.submitted}?${queryString}`);

console.log("res returned ... but so slow?");
const blob = await res.blob();
console.log("got blob")
this.imageURL = URL.createObjectURL(blob);
this.loading = false;
console.log(this.imageURL);
},
toggleFeature(featureName) {
let index = this.features.indexOf(featureName);
index == -1 ? this.features.push(featureName) : this.features.splice(index, 1);
}
}
}
</script>
</body>
</html>
2 changes: 2 additions & 0 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# that makes it run faster on Modal. The example takes about 10s to cold start
# and about 1.0s per image generated.
#
# To use the new XL 1.0 model, see the example posted [here](/docs/guide/ex/stable_diffusion_xl).
#
# For instance, here are 9 images produced by the prompt
# `An 1600s oil painting of the New York City skyline`
#
Expand Down
169 changes: 169 additions & 0 deletions 06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# # Stable Diffusion XL 1.0
#
# This example is similar to the [Stable Diffusion CLI](/docs/guide/ex/stable_diffusion_cli)
# example, but it generates images from the larger XL 1.0 model. Specifically, it runs the
# first set of steps with the base model, followed by the refiner model.
#
# [Try out the live demo here!](https://modal-labs--stable-diffusion-xl-app.modal.run/) The first
# generation may include a cold-start, which takes around 20 seconds. The inference speed depends on the GPU
# and step count (for reference, an A100 runs 40 steps in 8 seconds).

# ## Basic setup

from pathlib import Path
from modal import Stub, Mount, Image, gpu, method, asgi_app

# ## Define a container image
#
# To take advantage of Modal's blazing fast cold-start times, we'll need to download our model weights
# inside our container image with a download function. We ignore binaries, ONNX weights and 32-bit weights.
#
# Tip: avoid using global variables in this function to ensure the download step detects model changes and
# triggers a rebuild.


def download_models():
from huggingface_hub import snapshot_download

ignore = ["*.bin", "*.onnx_data", "*/diffusion_pytorch_model.safetensors"]
snapshot_download(
"stabilityai/stable-diffusion-xl-base-1.0", ignore_patterns=ignore
)
snapshot_download(
"stabilityai/stable-diffusion-xl-refiner-1.0", ignore_patterns=ignore
)


image = (
Image.debian_slim()
.apt_install(
"libglib2.0-0", "libsm6", "libxrender1", "libxext6", "ffmpeg", "libgl1"
)
.pip_install(
"diffusers~=0.19",
"invisible_watermark~=0.1",
"transformers~=4.31",
"accelerate~=0.21",
"safetensors~=0.3",
)
.run_function(download_models)
)

stub = Stub("stable-diffusion-xl", image=image)

# ## Load model and run inference
#
# The container lifecycle [`__enter__` function](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta)
# loads the model at startup. Then, we evaluate it in the `run_inference` function.
#
# To avoid excessive cold-starts, we set the idle timeout to 240 seconds, meaning once a GPU has loaded the model it will stay
# online for 4 minutes before spinning down. This can be adjusted for cost/experience trade-offs.


@stub.cls(gpu=gpu.A10G(), container_idle_timeout=240)
class Model:
def __enter__(self):
from diffusers import DiffusionPipeline
import torch

load_options = dict(
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
device_map="auto",
)

# Load base model
self.base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", **load_options
)

# Load refiner model
self.refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.base.text_encoder_2,
vae=self.base.vae,
**load_options,
)

# These suggested compile commands actually increase inference time, but may be mis-used.
# self.base.unet = torch.compile(self.base.unet, mode="reduce-overhead", fullgraph=True)
# self.refiner.unet = torch.compile(self.refiner.unet, mode="reduce-overhead", fullgraph=True)

@method()
def inference(self, prompt, n_steps=24, high_noise_frac=0.8):
negative_prompt = "disfigured, ugly, deformed"
image = self.base(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=n_steps,
denoising_end=high_noise_frac,
output_type="latent",
).images
image = self.refiner(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=n_steps,
denoising_start=high_noise_frac,
image=image,
).images[0]

import io

image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
image_bytes = image_bytes.getvalue()

return image_bytes


# And this is our entrypoint; where the CLI is invoked. Explore CLI options
# with: `modal run stable_diffusion_xl.py --prompt 'An astronaut riding a green horse'`


@stub.local_entrypoint()
def main(prompt: str):
image_bytes = Model().inference.call(prompt)

output_path = f"output.png"
print(f"Saving it to {output_path}")
with open(output_path, "wb") as f:
f.write(image_bytes)


# ## A user interface
#
# Here we ship a simple web application that exposes a front-end (written in Alpine.js) for
# our backend deployment.
#
# The Model class will serve multiple users from a its own shared pool of warm GPU containers automatically.
#
# We can deploy this with `modal deploy stable_diffusino_xl.py`.

frontend_path = Path(__file__).parent / "frontend"


@stub.function(
mounts=[Mount.from_local_dir(frontend_path, remote_path="/assets")],
allow_concurrent_inputs=20,
)
@asgi_app()
def app():
from fastapi import FastAPI
import fastapi.staticfiles

web_app = FastAPI()

@web_app.get("/infer/{prompt}")
async def infer(prompt: str):
from fastapi.responses import Response

image_bytes = Model().inference.call(prompt)

return Response(image_bytes, media_type="image/png")

web_app.mount(
"/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True)
)

return web_app

0 comments on commit 9758bb4

Please sign in to comment.