-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[Perf] Tune scaled_fp8_quant
by increasing vectorization
#18844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: mgoin <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: mgoin <[email protected]>
Signed-off-by: mgoin <[email protected]>
csrc/quantization/vectorization.cuh
Outdated
template <typename quant_type_t> | ||
struct __align__(4) q8x4_t { | ||
template <typename quant_type_t, size_t vec_size> | ||
struct __align__(4) q8_n_t { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we align this to vec_size
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that is overkill and would place more restrictions on the allowed tensors
Signed-off-by: mgoin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR! Do we have some quality evals to ensure that the quality is not impacted?
Signed-off-by: mgoin <[email protected]>
We should also try benchmarking a Triton implementation for this as it would be much easier to maintain! Especially given this is an elementwise op |
Signed-off-by: mgoin <[email protected]>
scaled_fp8_quant
by increasing vectorization
Signed-off-by: mgoin <[email protected]>
@ProExpertProg @ekagra-ranjan I've added gsm8k evals now to the description, PTAL. I think I'll add benchmarking against other impls in another PR |
Signed-off-by: mgoin <[email protected]>
} | ||
} | ||
|
||
using BlockReduce = cub::BlockReduce<float, 1024>; | ||
using BlockReduce = cub::BlockReduce<float, 256>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
256
here must be the same as block_size
below passed to kernel launch.
Propose to pass block_size
as a template arg to dynamic_per_token_scaled_fp8_quant_kernel
template <typename scalar_t, typename fp8_type, int BLOCK_SIZE>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
...
using BlockReduce = cub::BlockReduce<float, BLOCK_SIZE>;
Awesome PR!
How to tune this or Is it need some HW feature support? |
I have been experimenting with completely replacing the CUDA kernels with simple torch implementations using torch.compile, which seems to be better in all cases on H100, especially small sizes. I need to test on B200 but I think this might be a better way forward |
Increase the vectorization from 4 to 16 elements and decrease the block_size to launch for kernel for the per-tensor and per-token fp8 quantization CUDA kernels. The improvements are visible on H100 and obvious on B200.
Evaluations
B200
e2e throughput benchmark on B200:
On B200, this results in a kernel speedup of up to 99% once processing > 2k tokens at hidden_size=4k.
H100
e2e throughput benchmark on H100:
On H100, this results in a kernelspeedup of up to 40% once processing > 512 tokens at hidden_size=4k. There is a small degradation for < 256 tokens, but I feel this is not a big deal