diff --git a/OPTIONS.md b/OPTIONS.md index edeea098..79c77603 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -40,6 +40,15 @@ The script `configure.py` in the project root can be used via `python configure. - **What**: Path to the pretrained T5 model or its identifier from https://huggingface.co/models. - **Why**: When training PixArt, you might want to use a specific source for your T5 weights so that you can avoid downloading them multiple times when switching the base model you train from. +### `--gradient_checkpointing` + +- **What**: During training, gradients will be calculated layerwise and accumulated to save on peak VRAM requirements at the cost of slower training. + +### `--gradient_checkpointing_interval` + +- **What**: Checkpoint only every _n_ blocks, where _n_ is a value greater than zero. A value of 1 is effectively the same as just leaving `--gradient_checkpointing` enabled, and a value of 2 will checkpoint every other block. +- **Note**: SDXL and Flux are currently the only models supporting this option. SDXL uses a hackish implementation. + ### `--refiner_training` - **What**: Enables training a custom mixture-of-experts model series. See [Mixture-of-Experts](/documentation/MIXTURE_OF_EXPERTS.md) for more information on these options. @@ -109,6 +118,18 @@ Carefully answer the questions and use bf16 mixed precision training when prompt Note that the first several steps of training will be slower than usual because of compilation occuring in the background. +### `--attention_mechanism` + +Alternative attention mechanisms are supported, with varying levels of compatibility or other trade-offs; + +- `diffusers` uses the native Pytorch SDPA functions and is the default attention mechanism +- `xformers` allows the use of Meta's [xformers](https://github.com/facebook/xformers) attention implementation which supports both training and inference fully +- `sageattention` is an inference-focused attention mechanism which does not fully support being used for training ([SageAttention](https://github.com/thu-ml/SageAttention) project page) + - In simplest terms, SageAttention reduces compute requirement for inference + +Using `--sageattention_usage` to enable training with SageAttention should be enabled with care, as it does not track or propagate gradients from its custom CUDA implementations for the QKV linears. + - This results in these layers being completely untrained, which might cause model collapse or, slight improvements in short training runs. + --- ## 📰 Publishing @@ -452,7 +473,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--lr_scheduler {linear,sine,cosine,cosine_with_restarts,polynomial,constant,constant_with_warmup}] [--lr_warmup_steps LR_WARMUP_STEPS] [--lr_num_cycles LR_NUM_CYCLES] [--lr_power LR_POWER] - [--use_ema] [--ema_device {cpu,accelerator}] [--ema_cpu_only] + [--use_ema] [--ema_device {cpu,accelerator}] + [--ema_validation {none,ema_only,comparison}] [--ema_cpu_only] [--ema_foreach_disable] [--ema_update_interval EMA_UPDATE_INTERVAL] [--ema_decay EMA_DECAY] [--non_ema_revision NON_EMA_REVISION] @@ -473,8 +495,9 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--model_card_safe_for_work] [--logging_dir LOGGING_DIR] [--benchmark_base_model] [--disable_benchmark] [--evaluation_type {clip,none}] - [--pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path] + [--pretrained_evaluation_model_name_or_path PRETRAINED_EVALUATION_MODEL_NAME_OR_PATH] [--validation_on_startup] [--validation_seed_source {gpu,cpu}] + [--validation_lycoris_strength VALIDATION_LYCORIS_STRENGTH] [--validation_torch_compile] [--validation_torch_compile_mode {max-autotune,reduce-overhead,default}] [--validation_guidance_skip_layers VALIDATION_GUIDANCE_SKIP_LAYERS] @@ -509,6 +532,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--text_encoder_2_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}] [--text_encoder_3_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}] [--local_rank LOCAL_RANK] + [--attention_mechanism {diffusers,xformers,sageattention,sageattention-int8-fp16-triton,sageattention-int8-fp16-cuda,sageattention-int8-fp8-cuda}] + [--sageattention_usage {training,inference,training+inference}] [--enable_xformers_memory_efficient_attention] [--set_grads_to_none] [--noise_offset NOISE_OFFSET] [--noise_offset_probability NOISE_OFFSET_PROBABILITY] @@ -1137,12 +1162,21 @@ options: cosine_with_restarts scheduler. --lr_power LR_POWER Power factor of the polynomial scheduler. --use_ema Whether to use EMA (exponential moving average) model. + Works with LoRA, Lycoris, and full training. --ema_device {cpu,accelerator} The device to use for the EMA model. If set to 'accelerator', the EMA model will be placed on the accelerator. This provides the fastest EMA update times, but is not ultimately necessary for EMA to function. + --ema_validation {none,ema_only,comparison} + When 'none' is set, no EMA validation will be done. + When using 'ema_only', the validations will rely + mostly on the EMA weights. When using 'comparison' + (default) mode, the validations will first run on the + checkpoint before also running for the EMA weights. In + comparison mode, the resulting images will be provided + side-by-side. --ema_cpu_only When using EMA, the shadow model is moved to the accelerator before we update its parameters. When provided, this option will disable the moving of the @@ -1248,7 +1282,7 @@ options: function. The default is to use no evaluator, and 'clip' will use a CLIP model to evaluate the resulting model's performance during validations. - --pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path + --pretrained_evaluation_model_name_or_path PRETRAINED_EVALUATION_MODEL_NAME_OR_PATH Optionally provide a custom model to use for ViT evaluations. The default is currently clip-vit-large- patch14-336, allowing for lower patch sizes (greater @@ -1264,6 +1298,12 @@ options: validation errors. If so, please set SIMPLETUNER_LOG_LEVEL=DEBUG and submit debug.log to a new Github issue report. + --validation_lycoris_strength VALIDATION_LYCORIS_STRENGTH + When inferencing for validations, the Lycoris model + will by default be run at its training strength, 1.0. + However, this value can be increased to a value of + around 1.3 or 1.5 to get a stronger effect from the + model. --validation_torch_compile Supply `--validation_torch_compile=true` to enable the use of torch.compile() on the validation pipeline. For @@ -1453,8 +1493,32 @@ options: quantisation (Apple Silicon, NVIDIA, AMD). --local_rank LOCAL_RANK For distributed training: local_rank + --attention_mechanism {diffusers,xformers,sageattention,sageattention-int8-fp16-triton,sageattention-int8-fp16-cuda,sageattention-int8-fp8-cuda} + On NVIDIA CUDA devices, alternative flash attention + implementations are offered, with the default being + native pytorch SDPA. SageAttention has multiple + backends to select from. The recommended value, + 'sageattention', guesses what would be the 'best' + option for SageAttention on your hardware (usually + this is the int8-fp16-cuda backend). However, manually + setting this value to int8-fp16-triton may provide + better averages for per-step training and inference + performance while the cuda backend may provide the + highest maximum speed (with also a lower minimum + speed). NOTE: SageAttention training quality has not + been validated. + --sageattention_usage {training,inference,training+inference} + SageAttention breaks gradient tracking through the + backward pass, leading to untrained QKV layers. This + can result in substantial problems for training, so it + is recommended to use SageAttention only for inference + (default behaviour). If you are confident in your + training setup or do not wish to train QKV layers, you + may use 'training' to enable SageAttention for + training. --enable_xformers_memory_efficient_attention - Whether or not to use xformers. + Whether or not to use xformers. Deprecated and slated + for future removal. Use --attention_mechanism. --set_grads_to_none Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain behaviors, so disable this argument if it causes any diff --git a/configure.py b/configure.py index f14e13b9..26620662 100644 --- a/configure.py +++ b/configure.py @@ -429,7 +429,36 @@ def configure_env(): ).lower() == "y" ) - report_to_str = "" + + env_contents["--attention_mechanism"] = "diffusers" + use_sageattention = ( + prompt_user( + "Would you like to use SageAttention for image validation generation? (y/[n])", + "n", + ).lower() + == "y" + ) + if use_sageattention: + env_contents["--attention_mechanism"] = "sageattention" + env_contents["--sageattention_usage"] = "inference" + use_sageattention_training = ( + prompt_user( + ( + "Would you like to use SageAttention to cover the forward and backward pass during training?" + " This has the undesirable consequence of leaving the attention layers untrained," + " as SageAttention lacks the capability to fully track gradients through quantisation." + " If you are not training the attention layers for some reason, this may not matter and" + " you can safely enable this. For all other use-cases, reconsideration and caution are warranted." + ), + "n", + ).lower() + == "y" + ) + if use_sageattention_training: + env_contents["--sageattention_usage"] = "both" + + # properly disable wandb/tensorboard/comet_ml etc by default + report_to_str = "none" if report_to_wandb or report_to_tensorboard: tracker_project_name = prompt_user( "Enter the name of your Weights & Biases project", f"{model_type}-training" @@ -440,17 +469,17 @@ def configure_env(): f"simpletuner-{model_type}", ) env_contents["--tracker_run_name"] = tracker_run_name - report_to_str = None if report_to_wandb: report_to_str = "wandb" if report_to_tensorboard: - if report_to_wandb: + if report_to_str != "none": + # report to both WandB and Tensorboard if the user wanted. report_to_str += "," else: + # remove 'none' from the option report_to_str = "" report_to_str += "tensorboard" - if report_to_str: - env_contents["--report_to"] = report_to_str + env_contents["--report_to"] = report_to_str print_config(env_contents, extra_args) @@ -514,6 +543,18 @@ def configure_env(): ) ) env_contents["--gradient_checkpointing"] = "true" + gradient_checkpointing_interval = prompt_user( + "Would you like to configure a gradient checkpointing interval? A value larger than 1 will increase VRAM usage but speed up training by skipping checkpoint creation every Nth layer, and a zero will disable this feature.", + 0, + ) + try: + if int(gradient_checkpointing_interval) > 1: + env_contents["--gradient_checkpointing_interval"] = int( + gradient_checkpointing_interval + ) + except: + print("Could not parse gradient checkpointing interval. Not enabling.") + pass env_contents["--caption_dropout_probability"] = float( prompt_user( diff --git a/documentation/LYCORIS.md b/documentation/LYCORIS.md index 64aba565..5d0c5f55 100644 --- a/documentation/LYCORIS.md +++ b/documentation/LYCORIS.md @@ -59,6 +59,14 @@ Mandatory fields: For more information on LyCORIS, please refer to the [documentation in the library](https://github.com/KohakuBlueleaf/LyCORIS/tree/main/docs). +## Potential problems + +When using Lycoris on SDXL, it's noted that training the FeedForward modules may break the model and send loss into `NaN` (Not-a-Number) territory. + +This seems to be potentially exacerbated when using SageAttention (with `--sageattention_usage=training`), making it all but guaranteed that the model will immediately fail. + +The solution is to remove the `FeedForward` modules from the lycoris config and train only the `Attention` blocks. + ## LyCORIS Inference Example Here is a simple FLUX.1-dev inference script showing how to wrap your unet or transformer with create_lycoris_from_weights and then use it for inference. diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index ddb86e1d..ee5b98cc 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -144,6 +144,8 @@ There, you will possibly need to modify the following variables: - This option causes update steps to be accumulated over several steps. This will increase the training runtime linearly, such that a value of 2 will make your training run half as quickly, and take twice as long. - `optimizer` - Beginners are recommended to stick with adamw_bf16, though optimi-lion and optimi-stableadamw are also good choices. - `mixed_precision` - Beginners should keep this in `bf16` +- `gradient_checkpointing` - set this to true in practically every situation on every device +- `gradient_checkpointing_interval` - this could be set to a value of 2 or higher on larger GPUs to only checkpoint every _n_ blocks. A value of 2 would checkpoint half of the blocks, and 3 would be one-third. Multi-GPU users can reference [this document](/OPTIONS.md#environment-configuration-variables) for information on configuring the number of GPUs to use. @@ -414,9 +416,19 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with: - DeepSpeed: disabled / unconfigured - PyTorch: 2.6 Nightly (Sept 29th build) - Using `--quantize_via=cpu` to avoid outOfMemory error during startup on <=16G cards. +- With `--attention_mechanism=sageattention` to further reduce VRAM by 0.1GB and improve training validation image generation speed. +- Be sure to enable `--gradient_checkpointing` or nothing you do will stop it from OOMing + +**NOTE**: Pre-caching of VAE embeds and text encoder outputs may use more memory and still OOM. If so, text encoder quantisation and VAE tiling can be enabled. Speed was approximately 1.4 iterations per second on a 4090. +### SageAttention + +When using `--attention_mechanism=sageattention`, inference can be sped-up at validation time. + +**Note**: This isn't compatible with _every_ model configuration, but it's worth trying. + ### NF4-quantised training In simplest terms, NF4 is a 4bit-_ish_ representation of the model, which means training has serious stability concerns to address. @@ -428,6 +440,7 @@ In early tests, the following holds true: - NF4, AdamW8bit, and a higher batch size all help to overcome the stability issues, at the cost of more time spent training or VRAM used - Upping the resolution from 512px to 1024px slows training down from, for example, 1.4 seconds per step to 3.5 seconds per step (batch size of 1, 4090) - Anything that's difficult to train on int8 or bf16 becomes harder in NF4 +- It's less compatible with options like SageAttention NF4 does not work with torch.compile, so whatever you get for speed is what you get. diff --git a/documentation/quickstart/SD3.md b/documentation/quickstart/SD3.md index 5d646562..832eaa9f 100644 --- a/documentation/quickstart/SD3.md +++ b/documentation/quickstart/SD3.md @@ -339,6 +339,12 @@ These options have been known to keep SD3.5 in-tact for as long as possible: - DeepSpeed: disabled / unconfigured - PyTorch: 2.5 +### SageAttention + +When using `--attention_mechanism=sageattention`, inference can be sped-up at validation time. + +**Note**: This isn't compatible with _every_ model configuration, but it's worth trying. + ### Masked loss If you are training a subject or style and would like to mask one or the other, see the [masked loss training](/documentation/DREAMBOOTH.md#masked-loss) section of the Dreambooth guide. diff --git a/documentation/quickstart/SIGMA.md b/documentation/quickstart/SIGMA.md index cf3389aa..518bc006 100644 --- a/documentation/quickstart/SIGMA.md +++ b/documentation/quickstart/SIGMA.md @@ -220,3 +220,9 @@ For more information, see the [dataloader](/documentation/DATALOADER.md) and [tu ### CLIP score tracking If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores. + +### SageAttention + +When using `--attention_mechanism=sageattention`, inference can be sped-up at validation time. + +**Note**: This isn't compatible with _every_ model configuration, but it's worth trying. diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 7ea0dc61..457e1118 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1054,6 +1054,15 @@ def get_argument_parser(): action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) + parser.add_argument( + "--gradient_checkpointing_interval", + default=None, + type=int, + help=( + "Some models (Flux, SDXL, SD1.x/2.x) can have their gradient checkpointing limited to every nth block." + " This can speed up training but will use more memory with larger intervals." + ), + ) parser.add_argument( "--learning_rate", type=float, @@ -1131,7 +1140,7 @@ def get_argument_parser(): " When using 'ema_only', the validations will rely mostly on the EMA weights." " When using 'comparison' (default) mode, the validations will first run on the checkpoint before also running for" " the EMA weights. In comparison mode, the resulting images will be provided side-by-side." - ) + ), ) parser.add_argument( "--ema_cpu_only", @@ -1708,10 +1717,43 @@ def get_argument_parser(): default=-1, help="For distributed training: local_rank", ) + parser.add_argument( + "--attention_mechanism", + type=str, + choices=[ + "diffusers", + "xformers", + "sageattention", + "sageattention-int8-fp16-triton", + "sageattention-int8-fp16-cuda", + "sageattention-int8-fp8-cuda", + ], + default="diffusers", + help=( + "On NVIDIA CUDA devices, alternative flash attention implementations are offered, with the default being native pytorch SDPA." + " SageAttention has multiple backends to select from." + " The recommended value, 'sageattention', guesses what would be the 'best' option for SageAttention on your hardware" + " (usually this is the int8-fp16-cuda backend). However, manually setting this value to int8-fp16-triton" + " may provide better averages for per-step training and inference performance while the cuda backend" + " may provide the highest maximum speed (with also a lower minimum speed). NOTE: SageAttention training quality" + " has not been validated." + ), + ) + parser.add_argument( + "--sageattention_usage", + type=str, + choices=["training", "inference", "training+inference"], + default="inference", + help=( + "SageAttention breaks gradient tracking through the backward pass, leading to untrained QKV layers." + " This can result in substantial problems for training, so it is recommended to use SageAttention only for inference (default behaviour)." + " If you are confident in your training setup or do not wish to train QKV layers, you may use 'training' to enable SageAttention for training." + ), + ) parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", - help="Whether or not to use xformers.", + help="Whether or not to use xformers. Deprecated and slated for future removal. Use --attention_mechanism.", ) parser.add_argument( "--set_grads_to_none", @@ -2418,7 +2460,7 @@ def parse_cmdline_args(input_args=None): args.lycoris_config, os.R_OK ): raise ValueError( - f"Could not find the JSON configuration file at {args.lycoris_config}" + f"Could not find the JSON configuration file at '{args.lycoris_config}'" ) import json @@ -2438,11 +2480,11 @@ def parse_cmdline_args(input_args=None): elif "standard" == args.lora_type.lower(): if hasattr(args, "lora_init_type") and args.lora_init_type is not None: if torch.backends.mps.is_available() and args.lora_init_type == "loftq": - logger.error( + error_log( "Apple MPS cannot make use of LoftQ initialisation. Overriding to 'default'." ) elif args.is_quantized and args.lora_init_type == "loftq": - logger.error( + error_log( "LoftQ initialisation is not supported with quantised models. Overriding to 'default'." ) else: @@ -2451,7 +2493,7 @@ def parse_cmdline_args(input_args=None): ) if args.use_dora: if "quanto" in args.base_model_precision: - logger.error( + error_log( "Quanto does not yet support DoRA training in PEFT. Disabling DoRA. 😴" ) args.use_dora = False @@ -2488,4 +2530,16 @@ def parse_cmdline_args(input_args=None): logger.error(f"Could not load skip layers: {e}") raise + if args.enable_xformers_memory_efficient_attention: + if args.attention_mechanism != "xformers": + warning_log( + "The option --enable_xformers_memory_efficient_attention is deprecated. Please use --attention_mechanism=xformers instead." + ) + args.attention_mechanism = "xformers" + + if args.attention_mechanism != "diffusers" and not torch.cuda.is_available(): + warning_log( + "For non-CUDA systems, only Diffusers attention mechanism is officially supported." + ) + return args diff --git a/helpers/kolors/pipeline.py b/helpers/kolors/pipeline.py index 964a8e95..2f395e4d 100644 --- a/helpers/kolors/pipeline.py +++ b/helpers/kolors/pipeline.py @@ -1249,11 +1249,22 @@ def denoising_value_valid(dnv): # unscale/denormalize the latents latents = latents / self.vae.config.scaling_factor + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( - latents.to(device=self.vae.device, dtype=self.vae.dtype), - return_dict=False, + latents.to(device=self.vae.device), return_dict=False )[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) diff --git a/helpers/legacy/pipeline.py b/helpers/legacy/pipeline.py index 01b0ffe8..c186b48d 100644 --- a/helpers/legacy/pipeline.py +++ b/helpers/legacy/pipeline.py @@ -1160,18 +1160,27 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + has_nsfw_concept = None if not output_type == "latent": + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( - latents.to(self.vae.dtype) / self.vae.config.scaling_factor, + latents.to(dtype=self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False, generator=generator, )[0] - image, has_nsfw_concept = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) else: image = latents - has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] diff --git a/helpers/models/flux/pipeline.py b/helpers/models/flux/pipeline.py index ba7eaa44..1d152def 100644 --- a/helpers/models/flux/pipeline.py +++ b/helpers/models/flux/pipeline.py @@ -906,11 +906,22 @@ def __call__( latents = ( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) image = self.vae.decode( - latents.to(device=self.vae.device, dtype=self.vae.dtype), - return_dict=False, + latents.to(dtype=self.vae.dtype), return_dict=False )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models diff --git a/helpers/models/flux/transformer.py b/helpers/models/flux/transformer.py index 7a9c80f2..77097648 100644 --- a/helpers/models/flux/transformer.py +++ b/helpers/models/flux/transformer.py @@ -489,11 +489,16 @@ def __init__( ) self.gradient_checkpointing = False + # added for users to disable checkpointing every nth step + self.gradient_checkpointing_interval = None def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + def set_gradient_checkpointing_interval(self, value: int): + self.gradient_checkpointing_interval = value + def forward( self, hidden_states: torch.Tensor, @@ -574,7 +579,14 @@ def forward( image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if ( + self.training + and self.gradient_checkpointing + and ( + self.gradient_checkpointing_interval is None + or index_block % self.gradient_checkpointing_interval == 0 + ) + ): def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -614,7 +626,14 @@ def custom_forward(*inputs): hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): - if self.training and self.gradient_checkpointing: + if ( + self.training + and self.gradient_checkpointing + or ( + self.gradient_checkpointing_interval is not None + and index_block % self.gradient_checkpointing_interval == 0 + ) + ): def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/helpers/models/omnigen/pipeline.py b/helpers/models/omnigen/pipeline.py index fadbaf05..5693967f 100644 --- a/helpers/models/omnigen/pipeline.py +++ b/helpers/models/omnigen/pipeline.py @@ -345,9 +345,20 @@ def __call__( ) else: samples = samples / self.vae.config.scaling_factor - samples = self.vae.decode( - samples.to(dtype=self.vae.dtype, device=self.vae.device) - ).sample + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + + image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) if self.model_cpu_offload: self.vae.to("cpu") diff --git a/helpers/models/pixart/pipeline.py b/helpers/models/pixart/pipeline.py index 6efb54cd..244df2e2 100644 --- a/helpers/models/pixart/pipeline.py +++ b/helpers/models/pixart/pipeline.py @@ -1231,11 +1231,24 @@ def denoising_value_valid(dnv): callback(step_idx, t, latents) if not output_type == "latent": + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( - latents.to(device=self.vae.device, dtype=self.vae.dtype) - / self.vae.config.scaling_factor, + latents.to(dtype=self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False, + generator=generator, )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor( image, orig_width, orig_height diff --git a/helpers/models/sd3/pipeline.py b/helpers/models/sd3/pipeline.py index 653c2a6a..70698f6e 100644 --- a/helpers/models/sd3/pipeline.py +++ b/helpers/models/sd3/pipeline.py @@ -999,11 +999,14 @@ def __call__( continue # expand the latents if we are doing classifier free guidance + # added fix from: https://github.com/huggingface/diffusers/pull/10086/files + # to allow for num_images_per_prompt > 1 latent_model_input = ( torch.cat([latents] * 2) - if self.do_classifier_free_guidance and skip_guidance_layers is None + if self.do_classifier_free_guidance else latents ) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -1033,6 +1036,8 @@ def __call__( else False ) if skip_guidance_layers is not None and should_skip_layers: + timestep = t.expand(latents.shape[0]) + latent_model_input = latents noise_pred_skip_layers = self.transformer( hidden_states=latent_model_input.to( device=self.transformer.device, @@ -1097,7 +1102,22 @@ def __call__( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor - image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + + image = self.vae.decode( + latents.to(dtype=self.vae.dtype), return_dict=False + )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models @@ -2053,7 +2073,22 @@ def __call__( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor - image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + + image = self.vae.decode( + latents.to(dtype=self.vae.dtype), return_dict=False + )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models diff --git a/helpers/models/sdxl/pipeline.py b/helpers/models/sdxl/pipeline.py index bfd93f29..10b01216 100644 --- a/helpers/models/sdxl/pipeline.py +++ b/helpers/models/sdxl/pipeline.py @@ -1488,10 +1488,22 @@ def __call__( else: latents = latents / self.vae.config.scaling_factor + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( latents.to(dtype=self.vae.dtype), return_dict=False )[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index 92bd3568..2bfe4104 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -119,6 +119,7 @@ def ema_info(args): return ema_information return "" + def lycoris_download_info(): """output a function to download the adapter""" output_fn = """ @@ -556,7 +557,8 @@ def save_model_card( - Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''} - Trainable parameter precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'} - Caption dropout probability: {StateTracker.get_args().caption_dropout_probability * 100}% -{'- Xformers: Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else ''} +{'- Xformers: Enabled' if StateTracker.get_args().attention_mechanism == 'xformers' else ''} +{f'- SageAttention: Enabled {StateTracker.get_args().sageattention_usage}' if StateTracker.get_args().attention_mechanism == 'sageattention' else ''} {lora_info(args=StateTracker.get_args())} ## Datasets diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index f586972c..e6522a40 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -116,4 +116,38 @@ def safety_check(args, accelerator): logger.error( f"--flux_schedule_auto_shift cannot be combined with --flux_schedule_shift. Please set --flux_schedule_shift to 0 if you want to train with --flux_schedule_auto_shift." ) - sys.exit(1) \ No newline at end of file + sys.exit(1) + + if args.attention_mechanism == "sageattention": + if args.sageattention_usage != "inference": + logger.error( + f"SageAttention usage is set to '{args.sageattention_usage}' instead of 'inference'. This is not an officially supported configuration, please be sure you understand the implications. It is recommended to set this value to 'inference' for safety." + ) + if args.enable_xformers_memory_efficient_attention: + logger.error( + f"--enable_xformers_memory_efficient_attention is only compatible with --attention_mechanism=diffusers. Please set --attention_mechanism=diffusers to enable this feature or disable xformers to use alternative attention mechanisms." + ) + sys.exit(1) + if "nf4" in args.base_model_precision: + logger.error( + f"{args.base_model_precision} is not supported with SageAttention. Please select from int8 or fp8, or, disable quantisation to use SageAttention." + ) + sys.exit(1) + + gradient_checkpointing_interval_supported_models = [ + "flux", + "sdxl", + ] + if args.gradient_checkpointing_interval is not None: + if ( + args.model_family.lower() + not in gradient_checkpointing_interval_supported_models + ): + logger.error( + f"Gradient checkpointing is not supported with {args.model_family} models. Please disable --gradient_checkpointing_interval by setting it to None, or remove it from your configuration. Currently supported models: {gradient_checkpointing_interval_supported_models}" + ) + sys.exit(1) + if args.gradient_checkpointing_interval == 0: + raise ValueError( + "Gradient checkpointing interval must be greater than 0. Please set it to a positive integer." + ) diff --git a/helpers/training/diffusion_model.py b/helpers/training/diffusion_model.py index 5ac0c207..290858bc 100644 --- a/helpers/training/diffusion_model.py +++ b/helpers/training/diffusion_model.py @@ -52,7 +52,9 @@ def load_diffusion_model(args, weight_dtype): elif ( args.model_family.lower() == "flux" and not args.flux_attention_masked_training ): - from diffusers.models import FluxTransformer2DModel + from helpers.models.flux.transformer import ( + FluxTransformer2DModelWithMasking as FluxTransformer2DModel, + ) import torch if torch.cuda.is_available(): @@ -92,6 +94,10 @@ def load_diffusion_model(args, weight_dtype): subfolder=determine_subfolder(args.pretrained_transformer_subfolder), **pretrained_load_args, ) + if args.gradient_checkpointing_interval is not None: + transformer.set_gradient_checkpointing_interval( + int(args.gradient_checkpointing_interval) + ) elif args.model_family.lower() == "flux" and args.flux_attention_masked_training: from helpers.models.flux.transformer import ( FluxTransformer2DModelWithMasking, @@ -103,6 +109,10 @@ def load_diffusion_model(args, weight_dtype): subfolder=determine_subfolder(args.pretrained_transformer_subfolder), **pretrained_load_args, ) + if args.gradient_checkpointing_interval is not None: + transformer.set_gradient_checkpointing_interval( + int(args.gradient_checkpointing_interval) + ) elif args.model_family == "pixart_sigma": from diffusers.models import PixArtTransformer2DModel @@ -145,5 +155,22 @@ def load_diffusion_model(args, weight_dtype): subfolder=determine_subfolder(args.pretrained_unet_subfolder), **pretrained_load_args, ) + if ( + args.gradient_checkpointing_interval is not None + and args.gradient_checkpointing_interval > 0 + ): + logger.warning( + "Using experimental gradient checkpointing monkeypatch for a checkpoint interval of {}".format( + args.gradient_checkpointing_interval + ) + ) + # monkey-patch the gradient checkpointing function for pytorch to run every nth call only. + # definitely one of the more awful things I've ever done while programming, but it's easier than + # modifying every one of the unet blocks' forward calls in Diffusers to make it work properly. + from helpers.training.gradient_checkpointing_interval import ( + set_checkpoint_interval, + ) + + set_checkpoint_interval(int(args.gradient_checkpointing_interval)) return unet, transformer diff --git a/helpers/training/ema.py b/helpers/training/ema.py index d2153796..7ec7b991 100644 --- a/helpers/training/ema.py +++ b/helpers/training/ema.py @@ -267,7 +267,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None context_manager = contextlib.nullcontext if ( is_transformers_available() - and transformers.deepspeed.is_deepspeed_zero3_enabled() + and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled() ): import deepspeed @@ -309,7 +309,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None for s_param, param in zip(self.shadow_params, parameters): if ( is_transformers_available() - and transformers.deepspeed.is_deepspeed_zero3_enabled() + and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled() ): context_manager = deepspeed.zero.GatheredParameters( param, modifier_rank=None @@ -374,7 +374,9 @@ def cuda(self, device=None): def cpu(self): return self.to(device="cpu") - def state_dict(self, destination=None, prefix="", keep_vars=False, exclude_params: bool = False): + def state_dict( + self, destination=None, prefix="", keep_vars=False, exclude_params: bool = False + ): r""" Returns a dictionary containing a whole state of the EMA model. """ diff --git a/helpers/training/gradient_checkpointing_interval.py b/helpers/training/gradient_checkpointing_interval.py new file mode 100644 index 00000000..18026044 --- /dev/null +++ b/helpers/training/gradient_checkpointing_interval.py @@ -0,0 +1,42 @@ +import torch +from torch.utils.checkpoint import checkpoint as original_checkpoint + + +# Global variables to keep track of the checkpointing state +_checkpoint_call_count = 0 +_checkpoint_interval = 4 # You can set this to any interval you prefer + + +def reset_checkpoint_counter(): + """Resets the checkpoint call counter. Call this at the beginning of the forward pass.""" + global _checkpoint_call_count + _checkpoint_call_count = 0 + + +def set_checkpoint_interval(n): + """Sets the interval at which checkpointing is skipped.""" + global _checkpoint_interval + _checkpoint_interval = n + + +def checkpoint_wrapper(function, *args, use_reentrant=True, **kwargs): + """Wrapper function for torch.utils.checkpoint.checkpoint.""" + global _checkpoint_call_count, _checkpoint_interval + _checkpoint_call_count += 1 + + if ( + _checkpoint_interval > 0 + and (_checkpoint_call_count % _checkpoint_interval) == 0 + ): + # Use the original checkpoint function + return original_checkpoint( + function, *args, use_reentrant=use_reentrant, **kwargs + ) + else: + # Skip checkpointing: execute the function directly + # Do not pass 'use_reentrant' to the function + return function(*args, **kwargs) + + +# Monkeypatch torch.utils.checkpoint.checkpoint +torch.utils.checkpoint.checkpoint = checkpoint_wrapper diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index a510027d..f6a15696 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -997,8 +997,11 @@ def init_post_load_freeze(self): self.transformer = apply_bitfit_freezing( unwrap_model(self.accelerator, self.transformer), self.config ) + self.enable_gradient_checkpointing() + def enable_gradient_checkpointing(self): if self.config.gradient_checkpointing: + logger.info("Enabling gradient checkpointing.") if self.unet is not None: unwrap_model( self.accelerator, self.unet @@ -1022,6 +1025,32 @@ def init_post_load_freeze(self): self.accelerator, self.text_encoder_2 ).gradient_checkpointing_enable() + def disable_gradient_checkpointing(self): + if self.config.gradient_checkpointing: + logger.info("Disabling gradient checkpointing.") + if self.unet is not None: + unwrap_model( + self.accelerator, self.unet + ).disable_gradient_checkpointing() + if self.transformer is not None and self.config.model_family != "smoldit": + unwrap_model( + self.accelerator, self.transformer + ).disable_gradient_checkpointing() + if self.config.controlnet: + unwrap_model( + self.accelerator, self.controlnet + ).disable_gradient_checkpointing() + if ( + hasattr(self.config, "train_text_encoder") + and self.config.train_text_encoder + ): + unwrap_model( + self.accelerator, self.text_encoder_1 + ).gradient_checkpointing_disable() + unwrap_model( + self.accelerator, self.text_encoder_2 + ).gradient_checkpointing_disable() + def _get_trainable_parameters(self): # Return just a list of the currently trainable parameters. if self.config.model_type == "lora": @@ -1613,6 +1642,98 @@ def resume_and_prepare(self): lr_scheduler = self.init_resume_checkpoint(lr_scheduler=lr_scheduler) self.init_post_load_freeze() + def enable_sageattention_inference(self): + # if the sageattention is inference-only, we'll enable it. + # if it's training only, we'll disable it. + # if it's inference+training, we leave it alone. + if ( + "sageattention" not in self.config.attention_mechanism + or self.config.sageattention_usage == "training+inference" + ): + return + if self.config.sageattention_usage == "inference": + self.enable_sageattention() + if self.config.sageattention_usage == "training": + self.disable_sageattention() + + def disable_sageattention_inference(self): + # if the sageattention is inference-only, we'll disable it. + # if it's training only, we'll enable it. + # if it's inference+training, we leave it alone. + if ( + "sageattention" not in self.config.attention_mechanism + or self.config.sageattention_usage == "training+inference" + ): + return + if self.config.sageattention_usage == "inference": + self.disable_sageattention() + if self.config.sageattention_usage == "training": + self.enable_sageattention() + + def disable_sageattention(self): + if "sageattention" not in self.config.attention_mechanism: + return + + if ( + hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa") + and torch.nn.functional + != torch.nn.functional.scaled_dot_product_attention_sdpa + ): + logger.info("Disabling SageAttention.") + setattr( + torch.nn.functional, + "scaled_dot_product_attention", + torch.nn.functional.scaled_dot_product_attention_sdpa, + ) + + def enable_sageattention(self): + if "sageattention" not in self.config.attention_mechanism: + return + + # we'll try and load SageAttention and overload pytorch's sdpa function. + try: + logger.info("Enabling SageAttention.") + from sageattention import ( + sageattn, + sageattn_qk_int8_pv_fp16_triton, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp8_cuda, + ) + + sageattn_functions = { + "sageattention": sageattn, + "sageattention-int8-fp16-triton": sageattn_qk_int8_pv_fp16_triton, + "sageattention-int8-fp16-cuda": sageattn_qk_int8_pv_fp16_cuda, + "sageattention-int8-fp8-cuda": sageattn_qk_int8_pv_fp8_cuda, + } + # store the old SDPA for validations to use during VAE decode + if not hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + setattr( + torch.nn.functional, + "scaled_dot_product_attention_sdpa", + torch.nn.functional.scaled_dot_product_attention, + ) + torch.nn.functional.scaled_dot_product_attention = sageattn_functions.get( + self.config.attention_mechanism, "sageattention" + ) + if not hasattr(torch.nn.functional, "scaled_dot_product_attention_sage"): + setattr( + torch.nn.functional, + "scaled_dot_product_attention_sage", + torch.nn.functional.scaled_dot_product_attention, + ) + + if "training" in self.config.sageattention_usage: + logger.warning( + f"Using {self.config.attention_mechanism} for attention calculations during training. Your attention layers will not be trained. To disable SageAttention, remove or set --attention_mechanism to a different value." + ) + except ImportError as e: + logger.error( + "Could not import SageAttention. Please install it to use this --attention_mechanism=sageattention." + ) + logger.error(repr(e)) + sys.exit(1) + def move_models(self, destination: str = "accelerator"): target_device = "cpu" if destination == "accelerator": @@ -1636,8 +1757,17 @@ def move_models(self, destination: str = "accelerator"): target_device, dtype=self.config.weight_dtype ) ) + if ( - self.config.enable_xformers_memory_efficient_attention + "sageattention" in self.config.attention_mechanism + and "training" in self.config.sageattention_usage + ): + logger.info( + "Using SageAttention for training. This is an unsupported, experimental configuration." + ) + self.enable_sageattention() + elif ( + self.config.attention_mechanism == "xformers" and self.config.model_family not in [ "sd3", @@ -1661,11 +1791,14 @@ def move_models(self, destination: str = "accelerator"): raise ValueError( "xformers is not available. Make sure it is installed correctly" ) - elif self.config.enable_xformers_memory_efficient_attention: + elif self.config.attention_mechanism == "xformers": logger.warning( "xformers is not enabled, as it is incompatible with this model type." + " Falling back to diffusers attention mechanism (Pytorch SDPA)." + " Alternatively, provide --attention_mechanism=sageattention for a more efficient option on CUDA systems." ) self.config.enable_xformers_memory_efficient_attention = False + self.config.attention_mechanism = "diffusers" if self.config.controlnet: self.controlnet.train() @@ -2083,7 +2216,11 @@ def train(self): self.mark_optimizer_eval() # normal run-of-the-mill validation on startup. if self.validation is not None: + self.enable_sageattention_inference() + self.disable_gradient_checkpointing() self.validation.run_validations(validation_type="base_model", step=0) + self.disable_sageattention_inference() + self.enable_gradient_checkpointing() self.mark_optimizer_train() @@ -2805,9 +2942,15 @@ def train(self): progress_bar.set_postfix(**logs) self.mark_optimizer_eval() if self.validation is not None: + if self.validation.would_validate(): + self.enable_sageattention_inference() + self.disable_gradient_checkpointing() self.validation.run_validations( validation_type="intermediary", step=step ) + if self.validation.would_validate(): + self.disable_sageattention_inference() + self.enable_gradient_checkpointing() self.mark_optimizer_train() if ( self.config.push_to_hub @@ -2856,12 +2999,16 @@ def train(self): if self.accelerator.is_main_process: self.mark_optimizer_eval() if self.validation is not None: + self.enable_sageattention_inference() + self.disable_gradient_checkpointing() validation_images = self.validation.run_validations( validation_type="final", step=self.state["global_step"], force_evaluation=True, skip_execution=True, ).validation_images + # we don't have to do this but we will anyway. + self.disable_sageattention_inference() if self.unet is not None: self.unet = unwrap_model(self.accelerator, self.unet) if self.transformer is not None: diff --git a/helpers/training/validation.py b/helpers/training/validation.py index fe001cf2..c5a062bb 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -745,7 +745,11 @@ def _benchmark_path(self, benchmark: str = "base_model"): return os.path.join(self.args.output_dir, "benchmarks", benchmark) def stitch_benchmark_image( - self, validation_image_result, benchmark_image, separator_width=5, labels=["base model", "checkpoint"] + self, + validation_image_result, + benchmark_image, + separator_width=5, + labels=["base model", "checkpoint"], ): """ For each image, make a new canvas and place it side by side with its equivalent from {self.validation_image_inputs} @@ -754,7 +758,9 @@ def stitch_benchmark_image( """ # Calculate new dimensions - new_width = benchmark_image.size[0] + validation_image_result.size[0] + separator_width + new_width = ( + benchmark_image.size[0] + validation_image_result.size[0] + separator_width + ) new_height = benchmark_image.size[1] # Create a new image with a white background @@ -877,6 +883,18 @@ def _update_state(self): self.global_step = StateTracker.get_global_step() self.global_resume_step = StateTracker.get_global_resume_step() or 1 + def would_validate( + self, + step: int = 0, + validation_type="intermediary", + force_evaluation: bool = False, + ): + # a wrapper for should_perform_validation that can run in the training loop + self._update_state() + return self.should_perform_validation( + step, self.validation_prompts, validation_type + ) or (step == 0 and validation_type == "base_model") + def run_validations( self, step: int = 0, @@ -1168,9 +1186,11 @@ def process_prompts(self): ) self.validation_prompt_dict[shortname] = prompt logger.debug(f"Processing validation for prompt: {prompt}") - stitched_validation_images, checkpoint_validation_images, ema_validation_images = ( - self.validate_prompt(prompt, shortname, validation_input_image) - ) + ( + stitched_validation_images, + checkpoint_validation_images, + ema_validation_images, + ) = self.validate_prompt(prompt, shortname, validation_input_image) validation_images.update(stitched_validation_images) self._save_images(validation_images, shortname, prompt) logger.debug(f"Completed generating image: {prompt}") @@ -1364,15 +1384,17 @@ def validate_prompt( ) if current_validation_type == "ema": self.enable_ema_for_inference() - all_validation_type_results[current_validation_type] = self.pipeline( - **pipeline_kwargs - ).images + all_validation_type_results[current_validation_type] = ( + self.pipeline(**pipeline_kwargs).images + ) if current_validation_type == "ema": self.disable_ema_for_inference() # retrieve the default image result for stitching to controlnet inputs. ema_image_results = all_validation_type_results.get("ema") - validation_image_results = all_validation_type_results.get("checkpoint", ema_image_results) + validation_image_results = all_validation_type_results.get( + "checkpoint", ema_image_results + ) original_validation_image_results = validation_image_results benchmark_image = None if self.args.controlnet: @@ -1387,7 +1409,9 @@ def validate_prompt( validation_shortname, resolution ) if benchmark_image is not None: - for idx, validation_image in enumerate(validation_image_results): + for idx, validation_image in enumerate( + validation_image_results + ): validation_image_results[idx] = self.stitch_benchmark_image( validation_image_result=validation_image, benchmark_image=benchmark_image, @@ -1396,9 +1420,13 @@ def validate_prompt( checkpoint_validation_images[validation_shortname].extend( original_validation_image_results ) - stitched_validation_images[validation_shortname].extend(validation_image_results) + stitched_validation_images[validation_shortname].extend( + validation_image_results + ) if self.args.use_ema: - ema_validation_images[validation_shortname].extend(ema_image_results) + ema_validation_images[validation_shortname].extend( + ema_image_results + ) except Exception as e: import traceback @@ -1407,15 +1435,31 @@ def validate_prompt( f"Error generating validation image: {e}, {traceback.format_exc()}" ) continue - if self.args.use_ema and self.args.ema_validation == "comparison" and benchmark_image is not None: - for idx, validation_image in enumerate(stitched_validation_images[validation_shortname]): - stitched_validation_images[validation_shortname][idx] = self.stitch_benchmark_image( - validation_image_result=ema_validation_images[validation_shortname][idx], - benchmark_image=stitched_validation_images[validation_shortname][idx], - labels=[None, "EMA"] + if ( + self.args.use_ema + and self.args.ema_validation == "comparison" + and benchmark_image is not None + ): + for idx, validation_image in enumerate( + stitched_validation_images[validation_shortname] + ): + stitched_validation_images[validation_shortname][idx] = ( + self.stitch_benchmark_image( + validation_image_result=ema_validation_images[ + validation_shortname + ][idx], + benchmark_image=stitched_validation_images[ + validation_shortname + ][idx], + labels=[None, "EMA"], + ) ) - return stitched_validation_images, checkpoint_validation_images, ema_validation_images + return ( + stitched_validation_images, + checkpoint_validation_images, + ema_validation_images, + ) def _save_images(self, validation_images, validation_shortname, validation_prompt): validation_img_idx = 0 @@ -1549,9 +1593,15 @@ def enable_ema_for_inference(self, pipeline=None): logger.info("Setting Lycoris multiplier to 1.0") self.accelerator._lycoris_wrapped_network.set_multiplier(1.0) logger.info("Storing Lycoris weights for later recovery.") - self.ema_model.store(self.accelerator._lycoris_wrapped_network.parameters()) - logger.info("Storing the EMA weights into the Lycoris adapter for inference.") - self.ema_model.copy_to(self.accelerator._lycoris_wrapped_network.parameters()) + self.ema_model.store( + self.accelerator._lycoris_wrapped_network.parameters() + ) + logger.info( + "Storing the EMA weights into the Lycoris adapter for inference." + ) + self.ema_model.copy_to( + self.accelerator._lycoris_wrapped_network.parameters() + ) elif self.args.lora_type.lower() == "standard": _trainable_parameters = [ x for x in self._primary_model().parameters() if x.requires_grad @@ -1583,11 +1633,16 @@ def disable_ema_for_inference(self): if self.args.use_ema: logger.info("Disabling EMA.") self.ema_enabled = False - if self.args.model_type == "lora" and self.args.lora_type.lower() == "lycoris": + if ( + self.args.model_type == "lora" + and self.args.lora_type.lower() == "lycoris" + ): logger.info("Setting Lycoris network multiplier to 1.0.") self.accelerator._lycoris_wrapped_network.set_multiplier(1.0) logger.info("Restoring Lycoris weights.") - self.ema_model.restore(self.accelerator._lycoris_wrapped_network.parameters()) + self.ema_model.restore( + self.accelerator._lycoris_wrapped_network.parameters() + ) else: logger.info("Restoring trainable parameters.") self.ema_model.restore(self.trainable_parameters()) @@ -1601,7 +1656,6 @@ def disable_ema_for_inference(self): "Skipping EMA model restoration for validation, as we are not using EMA." ) - def finalize_validation(self, validation_type): """Cleans up and restores original state if necessary.""" if not self.args.keep_vae_loaded and not self.args.vae_cache_ondemand: diff --git a/tests/test_model_card.py b/tests/test_model_card.py index 1c596be5..2e2ed5b1 100644 --- a/tests/test_model_card.py +++ b/tests/test_model_card.py @@ -65,6 +65,7 @@ def setUp(self): self.args.flux_guidance_value = 1.0 self.args.t5_padding = "unmodified" self.args.enable_xformers_memory_efficient_attention = False + self.args.attention_mechanism = "diffusers" def test_model_imports(self): self.args.lora_type = "standard" diff --git a/tests/test_trainer.py b/tests/test_trainer.py index c54789d6..3a036762 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -140,6 +140,7 @@ def test_stats_memory_used_none( flux_schedule_shift=3, flux_schedule_auto_shift=False, validation_guidance_skip_layers=None, + gradient_checkpointing_interval=None, ), ) def test_misc_init(