From 0fba5f6d6931558a5c9768de5bd9182eb1dcb0ad Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 17 Nov 2024 12:39:56 -0600 Subject: [PATCH 01/13] (experimental) Allow EMA on LoRA/Lycoris networks --- helpers/configuration/cmd_args.py | 14 +- helpers/training/ema.py | 189 +++++++++++++--------- helpers/training/quantisation/__init__.py | 54 ++++++- helpers/training/save_hooks.py | 47 +++--- helpers/training/trainer.py | 81 ++++++---- tests/test_ema.py | 106 ++++++++++++ train.py | 3 + 7 files changed, 354 insertions(+), 140 deletions(-) create mode 100644 tests/test_ema.py diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index ebeb4235..ab41b5a0 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1338,7 +1338,7 @@ def get_argument_parser(): help=( "Validations must be enabled for model evaluation to function. The default is to use no evaluator," " and 'clip' will use a CLIP model to evaluate the resulting model's performance during validations." - ) + ), ) parser.add_argument( "--pretrained_evaluation_model_name_or_path", @@ -1348,7 +1348,7 @@ def get_argument_parser(): "Optionally provide a custom model to use for ViT evaluations." " The default is currently clip-vit-large-patch14-336, allowing for lower patch sizes (greater accuracy)" " and an input resolution of 336x336." - ) + ), ) parser.add_argument( "--validation_on_startup", @@ -2351,13 +2351,9 @@ def parse_cmdline_args(input_args=None): ) args.gradient_precision = "fp32" - if args.use_ema: - if args.model_family == "sd3": - raise ValueError( - "Using EMA is not currently supported for Stable Diffusion 3 training." - ) - if "lora" in args.model_type: - raise ValueError("Using EMA is not currently supported for LoRA training.") + # if args.use_ema: + # if "lora" in args.model_type: + # raise ValueError("Using EMA is not currently supported for LoRA training.") args.logging_dir = os.path.join(args.output_dir, args.logging_dir) args.accelerator_project_config = ProjectConfiguration( project_dir=args.output_dir, logging_dir=args.logging_dir diff --git a/helpers/training/ema.py b/helpers/training/ema.py index 9fdb5f2d..de04c148 100644 --- a/helpers/training/ema.py +++ b/helpers/training/ema.py @@ -7,9 +7,10 @@ from typing import Any, Dict, Iterable, Optional, Union from diffusers.utils.deprecation_utils import deprecate from diffusers.utils import is_transformers_available +from helpers.training.state_tracker import StateTracker logger = logging.getLogger("EMAModel") -logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) +logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "WARNING")) def should_update_ema(args, step): @@ -119,6 +120,62 @@ def __init__( self.model_config = model_config self.args = args self.accelerator = accelerator + self.training = True # To emulate nn.Module's training mode + + def save_state_dict(self, path: str) -> None: + """ + Save the EMA model's state directly to a file. + + Args: + path (str): The file path where the EMA state will be saved. + """ + # if the folder containing the path does not exist, create it + os.makedirs(os.path.dirname(path), exist_ok=True) + # grab state dict + state_dict = self.state_dict() + # save it using torch.save + torch.save(state_dict, path) + logger.info(f"EMA model state saved to {path}") + + def load_state_dict(self, path: str) -> None: + """ + Load the EMA model's state from a file and apply it to this instance. + + Args: + path (str): The file path from where the EMA state will be loaded. + """ + state_dict = torch.load(path, map_location="cpu", weights_only=True) + + # Load metadata + self.decay = state_dict.get("decay", self.decay) + self.min_decay = state_dict.get("min_decay", self.min_decay) + self.optimization_step = state_dict.get( + "optimization_step", self.optimization_step + ) + self.update_after_step = state_dict.get( + "update_after_step", self.update_after_step + ) + self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) + self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) + self.power = state_dict.get("power", self.power) + + # Load shadow parameters + shadow_params = [] + idx = 0 + while f"shadow_params.{idx}" in state_dict: + shadow_params.append(state_dict[f"shadow_params.{idx}"]) + idx += 1 + + if len(shadow_params) != len(self.shadow_params): + raise ValueError( + f"Mismatch in number of shadow parameters: expected {len(self.shadow_params)}, " + f"but found {len(shadow_params)} in the state dict." + ) + + for current_param, loaded_param in zip(self.shadow_params, shadow_params): + current_param.data.copy_(loaded_param.data) + + logger.info(f"EMA model state loaded from {path}") @classmethod def from_pretrained(cls, path, model_cls) -> "EMAModel": @@ -176,7 +233,6 @@ def get_decay(self, optimization_step: int = None) -> float: @torch.no_grad() def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None): if not should_update_ema(self.args, global_step): - return if self.args.ema_device == "cpu" and not self.args.ema_cpu_only: @@ -290,6 +346,7 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: ) else: for s_param, param in zip(self.shadow_params, parameters): + print(f"From shape: {s_param.shape}, to shape: {param.shape}") param.data.copy_(s_param.to(param.device).data) def pin_memory(self) -> None: @@ -307,31 +364,22 @@ def pin_memory(self) -> None: # This probably won't work, but we'll do it anyway. self.shadow_params = [p.pin_memory() for p in self.shadow_params] - def to(self, device=None, dtype=None, non_blocking=False) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. + def to(self, *args, **kwargs): + for param in self.shadow_params: + param.data = param.data.to(*args, **kwargs) + return self - Args: - device: like `device` argument to `torch.Tensor.to` - """ - # .to() on the tensors handles None correctly - self.shadow_params = [ - ( - p.to(device=device, dtype=dtype, non_blocking=non_blocking) - if p.is_floating_point() - else p.to(device=device, non_blocking=non_blocking) - ) - for p in self.shadow_params - ] + def cuda(self, device=None): + return self.to(device="cuda" if device is None else f"cuda:{device}") + + def cpu(self): + return self.to(device="cpu") - def state_dict(self) -> dict: + def state_dict(self, destination=None, prefix="", keep_vars=False): r""" - Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during - checkpointing to save the ema state dict. + Returns a dictionary containing a whole state of the EMA model. """ - # Following PyTorch conventions, references to tensors are returned: - # "returns a reference to the state and not its copy!" - - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict - return { + state_dict = { "decay": self.decay, "min_decay": self.min_decay, "optimization_step": self.optimization_step, @@ -339,27 +387,22 @@ def state_dict(self) -> dict: "use_ema_warmup": self.use_ema_warmup, "inv_gamma": self.inv_gamma, "power": self.power, - "shadow_params": self.shadow_params, } + for idx, param in enumerate(self.shadow_params): + state_dict[f"{prefix}shadow_params.{idx}"] = ( + param if keep_vars else param.detach() + ) + return state_dict def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" - Args: Save the current parameters for restoring later. - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. """ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" - Args: - Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: - affecting the original optimization process. Store the parameters before the `copy_to()` method. After - validation (or model saving), use this to restore the former parameters. - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. If `None`, the parameters with which this - `ExponentialMovingAverage` was initialized will be used. + Restore the parameters stored with the `store` method. """ if self.temp_stored_params is None: raise RuntimeError( @@ -378,53 +421,45 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: # Better memory-wise. self.temp_stored_params = None - def load_state_dict(self, state_dict: dict) -> None: - r""" - Args: - Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the - ema state dict. - state_dict (dict): EMA state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = copy.deepcopy(state_dict) + def parameter_count(self) -> int: + return sum(p.numel() for p in self.shadow_params) - self.decay = state_dict.get("decay", self.decay) - if self.decay < 0.0 or self.decay > 1.0: - raise ValueError("Decay must be between 0 and 1") + # Implementing nn.Module methods to emulate its behavior - self.min_decay = state_dict.get("min_decay", self.min_decay) - if not isinstance(self.min_decay, float): - raise ValueError("Invalid min_decay") + def named_children(self): + # No child modules + return iter([]) - self.optimization_step = state_dict.get( - "optimization_step", self.optimization_step - ) - if not isinstance(self.optimization_step, int): - raise ValueError("Invalid optimization_step") + def children(self): + return iter([]) - self.update_after_step = state_dict.get( - "update_after_step", self.update_after_step - ) - if not isinstance(self.update_after_step, int): - raise ValueError("Invalid update_after_step") + def modules(self): + yield self - self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) - if not isinstance(self.use_ema_warmup, bool): - raise ValueError("Invalid use_ema_warmup") + def named_modules(self, memo=None, prefix=""): + yield prefix, self - self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) - if not isinstance(self.inv_gamma, (float, int)): - raise ValueError("Invalid inv_gamma") + def parameters(self, recurse=True): + return iter(self.shadow_params) - self.power = state_dict.get("power", self.power) - if not isinstance(self.power, (float, int)): - raise ValueError("Invalid power") - - shadow_params = state_dict.get("shadow_params", None) - if shadow_params is not None: - self.shadow_params = shadow_params - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") + def named_parameters(self, prefix="", recurse=True): + for i, param in enumerate(self.shadow_params): + name = f"{prefix}shadow_params.{i}" + yield name, param + + def buffers(self, recurse=True): + return iter([]) + + def named_buffers(self, prefix="", recurse=True): + return iter([]) + + def train(self, mode=True): + self.training = mode + return self + + def eval(self): + return self.train(False) + + def zero_grad(self): + # No gradients to zero in EMA model + pass diff --git a/helpers/training/quantisation/__init__.py b/helpers/training/quantisation/__init__.py index ac96770b..f39c21fe 100644 --- a/helpers/training/quantisation/__init__.py +++ b/helpers/training/quantisation/__init__.py @@ -206,7 +206,15 @@ def get_quant_fn(base_model_precision): def quantise_model( - unet, transformer, text_encoder_1, text_encoder_2, text_encoder_3, controlnet, args + unet=None, + transformer=None, + text_encoder_1=None, + text_encoder_2=None, + text_encoder_3=None, + controlnet=None, + ema=None, + args=None, + return_dict: bool = False, ): """ Quantizes the provided models using the specified precision settings. @@ -218,6 +226,7 @@ def quantise_model( text_encoder_2: The second text encoder to quantize. text_encoder_3: The third text encoder to quantize. controlnet: The ControlNet model to quantize. + ema: An EMAModel to quantize. args: An object containing precision settings and other arguments. Returns: @@ -273,6 +282,14 @@ def quantise_model( "base_model_precision": args.base_model_precision, }, ), + ( + ema, + { + "quant_fn": get_quant_fn(args.base_model_precision), + "model_precision": args.base_model_precision, + "quantize_activations": args.quantize_activations, + }, + ), ] # Iterate over the models and apply quantization if the model is not None @@ -293,8 +310,33 @@ def quantise_model( models[i] = (quant_fn(model, **quant_args_combined), quant_args) # Unpack the quantized models - transformer, unet, controlnet, text_encoder_1, text_encoder_2, text_encoder_3 = [ - model for model, _ in models - ] - - return unet, transformer, text_encoder_1, text_encoder_2, text_encoder_3, controlnet + ( + transformer, + unet, + controlnet, + text_encoder_1, + text_encoder_2, + text_encoder_3, + ema, + ) = [model for model, _ in models] + + if return_dict: + return { + "unet": unet, + "transformer": transformer, + "text_encoder_1": text_encoder_1, + "text_encoder_2": text_encoder_2, + "text_encoder_3": text_encoder_3, + "controlnet": controlnet, + "ema": ema, + } + + return ( + unet, + transformer, + text_encoder_1, + text_encoder_2, + text_encoder_3, + controlnet, + ema, + ) diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index f6b75807..695b678b 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -1,4 +1,5 @@ -from diffusers.training_utils import EMAModel, _set_state_dict_into_text_encoder +from diffusers.training_utils import _set_state_dict_into_text_encoder +from helpers.training.ema import EMAModel from helpers.training.wrappers import unwrap_model from helpers.training.multi_process import _get_rank as get_rank from diffusers.utils import ( @@ -22,7 +23,7 @@ logger = logging.getLogger("SaveHookManager") -logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL") or "INFO") +logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "WARNING")) try: from diffusers import ( @@ -176,13 +177,10 @@ def __init__( self.ema_model_subdir = None if unet is not None: self.ema_model_subdir = "unet_ema" - self.ema_model_cls = UNet2DConditionModel + self.ema_model_cls = unet.__class__ if transformer is not None: self.ema_model_subdir = "transformer_ema" - if self.args.model_family == "sd3": - self.ema_model_cls = SD3Transformer2DModel - elif self.args.model_family == "pixart_sigma": - self.ema_model_cls = PixArtTransformer2DModel + self.ema_model_cls = transformer.__class__ self.training_state_path = "training_state.json" if self.accelerator is not None: rank = get_rank() @@ -295,11 +293,14 @@ def _save_full_model(self, models, weights, output_dir): os.makedirs(temporary_dir, exist_ok=True) if self.args.use_ema: - tqdm.write("Saving EMA model") - self.ema_model.save_pretrained( - os.path.join(temporary_dir, self.ema_model_subdir), - max_shard_size="10GB", + ema_model_path = os.path.join( + temporary_dir, self.ema_model_subdir, "ema_model.pt" ) + logger.info(f"Saving EMA model to {ema_model_path}") + try: + self.ema_model.save_state_dict(ema_model_path) + except Exception as e: + logger.error(f"Error saving EMA model: {e}") if self.unet is not None: sub_dir = "unet" @@ -334,6 +335,15 @@ def save_model_hook(self, models, weights, output_dir): ) if not self.accelerator.is_main_process: return + if self.args.use_ema: + ema_model_path = os.path.join( + output_dir, self.ema_model_subdir, "ema_model.pt" + ) + logger.info(f"Saving EMA model to {ema_model_path}") + try: + self.ema_model.save_state_dict(ema_model_path) + except Exception as e: + logger.error(f"Error saving EMA model: {e}") if "lora" in self.args.model_type and self.args.lora_type == "standard": self._save_lora(models=models, weights=weights, output_dir=output_dir) return @@ -455,13 +465,6 @@ def _load_lycoris(self, models, input_dir): lycoris_logger.setLevel(logging.ERROR) def _load_full_model(self, models, input_dir): - if self.args.use_ema: - load_model = EMAModel.from_pretrained( - os.path.join(input_dir, self.ema_model_subdir), self.ema_model_cls - ) - self.ema_model.load_state_dict(load_model.state_dict()) - self.ema_model.to(self.accelerator.device) - del load_model if self.args.model_type == "full": return_exception = False for i in range(len(models)): @@ -508,6 +511,14 @@ def load_model_hook(self, models, input_dir): logger.warning( f"Could not find {training_state_path} in checkpoint dir {input_dir}" ) + if self.args.use_ema: + try: + self.ema_model.load_state_dict( + os.path.join(input_dir, self.ema_model_subdir, "ema_model.pt") + ) + # self.ema_model.to(self.accelerator.device) + except Exception as e: + logger.error(f"Could not load EMA model: {e}") if "lora" in self.args.model_type and self.args.lora_type == "standard": self._load_lora(models=models, input_dir=input_dir) elif "lora" in self.args.model_type and self.args.lora_type == "lycoris": diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index b1525391..bc144212 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -176,6 +176,7 @@ def __init__( self.text_encoder_2 = None self.text_encoder_3 = None self.controlnet = None + self.ema_model = None self.validation = None def _config_to_obj(self, config): @@ -350,11 +351,7 @@ def _misc_init(self): self.config.is_torchao = True elif "bnb" in self.config.base_model_precision: self.config.is_bnb = True - if self.config.is_quanto: - from helpers.training.quantisation import quantise_model - - self.quantise_model = quantise_model - elif self.config.is_torchao: + if self.config.is_quanto or self.config.is_torchao: from helpers.training.quantisation import quantise_model self.quantise_model = quantise_model @@ -756,7 +753,9 @@ def init_unload_text_encoder(self): " The real memories were the friends we trained a model on along the way." ) - def init_precision(self, preprocessing_models_only: bool = False): + def init_precision( + self, preprocessing_models_only: bool = False, ema_only: bool = False + ): self.config.enable_adamw_bf16 = ( True if self.config.weight_dtype == torch.bfloat16 else False ) @@ -765,7 +764,7 @@ def init_precision(self, preprocessing_models_only: bool = False): ) if "bnb" in self.config.base_model_precision: - # can't cast or move bitsandbytes modelsthis + # can't cast or move bitsandbytes models return if not self.config.disable_accelerator and self.config.is_quantized: @@ -793,6 +792,10 @@ def init_precision(self, preprocessing_models_only: bool = False): if self.config.is_quanto: with self.accelerator.local_main_process_first(): + if ema_only: + self.quantise_model(ema=self.ema_model, args=self.config) + + return self.quantise_model( unet=self.unet if not preprocessing_models_only else None, transformer=( @@ -802,10 +805,17 @@ def init_precision(self, preprocessing_models_only: bool = False): text_encoder_2=self.text_encoder_2, text_encoder_3=self.text_encoder_3, controlnet=None, + ema=self.ema_model, args=self.config, ) elif self.config.is_torchao: with self.accelerator.local_main_process_first(): + if ema_only: + self.ema_model = self.quantise_model( + ema=self.ema_model, args=self.config, return_dict=True + )["ema"] + + return ( self.unet, self.transformer, @@ -813,6 +823,7 @@ def init_precision(self, preprocessing_models_only: bool = False): self.text_encoder_2, self.text_encoder_3, self.controlnet, + self.ema_model, ) = self.quantise_model( unet=self.unet if not preprocessing_models_only else None, transformer=( @@ -822,6 +833,7 @@ def init_precision(self, preprocessing_models_only: bool = False): text_encoder_2=self.text_encoder_2, text_encoder_3=self.text_encoder_3, controlnet=None, + ema=self.ema_model, args=self.config, ) @@ -1010,6 +1022,22 @@ def init_post_load_freeze(self): self.accelerator, self.text_encoder_2 ).gradient_checkpointing_enable() + def _get_trainable_parameters(self): + # Return just a list of the currently trainable parameters. + if self.config.model_type == "lora": + if self.config.lora_type == "lycoris": + return self.lycoris_wrapped_network.parameters() + if self.config.controlnet: + return [ + param for param in self.controlnet.parameters() if param.requires_grad + ] + if self.unet is not None: + return [param for param in self.unet.parameters() if param.requires_grad] + if self.transformer is not None: + return [ + param for param in self.transformer.parameters() if param.requires_grad + ] + def _recalculate_training_steps(self): # Scheduler and math around the number of training steps. if not hasattr(self.config, "overrode_max_train_steps"): @@ -1190,37 +1218,33 @@ def init_ema_model(self): logger.info("Using EMA. Creating EMAModel.") ema_model_cls = None - if self.unet is not None: - ema_model_cls = UNet2DConditionModel - elif self.config.model_family == "pixart_sigma": - ema_model_cls = PixArtTransformer2DModel - elif self.config.model_family == "flux": - ema_model_cls = FluxTransformer2DModel - else: - raise ValueError( - f"Please open a bug report or disable EMA. Unknown EMA model family: {self.config.model_family}" - ) - ema_model_config = None - if self.unet is not None: + if self.config.controlnet: + ema_model_cls = self.controlnet.__class__ + ema_model_config = self.controlnet.config + elif self.unet is not None: + ema_model_cls = self.unet.__class__ ema_model_config = self.unet.config elif self.transformer is not None: + ema_model_cls = self.transformer.__class__ ema_model_config = self.transformer.config + else: + raise ValueError( + f"Please open a bug report or disable EMA. Unknown EMA model family: {self.config.model_family}" + ) self.ema_model = EMAModel( self.config, self.accelerator, - parameters=( - self.unet.parameters() - if self.unet is not None - else self.transformer.parameters() - ), + parameters=self._get_trainable_parameters(), model_cls=ema_model_cls, model_config=ema_model_config, decay=self.config.ema_decay, foreach=not self.config.ema_foreach_disable, ) - logger.info("EMA model creation complete.") + logger.info( + f"EMA model creation completed with {self.ema_model.parameter_count():,} parameters" + ) self.accelerator.wait_for_everyone() @@ -1296,6 +1320,7 @@ def init_prepare_models(self, lr_scheduler): if self.config.use_ema and self.ema_model is not None: if self.config.ema_device == "accelerator": logger.info("Moving EMA model weights to accelerator...") + print(f"EMA model: {self.ema_model}") self.ema_model.to( ( self.accelerator.device @@ -2624,11 +2649,7 @@ def train(self): if self.ema_model is not None: training_logger.debug("Stepping EMA forward") self.ema_model.step( - parameters=( - self.unet.parameters() - if self.unet is not None - else self.transformer.parameters() - ), + parameters=self._get_trainable_parameters(), global_step=self.state["global_step"], ) wandb_logs["ema_decay_value"] = self.ema_model.get_decay() diff --git a/tests/test_ema.py b/tests/test_ema.py new file mode 100644 index 00000000..a05099e9 --- /dev/null +++ b/tests/test_ema.py @@ -0,0 +1,106 @@ +import unittest +import torch +import tempfile +import os +from helpers.training.ema import EMAModel + + +class TestEMAModel(unittest.TestCase): + def setUp(self): + # Set up a simple model and its parameters + self.model = torch.nn.Linear(10, 5) # Simple linear model + self.args = type( + "Args", + (), + {"ema_update_interval": None, "ema_device": "cpu", "ema_cpu_only": True}, + ) + self.accelerator = None # For simplicity, assuming no accelerator in tests + self.ema_model = EMAModel( + args=self.args, + accelerator=self.accelerator, + parameters=self.model.parameters(), + decay=0.999, + min_decay=0.999, # Force decay to be 0.999 + update_after_step=-1, # Ensure decay is used from step 1 + use_ema_warmup=False, # Disable EMA warmup + foreach=False, + ) + + def test_ema_initialization(self): + """Test that the EMA model initializes correctly.""" + self.assertEqual( + len(self.ema_model.shadow_params), len(list(self.model.parameters())) + ) + for shadow_param, model_param in zip( + self.ema_model.shadow_params, self.model.parameters() + ): + self.assertTrue(torch.equal(shadow_param, model_param)) + + def test_ema_step(self): + """Test that the EMA model updates correctly after a step.""" + # Perform a model parameter update + optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01) + dummy_input = torch.randn(1, 10) # Adjust to match input size + dummy_output = self.model(dummy_input) + loss = dummy_output.sum() # A dummy loss function + loss.backward() + optimizer.step() + + # Save a copy of the model parameters after the update but before the EMA update. + model_params = [p.clone() for p in self.model.parameters()] + # Save a copy of the shadow parameters before the EMA update. + shadow_params_before = [p.clone() for p in self.ema_model.shadow_params] + + # Perform an EMA update + self.ema_model.step(self.model.parameters(), global_step=1) + decay = self.ema_model.cur_decay_value # This should be 0.999 + + # Verify that the decay used is as expected + self.assertAlmostEqual( + decay, 0.999, places=6, msg="Decay value is not as expected." + ) + + # Verify shadow parameters have changed + for shadow_param, shadow_param_before in zip( + self.ema_model.shadow_params, shadow_params_before + ): + self.assertFalse( + torch.equal(shadow_param, shadow_param_before), + "Shadow parameters did not update correctly.", + ) + + # Compute and check expected shadow parameter values + for shadow_param, shadow_param_before, model_param in zip( + self.ema_model.shadow_params, shadow_params_before, self.model.parameters() + ): + expected_shadow = decay * shadow_param_before + (1 - decay) * model_param + self.assertTrue( + torch.allclose(shadow_param, expected_shadow, atol=1e-6), + f"Shadow parameter does not match expected value.", + ) + + def test_save_and_load_state_dict(self): + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = os.path.join(temp_dir, "ema_model_state.pth") + + # Save the state + self.ema_model.save_state_dict(temp_path) + + # Create a new EMA model and load the state + new_ema_model = EMAModel( + args=self.args, + accelerator=self.accelerator, + parameters=self.model.parameters(), + decay=0.999, + ) + new_ema_model.load_state_dict(temp_path) + + # Check that the new EMA model's shadow parameters match the saved state + for shadow_param, new_shadow_param in zip( + self.ema_model.shadow_params, new_ema_model.shadow_params + ): + self.assertTrue(torch.equal(shadow_param, new_shadow_param)) + + +if __name__ == "__main__": + unittest.main() diff --git a/train.py b/train.py index 7e00e46b..13642a4d 100644 --- a/train.py +++ b/train.py @@ -27,6 +27,7 @@ trainer.init_huggingface_hub() trainer.init_preprocessing_models() + trainer.init_precision(preprocessing_models_only=True) trainer.init_data_backend() trainer.init_validation_prompts() trainer.init_unload_text_encoder() @@ -38,6 +39,8 @@ trainer.init_freeze_models() trainer.init_trainable_peft_adapter() trainer.init_ema_model() + # EMA must be quantised if the base model is as well. + trainer.init_precision(ema_only=True) trainer.move_models(destination="accelerator") trainer.init_validations() From 1cb6356978cc086b88806c8fc19005aa86d4f16c Mon Sep 17 00:00:00 2001 From: bghira Date: Sun, 17 Nov 2024 14:01:59 -0600 Subject: [PATCH 02/13] validations: attempt to copy trainable parameters only --- helpers/training/trainer.py | 1 + helpers/training/validation.py | 47 +++++++++++++++++++--------------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index bc144212..30af7ed3 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1394,6 +1394,7 @@ def init_validations(self): return model_evaluator = ModelEvaluator.from_config(args=self.config) self.validation = Validation( + trainable_parameters=self._get_trainable_parameters(), accelerator=self.accelerator, unet=self.unet, transformer=self.transformer, diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 7597bd24..25898c31 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -399,7 +399,9 @@ def __init__( tokenizer_3=None, is_deepspeed: bool = False, model_evaluator=None, + trainable_parameters=None, ): + self.trainable_parameters = trainable_parameters self.accelerator = accelerator self.prompt_handler = None self.unet = unet @@ -958,17 +960,13 @@ def setup_scheduler(self): def setup_pipeline(self, validation_type, enable_ema_model: bool = True): if hasattr(self.accelerator, "_lycoris_wrapped_network"): - self.accelerator._lycoris_wrapped_network.set_multiplier(float(getattr( - self.args, "validation_lycoris_strength", 1.0 - ))) + self.accelerator._lycoris_wrapped_network.set_multiplier( + float(getattr(self.args, "validation_lycoris_strength", 1.0)) + ) if validation_type == "intermediary" and self.args.use_ema: if enable_ema_model: - if self.unet is not None: - self.ema_model.store(self.unet.parameters()) - self.ema_model.copy_to(self.unet.parameters()) - if self.transformer is not None: - self.ema_model.store(self.transformer.parameters()) - self.ema_model.copy_to(self.transformer.parameters()) + self.ema_model.store(self.trainable_parameters) + self.ema_model.copy_to(self.trainable_parameters) if self.args.ema_device != "accelerator": logger.info("Moving EMA weights to GPU for inference.") self.ema_model.to(self.inference_device) @@ -1094,9 +1092,13 @@ def setup_pipeline(self, validation_type, enable_ema_model: bool = True): break if self.args.validation_torch_compile: if self.deepspeed: - logger.warning("DeepSpeed does not support torch compile. Disabling. Set --validation_torch_compile=False to suppress this warning.") + logger.warning( + "DeepSpeed does not support torch compile. Disabling. Set --validation_torch_compile=False to suppress this warning." + ) elif self.args.lora_type.lower() == "lycoris": - logger.warning("LyCORIS does not support torch compile for validation due to graph compile breaks. Disabling. Set --validation_torch_compile=False to suppress this warning.") + logger.warning( + "LyCORIS does not support torch compile for validation due to graph compile breaks. Disabling. Set --validation_torch_compile=False to suppress this warning." + ) else: if self.unet is not None and not is_compiled_module(self.unet): logger.warning( @@ -1165,10 +1167,10 @@ def process_prompts(self): ) self.validation_prompt_dict[shortname] = prompt logger.debug(f"Processing validation for prompt: {prompt}") - stitched_validation_images, original_validation_images = self.validate_prompt(prompt, shortname, validation_input_image) - validation_images.update( - stitched_validation_images + stitched_validation_images, original_validation_images = ( + self.validate_prompt(prompt, shortname, validation_input_image) ) + validation_images.update(stitched_validation_images) self._save_images(validation_images, shortname, prompt) logger.debug(f"Completed generating image: {prompt}") self.validation_images = validation_images @@ -1181,7 +1183,7 @@ def process_prompts(self): def get_eval_result(self): return self.evaluation_result or {} - + def clear_eval_result(self): self.evaluation_result = None @@ -1342,11 +1344,14 @@ def validate_prompt( pipeline_kwargs.pop("negative_mask")[0], dim=0 ).to(device=self.inference_device, dtype=self.weight_dtype) - original_validation_image_results = self.pipeline(**pipeline_kwargs).images + original_validation_image_results = self.pipeline( + **pipeline_kwargs + ).images validation_image_results = original_validation_image_results.copy() if self.args.controlnet: validation_image_results = self.stitch_conditioning_images( - original_validation_image_results, extra_validation_kwargs["image"] + original_validation_image_results, + extra_validation_kwargs["image"], ) elif not self.args.disable_benchmark and self.benchmark_exists( "base_model" @@ -1360,7 +1365,9 @@ def validate_prompt( validation_image_results[0], benchmark_image ) validation_images[validation_shortname].extend(validation_image_results) - original_validation_images[validation_shortname].extend(original_validation_image_results) + original_validation_images[validation_shortname].extend( + original_validation_image_results + ) except Exception as e: import traceback @@ -1510,7 +1517,7 @@ def evaluate_images(self, images: list = None): for shortname, image_list in images.items(): if shortname in self.eval_scores: continue - prompt = self.validation_prompt_dict.get(shortname, '') + prompt = self.validation_prompt_dict.get(shortname, "") for image in image_list: evaluation_score = self.model_evaluator.evaluate([image], [prompt]) self.eval_scores[shortname] = round(float(evaluation_score), 4) @@ -1522,4 +1529,4 @@ def evaluate_images(self, images: list = None): "clip/std": np.std(list(self.eval_scores.values())), } - return result \ No newline at end of file + return result From 9a18a2ada855ba605d9f730a079430af0c00d54a Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 18 Nov 2024 04:16:28 +0000 Subject: [PATCH 03/13] ema validations overhaul to allow side by side comparison --- helpers/configuration/cmd_args.py | 11 ++ helpers/training/ema.py | 1 - helpers/training/save_hooks.py | 2 +- helpers/training/trainer.py | 1 + helpers/training/validation.py | 194 +++++++++++++++++++----------- helpers/webhooks/handler.py | 4 + 6 files changed, 142 insertions(+), 71 deletions(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index ab41b5a0..223e52fb 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1122,6 +1122,17 @@ def get_argument_parser(): " This provides the fastest EMA update times, but is not ultimately necessary for EMA to function." ), ) + parser.add_argument( + "--ema_validation", + choices=["none", "ema_only", "comparison"], + default="comparison", + help=( + "When 'none' is set, no EMA validation will be done." + " When using 'ema_only', the validations will rely mostly on the EMA weights." + " When using 'comparison' (default) mode, the validations will first run on the checkpoint before also running for" + " the EMA weights. In comparison mode, the resulting images will be provided side-by-side." + ) + ) parser.add_argument( "--ema_cpu_only", action="store_true", diff --git a/helpers/training/ema.py b/helpers/training/ema.py index de04c148..fa8f4f0c 100644 --- a/helpers/training/ema.py +++ b/helpers/training/ema.py @@ -346,7 +346,6 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: ) else: for s_param, param in zip(self.shadow_params, parameters): - print(f"From shape: {s_param.shape}, to shape: {param.shape}") param.data.copy_(s_param.to(param.device).data) def pin_memory(self) -> None: diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index 695b678b..f9b27204 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -511,7 +511,7 @@ def load_model_hook(self, models, input_dir): logger.warning( f"Could not find {training_state_path} in checkpoint dir {input_dir}" ) - if self.args.use_ema: + if self.args.use_ema and self.accelerator.is_main_process: try: self.ema_model.load_state_dict( os.path.join(input_dir, self.ema_model_subdir, "ema_model.pt") diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 30af7ed3..786bc0df 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -2654,6 +2654,7 @@ def train(self): global_step=self.state["global_step"], ) wandb_logs["ema_decay_value"] = self.ema_model.get_decay() + ema_decay_value = wandb_logs["ema_decay_value"] self.accelerator.wait_for_everyone() # Log scatter plot to wandb diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 25898c31..3ad6a234 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -434,6 +434,7 @@ def __init__( else validation_negative_prompt_embeds[0] ) self.ema_model = ema_model + self.ema_enabled = False self.vae = vae self.pipeline = None self.deepfloyd = True if "deepfloyd" in self.args.model_type else False @@ -735,7 +736,7 @@ def _benchmark_path(self, benchmark: str = "base_model"): return os.path.join(self.args.output_dir, "benchmarks", benchmark) def stitch_benchmark_image( - self, validation_image_result, benchmark_image, separator_width=5 + self, validation_image_result, benchmark_image, separator_width=5, labels=["base model", "checkpoint"] ): """ For each image, make a new canvas and place it side by side with its equivalent from {self.validation_image_inputs} @@ -744,8 +745,8 @@ def stitch_benchmark_image( """ # Calculate new dimensions - new_width = validation_image_result.size[0] * 2 + separator_width - new_height = validation_image_result.size[1] + new_width = benchmark_image.size[0] + validation_image_result.size[0] + separator_width + new_height = benchmark_image.size[1] # Create a new image with a white background new_image = Image.new("RGB", (new_width, new_height), color="white") @@ -766,24 +767,26 @@ def stitch_benchmark_image( font = ImageFont.load_default() # Add text to the left image - draw.text( - (10, 10), - "base model", - fill=(255, 255, 255), - font=font, - stroke_width=2, - stroke_fill=(0, 0, 0), - ) + if labels[0] is not None: + draw.text( + (10, 10), + labels[0], + fill=(255, 255, 255), + font=font, + stroke_width=2, + stroke_fill=(0, 0, 0), + ) - # Add text to the right image - draw.text( - (validation_image_result.size[0] + separator_width + 10, 10), - "checkpoint", - fill=(255, 255, 255), - font=font, - stroke_width=2, - stroke_fill=(0, 0, 0), - ) + if labels[1] is not None: + # Add text to the right image + draw.text( + (benchmark_image.size[0] + separator_width + 10, 10), + labels[1], + fill=(255, 255, 255), + font=font, + stroke_width=2, + stroke_fill=(0, 0, 0), + ) # Draw a vertical line as a separator line_color = (200, 200, 200) # Light gray @@ -958,22 +961,11 @@ def setup_scheduler(self): self.pipeline.scheduler = scheduler return scheduler - def setup_pipeline(self, validation_type, enable_ema_model: bool = True): + def setup_pipeline(self, validation_type): if hasattr(self.accelerator, "_lycoris_wrapped_network"): self.accelerator._lycoris_wrapped_network.set_multiplier( float(getattr(self.args, "validation_lycoris_strength", 1.0)) ) - if validation_type == "intermediary" and self.args.use_ema: - if enable_ema_model: - self.ema_model.store(self.trainable_parameters) - self.ema_model.copy_to(self.trainable_parameters) - if self.args.ema_device != "accelerator": - logger.info("Moving EMA weights to GPU for inference.") - self.ema_model.to(self.inference_device) - else: - logger.debug( - "Skipping EMA model setup for validation, as enable_ema_model=False." - ) if self.pipeline is None: pipeline_cls = self._pipeline_cls() @@ -1167,14 +1159,14 @@ def process_prompts(self): ) self.validation_prompt_dict[shortname] = prompt logger.debug(f"Processing validation for prompt: {prompt}") - stitched_validation_images, original_validation_images = ( + stitched_validation_images, checkpoint_validation_images, ema_validation_images = ( self.validate_prompt(prompt, shortname, validation_input_image) ) validation_images.update(stitched_validation_images) self._save_images(validation_images, shortname, prompt) logger.debug(f"Completed generating image: {prompt}") self.validation_images = validation_images - self.evaluation_result = self.evaluate_images(original_validation_images) + self.evaluation_result = self.evaluate_images(checkpoint_validation_images) self._log_validations_to_webhook(validation_images, shortname, prompt) try: self._log_validations_to_trackers(validation_images) @@ -1202,6 +1194,19 @@ def stitch_conditioning_images(self, validation_image_results, conditioning_imag return stitched_validation_images + def _validation_types(self): + types = ["checkpoint"] + if self.args.use_ema: + # ema has different validations we can add or overwrite. + if self.args.ema_validation == "ema_only": + # then we do not sample the base ckpt being trained, only the EMA weights. + types = ["ema"] + if self.args.ema_validation == "comparison": + # then we sample both. + types.append("ema") + + return types + def validate_prompt( self, prompt, validation_shortname, validation_input_image=None ): @@ -1209,16 +1214,12 @@ def validate_prompt( # Placeholder for actual image generation and logging logger.debug(f"Validating prompt: {prompt}") # benchmarked / stitched validation images - validation_images = {} + stitched_validation_images = {} # untouched / un-stitched validation images - original_validation_images = {} + checkpoint_validation_images = {} + ema_validation_images = {} for resolution in self.validation_resolutions: extra_validation_kwargs = {} - if not self.args.validation_randomize: - extra_validation_kwargs["generator"] = self._get_generator() - logger.debug( - f"Using a generator? {extra_validation_kwargs['generator']}" - ) if validation_input_image is not None: extra_validation_kwargs["image"] = validation_input_image if self.deepfloyd_stage2: @@ -1275,9 +1276,10 @@ def validate_prompt( logger.debug( f"Processing width/height: {validation_resolution_width}x{validation_resolution_height}" ) - if validation_shortname not in validation_images: - validation_images[validation_shortname] = [] - original_validation_images[validation_shortname] = [] + if validation_shortname not in stitched_validation_images: + stitched_validation_images[validation_shortname] = [] + checkpoint_validation_images[validation_shortname] = [] + ema_validation_images[validation_shortname] = [] try: extra_validation_kwargs.update(self._gather_prompt_embeds(prompt)) except Exception as e: @@ -1289,7 +1291,6 @@ def validate_prompt( continue try: - # print(f"pipeline dtype: {self.pipeline.unet.device}") pipeline_kwargs = { "prompt": None, "negative_prompt": None, @@ -1344,10 +1345,27 @@ def validate_prompt( pipeline_kwargs.pop("negative_mask")[0], dim=0 ).to(device=self.inference_device, dtype=self.weight_dtype) - original_validation_image_results = self.pipeline( - **pipeline_kwargs - ).images - validation_image_results = original_validation_image_results.copy() + validation_types = self._validation_types() + all_validation_type_results = {} + for current_validation_type in validation_types: + if not self.args.validation_randomize: + pipeline_kwargs["generator"] = self._get_generator() + logger.debug( + f"Using a generator? {pipeline_kwargs['generator']}" + ) + if current_validation_type == "ema": + self.enable_ema_for_inference() + all_validation_type_results[current_validation_type] = self.pipeline( + **pipeline_kwargs + ).images + if current_validation_type == "ema": + self.disable_ema_for_inference() + + # retrieve the default image result for stitching to controlnet inputs. + ema_image_results = all_validation_type_results.get("ema") + validation_image_results = all_validation_type_results.get("checkpoint", ema_image_results) + original_validation_image_results = validation_image_results + benchmark_image = None if self.args.controlnet: validation_image_results = self.stitch_conditioning_images( original_validation_image_results, @@ -1360,14 +1378,19 @@ def validate_prompt( validation_shortname, resolution ) if benchmark_image is not None: - # user might have added new resolutions or something. - validation_image_results[0] = self.stitch_benchmark_image( - validation_image_results[0], benchmark_image - ) - validation_images[validation_shortname].extend(validation_image_results) - original_validation_images[validation_shortname].extend( + for idx, validation_image in enumerate(validation_image_results): + validation_image_results[idx] = self.stitch_benchmark_image( + validation_image_result=validation_image, + benchmark_image=benchmark_image, + ) + + checkpoint_validation_images[validation_shortname].extend( original_validation_image_results ) + stitched_validation_images[validation_shortname].extend(validation_image_results) + ema_validation_images[validation_shortname].extend(ema_image_results) + + except Exception as e: import traceback @@ -1375,8 +1398,15 @@ def validate_prompt( f"Error generating validation image: {e}, {traceback.format_exc()}" ) continue + if self.args.use_ema and self.args.ema_validation == "comparison" and benchmark_image is not None: + for idx, validation_image in enumerate(stitched_validation_images[validation_shortname]): + stitched_validation_images[validation_shortname][idx] = self.stitch_benchmark_image( + validation_image_result=ema_validation_images[validation_shortname][idx], + benchmark_image=stitched_validation_images[validation_shortname][idx], + labels=[None, "EMA"] + ) - return validation_images, original_validation_images + return stitched_validation_images, checkpoint_validation_images, ema_validation_images def _save_images(self, validation_images, validation_shortname, validation_prompt): validation_img_idx = 0 @@ -1490,20 +1520,46 @@ def _log_validations_to_trackers(self, validation_images): # Log all images in one call to prevent the global step from ticking tracker.log(gallery_images, step=StateTracker.get_global_step()) - def finalize_validation(self, validation_type, enable_ema_model: bool = True): - """Cleans up and restores original state if necessary.""" - if validation_type == "intermediary" and self.args.use_ema: - if enable_ema_model: - if self.unet is not None: - self.ema_model.restore(self.unet.parameters()) - if self.transformer is not None: - self.ema_model.restore(self.transformer.parameters()) - if self.args.ema_device != "accelerator": - self.ema_model.to(self.args.ema_device) + def enable_ema_for_inference(self): + if self.ema_enabled: + logger.info("EMA already on GPU.") + return + if self.args.use_ema: + self.ema_enabled = True + if self.args.model_type == "lora" and self.args.lora_type.lower() == "lycoris": + self.accelerator._lycoris_wrapped_network.set_multiplier(1.0) + self.ema_model.store(self.accelerator._lycoris_wrapped_network.parameters()) + self.ema_model.copy_to(self.accelerator._lycoris_wrapped_network.parameters()) else: - logger.debug( - "Skipping EMA model restoration for validation, as enable_ema_model=False." - ) + self.ema_model.store(self.trainable_parameters) + self.ema_model.copy_to(self.trainable_parameters) + if self.args.ema_device != "accelerator": + logger.info("Moving EMA weights to GPU for inference.") + self.ema_model.to(self.inference_device) + else: + logger.debug( + "Skipping EMA model setup for validation, as enable_ema_model=False." + ) + + def disable_ema_for_inference(self): + if not self.ema_enabled: + return + if self.args.use_ema: + if self.args.model_type == "lora" and self.args.lora_type.lower() == "lycoris": + self.accelerator._lycoris_wrapped_network.set_multiplier(1.0) + self.ema_enabled = False + self.ema_model.restore(self.trainable_parameters) + if self.args.ema_device != "accelerator": + logger.info("Moving EMA weights to CPU for storage.") + self.ema_model.to(self.args.ema_device) + else: + logger.debug( + "Skipping EMA model restoration for validation, as enable_ema_model=False." + ) + + + def finalize_validation(self, validation_type): + """Cleans up and restores original state if necessary.""" if not self.args.keep_vae_loaded and not self.args.vae_cache_ondemand: self.vae = self.vae.to("cpu") self.vae = None diff --git a/helpers/webhooks/handler.py b/helpers/webhooks/handler.py index 51c4ab85..a4ca01ad 100644 --- a/helpers/webhooks/handler.py +++ b/helpers/webhooks/handler.py @@ -100,6 +100,10 @@ def _send_request( def _prepare_images(self, images: list): """Convert images to file objects for Discord uploads.""" files = {} + if not images: + return files + if type(images) is not list: + raise ValueError(f"Images must be a list of PIL images. Received: {images}") if images: for index, img in enumerate(images): img_byte_array = BytesIO() From 9f7fc68823de8154bb601fe9a8a236749d3bf419 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 18 Nov 2024 19:47:31 +0000 Subject: [PATCH 04/13] fix EMA comparison mode during validations by correctly unloading/restoring the lycoris weights during ema enable/disable --- .../training/default_settings/safety_check.py | 9 ++++++ helpers/training/validation.py | 32 +++++++++++++++---- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index 7b8d92de..5316ff6c 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -117,3 +117,12 @@ def safety_check(args, accelerator): f"--flux_schedule_auto_shift cannot be combined with --flux_schedule_shift. Please set --flux_schedule_shift to 0 if you want to train with --flux_schedule_auto_shift." ) sys.exit(1) + + if args.use_ema and args.model_type == "lora" and args.lora_type.lower() == "standard": + if args.ema_validation != "none": + logger.error( + "EMA validation is only supported via full rank training or Lycoris." + " To continue with Standard PEFT LoRA, set --ema_validation=none in your config file." + " You can still use EMA with Standard PEFT LoRA, but will be unable to see its outputs during training." + ) + sys.exit(1) \ No newline at end of file diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 3ad6a234..4cd9a097 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -1364,6 +1364,8 @@ def validate_prompt( # retrieve the default image result for stitching to controlnet inputs. ema_image_results = all_validation_type_results.get("ema") validation_image_results = all_validation_type_results.get("checkpoint", ema_image_results) + print(f"ema image results: {ema_image_results}") + print(f"validation_image_results: {validation_image_results}") original_validation_image_results = validation_image_results benchmark_image = None if self.args.controlnet: @@ -1379,6 +1381,7 @@ def validate_prompt( ) if benchmark_image is not None: for idx, validation_image in enumerate(validation_image_results): + print(f"stitching benchmark image {benchmark_image} to {validation_image}") validation_image_results[idx] = self.stitch_benchmark_image( validation_image_result=validation_image, benchmark_image=benchmark_image, @@ -1389,6 +1392,7 @@ def validate_prompt( ) stitched_validation_images[validation_shortname].extend(validation_image_results) ema_validation_images[validation_shortname].extend(ema_image_results) + print(f"Generated {len(validation_image_results)} images, {len(ema_validation_images[validation_shortname])} EMA images, {len(stitched_validation_images[validation_shortname])} for {validation_shortname}") except Exception as e: @@ -1400,6 +1404,7 @@ def validate_prompt( continue if self.args.use_ema and self.args.ema_validation == "comparison" and benchmark_image is not None: for idx, validation_image in enumerate(stitched_validation_images[validation_shortname]): + print(f"idx={idx} stitching EMA image {ema_validation_images[validation_shortname][idx]} to {stitched_validation_images[validation_shortname][idx]}") stitched_validation_images[validation_shortname][idx] = self.stitch_benchmark_image( validation_image_result=ema_validation_images[validation_shortname][idx], benchmark_image=stitched_validation_images[validation_shortname][idx], @@ -1522,39 +1527,52 @@ def _log_validations_to_trackers(self, validation_images): def enable_ema_for_inference(self): if self.ema_enabled: - logger.info("EMA already on GPU.") + logger.info("EMA already enabled. Not enabling EMA.") return if self.args.use_ema: + logger.info("Enabling EMA.") self.ema_enabled = True if self.args.model_type == "lora" and self.args.lora_type.lower() == "lycoris": + logger.info("Setting Lycoris multiplier to 1.0") self.accelerator._lycoris_wrapped_network.set_multiplier(1.0) + logger.info("Storing Lycoris weights for later recovery.") self.ema_model.store(self.accelerator._lycoris_wrapped_network.parameters()) + logger.info("Storing the EMA weights into the Lycoris adapter for inference.") self.ema_model.copy_to(self.accelerator._lycoris_wrapped_network.parameters()) else: + logger.info("Storing EMA weights for later recovery.") self.ema_model.store(self.trainable_parameters) + logger.info("Storing the EMA weights into the model for inference.") self.ema_model.copy_to(self.trainable_parameters) if self.args.ema_device != "accelerator": logger.info("Moving EMA weights to GPU for inference.") self.ema_model.to(self.inference_device) else: - logger.debug( - "Skipping EMA model setup for validation, as enable_ema_model=False." + logger.info( + "Skipping EMA model setup for validation, as we are not using EMA." ) def disable_ema_for_inference(self): if not self.ema_enabled: + logger.info("EMA was not enabled. Not disabling EMA.") return if self.args.use_ema: + logger.info("Disabling EMA.") + self.ema_enabled = False if self.args.model_type == "lora" and self.args.lora_type.lower() == "lycoris": + logger.info("Setting Lycoris network multiplier to 1.0.") self.accelerator._lycoris_wrapped_network.set_multiplier(1.0) - self.ema_enabled = False - self.ema_model.restore(self.trainable_parameters) + logger.info("Restoring Lycoris weights.") + self.ema_model.restore(self.accelerator._lycoris_wrapped_network.parameters()) + else: + logger.info("Restoring trainable parameters.") + self.ema_model.restore(self.trainable_parameters) if self.args.ema_device != "accelerator": logger.info("Moving EMA weights to CPU for storage.") self.ema_model.to(self.args.ema_device) else: - logger.debug( - "Skipping EMA model restoration for validation, as enable_ema_model=False." + logger.info( + "Skipping EMA model restoration for validation, as we are not using EMA." ) From ffea1788ffb6d0f36b0b4ce636e635c671aecabb Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 19 Nov 2024 02:06:12 +0000 Subject: [PATCH 05/13] add EMA information to model card --- helpers/publishing/huggingface.py | 23 +++++++++++++++++++++++ helpers/publishing/metadata.py | 18 +++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/helpers/publishing/huggingface.py b/helpers/publishing/huggingface.py index 33034086..2e74b184 100644 --- a/helpers/publishing/huggingface.py +++ b/helpers/publishing/huggingface.py @@ -11,6 +11,7 @@ LORA_SAFETENSORS_FILENAME = "pytorch_lora_weights.safetensors" +EMA_SAFETENSORS_FILENAME = "ema_model.safetensors" class HubManager: @@ -110,6 +111,8 @@ def upload_model(self, validation_images, webhook_handler=None, override_path=No self.upload_full_model(override_path=override_path) else: self.upload_lora_model(override_path=override_path) + if self.config.use_ema: + self.upload_ema_model(override_path=override_path) break except Exception as e: if webhook_handler: @@ -155,6 +158,26 @@ def upload_lora_model(self, override_path=None): except Exception as e: logger.error(f"Failed to upload LoRA weights to hub: {e}") + def upload_ema_model(self, override_path=None): + try: + check_ema_paths = ["transformer_ema", "unet_ema", "controlnet_ema", "ema"] + # if any of the folder names are present in the checkpoint dir, we will upload them too + for check_ema_path in check_ema_paths: + print(f"Checking for EMA path: {check_ema_path}") + ema_path = os.path.join( + override_path or self.config.output_dir, check_ema_path + ) + if os.path.exists(ema_path): + print(f"Found EMA checkpoint!") + upload_folder( + repo_id=self._repo_id, + folder_path=ema_path, + path_in_repo="/ema", + commit_message="LoRA EMA checkpoint auto-generated by SimpleTuner", + ) + except Exception as e: + logger.error(f"Failed to upload LoRA EMA weights to hub: {e}") + def find_latest_checkpoint(self): checkpoints = list(Path(self.config.output_dir).rglob("checkpoint-*")) highest_checkpoint_value = None diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index 10011c96..8f62cba6 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -105,6 +105,20 @@ def _model_imports(args): return f"{output}" +def ema_info(args): + if args.use_ema: + ema_information = """ +## Elastic Moving Average (EMA) + +SimpleTuner generates a safetensors variant of the EMA weights and a pt file. + +The safetensors file is intended to be used for inference, and the pt file is for continuing finetuning. + +The EMA model may provide a more well-rounded result, but typically will feel undertrained compared to the full model as it is a running decayed average of the model weights. +""" + return ema_information + return "" + def lycoris_download_info(): """output a function to download the adapter""" output_fn = """ @@ -145,7 +159,7 @@ def _model_load(args, repo_id: str = None): output = ( f"model_id = '{args.pretrained_model_name_or_path}'" f"\nadapter_id = '{repo_id if repo_id is not None else args.output_dir}'" - f"\npipeline = DiffusionPipeline.from_pretrained(model_id), torch_dtype={StateTracker.get_weight_dtype()}) # loading directly in bf16" + f"\npipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype={StateTracker.get_weight_dtype()}) # loading directly in bf16" f"\npipeline.load_lora_weights(adapter_id)" ) elif args.lora_type.lower() == "lycoris": @@ -558,3 +572,5 @@ def save_model_card( logger.debug(f"Model Card:\n{model_card_content}") with open(os.path.join(repo_folder, "README.md"), "w", encoding="utf-8") as f: f.write(yaml_content + model_card_content) + +{ema_info(args=StateTracker.get_args())} \ No newline at end of file From 4c7715622d919f80dfa9451226a2ed7c4f5e2d3c Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 19 Nov 2024 02:06:39 +0000 Subject: [PATCH 06/13] allow PEFT LoRA to use EMA validations --- helpers/training/default_settings/safety_check.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index 5316ff6c..f586972c 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -116,13 +116,4 @@ def safety_check(args, accelerator): logger.error( f"--flux_schedule_auto_shift cannot be combined with --flux_schedule_shift. Please set --flux_schedule_shift to 0 if you want to train with --flux_schedule_auto_shift." ) - sys.exit(1) - - if args.use_ema and args.model_type == "lora" and args.lora_type.lower() == "standard": - if args.ema_validation != "none": - logger.error( - "EMA validation is only supported via full rank training or Lycoris." - " To continue with Standard PEFT LoRA, set --ema_validation=none in your config file." - " You can still use EMA with Standard PEFT LoRA, but will be unable to see its outputs during training." - ) - sys.exit(1) \ No newline at end of file + sys.exit(1) \ No newline at end of file From 37cc1aca59cdc28dcf7c1bb4254cb1a6ef029f98 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 19 Nov 2024 02:07:11 +0000 Subject: [PATCH 07/13] save EMA weights for Lycoris and LoRA to mirror the proper safetensors formats --- helpers/training/save_hooks.py | 42 +++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index f9b27204..4b04590c 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -187,6 +187,14 @@ def __init__( if rank > 0: self.training_state_path = f"training_state-rank{rank}.json" + def _primary_model(self): + if self.args.controlnet: + return self.controlnet + if self.unet is not None: + return self.unet + if self.transformer is not None: + return self.transformer + def _save_lora(self, models, weights, output_dir): # for SDXL/others, there are only two options here. Either are just the unet attn processor layers # or there are the unet and text encoder atten layers. @@ -197,6 +205,24 @@ def _save_lora(self, models, weights, output_dir): # Diffusers does not train the third text encoder. # text_encoder_3_lora_layers_to_save = None + if self.args.use_ema: + # we'll temporarily overwrite teh LoRA parameters with the EMA parameters to save it. + logger.info("Saving EMA model to disk.") + trainable_parameters = [ + p + for p in self.ema_model.parameters() + if p.requires_grad + ] + self.ema_model.store(trainable_parameters) + self.ema_model.copy_to(trainable_parameters) + self.pipeline_class.save_lora_weights( + os.path.join(output_dir, "ema"), + transformer_lora_layers=convert_state_dict_to_diffusers( + get_peft_model_state_dict(self._primary_model()) + ), + ) + self.ema_model.restore(trainable_parameters) + for model in models: if isinstance(model, type(unwrap_model(self.accelerator, self.unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers( @@ -264,7 +290,7 @@ def _save_lycoris(self, models, weights, output_dir): save wrappers for lycoris. For now, text encoders are not trainable via lycoris. """ - from helpers.publishing.huggingface import LORA_SAFETENSORS_FILENAME + from helpers.publishing.huggingface import LORA_SAFETENSORS_FILENAME, EMA_SAFETENSORS_FILENAME for _ in models: if weights: @@ -279,6 +305,19 @@ def _save_lycoris(self, models, weights, output_dir): list(self.accelerator._lycoris_wrapped_network.parameters())[0].dtype, {"lycoris_config": json.dumps(lycoris_config)}, # metadata ) + if self.args.use_ema: + # we'll store lycoris weights. + self.ema_model.store(self.accelerator._lycoris_wrapped_network.parameters()) + # we'll write EMA to the lycoris adapter temporarily. + self.ema_model.copy_to(self.accelerator._lycoris_wrapped_network.parameters()) + # now we can write the lycoris weights using the EMA_SAFETENSORS_FILENAME instead. + os.makedirs(os.path.join(output_dir, "ema"), exist_ok=True) + self.accelerator._lycoris_wrapped_network.save_weights( + os.path.join(output_dir, "ema", EMA_SAFETENSORS_FILENAME), + list(self.accelerator._lycoris_wrapped_network.parameters())[0].dtype, + {"lycoris_config": json.dumps(lycoris_config)}, # metadata + ) + self.ema_model.restore(self.accelerator._lycoris_wrapped_network.parameters()) # copy the config into the repo shutil.copy2( @@ -336,6 +375,7 @@ def save_model_hook(self, models, weights, output_dir): if not self.accelerator.is_main_process: return if self.args.use_ema: + # we'll save this EMA checkpoint for restoring the state easier. ema_model_path = os.path.join( output_dir, self.ema_model_subdir, "ema_model.pt" ) From c1b7ccadff1bff3e37a22d65fac31fef9fdb7d98 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 19 Nov 2024 02:07:57 +0000 Subject: [PATCH 08/13] validation for PEFT LoRA EMA weights --- helpers/training/validation.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 4cd9a097..72c55872 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -1525,20 +1525,35 @@ def _log_validations_to_trackers(self, validation_images): # Log all images in one call to prevent the global step from ticking tracker.log(gallery_images, step=StateTracker.get_global_step()) - def enable_ema_for_inference(self): + def _primary_model(self): + if self.args.controlnet: + return self.controlnet + if self.unet is not None: + return self.unet + if self.transformer is not None: + return self.transformer + + def enable_ema_for_inference(self, pipeline=None): if self.ema_enabled: logger.info("EMA already enabled. Not enabling EMA.") return if self.args.use_ema: logger.info("Enabling EMA.") self.ema_enabled = True - if self.args.model_type == "lora" and self.args.lora_type.lower() == "lycoris": - logger.info("Setting Lycoris multiplier to 1.0") - self.accelerator._lycoris_wrapped_network.set_multiplier(1.0) - logger.info("Storing Lycoris weights for later recovery.") - self.ema_model.store(self.accelerator._lycoris_wrapped_network.parameters()) - logger.info("Storing the EMA weights into the Lycoris adapter for inference.") - self.ema_model.copy_to(self.accelerator._lycoris_wrapped_network.parameters()) + if self.args.model_type == "lora": + if self.args.lora_type.lower() == "lycoris": + logger.info("Setting Lycoris multiplier to 1.0") + self.accelerator._lycoris_wrapped_network.set_multiplier(1.0) + logger.info("Storing Lycoris weights for later recovery.") + self.ema_model.store(self.accelerator._lycoris_wrapped_network.parameters()) + logger.info("Storing the EMA weights into the Lycoris adapter for inference.") + self.ema_model.copy_to(self.accelerator._lycoris_wrapped_network.parameters()) + elif self.args.lora_type.lower() == "standard": + self.trainable_parameters = [ + x for x in self._primary_model().parameters() if x.requires_grad + ] + self.ema_model.store(self.trainable_parameters) + self.ema_model.copy_to(self.trainable_parameters) else: logger.info("Storing EMA weights for later recovery.") self.ema_model.store(self.trainable_parameters) @@ -1567,6 +1582,7 @@ def disable_ema_for_inference(self): else: logger.info("Restoring trainable parameters.") self.ema_model.restore(self.trainable_parameters) + self.trainable_parameters = None if self.args.ema_device != "accelerator": logger.info("Moving EMA weights to CPU for storage.") self.ema_model.to(self.args.ema_device) From 9ecd59d1bc62a2673fcc4e343a24f8397cb3a4a3 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 19 Nov 2024 02:13:05 +0000 Subject: [PATCH 09/13] fix model card info --- helpers/publishing/metadata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index 8f62cba6..8f53c565 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -566,11 +566,11 @@ def save_model_card( ## Inference {code_example(args=StateTracker.get_args(), repo_id=repo_id)} + +{ema_info(args=StateTracker.get_args())} """ logger.debug(f"YAML:\n{yaml_content}") logger.debug(f"Model Card:\n{model_card_content}") with open(os.path.join(repo_folder, "README.md"), "w", encoding="utf-8") as f: f.write(yaml_content + model_card_content) - -{ema_info(args=StateTracker.get_args())} \ No newline at end of file From 2118dccc96e62f993bc308a7bf9fd610716a9b87 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 19 Nov 2024 03:39:25 +0000 Subject: [PATCH 10/13] ema: validation, saving and restoring for full model training --- helpers/publishing/metadata.py | 2 +- helpers/training/ema.py | 6 ++++-- helpers/training/save_hooks.py | 9 +++++++-- helpers/training/trainer.py | 6 +++--- helpers/training/validation.py | 35 ++++++++++++++++++++++++---------- 5 files changed, 40 insertions(+), 18 deletions(-) diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index 8f53c565..a4e27500 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -404,7 +404,7 @@ def ddpm_schedule_info(args): f"training_scheduler_timestep_spacing={args.training_scheduler_timestep_spacing}" ) output_args.append( - f"validation_scheduler_timestep_spacing={args.validation_scheduler_timestep_spacing}" + f"inference_scheduler_timestep_spacing={args.inference_scheduler_timestep_spacing}" ) output_str = ( f" (extra parameters={output_args})" diff --git a/helpers/training/ema.py b/helpers/training/ema.py index fa8f4f0c..d2153796 100644 --- a/helpers/training/ema.py +++ b/helpers/training/ema.py @@ -201,7 +201,7 @@ def save_pretrained(self, path, max_shard_size: str = "10GB"): ) model = self.model_cls.from_config(self.model_config) - state_dict = self.state_dict() + state_dict = self.state_dict(exclude_params=True) state_dict.pop("shadow_params", None) model.register_to_config(**state_dict) @@ -374,7 +374,7 @@ def cuda(self, device=None): def cpu(self): return self.to(device="cpu") - def state_dict(self, destination=None, prefix="", keep_vars=False): + def state_dict(self, destination=None, prefix="", keep_vars=False, exclude_params: bool = False): r""" Returns a dictionary containing a whole state of the EMA model. """ @@ -387,6 +387,8 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): "inv_gamma": self.inv_gamma, "power": self.power, } + if exclude_params: + return state_dict for idx, param in enumerate(self.shadow_params): state_dict[f"{prefix}shadow_params.{idx}"] = ( param if keep_vars else param.detach() diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index 4b04590c..876d4200 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -331,7 +331,8 @@ def _save_full_model(self, models, weights, output_dir): temporary_dir = output_dir.replace("checkpoint", "temporary") os.makedirs(temporary_dir, exist_ok=True) - if self.args.use_ema: + if self.args.use_ema and self.accelerator.is_main_process: + # even with deepspeed, EMA should only save on the main process. ema_model_path = os.path.join( temporary_dir, self.ema_model_subdir, "ema_model.pt" ) @@ -340,7 +341,11 @@ def _save_full_model(self, models, weights, output_dir): self.ema_model.save_state_dict(ema_model_path) except Exception as e: logger.error(f"Error saving EMA model: {e}") - + logger.info(f"Saving EMA safetensors variant.") + self.ema_model.save_pretrained( + os.path.join(temporary_dir, self.ema_model_subdir), + max_shard_size="10GB", + ) if self.unet is not None: sub_dir = "unet" if self.transformer is not None: diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 786bc0df..a510027d 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1394,7 +1394,7 @@ def init_validations(self): return model_evaluator = ModelEvaluator.from_config(args=self.config) self.validation = Validation( - trainable_parameters=self._get_trainable_parameters(), + trainable_parameters=self._get_trainable_parameters, accelerator=self.accelerator, unet=self.unet, transformer=self.transformer, @@ -1416,6 +1416,7 @@ def init_validations(self): vae=self.vae, controlnet=self.controlnet if self.config.controlnet else None, model_evaluator=model_evaluator, + is_deepspeed=self.config.use_deepspeed_optimizer, ) if not self.config.train_text_encoder and self.validation is not None: self.validation.clear_text_encoders() @@ -1533,7 +1534,7 @@ def init_resume_checkpoint(self, lr_scheduler): logger.debug(f"Training state inside checkpoint: {training_state_in_ckpt}") if hasattr(lr_scheduler, "last_step"): lr_scheduler.last_step = self.state["global_resume_step"] - logger.info(f"Resuming from global_step {self.state['global_resume_step']}).") + logger.info(f"Resuming from global_step {self.state['global_resume_step']}.") # Log the current state of each data backend. for _, backend in StateTracker.get_data_backends().items(): @@ -2648,7 +2649,6 @@ def train(self): ema_decay_value = "None (EMA not in use)" if self.config.use_ema: if self.ema_model is not None: - training_logger.debug("Stepping EMA forward") self.ema_model.step( parameters=self._get_trainable_parameters(), global_step=self.state["global_step"], diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 72c55872..1f791d0b 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -2,6 +2,7 @@ import os import wandb import logging +import sys import numpy as np from tqdm import tqdm from helpers.training.wrappers import unwrap_model @@ -452,6 +453,14 @@ def __init__( and self.args.flow_matching_loss != "diffusion" ) or self.args.model_family == "flux" self.deepspeed = is_deepspeed + if is_deepspeed: + if args.use_ema: + if args.ema_validation != "none": + logger.error( + "EMA validation is not supported via DeepSpeed." + " Please use --ema_validation=none or disable DeepSpeed." + ) + sys.exit(1) self.inference_device = ( accelerator.device if not is_deepspeed @@ -1549,19 +1558,24 @@ def enable_ema_for_inference(self, pipeline=None): logger.info("Storing the EMA weights into the Lycoris adapter for inference.") self.ema_model.copy_to(self.accelerator._lycoris_wrapped_network.parameters()) elif self.args.lora_type.lower() == "standard": - self.trainable_parameters = [ + _trainable_parameters = [ x for x in self._primary_model().parameters() if x.requires_grad ] - self.ema_model.store(self.trainable_parameters) - self.ema_model.copy_to(self.trainable_parameters) + self.ema_model.store(_trainable_parameters) + self.ema_model.copy_to(_trainable_parameters) else: + # if self.args.ema_device != "accelerator": + # logger.info("Moving checkpoint to CPU for storage.") + # self._primary_model().to("cpu") logger.info("Storing EMA weights for later recovery.") - self.ema_model.store(self.trainable_parameters) + self.ema_model.store(self.trainable_parameters()) logger.info("Storing the EMA weights into the model for inference.") - self.ema_model.copy_to(self.trainable_parameters) - if self.args.ema_device != "accelerator": - logger.info("Moving EMA weights to GPU for inference.") - self.ema_model.to(self.inference_device) + self.ema_model.copy_to(self.trainable_parameters()) + # if self.args.ema_device != "accelerator": + # logger.info("Moving checkpoint to CPU for storage.") + # self._primary_model().to("cpu") + # logger.info("Moving EMA weights to GPU for inference.") + # self.ema_model.to(self.inference_device) else: logger.info( "Skipping EMA model setup for validation, as we are not using EMA." @@ -1581,11 +1595,12 @@ def disable_ema_for_inference(self): self.ema_model.restore(self.accelerator._lycoris_wrapped_network.parameters()) else: logger.info("Restoring trainable parameters.") - self.ema_model.restore(self.trainable_parameters) - self.trainable_parameters = None + self.ema_model.restore(self.trainable_parameters()) if self.args.ema_device != "accelerator": logger.info("Moving EMA weights to CPU for storage.") self.ema_model.to(self.args.ema_device) + self._primary_model().to(self.inference_device) + else: logger.info( "Skipping EMA model restoration for validation, as we are not using EMA." From 6790c258e705cc9764d1c284b4a5f700dfbc04c2 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 19 Nov 2024 03:40:56 +0000 Subject: [PATCH 11/13] ema lora save for sdxl unet --- helpers/training/save_hooks.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index 876d4200..dbb33b68 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -215,12 +215,20 @@ def _save_lora(self, models, weights, output_dir): ] self.ema_model.store(trainable_parameters) self.ema_model.copy_to(trainable_parameters) - self.pipeline_class.save_lora_weights( - os.path.join(output_dir, "ema"), - transformer_lora_layers=convert_state_dict_to_diffusers( - get_peft_model_state_dict(self._primary_model()) - ), - ) + if self.transformer is not None: + self.pipeline_class.save_lora_weights( + os.path.join(output_dir, "ema"), + transformer_lora_layers=convert_state_dict_to_diffusers( + get_peft_model_state_dict(self._primary_model()) + ), + ) + elif self.unet is not None: + self.pipeline_class.save_lora_weights( + os.path.join(output_dir, "ema"), + unet_lora_layers=convert_state_dict_to_diffusers( + get_peft_model_state_dict(self._primary_model()) + ), + ) self.ema_model.restore(trainable_parameters) for model in models: From 8dda33d5651c5cc5ca6f38708db369d21277b87a Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 19 Nov 2024 04:04:42 +0000 Subject: [PATCH 12/13] fix non-ema validations --- helpers/configuration/cmd_args.py | 2 +- helpers/training/validation.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 223e52fb..7ea0dc61 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1111,7 +1111,7 @@ def get_argument_parser(): parser.add_argument( "--use_ema", action="store_true", - help="Whether to use EMA (exponential moving average) model.", + help="Whether to use EMA (exponential moving average) model. Works with LoRA, Lycoris, and full training.", ) parser.add_argument( "--ema_device", diff --git a/helpers/training/validation.py b/helpers/training/validation.py index 1f791d0b..fe001cf2 100644 --- a/helpers/training/validation.py +++ b/helpers/training/validation.py @@ -1373,8 +1373,6 @@ def validate_prompt( # retrieve the default image result for stitching to controlnet inputs. ema_image_results = all_validation_type_results.get("ema") validation_image_results = all_validation_type_results.get("checkpoint", ema_image_results) - print(f"ema image results: {ema_image_results}") - print(f"validation_image_results: {validation_image_results}") original_validation_image_results = validation_image_results benchmark_image = None if self.args.controlnet: @@ -1390,7 +1388,6 @@ def validate_prompt( ) if benchmark_image is not None: for idx, validation_image in enumerate(validation_image_results): - print(f"stitching benchmark image {benchmark_image} to {validation_image}") validation_image_results[idx] = self.stitch_benchmark_image( validation_image_result=validation_image, benchmark_image=benchmark_image, @@ -1400,9 +1397,8 @@ def validate_prompt( original_validation_image_results ) stitched_validation_images[validation_shortname].extend(validation_image_results) - ema_validation_images[validation_shortname].extend(ema_image_results) - print(f"Generated {len(validation_image_results)} images, {len(ema_validation_images[validation_shortname])} EMA images, {len(stitched_validation_images[validation_shortname])} for {validation_shortname}") - + if self.args.use_ema: + ema_validation_images[validation_shortname].extend(ema_image_results) except Exception as e: import traceback @@ -1413,7 +1409,6 @@ def validate_prompt( continue if self.args.use_ema and self.args.ema_validation == "comparison" and benchmark_image is not None: for idx, validation_image in enumerate(stitched_validation_images[validation_shortname]): - print(f"idx={idx} stitching EMA image {ema_validation_images[validation_shortname][idx]} to {stitched_validation_images[validation_shortname][idx]}") stitched_validation_images[validation_shortname][idx] = self.stitch_benchmark_image( validation_image_result=ema_validation_images[validation_shortname][idx], benchmark_image=stitched_validation_images[validation_shortname][idx], From fc68c9eb60355b820cc523ab84b464e9f55cbcc0 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 19 Nov 2024 04:07:02 +0000 Subject: [PATCH 13/13] update EMA docs --- README.md | 6 +++--- documentation/DEEPSPEED.md | 4 +--- documentation/DREAMBOOTH.md | 6 ++++++ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index b332a696..bafd385c 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ For multi-node distributed training, [this guide](/documentation/DISTRIBUTED.md) - LoRA/LyCORIS training for PixArt, SDXL, SD3, and SD 2.x that uses less than 16G VRAM - DeepSpeed integration allowing for [training SDXL's full u-net on 12G of VRAM](/documentation/DEEPSPEED.md), albeit very slowly. - Quantised NF4/INT8/FP8 LoRA training, using low-precision base model to reduce VRAM consumption. -- Optional EMA (Exponential moving average) weight network to counteract model overfitting and improve training stability. **Note:** This does not apply to LoRA. +- Optional EMA (Exponential moving average) weight network to counteract model overfitting and improve training stability. - Train directly from an S3-compatible storage provider, eliminating the requirement for expensive local storage. (Tested with Cloudflare R2 and Wasabi S3) - For only SDXL and SD 1.x/2.x, full [ControlNet model training](/documentation/CONTROLNET.md) (not ControlLoRA or ControlLite) - Training [Mixture of Experts](/documentation/MIXTURE_OF_EXPERTS.md) for lightweight, high-quality diffusion models @@ -137,8 +137,8 @@ Flux prefers being trained with multiple large GPUs but a single 16G card should - A100-80G (EMA, large batches, LoRA @ insane batch sizes) - A6000-48G (EMA@768px, no EMA@1024px, LoRA @ high batch sizes) -- A100-40G (no EMA@1024px, no EMA@768px, EMA@512px, LoRA @ high batch sizes) -- 4090-24G (no EMA@1024px, batch size 1-4, LoRA @ medium-high batch sizes) +- A100-40G (EMA@1024px, EMA@768px, EMA@512px, LoRA @ high batch sizes) +- 4090-24G (EMA@1024px, batch size 1-4, LoRA @ medium-high batch sizes) - 4080-12G (LoRA @ low-medium batch sizes) ### Stable Diffusion 2.x, 768px diff --git a/documentation/DEEPSPEED.md b/documentation/DEEPSPEED.md index 2d272459..a29619e9 100644 --- a/documentation/DEEPSPEED.md +++ b/documentation/DEEPSPEED.md @@ -162,6 +162,4 @@ While EMA is a great way to smooth out gradients and improve generalisation abil EMA holds a shadow copy of the model parameters in memory, essentially doubling the footprint of the model. For SimpleTuner, EMA is not passed through the Accelerator module, which means it is not impacted by DeepSpeed. This means the memory savings that we saw with the base U-net, are not realised with the EMA model. -That said, any memory savings that applied to the base U-net could possibly allow EMA weights to load and operate effectively. - -Future work is planned to allow EMA to run on the CPU only. \ No newline at end of file +However, by default, the EMA model is kept on CPU. \ No newline at end of file diff --git a/documentation/DREAMBOOTH.md b/documentation/DREAMBOOTH.md index 46cf3892..e5e82cd8 100644 --- a/documentation/DREAMBOOTH.md +++ b/documentation/DREAMBOOTH.md @@ -222,6 +222,12 @@ Alternatively, one might use the real name of their subject, or a 'similar enoug After a number of training experiments, it seems as though a 'similar enough' celebrity is the best choice, especially if prompting the model for the person's real name ends up looking dissimilar. +# Exponential moving average (EMA) + +A second model can be trained in parallel to your checkpoint, nearly for free - only the resulting system memory (by default) is consumed, rather than more VRAM. + +Applying `use_ema=true` in your config file will enable this feature. + # CLIP score tracking If you wish to enable evaluations to score the model's performance, see [this document](/documentation/evaluation/CLIP_SCORES.md) for information on configuring and interpreting CLIP scores.