Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What does "weights_scaling_factor_2" mean in safetensor results of awq_w4a8 #2561

Closed
gujiewen opened this issue Dec 11, 2024 · 3 comments
Closed
Assignees
Labels
Investigating Low Precision Issue about lower bit quantization, including int8, int4, fp8 triaged Issue has been triaged by maintainers

Comments

@gujiewen
Copy link

gujiewen commented Dec 11, 2024

I follow this step to do quantization for qwen2 model.
Then I got the safetensor results like

Image

What does ' prequant_scaling_factor', 'activation_scaling_factor', 'weights_scaling_factor', 'weights_scaling_factor_2' mean. And how are they used in the w4a8 gemm?

@nv-guomingz nv-guomingz added the Low Precision Issue about lower bit quantization, including int8, int4, fp8 label Dec 12, 2024
@github-actions github-actions bot added the triaged Issue has been triaged by maintainers label Dec 12, 2024
@nv-guomingz nv-guomingz removed triaged Issue has been triaged by maintainers Low Precision Issue about lower bit quantization, including int8, int4, fp8 Investigating labels Dec 12, 2024
@nv-guomingz nv-guomingz removed their assignment Dec 12, 2024
@nv-guomingz nv-guomingz added the Low Precision Issue about lower bit quantization, including int8, int4, fp8 label Dec 12, 2024
@github-actions github-actions bot added triaged Issue has been triaged by maintainers Investigating labels Dec 12, 2024
@nv-guomingz nv-guomingz added Low Precision Issue about lower bit quantization, including int8, int4, fp8 and removed triaged Issue has been triaged by maintainers Low Precision Issue about lower bit quantization, including int8, int4, fp8 Investigating labels Dec 12, 2024
@github-actions github-actions bot added triaged Issue has been triaged by maintainers Investigating labels Dec 12, 2024
@Barry-Delaney
Copy link
Collaborator

For a linear layer with GEMM shape [M, N, K], we need these components in the TRT-LLM layer:

Name Dtype Shape Layout
{LAYER_NAME}.weight float16 [K, N / 4] Interleaved and packed INT4
{LAYER_NAME}.weight_scaling_factor float16 [K / group_size, N] Row-major
{LAYER_NAME}.activation_scaling_factor float16 [K] Row-major
{LAYER_NAME}.alpha float32 [1] -

The calculation process is:
output = FP16(FP8(act * activation_scaling_factor) * FP8(weight * weight_scaling_factor) * alpha)

However, the checkpoint will have more parameters, here is how they are converted when building the engine.

@gujiewen
Copy link
Author

gujiewen commented Dec 18, 2024

For a linear layer with GEMM shape [M, N, K], we need these components in the TRT-LLM layer:

Name Dtype Shape Layout
{LAYER_NAME}.weight float16 [K, N / 4] Interleaved and packed INT4
{LAYER_NAME}.weight_scaling_factor float16 [K / group_size, N] Row-major
{LAYER_NAME}.activation_scaling_factor float16 [K] Row-major
{LAYER_NAME}.alpha float32 [1] -
The calculation process is: output = FP16(FP8(act * activation_scaling_factor) * FP8(weight * weight_scaling_factor) * alpha)

However, the checkpoint will have more parameters, here is how they are converted when building the engine.

Thanks for your reply. However, in w4a8_awq, I found prequant_scaling_factor has shape of [K]. According to the source code in modeling_utils.py

        if quant_algo == QuantAlgo.W4A8_AWQ:
            for name in list(weights):
                if name.endswith('weights_scaling_factor'):
                    activation_scaling_factor = weights.pop(
                        name.replace('weights_scaling_factor',
                                     'activation_scaling_factor'))
                    weights_scaling_factor_2 = weights.pop(
                        name.replace('weights_scaling_factor',
                                     'weights_scaling_factor_2'))
                    weights[name] /= weights_scaling_factor_2
                    weights[name] = weights[name].to(torch.float16).view(
                        str_dtype_to_torch(model_config.dtype))
                    weights[name.replace(
                        'weights_scaling_factor',
                        'prequant_scaling_factor')] /= activation_scaling_factor
                    weights[name.replace(
                        'weights_scaling_factor', 'alpha'
                    )] = activation_scaling_factor * weights_scaling_factor_2

,alpha semms to be computed as activation_scaling_factor * weight_scaling_factor_2.

So, the calculation process of w4a8 is
output = FP16(FP8(act * prequant_scaling_factor / activation_scaling_factor) * FP8(weight * weight_scaling_factor / weight_scaling_factor_2) * activation_scaling_factor * weight_scaling_factor_2),
If we set
activation_scaling_factor'= prequant_scaling_factor / activation_scaling_factor
and
weight_scaling_factor'=weight_scaling_factor / weight_scaling_factor_2
the formula becomes
output = FP16(FP8(act * activation_scaling_factor') * FP8(weight * weight_scaling_factor') * alpha) which is your form.

Am I right?

@Barry-Delaney
Copy link
Collaborator

Exactly.
For clearer understanding, you can consider W4A8_AWQ as W4A16_AWQ + FP8.
In addition to the components of W4A16_AWQ, i.e., prequant_scaling_factor, weight_scaling_factor, FP8 will provide 2 more per-tensor scaling factors activation_scaling_factor and weight_scaling_factor_2. In order to having them combined in one GEMM, we have:

  • Multiplied the per-tensor activation_scaling_factor into prequant_scaling_factor
  • Multiplied the per-tensor weight_scaling_factor_2 into weight_scaling_factor
  • Exposed alpha as layer parameters for de-quantization

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Investigating Low Precision Issue about lower bit quantization, including int8, int4, fp8 triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants