Skip to content

Commit

Permalink
Fix compiled pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters-amd committed Aug 29, 2024
1 parent 5eb013d commit 1d71f8c
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 81 deletions.
36 changes: 30 additions & 6 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def prepare_all(
vmfbs: dict = {},
weights: dict = {},
interactive: bool = False,
num_steps: int = 20,
):
ready = self.is_prepared(vmfbs, weights)
match ready:
Expand All @@ -463,7 +464,9 @@ def prepare_all(
if not self.map[submodel].get("weights") and self.map[submodel][
"export_args"
].get("external_weights"):
self.export_submodel(submodel, weights_only=True)
self.export_submodel(
submodel, weights_only=True, num_steps=num_steps
)
return self.prepare_all(mlirs, vmfbs, weights, interactive)

def is_prepared(self, vmfbs, weights):
Expand Down Expand Up @@ -581,6 +584,7 @@ def export_submodel(
submodel: str,
input_mlir: str = None,
weights_only: bool = False,
num_steps: int = 20,
):
if not os.path.exists(self.pipeline_dir):
os.makedirs(self.pipeline_dir)
Expand Down Expand Up @@ -670,7 +674,9 @@ def export_submodel(
self.map[submodel]["export_args"]["precision"],
self.map[submodel]["export_args"]["batch_size"],
self.map[submodel]["export_args"]["max_length"],
"tokens_to_image",
"produce_img_split",
unet_module_name=self.map["unet"]["module_name"],
num_steps=num_steps,
)
dims = [
self.map[submodel]["export_args"]["width"],
Expand Down Expand Up @@ -699,8 +705,8 @@ def export_submodel(
return_path=True,
mlir_source="str",
)
self.map[submodel]["vmfb"] = vmfb_path
self.map[submodel]["weights"] = None
self.map[submodel]["vmfb"] = [vmfb_path]
self.map[submodel]["weights"] = []
case _:
export_args = self.map[submodel].get("export_args", {})
if weights_only:
Expand All @@ -721,10 +727,24 @@ def export_submodel(

# LOAD
def load_map(self):
for submodel in self.map.keys():
# Make sure fullpipeline is imported last
submodels = list(self.map.keys() - {"fullpipeline"})
submodels += ["fullpipeline"] if "fullpipeline" in self.map.keys() else []
for submodel in submodels:
if not self.map[submodel]["load"]:
self.printer.print("Skipping load for ", submodel)
self.printer.print(f"Skipping load for {submodel}")
continue
elif self.map[submodel].get("wraps"):
vmfbs = []
weights = []
for wrapped in self.map[submodel]["wraps"]:
vmfbs.append(self.map[wrapped]["vmfb"])
if "weights" in self.map[wrapped]:
weights.append(self.map[wrapped]["weights"])
self.map[submodel]["vmfb"] = vmfbs + self.map[submodel]["vmfb"]
self.map[submodel]["weights"] = weights + self.map[submodel]["weights"]

print(f"Loading {submodel}")
self.load_submodel(submodel)

def load_submodel(self, submodel):
Expand All @@ -751,6 +771,10 @@ def load_submodel(self, submodel):

def unload_submodel(self, submodel):
self.map[submodel]["runner"].unload()
self.map[submodel]["vmfb"] = None
self.map[submodel]["mlir"] = None
self.map[submodel]["weights"] = None
self.map[submodel]["export_args"]["input_mlir"] = None
setattr(self, submodel, None)


Expand Down
187 changes: 133 additions & 54 deletions models/turbine_models/custom_models/sd_inference/sd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,11 @@
"decomp_attn": None,
},
},
"unetloop": {
"module_name": "sdxl_compiled_pipeline",
"load": False,
"keywords": ["unetloop"],
"wraps": ["unet", "scheduler"],
"export_args": {
"batch_size": 1,
"height": 1024,
"width": 1024,
"max_length": 64,
},
},
"fullpipeline": {
"module_name": "sdxl_compiled_pipeline",
"load": False,
"load": True,
"keywords": ["fullpipeline"],
"wraps": ["text_encoder", "unet", "scheduler", "vae"],
"wraps": ["unet", "scheduler", "vae"],
"export_args": {
"batch_size": 1,
"height": 1024,
Expand Down Expand Up @@ -190,6 +178,7 @@ def get_sd_model_map(hf_model_name):
"stabilityai/sdxl-turbo",
"stabilityai/stable-diffusion-xl-base-1.0",
"/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16/checkpoint_pipe",
"/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16//checkpoint_pipe",
]:
return sdxl_model_map
elif "stabilityai/stable-diffusion-3" in name:
Expand Down Expand Up @@ -233,6 +222,7 @@ def __init__(
benchmark: bool | dict[bool] = False,
verbose: bool = False,
batch_prompts: bool = False,
compiled_pipeline: bool = False,
):
common_export_args = {
"hf_model_name": None,
Expand All @@ -243,11 +233,11 @@ def __init__(
"exit_on_vmfb": False,
"pipeline_dir": pipeline_dir,
"input_mlir": None,
"attn_spec": None,
"attn_spec": attn_spec,
"external_weights": None,
"external_weight_path": None,
}
sd_model_map = get_sd_model_map(hf_model_name)
sd_model_map = copy.deepcopy(get_sd_model_map(hf_model_name))
for submodel in sd_model_map:
if "load" not in sd_model_map[submodel]:
sd_model_map[submodel]["load"] = True
Expand Down Expand Up @@ -311,6 +301,7 @@ def __init__(
self.scheduler = None

self.split_scheduler = True
self.compiled_pipeline = compiled_pipeline

self.base_model_name = (
hf_model_name
Expand All @@ -321,11 +312,6 @@ def __init__(
self.is_sdxl = "xl" in self.base_model_name.lower()
self.is_sd3 = "stable-diffusion-3" in self.base_model_name
if self.is_sdxl:
if self.split_scheduler:
if self.map.get("unetloop"):
self.map.pop("unetloop")
if self.map.get("fullpipeline"):
self.map.pop("fullpipeline")
self.tokenizers = [
CLIPTokenizer.from_pretrained(
self.base_model_name, subfolder="tokenizer"
Expand All @@ -339,6 +325,20 @@ def __init__(
self.scheduler_device = self.map["unet"]["device"]
self.scheduler_driver = self.map["unet"]["driver"]
self.scheduler_target = self.map["unet"]["target"]
if not self.compiled_pipeline:
if self.map.get("unetloop"):
self.map.pop("unetloop")
if self.map.get("fullpipeline"):
self.map.pop("fullpipeline")
elif self.compiled_pipeline:
self.map["unet"]["load"] = False
self.map["vae"]["load"] = False
self.load_scheduler(
scheduler_id,
num_inference_steps,
)
self.map["scheduler"]["runner"].unload()
self.map["scheduler"]["load"] = False
elif not self.is_sd3:
self.tokenizer = CLIPTokenizer.from_pretrained(
self.base_model_name, subfolder="tokenizer"
Expand All @@ -351,23 +351,27 @@ def __init__(

self.latents_dtype = torch_dtypes[self.latents_precision]
self.use_i8_punet = self.use_punet = use_i8_punet
if self.use_punet:
self.setup_punet()
else:
self.map["unet"]["keywords"].append("!punet")
self.map["unet"]["function_name"] = "run_forward"

def setup_punet(self):
if self.use_i8_punet:
self.map["unet"]["export_args"]["precision"] = "i8"
self.map["unet"]["export_args"]["use_punet"] = True
self.map["unet"]["use_weights_for_export"] = True
self.map["unet"]["keywords"].append("punet")
self.map["unet"]["module_name"] = "compiled_punet"
self.map["unet"]["function_name"] = "main"
self.map["unet"]["export_args"]["external_weight_path"] = (
utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa"
)
for idx, word in enumerate(self.map["unet"]["keywords"]):
if word in ["fp32", "fp16"]:
self.map["unet"]["keywords"][idx] = "i8"
break
else:
self.map["unet"]["keywords"].append("!punet")
self.map["unet"]["function_name"] = "run_forward"
self.map["unet"]["export_args"]["use_punet"] = True
self.map["unet"]["use_weights_for_export"] = True
self.map["unet"]["keywords"].append("punet")
self.map["unet"]["module_name"] = "compiled_punet"
self.map["unet"]["function_name"] = "main"

# LOAD

Expand All @@ -376,10 +380,6 @@ def load_scheduler(
scheduler_id: str,
steps: int = 30,
):
if self.is_sd3:
scheduler_device = self.mmdit.device
else:
scheduler_device = self.unet.device
if not self.cpu_scheduling:
self.map["scheduler"] = {
"module_name": "compiled_scheduler",
Expand Down Expand Up @@ -425,7 +425,11 @@ def load_scheduler(
except:
print("JIT export of scheduler failed. Loading CPU scheduler.")
self.cpu_scheduling = True
if self.cpu_scheduling:
elif self.cpu_scheduling:
if self.is_sd3:
scheduler_device = self.mmdit.device
else:
scheduler_device = self.unet.device
scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id)
self.scheduler = schedulers.SharkSchedulerCPUWrapper(
scheduler,
Expand Down Expand Up @@ -461,13 +465,10 @@ def encode_prompts_sdxl(self, prompt, negative_prompt):
text_input_ids_list += text_inputs.input_ids.unsqueeze(0)
uncond_input_ids_list += uncond_input.input_ids.unsqueeze(0)

if self.compiled_pipeline:
return text_input_ids_list, uncond_input_ids_list
else:
prompt_embeds, add_text_embeds = self.text_encoder(
"encode_prompts", [*text_input_ids_list, *uncond_input_ids_list]
)
return prompt_embeds, add_text_embeds
prompt_embeds, add_text_embeds = self.text_encoder(
"encode_prompts", [*text_input_ids_list, *uncond_input_ids_list]
)
return prompt_embeds, add_text_embeds

def prepare_latents(
self,
Expand Down Expand Up @@ -565,9 +566,11 @@ def _produce_latents_sdxl(
[guidance_scale],
dtype=self.map["unet"]["np_dtype"],
)
# Disable progress bar if we aren't in verbose mode or if we're printing
# benchmark latencies for unet.
for i, t in tqdm(
enumerate(timesteps),
disable=(self.map["unet"].get("benchmark") and self.verbose),
disable=(self.map["unet"].get("benchmark") or not self.verbose),
):
if self.cpu_scheduling:
latent_model_input, t = self.scheduler.scale_model_input(
Expand Down Expand Up @@ -608,6 +611,75 @@ def _produce_latents_sdxl(
latents = self.scheduler("run_step", [noise_pred, t, latents])
return latents

def produce_images_compiled(
self,
sample,
prompt_embeds,
text_embeds,
guidance_scale,
):
pipe_inputs = [
sample,
prompt_embeds,
text_embeds,
torch.as_tensor([guidance_scale], dtype=sample.dtype),
]
# image = self.compiled_pipeline("produce_img_latents", pipe_inputs)
image = self.map["fullpipeline"]["runner"]("produce_image_latents", pipe_inputs)
return image

def prepare_sampling_inputs(
self,
prompt: str,
negative_prompt: str = "",
steps: int = 30,
batch_count: int = 1,
guidance_scale: float = 7.5,
seed: float = -1,
cpu_scheduling: bool = True,
scheduler_id: str = "EulerDiscrete",
return_imgs: bool = False,
):
needs_new_scheduler = (
(steps and steps != self.num_inference_steps)
or (cpu_scheduling != self.cpu_scheduling)
and self.split_scheduler
)
if not self.scheduler and not self.compiled_pipeline:
needs_new_scheduler = True

if guidance_scale == 0:
negative_prompt = prompt
prompt = ""

self.cpu_scheduling = cpu_scheduling
if steps and needs_new_scheduler:
self.num_inference_steps = steps
self.load_scheduler(scheduler_id, steps)

pipe_start = time.time()
numpy_images = []

samples = self.get_rand_latents(seed, batch_count)

# Tokenize prompt and negative prompt.
if self.is_sdxl:
prompt_embeds, negative_embeds = self.encode_prompts_sdxl(
prompt, negative_prompt
)
else:
prompt_embeds, negative_embeds = encode_prompt(
self, prompt, negative_prompt
)
produce_latents_input = [
samples[0],
prompt_embeds,
negative_embeds,
steps,
guidance_scale,
]
return produce_latents_input

def generate_images(
self,
prompt: str,
Expand Down Expand Up @@ -653,18 +725,23 @@ def generate_images(
)

for i in range(batch_count):
produce_latents_input = [
samples[i],
prompt_embeds,
negative_embeds,
steps,
guidance_scale,
]
if self.is_sdxl:
latents = self._produce_latents_sdxl(*produce_latents_input)
if self.compiled_pipeline:
image = self.produce_images_compiled(
samples[i], prompt_embeds, negative_embeds, guidance_scale
).to_host()
else:
latents = self._produce_latents_sd(*produce_latents_input)
image = self.vae("decode", [latents])
produce_latents_input = [
samples[i],
prompt_embeds,
negative_embeds,
steps,
guidance_scale,
]
if self.is_sdxl:
latents = self._produce_latents_sdxl(*produce_latents_input)
else:
latents = self._produce_latents_sd(*produce_latents_input)
image = self.vae("decode", [latents])
numpy_images.append(image)
pipe_end = time.time()

Expand Down Expand Up @@ -750,8 +827,10 @@ def numpy_to_pil_image(images):
args.use_i8_punet,
benchmark,
args.verbose,
False,
args.compiled_pipeline,
)
sd_pipe.prepare_all()
sd_pipe.prepare_all(num_steps=args.num_inference_steps)
sd_pipe.load_map()
sd_pipe.generate_images(
args.prompt,
Expand Down
Loading

0 comments on commit 1d71f8c

Please sign in to comment.