Skip to content

[Perf] Tunings for SM100 FP8 CUTLASS kernel #18778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented May 27, 2025

I noticed that the FP8 CUTLASS kernel for Blackwell only had one default set of configs. This PR adds new configs for small M < 128.

For Llama 8B on B200, these tunings offer a GEMM improvement of:

  • 1.7 to 2.5x speedup at M<64
  • 1.1 to 1.3x speedup at 64<=M<128

Accuracy eval:

lm_eval --model vllm --model_args pretrained=RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
vllm (pretrained=RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8-dynamic,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7657|±  |0.0117|
|     |       |strict-match    |     5|exact_match|↑  |0.7422|±  |0.0120|

Kernel benchmarks using #17126

# B200 original tunings
python benchmarks/kernels/bench_fp8_gemm.py --model meta-llama/Llama-3.1-8B-Instruct --tp-sizes 1
meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
   batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0     5.947784               2.190874               2.072256                       2.506103                       2.457149
1        16.0    94.680891              42.792760              39.819737                      47.809155                      47.631335
2        64.0   402.859620             288.133123             254.794430                     372.338300                     373.705553
3       128.0   728.629250             569.629964             505.291218                     736.715692                     740.021782
4       256.0  1095.642978            1078.797214             898.837969                    1464.385495                    1478.845473
5       512.0  1261.717282            1316.419568            1082.039073                    1754.103111                    1795.102080
6      1024.0  1421.279286            1790.971627            1448.677562                    2235.177663                    2584.518414
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
   batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0     2.748527               2.011939               1.852286                       2.426765                       2.356377
1        16.0    72.977131              38.338769              34.676703                      45.512563                      45.085737
2        64.0   286.363606             193.600910             170.867183                     251.441611                     251.175358
3       128.0   562.207888             383.696405             339.840832                     498.322879                     497.543227
4       256.0   870.002978             727.559934             603.567390                     993.429300                     993.718536
5       512.0  1278.707903            1219.430597             928.872100                    1859.222222                    1932.512698
6      1024.0  1404.598978            1542.923304            1147.494197                    2087.175763                    2354.134809
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
   batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0     5.782370               2.746741               2.707684                       3.076109                       2.971836
1        16.0    80.890917              57.057816              56.033710                      59.886153                      59.472853
2        64.0   315.610284             439.068431             449.404687                     475.453743                     499.493704
3       128.0   605.145298             889.769393             854.905235                     914.885575                     996.511367
4       256.0  1070.665944            1645.908127            1636.642023                    1723.292399                    1847.836941
5       512.0  1193.333425            2045.287912            2040.134063                    2035.227478                    2357.713930
6      1024.0  1304.686942            2096.869473            2323.544617                    2108.071352                    2593.971492
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
   batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0     4.897150               2.804295               2.710499                       3.428506                       3.439044
1        16.0    80.956437              75.659048              71.651283                      91.678247                      91.692121
2        64.0   307.759226             299.856325             284.288350                     370.326019                     370.149516
3       128.0   598.344683             593.513235             567.052452                     729.721261                     729.418140
4       256.0   872.655528             980.745535            1013.446415                    1455.929961                    1457.608605
5       512.0  1165.633163            1575.324418            1564.636469                    2301.963073                    2626.473590
6      1024.0  1367.848100            1851.543351            1646.237302                    2299.788792                    2808.035307

# B200 new tunings
python benchmarks/kernels/bench_fp8_gemm.py --model meta-llama/Llama-3.1-8B-Instruct --tp-sizes 1
meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
   batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0     5.910482               6.019409               5.222792                       8.450960                       8.513710
1        16.0    94.847035              95.974738              83.468890                     137.105099                     139.305386
2        64.0   402.608787             383.513195             332.526351                     533.221407                     542.513655
3       128.0   727.344333             668.354537             585.709400                     905.743648                     914.948318
4       256.0  1111.895929            1063.121536             898.299833                    1462.903418                    1478.033168
5       512.0  1337.772701            1308.382300            1082.112777                    1759.884804                    1795.570613
6      1024.0  1432.973854            1778.950801            1450.251744                    2207.502274                    2585.290571
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
   batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0     2.737313               4.082587               3.519601                       5.853116                       5.868324
1        16.0    72.871574              65.314403              56.451493                      93.371674                      93.612592
2        64.0   286.086195             258.552930             223.655973                     367.486248                     368.629133
3       128.0   562.275811             448.684298             392.292799                     610.722210                     612.259419
4       256.0   893.528563             724.776694             603.564815                     992.761187                     993.035383
5       512.0  1309.727992            1238.251236             930.249372                    1875.319022                    1936.937988
6      1024.0  1446.266181            1504.181444            1146.414136                    2277.382550                    2333.145483
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
   batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0     5.874980               8.907531               8.513217                      10.032002                      10.016935
1        16.0    81.733219             140.734419             134.216982                     158.030780                     157.949288
2        64.0   320.265939             547.445855             522.185380                     615.991357                     616.723404
3       128.0   602.637950             937.018227             910.115237                    1055.155568                    1056.821424
4       256.0  1088.675311            1526.447920            1638.574638                    1741.023667                    1957.404538
5       512.0  1126.877028            1714.897685            2107.324511                    2042.790650                    2611.255491
6      1024.0  1255.755813            2064.446702            2314.894491                    2191.506895                    2736.695630
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
   batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0     4.862665               7.579185               6.920945                      10.217567                      10.249600
1        16.0    81.973118             119.307032             109.946944                     162.638848                     163.243717
2        64.0   303.985531             469.113775             433.804909                     642.102177                     645.760931
3       128.0   597.201158             718.894275             685.251755                     931.641494                     935.816374
4       256.0   870.699430            1120.576212            1011.127380                    1455.045031                    1455.170319
5       512.0  1282.913614            1284.473964            1530.047594                    2387.802820                    2663.654061
6      1024.0  1276.525418            1640.552066            1695.358034                    2447.932588                    2800.132639

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. shall we do some accuracy test?

@houseroad
Copy link
Collaborator

cc: @chenyang78 @drisspg

@mgoin mgoin added performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed labels Jun 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants