Skip to content

[Perf] Tune scaled_fp8_quant by increasing vectorization #18844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented May 28, 2025

Increase the vectorization from 4 to 16 elements and decrease the block_size to launch for kernel for the per-tensor and per-token fp8 quantization CUDA kernels. The improvements are visible on H100 and obvious on B200.

Evaluations

# Command for static per-tensor quantization
lm_eval --model vllm --model_args pretrained=RedHatAI/Meta-Llama-3.1-8B-FP8 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

# 0.9.0
vllm (pretrained=RedHatAI/Meta-Llama-3.1-8B-FP8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4973|±  |0.0138|
|     |       |strict-match    |     5|exact_match|↑  |0.4973|±  |0.0138|

# This PR
vllm (pretrained=RedHatAI/Meta-Llama-3.1-8B-FP8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4973|±  |0.0138|
|     |       |strict-match    |     5|exact_match|↑  |0.4973|±  |0.0138|
# Command for dynamic per-token quantization
lm_eval --model vllm --model_args pretrained=RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

# 0.9.0
vllm (pretrained=RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7657|±  |0.0117|
|     |       |strict-match    |     5|exact_match|↑  |0.7415|±  |0.0121|

# This PR
vllm (pretrained=RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7650|±  |0.0117|
|     |       |strict-match    |     5|exact_match|↑  |0.7407|±  |0.0121|

B200

e2e throughput benchmark on B200:

# static per-tensor quant
VLLM_ATTENTION_BACKEND=FLASHINFER vllm bench throughput --model RedHatAI/Meta-Llama-3.1-8B-FP8 --load-format dummy --input-len 1000 --output-len 100
# Before
Throughput: 69.24 requests/s, 79479.95 total tokens/s, 6924.14 output tokens/s
# After
Throughput: 73.78 requests/s, 81110.68 total tokens/s, 7378.19 output tokens/s
# dynamic per-token quant
VLLM_ATTENTION_BACKEND=FLASHINFER vllm bench throughput --model RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic --load-format dummy --input-len 1000 --output-len 100
# Before
Throughput: 60.19 requests/s, 69073.58 total tokens/s, 6018.77 output tokens/s
# After
Throughput: 67.09 requests/s, 73723.27 total tokens/s, 6709.12 output tokens/s

On B200, this results in a kernel speedup of up to 99% once processing > 2k tokens at hidden_size=4k.

Screenshot 2025-05-28 at 1 26 34 PM

H100

e2e throughput benchmark on H100:

vllm bench throughput --model RedHatAI/Meta-Llama-3.1-8B-FP8 --load-format dummy --input-len 1000 --output-len 100
# Before
Throughput: 37.31 requests/s, 42796.08 total tokens/s, 3730.78 output tokens/s
# After
Throughput: 40.92 requests/s, 44963.62 total tokens/s, 4091.79 output tokens/s

On H100, this results in a kernelspeedup of up to 40% once processing > 512 tokens at hidden_size=4k. There is a small degradation for < 256 tokens, but I feel this is not a big deal

Screenshot 2025-05-28 at 12 16 10 PM

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

mgoin added 2 commits May 28, 2025 16:53
Signed-off-by: mgoin <[email protected]>
Signed-off-by: mgoin <[email protected]>
template <typename quant_type_t>
struct __align__(4) q8x4_t {
template <typename quant_type_t, size_t vec_size>
struct __align__(4) q8_n_t {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we align this to vec_size?

Copy link
Member Author

@mgoin mgoin May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is overkill and would place more restrictions on the allowed tensors

Signed-off-by: mgoin <[email protected]>
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! Do we have some quality evals to ensure that the quality is not impacted?

@mgoin mgoin added performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed labels May 28, 2025
@ProExpertProg
Copy link
Contributor

We should also try benchmarking a Triton implementation for this as it would be much easier to maintain! Especially given this is an elementwise op

@mgoin mgoin changed the title [Perf] Tune scaled_fp8_quant by increasing vectorization [Perf] Tune scaled_fp8_quant by increasing vectorization May 29, 2025
Signed-off-by: mgoin <[email protected]>
@mgoin
Copy link
Member Author

mgoin commented May 29, 2025

@ProExpertProg @ekagra-ranjan I've added gsm8k evals now to the description, PTAL. I think I'll add benchmarking against other impls in another PR

}
}

using BlockReduce = cub::BlockReduce<float, 1024>;
using BlockReduce = cub::BlockReduce<float, 256>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

256 here must be the same as block_size below passed to kernel launch.
Propose to pass block_size as a template arg to dynamic_per_token_scaled_fp8_quant_kernel

template <typename scalar_t, typename fp8_type, int BLOCK_SIZE>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
...
  using BlockReduce = cub::BlockReduce<float, BLOCK_SIZE>;

@jinyouzhi
Copy link
Contributor

Awesome PR!
I have a simple test with H20 (sorry, I only have this Hopper series on hand).

vllm bench throughput --model RedHatAI/Meta-Llama-3.1-8B-FP8 --load-format dummy --input-len 1000 --output-len 100
w/ PR
Throughput: 13.04 requests/s, 14319.47 total tokens/s, 1303.64 output tokens/s
w/o PR
Throughput: 12.97 requests/s, 14235.45 total tokens/s, 1296.72 output tokens/s

How to tune this or Is it need some HW feature support?
And another question can this expand to other ops? I am interested on it. Thank you!

@mgoin
Copy link
Member Author

mgoin commented May 30, 2025

I have been experimenting with completely replacing the CUDA kernels with simple torch implementations using torch.compile, which seems to be better in all cases on H100, especially small sizes. I need to test on B200 but I think this might be a better way forward

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants