Skip to content

Commit

Permalink
[SD][WEB] Add vae tuned model in the SD web (#653)
Browse files Browse the repository at this point in the history
1. Add tuned vae model in the SD web.
2. Use tuned models in case of rdna3 cards.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
  • Loading branch information
Shukla-Gaurav authored Dec 16, 2022
1 parent 72976a2 commit 10160a0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
4 changes: 4 additions & 0 deletions web/models/stable_diffusion/cache_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
subfolder="scheduler",
)

# use tuned unet model in case of rdna3 cards.
if "rdna3" in get_vulkan_triple_flag():
args.use_tuned = True

# set iree-runtime flags
set_iree_runtime_flags()

Expand Down
40 changes: 26 additions & 14 deletions web/models/stable_diffusion/opt_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,32 @@ def get_vae():
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if args.precision in ["fp16", "int8"]:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_8dec_fp16"
if args.version == "v2.1base":
model_name = "vae2base_8dec_fp16"
if args.version == "v2.1":
model_name = "vae2_14dec_fp16"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.import_mlir:
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)
if args.use_tuned:
bucket = "gs://shark_tank/vivian"
if args.version == "v2.1base":
model_name = "vae2base_8dec_fp16_tuned"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
"--iree-flow-enable-conv-winograd-transform",
]
return get_shark_model(bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_8dec_fp16"
if args.version == "v2.1base":
model_name = "vae2base_8dec_fp16"
if args.version == "v2.1":
model_name = "vae2_14dec_fp16"
iree_flags += [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
]
if args.import_mlir:
return get_vae_mlir(model_name, iree_flags)
return get_shark_model(bucket, model_name, iree_flags)

if args.precision == "fp32":
bucket = "gs://shark_tank/stable_diffusion"
Expand Down

0 comments on commit 10160a0

Please sign in to comment.