Skip to content

Commit

Permalink
Merge pull request #1170 from bghira/feature/ema-for-lora
Browse files Browse the repository at this point in the history
(experimental) Allow EMA on LoRA/Lycoris networks
  • Loading branch information
bghira authored Nov 19, 2024
2 parents 9cdf13c + fc68c9e commit a85c6be
Show file tree
Hide file tree
Showing 15 changed files with 670 additions and 235 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ For multi-node distributed training, [this guide](/documentation/DISTRIBUTED.md)
- LoRA/LyCORIS training for PixArt, SDXL, SD3, and SD 2.x that uses less than 16G VRAM
- DeepSpeed integration allowing for [training SDXL's full u-net on 12G of VRAM](/documentation/DEEPSPEED.md), albeit very slowly.
- Quantised NF4/INT8/FP8 LoRA training, using low-precision base model to reduce VRAM consumption.
- Optional EMA (Exponential moving average) weight network to counteract model overfitting and improve training stability. **Note:** This does not apply to LoRA.
- Optional EMA (Exponential moving average) weight network to counteract model overfitting and improve training stability.
- Train directly from an S3-compatible storage provider, eliminating the requirement for expensive local storage. (Tested with Cloudflare R2 and Wasabi S3)
- For only SDXL and SD 1.x/2.x, full [ControlNet model training](/documentation/CONTROLNET.md) (not ControlLoRA or ControlLite)
- Training [Mixture of Experts](/documentation/MIXTURE_OF_EXPERTS.md) for lightweight, high-quality diffusion models
Expand Down Expand Up @@ -137,8 +137,8 @@ Flux prefers being trained with multiple large GPUs but a single 16G card should

- A100-80G (EMA, large batches, LoRA @ insane batch sizes)
- A6000-48G (EMA@768px, no EMA@1024px, LoRA @ high batch sizes)
- A100-40G (no EMA@1024px, no EMA@768px, EMA@512px, LoRA @ high batch sizes)
- 4090-24G (no EMA@1024px, batch size 1-4, LoRA @ medium-high batch sizes)
- A100-40G (EMA@1024px, EMA@768px, EMA@512px, LoRA @ high batch sizes)
- 4090-24G (EMA@1024px, batch size 1-4, LoRA @ medium-high batch sizes)
- 4080-12G (LoRA @ low-medium batch sizes)

### Stable Diffusion 2.x, 768px
Expand Down
4 changes: 1 addition & 3 deletions documentation/DEEPSPEED.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,4 @@ While EMA is a great way to smooth out gradients and improve generalisation abil

EMA holds a shadow copy of the model parameters in memory, essentially doubling the footprint of the model. For SimpleTuner, EMA is not passed through the Accelerator module, which means it is not impacted by DeepSpeed. This means the memory savings that we saw with the base U-net, are not realised with the EMA model.

That said, any memory savings that applied to the base U-net could possibly allow EMA weights to load and operate effectively.

Future work is planned to allow EMA to run on the CPU only.
However, by default, the EMA model is kept on CPU.
6 changes: 6 additions & 0 deletions documentation/DREAMBOOTH.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@ Alternatively, one might use the real name of their subject, or a 'similar enoug

After a number of training experiments, it seems as though a 'similar enough' celebrity is the best choice, especially if prompting the model for the person's real name ends up looking dissimilar.

# Exponential moving average (EMA)

A second model can be trained in parallel to your checkpoint, nearly for free - only the resulting system memory (by default) is consumed, rather than more VRAM.

Applying `use_ema=true` in your config file will enable this feature.

# CLIP score tracking

If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.
Expand Down
27 changes: 17 additions & 10 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,7 @@ def get_argument_parser():
parser.add_argument(
"--use_ema",
action="store_true",
help="Whether to use EMA (exponential moving average) model.",
help="Whether to use EMA (exponential moving average) model. Works with LoRA, Lycoris, and full training.",
)
parser.add_argument(
"--ema_device",
Expand All @@ -1122,6 +1122,17 @@ def get_argument_parser():
" This provides the fastest EMA update times, but is not ultimately necessary for EMA to function."
),
)
parser.add_argument(
"--ema_validation",
choices=["none", "ema_only", "comparison"],
default="comparison",
help=(
"When 'none' is set, no EMA validation will be done."
" When using 'ema_only', the validations will rely mostly on the EMA weights."
" When using 'comparison' (default) mode, the validations will first run on the checkpoint before also running for"
" the EMA weights. In comparison mode, the resulting images will be provided side-by-side."
)
)
parser.add_argument(
"--ema_cpu_only",
action="store_true",
Expand Down Expand Up @@ -1338,7 +1349,7 @@ def get_argument_parser():
help=(
"Validations must be enabled for model evaluation to function. The default is to use no evaluator,"
" and 'clip' will use a CLIP model to evaluate the resulting model's performance during validations."
)
),
)
parser.add_argument(
"--pretrained_evaluation_model_name_or_path",
Expand All @@ -1348,7 +1359,7 @@ def get_argument_parser():
"Optionally provide a custom model to use for ViT evaluations."
" The default is currently clip-vit-large-patch14-336, allowing for lower patch sizes (greater accuracy)"
" and an input resolution of 336x336."
)
),
)
parser.add_argument(
"--validation_on_startup",
Expand Down Expand Up @@ -2351,13 +2362,9 @@ def parse_cmdline_args(input_args=None):
)
args.gradient_precision = "fp32"

