Skip to content

Commit

Permalink
Add script to auto annotate SD models and variants (#751)
Browse files Browse the repository at this point in the history
* Add script to auto annotate SD models and variants

* Add model config files

* Add script to auto annotate SD models and variants

* Add model config files

* Move config files to shark_tank
  • Loading branch information
yzhang93 authored Jan 4, 2023
1 parent 017dcab commit 782b449
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 8 deletions.
108 changes: 108 additions & 0 deletions shark/examples/shark_inference/stable_diffusion/sd_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
from shark.model_annotation import model_annotation, create_context
from shark.iree_utils._common import run_cmd, iree_target_map
from shark.shark_downloader import (
download_model,
download_public_file,
WORKDIR,
)
from shark.parser import shark_args
from stable_args import args
from opt_params import get_params
from utils import set_init_device_flags


# Downloads the model (Unet or VAE fp16) from shark_tank
set_init_device_flags()
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{args.variant}/untuned"
use_winograd = False
if args.annotation_model == "unet":
if args.version == "v2_1base":
use_winograd = True
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
use_winograd = True
is_base = "/base" if args.use_base_vae else ""
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/untuned{is_base}"

bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
)

# Downloads the tuned config files from shark_tank
config_bucket = "gs://shark_tank/sd_tuned/configs/"
if use_winograd:
config_name = f"{args.annotation_model}_winograd.json"
full_gs_url = config_bucket + config_name
winograd_config_dir = f"{WORKDIR}configs/" + config_name
download_public_file(full_gs_url, winograd_config_dir, True)

if args.annotation_model == "unet":
if args.variant in ["anythingv3", "analogdiffusion"]:
args.max_length = 77
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}.json"
full_gs_url = config_bucket + config_name
lowering_config_dir = f"{WORKDIR}configs/" + config_name
download_public_file(full_gs_url, lowering_config_dir, True)

# Annotate the model with Winograd attribute on selected conv ops
if use_winograd:
with create_context() as ctx:
winograd_model = model_annotation(
ctx,
input_contents=mlir_model,
config_path=winograd_config_dir,
search_op="conv",
winograd=use_winograd,
)
with open(
f"{args.annotation_output}/{model_name}_tuned_torch.mlir", "w"
) as f:
f.write(str(winograd_model))

# For Unet annotate the model with tuned lowering configs
if args.annotation_model == "unet":
if use_winograd:
input_mlir = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
else:
input_mlir = f"{WORKDIR}/{model_name}_torch/{model_name}_torch.mlir"
dump_after = "iree-flow-pad-linalg-ops"

# Dump IR after padding/img2col/winograd passes
run_cmd(
f"iree-compile {input_mlir} "
"--iree-input-type=tm_tensor "
f"--iree-hal-target-backends={iree_target_map(args.device)} "
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
"--iree-stream-resource-index-bits=64 "
"--iree-vm-target-index-bits=64 "
"--iree-flow-enable-padding-linalg-ops "
"--iree-flow-linalg-ops-padding-size=32 "
"--iree-flow-enable-conv-img2col-transform "
f"--mlir-print-ir-after={dump_after} "
"--compile-to=flow "
f"2>{args.annotation_output}/dump_after_winograd.mlir "
)

# Annotate the model with lowering configs in the config file
with create_context() as ctx:
tuned_model = model_annotation(
ctx,
input_contents=f"{args.annotation_output}/dump_after_winograd.mlir",
config_path=lowering_config_dir,
search_op="all",
)

# Remove the intermediate mlir and save the final annotated model
os.remove(f"{args.annotation_output}/dump_after_winograd.mlir")
output_path = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
with open(output_path, "w") as f:
f.write(str(tuned_model))
print(f"Saved the annotated mlir in {output_path}.")
24 changes: 24 additions & 0 deletions shark/examples/shark_inference/stable_diffusion/stable_args.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import argparse
from pathlib import Path


def path_expand(s):
return Path(s).expanduser().resolve()


p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
Expand Down Expand Up @@ -223,4 +229,22 @@
help="flag for removing the pregress bar animation during image generation",
)

##############################################################################
### SD model auto-annotation flags
##############################################################################

p.add_argument(
"--annotation_output",
type=path_expand,
default="./",
help="Directory to save the annotated mlir file",
)

