Skip to content

Commit

Permalink
Merge pull request #1014 from bghira/feature/quanto-activations
Browse files Browse the repository at this point in the history
quanto: activations sledgehammer
  • Loading branch information
bghira authored Oct 1, 2024
2 parents 43a9e37 + 262bfbc commit fec2827
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 35 deletions.
10 changes: 10 additions & 0 deletions OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,16 @@ A lot of settings are instead set through the [dataloader config](/documentation
- **Why**: Enables integration with platforms like TensorBoard, wandb, or comet_ml for monitoring. Use multiple values separated by a comma to report to multiple trackers;
- **Choices**: wandb, tensorboard, comet_ml

# Environment configuration variables

The above options apply for the most part, to `config.json` - but some entries must be set inside `config.env` instead.

- `TRAINING_NUM_PROCESSES` should be set to the number of GPUs in the system. For most use-cases, this is enough to enable DistributedDataParallel (DDP) training
- `TRAINING_DYNAMO_BACKEND` defaults to `no` but can be set to `inductor` for substantial speed improvements on NVIDIA hardware
- `SIMPLETUNER_LOG_LEVEL` defaults to `INFO` but can be set to `DEBUG` to add more information for issue reports into `debug.log`
- `VENV_PATH` can be set to the location of your python virtual env, if it is not in the typical `.venv` location
- `ACCELERATE_EXTRA_ARGS` can be left unset, or, contain extra arguments to add like `--multi_gpu` or FSDP-specific flags

---

This is a basic overview meant to help you get started. For a complete list of options and more detailed explanations, please refer to the full specification:
Expand Down
2 changes: 2 additions & 0 deletions documentation/quickstart/FLUX.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ There, you will possibly need to modify the following variables:
- `optimizer` - Beginners are recommended to stick with adamw_bf16, though optimi-lion and optimi-stableadamw are also good choices.
- `mixed_precision` - Beginners should keep this in `bf16`

Multi-GPU users can reference [this document](/OPTIONS.md#environment-configuration-variables) for information on configuring the number of GPUs to use.

#### Validation prompts

Inside `config/config.json` is the "primary validation prompt", which is typically the main instance_prompt you are training on for your single subject or style. Additionally, a JSON file may be created that contains extra prompts to run through during validations.
Expand Down
7 changes: 7 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,13 @@ def get_argument_parser():
" Using 'fp8-quanto' will require Quanto for quantisation (Apple Silicon, NVIDIA, AMD)."
),
)
parser.add_argument(
"--quantize_activations",
action="store_true",
help=(
"(EXPERIMENTAL) This option is currently unsupported, and exists solely for development purposes."
),
)
parser.add_argument(
"--base_model_default_dtype",
type=str,
Expand Down
4 changes: 2 additions & 2 deletions helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def safety_check(args, accelerator):
" Use LORA_TYPE (--lora_type) lycoris for quantised multi-GPU training of LoKr models in FP8."
)
args.base_model_precision = "int8-quanto"
# sys.exit(1)

if (
args.base_model_precision in ["fp8-quanto", "int4-quanto"]
(args.base_model_precision in ["fp8-quanto", "int4-quanto"] or (args.base_model_precision != "no_change" and args.quantize_activations))
and accelerator.state.dynamo_plugin.backend.lower() == "inductor"
):
logger.warning(
Expand Down
90 changes: 57 additions & 33 deletions helpers/training/quantisation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,44 @@
logger.setLevel(logging.ERROR)


def _quanto_model(model, model_precision, base_model_precision=None):
def _quanto_type_map(model_precision: str):
if model_precision == "no_change":
return None
from optimum.quanto import (
qfloat8,
qfloat8_e4m3fnuz,
qint8,
qint4,
qint2,
)
if model_precision == "int2-quanto":
quant_level = qint2
elif model_precision == "int4-quanto":
quant_level = qint4
elif model_precision == "int8-quanto":
quant_level = qint8
elif model_precision == "fp8-quanto" or model_precision == "fp8uz-quanto":
if torch.backends.mps.is_available():
logger.warning(
"MPS doesn't support dtype float8, you must select another precision level such as bf16, int2, int8, or int8."
)

return model
if model_precision == "fp8-quanto":
quant_level = qfloat8
elif model_precision == "fp8uz-quanto":
quant_level = qfloat8_e4m3fnuz
else:
raise ValueError(f"Invalid quantisation level: {model_precision}")

return quant_level

def _quanto_model(model, model_precision, base_model_precision=None, quantize_activations: bool = False):
try:
from helpers.training.quantisation import quanto_workarounds
from optimum.quanto import (
freeze,
quantize,
qfloat8,
qfloat8_e4m3fnuz,
qint8,
qint4,
qint2,
QTensor,
)
except ImportError as e:
Expand All @@ -36,27 +63,22 @@ def _quanto_model(model, model_precision, base_model_precision=None):
return model

logger.info(f"Quantising {model.__class__.__name__}. Using {model_precision}.")
if model_precision == "int2-quanto":
weight_quant = qint2
elif model_precision == "int4-quanto":
weight_quant = qint4
elif model_precision == "int8-quanto":
weight_quant = qint8
elif model_precision == "fp8-quanto" or model_precision == "fp8uz-quanto":
if torch.backends.mps.is_available():
logger.warning(
"MPS doesn't support dtype float8, you must select another precision level such as bf16, int2, int8, or int8."
)

return model
if model_precision == "fp8-quanto":
weight_quant = qfloat8
elif model_precision == "fp8uz-quanto":
weight_quant = qfloat8_e4m3fnuz
weight_quant = _quanto_type_map(model_precision)
extra_quanto_args = {}
if quantize_activations:
logger.info("Freezing model weights and activations")
extra_quanto_args["activations"] = weight_quant
extra_quanto_args["exclude"] = [
"*.norm",
"*.norm1",
"*.norm2",
"*.norm2_context",
"proj_out",
]
else:
raise ValueError(f"Invalid quantisation level: {base_model_precision}")
quantize(model, weights=weight_quant)
logger.info("Freezing model.")
logger.info("Freezing model weights only")

quantize(model, weights=weight_quant, **extra_quanto_args)
freeze(model)

return model
Expand All @@ -73,7 +95,7 @@ def _torchao_filter_fn(mod: torch.nn.Module, fqn: str):
return True


def _torchao_model(model, model_precision, base_model_precision=None):
def _torchao_model(model, model_precision, base_model_precision=None, quantize_activations:bool=False):
if model_precision is None:
model_precision = base_model_precision
if model is None:
Expand All @@ -95,6 +117,8 @@ def _torchao_model(model, model_precision, base_model_precision=None):
f"To use torchao, please install the torchao library: `pip install torchao`: {e}"
)
logger.info(f"Quantising {model.__class__.__name__}. Using {model_precision}.")
if quantize_activations:
logger.warning("Activation quantisation is not used in TorchAO. This will be ignored.")

if model_precision == "int8-torchao":
quantize_(
Expand Down Expand Up @@ -131,23 +155,23 @@ def quantise_model(
logger.info("Loading TorchAO. This may take a few minutes.")
quant_fn = _torchao_model
if transformer is not None:
transformer = quant_fn(transformer, args.base_model_precision)
transformer = quant_fn(transformer, model_precision=args.base_model_precision, quantize_activations=args.quantize_activations)
if unet is not None:
unet = quant_fn(unet, args.base_model_precision)
unet = quant_fn(unet, model_precision=args.base_model_precision, quantize_activations=args.quantize_activations)
if controlnet is not None:
controlnet = quant_fn(controlnet, args.base_model_precision)
controlnet = quant_fn(controlnet, model_precision=args.base_model_precision, quantize_activations=args.quantize_activations)

if text_encoder_1 is not None:
text_encoder_1 = quant_fn(
text_encoder_1, args.text_encoder_1_precision, args.base_model_precision
text_encoder_1, model_precision=args.text_encoder_1_precision, base_model_precision=args.base_model_precision
)
if text_encoder_2 is not None:
text_encoder_2 = quant_fn(
text_encoder_2, args.text_encoder_2_precision, args.base_model_precision
text_encoder_2, model_precision=args.text_encoder_2_precision, base_model_precision=args.base_model_precision
)
if text_encoder_3 is not None:
text_encoder_3 = quant_fn(
text_encoder_3, args.text_encoder_3_precision, args.base_model_precision
text_encoder_3, model_precision=args.text_encoder_3_precision, base_model_precision=args.base_model_precision
)

return unet, transformer, text_encoder_1, text_encoder_2, text_encoder_3, controlnet
1 change: 1 addition & 0 deletions helpers/training/quantisation/quanto_workarounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def fp8_marlin_gemm_wrapper(

# Monkey-patch the operator
torch.ops.quanto.gemm_f16f8_marlin = fp8_marlin_gemm_wrapper

class TinyGemmQBitsLinearFunction(
optimum.quanto.tensor.function.QuantizedLinearFunction
):
Expand Down

0 comments on commit fec2827

Please sign in to comment.