Skip to content

Commit

Permalink
validations: disable torch compile for lycoris or deepspeed
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Nov 10, 2024
1 parent c2701f6 commit cbebba4
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,26 +1083,31 @@ def setup_pipeline(self, validation_type, enable_ema_model: bool = True):
continue
break
if self.args.validation_torch_compile:
if self.unet is not None and not is_compiled_module(self.unet):
logger.warning(
f"Compiling the UNet for validation ({self.args.validation_torch_compile})"
)
self.pipeline.unet = torch.compile(
self.pipeline.unet,
mode=self.args.validation_torch_compile_mode,
fullgraph=False,
)
if self.transformer is not None and not is_compiled_module(
self.transformer
):
logger.warning(
f"Compiling the transformer for validation ({self.args.validation_torch_compile})"
)
self.pipeline.transformer = torch.compile(
self.pipeline.transformer,
mode=self.args.validation_torch_compile_mode,
fullgraph=False,
)
if self.deepspeed:
logger.warning("DeepSpeed does not support torch compile. Disabling. Set --validation_torch_compile=False to suppress this warning.")
elif self.lora_type.lower() == "lycoris":
logger.warning("LyCORIS does not support torch compile for validation due to graph compile breaks. Disabling. Set --validation_torch_compile=False to suppress this warning.")
else:
if self.unet is not None and not is_compiled_module(self.unet):
logger.warning(
f"Compiling the UNet for validation ({self.args.validation_torch_compile})"
)
self.pipeline.unet = torch.compile(
self.pipeline.unet,
mode=self.args.validation_torch_compile_mode,
fullgraph=False,
)
if self.transformer is not None and not is_compiled_module(
self.transformer
):
logger.warning(
f"Compiling the transformer for validation ({self.args.validation_torch_compile})"
)
self.pipeline.transformer = torch.compile(
self.pipeline.transformer,
mode=self.args.validation_torch_compile_mode,
fullgraph=False,
)

self.pipeline = self.pipeline.to(self.inference_device)
self.pipeline.set_progress_bar_config(disable=True)
Expand Down

0 comments on commit cbebba4

Please sign in to comment.