Skip to content

Commit

Permalink
Update SD_DeepCache.py
Browse files Browse the repository at this point in the history
  • Loading branch information
WentianZhang-ML authored Apr 22, 2024
1 parent a5428f7 commit a29ec6e
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions src/tgate/SD_DeepCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,24 +251,28 @@ def tgate(
for i, t in enumerate(timesteps):
if self.interrupt:
continue
register_tgate_forward(self.unet,
'Attention',
gate_step=gate_step,
inference_num_per_image = num_inference_steps,
lcm=False,
cur_step=i+1
)

# expand the latents if we are doing classifier free guidance
if self.do_classifier_free_guidance and i < gate_step:
if self.do_classifier_free_guidance and (i-num_warmup_steps) < gate_step:
latent_model_input = torch.cat([latents] * 2)
prompt_embeds = prompt_cfg_embeds
else:
latent_model_input = latents
prompt_embeds = negative_prompt_embeds
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
if i == gate_step or i == gate_step-1:
if (i-num_warmup_steps) == gate_step or (i-num_warmup_steps) == gate_step-1:
self.deepcache.disable()
self.deepcache.enable()

# TGATE
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
register_tgate_forward(self.unet,
'Attention',
gate_step=gate_step,
inference_num_per_image = num_inference_steps,
cur_step=i+1-num_warmup_steps,
)

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
Expand All @@ -281,11 +285,11 @@ def tgate(
)[0]

# perform guidance
if self.do_classifier_free_guidance and i < gate_step:
if self.do_classifier_free_guidance and (i-num_warmup_steps) < gate_step:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

if self.do_classifier_free_guidance and self.guidance_rescale > 0.0 and i < gate_step:
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0 and (i-num_warmup_steps) < gate_step:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

Expand Down

0 comments on commit a29ec6e

Please sign in to comment.