diff --git a/src/tgate/SDXL_DeepCache.py b/src/tgate/SDXL_DeepCache.py index e56924b..44cf820 100644 --- a/src/tgate/SDXL_DeepCache.py +++ b/src/tgate/SDXL_DeepCache.py @@ -467,6 +467,7 @@ def tgate( xm.mark_step() self.deepcache.disable() self.deepcache.enable() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast @@ -474,8 +475,27 @@ def tgate( if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: