diff --git a/OPTIONS.md b/OPTIONS.md index 29eb601b..67f1b1ec 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -316,6 +316,16 @@ A lot of settings are instead set through the [dataloader config](/documentation - **Why**: Enables integration with platforms like TensorBoard, wandb, or comet_ml for monitoring. Use multiple values separated by a comma to report to multiple trackers; - **Choices**: wandb, tensorboard, comet_ml +# Environment configuration variables + +The above options apply for the most part, to `config.json` - but some entries must be set inside `config.env` instead. + +- `TRAINING_NUM_PROCESSES` should be set to the number of GPUs in the system. For most use-cases, this is enough to enable DistributedDataParallel (DDP) training +- `TRAINING_DYNAMO_BACKEND` defaults to `no` but can be set to `inductor` for substantial speed improvements on NVIDIA hardware +- `SIMPLETUNER_LOG_LEVEL` defaults to `INFO` but can be set to `DEBUG` to add more information for issue reports into `debug.log` +- `VENV_PATH` can be set to the location of your python virtual env, if it is not in the typical `.venv` location +- `ACCELERATE_EXTRA_ARGS` can be left unset, or, contain extra arguments to add like `--multi_gpu` or FSDP-specific flags + --- This is a basic overview meant to help you get started. For a complete list of options and more detailed explanations, please refer to the full specification: diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 94639341..4e4f61c9 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -141,6 +141,8 @@ There, you will possibly need to modify the following variables: - `optimizer` - Beginners are recommended to stick with adamw_bf16, though optimi-lion and optimi-stableadamw are also good choices. - `mixed_precision` - Beginners should keep this in `bf16` +Multi-GPU users can reference [this document](/OPTIONS.md#environment-configuration-variables) for information on configuring the number of GPUs to use. + #### Validation prompts Inside `config/config.json` is the "primary validation prompt", which is typically the main instance_prompt you are training on for your single subject or style. Additionally, a JSON file may be created that contains extra prompts to run through during validations. diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index ce812d65..4be1ffce 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1553,6 +1553,13 @@ def get_argument_parser(): " Using 'fp8-quanto' will require Quanto for quantisation (Apple Silicon, NVIDIA, AMD)." ), ) + parser.add_argument( + "--quantize_activations", + action="store_true", + help=( + "(EXPERIMENTAL) This option is currently unsupported, and exists solely for development purposes." + ), + ) parser.add_argument( "--base_model_default_dtype", type=str, diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index bd21e7ab..cc7ec57e 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -23,9 +23,9 @@ def safety_check(args, accelerator): " Use LORA_TYPE (--lora_type) lycoris for quantised multi-GPU training of LoKr models in FP8." ) args.base_model_precision = "int8-quanto" - # sys.exit(1) + if ( - args.base_model_precision in ["fp8-quanto", "int4-quanto"] + (args.base_model_precision in ["fp8-quanto", "int4-quanto"] or (args.base_model_precision != "no_change" and args.quantize_activations)) and accelerator.state.dynamo_plugin.backend.lower() == "inductor" ): logger.warning( diff --git a/helpers/training/quantisation/__init__.py b/helpers/training/quantisation/__init__.py index 5089bbcd..8d332399 100644 --- a/helpers/training/quantisation/__init__.py +++ b/helpers/training/quantisation/__init__.py @@ -9,17 +9,44 @@ logger.setLevel(logging.ERROR) -def _quanto_model(model, model_precision, base_model_precision=None): +def _quanto_type_map(model_precision: str): + if model_precision == "no_change": + return None + from optimum.quanto import ( + qfloat8, + qfloat8_e4m3fnuz, + qint8, + qint4, + qint2, + ) + if model_precision == "int2-quanto": + quant_level = qint2 + elif model_precision == "int4-quanto": + quant_level = qint4 + elif model_precision == "int8-quanto": + quant_level = qint8 + elif model_precision == "fp8-quanto" or model_precision == "fp8uz-quanto": + if torch.backends.mps.is_available(): + logger.warning( + "MPS doesn't support dtype float8, you must select another precision level such as bf16, int2, int8, or int8." + ) + + return model + if model_precision == "fp8-quanto": + quant_level = qfloat8 + elif model_precision == "fp8uz-quanto": + quant_level = qfloat8_e4m3fnuz + else: + raise ValueError(f"Invalid quantisation level: {model_precision}") + + return quant_level + +def _quanto_model(model, model_precision, base_model_precision=None, quantize_activations: bool = False): try: from helpers.training.quantisation import quanto_workarounds from optimum.quanto import ( freeze, quantize, - qfloat8, - qfloat8_e4m3fnuz, - qint8, - qint4, - qint2, QTensor, ) except ImportError as e: @@ -36,27 +63,22 @@ def _quanto_model(model, model_precision, base_model_precision=None): return model logger.info(f"Quantising {model.__class__.__name__}. Using {model_precision}.") - if model_precision == "int2-quanto": - weight_quant = qint2 - elif model_precision == "int4-quanto": - weight_quant = qint4 - elif model_precision == "int8-quanto": - weight_quant = qint8 - elif model_precision == "fp8-quanto" or model_precision == "fp8uz-quanto": - if torch.backends.mps.is_available(): - logger.warning( - "MPS doesn't support dtype float8, you must select another precision level such as bf16, int2, int8, or int8." - ) - - return model - if model_precision == "fp8-quanto": - weight_quant = qfloat8 - elif model_precision == "fp8uz-quanto": - weight_quant = qfloat8_e4m3fnuz + weight_quant = _quanto_type_map(model_precision) + extra_quanto_args = {} + if quantize_activations: + logger.info("Freezing model weights and activations") + extra_quanto_args["activations"] = weight_quant + extra_quanto_args["exclude"] = [ + "*.norm", + "*.norm1", + "*.norm2", + "*.norm2_context", + "proj_out", + ] else: - raise ValueError(f"Invalid quantisation level: {base_model_precision}") - quantize(model, weights=weight_quant) - logger.info("Freezing model.") + logger.info("Freezing model weights only") + + quantize(model, weights=weight_quant, **extra_quanto_args) freeze(model) return model @@ -73,7 +95,7 @@ def _torchao_filter_fn(mod: torch.nn.Module, fqn: str): return True -def _torchao_model(model, model_precision, base_model_precision=None): +def _torchao_model(model, model_precision, base_model_precision=None, quantize_activations:bool=False): if model_precision is None: model_precision = base_model_precision if model is None: @@ -95,6 +117,8 @@ def _torchao_model(model, model_precision, base_model_precision=None): f"To use torchao, please install the torchao library: `pip install torchao`: {e}" ) logger.info(f"Quantising {model.__class__.__name__}. Using {model_precision}.") + if quantize_activations: + logger.warning("Activation quantisation is not used in TorchAO. This will be ignored.") if model_precision == "int8-torchao": quantize_( @@ -131,23 +155,23 @@ def quantise_model( logger.info("Loading TorchAO. This may take a few minutes.") quant_fn = _torchao_model if transformer is not None: - transformer = quant_fn(transformer, args.base_model_precision) + transformer = quant_fn(transformer, model_precision=args.base_model_precision, quantize_activations=args.quantize_activations) if unet is not None: - unet = quant_fn(unet, args.base_model_precision) + unet = quant_fn(unet, model_precision=args.base_model_precision, quantize_activations=args.quantize_activations) if controlnet is not None: - controlnet = quant_fn(controlnet, args.base_model_precision) + controlnet = quant_fn(controlnet, model_precision=args.base_model_precision, quantize_activations=args.quantize_activations) if text_encoder_1 is not None: text_encoder_1 = quant_fn( - text_encoder_1, args.text_encoder_1_precision, args.base_model_precision + text_encoder_1, model_precision=args.text_encoder_1_precision, base_model_precision=args.base_model_precision ) if text_encoder_2 is not None: text_encoder_2 = quant_fn( - text_encoder_2, args.text_encoder_2_precision, args.base_model_precision + text_encoder_2, model_precision=args.text_encoder_2_precision, base_model_precision=args.base_model_precision ) if text_encoder_3 is not None: text_encoder_3 = quant_fn( - text_encoder_3, args.text_encoder_3_precision, args.base_model_precision + text_encoder_3, model_precision=args.text_encoder_3_precision, base_model_precision=args.base_model_precision ) return unet, transformer, text_encoder_1, text_encoder_2, text_encoder_3, controlnet diff --git a/helpers/training/quantisation/quanto_workarounds.py b/helpers/training/quantisation/quanto_workarounds.py index 86d3741e..d48bd1b2 100644 --- a/helpers/training/quantisation/quanto_workarounds.py +++ b/helpers/training/quantisation/quanto_workarounds.py @@ -37,6 +37,7 @@ def fp8_marlin_gemm_wrapper( # Monkey-patch the operator torch.ops.quanto.gemm_f16f8_marlin = fp8_marlin_gemm_wrapper + class TinyGemmQBitsLinearFunction( optimum.quanto.tensor.function.QuantizedLinearFunction ):