Skip to content

Commit

Permalink
Add compiled pipeline option
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed May 30, 2024
1 parent fe142c8 commit 151d300
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
12 changes: 8 additions & 4 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,17 @@ def __init__(
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
gc.collect()

def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
def prepare_pipe(
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
):
print(f"\n[LOG] Preparing pipeline...")
self.is_img2img = False
mlirs = copy.deepcopy(self.model_map)
vmfbs = copy.deepcopy(self.model_map)
weights = copy.deepcopy(self.model_map)
if not self.is_sdxl:
compiled_pipeline = False
self.compiled_pipeline = compiled_pipeline

if custom_weights:
custom_weights = os.path.join(
Expand Down Expand Up @@ -253,7 +258,6 @@ def generate_images(
guidance_scale,
seed,
ondemand,
repeatable_seeds,
resample_type,
control_mode,
hints,
Expand Down Expand Up @@ -306,7 +310,7 @@ def shark_sd_fn(
device: str,
target_triple: str,
ondemand: bool,
repeatable_seeds: bool,
compiled_pipeline: bool,
resample_type: str,
controlnets: dict,
embeddings: dict,
Expand Down Expand Up @@ -369,6 +373,7 @@ def shark_sd_fn(
"adapters": adapters,
"embeddings": embeddings,
"is_img2img": is_img2img,
"compiled_pipeline": compiled_pipeline,
}
submit_run_kwargs = {
"prompt": prompt,
Expand All @@ -378,7 +383,6 @@ def shark_sd_fn(
"guidance_scale": guidance_scale,
"seed": seed,
"ondemand": ondemand,
"repeatable_seeds": repeatable_seeds,
"resample_type": resample_type,
"control_mode": control_mode,
"hints": hints,
Expand Down
14 changes: 7 additions & 7 deletions apps/shark_studio/web/ui/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def pull_sd_configs(
device,
target_triple,
ondemand,
repeatable_seeds,
compiled_pipeline,
resample_type,
controlnets,
embeddings,
Expand Down Expand Up @@ -179,7 +179,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_json["device"],
sd_json["target_triple"],
sd_json["ondemand"],
sd_json["repeatable_seeds"],
sd_json["compiled_pipeline"],
sd_json["resample_type"],
sd_json["controlnets"],
sd_json["embeddings"],
Expand Down Expand Up @@ -606,9 +606,9 @@ def base_model_changed(base_model_id):
interactive=True,
visible=True,
)
repeatable_seeds = gr.Checkbox(
cmd_opts.repeatable_seeds,
label="Use Repeatable Seeds for Batches",
compiled_pipeline = gr.Checkbox(
False,
label="Faster txt2img (SDXL only)",
)
with gr.Row():
stable_diffusion = gr.Button("Start")
Expand Down Expand Up @@ -685,7 +685,7 @@ def base_model_changed(base_model_id):
device,
target_triple,
ondemand,
repeatable_seeds,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
Expand Down Expand Up @@ -741,7 +741,7 @@ def base_model_changed(base_model_id):
device,
target_triple,
ondemand,
repeatable_seeds,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
Expand Down

0 comments on commit 151d300

Please sign in to comment.