if args.use_ema:
if args.model_family == "sd3":
raise ValueError(
"Using EMA is not currently supported for Stable Diffusion 3 training."
)
if "lora" in args.model_type:
raise ValueError("Using EMA is not currently supported for LoRA training.")
# if args.use_ema:
# if "lora" in args.model_type:
# raise ValueError("Using EMA is not currently supported for LoRA training.")
args.logging_dir = os.path.join(args.output_dir, args.logging_dir)
args.accelerator_project_config = ProjectConfiguration(
project_dir=args.output_dir, logging_dir=args.logging_dir
Expand Down
23 changes: 23 additions & 0 deletions helpers/publishing/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


LORA_SAFETENSORS_FILENAME = "pytorch_lora_weights.safetensors"
EMA_SAFETENSORS_FILENAME = "ema_model.safetensors"


class HubManager:
Expand Down Expand Up @@ -110,6 +111,8 @@ def upload_model(self, validation_images, webhook_handler=None, override_path=No
self.upload_full_model(override_path=override_path)
else:
self.upload_lora_model(override_path=override_path)
if self.config.use_ema:
self.upload_ema_model(override_path=override_path)
break
except Exception as e:
if webhook_handler:
Expand Down Expand Up @@ -155,6 +158,26 @@ def upload_lora_model(self, override_path=None):
except Exception as e:
logger.error(f"Failed to upload LoRA weights to hub: {e}")

def upload_ema_model(self, override_path=None):
try:
check_ema_paths = ["transformer_ema", "unet_ema", "controlnet_ema", "ema"]
# if any of the folder names are present in the checkpoint dir, we will upload them too
for check_ema_path in check_ema_paths:
print(f"Checking for EMA path: {check_ema_path}")
ema_path = os.path.join(
override_path or self.config.output_dir, check_ema_path
)
if os.path.exists(ema_path):
print(f"Found EMA checkpoint!")
upload_folder(
repo_id=self._repo_id,
folder_path=ema_path,
path_in_repo="/ema",
commit_message="LoRA EMA checkpoint auto-generated by SimpleTuner",
)
except Exception as e:
logger.error(f"Failed to upload LoRA EMA weights to hub: {e}")

def find_latest_checkpoint(self):
checkpoints = list(Path(self.config.output_dir).rglob("checkpoint-*"))
highest_checkpoint_value = None
Expand Down
20 changes: 18 additions & 2 deletions helpers/publishing/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ def _model_imports(args):
return f"{output}"


def ema_info(args):
if args.use_ema:
ema_information = """
## Elastic Moving Average (EMA)
SimpleTuner generates a safetensors variant of the EMA weights and a pt file.
The safetensors file is intended to be used for inference, and the pt file is for continuing finetuning.
The EMA model may provide a more well-rounded result, but typically will feel undertrained compared to the full model as it is a running decayed average of the model weights.
"""
return ema_information
return ""

def lycoris_download_info():
"""output a function to download the adapter"""
output_fn = """
Expand Down Expand Up @@ -145,7 +159,7 @@ def _model_load(args, repo_id: str = None):
output = (
f"model_id = '{args.pretrained_model_name_or_path}'"
f"\nadapter_id = '{repo_id if repo_id is not None else args.output_dir}'"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id), torch_dtype={StateTracker.get_weight_dtype()}) # loading directly in bf16"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype={StateTracker.get_weight_dtype()}) # loading directly in bf16"
f"\npipeline.load_lora_weights(adapter_id)"
)
elif args.lora_type.lower() == "lycoris":
Expand Down Expand Up @@ -390,7 +404,7 @@ def ddpm_schedule_info(args):
f"training_scheduler_timestep_spacing={args.training_scheduler_timestep_spacing}"
)
output_args.append(
f"validation_scheduler_timestep_spacing={args.validation_scheduler_timestep_spacing}"
f"inference_scheduler_timestep_spacing={args.inference_scheduler_timestep_spacing}"
)
output_str = (
f" (extra parameters={output_args})"
Expand Down Expand Up @@ -552,6 +566,8 @@ def save_model_card(
## Inference
{code_example(args=StateTracker.get_args(), repo_id=repo_id)}
{ema_info(args=StateTracker.get_args())}
"""

logger.debug(f"YAML:\n{yaml_content}")
Expand Down
2 changes: 1 addition & 1 deletion helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,4 @@ def safety_check(args, accelerator):
logger.error(
f"--flux_schedule_auto_shift cannot be combined with --flux_schedule_shift. Please set --flux_schedule_shift to 0 if you want to train with --flux_schedule_auto_shift."
)
sys.exit(1)
sys.exit(1)
Loading

0 comments on commit a85c6be

Please sign in to comment.