Description
System Info
Intel(R) Xeon(R) CPU @ 2.20GHz
Architecture: x86_64
NVIDIA A100-SXM4-40G
Ubuntu
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
I follow official examples for Llama model: https://github.com/NVIDIA/TensorRT-LLM/tree/v0.8.0/examples/llama
I've been experiencing significant slowdowns when the return_context_logits flag is turned on. For context, I am utilizing the llama example and have specifically enabled the gather_context_logits flag during the TensorRT-LLM build process.
Additionally, I have been passing return_context_logits through the triton_client in an attempt to retrieve logits for the request sentences. To accommodate this, I have set the request_output_len or output_len to 1.
Expected behavior
The anticipated behavior when enabling return_context_logits would be a manageable decrease in speed, ideally not significantly deviating from the throughput when the flag is off. Performance should ideally be on par with or better than the forward pass speed of HuggingFace implementations.
actual behavior
The current observed behavior shows an almost 8-fold decrease in execution speed when trying to obtain logits with a maximum length of 1. This is surprisingly slower than the forward pass speed of comparable HuggingFace models.
Here's a comparative table of performance with and without the return_context_logits
flag:
Logit Status | max_gen_token | input_len | Execution Time | Average Time per Example |
---|---|---|---|---|
On | 1 | 2000 | 0:49 | 0.98s |
Off | 1 | 2000 | 0:06 | 0.12s |
additional notes
I have executed the trtllm-build with the following configuration:
trtllm-build --checkpoint_dir {model_dir}/tensorrt/{tp_size}-gpu \
--remove_input_padding enable \
--gpt_attention_plugin float16 \
--context_fmha enable \
--gemm_plugin float16 \
--output_dir {model_dir}/tensorrt_llm/context_fmha \
--paged_kv_cache disable \
--enable_xqa disable \
--multi_block_mode disable \
--use_custom_all_reduce disable \
--tp_size {tp_size} \
--workers {tp_size} \
--max_batch_size 1 \
--max_input_len 8192 \
--max_output_len 8192 \
--max_num_tokens 8192 \
--gather_context_logits
Any insights or assistance in addressing this unexpected slowdown would be greatly appreciated. If there are any further experiments or specific areas you would recommend investigating, please advise.