diff --git a/06_gpu_and_ml/stable_diffusion/frontend/index.html b/06_gpu_and_ml/stable_diffusion/frontend/index.html new file mode 100644 index 000000000..39f99aece --- /dev/null +++ b/06_gpu_and_ml/stable_diffusion/frontend/index.html @@ -0,0 +1,92 @@ + + + + + + + Modal - SD XL 1.0 + + +
+

+ Stable Diffusion XL 1.0 on Modal +

+ +
+ + +
+ +
+
+ +
+
+ + + + diff --git a/06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py b/06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py index c8c441dd6..d182ec91a 100644 --- a/06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py +++ b/06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py @@ -9,6 +9,8 @@ # that makes it run faster on Modal. The example takes about 10s to cold start # and about 1.0s per image generated. # +# To use the new XL 1.0 model, see the example posted [here](/docs/guide/ex/stable_diffusion_xl). +# # For instance, here are 9 images produced by the prompt # `An 1600s oil painting of the New York City skyline` # diff --git a/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py b/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py new file mode 100644 index 000000000..d562e712b --- /dev/null +++ b/06_gpu_and_ml/stable_diffusion/stable_diffusion_xl.py @@ -0,0 +1,169 @@ +# # Stable Diffusion XL 1.0 +# +# This example is similar to the [Stable Diffusion CLI](/docs/guide/ex/stable_diffusion_cli) +# example, but it generates images from the larger XL 1.0 model. Specifically, it runs the +# first set of steps with the base model, followed by the refiner model. +# +# [Try out the live demo here!](https://modal-labs--stable-diffusion-xl-app.modal.run/) The first +# generation may include a cold-start, which takes around 20 seconds. The inference speed depends on the GPU +# and step count (for reference, an A100 runs 40 steps in 8 seconds). + +# ## Basic setup + +from pathlib import Path +from modal import Stub, Mount, Image, gpu, method, asgi_app + +# ## Define a container image +# +# To take advantage of Modal's blazing fast cold-start times, we'll need to download our model weights +# inside our container image with a download function. We ignore binaries, ONNX weights and 32-bit weights. +# +# Tip: avoid using global variables in this function to ensure the download step detects model changes and +# triggers a rebuild. + + +def download_models(): + from huggingface_hub import snapshot_download + + ignore = ["*.bin", "*.onnx_data", "*/diffusion_pytorch_model.safetensors"] + snapshot_download( + "stabilityai/stable-diffusion-xl-base-1.0", ignore_patterns=ignore + ) + snapshot_download( + "stabilityai/stable-diffusion-xl-refiner-1.0", ignore_patterns=ignore + ) + + +image = ( + Image.debian_slim() + .apt_install( + "libglib2.0-0", "libsm6", "libxrender1", "libxext6", "ffmpeg", "libgl1" + ) + .pip_install( + "diffusers~=0.19", + "invisible_watermark~=0.1", + "transformers~=4.31", + "accelerate~=0.21", + "safetensors~=0.3", + ) + .run_function(download_models) +) + +stub = Stub("stable-diffusion-xl", image=image) + +# ## Load model and run inference +# +# The container lifecycle [`__enter__` function](https://modal.com/docs/guide/lifecycle-functions#container-lifecycle-beta) +# loads the model at startup. Then, we evaluate it in the `run_inference` function. +# +# To avoid excessive cold-starts, we set the idle timeout to 240 seconds, meaning once a GPU has loaded the model it will stay +# online for 4 minutes before spinning down. This can be adjusted for cost/experience trade-offs. + + +@stub.cls(gpu=gpu.A10G(), container_idle_timeout=240) +class Model: + def __enter__(self): + from diffusers import DiffusionPipeline + import torch + + load_options = dict( + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16", + device_map="auto", + ) + + # Load base model + self.base = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", **load_options + ) + + # Load refiner model + self.refiner = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-refiner-1.0", + text_encoder_2=self.base.text_encoder_2, + vae=self.base.vae, + **load_options, + ) + + # These suggested compile commands actually increase inference time, but may be mis-used. + # self.base.unet = torch.compile(self.base.unet, mode="reduce-overhead", fullgraph=True) + # self.refiner.unet = torch.compile(self.refiner.unet, mode="reduce-overhead", fullgraph=True) + + @method() + def inference(self, prompt, n_steps=24, high_noise_frac=0.8): + negative_prompt = "disfigured, ugly, deformed" + image = self.base( + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=n_steps, + denoising_end=high_noise_frac, + output_type="latent", + ).images + image = self.refiner( + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=n_steps, + denoising_start=high_noise_frac, + image=image, + ).images[0] + + import io + + image_bytes = io.BytesIO() + image.save(image_bytes, format="PNG") + image_bytes = image_bytes.getvalue() + + return image_bytes + + +# And this is our entrypoint; where the CLI is invoked. Explore CLI options +# with: `modal run stable_diffusion_xl.py --prompt 'An astronaut riding a green horse'` + + +@stub.local_entrypoint() +def main(prompt: str): + image_bytes = Model().inference.call(prompt) + + output_path = f"output.png" + print(f"Saving it to {output_path}") + with open(output_path, "wb") as f: + f.write(image_bytes) + + +# ## A user interface +# +# Here we ship a simple web application that exposes a front-end (written in Alpine.js) for +# our backend deployment. +# +# The Model class will serve multiple users from a its own shared pool of warm GPU containers automatically. +# +# We can deploy this with `modal deploy stable_diffusino_xl.py`. + +frontend_path = Path(__file__).parent / "frontend" + + +@stub.function( + mounts=[Mount.from_local_dir(frontend_path, remote_path="/assets")], + allow_concurrent_inputs=20, +) +@asgi_app() +def app(): + from fastapi import FastAPI + import fastapi.staticfiles + + web_app = FastAPI() + + @web_app.get("/infer/{prompt}") + async def infer(prompt: str): + from fastapi.responses import Response + + image_bytes = Model().inference.call(prompt) + + return Response(image_bytes, media_type="image/png") + + web_app.mount( + "/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True) + ) + + return web_app