Skip to content

Performance Issue with return_context_logits Enabled in TensorRT-LLM #419

Open
@metterian

Description

@metterian

System Info

Intel(R) Xeon(R) CPU @ 2.20GHz
Architecture: x86_64
NVIDIA A100-SXM4-40G
Ubuntu

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I follow official examples for Llama model: https://github.com/NVIDIA/TensorRT-LLM/tree/v0.8.0/examples/llama

I've been experiencing significant slowdowns when the return_context_logits flag is turned on. For context, I am utilizing the llama example and have specifically enabled the gather_context_logits flag during the TensorRT-LLM build process.

Additionally, I have been passing return_context_logits through the triton_client in an attempt to retrieve logits for the request sentences. To accommodate this, I have set the request_output_len or output_len to 1.

Expected behavior

The anticipated behavior when enabling return_context_logits would be a manageable decrease in speed, ideally not significantly deviating from the throughput when the flag is off. Performance should ideally be on par with or better than the forward pass speed of HuggingFace implementations.

actual behavior

The current observed behavior shows an almost 8-fold decrease in execution speed when trying to obtain logits with a maximum length of 1. This is surprisingly slower than the forward pass speed of comparable HuggingFace models.

Here's a comparative table of performance with and without the return_context_logits flag:

Logit Status max_gen_token input_len Execution Time Average Time per Example
On 1 2000 0:49 0.98s
Off 1 2000 0:06 0.12s

additional notes

I have executed the trtllm-build with the following configuration:

trtllm-build --checkpoint_dir {model_dir}/tensorrt/{tp_size}-gpu \
             --remove_input_padding enable \
             --gpt_attention_plugin float16 \
             --context_fmha enable \
             --gemm_plugin float16 \
             --output_dir {model_dir}/tensorrt_llm/context_fmha \
             --paged_kv_cache disable \
             --enable_xqa disable \
             --multi_block_mode disable \
             --use_custom_all_reduce disable \
	         --tp_size {tp_size} \
	         --workers {tp_size} \
             --max_batch_size 1 \
             --max_input_len 8192 \
             --max_output_len 8192 \
	         --max_num_tokens 8192 \
             --gather_context_logits

Any insights or assistance in addressing this unexpected slowdown would be greatly appreciated. If there are any further experiments or specific areas you would recommend investigating, please advise.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtriagedIssue has been triaged by maintainers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions