From 59600456be88abd75ade788c5f144abcef841957 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Sun, 2 Jun 2024 16:25:16 -0500 Subject: [PATCH] seed fixes --- apps/shark_studio/api/sd.py | 15 +++++++++------ apps/shark_studio/modules/shared_cmd_opts.py | 2 +- apps/shark_studio/web/ui/sd.py | 6 ++++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 380bfe82a3..54b703a9bd 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -459,10 +459,10 @@ def shark_sd_fn( global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs) generated_imgs = [] - if seed in [-1, "-1"]: - seed = randint(0, 4294967295) + if submit_run_kwargs["seed"] in [-1, "-1"]: + submit_run_kwargs["seed"] = randint(0, 4294967295) seed_increment = "random" - print(f"\n[LOG] Random seed: {seed}") + #print(f"\n[LOG] Random seed: {seed}") progress(None, desc=f"Generating...") for current_batch in range(batch_count): @@ -483,20 +483,23 @@ def shark_sd_fn( sd_kwargs, ) generated_imgs.extend(out_imgs) - seed = get_next_seed(seed, seed_increment) + yield generated_imgs, status_label( "Stable Diffusion", current_batch + 1, batch_count, batch_size ) + if batch_count > 1: + submit_run_kwargs["seed"] = get_next_seed(seed, seed_increment) + return (generated_imgs, "") def get_next_seed(seed, seed_increment: str | int = 10): if isinstance(seed_increment, int): - print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}") + #print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}") return int(seed + seed_increment) elif seed_increment == "random": seed = randint(0, 4294967295) - print(f"\n[LOG] Random seed: {seed}") + #print(f"\n[LOG] Random seed: {seed}") return seed diff --git a/apps/shark_studio/modules/shared_cmd_opts.py b/apps/shark_studio/modules/shared_cmd_opts.py index 286703adea..c9cdbf4aad 100644 --- a/apps/shark_studio/modules/shared_cmd_opts.py +++ b/apps/shark_studio/modules/shared_cmd_opts.py @@ -343,7 +343,7 @@ def is_valid_file(arg): p.add_argument( "--batch_count", type=int, - default=4, + default=1, help="Number of batches to be generated with random seeds in " "single execution.", ) diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index 3a605926f4..ddf8276a92 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -256,7 +256,7 @@ def base_model_changed(base_model_id): elif ".py" in base_model_id: new_steps = gr.Dropdown( value=20, - choices=[10, 15, 20, 28], + choices=[10, 15, 20], label="\U0001F3C3\U0000FE0F Steps", allow_custom_value=True, ) @@ -462,7 +462,7 @@ def base_model_changed(base_model_id): ) guidance_scale = gr.Slider( 0, - 50, + 5, #DEMO value=cmd_opts.guidance_scale, step=0.1, label="\U0001F5C3\U0000FE0F CFG Scale", @@ -636,6 +636,7 @@ def base_model_changed(base_model_id): step=1, label="Batch Count", interactive=True, + visible=True, ) batch_size = gr.Slider( 1, @@ -649,6 +650,7 @@ def base_model_changed(base_model_id): compiled_pipeline = gr.Checkbox( True, label="Faster txt2img (SDXL only)", + visible=False, # DEMO ) with gr.Row(): stable_diffusion = gr.Button("Start")