diff --git a/OPTIONS.md b/OPTIONS.md index edeea098..ddc32383 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -109,6 +109,14 @@ Carefully answer the questions and use bf16 mixed precision training when prompt Note that the first several steps of training will be slower than usual because of compilation occuring in the background. +### `--attention_mechanism` + +Setting `sageattention` or `xformers` here will allow the use of other memory-efficient attention mechanisms for the forward pass during training and inference, potentially resulting in major performance improvement. + +Using `sageattention` enables the use of [SageAttention](https://github.com/thu-ml/SageAttention) on NVIDIA CUDA equipment (sorry, AMD and Apple users). + +In simple terms, this will quantise the attention calculations for lower compute and memory overhead, **massively** speeding up training while minimally impacting quality. + --- ## 📰 Publishing @@ -452,7 +460,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--lr_scheduler {linear,sine,cosine,cosine_with_restarts,polynomial,constant,constant_with_warmup}] [--lr_warmup_steps LR_WARMUP_STEPS] [--lr_num_cycles LR_NUM_CYCLES] [--lr_power LR_POWER] - [--use_ema] [--ema_device {cpu,accelerator}] [--ema_cpu_only] + [--use_ema] [--ema_device {cpu,accelerator}] + [--ema_validation {none,ema_only,comparison}] [--ema_cpu_only] [--ema_foreach_disable] [--ema_update_interval EMA_UPDATE_INTERVAL] [--ema_decay EMA_DECAY] [--non_ema_revision NON_EMA_REVISION] @@ -473,8 +482,9 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--model_card_safe_for_work] [--logging_dir LOGGING_DIR] [--benchmark_base_model] [--disable_benchmark] [--evaluation_type {clip,none}] - [--pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path] + [--pretrained_evaluation_model_name_or_path PRETRAINED_EVALUATION_MODEL_NAME_OR_PATH] [--validation_on_startup] [--validation_seed_source {gpu,cpu}] + [--validation_lycoris_strength VALIDATION_LYCORIS_STRENGTH] [--validation_torch_compile] [--validation_torch_compile_mode {max-autotune,reduce-overhead,default}] [--validation_guidance_skip_layers VALIDATION_GUIDANCE_SKIP_LAYERS] @@ -509,6 +519,7 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--text_encoder_2_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}] [--text_encoder_3_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}] [--local_rank LOCAL_RANK] + [--attention_mechanism {diffusers,xformers,sageattention,sageattention-int8-fp16-triton,sageattention-int8-fp16-cuda,sageattention-int8-fp8-cuda}] [--enable_xformers_memory_efficient_attention] [--set_grads_to_none] [--noise_offset NOISE_OFFSET] [--noise_offset_probability NOISE_OFFSET_PROBABILITY] @@ -1137,12 +1148,21 @@ options: cosine_with_restarts scheduler. --lr_power LR_POWER Power factor of the polynomial scheduler. --use_ema Whether to use EMA (exponential moving average) model. + Works with LoRA, Lycoris, and full training. --ema_device {cpu,accelerator} The device to use for the EMA model. If set to 'accelerator', the EMA model will be placed on the accelerator. This provides the fastest EMA update times, but is not ultimately necessary for EMA to function. + --ema_validation {none,ema_only,comparison} + When 'none' is set, no EMA validation will be done. + When using 'ema_only', the validations will rely + mostly on the EMA weights. When using 'comparison' + (default) mode, the validations will first run on the + checkpoint before also running for the EMA weights. In + comparison mode, the resulting images will be provided + side-by-side. --ema_cpu_only When using EMA, the shadow model is moved to the accelerator before we update its parameters. When provided, this option will disable the moving of the @@ -1248,7 +1268,7 @@ options: function. The default is to use no evaluator, and 'clip' will use a CLIP model to evaluate the resulting model's performance during validations. - --pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path + --pretrained_evaluation_model_name_or_path PRETRAINED_EVALUATION_MODEL_NAME_OR_PATH Optionally provide a custom model to use for ViT evaluations. The default is currently clip-vit-large- patch14-336, allowing for lower patch sizes (greater @@ -1264,6 +1284,12 @@ options: validation errors. If so, please set SIMPLETUNER_LOG_LEVEL=DEBUG and submit debug.log to a new Github issue report. + --validation_lycoris_strength VALIDATION_LYCORIS_STRENGTH + When inferencing for validations, the Lycoris model + will by default be run at its training strength, 1.0. + However, this value can be increased to a value of + around 1.3 or 1.5 to get a stronger effect from the + model. --validation_torch_compile Supply `--validation_torch_compile=true` to enable the use of torch.compile() on the validation pipeline. For @@ -1453,6 +1479,20 @@ options: quantisation (Apple Silicon, NVIDIA, AMD). --local_rank LOCAL_RANK For distributed training: local_rank + --attention_mechanism {diffusers,xformers,sageattention,sageattention-int8-fp16-triton,sageattention-int8-fp16-cuda,sageattention-int8-fp8-cuda} + On NVIDIA CUDA devices, alternative flash attention + implementations are offered, with the default being + native pytorch SDPA. SageAttention has multiple + backends to select from. The recommended value, + 'sageattention', guesses what would be the 'best' + option for SageAttention on your hardware (usually + this is the int8-fp16-cuda backend). However, manually + setting this value to int8-fp16-triton may provide + better averages for per-step training and inference + performance while the cuda backend may provide the + highest maximum speed (with also a lower minimum + speed). NOTE: SageAttention training quality has not + been validated. --enable_xformers_memory_efficient_attention Whether or not to use xformers. --set_grads_to_none Save more memory by using setting grads to None diff --git a/configure.py b/configure.py index f14e13b9..b5c26d28 100644 --- a/configure.py +++ b/configure.py @@ -429,7 +429,20 @@ def configure_env(): ).lower() == "y" ) - report_to_str = "" + + env_contents["--attention_mechanism"] = "diffusers" + use_sageattention = ( + prompt_user( + "Would you like to use SageAttention? This is an experimental option that can greatly speed up training. (y/[n])", + "n", + ).lower() + == "y" + ) + if use_sageattention: + env_contents["--attention_mechanism"] = "sageattention" + + # properly disable wandb/tensorboard/comet_ml etc by default + report_to_str = "none" if report_to_wandb or report_to_tensorboard: tracker_project_name = prompt_user( "Enter the name of your Weights & Biases project", f"{model_type}-training" @@ -440,17 +453,17 @@ def configure_env(): f"simpletuner-{model_type}", ) env_contents["--tracker_run_name"] = tracker_run_name - report_to_str = None if report_to_wandb: report_to_str = "wandb" if report_to_tensorboard: - if report_to_wandb: + if report_to_str != "none": + # report to both WandB and Tensorboard if the user wanted. report_to_str += "," else: + # remove 'none' from the option report_to_str = "" report_to_str += "tensorboard" - if report_to_str: - env_contents["--report_to"] = report_to_str + env_contents["--report_to"] = report_to_str print_config(env_contents, extra_args) diff --git a/documentation/LYCORIS.md b/documentation/LYCORIS.md index 64aba565..853d3ec6 100644 --- a/documentation/LYCORIS.md +++ b/documentation/LYCORIS.md @@ -59,6 +59,14 @@ Mandatory fields: For more information on LyCORIS, please refer to the [documentation in the library](https://github.com/KohakuBlueleaf/LyCORIS/tree/main/docs). +## Potential problems + +When using Lycoris on SDXL, it's noted that training the FeedForward modules may break the model and send loss into `NaN` (Not-a-Number) territory. + +This seems to be potentially exacerbated when using SageAttention, making it all but guaranteed that the model will immediately fail. + +The solution is to remove the `FeedForward` modules from the lycoris config and train only the `Attention` blocks. + ## LyCORIS Inference Example Here is a simple FLUX.1-dev inference script showing how to wrap your unet or transformer with create_lycoris_from_weights and then use it for inference. diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index ddb86e1d..189b1a72 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -414,9 +414,18 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with: - DeepSpeed: disabled / unconfigured - PyTorch: 2.6 Nightly (Sept 29th build) - Using `--quantize_via=cpu` to avoid outOfMemory error during startup on <=16G cards. +- With `--attention_mechanism=sageattention` to further reduce VRAM by 0.1GB and improve training speed. Speed was approximately 1.4 iterations per second on a 4090. +### SageAttention + +When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations. + +In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training. + +**Note**: This isn't compatible with _every_ configuration, but it's worth trying. + ### NF4-quantised training In simplest terms, NF4 is a 4bit-_ish_ representation of the model, which means training has serious stability concerns to address. @@ -428,6 +437,7 @@ In early tests, the following holds true: - NF4, AdamW8bit, and a higher batch size all help to overcome the stability issues, at the cost of more time spent training or VRAM used - Upping the resolution from 512px to 1024px slows training down from, for example, 1.4 seconds per step to 3.5 seconds per step (batch size of 1, 4090) - Anything that's difficult to train on int8 or bf16 becomes harder in NF4 +- It's less compatible with options like SageAttention NF4 does not work with torch.compile, so whatever you get for speed is what you get. diff --git a/documentation/quickstart/SD3.md b/documentation/quickstart/SD3.md index 5d646562..764bf64a 100644 --- a/documentation/quickstart/SD3.md +++ b/documentation/quickstart/SD3.md @@ -339,6 +339,14 @@ These options have been known to keep SD3.5 in-tact for as long as possible: - DeepSpeed: disabled / unconfigured - PyTorch: 2.5 +### SageAttention + +When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations. + +In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training. + +**Note**: This isn't compatible with _every_ configuration, but it's worth trying. + ### Masked loss If you are training a subject or style and would like to mask one or the other, see the [masked loss training](/documentation/DREAMBOOTH.md#masked-loss) section of the Dreambooth guide. diff --git a/documentation/quickstart/SIGMA.md b/documentation/quickstart/SIGMA.md index cf3389aa..aadd46f2 100644 --- a/documentation/quickstart/SIGMA.md +++ b/documentation/quickstart/SIGMA.md @@ -220,3 +220,11 @@ For more information, see the [dataloader](/documentation/DATALOADER.md) and [tu ### CLIP score tracking If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores. + +### SageAttention + +When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations. + +In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training. + +**Note**: This isn't compatible with _every_ configuration, but it's worth trying. \ No newline at end of file diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 7ea0dc61..fce256e3 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1131,7 +1131,7 @@ def get_argument_parser(): " When using 'ema_only', the validations will rely mostly on the EMA weights." " When using 'comparison' (default) mode, the validations will first run on the checkpoint before also running for" " the EMA weights. In comparison mode, the resulting images will be provided side-by-side." - ) + ), ) parser.add_argument( "--ema_cpu_only", @@ -1708,6 +1708,28 @@ def get_argument_parser(): default=-1, help="For distributed training: local_rank", ) + parser.add_argument( + "--attention_mechanism", + type=str, + choices=[ + "diffusers", + "xformers", + "sageattention", + "sageattention-int8-fp16-triton", + "sageattention-int8-fp16-cuda", + "sageattention-int8-fp8-cuda", + ], + default="diffusers", + help=( + "On NVIDIA CUDA devices, alternative flash attention implementations are offered, with the default being native pytorch SDPA." + " SageAttention has multiple backends to select from." + " The recommended value, 'sageattention', guesses what would be the 'best' option for SageAttention on your hardware" + " (usually this is the int8-fp16-cuda backend). However, manually setting this value to int8-fp16-triton" + " may provide better averages for per-step training and inference performance while the cuda backend" + " may provide the highest maximum speed (with also a lower minimum speed). NOTE: SageAttention training quality" + " has not been validated." + ), + ) parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", @@ -2418,7 +2440,7 @@ def parse_cmdline_args(input_args=None): args.lycoris_config, os.R_OK ): raise ValueError( - f"Could not find the JSON configuration file at {args.lycoris_config}" + f"Could not find the JSON configuration file at '{args.lycoris_config}'" ) import json @@ -2438,11 +2460,11 @@ def parse_cmdline_args(input_args=None): elif "standard" == args.lora_type.lower(): if hasattr(args, "lora_init_type") and args.lora_init_type is not None: if torch.backends.mps.is_available() and args.lora_init_type == "loftq": - logger.error( + error_log( "Apple MPS cannot make use of LoftQ initialisation. Overriding to 'default'." ) elif args.is_quantized and args.lora_init_type == "loftq": - logger.error( + error_log( "LoftQ initialisation is not supported with quantised models. Overriding to 'default'." ) else: @@ -2451,7 +2473,7 @@ def parse_cmdline_args(input_args=None): ) if args.use_dora: if "quanto" in args.base_model_precision: - logger.error( + error_log( "Quanto does not yet support DoRA training in PEFT. Disabling DoRA. 😴" ) args.use_dora = False @@ -2488,4 +2510,16 @@ def parse_cmdline_args(input_args=None): logger.error(f"Could not load skip layers: {e}") raise + if args.enable_xformers_memory_efficient_attention: + if args.attention_mechanism != "xformers": + warning_log( + "The option --enable_xformers_memory_efficient_attention is deprecated. Please use --attention_mechanism=xformers instead." + ) + args.attention_mechanism = "xformers" + + if args.attention_mechanism != "diffusers" and not torch.cuda.is_available(): + warning_log( + "For non-CUDA systems, only Diffusers attention mechanism is officially supported." + ) + return args diff --git a/helpers/kolors/pipeline.py b/helpers/kolors/pipeline.py index 964a8e95..2f395e4d 100644 --- a/helpers/kolors/pipeline.py +++ b/helpers/kolors/pipeline.py @@ -1249,11 +1249,22 @@ def denoising_value_valid(dnv): # unscale/denormalize the latents latents = latents / self.vae.config.scaling_factor + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( - latents.to(device=self.vae.device, dtype=self.vae.dtype), - return_dict=False, + latents.to(device=self.vae.device), return_dict=False )[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) diff --git a/helpers/legacy/pipeline.py b/helpers/legacy/pipeline.py index 01b0ffe8..c186b48d 100644 --- a/helpers/legacy/pipeline.py +++ b/helpers/legacy/pipeline.py @@ -1160,18 +1160,27 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + has_nsfw_concept = None if not output_type == "latent": + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( - latents.to(self.vae.dtype) / self.vae.config.scaling_factor, + latents.to(dtype=self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False, generator=generator, )[0] - image, has_nsfw_concept = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) else: image = latents - has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] diff --git a/helpers/models/flux/pipeline.py b/helpers/models/flux/pipeline.py index ba7eaa44..1d152def 100644 --- a/helpers/models/flux/pipeline.py +++ b/helpers/models/flux/pipeline.py @@ -906,11 +906,22 @@ def __call__( latents = ( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) image = self.vae.decode( - latents.to(device=self.vae.device, dtype=self.vae.dtype), - return_dict=False, + latents.to(dtype=self.vae.dtype), return_dict=False )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models diff --git a/helpers/models/omnigen/pipeline.py b/helpers/models/omnigen/pipeline.py index fadbaf05..5693967f 100644 --- a/helpers/models/omnigen/pipeline.py +++ b/helpers/models/omnigen/pipeline.py @@ -345,9 +345,20 @@ def __call__( ) else: samples = samples / self.vae.config.scaling_factor - samples = self.vae.decode( - samples.to(dtype=self.vae.dtype, device=self.vae.device) - ).sample + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + + image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) if self.model_cpu_offload: self.vae.to("cpu") diff --git a/helpers/models/pixart/pipeline.py b/helpers/models/pixart/pipeline.py index 6efb54cd..244df2e2 100644 --- a/helpers/models/pixart/pipeline.py +++ b/helpers/models/pixart/pipeline.py @@ -1231,11 +1231,24 @@ def denoising_value_valid(dnv): callback(step_idx, t, latents) if not output_type == "latent": + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( - latents.to(device=self.vae.device, dtype=self.vae.dtype) - / self.vae.config.scaling_factor, + latents.to(dtype=self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False, + generator=generator, )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor( image, orig_width, orig_height diff --git a/helpers/models/sd3/pipeline.py b/helpers/models/sd3/pipeline.py index 653c2a6a..70698f6e 100644 --- a/helpers/models/sd3/pipeline.py +++ b/helpers/models/sd3/pipeline.py @@ -999,11 +999,14 @@ def __call__( continue # expand the latents if we are doing classifier free guidance + # added fix from: https://github.com/huggingface/diffusers/pull/10086/files + # to allow for num_images_per_prompt > 1 latent_model_input = ( torch.cat([latents] * 2) - if self.do_classifier_free_guidance and skip_guidance_layers is None + if self.do_classifier_free_guidance else latents ) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -1033,6 +1036,8 @@ def __call__( else False ) if skip_guidance_layers is not None and should_skip_layers: + timestep = t.expand(latents.shape[0]) + latent_model_input = latents noise_pred_skip_layers = self.transformer( hidden_states=latent_model_input.to( device=self.transformer.device, @@ -1097,7 +1102,22 @@ def __call__( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor - image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + + image = self.vae.decode( + latents.to(dtype=self.vae.dtype), return_dict=False + )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models @@ -2053,7 +2073,22 @@ def __call__( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor - image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + + image = self.vae.decode( + latents.to(dtype=self.vae.dtype), return_dict=False + )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models diff --git a/helpers/models/sdxl/pipeline.py b/helpers/models/sdxl/pipeline.py index bfd93f29..10b01216 100644 --- a/helpers/models/sdxl/pipeline.py +++ b/helpers/models/sdxl/pipeline.py @@ -1488,10 +1488,22 @@ def __call__( else: latents = latents / self.vae.config.scaling_factor + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( latents.to(dtype=self.vae.dtype), return_dict=False )[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index 92bd3568..c761f27e 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -119,6 +119,7 @@ def ema_info(args): return ema_information return "" + def lycoris_download_info(): """output a function to download the adapter""" output_fn = """ @@ -556,7 +557,8 @@ def save_model_card( - Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''} - Trainable parameter precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'} - Caption dropout probability: {StateTracker.get_args().caption_dropout_probability * 100}% -{'- Xformers: Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else ''} +{'- Xformers: Enabled' if StateTracker.get_args().attention_mechanism == 'xformers' else ''} +{'- SageAttention: Enabled' if StateTracker.get_args().attention_mechanism == 'sageattention' else ''} {lora_info(args=StateTracker.get_args())} ## Datasets diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index f586972c..518afd63 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -116,4 +116,19 @@ def safety_check(args, accelerator): logger.error( f"--flux_schedule_auto_shift cannot be combined with --flux_schedule_shift. Please set --flux_schedule_shift to 0 if you want to train with --flux_schedule_auto_shift." ) - sys.exit(1) \ No newline at end of file + sys.exit(1) + + if ( + args.enable_xformers_memory_efficient_attention + and args.attention_mechanism == "sageattention" + ): + logger.error( + f"--enable_xformers_memory_efficient_attention is only compatible with --attention_mechanism=diffusers. Please set --attention_mechanism=diffusers to enable this feature or disable xformers to use alternative attention mechanisms." + ) + sys.exit(1) + + if "nf4" in args.base_model_precision: + logger.error( + f"{args.base_model_precision} is not supported with SageAttention. Please select from int8 or fp8, or, disable quantisation to use SageAttention." + ) + sys.exit(1) diff --git a/helpers/training/ema.py b/helpers/training/ema.py index d2153796..7ec7b991 100644 --- a/helpers/training/ema.py +++ b/helpers/training/ema.py @@ -267,7 +267,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None context_manager = contextlib.nullcontext if ( is_transformers_available() - and transformers.deepspeed.is_deepspeed_zero3_enabled() + and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled() ): import deepspeed @@ -309,7 +309,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None for s_param, param in zip(self.shadow_params, parameters): if ( is_transformers_available() - and transformers.deepspeed.is_deepspeed_zero3_enabled() + and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled() ): context_manager = deepspeed.zero.GatheredParameters( param, modifier_rank=None @@ -374,7 +374,9 @@ def cuda(self, device=None): def cpu(self): return self.to(device="cpu") - def state_dict(self, destination=None, prefix="", keep_vars=False, exclude_params: bool = False): + def state_dict( + self, destination=None, prefix="", keep_vars=False, exclude_params: bool = False + ): r""" Returns a dictionary containing a whole state of the EMA model. """ diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index a510027d..a78ede95 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1636,8 +1636,51 @@ def move_models(self, destination: str = "accelerator"): target_device, dtype=self.config.weight_dtype ) ) - if ( - self.config.enable_xformers_memory_efficient_attention + + if "sageattention" in self.config.attention_mechanism: + # we'll try and load SageAttention and overload pytorch's sdpa function. + try: + from sageattention import ( + sageattn, + sageattn_qk_int8_pv_fp16_triton, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp8_cuda, + ) + + sageattn_functions = { + "sageattention": sageattn, + "sageattention-int8-fp16-triton": sageattn_qk_int8_pv_fp16_triton, + "sageattention-int8-fp16-cuda": sageattn_qk_int8_pv_fp16_cuda, + "sageattention-int8-fp8-cuda": sageattn_qk_int8_pv_fp8_cuda, + } + # store the old SDPA for validations to use during VAE decode + setattr( + torch.nn.functional, + "scaled_dot_product_attention_sdpa", + torch.nn.functional.scaled_dot_product_attention, + ) + torch.nn.functional.scaled_dot_product_attention = ( + sageattn_functions.get( + self.config.attention_mechanism, "sageattention" + ) + ) + setattr( + torch.nn.functional, + "scaled_dot_product_attention_sage", + torch.nn.functional.scaled_dot_product_attention, + ) + + logger.warning( + f"Using {self.config.attention_mechanism} for flash attention mechanism. This is an experimental option, and you may receive unexpected or poor results. To disable SageAttention, remove or set --attention_mechanism to a different value." + ) + except ImportError as e: + logger.error( + "Could not import SageAttention. Please install it to use this --attention_mechanism=sageattention." + ) + logger.error(repr(e)) + sys.exit(1) + elif ( + self.config.attention_mechanism == "xformers" and self.config.model_family not in [ "sd3", @@ -1661,11 +1704,14 @@ def move_models(self, destination: str = "accelerator"): raise ValueError( "xformers is not available. Make sure it is installed correctly" ) - elif self.config.enable_xformers_memory_efficient_attention: + elif self.config.attention_mechanism == "xformers": logger.warning( "xformers is not enabled, as it is incompatible with this model type." + " Falling back to diffusers attention mechanism (Pytorch SDPA)." + " Alternatively, provide --attention_mechanism=sageattention for a more efficient option on CUDA systems." ) self.config.enable_xformers_memory_efficient_attention = False + self.config.attention_mechanism = "diffusers" if self.config.controlnet: self.controlnet.train() diff --git a/tests/test_model_card.py b/tests/test_model_card.py index 1c596be5..2e2ed5b1 100644 --- a/tests/test_model_card.py +++ b/tests/test_model_card.py @@ -65,6 +65,7 @@ def setUp(self): self.args.flux_guidance_value = 1.0 self.args.t5_padding = "unmodified" self.args.enable_xformers_memory_efficient_attention = False + self.args.attention_mechanism = "diffusers" def test_model_imports(self): self.args.lora_type = "standard"