Skip to content

Commit

Permalink
Merge pull request #1186 from bghira/main
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
bghira authored Dec 3, 2024
2 parents 2f8fc6e + aa42e05 commit 07d9ea7
Show file tree
Hide file tree
Showing 24 changed files with 690 additions and 67 deletions.
72 changes: 68 additions & 4 deletions OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ The script `configure.py` in the project root can be used via `python configure.
- **What**: Path to the pretrained T5 model or its identifier from https://huggingface.co/models.
- **Why**: When training PixArt, you might want to use a specific source for your T5 weights so that you can avoid downloading them multiple times when switching the base model you train from.

### `--gradient_checkpointing`

- **What**: During training, gradients will be calculated layerwise and accumulated to save on peak VRAM requirements at the cost of slower training.

### `--gradient_checkpointing_interval`

- **What**: Checkpoint only every _n_ blocks, where _n_ is a value greater than zero. A value of 1 is effectively the same as just leaving `--gradient_checkpointing` enabled, and a value of 2 will checkpoint every other block.
- **Note**: SDXL and Flux are currently the only models supporting this option. SDXL uses a hackish implementation.

### `--refiner_training`

- **What**: Enables training a custom mixture-of-experts model series. See [Mixture-of-Experts](/documentation/MIXTURE_OF_EXPERTS.md) for more information on these options.
Expand Down Expand Up @@ -109,6 +118,18 @@ Carefully answer the questions and use bf16 mixed precision training when prompt

Note that the first several steps of training will be slower than usual because of compilation occuring in the background.

### `--attention_mechanism`

Alternative attention mechanisms are supported, with varying levels of compatibility or other trade-offs;

- `diffusers` uses the native Pytorch SDPA functions and is the default attention mechanism
- `xformers` allows the use of Meta's [xformers](https://github.com/facebook/xformers) attention implementation which supports both training and inference fully
- `sageattention` is an inference-focused attention mechanism which does not fully support being used for training ([SageAttention](https://github.com/thu-ml/SageAttention) project page)
- In simplest terms, SageAttention reduces compute requirement for inference

Using `--sageattention_usage` to enable training with SageAttention should be enabled with care, as it does not track or propagate gradients from its custom CUDA implementations for the QKV linears.
- This results in these layers being completely untrained, which might cause model collapse or, slight improvements in short training runs.

---

## 📰 Publishing
Expand Down Expand Up @@ -452,7 +473,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--lr_scheduler {linear,sine,cosine,cosine_with_restarts,polynomial,constant,constant_with_warmup}]
[--lr_warmup_steps LR_WARMUP_STEPS]
[--lr_num_cycles LR_NUM_CYCLES] [--lr_power LR_POWER]
[--use_ema] [--ema_device {cpu,accelerator}] [--ema_cpu_only]
[--use_ema] [--ema_device {cpu,accelerator}]
[--ema_validation {none,ema_only,comparison}] [--ema_cpu_only]
[--ema_foreach_disable]
[--ema_update_interval EMA_UPDATE_INTERVAL]
[--ema_decay EMA_DECAY] [--non_ema_revision NON_EMA_REVISION]
Expand All @@ -473,8 +495,9 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--model_card_safe_for_work] [--logging_dir LOGGING_DIR]
[--benchmark_base_model] [--disable_benchmark]
[--evaluation_type {clip,none}]
[--pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path]
[--pretrained_evaluation_model_name_or_path PRETRAINED_EVALUATION_MODEL_NAME_OR_PATH]
[--validation_on_startup] [--validation_seed_source {gpu,cpu}]
[--validation_lycoris_strength VALIDATION_LYCORIS_STRENGTH]
[--validation_torch_compile]
[--validation_torch_compile_mode {max-autotune,reduce-overhead,default}]
[--validation_guidance_skip_layers VALIDATION_GUIDANCE_SKIP_LAYERS]
Expand Down Expand Up @@ -509,6 +532,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--text_encoder_2_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}]
[--text_encoder_3_precision {no_change,int8-quanto,int4-quanto,int2-quanto,int8-torchao,nf4-bnb,fp8-quanto,fp8uz-quanto}]
[--local_rank LOCAL_RANK]
[--attention_mechanism {diffusers,xformers,sageattention,sageattention-int8-fp16-triton,sageattention-int8-fp16-cuda,sageattention-int8-fp8-cuda}]
[--sageattention_usage {training,inference,training+inference}]
[--enable_xformers_memory_efficient_attention]
[--set_grads_to_none] [--noise_offset NOISE_OFFSET]
[--noise_offset_probability NOISE_OFFSET_PROBABILITY]
Expand Down Expand Up @@ -1137,12 +1162,21 @@ options:
cosine_with_restarts scheduler.
--lr_power LR_POWER Power factor of the polynomial scheduler.
--use_ema Whether to use EMA (exponential moving average) model.
Works with LoRA, Lycoris, and full training.
--ema_device {cpu,accelerator}
The device to use for the EMA model. If set to
'accelerator', the EMA model will be placed on the
accelerator. This provides the fastest EMA update
times, but is not ultimately necessary for EMA to
function.
--ema_validation {none,ema_only,comparison}
When 'none' is set, no EMA validation will be done.
When using 'ema_only', the validations will rely
mostly on the EMA weights. When using 'comparison'
(default) mode, the validations will first run on the
checkpoint before also running for the EMA weights. In
comparison mode, the resulting images will be provided
side-by-side.
--ema_cpu_only When using EMA, the shadow model is moved to the
accelerator before we update its parameters. When
provided, this option will disable the moving of the
Expand Down Expand Up @@ -1248,7 +1282,7 @@ options:
function. The default is to use no evaluator, and
'clip' will use a CLIP model to evaluate the resulting
model's performance during validations.
--pretrained_evaluation_model_name_or_path pretrained_evaluation_model_name_or_path
--pretrained_evaluation_model_name_or_path PRETRAINED_EVALUATION_MODEL_NAME_OR_PATH
Optionally provide a custom model to use for ViT
evaluations. The default is currently clip-vit-large-
patch14-336, allowing for lower patch sizes (greater
Expand All @@ -1264,6 +1298,12 @@ options:
validation errors. If so, please set
SIMPLETUNER_LOG_LEVEL=DEBUG and submit debug.log to a
new Github issue report.
--validation_lycoris_strength VALIDATION_LYCORIS_STRENGTH
When inferencing for validations, the Lycoris model
will by default be run at its training strength, 1.0.
However, this value can be increased to a value of
around 1.3 or 1.5 to get a stronger effect from the
model.
--validation_torch_compile
Supply `--validation_torch_compile=true` to enable the
use of torch.compile() on the validation pipeline. For
Expand Down Expand Up @@ -1453,8 +1493,32 @@ options:
quantisation (Apple Silicon, NVIDIA, AMD).
--local_rank LOCAL_RANK
For distributed training: local_rank
--attention_mechanism {diffusers,xformers,sageattention,sageattention-int8-fp16-triton,sageattention-int8-fp16-cuda,sageattention-int8-fp8-cuda}
On NVIDIA CUDA devices, alternative flash attention
implementations are offered, with the default being
native pytorch SDPA. SageAttention has multiple
backends to select from. The recommended value,
'sageattention', guesses what would be the 'best'
option for SageAttention on your hardware (usually
this is the int8-fp16-cuda backend). However, manually
setting this value to int8-fp16-triton may provide
better averages for per-step training and inference
performance while the cuda backend may provide the
highest maximum speed (with also a lower minimum
speed). NOTE: SageAttention training quality has not
been validated.
--sageattention_usage {training,inference,training+inference}
SageAttention breaks gradient tracking through the
backward pass, leading to untrained QKV layers. This
can result in substantial problems for training, so it
is recommended to use SageAttention only for inference
(default behaviour). If you are confident in your
training setup or do not wish to train QKV layers, you
may use 'training' to enable SageAttention for
training.
--enable_xformers_memory_efficient_attention
Whether or not to use xformers.
Whether or not to use xformers. Deprecated and slated
for future removal. Use --attention_mechanism.
--set_grads_to_none Save more memory by using setting grads to None
instead of zero. Be aware, that this changes certain
behaviors, so disable this argument if it causes any
Expand Down
51 changes: 46 additions & 5 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,36 @@ def configure_env():
).lower()
== "y"
)
report_to_str = ""

env_contents["--attention_mechanism"] = "diffusers"
use_sageattention = (
prompt_user(
"Would you like to use SageAttention for image validation generation? (y/[n])",
"n",
).lower()
== "y"
)
if use_sageattention:
env_contents["--attention_mechanism"] = "sageattention"
env_contents["--sageattention_usage"] = "inference"
use_sageattention_training = (
prompt_user(
(
"Would you like to use SageAttention to cover the forward and backward pass during training?"
" This has the undesirable consequence of leaving the attention layers untrained,"
" as SageAttention lacks the capability to fully track gradients through quantisation."
" If you are not training the attention layers for some reason, this may not matter and"
" you can safely enable this. For all other use-cases, reconsideration and caution are warranted."
),
"n",
).lower()
== "y"
)
if use_sageattention_training:
env_contents["--sageattention_usage"] = "both"

# properly disable wandb/tensorboard/comet_ml etc by default
report_to_str = "none"
if report_to_wandb or report_to_tensorboard:
tracker_project_name = prompt_user(
"Enter the name of your Weights & Biases project", f"{model_type}-training"
Expand All @@ -440,17 +469,17 @@ def configure_env():
f"simpletuner-{model_type}",
)
env_contents["--tracker_run_name"] = tracker_run_name
report_to_str = None
if report_to_wandb:
report_to_str = "wandb"
if report_to_tensorboard:
if report_to_wandb:
if report_to_str != "none":
# report to both WandB and Tensorboard if the user wanted.
report_to_str += ","
else:
# remove 'none' from the option
report_to_str = ""
report_to_str += "tensorboard"
if report_to_str:
env_contents["--report_to"] = report_to_str
env_contents["--report_to"] = report_to_str

print_config(env_contents, extra_args)

Expand Down Expand Up @@ -514,6 +543,18 @@ def configure_env():
)
)
env_contents["--gradient_checkpointing"] = "true"
gradient_checkpointing_interval = prompt_user(
"Would you like to configure a gradient checkpointing interval? A value larger than 1 will increase VRAM usage but speed up training by skipping checkpoint creation every Nth layer, and a zero will disable this feature.",
0,
)
try:
if int(gradient_checkpointing_interval) > 1:
env_contents["--gradient_checkpointing_interval"] = int(
gradient_checkpointing_interval
)
except:
print("Could not parse gradient checkpointing interval. Not enabling.")
pass

env_contents["--caption_dropout_probability"] = float(
prompt_user(
Expand Down
8 changes: 8 additions & 0 deletions documentation/LYCORIS.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ Mandatory fields:

For more information on LyCORIS, please refer to the [documentation in the library](https://github.com/KohakuBlueleaf/LyCORIS/tree/main/docs).

## Potential problems

When using Lycoris on SDXL, it's noted that training the FeedForward modules may break the model and send loss into `NaN` (Not-a-Number) territory.

This seems to be potentially exacerbated when using SageAttention (with `--sageattention_usage=training`), making it all but guaranteed that the model will immediately fail.

The solution is to remove the `FeedForward` modules from the lycoris config and train only the `Attention` blocks.

## LyCORIS Inference Example

Here is a simple FLUX.1-dev inference script showing how to wrap your unet or transformer with create_lycoris_from_weights and then use it for inference.
Expand Down
13 changes: 13 additions & 0 deletions documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ There, you will possibly need to modify the following variables:
- This option causes update steps to be accumulated over several steps. This will increase the training runtime linearly, such that a value of 2 will make your training run half as quickly, and take twice as long.
- `optimizer` - Beginners are recommended to stick with adamw_bf16, though optimi-lion and optimi-stableadamw are also good choices.
- `mixed_precision` - Beginners should keep this in `bf16`
- `gradient_checkpointing` - set this to true in practically every situation on every device
- `gradient_checkpointing_interval` - this could be set to a value of 2 or higher on larger GPUs to only checkpoint every _n_ blocks. A value of 2 would checkpoint half of the blocks, and 3 would be one-third.

Multi-GPU users can reference [this document](/OPTIONS.md#environment-configuration-variables) for information on configuring the number of GPUs to use.

Expand Down Expand Up @@ -414,9 +416,19 @@ Currently, the lowest VRAM utilisation (9090M) can be attained with:
- DeepSpeed: disabled / unconfigured
- PyTorch: 2.6 Nightly (Sept 29th build)
- Using `--quantize_via=cpu` to avoid outOfMemory error during startup on <=16G cards.
- With `--attention_mechanism=sageattention` to further reduce VRAM by 0.1GB and improve training validation image generation speed.
- Be sure to enable `--gradient_checkpointing` or nothing you do will stop it from OOMing

**NOTE**: Pre-caching of VAE embeds and text encoder outputs may use more memory and still OOM. If so, text encoder quantisation and VAE tiling can be enabled.

Speed was approximately 1.4 iterations per second on a 4090.

### SageAttention

When using `--attention_mechanism=sageattention`, inference can be sped-up at validation time.

**Note**: This isn't compatible with _every_ model configuration, but it's worth trying.

### NF4-quantised training

In simplest terms, NF4 is a 4bit-_ish_ representation of the model, which means training has serious stability concerns to address.
Expand All @@ -428,6 +440,7 @@ In early tests, the following holds true:
- NF4, AdamW8bit, and a higher batch size all help to overcome the stability issues, at the cost of more time spent training or VRAM used
- Upping the resolution from 512px to 1024px slows training down from, for example, 1.4 seconds per step to 3.5 seconds per step (batch size of 1, 4090)
- Anything that's difficult to train on int8 or bf16 becomes harder in NF4
- It's less compatible with options like SageAttention

NF4 does not work with torch.compile, so whatever you get for speed is what you get.

Expand Down
6 changes: 6 additions & 0 deletions documentation/quickstart/SD3.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,12 @@ These options have been known to keep SD3.5 in-tact for as long as possible:
- DeepSpeed: disabled / unconfigured
- PyTorch: 2.5

### SageAttention

When using `--attention_mechanism=sageattention`, inference can be sped-up at validation time.

**Note**: This isn't compatible with _every_ model configuration, but it's worth trying.

### Masked loss

If you are training a subject or style and would like to mask one or the other, see the [masked loss training](/documentation/DREAMBOOTH.md#masked-loss) section of the Dreambooth guide.
Expand Down
6 changes: 6 additions & 0 deletions documentation/quickstart/SIGMA.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,9 @@ For more information, see the [dataloader](/documentation/DATALOADER.md) and [tu
### CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.

### SageAttention

When using `--attention_mechanism=sageattention`, inference can be sped-up at validation time.

**Note**: This isn't compatible with _every_ model configuration, but it's worth trying.
Loading

0 comments on commit 07d9ea7

Please sign in to comment.