From 782b449c7104ca032e5d8d58ddeda089969b7a1c Mon Sep 17 00:00:00 2001 From: yzhang93 Date: Wed, 4 Jan 2023 15:53:10 -0800 Subject: [PATCH] Add script to auto annotate SD models and variants (#751) * Add script to auto annotate SD models and variants * Add model config files * Add script to auto annotate SD models and variants * Add model config files * Move config files to shark_tank --- .../stable_diffusion/sd_annotation.py | 108 ++++++++++++++++++ .../stable_diffusion/stable_args.py | 24 ++++ shark/model_annotation.py | 58 ++++++++-- 3 files changed, 182 insertions(+), 8 deletions(-) create mode 100644 shark/examples/shark_inference/stable_diffusion/sd_annotation.py diff --git a/shark/examples/shark_inference/stable_diffusion/sd_annotation.py b/shark/examples/shark_inference/stable_diffusion/sd_annotation.py new file mode 100644 index 0000000000..09b84b2d4a --- /dev/null +++ b/shark/examples/shark_inference/stable_diffusion/sd_annotation.py @@ -0,0 +1,108 @@ +import os +from shark.model_annotation import model_annotation, create_context +from shark.iree_utils._common import run_cmd, iree_target_map +from shark.shark_downloader import ( + download_model, + download_public_file, + WORKDIR, +) +from shark.parser import shark_args +from stable_args import args +from opt_params import get_params +from utils import set_init_device_flags + + +# Downloads the model (Unet or VAE fp16) from shark_tank +set_init_device_flags() +shark_args.local_tank_cache = args.local_tank_cache +bucket_key = f"{args.variant}/untuned" +use_winograd = False +if args.annotation_model == "unet": + if args.version == "v2_1base": + use_winograd = True + model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/untuned" +elif args.annotation_model == "vae": + use_winograd = True + is_base = "/base" if args.use_base_vae else "" + model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/untuned{is_base}" + +bucket, model_name, iree_flags = get_params( + bucket_key, model_key, args.annotation_model, "untuned", args.precision +) +mlir_model, func_name, inputs, golden_out = download_model( + model_name, + tank_url=bucket, + frontend="torch", +) + +# Downloads the tuned config files from shark_tank +config_bucket = "gs://shark_tank/sd_tuned/configs/" +if use_winograd: + config_name = f"{args.annotation_model}_winograd.json" + full_gs_url = config_bucket + config_name + winograd_config_dir = f"{WORKDIR}configs/" + config_name + download_public_file(full_gs_url, winograd_config_dir, True) + +if args.annotation_model == "unet": + if args.variant in ["anythingv3", "analogdiffusion"]: + args.max_length = 77 + config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}.json" + full_gs_url = config_bucket + config_name + lowering_config_dir = f"{WORKDIR}configs/" + config_name + download_public_file(full_gs_url, lowering_config_dir, True) + +# Annotate the model with Winograd attribute on selected conv ops +if use_winograd: + with create_context() as ctx: + winograd_model = model_annotation( + ctx, + input_contents=mlir_model, + config_path=winograd_config_dir, + search_op="conv", + winograd=use_winograd, + ) + with open( + f"{args.annotation_output}/{model_name}_tuned_torch.mlir", "w" + ) as f: + f.write(str(winograd_model)) + +# For Unet annotate the model with tuned lowering configs +if args.annotation_model == "unet": + if use_winograd: + input_mlir = f"{args.annotation_output}/{model_name}_tuned_torch.mlir" + dump_after = "iree-linalg-ext-convert-conv2d-to-winograd" + else: + input_mlir = f"{WORKDIR}/{model_name}_torch/{model_name}_torch.mlir" + dump_after = "iree-flow-pad-linalg-ops" + + # Dump IR after padding/img2col/winograd passes + run_cmd( + f"iree-compile {input_mlir} " + "--iree-input-type=tm_tensor " + f"--iree-hal-target-backends={iree_target_map(args.device)} " + f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} " + "--iree-stream-resource-index-bits=64 " + "--iree-vm-target-index-bits=64 " + "--iree-flow-enable-padding-linalg-ops " + "--iree-flow-linalg-ops-padding-size=32 " + "--iree-flow-enable-conv-img2col-transform " + f"--mlir-print-ir-after={dump_after} " + "--compile-to=flow " + f"2>{args.annotation_output}/dump_after_winograd.mlir " + ) + + # Annotate the model with lowering configs in the config file + with create_context() as ctx: + tuned_model = model_annotation( + ctx, + input_contents=f"{args.annotation_output}/dump_after_winograd.mlir", + config_path=lowering_config_dir, + search_op="all", + ) + + # Remove the intermediate mlir and save the final annotated model + os.remove(f"{args.annotation_output}/dump_after_winograd.mlir") + output_path = f"{args.annotation_output}/{model_name}_tuned_torch.mlir" + with open(output_path, "w") as f: + f.write(str(tuned_model)) + print(f"Saved the annotated mlir in {output_path}.") diff --git a/shark/examples/shark_inference/stable_diffusion/stable_args.py b/shark/examples/shark_inference/stable_diffusion/stable_args.py index 4fb6c680a7..4aa274633a 100644 --- a/shark/examples/shark_inference/stable_diffusion/stable_args.py +++ b/shark/examples/shark_inference/stable_diffusion/stable_args.py @@ -1,4 +1,10 @@ import argparse +from pathlib import Path + + +def path_expand(s): + return Path(s).expanduser().resolve() + p = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -223,4 +229,22 @@ help="flag for removing the pregress bar animation during image generation", ) +############################################################################## +### SD model auto-annotation flags +############################################################################## + +p.add_argument( + "--annotation_output", + type=path_expand, + default="./", + help="Directory to save the annotated mlir file", +) + +p.add_argument( + "--annotation_model", + type=str, + default="unet", + help="Options are unet and vae.", +) + args = p.parse_args() diff --git a/shark/model_annotation.py b/shark/model_annotation.py index 5836d66691..32ea709747 100644 --- a/shark/model_annotation.py +++ b/shark/model_annotation.py @@ -40,17 +40,23 @@ def model_annotation( input_contents: str, config_path: str, search_op: str, + winograd: bool = False, ): if os.path.isfile(input_contents): with open(input_contents, "rb") as f: input_contents = f.read() module = ir.Module.parse(input_contents) - configs = load_model_configs(config_path) + if winograd: + with open(config_path, "r") as f: + data = json.load(f) + configs = data["c,f"] + else: + configs = load_model_configs(config_path) # The Python API does not expose a general walk() function, so we just # do it ourselves. - walk_children(module.operation, configs, search_op) + walk_children(module.operation, configs, search_op, winograd) if not module.operation.verify(): raise RuntimeError("Modified program does not verify!") @@ -92,7 +98,9 @@ def load_model_configs(config_path: str): return config -def walk_children(op: ir.Operation, configs: List[Dict], search_op: str): +def walk_children( + op: ir.Operation, configs: List[Dict], search_op: str, winograd: bool +): if search_op == "matmul": op_names = ["linalg.matmul", "mhlo.dot"] elif search_op == "bmm": @@ -121,6 +129,11 @@ def walk_children(op: ir.Operation, configs: List[Dict], search_op: str): # 'operation' and 'name' attributes. if isinstance(child_op, ir.OpView): child_op = child_op.operation + if winograd and child_op.name in [ + "linalg.conv_2d_nchw_fchw", + "linalg.conv_2d_nhwc_hwcf", + ]: + add_winograd_attribute(child_op, configs) if child_op.name in op_names: if child_op.name == "linalg.generic": # This is for generic op that has contractionOpInterface @@ -151,7 +164,7 @@ def walk_children(op: ir.Operation, configs: List[Dict], search_op: str): ) print(f"Updated op {child_op}", file=sys.stderr) - walk_children(child_op, configs, search_op) + walk_children(child_op, configs, search_op, winograd) def get_op_shape(op: ir.Operation, search_op: str): @@ -294,10 +307,6 @@ def add_attributes(op: ir.Operation, config: List[Dict]): pipeline_depth = config["pipeline_depth"] if "split_k" in config.keys(): split_k = config["split_k"] - if "devices" in config.keys(): - devices = config["devices"] - if "shard_sizes" in config.keys(): - shard_sizes = config["shard_sizes"] elif "SPIRV" in config["pipeline"]: pipeline = config["pipeline"] tile_sizes = [ @@ -355,6 +364,39 @@ def add_attributes(op: ir.Operation, config: List[Dict]): add_attribute_by_name(op, "iree_flow_split_k", split_k) +def add_winograd_attribute(op: ir.Operation, config: List): + op_result = str(op.results[0]).split("ins(")[1] + dilation = int( + str(op.attributes["dilations"]).split("dense<")[1].split(">")[0] + ) + stride = int( + str(op.attributes["strides"]).split("dense<")[1].split(">")[0] + ) + + if op.name == "linalg.conv_2d_nchw_fchw": + f = int(op_result.split("tensor<")[2].split("x")[0]) + c = int(op_result.split("tensor<")[2].split("x")[1]) + kh = int(op_result.split("tensor<")[2].split("x")[2]) + kw = int(op_result.split("tensor<")[2].split("x")[3]) + else: + kh = int(op_result.split("tensor<")[2].split("x")[0]) + kw = int(op_result.split("tensor<")[2].split("x")[1]) + c = int(op_result.split("tensor<")[2].split("x")[2]) + f = int(op_result.split("tensor<")[2].split("x")[3]) + + if ( + dilation == 1 + and stride == 1 + and kh == 3 + and kw == 3 + and [c, f] in config + ): + op.attributes["iree_winograd_conv"] = ir.IntegerAttr.get( + ir.IntegerType.get_signless(64), 1 + ) + print("Apply Winograd on selected conv op: ", op) + + def add_attribute_by_name(op: ir.Operation, name: str, val: int): attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val) op.attributes[name] = attr