Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is W4A(FP)8 quant not supported with bf16 datatype? #1843

Closed
2 of 4 tasks
wxsms opened this issue Jun 26, 2024 · 6 comments
Closed
2 of 4 tasks

Is W4A(FP)8 quant not supported with bf16 datatype? #1843

wxsms opened this issue Jun 26, 2024 · 6 comments
Assignees
Labels
bug Something isn't working question Further information is requested

Comments

@wxsms
Copy link

wxsms commented Jun 26, 2024

System Info

ubuntu, with Ada GPUs. tllm version: 0.11.0.dev2024061800

Who can help?

@Tracin

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

use example/quantization/quantize.py to quant a model like this (I am using Llama):

python3 ./quantization/quantize.py \
        --model_dir /mnt/models/source \
        --dtype bfloat16 \
        --qformat w4a8_awq \
        --output_dir /tmp/checkpoint \
        --calib_tp_size 4 \
        --tp_size 1

Expected behavior

the quantization should work

actual behavior

not working with error: FP8 is unsupported on with BF16 scales and zero-points!

additional notes

I notice that in tensorrt_llm/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp there is a snip of code like this:

#if defined(ENABLE_BF16)
    else if (mType == nvinfer1::DataType::kBF16)
    {
        if (quant_algo & FP8_ALPHA)
        {
            // FP8 requires at least sm89 devices
            if (mArch < 89)
            {
                TLLM_THROW("W4A(fp)8 kernel is unsupported on pre-Ada (sm<89) architectures!");
            }
            TLLM_THROW("FP8 is unsupported on with BF16 scales and zero-points!");
        }
        else
        {
            if (quant_algo & ZERO)
            {
                // has zeros
                m_weightOnlyGroupwiseGemmRunner
                    = std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16,
                        cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>>();
            }
            else
            {
                // no zeros
                m_weightOnlyGroupwiseGemmRunner
                    = std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16,
                        cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>>();
            }
        }
        mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
            mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise);
        mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise;
    }
#endif

I not very sure but is this a mistake? though the error message is mentioning zero-points, but it throws without zero condition check (which in in the next block I think?).

@wxsms wxsms added the bug Something isn't working label Jun 26, 2024
@nv-guomingz nv-guomingz added the question Further information is requested label Jun 26, 2024
@Barry-Delaney
Copy link
Collaborator

@wxsms thanks for the feedback. w4a8_awq with BF16 data type is not supported yet, we will add it in the following updates.

@nv-guomingz
Copy link
Collaborator

Hi @wxsms could we close this ticket now?

@wxsms
Copy link
Author

wxsms commented Jul 3, 2024

Hi @wxsms could we close this ticket now?

It's okay. we can also close this issue while this feature is fully supported. You may close it on your demand. Thanks

@nv-guomingz
Copy link
Collaborator

Thanks @wxsms . Please feel free to reopen it if neede.

@youki-sada
Copy link

@Barry-Delaney Do you have any updates on this? It's seems not supported yet on v0.15.0. We are waiting this feature since bfloat16 is crucial for gemma2 27B.

@youki-sada
Copy link

I guess now it's supported on v0.16.0. Thank you.
https://github.com/NVIDIA/TensorRT-LLM/releases

Added W4A8 quantization support to BF16 models on Ada (SM89).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants