Skip to content

Commit

Permalink
Merge pull request #1010 from bghira/main
Browse files Browse the repository at this point in the history
1.1 merge window
  • Loading branch information
bghira authored Oct 1, 2024
2 parents c721c6a + eafb09e commit 696760e
Show file tree
Hide file tree
Showing 33 changed files with 5,637 additions and 350 deletions.
16 changes: 11 additions & 5 deletions .github/workflows/python-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,23 @@ jobs:
runs-on: ubuntu-latest

steps:
- name: Maximize build space
uses: AdityaGarg8/remove-unwanted-software@v4.1
with:
remove-android: 'true'

- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.8
python-version: 3.11

- name: Install Poetry
run: python -m pip install --upgrade pip poetry

- name: Install Dependencies
run: |
python -m pip install --upgrade pip poetry
poetry install
run: poetry -C install/apple install

- name: Run Tests
run: poetry run python -m unittest discover tests/
run: poetry -C install/apple run python -m unittest discover tests/
203 changes: 155 additions & 48 deletions OPTIONS.md

Large diffs are not rendered by default.

15 changes: 10 additions & 5 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
from helpers.training.optimizer_param import optimizer_choices

bf16_only_optims = [
key for key, value in optimizer_choices.items() if value["precision"] == "bf16"
key
for key, value in optimizer_choices.items()
if value.get("precision", "any") == "bf16"
]
any_precision_optims = [
key for key, value in optimizer_choices.items() if value["precision"] == "any"
key
for key, value in optimizer_choices.items()
if value.get("precision", "any") == "any"
]
model_classes = {
"full": [
Expand All @@ -17,10 +21,10 @@
"pixart_sigma",
"kolors",
"sd3",
"stable_diffusion_legacy",
"legacy",
],
"lora": ["flux", "sdxl", "kolors", "sd3", "stable_diffusion_legacy"],
"controlnet": ["sdxl", "stable_diffusion_legacy"],
"lora": ["flux", "sdxl", "kolors", "sd3", "legacy"],
"controlnet": ["sdxl", "legacy"],
}

default_models = {
Expand All @@ -30,6 +34,7 @@
"kolors": "kwai-kolors/kolors-diffusers",
"terminus": "ptx0/terminus-xl-velocity-v2",
"sd3": "stabilityai/stable-diffusion-3-medium-diffusers",
"legacy": "stabilityai/stable-diffusion-2-1-base",
}

