Skip to content

Commit

Permalink
[WEB][SD] Make unet tuned model default for rdna3 devices (#642)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shukla-Gaurav authored Dec 15, 2022
1 parent 2928179 commit e7e7635
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
4 changes: 4 additions & 0 deletions web/models/stable_diffusion/cache_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
from models.stable_diffusion.utils import set_iree_runtime_flags
from models.stable_diffusion.stable_args import args
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag


model_config = {
Expand Down Expand Up @@ -39,6 +40,9 @@
subfolder="scheduler",
)

# set use_tuned
if "rdna3" not in get_vulkan_triple_flag():
args.use_tuned = False

# set iree-runtime flags
set_iree_runtime_flags(args)
Expand Down
5 changes: 4 additions & 1 deletion web/models/stable_diffusion/opt_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def get_unet(args):
if args.precision == "fp16":
if args.use_tuned:
bucket = "gs://shark_tank/vivian"
model_name = "unet_1dec_fp16_tuned"
if args.version == "v1.4":
model_name = "unet_1dec_fp16_tuned"
if args.version == "v2.1base":
model_name = "unet2base_8dec_fp16_tuned"
return get_shark_model(args, bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
Expand Down
2 changes: 1 addition & 1 deletion web/models/stable_diffusion/stable_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@

p.add_argument(
"--use_tuned",
default=False,
default=True,
action=argparse.BooleanOptionalAction,
help="Download and use the tuned version of the model if available",
)
Expand Down

0 comments on commit e7e7635

Please sign in to comment.