Skip to content

Commit

Permalink
[WEB] Update models to 8dec and also default values (#620)
Browse files Browse the repository at this point in the history
1. Update the models to 8 dec.
2. precision is default to `fp16` in CLI.
3. version is default to `v2.1base` in CLI as well as web.
4. The default scheduler is set to `EulerDiscrete` now.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
  • Loading branch information
Shukla-Gaurav authored Dec 13, 2022
1 parent 08e373a commit d913453
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
p.add_argument(
"--version",
type=str,
default="v1.4",
default="v2.1base",
help="Specify version of stable diffusion model",
)

Expand All @@ -48,7 +48,7 @@
)

p.add_argument(
"--precision", type=str, default="fp32", help="precision to run the model."
"--precision", type=str, default="fp16", help="precision to run the model."
)

p.add_argument(
Expand Down
6 changes: 3 additions & 3 deletions web/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ def resource_path(relative_path):
)
version = gr.Radio(
label="Version",
value="v1.4",
value="v2.1base",
choices=["v1.4", "v2.1base"],
)
with gr.Row():
scheduler_key = gr.Dropdown(
label="Scheduler",
value="DPMSolverMultistep",
value="EulerDiscrete",
choices=[
"DDIM",
"PNDM",
Expand Down Expand Up @@ -174,9 +174,9 @@ def resource_path(relative_path):
outputs=[generated_img, std_output],
)

shark_web.queue()
shark_web.launch(
share=False,
server_name="0.0.0.0",
server_port=8080,
enable_queue=True,
)
11 changes: 11 additions & 0 deletions web/models/stable_diffusion/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

model_config = {
"v2": "stabilityai/stable-diffusion-2",
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
"v1.4": "CompVis/stable-diffusion-v1-4",
}

Expand All @@ -19,6 +20,16 @@
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v2.1base": {
"clip": (torch.randint(1, 2, (1, 77)),),
"vae": (torch.randn(1, 4, 64, 64),),
"unet": (
torch.randn(1, 4, 64, 64), # latents
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, 77, 1024), # embedding
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v1.4": {
"clip": (torch.randint(1, 2, (1, 77)),),
"vae": (torch.randn(1, 4, 64, 64),),
Expand Down
6 changes: 3 additions & 3 deletions web/models/stable_diffusion/opt_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_unet(args):
return get_shark_model(args, bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "unet_1dec_fp16"
model_name = "unet_8dec_fp16"
if args.version == "v2.1base":
model_name = "unet2base_8dec_fp16"
iree_flags += [
Expand Down Expand Up @@ -56,7 +56,7 @@ def get_vae(args):
)
if args.precision == "fp16":
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_1dec_fp16"
model_name = "vae_8dec_fp16"
if args.version == "v2.1base":
model_name = "vae2base_8dec_fp16"
iree_flags += [
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_clip(args):
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
bucket = "gs://shark_tank/stable_diffusion"
model_name = "clip_1dec_fp32"
model_name = "clip_8dec_fp32"
if args.version == "v2.1base":
model_name = "clip2base_8dec_fp32"
iree_flags += [
Expand Down
6 changes: 3 additions & 3 deletions web/models/stable_diffusion/stable_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
p.add_argument(
"--version",
type=str,
default="v1.4",
default="v2.1base",
help="Specify version of stable diffusion model",
)

Expand Down Expand Up @@ -60,8 +60,8 @@
p.add_argument(
"--scheduler",
type=str,
default="DPMSolverMultistep",
help="can be [PNDM, LMSDiscrete, DDIM, DPMSolverMultistep]",
default="EulerDiscrete",
help="can be [PNDM, LMSDiscrete, DDIM, DPMSolverMultistep, EulerDiscrete]",
)

p.add_argument(
Expand Down

0 comments on commit d913453

Please sign in to comment.