-
Notifications
You must be signed in to change notification settings - Fork 176
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
263 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
<html> | ||
<head> | ||
<script | ||
defer | ||
src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js" | ||
></script> | ||
<script src="https://cdn.tailwindcss.com"></script> | ||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1" /> | ||
<title>Modal - SD XL 1.0</title> | ||
</head> | ||
<body x-data="state()"> | ||
<div class="max-w-3xl mx-auto pt-4 pb-8 px-10 sm:py-12 sm:px-6 lg:px-8"> | ||
<h2 class="text-3xl font-medium text-center mb-10"> | ||
Stable Diffusion XL 1.0 on Modal | ||
</h2> | ||
|
||
<form @submit.prevent="submitPrompt" class="flex items-center justify-center gap-x-4 gap-y-2 w-full mx-auto mb-4"> | ||
<input | ||
x-data | ||
x-model="prompt" | ||
x-init="$nextTick(() => { $el.focus(); });" | ||
type="text" | ||
class="flex w-full px-3 py-3 text-md bg-white border rounded-md border-neutral-300 ring-offset-background placeholder:text-neutral-500 focus:border-neutral-300 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-neutral-400 disabled:cursor-not-allowed disabled:opacity-50 text-center" | ||
/> | ||
<button | ||
type="submit" | ||
class="inline-flex items-center justify-center px-4 py-3 text-sm font-medium tracking-wide text-white transition-colors duration-200 rounded-md bg-neutral-950 hover:bg-neutral-900 focus:ring-2 focus:ring-offset-2 focus:ring-neutral-900 focus:shadow-outline focus:outline-none" | ||
:disabled="loading" | ||
> | ||
<span x-show="!loading">Submit</span> | ||
<div class="animate-spin w-6 h-6 mx-3" x-show="loading"> | ||
<svg | ||
xmlns="http://www.w3.org/2000/svg" | ||
viewBox="0 0 24 24" | ||
fill="none" | ||
stroke="currentColor" | ||
stroke-width="2" | ||
stroke-linecap="round" | ||
stroke-linejoin="round" | ||
class="lucide lucide-loader-2" | ||
> | ||
<path d="M21 12a9 9 0 1 1-6.219-8.56" /> | ||
</svg> | ||
</div> | ||
</button> | ||
</form> | ||
|
||
<div class="mx-auto w-full max-w-[768px] relative grid"> | ||
<div style="padding-top: 100%;" x-show="loading" class="absolute w-full h-full animate-pulse bg-neutral-100 rounded-md"></div> | ||
<img | ||
x-show="imageURL" | ||
class="rounded-md self-center justify-self-center" | ||
:src="imageURL" | ||
/> | ||
</div> | ||
</div> | ||
|
||
<script> | ||
function state() { | ||
return { | ||
prompt: "a beautiful Japanese temple, butterflies flying around", | ||
features: [], | ||
submitted: "", | ||
submittedFeatures: [], | ||
loading: false, | ||
imageURL: '', | ||
async submitPrompt() { | ||
if (!this.prompt) return; | ||
this.submitted = this.prompt; | ||
this.submittedFeatures = [...this.features]; | ||
this.loading = true; | ||
|
||
const queryString = new URLSearchParams(this.features.map((f) => ['features', f])).toString(); | ||
const res = await fetch(`/infer/${this.submitted}?${queryString}`); | ||
|
||
console.log("res returned ... but so slow?"); | ||
const blob = await res.blob(); | ||
console.log("got blob") | ||
this.imageURL = URL.createObjectURL(blob); | ||
this.loading = false; | ||
console.log(this.imageURL); | ||
}, | ||
toggleFeature(featureName) { | ||
let index = this.features.indexOf(featureName); | ||
index == -1 ? this.features.push(featureName) : this.features.splice(index, 1); | ||
} | ||
} | ||
} | ||
</script> | ||
</body> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# # Stable Diffusion XL 1.0 | ||
# | ||
# This example is similar to the [Stable Diffusion CLI](/docs/guide/ex/stable_diffusion_cli) | ||
# example, but it generates images from the larger XL 1.0 model. Specifically, it runs the | ||
# first set of steps with the base model, followed by the refiner model. | ||
# | ||
# [Try out the live demo here!](https://modal-labs--stable-diffusion-xl-app.modal.run/) The first | ||
# generation may include a cold-start, which takes around 20 seconds. The inference speed depends on the GPU | ||
# and step count (for reference, an A100 runs 40 steps in 8 seconds). | ||
|
||
# ## Basic setup | ||
|
||
from pathlib import Path | ||
from modal import Stub, Mount, Image, gpu, method, asgi_app | ||
|
||
# ## Define a container image | ||
# | ||
# To take advantage of Modal's blazing fast cold-start times, we'll need to download our model weights | ||
# inside our container image with a download function. We ignore binaries, ONNX weights and 32-bit weights. | ||
# | ||
# Tip: avoid using global variables in this function to ensure the download step detects model changes and | ||
# triggers a rebuild. | ||
|
||
|
||
def download_models(): | ||
from huggingface_hub import snapshot_download | ||
|
||
ignore = ["*.bin", "*.onnx_data", "*/diffusion_pytorch_model.safetensors"] | ||
snapshot_download( | ||
"stabilityai/stable-diffusion-xl-base-1.0", ignore_patterns=ignore | ||
) | ||
snapshot_download( | ||
"stabilityai/stable-diffusion-xl-refiner-1.0", ignore_patterns=ignore | ||
) | ||
|
||
|
||
image = ( | ||
Image.debian_slim() | ||
.apt_install( | ||
"libglib2.0-0", "libsm6", "libxrender1", "libxext6", "ffmpeg", "libgl1" | ||
) | ||
.pip_install( | ||
"diffusers~=0.19", | ||
"invisible_watermark~=0.1", | ||
"transformers~=4.31", | ||
"accelerate~=0.21", | ||
"safetensors~=0.3", | ||
) | ||
.run_function(download_models) | ||
) | ||
|
||
stub = Stub("stable-diffusion-xl", image=image) | ||
|
||
# ## Load model and run inference | ||
# | ||
# The container lifecycle [`__enter__` function](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta) | ||
# loads the model at startup. Then, we evaluate it in the `run_inference` function. | ||
# | ||
# To avoid excessive cold-starts, we set the idle timeout to 240 seconds, meaning once a GPU has loaded the model it will stay | ||
# online for 4 minutes before spinning down. This can be adjusted for cost/experience trade-offs. | ||
|
||
|
||
@stub.cls(gpu=gpu.A10G(), container_idle_timeout=240) | ||
class Model: | ||
def __enter__(self): | ||
from diffusers import DiffusionPipeline | ||
import torch | ||
|
||
load_options = dict( | ||
torch_dtype=torch.float16, | ||
use_safetensors=True, | ||
variant="fp16", | ||
device_map="auto", | ||
) | ||
|
||
# Load base model | ||
self.base = DiffusionPipeline.from_pretrained( | ||
"stabilityai/stable-diffusion-xl-base-1.0", **load_options | ||
) | ||
|
||
# Load refiner model | ||
self.refiner = DiffusionPipeline.from_pretrained( | ||
"stabilityai/stable-diffusion-xl-refiner-1.0", | ||
text_encoder_2=self.base.text_encoder_2, | ||
vae=self.base.vae, | ||
**load_options, | ||
) | ||
|
||
# These suggested compile commands actually increase inference time, but may be mis-used. | ||
# self.base.unet = torch.compile(self.base.unet, mode="reduce-overhead", fullgraph=True) | ||
# self.refiner.unet = torch.compile(self.refiner.unet, mode="reduce-overhead", fullgraph=True) | ||
|
||
@method() | ||
def inference(self, prompt, n_steps=24, high_noise_frac=0.8): | ||
negative_prompt = "disfigured, ugly, deformed" | ||
image = self.base( | ||
prompt=prompt, | ||
negative_prompt=negative_prompt, | ||
num_inference_steps=n_steps, | ||
denoising_end=high_noise_frac, | ||
output_type="latent", | ||
).images | ||
image = self.refiner( | ||
prompt=prompt, | ||
negative_prompt=negative_prompt, | ||
num_inference_steps=n_steps, | ||
denoising_start=high_noise_frac, | ||
image=image, | ||
).images[0] | ||
|
||
import io | ||
|
||
image_bytes = io.BytesIO() | ||
image.save(image_bytes, format="PNG") | ||
image_bytes = image_bytes.getvalue() | ||
|
||
return image_bytes | ||
|
||
|
||
# And this is our entrypoint; where the CLI is invoked. Explore CLI options | ||
# with: `modal run stable_diffusion_xl.py --prompt 'An astronaut riding a green horse'` | ||
|
||
|
||
@stub.local_entrypoint() | ||
def main(prompt: str): | ||
image_bytes = Model().inference.call(prompt) | ||
|
||
output_path = f"output.png" | ||
print(f"Saving it to {output_path}") | ||
with open(output_path, "wb") as f: | ||
f.write(image_bytes) | ||
|
||
|
||
# ## A user interface | ||
# | ||
# Here we ship a simple web application that exposes a front-end (written in Alpine.js) for | ||
# our backend deployment. | ||
# | ||
# The Model class will serve multiple users from a its own shared pool of warm GPU containers automatically. | ||
# | ||
# We can deploy this with `modal deploy stable_diffusino_xl.py`. | ||
|
||
frontend_path = Path(__file__).parent / "frontend" | ||
|
||
|
||
@stub.function( | ||
mounts=[Mount.from_local_dir(frontend_path, remote_path="/assets")], | ||
allow_concurrent_inputs=20, | ||
) | ||
@asgi_app() | ||
def app(): | ||
from fastapi import FastAPI | ||
import fastapi.staticfiles | ||
|
||
web_app = FastAPI() | ||
|
||
@web_app.get("/infer/{prompt}") | ||
async def infer(prompt: str): | ||
from fastapi.responses import Response | ||
|
||
image_bytes = Model().inference.call(prompt) | ||
|
||
return Response(image_bytes, media_type="image/png") | ||
|
||
web_app.mount( | ||
"/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True) | ||
) | ||
|
||
return web_app |