p.add_argument(
"--annotation_model",
type=str,
default="unet",
help="Options are unet and vae.",
)

args = p.parse_args()
58 changes: 50 additions & 8 deletions shark/model_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,23 @@ def model_annotation(
input_contents: str,
config_path: str,
search_op: str,
winograd: bool = False,
):
if os.path.isfile(input_contents):
with open(input_contents, "rb") as f:
input_contents = f.read()
module = ir.Module.parse(input_contents)

configs = load_model_configs(config_path)
if winograd:
with open(config_path, "r") as f:
data = json.load(f)
configs = data["c,f"]
else:
configs = load_model_configs(config_path)

# The Python API does not expose a general walk() function, so we just
# do it ourselves.
walk_children(module.operation, configs, search_op)
walk_children(module.operation, configs, search_op, winograd)

if not module.operation.verify():
raise RuntimeError("Modified program does not verify!")
Expand Down Expand Up @@ -92,7 +98,9 @@ def load_model_configs(config_path: str):
return config


def walk_children(op: ir.Operation, configs: List[Dict], search_op: str):
def walk_children(
op: ir.Operation, configs: List[Dict], search_op: str, winograd: bool
):
if search_op == "matmul":
op_names = ["linalg.matmul", "mhlo.dot"]
elif search_op == "bmm":
Expand Down Expand Up @@ -121,6 +129,11 @@ def walk_children(op: ir.Operation, configs: List[Dict], search_op: str):
# 'operation' and 'name' attributes.
if isinstance(child_op, ir.OpView):
child_op = child_op.operation
if winograd and child_op.name in [
"linalg.conv_2d_nchw_fchw",
"linalg.conv_2d_nhwc_hwcf",
]:
add_winograd_attribute(child_op, configs)
if child_op.name in op_names:
if child_op.name == "linalg.generic":
# This is for generic op that has contractionOpInterface
Expand Down Expand Up @@ -151,7 +164,7 @@ def walk_children(op: ir.Operation, configs: List[Dict], search_op: str):
)
print(f"Updated op {child_op}", file=sys.stderr)

walk_children(child_op, configs, search_op)
walk_children(child_op, configs, search_op, winograd)


def get_op_shape(op: ir.Operation, search_op: str):
Expand Down Expand Up @@ -294,10 +307,6 @@ def add_attributes(op: ir.Operation, config: List[Dict]):
pipeline_depth = config["pipeline_depth"]
if "split_k" in config.keys():
split_k = config["split_k"]
if "devices" in config.keys():
devices = config["devices"]
if "shard_sizes" in config.keys():
shard_sizes = config["shard_sizes"]
elif "SPIRV" in config["pipeline"]:
pipeline = config["pipeline"]
tile_sizes = [
Expand Down Expand Up @@ -355,6 +364,39 @@ def add_attributes(op: ir.Operation, config: List[Dict]):
add_attribute_by_name(op, "iree_flow_split_k", split_k)


def add_winograd_attribute(op: ir.Operation, config: List):
op_result = str(op.results[0]).split("ins(")[1]
dilation = int(
str(op.attributes["dilations"]).split("dense<")[1].split(">")[0]
)
stride = int(
str(op.attributes["strides"]).split("dense<")[1].split(">")[0]
)

if op.name == "linalg.conv_2d_nchw_fchw":
f = int(op_result.split("tensor<")[2].split("x")[0])
c = int(op_result.split("tensor<")[2].split("x")[1])
kh = int(op_result.split("tensor<")[2].split("x")[2])
kw = int(op_result.split("tensor<")[2].split("x")[3])
else:
kh = int(op_result.split("tensor<")[2].split("x")[0])
kw = int(op_result.split("tensor<")[2].split("x")[1])
c = int(op_result.split("tensor<")[2].split("x")[2])
f = int(op_result.split("tensor<")[2].split("x")[3])

if (
dilation == 1
and stride == 1
and kh == 3
and kw == 3
and [c, f] in config
):
op.attributes["iree_winograd_conv"] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), 1
)
print("Apply Winograd on selected conv op: ", op)


def add_attribute_by_name(op: ir.Operation, name: str, val: int):
attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val)
op.attributes[name] = attr
Expand Down

0 comments on commit 782b449

Please sign in to comment.