default_cfg = {
Expand Down
63 changes: 33 additions & 30 deletions documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,26 @@

![image](https://github.com/user-attachments/assets/6409d790-3bb4-457c-a4b4-a51a45fc91d1)

In this example, we'll be training a Flux.1 LoRA model using the SimpleTuner toolkit.
In this example, we'll be training a Flux.1 LoRA.

### Hardware requirements

Flux requires a lot of **system RAM** in addition to GPU memory. Simply quantising the model at startup requires about 50GB of system memory. If it takes an excessively long time, you may need to assess your hardware's capabilities and whether any changes are needed.

When you're training every component of a rank-16 LoRA (MLP, projections, multimodal blocks), it ends up using:
- a bit more than 32G VRAM when not quantising the base model
- a bit more than 20G VRAM when quantising to int8 + bf16 base/LoRA weights
- a bit more than 13G VRAM when quantising to int2 + bf16 base/LoRA weights
- a bit more than 30G VRAM when not quantising the base model
- a bit more than 18G VRAM when quantising to int8 + bf16 base/LoRA weights
- a bit more than 13G VRAM when quantising to int4 + bf16 base/LoRA weights
- a bit more than 9G VRAM when quantising to int2 + bf16 base/LoRA weights

To have reliable results, you'll need:
- **at minimum** a single 3090 or V100 GPU
- **ideally** multiple A6000s
You'll need:
- **the absolute minimum** is a single 4060 Ti 16GB
- **a realistic minimum** is a single 3090 or V100 GPU
- **ideally** multiple 4090, A6000, L40S, or better

Luckily, these are readily available through providers such as [LambdaLabs](https://lambdalabs.com) which provides the lowest available rates, and localised clusters for multi-node training.

**Unlike other models, AMD and Apple GPUs do not work for training Flux.**
**Unlike other models, Apple GPUs do not currently work for training Flux.**

### Prerequisites

Expand Down Expand Up @@ -123,7 +125,7 @@ There, you will possibly need to modify the following variables:
- `model_type` - Set this to `lora`.
- `model_family` - Set this to `flux`.
- `pretrained_model_name_or_path` - Set this to `black-forest-labs/FLUX.1-dev`.
- Note that you will *probably* need to log in to Huggingface and be granted access to download this model. We will go over logging in to Huggingface later in this tutorial.
- Note that you will need to log in to Huggingface and be granted access to download this model. We will go over logging in to Huggingface later in this tutorial.
- `output_dir` - Set this to the directory where you want to store your checkpoints and validation images. It's recommended to use a full path here.
- `train_batch_size` - this should be kept at 1, especially if you have a very small dataset.
- `validation_resolution` - As Flux is a 1024px model, you can set this to `1024x1024`.
Expand All @@ -139,6 +141,8 @@ There, you will possibly need to modify the following variables:
- `optimizer` - Beginners are recommended to stick with adamw_bf16, though optimi-lion and optimi-stableadamw are also good choices.
- `mixed_precision` - Beginners should keep this in `bf16`

Multi-GPU users can reference [this document](/OPTIONS.md#environment-configuration-variables) for information on configuring the number of GPUs to use.

#### Validation prompts

Inside `config/config.json` is the "primary validation prompt", which is typically the main instance_prompt you are training on for your single subject or style. Additionally, a JSON file may be created that contains extra prompts to run through during validations.
Expand Down Expand Up @@ -185,7 +189,7 @@ A set of diverse prompt will help determine whether the model is collapsing as i
#### Quantised model training

Tested on Apple and NVIDIA systems, Hugging Face Optimum-Quanto can be used to reduce the precision and VRAM requirements, training Flux on just 20GB.
Tested on Apple and NVIDIA systems, Hugging Face Optimum-Quanto can be used to reduce the precision and VRAM requirements, training Flux on just 16GB.

Inside your SimpleTuner venv:

Expand All @@ -203,9 +207,7 @@ For `config.json` users:
"base_model_default_dtype": "bf16"
```

#################################################
# Below guidance is for LoRA, not LyCORIS. #
#################################################
##### LoRA-specific settings (not LyCORIS)


```bash
Expand All @@ -215,6 +217,7 @@ For `config.json` users:
# - This mode has been reported to lack portability, and platforms such as ComfyUI might not be able to load the LoRA.
# The option to train only the 'context' blocks is offered as well, but its impact is unknown, and is offered as an experimental choice.
# - An extension to this mode, 'context+ffs' is also available, which is useful for pretraining new tokens into a LoRA before continuing finetuning it via `--init_lora`.
# Other options include 'tiny' and 'nano' which train just 1 or 2 layers.
"--flux_lora_target": "all",

# If you want to use LoftQ initialisation, you can't use Quanto to quantise the base model.
Expand Down Expand Up @@ -254,7 +257,8 @@ Create a `--data_backend_config` (`config/multidatabackend.json`) document conta
"disabled": false,
"skip_file_discovery": "",
"caption_strategy": "filename",
"metadata_backend": "discovery"
"metadata_backend": "discovery",
"repeats": 0
},
{
"id": "dreambooth-subject",
Expand All @@ -269,7 +273,8 @@ Create a `--data_backend_config` (`config/multidatabackend.json`) document conta
"instance_data_dir": "datasets/dreambooth-subject",
"caption_strategy": "instanceprompt",
"instance_prompt": "the name of your subject goes here",
"metadata_backend": "discovery"
"metadata_backend": "discovery",
"repeats": 1000
},
{
"id": "dreambooth-subject-512",
Expand All @@ -284,7 +289,8 @@ Create a `--data_backend_config` (`config/multidatabackend.json`) document conta
"instance_data_dir": "datasets/dreambooth-subject",
"caption_strategy": "instanceprompt",
"instance_prompt": "the name of your subject goes here",
"metadata_backend": "discovery"
"metadata_backend": "discovery",
"repeats": 1000
},
{
"id": "text-embeds",
Expand Down Expand Up @@ -390,35 +396,32 @@ We can partially reintroduce distillation to a de-distilled model by continuing
- Inference workflows for ComfyUI or other applications (eg. AUTOMATIC1111) will need to be modified to also enable "true" CFG, which might not be currently possible out of the box.

### Quantisation
- Minimum 8bit quantisation is required for a 24G card to train this model - but 32G (V100) cards suffer a more tragic fate.
- Without quantising the model, a rank-1 LoRA sits at just over 32GB of mem use, in a way that prevents a 32G V100 from actually working
- Using the optimi-lion optimiser may reduce training just enough to make the V100 work.
- Quantising the model doesn't harm training
- Minimum 8bit quantisation is required for a 16G card to train this model
- In bfloat16/float16, a rank-1 LoRA sits at just over 30GB of mem use
- Quantising the model to 8bit doesn't harm training
- It allows you to push higher batch sizes and possibly obtain a better result
- Behaves the same as full-precision training - fp32 won't make your model any better than bf16+int8.
- As usual, **fp8 quantisation runs more slowly** than **int8** and might have a worse result due to the use of `e4m3fn` in Quanto
- fp16 training similarly is bad for Flux; this model wants the range of bf16
- `e5m2` level precision is better at fp8 but haven't looked into how to enable it yet. Sorry, H100 owners. We weep for you.
- **int8** has hardware acceleration and `torch.compile()` support on newer NVIDIA hardware (3090 or better)
- **nf4** does not seem to benefit training as much as it benefits inference
- When loading the LoRA in ComfyUI later, you **must** use the same base model precision as you trained your LoRA on.
- `int4` is weird and really only works on A100 and H100 cards due to a reliance on custom bf16 kernels
- **int4** is weird and really only works on A100 and H100 cards due to a reliance on custom bf16 kernels

### Crashing
- If you get SIGKILL after the text encoders are unloaded, this means you do not have enough system memory to quantise Flux.
- Try loading the `--base_model_precision=bf16` but if that does not work, you might just need more memory..
- Try `--quantize_via=accelerator` to use the GPU instead

### Schnell
- Direct Schnell training really needs a bit more time in the oven - currently, the results do not look good
- If you absolutely must train Schnell, try the x-flux trainer from X-Labs
- Ostris' ai-toolkit uses a low-rank adapter probably pulled from OpenFLUX.1 as a source of CFG that can be inverted from the final result - this will probably be implemented here eventually after results are more widely available and tests have completed
- Training a LoRA on Dev will however, run just fine on Schnell
- Dev+Schnell merge 50/50 just fine, and the LoRAs can possibly be trained from that, which will then run on Schnell **or** Dev
- If you train a LyCORIS LoKr on Dev, it **generally** works very well on Schnell at just 4 steps later.
- Direct Schnell training really needs a bit more time in the oven - currently, the results do not look good

> ℹ️ When merging Schnell with Dev in any way, the license of Dev takes over and it becomes non-commercial. This shouldn't really matter for most users, but it's worth noting.
### Learning rates

#### LoRA (--lora_type=standard)
- It's been reported that Flux trains similarly to SD 1.5 LoRAs
- LoRA has overall worse performance than LoKr for larger datasets
- It's been reported that Flux LoRA trains similarly to SD 1.5 LoRAs
- However, a model as large as 12B has empirically performed better with **lower learning rates.**
- LoRA at 1e-3 might totally roast the thing. LoRA at 1e-5 does nearly nothing.
- Ranks as large as 64 through 128 might be undesirable on a 12B model due to general difficulties that scale up with the size of the base model.
Expand Down
4 changes: 2 additions & 2 deletions helpers/caching/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ def reclaim_memory():
import torch

if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

if torch.backends.mps.is_available():
torch.mps.empty_cache()
torch.mps.synchronize()

gc.collect()
gc.collect()
74 changes: 61 additions & 13 deletions helpers/caching/text_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,11 +767,23 @@ def compute_embeddings_for_sdxl_prompts(
self.debug_log(f"Adding embed to write queue: {filename}")
self.save_to_cache(filename, (prompt_embeds, add_text_embeds))

if self.webhook_handler is not None and int(self.write_thread_bar.n % self.webhook_progress_interval) < 10:
last_reported_index = int(self.write_thread_bar.n % self.webhook_progress_interval)
if (
self.webhook_handler is not None
and int(
self.write_thread_bar.n % self.webhook_progress_interval
)
< 10
):
last_reported_index = int(
self.write_thread_bar.n % self.webhook_progress_interval
)
self.send_progress_update(
type="init_cache_text_embeds_status_update",
progress=int(self.write_thread_bar.n // len(local_caption_split) * 100),
progress=int(
self.write_thread_bar.n
// len(local_caption_split)
* 100
),
total=len(local_caption_split),
current=0,
)
Expand Down Expand Up @@ -947,11 +959,23 @@ def compute_embeddings_for_legacy_prompts(

self.save_to_cache(filename, prompt_embeds)

if self.webhook_handler is not None and int(self.write_thread_bar.n % self.webhook_progress_interval) < 10:
last_reported_index = int(self.write_thread_bar.n % self.webhook_progress_interval)
if (
self.webhook_handler is not None
and int(
self.write_thread_bar.n % self.webhook_progress_interval
)
< 10
):
last_reported_index = int(
self.write_thread_bar.n % self.webhook_progress_interval
)
self.send_progress_update(
type="init_cache_text_embeds_status_update",
progress=int(self.write_thread_bar.n // len(local_caption_split) * 100),
progress=int(
self.write_thread_bar.n
// len(local_caption_split)
* 100
),
total=len(local_caption_split),
current=0,
)
Expand Down Expand Up @@ -1118,15 +1142,27 @@ def compute_embeddings_for_flux_prompts(
self.save_to_cache(
filename, (prompt_embeds, add_text_embeds, time_ids, masks)
)
if self.webhook_handler is not None and int(self.write_thread_bar.n % self.webhook_progress_interval) < 10:
last_reported_index = int(self.write_thread_bar.n % self.webhook_progress_interval)
if (
self.webhook_handler is not None
and int(
self.write_thread_bar.n % self.webhook_progress_interval
)
< 10
):
last_reported_index = int(
self.write_thread_bar.n % self.webhook_progress_interval
)
self.send_progress_update(
type="init_cache_text_embeds_status_update",
progress=int(self.write_thread_bar.n // len(local_caption_split) * 100),
progress=int(
self.write_thread_bar.n
// len(local_caption_split)
* 100
),
total=len(local_caption_split),
current=0,
)

if return_concat:
prompt_embeds = prompt_embeds.to(self.accelerator.device)
add_text_embeds = add_text_embeds.to(self.accelerator.device)
Expand Down Expand Up @@ -1292,11 +1328,23 @@ def compute_embeddings_for_sd3_prompts(
self.debug_log(f"Adding embed to write queue: {filename}")
self.save_to_cache(filename, (prompt_embeds, add_text_embeds))

if self.webhook_handler is not None and int(self.write_thread_bar.n % self.webhook_progress_interval) < 10:
last_reported_index = int(self.write_thread_bar.n % self.webhook_progress_interval)
if (
self.webhook_handler is not None
and int(
self.write_thread_bar.n % self.webhook_progress_interval
)
< 10
):
last_reported_index = int(
self.write_thread_bar.n % self.webhook_progress_interval
)
self.send_progress_update(
type="init_cache_text_embeds_status_update",
progress=int(self.write_thread_bar.n // len(local_caption_split) * 100),
progress=int(
self.write_thread_bar.n
// len(local_caption_split)
* 100
),
total=len(local_caption_split),
current=0,
)
Expand Down
Loading

0 comments on commit 696760e

Please sign in to comment.