-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
T5 model, large difference in results when remove_input_padding
is enabled
#1999
Comments
After the transformation of the T5 model is complete, with "remove input padding" set to false and a maximum batch size of 8, when the model inference is set to a batch size of 1, 4, or 8 in the float32 data type, there are slight differences in the results. In the float16 data type, the differences in results are even more pronounced. In theory, the results of inference with different batch sizes should be consistent. Is this a normal phenomenon? |
I faced the same issue while running with TensorRT-LLM Triton Backend latest version, the output is not consistent and not the same as HF output. |
@0xd8b @thanhlt998 Thanks for reporting this! It might be some bugs introduced in recent versions. We'll investigate this soon. btw, if I remember correctly you have been actively using enc-dec TRT-LLM for a while, do you have some additonal info on whether this is a regression, like did everything work without any problem on your previous TRT-LLM version? If so, can you share which version was that? |
@symphonylyh Yes! As I remember it worked normally at TRT-LLM version |
@symphonylyh Do you find out the issue here? |
@thanhlt998 we found (1) the issue only appears for batch size >= 18 (2) the issue is token-specific, i.e. only happens for specific tokens (3) it was due to some NaN elements after the 1st layer. It's the progress so far -- we're still investigating the root cause. btw, we found 0604 version also had the same issue (Oh I noticed you have changed the version to 0618, we can try this as well) |
@symphonylyh Let me know if there are something new insighted. I am very eager to apply TensorRT-LLM for enc-dec in production. |
@thanhlt998 Just to share my own experience: When I use |
@ogaloglu I know, but I am eager to apply the inflight batching feature, not the dynamic batching feature. If use the above parameters, it is total like FasterTransformer with dynamic batching. |
@thanhlt998 I see, I got your point! Just want to add that I can actually use |
@symphonylyh Do you have any progress here? |
@thanhlt998 @ogaloglu @0xd8b root caused and fixing internally now! It was due to the padding_offset setup in the gptAttentionCommon. Bascially we need a padding_offset for the encoder input lengths and a separate one for the decoder output lengths. Otherwise it will fail under certain edge cases when length(1st input seq in the batch) < batch size (also, under certain config: dtype != fp32 && remove_padding mode)... Expect a fix in next week's main branch update! |
@symphonylyh Expect your fix next week! |
@symphonylyh when will the fixes be landed on main branch this week? |
@thanhlt998 @ogaloglu @0xd8b updated on main branch now! closing the issue. Please verify |
@symphonylyh I am trying your fixes with tensorrtllm triton backend at the latest version on main branch. But I can not launch the engines with triton backend after serializing encoder and decoder engines. I suspect the cause of triton 24.07 but when I build the docker with triton 24.05 (launch model normally before your fixes), it still raises the same error:
So I cannot verify your fixes now :(( The same problem is reported at here. |
@thanhlt998 you can always follow enc_dec/README.md and use the exampes/run.py or examples/enc_dec/run.py to verify by changing the input text. The triton backend error you saw should be either a version of setup issue from your end, because we have tested from e2e and didn't see such error. Maybe you can verify with non-triton first? |
@symphonylyh Yes, I ran tensorrt engines with c++ backend with examples/enc_dec/run.py in the built docker container successfully. I can always follow the enc_dec/README.md. But I cannot launch the engines with triton backend. |
@thanhlt998 I see. It looks like a general triton issue not limited to enc-dec. We need to wait for the backend people to investigate and fix then |
@symphonylyh thanks a lot; I can confirm that, for this input_text, there is no difference in the results depending on the So, the following experiments are still based on the export MODEL_NAME="t5-small"
export MODEL_TYPE="t5"
export INFERENCE_PRECISION="bfloat16"
export TP_SIZE=1
export PP_SIZE=1
export WORLD_SIZE=1
export MAX_BEAM_WIDTH=1
export GPUS_PER_NODE=1
export PAGED_KV_CACHE="disable"
export MAX_BATCH_SIZE=32
export MAX_INPUT_LEN=512
export MAX_SEQ_LEN=201
export REMOVE_INPUT_PADDING="enable" When I pass a batch with the first two sentences of HF output (basically the same as the input): TRT output: Can there be still some token-specific problems? Thanks a lot in advance! Please let me know if you think I should rather create a new issue for that! |
@ogaloglu thanks for the info. We can continue using this issue to investigate. Reopened now |
@symphonylyh thank you! Looking forward to your updates! |
@symphonylyh I've tested your fixes with triton backend, it still have the same issues before but with lower frequency. Looking forward to your insights of these issues! |
@ogaloglu the "a man, a man" issue appears to be only with bfloat16. When you switch to float16, it's seems fine. This is not a fix yet, but just to share findings |
@symphonylyh In my experiments, I only used RTX2080Ti (support float16 only) so I think float16/bfloat16 is not the cause of this issue. |
@symphonylyh thank you for the insight! Then, I will run some experiments early next week and share the outcomes. |
@0xd8b @ogaloglu @thanhlt998 good news! The main issue regarding remove padding has been fixed and will release next week. If you want to verify locally first, you can search for invokeAddFusedQKVBiasTranspose in bertAttentionPlugin.cpp, and copy the cudaMemset call (link above) and put it before invokeAddFusedQKVBiasTranspose. This should solve a failure when doing IFB + Remove padding + when BS is large The "a man, a man" issue is still somehow vague, but I will suggest @ogaloglu to switch to float16 because I saw the "a man" in bfloat16 but switching to float16 gives 100% match with HF. This cause must be something else and I would suggest filing another bug. The more fundamental issue should be fixed by the solution in my comment here, and will be released next week |
@symphonylyh It is great to hear your solution to fix this issue. I am looking forward to your update next week! |
@symphonylyh I have tested the fixes following your instruction. After the load test processing with 999 samples (random shuffling):
The results are much better than before the fixes. But it still has slight difference between TRT-LLM and FasterTransformer or two diffence run times. |
@thanhlt998 do you have a concreate input text to reproduce the difference? So far we think the TRT-LLM results are correct and good match with HF. While an exact 100% match cannot be guaranteed among TRT-LLM, FT, HF because their kernel selected is indeed different But if you found weird output, please post here with a reproducer and we'll investigate |
@symphonylyh I think it's ok with non-structured output task. However, I have tested it with tasks that have structured outputs like Function Calling, and the output of TRT-LLM causes the accuracy to decrease by 10-30% compared to FasterTransformer on the same test set. |
@thanhlt998 understood, but do you have a Model name + Reproduce text? We can't investigate the accuracy issue without such info And in this case it makes sense to create a separate issue for better tracking. Can you file one? |
@symphonylyh I use my private model that follow the same config with UL2 model architecture (use silu instead of relu activation) which raises these above issues. I will find the way to reproduce these issues with the open-source model and let you know. |
@symphonylyh You can reproduce the bug with model ul2-base-en-nl, with the inputs:
The outputs after run
|
@symphonylyh Did you reproduce the bug with the provided information above? Do you have any insights? |
@symphonylyh Do you have any progress here?
|
@symphonylyh We have transformed the T5 model using TensorRT LLM, with the input sequence set to 2048, in float16 format, and using both GPT and BERT plugins. During the transformation, we set the max_batchsize to 3, and for inference, we set the batch to 1, and the model inference results were normal. However, when we set the max_batchsize to 1 during the transformation of the engine, and kept the batch to 1 for inference, the model inference results were abnormal. We have identified that the issue is with the GPT plugin, but since we have retrained the model, we cannot provide a sample to reproduce the issue. We are using TensorRT LLM version 0.9, and we would like to inquire about the direction for troubleshooting. We look forward to your reply. |
System Info
Container:
nvidia/cuda:12.4.1-devel-ubuntu22.04
GPU: L4
TensorRT-LLM version:
0.12.0.dev2024071600
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
git clone https://huggingface.co/t5-small tmp/hf_models/t5-small
Expected behavior
HF outputs and TRT-LLM outputs should be roughly the same. As a side note, this is the case with FasterTransformer.
actual behavior
TRT-LLM outputs differ significantly from HF outputs. Generation quality is clearly worse: Repetitions, outputs with only special tokens etc.
additional notes
I just realized before submitting the issue that the TRT-LLM and HF results are much closer to each other when
remove_input_padding
is disabled. Therefore, changed the title accordingly. Additionally, similar to the previously created issues, I notice that the discrepancy between TRT-LLM and HF outputs increases as batch size is increased.The text was updated successfully, but these errors were encountered: