Skip to content

Commit

Permalink
Fix an issue with hardcoded iterations
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters-amd committed Aug 19, 2024
1 parent d23b528 commit 6f6585b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def prepare_all(
vmfbs: dict = {},
weights: dict = {},
interactive: bool = False,
num_steps: int = 20,
):
ready = self.is_prepared(vmfbs, weights)
match ready:
Expand All @@ -463,7 +464,7 @@ def prepare_all(
if not self.map[submodel].get("weights") and self.map[submodel][
"export_args"
].get("external_weights"):
self.export_submodel(submodel, weights_only=True)
self.export_submodel(submodel, weights_only=True, num_steps=num_steps)
return self.prepare_all(mlirs, vmfbs, weights, interactive)

def is_prepared(self, vmfbs, weights):
Expand Down Expand Up @@ -581,6 +582,7 @@ def export_submodel(
submodel: str,
input_mlir: str = None,
weights_only: bool = False,
num_steps: int = 20,
):
if not os.path.exists(self.pipeline_dir):
os.makedirs(self.pipeline_dir)
Expand Down Expand Up @@ -672,6 +674,7 @@ def export_submodel(
self.map[submodel]["export_args"]["max_length"],
"produce_img_split",
unet_module_name=self.map["unet"]["module_name"],
num_steps=num_steps,
)
dims = [
self.map[submodel]["export_args"]["width"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ def numpy_to_pil_image(images):
False,
args.compiled_pipeline,
)
sd_pipe.prepare_all()
sd_pipe.prepare_all(num_steps=args.num_inference_steps)
sd_pipe.load_map()
sd_pipe.generate_images(
args.prompt,
Expand Down

0 comments on commit 6f6585b

Please sign in to comment.