From 57ec923caf48aef4f6ece059e59fa386c54d6835 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 30 Nov 2024 09:46:36 -0600 Subject: [PATCH 01/26] flux: use sage attention if available --- helpers/models/flux/attention.py | 83 ++++++++++++++++++++++++++++++ helpers/models/flux/transformer.py | 15 +++++- 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/helpers/models/flux/attention.py b/helpers/models/flux/attention.py index 9009c858..edf2d05b 100644 --- a/helpers/models/flux/attention.py +++ b/helpers/models/flux/attention.py @@ -8,6 +8,12 @@ from flash_attn_interface import flash_attn_func except: pass +try: + from sageattention import sageattn + + F.scaled_dot_product_attention = sageattn +except: + pass def fa3_sdpa( @@ -98,6 +104,83 @@ def __call__( return hidden_states +class FluxSingleSageAttnProcessor3_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn, + hidden_states: Tensor, + encoder_hidden_states: Tensor = None, + attention_mask: FloatTensor = None, + image_rotary_emb: Tensor = None, + ) -> Tensor: + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, _, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = fa3_sdpa(query, key, value) + hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + return hidden_states + + class FluxAttnProcessor3_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" diff --git a/helpers/models/flux/transformer.py b/helpers/models/flux/transformer.py index 7a9c80f2..806bcc27 100644 --- a/helpers/models/flux/transformer.py +++ b/helpers/models/flux/transformer.py @@ -47,6 +47,15 @@ except: pass +is_sage_attn_available = False +try: + from sageattention import sageattn + + is_sage_attn_available = True + +except: + pass + from helpers.models.flux.attention import ( FluxSingleAttnProcessor3_0, FluxAttnProcessor3_0, @@ -217,7 +226,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): ) primary_device = torch.cuda.get_device_properties(rank) if primary_device.major == 9 and primary_device.minor == 0: - if is_flash_attn_available: + if is_sage_attn_available: + if rank == 0: + print("Using SageAttention for H100 GPU (Single block)") + processor = FluxSingleSageAttnProcessor() + elif is_flash_attn_available: if rank == 0: print("Using FlashAttention3_0 for H100 GPU (Single block)") processor = FluxSingleAttnProcessor3_0() From 68c5a5b6d7e1a49a4f6395c3b0406c882a8ddb14 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 30 Nov 2024 09:48:27 -0600 Subject: [PATCH 02/26] update deepspeed call to ensure compliance with new transformers version --- helpers/training/ema.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/helpers/training/ema.py b/helpers/training/ema.py index d2153796..7ec7b991 100644 --- a/helpers/training/ema.py +++ b/helpers/training/ema.py @@ -267,7 +267,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None context_manager = contextlib.nullcontext if ( is_transformers_available() - and transformers.deepspeed.is_deepspeed_zero3_enabled() + and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled() ): import deepspeed @@ -309,7 +309,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None for s_param, param in zip(self.shadow_params, parameters): if ( is_transformers_available() - and transformers.deepspeed.is_deepspeed_zero3_enabled() + and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled() ): context_manager = deepspeed.zero.GatheredParameters( param, modifier_rank=None @@ -374,7 +374,9 @@ def cuda(self, device=None): def cpu(self): return self.to(device="cpu") - def state_dict(self, destination=None, prefix="", keep_vars=False, exclude_params: bool = False): + def state_dict( + self, destination=None, prefix="", keep_vars=False, exclude_params: bool = False + ): r""" Returns a dictionary containing a whole state of the EMA model. """ From c8fe4e5e8f6e40775136eaa589650bdbb9a96e22 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:36:01 -0600 Subject: [PATCH 03/26] configurator should offer option to enable SageAttention for user --- configure.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/configure.py b/configure.py index f14e13b9..b5c26d28 100644 --- a/configure.py +++ b/configure.py @@ -429,7 +429,20 @@ def configure_env(): ).lower() == "y" ) - report_to_str = "" + + env_contents["--attention_mechanism"] = "diffusers" + use_sageattention = ( + prompt_user( + "Would you like to use SageAttention? This is an experimental option that can greatly speed up training. (y/[n])", + "n", + ).lower() + == "y" + ) + if use_sageattention: + env_contents["--attention_mechanism"] = "sageattention" + + # properly disable wandb/tensorboard/comet_ml etc by default + report_to_str = "none" if report_to_wandb or report_to_tensorboard: tracker_project_name = prompt_user( "Enter the name of your Weights & Biases project", f"{model_type}-training" @@ -440,17 +453,17 @@ def configure_env(): f"simpletuner-{model_type}", ) env_contents["--tracker_run_name"] = tracker_run_name - report_to_str = None if report_to_wandb: report_to_str = "wandb" if report_to_tensorboard: - if report_to_wandb: + if report_to_str != "none": + # report to both WandB and Tensorboard if the user wanted. report_to_str += "," else: + # remove 'none' from the option report_to_str = "" report_to_str += "tensorboard" - if report_to_str: - env_contents["--report_to"] = report_to_str + env_contents["--report_to"] = report_to_str print_config(env_contents, extra_args) From ab02d128f3a0bad4bea1cba559bf23c9506c30d4 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:36:26 -0600 Subject: [PATCH 04/26] add --attention_mechanism option for sageattention --- helpers/configuration/cmd_args.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 7ea0dc61..1308c50d 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1131,7 +1131,7 @@ def get_argument_parser(): " When using 'ema_only', the validations will rely mostly on the EMA weights." " When using 'comparison' (default) mode, the validations will first run on the checkpoint before also running for" " the EMA weights. In comparison mode, the resulting images will be provided side-by-side." - ) + ), ) parser.add_argument( "--ema_cpu_only", @@ -1708,6 +1708,15 @@ def get_argument_parser(): default=-1, help="For distributed training: local_rank", ) + parser.add_argument( + "--attention_mechanism", + type=str, + choices=["diffusers", "xformers", "sageattention"], + default="diffusers", + help=( + "On NVIDIA CUDA devices, we can use Xformers or SageAttention as an alternative to Pytorch SDPA (Diffusers)." + ), + ) parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", From 2a5e028ed37361df570034efd509f96cb04e0d10 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:37:02 -0600 Subject: [PATCH 05/26] flux: remove model-specific sageattention code --- helpers/models/flux/attention.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/helpers/models/flux/attention.py b/helpers/models/flux/attention.py index edf2d05b..2e880d6e 100644 --- a/helpers/models/flux/attention.py +++ b/helpers/models/flux/attention.py @@ -8,12 +8,6 @@ from flash_attn_interface import flash_attn_func except: pass -try: - from sageattention import sageattn - - F.scaled_dot_product_attention = sageattn -except: - pass def fa3_sdpa( From e6b1919110bf3df2595fb4fff0829461c37c2b38 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:37:34 -0600 Subject: [PATCH 06/26] sageattention cannot be enabled concurrently to xformers --- helpers/training/default_settings/safety_check.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index f586972c..01444a32 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -116,4 +116,13 @@ def safety_check(args, accelerator): logger.error( f"--flux_schedule_auto_shift cannot be combined with --flux_schedule_shift. Please set --flux_schedule_shift to 0 if you want to train with --flux_schedule_auto_shift." ) - sys.exit(1) \ No newline at end of file + sys.exit(1) + + if ( + args.enable_xformers_memory_efficient_attention + and args.attention_mechanism == "sageattention" + ): + logger.error( + f"--enable_xformers_memory_efficient_attention is only compatible with --attention_mechanism=diffusers. Please set --attention_mechanism=diffusers to enable this feature or disable xformers to use alternative attention mechanisms." + ) + sys.exit(1) From 72e7d9ec29027c476a08d6e6ea116ea9bc83d842 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:38:29 -0600 Subject: [PATCH 07/26] sageattention should overwrite sdpa at startup if enabled --- helpers/training/trainer.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index a510027d..2aa15894 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1636,7 +1636,22 @@ def move_models(self, destination: str = "accelerator"): target_device, dtype=self.config.weight_dtype ) ) - if ( + + if self.config.attention_mechanism == "sageattention": + # we'll try and load SageAttention and overload pytorch's sdpa function. + try: + from sageattention import sageattn + + torch.nn.functional.scaled_dot_product_attention = sageattn + logger.warning( + "Using SageAttention for flash attention mechanism. This is an experimental option, and you may receive unexpected or poor results. To disable SageAttention, remove or set --attention_mechanism to a different value." + ) + except ImportError: + logger.error( + "Could not import SageAttention. Please install it to use this --attention_mechanism=sageattention" + ) + sys.exit(1) + elif ( self.config.enable_xformers_memory_efficient_attention and self.config.model_family not in [ From bbf20d0eff804ce5ebff6cadece736782204637a Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:39:59 -0600 Subject: [PATCH 08/26] flux: remove more model-specific sageattn code --- helpers/models/flux/attention.py | 77 ------------------------------ helpers/models/flux/transformer.py | 15 +----- 2 files changed, 1 insertion(+), 91 deletions(-) diff --git a/helpers/models/flux/attention.py b/helpers/models/flux/attention.py index 2e880d6e..9009c858 100644 --- a/helpers/models/flux/attention.py +++ b/helpers/models/flux/attention.py @@ -98,83 +98,6 @@ def __call__( return hidden_states -class FluxSingleSageAttnProcessor3_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def __call__( - self, - attn, - hidden_states: Tensor, - encoder_hidden_states: Tensor = None, - attention_mask: FloatTensor = None, - image_rotary_emb: Tensor = None, - ) -> Tensor: - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view( - batch_size, channel, height * width - ).transpose(1, 2) - - batch_size, _, _ = ( - hidden_states.shape - if encoder_hidden_states is None - else encoder_hidden_states.shape - ) - - query = attn.to_q(hidden_states) - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = fa3_sdpa(query, key, value) - hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") - - hidden_states = hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - hidden_states = hidden_states.to(query.dtype) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape( - batch_size, channel, height, width - ) - - return hidden_states - - class FluxAttnProcessor3_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" diff --git a/helpers/models/flux/transformer.py b/helpers/models/flux/transformer.py index 806bcc27..7a9c80f2 100644 --- a/helpers/models/flux/transformer.py +++ b/helpers/models/flux/transformer.py @@ -47,15 +47,6 @@ except: pass -is_sage_attn_available = False -try: - from sageattention import sageattn - - is_sage_attn_available = True - -except: - pass - from helpers.models.flux.attention import ( FluxSingleAttnProcessor3_0, FluxAttnProcessor3_0, @@ -226,11 +217,7 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): ) primary_device = torch.cuda.get_device_properties(rank) if primary_device.major == 9 and primary_device.minor == 0: - if is_sage_attn_available: - if rank == 0: - print("Using SageAttention for H100 GPU (Single block)") - processor = FluxSingleSageAttnProcessor() - elif is_flash_attn_available: + if is_flash_attn_available: if rank == 0: print("Using FlashAttention3_0 for H100 GPU (Single block)") processor = FluxSingleAttnProcessor3_0() From a86c256d4a97819e2475a41d1f1d1f9593e17a4f Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 10:47:45 -0600 Subject: [PATCH 09/26] deprecate --enable_xformers_memory_efficient_attention in favour of --attention_mechanism=xformers --- helpers/configuration/cmd_args.py | 18 +++++++++++++++--- helpers/publishing/metadata.py | 4 +++- helpers/training/trainer.py | 7 +++++-- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 1308c50d..c15cdc01 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -2447,11 +2447,11 @@ def parse_cmdline_args(input_args=None): elif "standard" == args.lora_type.lower(): if hasattr(args, "lora_init_type") and args.lora_init_type is not None: if torch.backends.mps.is_available() and args.lora_init_type == "loftq": - logger.error( + error_log( "Apple MPS cannot make use of LoftQ initialisation. Overriding to 'default'." ) elif args.is_quantized and args.lora_init_type == "loftq": - logger.error( + error_log( "LoftQ initialisation is not supported with quantised models. Overriding to 'default'." ) else: @@ -2460,7 +2460,7 @@ def parse_cmdline_args(input_args=None): ) if args.use_dora: if "quanto" in args.base_model_precision: - logger.error( + error_log( "Quanto does not yet support DoRA training in PEFT. Disabling DoRA. 😴" ) args.use_dora = False @@ -2497,4 +2497,16 @@ def parse_cmdline_args(input_args=None): logger.error(f"Could not load skip layers: {e}") raise + if args.enable_xformers_memory_efficient_attention: + if args.attention_mechanism != "xformers": + warning_log( + "The option --enable_xformers_memory_efficient_attention is deprecated. Please use --attention_mechanism=xformers instead." + ) + args.attention_mechanism = "xformers" + + if args.attention_mechanism != "diffusers" and not torch.cuda.is_available(): + warning_log( + "For non-CUDA systems, only Diffusers attention mechanism is officially supported." + ) + return args diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index 92bd3568..c761f27e 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -119,6 +119,7 @@ def ema_info(args): return ema_information return "" + def lycoris_download_info(): """output a function to download the adapter""" output_fn = """ @@ -556,7 +557,8 @@ def save_model_card( - Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''} - Trainable parameter precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'} - Caption dropout probability: {StateTracker.get_args().caption_dropout_probability * 100}% -{'- Xformers: Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else ''} +{'- Xformers: Enabled' if StateTracker.get_args().attention_mechanism == 'xformers' else ''} +{'- SageAttention: Enabled' if StateTracker.get_args().attention_mechanism == 'sageattention' else ''} {lora_info(args=StateTracker.get_args())} ## Datasets diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 2aa15894..637006d1 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1652,7 +1652,7 @@ def move_models(self, destination: str = "accelerator"): ) sys.exit(1) elif ( - self.config.enable_xformers_memory_efficient_attention + self.config.attention_mechanism == "xformers" and self.config.model_family not in [ "sd3", @@ -1676,11 +1676,14 @@ def move_models(self, destination: str = "accelerator"): raise ValueError( "xformers is not available. Make sure it is installed correctly" ) - elif self.config.enable_xformers_memory_efficient_attention: + elif self.config.attention_mechanism == "xformers": logger.warning( "xformers is not enabled, as it is incompatible with this model type." + " Falling back to diffusers attention mechanism (Pytorch SDPA)." + " Alternatively, provide --attention_mechanism=sageattention for a more efficient option on CUDA systems." ) self.config.enable_xformers_memory_efficient_attention = False + self.config.attention_mechanism = "diffusers" if self.config.controlnet: self.controlnet.train() From 73c6f1b89d487f73dd72be08e5e60e6503943dec Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 11:44:44 -0600 Subject: [PATCH 10/26] hackish awful workaround for VAE decode in SageAttention --- helpers/models/flux/pipeline.py | 15 +++++++++++-- helpers/models/omnigen/pipeline.py | 17 ++++++++++++--- helpers/models/pixart/pipeline.py | 17 +++++++++++++-- helpers/models/sd3/pipeline.py | 34 ++++++++++++++++++++++++++++-- helpers/models/sdxl/pipeline.py | 12 +++++++++++ 5 files changed, 86 insertions(+), 9 deletions(-) diff --git a/helpers/models/flux/pipeline.py b/helpers/models/flux/pipeline.py index ba7eaa44..1d152def 100644 --- a/helpers/models/flux/pipeline.py +++ b/helpers/models/flux/pipeline.py @@ -906,11 +906,22 @@ def __call__( latents = ( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) image = self.vae.decode( - latents.to(device=self.vae.device, dtype=self.vae.dtype), - return_dict=False, + latents.to(dtype=self.vae.dtype), return_dict=False )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models diff --git a/helpers/models/omnigen/pipeline.py b/helpers/models/omnigen/pipeline.py index fadbaf05..5693967f 100644 --- a/helpers/models/omnigen/pipeline.py +++ b/helpers/models/omnigen/pipeline.py @@ -345,9 +345,20 @@ def __call__( ) else: samples = samples / self.vae.config.scaling_factor - samples = self.vae.decode( - samples.to(dtype=self.vae.dtype, device=self.vae.device) - ).sample + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + + image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) if self.model_cpu_offload: self.vae.to("cpu") diff --git a/helpers/models/pixart/pipeline.py b/helpers/models/pixart/pipeline.py index 6efb54cd..244df2e2 100644 --- a/helpers/models/pixart/pipeline.py +++ b/helpers/models/pixart/pipeline.py @@ -1231,11 +1231,24 @@ def denoising_value_valid(dnv): callback(step_idx, t, latents) if not output_type == "latent": + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( - latents.to(device=self.vae.device, dtype=self.vae.dtype) - / self.vae.config.scaling_factor, + latents.to(dtype=self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False, + generator=generator, )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor( image, orig_width, orig_height diff --git a/helpers/models/sd3/pipeline.py b/helpers/models/sd3/pipeline.py index 653c2a6a..ed25c954 100644 --- a/helpers/models/sd3/pipeline.py +++ b/helpers/models/sd3/pipeline.py @@ -1097,7 +1097,22 @@ def __call__( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor - image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + + image = self.vae.decode( + latents.to(dtype=self.vae.dtype), return_dict=False + )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models @@ -2053,7 +2068,22 @@ def __call__( latents / self.vae.config.scaling_factor ) + self.vae.config.shift_factor - image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + + image = self.vae.decode( + latents.to(dtype=self.vae.dtype), return_dict=False + )[0] + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models diff --git a/helpers/models/sdxl/pipeline.py b/helpers/models/sdxl/pipeline.py index bfd93f29..10b01216 100644 --- a/helpers/models/sdxl/pipeline.py +++ b/helpers/models/sdxl/pipeline.py @@ -1488,10 +1488,22 @@ def __call__( else: latents = latents / self.vae.config.scaling_factor + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( latents.to(dtype=self.vae.dtype), return_dict=False )[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) From d406e3a817e57411df8afd9153e0c834a6a3fbb8 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 11:45:34 -0600 Subject: [PATCH 11/26] add more sageattention API choices --- helpers/configuration/cmd_args.py | 11 +++++++-- helpers/training/trainer.py | 40 ++++++++++++++++++++++++++----- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index c15cdc01..a0617f80 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1711,7 +1711,14 @@ def get_argument_parser(): parser.add_argument( "--attention_mechanism", type=str, - choices=["diffusers", "xformers", "sageattention"], + choices=[ + "diffusers", + "xformers", + "sageattention", + "sageattention-int8-fp16-triton", + "sageattention-int8-fp16-cuda", + "sageattention-int8-fp8-cuda", + ], default="diffusers", help=( "On NVIDIA CUDA devices, we can use Xformers or SageAttention as an alternative to Pytorch SDPA (Diffusers)." @@ -2427,7 +2434,7 @@ def parse_cmdline_args(input_args=None): args.lycoris_config, os.R_OK ): raise ValueError( - f"Could not find the JSON configuration file at {args.lycoris_config}" + f"Could not find the JSON configuration file at '{args.lycoris_config}'" ) import json diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 637006d1..a78ede95 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1637,19 +1637,47 @@ def move_models(self, destination: str = "accelerator"): ) ) - if self.config.attention_mechanism == "sageattention": + if "sageattention" in self.config.attention_mechanism: # we'll try and load SageAttention and overload pytorch's sdpa function. try: - from sageattention import sageattn + from sageattention import ( + sageattn, + sageattn_qk_int8_pv_fp16_triton, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp8_cuda, + ) + + sageattn_functions = { + "sageattention": sageattn, + "sageattention-int8-fp16-triton": sageattn_qk_int8_pv_fp16_triton, + "sageattention-int8-fp16-cuda": sageattn_qk_int8_pv_fp16_cuda, + "sageattention-int8-fp8-cuda": sageattn_qk_int8_pv_fp8_cuda, + } + # store the old SDPA for validations to use during VAE decode + setattr( + torch.nn.functional, + "scaled_dot_product_attention_sdpa", + torch.nn.functional.scaled_dot_product_attention, + ) + torch.nn.functional.scaled_dot_product_attention = ( + sageattn_functions.get( + self.config.attention_mechanism, "sageattention" + ) + ) + setattr( + torch.nn.functional, + "scaled_dot_product_attention_sage", + torch.nn.functional.scaled_dot_product_attention, + ) - torch.nn.functional.scaled_dot_product_attention = sageattn logger.warning( - "Using SageAttention for flash attention mechanism. This is an experimental option, and you may receive unexpected or poor results. To disable SageAttention, remove or set --attention_mechanism to a different value." + f"Using {self.config.attention_mechanism} for flash attention mechanism. This is an experimental option, and you may receive unexpected or poor results. To disable SageAttention, remove or set --attention_mechanism to a different value." ) - except ImportError: + except ImportError as e: logger.error( - "Could not import SageAttention. Please install it to use this --attention_mechanism=sageattention" + "Could not import SageAttention. Please install it to use this --attention_mechanism=sageattention." ) + logger.error(repr(e)) sys.exit(1) elif ( self.config.attention_mechanism == "xformers" From cb286c2dcca43694e6a296f44ca140bc040dc9e6 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 11:47:42 -0600 Subject: [PATCH 12/26] SD 1.5/2.x fix for SageAttention decode --- helpers/legacy/pipeline.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/helpers/legacy/pipeline.py b/helpers/legacy/pipeline.py index 01b0ffe8..c186b48d 100644 --- a/helpers/legacy/pipeline.py +++ b/helpers/legacy/pipeline.py @@ -1160,18 +1160,27 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + has_nsfw_concept = None if not output_type == "latent": + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( - latents.to(self.vae.dtype) / self.vae.config.scaling_factor, + latents.to(dtype=self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False, generator=generator, )[0] - image, has_nsfw_concept = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) + + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) else: image = latents - has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] From 11b3cf175607c61e4f09ab989eb648b0698fec51 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 11:48:24 -0600 Subject: [PATCH 13/26] set attention mechanism to the default for tests --- tests/test_model_card.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_model_card.py b/tests/test_model_card.py index 1c596be5..2e2ed5b1 100644 --- a/tests/test_model_card.py +++ b/tests/test_model_card.py @@ -65,6 +65,7 @@ def setUp(self): self.args.flux_guidance_value = 1.0 self.args.t5_padding = "unmodified" self.args.enable_xformers_memory_efficient_attention = False + self.args.attention_mechanism = "diffusers" def test_model_imports(self): self.args.lora_type = "standard" From d1c227cc71af063615ff1357fa9c196765ddf1df Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 12:14:56 -0600 Subject: [PATCH 14/26] add sageattention to OPTIONS doc, update recommendations in --help output --- OPTIONS.md | 38 ++++++++++++++++++++++++++++--- helpers/configuration/cmd_args.py | 8 ++++++- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/OPTIONS.md b/OPTIONS.md index edeea098..c3d59c8d 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -452,7 +452,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--lr_scheduler {linear,sine,cosine,cosine_with_restarts,polynomial,constant,constant_with_warmup}] [--lr_warmup_steps LR_WARMUP_STEPS] [--lr_num_cycles LR_NUM_CYCLES] [--lr_power LR_POWER] - [--use_ema] [--ema_device {cpu,accelerator}] [--ema_cpu_only] + [--use_ema] [--ema_device {cpu,accelerator}] + [--ema_validation {none,ema_only,comparison}] [--ema_cpu_only] [--ema_foreach_disable] [--ema_update_interval EMA_UPDATE_INTERVAL] [--ema_decay EMA_DECAY] [--non_ema_revision NON_EMA_REVISION] @@ -473,8 +474,9 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--model_card_safe_for_work] [--logging_dir LOGGING_DIR] [--benchmark_base_model] [--disable_benchmark] [--evaluation_type {clip,none}] - [--pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path] + [--pretrained_evaluation_model_name_or_path PRETRAINED_EVALUATION_MODEL_NAME_OR_PATH] [--validation_on_startup] [--validation_seed_source {gpu,cpu}] + [--validation_lycoris_strength VALIDATION_LYCORIS_STRENGTH] [--validation_torch_compile] [--validation_torch_compile_mode {max-autotune,reduce-overhead,default}] [--validation_guidance_skip_layers VALIDATION_GUIDANCE_SKIP_LAYERS] @@ -509,6 +511,7 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--text_encoder_2_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}] [--text_encoder_3_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}] [--local_rank LOCAL_RANK] + [--attention_mechanism {diffusers,xformers,sageattention,sageattention-int8-fp16-triton,sageattention-int8-fp16-cuda,sageattention-int8-fp8-cuda}] [--enable_xformers_memory_efficient_attention] [--set_grads_to_none] [--noise_offset NOISE_OFFSET] [--noise_offset_probability NOISE_OFFSET_PROBABILITY] @@ -1137,12 +1140,21 @@ options: cosine_with_restarts scheduler. --lr_power LR_POWER Power factor of the polynomial scheduler. --use_ema Whether to use EMA (exponential moving average) model. + Works with LoRA, Lycoris, and full training. --ema_device {cpu,accelerator} The device to use for the EMA model. If set to 'accelerator', the EMA model will be placed on the accelerator. This provides the fastest EMA update times, but is not ultimately necessary for EMA to function. + --ema_validation {none,ema_only,comparison} + When 'none' is set, no EMA validation will be done. + When using 'ema_only', the validations will rely + mostly on the EMA weights. When using 'comparison' + (default) mode, the validations will first run on the + checkpoint before also running for the EMA weights. In + comparison mode, the resulting images will be provided + side-by-side. --ema_cpu_only When using EMA, the shadow model is moved to the accelerator before we update its parameters. When provided, this option will disable the moving of the @@ -1248,7 +1260,7 @@ options: function. The default is to use no evaluator, and 'clip' will use a CLIP model to evaluate the resulting model's performance during validations. - --pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path + --pretrained_evaluation_model_name_or_path PRETRAINED_EVALUATION_MODEL_NAME_OR_PATH Optionally provide a custom model to use for ViT evaluations. The default is currently clip-vit-large- patch14-336, allowing for lower patch sizes (greater @@ -1264,6 +1276,12 @@ options: validation errors. If so, please set SIMPLETUNER_LOG_LEVEL=DEBUG and submit debug.log to a new Github issue report. + --validation_lycoris_strength VALIDATION_LYCORIS_STRENGTH + When inferencing for validations, the Lycoris model + will by default be run at its training strength, 1.0. + However, this value can be increased to a value of + around 1.3 or 1.5 to get a stronger effect from the + model. --validation_torch_compile Supply `--validation_torch_compile=true` to enable the use of torch.compile() on the validation pipeline. For @@ -1453,6 +1471,20 @@ options: quantisation (Apple Silicon, NVIDIA, AMD). --local_rank LOCAL_RANK For distributed training: local_rank + --attention_mechanism {diffusers,xformers,sageattention,sageattention-int8-fp16-triton,sageattention-int8-fp16-cuda,sageattention-int8-fp8-cuda} + On NVIDIA CUDA devices, alternative flash attention + implementations are offered, with the default being + native pytorch SDPA. SageAttention has multiple + backends to select from. The recommended value, + 'sageattention', guesses what would be the 'best' + option for SageAttention on your hardware (usually + this is the int8-fp16-cuda backend). However, manually + setting this value to int8-fp16-triton may provide + better averages for per-step training and inference + performance while the cuda backend may provide the + highest maximum speed (with also a lower minimum + speed). NOTE: SageAttention training quality has not + been validated. --enable_xformers_memory_efficient_attention Whether or not to use xformers. --set_grads_to_none Save more memory by using setting grads to None diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index a0617f80..fce256e3 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1721,7 +1721,13 @@ def get_argument_parser(): ], default="diffusers", help=( - "On NVIDIA CUDA devices, we can use Xformers or SageAttention as an alternative to Pytorch SDPA (Diffusers)." + "On NVIDIA CUDA devices, alternative flash attention implementations are offered, with the default being native pytorch SDPA." + " SageAttention has multiple backends to select from." + " The recommended value, 'sageattention', guesses what would be the 'best' option for SageAttention on your hardware" + " (usually this is the int8-fp16-cuda backend). However, manually setting this value to int8-fp16-triton" + " may provide better averages for per-step training and inference performance while the cuda backend" + " may provide the highest maximum speed (with also a lower minimum speed). NOTE: SageAttention training quality" + " has not been validated." ), ) parser.add_argument( From 0976b5ef22705078bc666a32ce671f2919c27a03 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 12:32:52 -0600 Subject: [PATCH 15/26] kolors: enable vae decode hack for sageattention --- helpers/kolors/pipeline.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/helpers/kolors/pipeline.py b/helpers/kolors/pipeline.py index 964a8e95..2f395e4d 100644 --- a/helpers/kolors/pipeline.py +++ b/helpers/kolors/pipeline.py @@ -1249,11 +1249,22 @@ def denoising_value_valid(dnv): # unscale/denormalize the latents latents = latents / self.vae.config.scaling_factor + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # we have SageAttention loaded. fallback to SDPA for decode. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sdpa + ) + image = self.vae.decode( - latents.to(device=self.vae.device, dtype=self.vae.dtype), - return_dict=False, + latents.to(device=self.vae.device), return_dict=False )[0] + if hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + # reenable SageAttention for training. + torch.nn.functional.scaled_dot_product_attention = ( + torch.nn.functional.scaled_dot_product_attention_sage + ) + # cast back to fp16 if needed if needs_upcasting: self.vae.to(dtype=torch.float16) From c98d0f7a9d49c08a146732c562d22d57e5c989f1 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 12:40:42 -0600 Subject: [PATCH 16/26] add sageattention to the quickstart docs and a specific section under "Precision" in the OPTIONS doc --- OPTIONS.md | 8 ++++++++ documentation/LYCORIS.md | 8 ++++++++ documentation/quickstart/FLUX.md | 10 ++++++++++ documentation/quickstart/SD3.md | 8 ++++++++ documentation/quickstart/SIGMA.md | 8 ++++++++ 5 files changed, 42 insertions(+) diff --git a/OPTIONS.md b/OPTIONS.md index c3d59c8d..ddc32383 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -109,6 +109,14 @@ Carefully answer the questions and use bf16 mixed precision training when prompt Note that the first several steps of training will be slower than usual because of compilation occuring in the background. +### `--attention_mechanism` + +Setting `sageattention` or `xformers` here will allow the use of other memory-efficient attention mechanisms for the forward pass during training and inference, potentially resulting in major performance improvement. + +Using `sageattention` enables the use of [SageAttention](https://github.com/thu-ml/SageAttention) on NVIDIA CUDA equipment (sorry, AMD and Apple users). + +In simple terms, this will quantise the attention calculations for lower compute and memory overhead, **massively** speeding up training while minimally impacting quality. + --- ## 📰 Publishing diff --git a/documentation/LYCORIS.md b/documentation/LYCORIS.md index 64aba565..853d3ec6 100644 --- a/documentation/LYCORIS.md +++ b/documentation/LYCORIS.md @@ -59,6 +59,14 @@ Mandatory fields: For more information on LyCORIS, please refer to the [documentation in the library](https://github.com/KohakuBlueleaf/LyCORIS/tree/main/docs). +## Potential problems + +When using Lycoris on SDXL, it's noted that training the FeedForward modules may break the model and send loss into `NaN` (Not-a-Number) territory. + +This seems to be potentially exacerbated when using SageAttention, making it all but guaranteed that the model will immediately fail. + +The solution is to remove the `FeedForward` modules from the lycoris config and train only the `Attention` blocks. + ## LyCORIS Inference Example Here is a simple FLUX.1-dev inference script showing how to wrap your unet or transformer with create_lycoris_from_weights and then use it for inference. diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index ddb86e1d..189b1a72 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -414,9 +414,18 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with: - DeepSpeed: disabled / unconfigured - PyTorch: 2.6 Nightly (Sept 29th build) - Using `--quantize_via=cpu` to avoid outOfMemory error during startup on <=16G cards. +- With `--attention_mechanism=sageattention` to further reduce VRAM by 0.1GB and improve training speed. Speed was approximately 1.4 iterations per second on a 4090. +### SageAttention + +When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations. + +In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training. + +**Note**: This isn't compatible with _every_ configuration, but it's worth trying. + ### NF4-quantised training In simplest terms, NF4 is a 4bit-_ish_ representation of the model, which means training has serious stability concerns to address. @@ -428,6 +437,7 @@ In early tests, the following holds true: - NF4, AdamW8bit, and a higher batch size all help to overcome the stability issues, at the cost of more time spent training or VRAM used - Upping the resolution from 512px to 1024px slows training down from, for example, 1.4 seconds per step to 3.5 seconds per step (batch size of 1, 4090) - Anything that's difficult to train on int8 or bf16 becomes harder in NF4 +- It's less compatible with options like SageAttention NF4 does not work with torch.compile, so whatever you get for speed is what you get. diff --git a/documentation/quickstart/SD3.md b/documentation/quickstart/SD3.md index 5d646562..764bf64a 100644 --- a/documentation/quickstart/SD3.md +++ b/documentation/quickstart/SD3.md @@ -339,6 +339,14 @@ These options have been known to keep SD3.5 in-tact for as long as possible: - DeepSpeed: disabled / unconfigured - PyTorch: 2.5 +### SageAttention + +When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations. + +In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training. + +**Note**: This isn't compatible with _every_ configuration, but it's worth trying. + ### Masked loss If you are training a subject or style and would like to mask one or the other, see the [masked loss training](/documentation/DREAMBOOTH.md#masked-loss) section of the Dreambooth guide. diff --git a/documentation/quickstart/SIGMA.md b/documentation/quickstart/SIGMA.md index cf3389aa..aadd46f2 100644 --- a/documentation/quickstart/SIGMA.md +++ b/documentation/quickstart/SIGMA.md @@ -220,3 +220,11 @@ For more information, see the [dataloader](/documentation/DATALOADER.md) and [tu ### CLIP score tracking If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores. + +### SageAttention + +When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations. + +In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training. + +**Note**: This isn't compatible with _every_ configuration, but it's worth trying. \ No newline at end of file From 8719e738fea3964868d5bcdbffa57e6faddb9c2b Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 13:05:48 -0600 Subject: [PATCH 17/26] sd3 skip_layer_guidance fix from upstream for num images per prompt > 1 --- helpers/models/sd3/pipeline.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/helpers/models/sd3/pipeline.py b/helpers/models/sd3/pipeline.py index ed25c954..70698f6e 100644 --- a/helpers/models/sd3/pipeline.py +++ b/helpers/models/sd3/pipeline.py @@ -999,11 +999,14 @@ def __call__( continue # expand the latents if we are doing classifier free guidance + # added fix from: https://github.com/huggingface/diffusers/pull/10086/files + # to allow for num_images_per_prompt > 1 latent_model_input = ( torch.cat([latents] * 2) - if self.do_classifier_free_guidance and skip_guidance_layers is None + if self.do_classifier_free_guidance else latents ) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) @@ -1033,6 +1036,8 @@ def __call__( else False ) if skip_guidance_layers is not None and should_skip_layers: + timestep = t.expand(latents.shape[0]) + latent_model_input = latents noise_pred_skip_layers = self.transformer( hidden_states=latent_model_input.to( device=self.transformer.device, From 964e065336eb384e1ad90ad05297990a6d91d17a Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 2 Dec 2024 13:41:07 -0600 Subject: [PATCH 18/26] disable nf4 + sageattention --- helpers/training/default_settings/safety_check.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index 01444a32..518afd63 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -126,3 +126,9 @@ def safety_check(args, accelerator): f"--enable_xformers_memory_efficient_attention is only compatible with --attention_mechanism=diffusers. Please set --attention_mechanism=diffusers to enable this feature or disable xformers to use alternative attention mechanisms." ) sys.exit(1) + + if "nf4" in args.base_model_precision: + logger.error( + f"{args.base_model_precision} is not supported with SageAttention. Please select from int8 or fp8, or, disable quantisation to use SageAttention." + ) + sys.exit(1) From 9bda6329d3548d17dee47b47bfe418554a3ebd88 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 3 Dec 2024 11:00:28 -0600 Subject: [PATCH 19/26] add --sageattention_usage and default to inference-only --- helpers/configuration/cmd_args.py | 13 +- helpers/publishing/metadata.py | 2 +- .../training/default_settings/safety_check.py | 29 ++-- helpers/training/trainer.py | 151 +++++++++++++----- helpers/training/validation.py | 104 +++++++++--- 5 files changed, 216 insertions(+), 83 deletions(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index fce256e3..c90923e5 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1730,10 +1730,21 @@ def get_argument_parser(): " has not been validated." ), ) + parser.add_argument( + "--sageattention_usage", + type=str, + choices=["training", "inference", "training+inference"], + default="inference", + help=( + "SageAttention breaks gradient tracking through the backward pass, leading to untrained QKV layers." + " This can result in substantial problems for training, so it is recommended to use SageAttention only for inference (default behaviour)." + " If you are confident in your training setup or do not wish to train QKV layers, you may use 'training' to enable SageAttention for training." + ), + ) parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", - help="Whether or not to use xformers.", + help="Whether or not to use xformers. Deprecated and slated for future removal.", ) parser.add_argument( "--set_grads_to_none", diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index c761f27e..2bfe4104 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -558,7 +558,7 @@ def save_model_card( - Trainable parameter precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'} - Caption dropout probability: {StateTracker.get_args().caption_dropout_probability * 100}% {'- Xformers: Enabled' if StateTracker.get_args().attention_mechanism == 'xformers' else ''} -{'- SageAttention: Enabled' if StateTracker.get_args().attention_mechanism == 'sageattention' else ''} +{f'- SageAttention: Enabled {StateTracker.get_args().sageattention_usage}' if StateTracker.get_args().attention_mechanism == 'sageattention' else ''} {lora_info(args=StateTracker.get_args())} ## Datasets diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index 518afd63..36f321c5 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -118,17 +118,18 @@ def safety_check(args, accelerator): ) sys.exit(1) - if ( - args.enable_xformers_memory_efficient_attention - and args.attention_mechanism == "sageattention" - ): - logger.error( - f"--enable_xformers_memory_efficient_attention is only compatible with --attention_mechanism=diffusers. Please set --attention_mechanism=diffusers to enable this feature or disable xformers to use alternative attention mechanisms." - ) - sys.exit(1) - - if "nf4" in args.base_model_precision: - logger.error( - f"{args.base_model_precision} is not supported with SageAttention. Please select from int8 or fp8, or, disable quantisation to use SageAttention." - ) - sys.exit(1) + if args.attention_mechanism == "sageattention": + if args.sageattention_usage != "inference": + logger.error( + f"SageAttention usage is set to '{args.sageattention_usage}' instead of 'inference'. This is not an officially supported configuration, please be sure you understand the implications. It is recommended to set this value to 'inference' for safety." + ) + if args.enable_xformers_memory_efficient_attention: + logger.error( + f"--enable_xformers_memory_efficient_attention is only compatible with --attention_mechanism=diffusers. Please set --attention_mechanism=diffusers to enable this feature or disable xformers to use alternative attention mechanisms." + ) + sys.exit(1) + if "nf4" in args.base_model_precision: + logger.error( + f"{args.base_model_precision} is not supported with SageAttention. Please select from int8 or fp8, or, disable quantisation to use SageAttention." + ) + sys.exit(1) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index a78ede95..dfe7593d 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1613,6 +1613,98 @@ def resume_and_prepare(self): lr_scheduler = self.init_resume_checkpoint(lr_scheduler=lr_scheduler) self.init_post_load_freeze() + def enable_sageattention_inference(self): + # if the sageattention is inference-only, we'll enable it. + # if it's training only, we'll disable it. + # if it's inference+training, we leave it alone. + if ( + "sageattention" not in self.config.attention_mechanism + or self.config.sageattention_usage == "training+inference" + ): + return + if self.config.sageattention_usage == "inference": + self.enable_sageattention() + if self.config.sageattention_usage == "training": + self.disable_sageattention() + + def disable_sageattention_inference(self): + # if the sageattention is inference-only, we'll disable it. + # if it's training only, we'll enable it. + # if it's inference+training, we leave it alone. + if ( + "sageattention" not in self.config.attention_mechanism + or self.config.sageattention_usage == "training+inference" + ): + return + if self.config.sageattention_usage == "inference": + self.disable_sageattention() + if self.config.sageattention_usage == "training": + self.enable_sageattention() + + def disable_sageattention(self): + if "sageattention" not in self.config.attention_mechanism: + return + + if ( + hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa") + and torch.nn.functional + != torch.nn.functional.scaled_dot_product_attention_sdpa + ): + logger.info("Disabling SageAttention.") + setattr( + torch.nn.functional, + "scaled_dot_product_attention", + torch.nn.functional.scaled_dot_product_attention_sdpa, + ) + + def enable_sageattention(self): + if "sageattention" not in self.config.attention_mechanism: + return + + # we'll try and load SageAttention and overload pytorch's sdpa function. + try: + logger.info("Enabling SageAttention.") + from sageattention import ( + sageattn, + sageattn_qk_int8_pv_fp16_triton, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp8_cuda, + ) + + sageattn_functions = { + "sageattention": sageattn, + "sageattention-int8-fp16-triton": sageattn_qk_int8_pv_fp16_triton, + "sageattention-int8-fp16-cuda": sageattn_qk_int8_pv_fp16_cuda, + "sageattention-int8-fp8-cuda": sageattn_qk_int8_pv_fp8_cuda, + } + # store the old SDPA for validations to use during VAE decode + if not hasattr(torch.nn.functional, "scaled_dot_product_attention_sdpa"): + setattr( + torch.nn.functional, + "scaled_dot_product_attention_sdpa", + torch.nn.functional.scaled_dot_product_attention, + ) + torch.nn.functional.scaled_dot_product_attention = sageattn_functions.get( + self.config.attention_mechanism, "sageattention" + ) + if not hasattr(torch.nn.functional, "scaled_dot_product_attention_sage"): + setattr( + torch.nn.functional, + "scaled_dot_product_attention_sage", + torch.nn.functional.scaled_dot_product_attention, + ) + + if "training" in self.config.sageattention_usage: + logger.warning( + f"Using {self.config.attention_mechanism} for attention calculations during training. Your attention layers will not be trained. To disable SageAttention, remove or set --attention_mechanism to a different value." + ) + except ImportError as e: + logger.error( + "Could not import SageAttention. Please install it to use this --attention_mechanism=sageattention." + ) + logger.error(repr(e)) + sys.exit(1) + def move_models(self, destination: str = "accelerator"): target_device = "cpu" if destination == "accelerator": @@ -1637,48 +1729,14 @@ def move_models(self, destination: str = "accelerator"): ) ) - if "sageattention" in self.config.attention_mechanism: - # we'll try and load SageAttention and overload pytorch's sdpa function. - try: - from sageattention import ( - sageattn, - sageattn_qk_int8_pv_fp16_triton, - sageattn_qk_int8_pv_fp16_cuda, - sageattn_qk_int8_pv_fp8_cuda, - ) - - sageattn_functions = { - "sageattention": sageattn, - "sageattention-int8-fp16-triton": sageattn_qk_int8_pv_fp16_triton, - "sageattention-int8-fp16-cuda": sageattn_qk_int8_pv_fp16_cuda, - "sageattention-int8-fp8-cuda": sageattn_qk_int8_pv_fp8_cuda, - } - # store the old SDPA for validations to use during VAE decode - setattr( - torch.nn.functional, - "scaled_dot_product_attention_sdpa", - torch.nn.functional.scaled_dot_product_attention, - ) - torch.nn.functional.scaled_dot_product_attention = ( - sageattn_functions.get( - self.config.attention_mechanism, "sageattention" - ) - ) - setattr( - torch.nn.functional, - "scaled_dot_product_attention_sage", - torch.nn.functional.scaled_dot_product_attention, - ) - - logger.warning( - f"Using {self.config.attention_mechanism} for flash attention mechanism. This is an experimental option, and you may receive unexpected or poor results. To disable SageAttention, remove or set --attention_mechanism to a different value." - ) - except ImportError as e: - logger.error( - "Could not import SageAttention. Please install it to use this --attention_mechanism=sageattention." - ) - logger.error(repr(e)) - sys.exit(1) + if ( + "sageattention" in self.config.attention_mechanism + and "training" in self.config.sageattention_usage + ): + logger.info( + "Using SageAttention for training. This is an unsupported, experimental configuration." + ) + self.enable_sageattention() elif ( self.config.attention_mechanism == "xformers" and self.config.model_family @@ -2129,7 +2187,9 @@ def train(self): self.mark_optimizer_eval() # normal run-of-the-mill validation on startup. if self.validation is not None: + self.enable_sageattention_inference() self.validation.run_validations(validation_type="base_model", step=0) + self.disable_sageattention_inference() self.mark_optimizer_train() @@ -2851,9 +2911,13 @@ def train(self): progress_bar.set_postfix(**logs) self.mark_optimizer_eval() if self.validation is not None: + if self.validation.would_validate(): + self.enable_sageattention_inference() self.validation.run_validations( validation_type="intermediary", step=step ) + if self.validation.would_validate(): + self.disable_sageattention_inference() self.mark_optimizer_train() if ( self.config.push_to_hub @@ -2902,12 +2966,15 @@ def train(self): if self.accelerator.is_main_process: self.mark_optimizer_eval() if self.validation is not None: + self.enable_sageattention_inference() validation_images = self.validation.run_validations( validation_type="final", step=self.state["global_step"], force_evaluation=True, skip_execution=True, ).validation_images + # we don't have to do this but we will anyway. + self.disable_sageattention_inference() if self.unet is not None: self.unet = unwrap_model(self.accelerator, self.unet) if self.transformer is not None: diff --git a/helpers/training/validation.py b/helpers/training/validation.py index fe001cf2..c5a062bb 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -745,7 +745,11 @@ def _benchmark_path(self, benchmark: str = "base_model"): return os.path.join(self.args.output_dir, "benchmarks", benchmark) def stitch_benchmark_image( - self, validation_image_result, benchmark_image, separator_width=5, labels=["base model", "checkpoint"] + self, + validation_image_result, + benchmark_image, + separator_width=5, + labels=["base model", "checkpoint"], ): """ For each image, make a new canvas and place it side by side with its equivalent from {self.validation_image_inputs} @@ -754,7 +758,9 @@ def stitch_benchmark_image( """ # Calculate new dimensions - new_width = benchmark_image.size[0] + validation_image_result.size[0] + separator_width + new_width = ( + benchmark_image.size[0] + validation_image_result.size[0] + separator_width + ) new_height = benchmark_image.size[1] # Create a new image with a white background @@ -877,6 +883,18 @@ def _update_state(self): self.global_step = StateTracker.get_global_step() self.global_resume_step = StateTracker.get_global_resume_step() or 1 + def would_validate( + self, + step: int = 0, + validation_type="intermediary", + force_evaluation: bool = False, + ): + # a wrapper for should_perform_validation that can run in the training loop + self._update_state() + return self.should_perform_validation( + step, self.validation_prompts, validation_type + ) or (step == 0 and validation_type == "base_model") + def run_validations( self, step: int = 0, @@ -1168,9 +1186,11 @@ def process_prompts(self): ) self.validation_prompt_dict[shortname] = prompt logger.debug(f"Processing validation for prompt: {prompt}") - stitched_validation_images, checkpoint_validation_images, ema_validation_images = ( - self.validate_prompt(prompt, shortname, validation_input_image) - ) + ( + stitched_validation_images, + checkpoint_validation_images, + ema_validation_images, + ) = self.validate_prompt(prompt, shortname, validation_input_image) validation_images.update(stitched_validation_images) self._save_images(validation_images, shortname, prompt) logger.debug(f"Completed generating image: {prompt}") @@ -1364,15 +1384,17 @@ def validate_prompt( ) if current_validation_type == "ema": self.enable_ema_for_inference() - all_validation_type_results[current_validation_type] = self.pipeline( - **pipeline_kwargs - ).images + all_validation_type_results[current_validation_type] = ( + self.pipeline(**pipeline_kwargs).images + ) if current_validation_type == "ema": self.disable_ema_for_inference() # retrieve the default image result for stitching to controlnet inputs. ema_image_results = all_validation_type_results.get("ema") - validation_image_results = all_validation_type_results.get("checkpoint", ema_image_results) + validation_image_results = all_validation_type_results.get( + "checkpoint", ema_image_results + ) original_validation_image_results = validation_image_results benchmark_image = None if self.args.controlnet: @@ -1387,7 +1409,9 @@ def validate_prompt( validation_shortname, resolution ) if benchmark_image is not None: - for idx, validation_image in enumerate(validation_image_results): + for idx, validation_image in enumerate( + validation_image_results + ): validation_image_results[idx] = self.stitch_benchmark_image( validation_image_result=validation_image, benchmark_image=benchmark_image, @@ -1396,9 +1420,13 @@ def validate_prompt( checkpoint_validation_images[validation_shortname].extend( original_validation_image_results ) - stitched_validation_images[validation_shortname].extend(validation_image_results) + stitched_validation_images[validation_shortname].extend( + validation_image_results + ) if self.args.use_ema: - ema_validation_images[validation_shortname].extend(ema_image_results) + ema_validation_images[validation_shortname].extend( + ema_image_results + ) except Exception as e: import traceback @@ -1407,15 +1435,31 @@ def validate_prompt( f"Error generating validation image: {e}, {traceback.format_exc()}" ) continue - if self.args.use_ema and self.args.ema_validation == "comparison" and benchmark_image is not None: - for idx, validation_image in enumerate(stitched_validation_images[validation_shortname]): - stitched_validation_images[validation_shortname][idx] = self.stitch_benchmark_image( - validation_image_result=ema_validation_images[validation_shortname][idx], - benchmark_image=stitched_validation_images[validation_shortname][idx], - labels=[None, "EMA"] + if ( + self.args.use_ema + and self.args.ema_validation == "comparison" + and benchmark_image is not None + ): + for idx, validation_image in enumerate( + stitched_validation_images[validation_shortname] + ): + stitched_validation_images[validation_shortname][idx] = ( + self.stitch_benchmark_image( + validation_image_result=ema_validation_images[ + validation_shortname + ][idx], + benchmark_image=stitched_validation_images[ + validation_shortname + ][idx], + labels=[None, "EMA"], + ) ) - return stitched_validation_images, checkpoint_validation_images, ema_validation_images + return ( + stitched_validation_images, + checkpoint_validation_images, + ema_validation_images, + ) def _save_images(self, validation_images, validation_shortname, validation_prompt): validation_img_idx = 0 @@ -1549,9 +1593,15 @@ def enable_ema_for_inference(self, pipeline=None): logger.info("Setting Lycoris multiplier to 1.0") self.accelerator._lycoris_wrapped_network.set_multiplier(1.0) logger.info("Storing Lycoris weights for later recovery.") - self.ema_model.store(self.accelerator._lycoris_wrapped_network.parameters()) - logger.info("Storing the EMA weights into the Lycoris adapter for inference.") - self.ema_model.copy_to(self.accelerator._lycoris_wrapped_network.parameters()) + self.ema_model.store( + self.accelerator._lycoris_wrapped_network.parameters() + ) + logger.info( + "Storing the EMA weights into the Lycoris adapter for inference." + ) + self.ema_model.copy_to( + self.accelerator._lycoris_wrapped_network.parameters() + ) elif self.args.lora_type.lower() == "standard": _trainable_parameters = [ x for x in self._primary_model().parameters() if x.requires_grad @@ -1583,11 +1633,16 @@ def disable_ema_for_inference(self): if self.args.use_ema: logger.info("Disabling EMA.") self.ema_enabled = False - if self.args.model_type == "lora" and self.args.lora_type.lower() == "lycoris": + if ( + self.args.model_type == "lora" + and self.args.lora_type.lower() == "lycoris" + ): logger.info("Setting Lycoris network multiplier to 1.0.") self.accelerator._lycoris_wrapped_network.set_multiplier(1.0) logger.info("Restoring Lycoris weights.") - self.ema_model.restore(self.accelerator._lycoris_wrapped_network.parameters()) + self.ema_model.restore( + self.accelerator._lycoris_wrapped_network.parameters() + ) else: logger.info("Restoring trainable parameters.") self.ema_model.restore(self.trainable_parameters()) @@ -1601,7 +1656,6 @@ def disable_ema_for_inference(self): "Skipping EMA model restoration for validation, as we are not using EMA." ) - def finalize_validation(self, validation_type): """Cleans up and restores original state if necessary.""" if not self.args.keep_vae_loaded and not self.args.vae_cache_ondemand: From 96eb1164a4c0a52b7c5bf2f562247ddcdddd7a0a Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 3 Dec 2024 11:30:38 -0600 Subject: [PATCH 20/26] update documentation --- OPTIONS.md | 23 +++++++++++++++++++---- documentation/LYCORIS.md | 2 +- documentation/quickstart/FLUX.md | 8 +++----- documentation/quickstart/SD3.md | 6 ++---- documentation/quickstart/SIGMA.md | 6 ++---- 5 files changed, 27 insertions(+), 18 deletions(-) diff --git a/OPTIONS.md b/OPTIONS.md index ddc32383..d0069deb 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -111,11 +111,15 @@ Note that the first several steps of training will be slower than usual because ### `--attention_mechanism` -Setting `sageattention` or `xformers` here will allow the use of other memory-efficient attention mechanisms for the forward pass during training and inference, potentially resulting in major performance improvement. +Alternative attention mechanisms are supported, with varying levels of compatibility or other trade-offs; -Using `sageattention` enables the use of [SageAttention](https://github.com/thu-ml/SageAttention) on NVIDIA CUDA equipment (sorry, AMD and Apple users). +- `diffusers` uses the native Pytorch SDPA functions and is the default attention mechanism +- `xformers` allows the use of Meta's [xformers](https://github.com/facebook/xformers) attention implementation which supports both training and inference fully +- `sageattention` is an inference-focused attention mechanism which does not fully support being used for training ([SageAttention](https://github.com/thu-ml/SageAttention) project page) + - In simplest terms, SageAttention reduces compute requirement for inference -In simple terms, this will quantise the attention calculations for lower compute and memory overhead, **massively** speeding up training while minimally impacting quality. +Using `--sageattention_usage` to enable training with SageAttention should be enabled with care, as it does not track or propagate gradients from its custom CUDA implementations for the QKV linears. + - This results in these layers being completely untrained, which might cause model collapse or, slight improvements in short training runs. --- @@ -520,6 +524,7 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr] [--text_encoder_3_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}] [--local_rank LOCAL_RANK] [--attention_mechanism {diffusers,xformers,sageattention,sageattention-int8-fp16-triton,sageattention-int8-fp16-cuda,sageattention-int8-fp8-cuda}] + [--sageattention_usage {training,inference,training+inference}] [--enable_xformers_memory_efficient_attention] [--set_grads_to_none] [--noise_offset NOISE_OFFSET] [--noise_offset_probability NOISE_OFFSET_PROBABILITY] @@ -1493,8 +1498,18 @@ options: highest maximum speed (with also a lower minimum speed). NOTE: SageAttention training quality has not been validated. + --sageattention_usage {training,inference,training+inference} + SageAttention breaks gradient tracking through the + backward pass, leading to untrained QKV layers. This + can result in substantial problems for training, so it + is recommended to use SageAttention only for inference + (default behaviour). If you are confident in your + training setup or do not wish to train QKV layers, you + may use 'training' to enable SageAttention for + training. --enable_xformers_memory_efficient_attention - Whether or not to use xformers. + Whether or not to use xformers. Deprecated and slated + for future removal. Use --attention_mechanism. --set_grads_to_none Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain behaviors, so disable this argument if it causes any diff --git a/documentation/LYCORIS.md b/documentation/LYCORIS.md index 853d3ec6..5d0c5f55 100644 --- a/documentation/LYCORIS.md +++ b/documentation/LYCORIS.md @@ -63,7 +63,7 @@ For more information on LyCORIS, please refer to the [documentation in the libra When using Lycoris on SDXL, it's noted that training the FeedForward modules may break the model and send loss into `NaN` (Not-a-Number) territory. -This seems to be potentially exacerbated when using SageAttention, making it all but guaranteed that the model will immediately fail. +This seems to be potentially exacerbated when using SageAttention (with `--sageattention_usage=training`), making it all but guaranteed that the model will immediately fail. The solution is to remove the `FeedForward` modules from the lycoris config and train only the `Attention` blocks. diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 189b1a72..6b2c60b1 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -414,17 +414,15 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with: - DeepSpeed: disabled / unconfigured - PyTorch: 2.6 Nightly (Sept 29th build) - Using `--quantize_via=cpu` to avoid outOfMemory error during startup on <=16G cards. -- With `--attention_mechanism=sageattention` to further reduce VRAM by 0.1GB and improve training speed. +- With `--attention_mechanism=sageattention` to further reduce VRAM by 0.1GB and improve training validation image generation speed. Speed was approximately 1.4 iterations per second on a 4090. ### SageAttention -When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations. +When using `--attention_mechanism=sageattention`, inference can be sped-up at validation time. -In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training. - -**Note**: This isn't compatible with _every_ configuration, but it's worth trying. +**Note**: This isn't compatible with _every_ model configuration, but it's worth trying. ### NF4-quantised training diff --git a/documentation/quickstart/SD3.md b/documentation/quickstart/SD3.md index 764bf64a..832eaa9f 100644 --- a/documentation/quickstart/SD3.md +++ b/documentation/quickstart/SD3.md @@ -341,11 +341,9 @@ These options have been known to keep SD3.5 in-tact for as long as possible: ### SageAttention -When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations. +When using `--attention_mechanism=sageattention`, inference can be sped-up at validation time. -In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training. - -**Note**: This isn't compatible with _every_ configuration, but it's worth trying. +**Note**: This isn't compatible with _every_ model configuration, but it's worth trying. ### Masked loss diff --git a/documentation/quickstart/SIGMA.md b/documentation/quickstart/SIGMA.md index aadd46f2..518bc006 100644 --- a/documentation/quickstart/SIGMA.md +++ b/documentation/quickstart/SIGMA.md @@ -223,8 +223,6 @@ If you wish to enable evaluations to score the model's performance, see [this do ### SageAttention -When using `--attention_mechanism=sageattention`, quantised operations are performed during SDPA calculations. +When using `--attention_mechanism=sageattention`, inference can be sped-up at validation time. -In simpler terms, this can very slightly improve VRAM usage while substantially speeding up training. - -**Note**: This isn't compatible with _every_ configuration, but it's worth trying. \ No newline at end of file +**Note**: This isn't compatible with _every_ model configuration, but it's worth trying. From 62f1034b0e417b5772dbcc88da3fd001cfb09c01 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 3 Dec 2024 11:48:24 -0600 Subject: [PATCH 21/26] update configure.py to use inference-only sageattention --- configure.py | 18 +++++++++++++++++- helpers/configuration/cmd_args.py | 2 +- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/configure.py b/configure.py index b5c26d28..7aced716 100644 --- a/configure.py +++ b/configure.py @@ -433,13 +433,29 @@ def configure_env(): env_contents["--attention_mechanism"] = "diffusers" use_sageattention = ( prompt_user( - "Would you like to use SageAttention? This is an experimental option that can greatly speed up training. (y/[n])", + "Would you like to use SageAttention for image validation generation? (y/[n])", "n", ).lower() == "y" ) if use_sageattention: env_contents["--attention_mechanism"] = "sageattention" + env_contents["--sageattention_usage"] = "inference" + use_sageattention_training = ( + prompt_user( + ( + "Would you like to use SageAttention to cover the forward and backward pass during training?" + " This has the undesirable consequence of leaving the attention layers untrained," + " as SageAttention lacks the capability to fully track gradients through quantisation." + " If you are not training the attention layers for some reason, this may not matter and" + " you can safely enable this. For all other use-cases, reconsideration and caution are warranted." + ), + "n", + ).lower() + == "y" + ) + if use_sageattention_training: + env_contents["--sageattention_usage"] = "both" # properly disable wandb/tensorboard/comet_ml etc by default report_to_str = "none" diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index c90923e5..0805ded8 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1744,7 +1744,7 @@ def get_argument_parser(): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", - help="Whether or not to use xformers. Deprecated and slated for future removal.", + help="Whether or not to use xformers. Deprecated and slated for future removal. Use --attention_mechanism.", ) parser.add_argument( "--set_grads_to_none", From a14065728d7caef8ef2cf53b57c97b4f18c22e34 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 3 Dec 2024 15:47:18 -0600 Subject: [PATCH 22/26] add --gradient_checkpointing_interval and disable gradient checkpointing before validations run for extra inference performance boost --- helpers/configuration/cmd_args.py | 9 ++++ helpers/models/flux/transformer.py | 27 +++++++++++- .../training/default_settings/safety_check.py | 18 ++++++++ helpers/training/diffusion_model.py | 29 ++++++++++++- .../gradient_checkpointing_interval.py | 42 +++++++++++++++++++ helpers/training/trainer.py | 34 +++++++++++++++ tests/test_trainer.py | 1 + 7 files changed, 157 insertions(+), 3 deletions(-) create mode 100644 helpers/training/gradient_checkpointing_interval.py diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 0805ded8..457e1118 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1054,6 +1054,15 @@ def get_argument_parser(): action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) + parser.add_argument( + "--gradient_checkpointing_interval", + default=None, + type=int, + help=( + "Some models (Flux, SDXL, SD1.x/2.x) can have their gradient checkpointing limited to every nth block." + " This can speed up training but will use more memory with larger intervals." + ), + ) parser.add_argument( "--learning_rate", type=float, diff --git a/helpers/models/flux/transformer.py b/helpers/models/flux/transformer.py index 7a9c80f2..280ff804 100644 --- a/helpers/models/flux/transformer.py +++ b/helpers/models/flux/transformer.py @@ -489,11 +489,16 @@ def __init__( ) self.gradient_checkpointing = False + # added for users to disable checkpointing every nth step + self.gradient_checkpointing_interval = None def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + def set_gradient_checkpointing_interval(self, value: int): + self.gradient_checkpointing_interval = value + def forward( self, hidden_states: torch.Tensor, @@ -574,7 +579,18 @@ def forward( image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if ( + self.training + and self.gradient_checkpointing + and ( + self.gradient_checkpointing_interval is None + or index_block % self.gradient_checkpointing_interval == 0 + ) + ): + + print( + f"checkpointing index {index_block} at interval: {self.gradient_checkpointing_interval}" + ) def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -614,7 +630,14 @@ def custom_forward(*inputs): hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): - if self.training and self.gradient_checkpointing: + if ( + self.training + and self.gradient_checkpointing + or ( + self.gradient_checkpointing_interval is not None + and index_block % self.gradient_checkpointing_interval == 0 + ) + ): def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index 36f321c5..e6522a40 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -133,3 +133,21 @@ def safety_check(args, accelerator): f"{args.base_model_precision} is not supported with SageAttention. Please select from int8 or fp8, or, disable quantisation to use SageAttention." ) sys.exit(1) + + gradient_checkpointing_interval_supported_models = [ + "flux", + "sdxl", + ] + if args.gradient_checkpointing_interval is not None: + if ( + args.model_family.lower() + not in gradient_checkpointing_interval_supported_models + ): + logger.error( + f"Gradient checkpointing is not supported with {args.model_family} models. Please disable --gradient_checkpointing_interval by setting it to None, or remove it from your configuration. Currently supported models: {gradient_checkpointing_interval_supported_models}" + ) + sys.exit(1) + if args.gradient_checkpointing_interval == 0: + raise ValueError( + "Gradient checkpointing interval must be greater than 0. Please set it to a positive integer." + ) diff --git a/helpers/training/diffusion_model.py b/helpers/training/diffusion_model.py index 5ac0c207..290858bc 100644 --- a/helpers/training/diffusion_model.py +++ b/helpers/training/diffusion_model.py @@ -52,7 +52,9 @@ def load_diffusion_model(args, weight_dtype): elif ( args.model_family.lower() == "flux" and not args.flux_attention_masked_training ): - from diffusers.models import FluxTransformer2DModel + from helpers.models.flux.transformer import ( + FluxTransformer2DModelWithMasking as FluxTransformer2DModel, + ) import torch if torch.cuda.is_available(): @@ -92,6 +94,10 @@ def load_diffusion_model(args, weight_dtype): subfolder=determine_subfolder(args.pretrained_transformer_subfolder), **pretrained_load_args, ) + if args.gradient_checkpointing_interval is not None: + transformer.set_gradient_checkpointing_interval( + int(args.gradient_checkpointing_interval) + ) elif args.model_family.lower() == "flux" and args.flux_attention_masked_training: from helpers.models.flux.transformer import ( FluxTransformer2DModelWithMasking, @@ -103,6 +109,10 @@ def load_diffusion_model(args, weight_dtype): subfolder=determine_subfolder(args.pretrained_transformer_subfolder), **pretrained_load_args, ) + if args.gradient_checkpointing_interval is not None: + transformer.set_gradient_checkpointing_interval( + int(args.gradient_checkpointing_interval) + ) elif args.model_family == "pixart_sigma": from diffusers.models import PixArtTransformer2DModel @@ -145,5 +155,22 @@ def load_diffusion_model(args, weight_dtype): subfolder=determine_subfolder(args.pretrained_unet_subfolder), **pretrained_load_args, ) + if ( + args.gradient_checkpointing_interval is not None + and args.gradient_checkpointing_interval > 0 + ): + logger.warning( + "Using experimental gradient checkpointing monkeypatch for a checkpoint interval of {}".format( + args.gradient_checkpointing_interval + ) + ) + # monkey-patch the gradient checkpointing function for pytorch to run every nth call only. + # definitely one of the more awful things I've ever done while programming, but it's easier than + # modifying every one of the unet blocks' forward calls in Diffusers to make it work properly. + from helpers.training.gradient_checkpointing_interval import ( + set_checkpoint_interval, + ) + + set_checkpoint_interval(int(args.gradient_checkpointing_interval)) return unet, transformer diff --git a/helpers/training/gradient_checkpointing_interval.py b/helpers/training/gradient_checkpointing_interval.py new file mode 100644 index 00000000..18026044 --- /dev/null +++ b/helpers/training/gradient_checkpointing_interval.py @@ -0,0 +1,42 @@ +import torch +from torch.utils.checkpoint import checkpoint as original_checkpoint + + +# Global variables to keep track of the checkpointing state +_checkpoint_call_count = 0 +_checkpoint_interval = 4 # You can set this to any interval you prefer + + +def reset_checkpoint_counter(): + """Resets the checkpoint call counter. Call this at the beginning of the forward pass.""" + global _checkpoint_call_count + _checkpoint_call_count = 0 + + +def set_checkpoint_interval(n): + """Sets the interval at which checkpointing is skipped.""" + global _checkpoint_interval + _checkpoint_interval = n + + +def checkpoint_wrapper(function, *args, use_reentrant=True, **kwargs): + """Wrapper function for torch.utils.checkpoint.checkpoint.""" + global _checkpoint_call_count, _checkpoint_interval + _checkpoint_call_count += 1 + + if ( + _checkpoint_interval > 0 + and (_checkpoint_call_count % _checkpoint_interval) == 0 + ): + # Use the original checkpoint function + return original_checkpoint( + function, *args, use_reentrant=use_reentrant, **kwargs + ) + else: + # Skip checkpointing: execute the function directly + # Do not pass 'use_reentrant' to the function + return function(*args, **kwargs) + + +# Monkeypatch torch.utils.checkpoint.checkpoint +torch.utils.checkpoint.checkpoint = checkpoint_wrapper diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index dfe7593d..f6a15696 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -997,8 +997,11 @@ def init_post_load_freeze(self): self.transformer = apply_bitfit_freezing( unwrap_model(self.accelerator, self.transformer), self.config ) + self.enable_gradient_checkpointing() + def enable_gradient_checkpointing(self): if self.config.gradient_checkpointing: + logger.info("Enabling gradient checkpointing.") if self.unet is not None: unwrap_model( self.accelerator, self.unet @@ -1022,6 +1025,32 @@ def init_post_load_freeze(self): self.accelerator, self.text_encoder_2 ).gradient_checkpointing_enable() + def disable_gradient_checkpointing(self): + if self.config.gradient_checkpointing: + logger.info("Disabling gradient checkpointing.") + if self.unet is not None: + unwrap_model( + self.accelerator, self.unet + ).disable_gradient_checkpointing() + if self.transformer is not None and self.config.model_family != "smoldit": + unwrap_model( + self.accelerator, self.transformer + ).disable_gradient_checkpointing() + if self.config.controlnet: + unwrap_model( + self.accelerator, self.controlnet + ).disable_gradient_checkpointing() + if ( + hasattr(self.config, "train_text_encoder") + and self.config.train_text_encoder + ): + unwrap_model( + self.accelerator, self.text_encoder_1 + ).gradient_checkpointing_disable() + unwrap_model( + self.accelerator, self.text_encoder_2 + ).gradient_checkpointing_disable() + def _get_trainable_parameters(self): # Return just a list of the currently trainable parameters. if self.config.model_type == "lora": @@ -2188,8 +2217,10 @@ def train(self): # normal run-of-the-mill validation on startup. if self.validation is not None: self.enable_sageattention_inference() + self.disable_gradient_checkpointing() self.validation.run_validations(validation_type="base_model", step=0) self.disable_sageattention_inference() + self.enable_gradient_checkpointing() self.mark_optimizer_train() @@ -2913,11 +2944,13 @@ def train(self): if self.validation is not None: if self.validation.would_validate(): self.enable_sageattention_inference() + self.disable_gradient_checkpointing() self.validation.run_validations( validation_type="intermediary", step=step ) if self.validation.would_validate(): self.disable_sageattention_inference() + self.enable_gradient_checkpointing() self.mark_optimizer_train() if ( self.config.push_to_hub @@ -2967,6 +3000,7 @@ def train(self): self.mark_optimizer_eval() if self.validation is not None: self.enable_sageattention_inference() + self.disable_gradient_checkpointing() validation_images = self.validation.run_validations( validation_type="final", step=self.state["global_step"], diff --git a/tests/test_trainer.py b/tests/test_trainer.py index c54789d6..3a036762 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -140,6 +140,7 @@ def test_stats_memory_used_none( flux_schedule_shift=3, flux_schedule_auto_shift=False, validation_guidance_skip_layers=None, + gradient_checkpointing_interval=None, ), ) def test_misc_init( From a4e5393f9d5787fdc47010c9bb17345f36b46307 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 3 Dec 2024 15:55:58 -0600 Subject: [PATCH 23/26] remove print message --- helpers/models/flux/transformer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/helpers/models/flux/transformer.py b/helpers/models/flux/transformer.py index 280ff804..77097648 100644 --- a/helpers/models/flux/transformer.py +++ b/helpers/models/flux/transformer.py @@ -588,10 +588,6 @@ def forward( ) ): - print( - f"checkpointing index {index_block} at interval: {self.gradient_checkpointing_interval}" - ) - def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: From 2ba64ab53e9e979ad4fc78222a38cc2ec1073fc7 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 3 Dec 2024 16:02:03 -0600 Subject: [PATCH 24/26] configurator should ask about gradient checkpointing interval --- configure.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/configure.py b/configure.py index 7aced716..447ab5fb 100644 --- a/configure.py +++ b/configure.py @@ -543,6 +543,18 @@ def configure_env(): ) ) env_contents["--gradient_checkpointing"] = "true" + gradient_checkpointing_interval = prompt_user( + "Would you like to configure a gradient checkpointing interval? A value larger than 1 will increase VRAM usage but speed up training by skipping checkpoint creation every Nth layer", + 0, + ) + try: + if int(gradient_checkpointing_interval) > 0: + env_contents["--gradient_checkpointing_interval"] = int( + gradient_checkpointing_interval + ) + except: + print("Could not parse gradient checkpointing interval. Not enabling.") + pass env_contents["--caption_dropout_probability"] = float( prompt_user( From 96d477ea3ec2bed49866028377c8387ba6911b63 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 3 Dec 2024 16:02:31 -0600 Subject: [PATCH 25/26] configurator logic fix --- configure.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configure.py b/configure.py index 447ab5fb..26620662 100644 --- a/configure.py +++ b/configure.py @@ -544,11 +544,11 @@ def configure_env(): ) env_contents["--gradient_checkpointing"] = "true" gradient_checkpointing_interval = prompt_user( - "Would you like to configure a gradient checkpointing interval? A value larger than 1 will increase VRAM usage but speed up training by skipping checkpoint creation every Nth layer", + "Would you like to configure a gradient checkpointing interval? A value larger than 1 will increase VRAM usage but speed up training by skipping checkpoint creation every Nth layer, and a zero will disable this feature.", 0, ) try: - if int(gradient_checkpointing_interval) > 0: + if int(gradient_checkpointing_interval) > 1: env_contents["--gradient_checkpointing_interval"] = int( gradient_checkpointing_interval ) From 8f8e1cf06c1aa5b5bd72bddfc1b9539aac91c952 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 3 Dec 2024 16:14:41 -0600 Subject: [PATCH 26/26] add gradient checkpointing option to docs --- OPTIONS.md | 9 +++++++++ documentation/quickstart/FLUX.md | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/OPTIONS.md b/OPTIONS.md index d0069deb..79c77603 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -40,6 +40,15 @@ The script `configure.py` in the project root can be used via `python configure. - **What**: Path to the pretrained T5 model or its identifier from https://huggingface.co/models. - **Why**: When training PixArt, you might want to use a specific source for your T5 weights so that you can avoid downloading them multiple times when switching the base model you train from. +### `--gradient_checkpointing` + +- **What**: During training, gradients will be calculated layerwise and accumulated to save on peak VRAM requirements at the cost of slower training. + +### `--gradient_checkpointing_interval` + +- **What**: Checkpoint only every _n_ blocks, where _n_ is a value greater than zero. A value of 1 is effectively the same as just leaving `--gradient_checkpointing` enabled, and a value of 2 will checkpoint every other block. +- **Note**: SDXL and Flux are currently the only models supporting this option. SDXL uses a hackish implementation. + ### `--refiner_training` - **What**: Enables training a custom mixture-of-experts model series. See [Mixture-of-Experts](/documentation/MIXTURE_OF_EXPERTS.md) for more information on these options. diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 6b2c60b1..ee5b98cc 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -144,6 +144,8 @@ There, you will possibly need to modify the following variables: - This option causes update steps to be accumulated over several steps. This will increase the training runtime linearly, such that a value of 2 will make your training run half as quickly, and take twice as long. - `optimizer` - Beginners are recommended to stick with adamw_bf16, though optimi-lion and optimi-stableadamw are also good choices. - `mixed_precision` - Beginners should keep this in `bf16` +- `gradient_checkpointing` - set this to true in practically every situation on every device +- `gradient_checkpointing_interval` - this could be set to a value of 2 or higher on larger GPUs to only checkpoint every _n_ blocks. A value of 2 would checkpoint half of the blocks, and 3 would be one-third. Multi-GPU users can reference [this document](/OPTIONS.md#environment-configuration-variables) for information on configuring the number of GPUs to use. @@ -415,6 +417,9 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with: - PyTorch: 2.6 Nightly (Sept 29th build) - Using `--quantize_via=cpu` to avoid outOfMemory error during startup on <=16G cards. - With `--attention_mechanism=sageattention` to further reduce VRAM by 0.1GB and improve training validation image generation speed. +- Be sure to enable `--gradient_checkpointing` or nothing you do will stop it from OOMing + +**NOTE**: Pre-caching of VAE embeds and text encoder outputs may use more memory and still OOM. If so, text encoder quantisation and VAE tiling can be enabled. Speed was approximately 1.4 iterations per second on a 4090.