-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[Quantization] Add compressed-tensors NVFP4 support #18312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
dsikka
wants to merge
7
commits into
vllm-project:main
Choose a base branch
from
neuralmagic:nvfp4_emulation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+273
−19
Draft
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
27a239b
add ct nvfp4 emulation support
dsikka b30168b
fix conditions; add test models
dsikka f6c7914
add cutlass support
dsikka b2e8c26
clean-up
dsikka 8c62660
update
dsikka 739f512
remove extra value
dsikka e9b0f07
Merge branch 'main' into nvfp4_emulation
mgoin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,10 +23,10 @@ | |
CompressedTensorsMoEMethod) | ||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( | ||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, | ||
CompressedTensorsScheme, CompressedTensorsW4A16Fp4, | ||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, | ||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, | ||
CompressedTensorsWNA16) | ||
CompressedTensorsScheme, CompressedTensorsW4A4Fp4, | ||
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, | ||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, | ||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) | ||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( | ||
find_matched_target, is_activation_quantization_format, | ||
should_ignore_layer) | ||
|
@@ -217,10 +217,33 @@ def _check_scheme_supported(self, | |
else: | ||
return False | ||
|
||
def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): | ||
|
||
if weight_quant is None or input_quant is None: | ||
return False | ||
|
||
is_group_quant = ( | ||
weight_quant.strategy == QuantizationStrategy.GROUP.value) | ||
is_symmetric = weight_quant.symmetric and input_quant.symmetric | ||
|
||
is_group_size_16 = (weight_quant.group_size == 16 | ||
and input_quant.group_size == 16) | ||
is_float_type = (weight_quant.type == QuantizationType.FLOAT | ||
and input_quant.type == QuantizationType.FLOAT) | ||
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4 | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should check for dynamic input as well |
||
return (is_group_quant and is_float_type and is_4_bits | ||
and is_group_size_16 and is_symmetric) | ||
|
||
def _is_fp4a16_nvfp4(self, weight_quant: BaseModel, | ||
input_quant: BaseModel): | ||
|
||
is_weight_only = weight_quant is not None and input_quant is None | ||
if weight_quant is None: | ||
return False | ||
|
||
if input_quant is not None: | ||
return False | ||
|
||
is_group_quant = ( | ||
weight_quant.strategy == QuantizationStrategy.GROUP.value) | ||
is_symmetric = weight_quant.symmetric | ||
|
@@ -229,8 +252,8 @@ def _is_fp4a16_nvfp4(self, weight_quant: BaseModel, | |
is_float_type = weight_quant.type == QuantizationType.FLOAT | ||
is_4_bits = weight_quant.num_bits == 4 | ||
|
||
return (is_weight_only and is_group_quant and is_float_type | ||
and is_4_bits and is_group_size_16 and is_symmetric) | ||
return (is_group_quant and is_float_type and is_4_bits | ||
and is_group_size_16 and is_symmetric) | ||
|
||
def _is_static_tensor_w8a8(self, weight_quant: BaseModel, | ||
input_quant: BaseModel) -> bool: | ||
|
@@ -352,6 +375,9 @@ def _get_scheme_from_parts( | |
actorder=weight_quant.actorder) | ||
|
||
if is_activation_quantization_format(self.quant_format): | ||
if self._is_fp4a4_nvfp4(weight_quant, input_quant): | ||
return CompressedTensorsW4A4Fp4() | ||
|
||
if self._is_fp8_w8a8(weight_quant, input_quant): | ||
is_fp8_w8a8_supported = self._check_scheme_supported( | ||
CompressedTensorsW8A8Fp8.get_min_capability(), error=False) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
178 changes: 178 additions & 0 deletions
178
..._executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import Callable, Optional | ||
|
||
import torch | ||
from torch.nn.parameter import Parameter | ||
|
||
from vllm._custom_ops import (cutlass_scaled_fp4_mm, | ||
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( | ||
CompressedTensorsScheme) | ||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 | ||
dequantize_to_dtype, ref_nvfp4_quant) | ||
from vllm.model_executor.parameter import (GroupQuantScaleParameter, | ||
ModelWeightParameter, | ||
PerTensorScaleParameter) | ||
from vllm.platforms import current_platform | ||
|
||
logger = init_logger(__name__) | ||
|
||
__all__ = ["CompressedTensorsW4A4Fp4"] | ||
|
||
|
||
def cutlass_fp4_supported() -> bool: | ||
if not current_platform.is_cuda(): | ||
return False | ||
capability_tuple = current_platform.get_device_capability() | ||
capability = -1 if capability_tuple is None else capability_tuple.to_int() | ||
return cutlass_scaled_mm_supports_fp4(capability) | ||
|
||
|
||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): | ||
|
||
def __init__(self): | ||
self.group_size = 16 | ||
self.cutlass_nvfp4_supported = cutlass_fp4_supported() | ||
if not self.cutlass_nvfp4_supported: | ||
logger.warning("Current platform does not support cutlass NVFP4." | ||
" Running emulations.") | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
# dont restrict as emulations | ||
return 80 | ||
|
||
def run_nvfp4_emulations(self, x: torch.Tensor, layer): | ||
x_m, x_k = x.shape | ||
output_dtype = x.dtype | ||
|
||
# quantize input to (FP4 and interleaved block scale) | ||
x_fp4, x_blockscale = ref_nvfp4_quant(x, layer.input_global_scale, | ||
self.group_size) | ||
|
||
# dequantize input | ||
x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size) | ||
x_blockscale = x_blockscale.unsqueeze(-1) / layer.input_global_scale | ||
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) | ||
del x_fp4, x_blockscale | ||
|
||
# dequantize weight | ||
w_fp4 = layer.weight.data.view(torch.uint8) | ||
w_blockscale = layer.weight_scale_swizzled.data | ||
w_global_scale = layer.weight_global_scale | ||
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, | ||
output_dtype, x.device, self.group_size) | ||
|
||
# matmul | ||
out = torch.matmul(x_dq, w_dq.t()) | ||
del w_dq, x_dq | ||
return out | ||
|
||
def create_weights(self, layer: torch.nn.Module, | ||
output_partition_sizes: list[int], | ||
input_size_per_partition: int, | ||
params_dtype: torch.dtype, weight_loader: Callable, | ||
**kwargs): | ||
output_size_per_partition = sum(output_partition_sizes) | ||
layer.logical_widths = output_partition_sizes | ||
layer.input_size_per_partition = input_size_per_partition | ||
layer.output_size_per_partition = output_size_per_partition | ||
|
||
# Weight | ||
weight = ModelWeightParameter(data=torch.empty( | ||
sum(output_partition_sizes), | ||
input_size_per_partition // 2, | ||
dtype=torch.uint8), | ||
input_dim=1, | ||
output_dim=0, | ||
weight_loader=weight_loader) | ||
layer.register_parameter("weight_packed", weight) | ||
|
||
# Global Weight Scale | ||
weight_global_scale = PerTensorScaleParameter( | ||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32), | ||
weight_loader=weight_loader) | ||
layer.register_parameter("weight_global_scale", weight_global_scale) | ||
|
||
# Per Group Weight Scale | ||
weight_scale = GroupQuantScaleParameter(data=torch.empty( | ||
sum(output_partition_sizes), | ||
input_size_per_partition // self.group_size, | ||
dtype=torch.float8_e4m3fn, | ||
), | ||
input_dim=1, | ||
output_dim=0, | ||
weight_loader=weight_loader) | ||
|
||
layer.register_parameter("weight_scale", weight_scale) | ||
|
||
input_global_scale = PerTensorScaleParameter( | ||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32), | ||
weight_loader=weight_loader) | ||
layer.register_parameter("input_global_scale", input_global_scale) | ||
|
||
def swizzle_blockscale(self, scale: torch.tensor): | ||
assert (scale.dtype == torch.float8_e4m3fn) | ||
# Pad and blockwise interleave weight_scale | ||
scale_ndim = scale.ndim | ||
if scale.ndim == 2: | ||
scale = scale.unsqueeze(0) | ||
assert scale.ndim == 3 | ||
B, M, K = scale.shape | ||
round_up_multiple = lambda x, m: (x + m - 1) // m * m | ||
M_padded = round_up_multiple(M, 128) | ||
K_padded = round_up_multiple(K, 4) | ||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) | ||
padded_scale[:B, :M, :K] = scale | ||
batches, rows, cols = padded_scale.shape | ||
assert rows % 128 == 0 | ||
assert cols % 4 == 0 | ||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, | ||
cols // 4, 4) | ||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) | ||
swizzled_scale = swizzled_scale.contiguous().cuda() | ||
return (swizzled_scale.reshape(M, K) | ||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) | ||
|
||
def process_weights_after_loading(self, layer) -> None: | ||
|
||
global_input_scale = layer.input_global_scale.max().to(torch.float32) | ||
layer.input_global_scale = Parameter(global_input_scale, | ||
requires_grad=False) | ||
|
||
layer.weight_global_scale = Parameter( | ||
layer.weight_global_scale.max().to(torch.float32), | ||
requires_grad=False) | ||
|
||
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) | ||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, | ||
requires_grad=False) | ||
|
||
# required by cutlass kernel; need Parameter, not ModelWeightParameter | ||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) | ||
|
||
if self.cutlass_nvfp4_supported: | ||
layer.alpha = Parameter(layer.input_global_scale * | ||
layer.weight_global_scale, | ||
requires_grad=False) | ||
|
||
def apply_weights(self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
|
||
if self.cutlass_nvfp4_supported: | ||
output_dtype = x.dtype | ||
output_shape = [x.shape[0], layer.weight.shape[0]] | ||
|
||
# quantize BF16 or FP16 to (FP4 and interleaved block scale) | ||
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) | ||
|
||
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, | ||
layer.weight_scale_swizzled, | ||
1 / layer.alpha, output_dtype) | ||
if bias is not None: | ||
out = out + bias | ||
return out.view(*output_shape) | ||
return self.run_nvfp4_emulations(x, layer) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.