Skip to content

[Quantization] Add compressed-tensors NVFP4 support #18312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensors24, CompressedTensorsLinearMethod,
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
from vllm.platforms import current_platform
Expand Down Expand Up @@ -650,9 +651,13 @@ def check_model(model):
assert output


def test_compressed_tensors_nvfp4a16(vllm_runner):
# run weight only example
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP4"
# TODO: update model configs with next ct release
@pytest.mark.parametrize("args", [
("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP4", CompressedTensorsW4A16Fp4),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A4", CompressedTensorsW4A4Fp4)
])
def test_compressed_tensors_nvfp4(vllm_runner, args):
model, scheme = args
with vllm_runner(model, enforce_eager=True) as llm:

def check_model(model):
Expand All @@ -661,7 +666,7 @@ def check_model(model):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
assert isinstance(qkv_proj.scheme, scheme)
assert qkv_proj.scheme.group_size == 16

llm.apply_model(check_model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
CompressedTensorsMoEMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24,
CompressedTensorsScheme, CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
CompressedTensorsScheme, CompressedTensorsW4A4Fp4,
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
Expand Down Expand Up @@ -217,10 +217,33 @@ def _check_scheme_supported(self,
else:
return False

def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):

if weight_quant is None or input_quant is None:
return False

is_group_quant = (
weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric

is_group_size_16 = (weight_quant.group_size == 16
and input_quant.group_size == 16)
is_float_type = (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT)
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check for dynamic input as well

return (is_group_quant and is_float_type and is_4_bits
and is_group_size_16 and is_symmetric)

def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
input_quant: BaseModel):

is_weight_only = weight_quant is not None and input_quant is None
if weight_quant is None:
return False

if input_quant is not None:
return False

is_group_quant = (
weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_symmetric = weight_quant.symmetric
Expand All @@ -229,8 +252,8 @@ def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
is_float_type = weight_quant.type == QuantizationType.FLOAT
is_4_bits = weight_quant.num_bits == 4

return (is_weight_only and is_group_quant and is_float_type
and is_4_bits and is_group_size_16 and is_symmetric)
return (is_group_quant and is_float_type and is_4_bits
and is_group_size_16 and is_symmetric)

def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
Expand Down Expand Up @@ -352,6 +375,9 @@ def _get_scheme_from_parts(
actorder=weight_quant.actorder)

if is_activation_quantization_format(self.quant_format):
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4Fp4()

if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
Expand All @@ -17,5 +18,6 @@
"CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensors24", "CompressedTensorsW4A16Fp4"
"CompressedTensors24", "CompressedTensorsW4A16Fp4",
"CompressedTensorsW4A4Fp4"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional

import torch
from torch.nn.parameter import Parameter

from vllm._custom_ops import (cutlass_scaled_fp4_mm,
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
dequantize_to_dtype, ref_nvfp4_quant)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform

logger = init_logger(__name__)

__all__ = ["CompressedTensorsW4A4Fp4"]


def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)


class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):

def __init__(self):
self.group_size = 16
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
if not self.cutlass_nvfp4_supported:
logger.warning("Current platform does not support cutlass NVFP4."
" Running emulations.")

@classmethod
def get_min_capability(cls) -> int:
# dont restrict as emulations
return 80

def run_nvfp4_emulations(self, x: torch.Tensor, layer):
x_m, x_k = x.shape
output_dtype = x.dtype

# quantize input to (FP4 and interleaved block scale)
x_fp4, x_blockscale = ref_nvfp4_quant(x, layer.input_global_scale,
self.group_size)

# dequantize input
x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size)
x_blockscale = x_blockscale.unsqueeze(-1) / layer.input_global_scale
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
del x_fp4, x_blockscale

# dequantize weight
w_fp4 = layer.weight.data.view(torch.uint8)
w_blockscale = layer.weight_scale_swizzled.data
w_global_scale = layer.weight_global_scale
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
output_dtype, x.device, self.group_size)

# matmul
out = torch.matmul(x_dq, w_dq.t())
del w_dq, x_dq
return out

def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition

# Weight
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)

# Global Weight Scale
weight_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_global_scale", weight_global_scale)

# Per Group Weight Scale
weight_scale = GroupQuantScaleParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

layer.register_parameter("weight_scale", weight_scale)

input_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_global_scale", input_global_scale)

def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))

def process_weights_after_loading(self, layer) -> None:

global_input_scale = layer.input_global_scale.max().to(torch.float32)
layer.input_global_scale = Parameter(global_input_scale,
requires_grad=False)

layer.weight_global_scale = Parameter(
layer.weight_global_scale.max().to(torch.float32),
requires_grad=False)

swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)

# required by cutlass kernel; need Parameter, not ModelWeightParameter
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)

if self.cutlass_nvfp4_supported:
layer.alpha = Parameter(layer.input_global_scale *
layer.weight_global_scale,
requires_grad=False)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

if self.cutlass_nvfp4_supported:
output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight.shape[0]]

# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)

out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
layer.weight_scale_swizzled,
1 / layer.alpha, output_dtype)
if bias is not None:
out = out + bias
return out.view(*output_shape)
return self.run_nvfp4_emulations(x, layer)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def is_activation_quantization_format(format: str) -> bool:
CompressionFormat.naive_quantized.value,
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value,
"nvfp4-pack-quantized" # TODO: use enum value after next release
]
return format in _ACTIVATION_QUANTIZATION_FORMATS

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
import torch

__all__ = [
"break_fp4_bytes",
"dequantize_to_dtype",
]
from vllm.scalar_type import scalar_types

__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"]

FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()

kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
dtype=torch.float32)
Expand Down Expand Up @@ -59,3 +60,44 @@ def dequantize_to_dtype(tensor_fp4,
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out.to(dtype)


def get_reciprocal(x):
if isinstance(x, torch.Tensor):
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)
elif isinstance(x, (float, int)):
return 0.0 if x == 0 else 1.0 / x
else:
raise TypeError("Input must be a float, int, or a torch.Tensor.")


def cast_to_fp4(x):
sign = torch.sign(x)
x = torch.abs(x)
x[(x >= 0.0) & (x <= 0.25)] = 0.0
x[(x > 0.25) & (x < 0.75)] = 0.5
x[(x >= 0.75) & (x <= 1.25)] = 1.0
x[(x > 1.25) & (x < 1.75)] = 1.5
x[(x >= 1.75) & (x <= 2.5)] = 2.0
x[(x > 2.5) & (x < 3.5)] = 3.0
x[(x >= 3.5) & (x <= 5.0)] = 4.0
x[x > 5.0] = 6.0
return x * sign


def ref_nvfp4_quant(x, global_scale, block_size):
assert global_scale.dtype == torch.float32
assert x.ndim == 2
m, n = x.shape
x = torch.reshape(x, (m, n // block_size, block_size))
vec_max = torch.max(torch.abs(x), dim=-1,
keepdim=True)[0].to(torch.float32)
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
scale = torch.clamp(scale, max=448, min=-448)
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))

scaled_x = x.to(torch.float32) * output_scale
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
# both outputs are float32
return cast_to_fp4(clipped_x), scale.squeeze(-1)