Skip to content

Commit f6c7914

Browse files
committed
add cutlass support
1 parent b30168b commit f6c7914

File tree

1 file changed

+63
-24
lines changed

1 file changed

+63
-24
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py

Lines changed: 63 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
import torch
55
from torch.nn.parameter import Parameter
6-
6+
from vllm.logger import init_logger
7+
from vllm.platforms import current_platform
8+
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
9+
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
710
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
811
CompressedTensorsScheme)
912
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
@@ -12,13 +15,26 @@
1215
ModelWeightParameter,
1316
PerTensorScaleParameter)
1417

18+
logger = init_logger(__name__)
19+
1520
__all__ = ["CompressedTensorsW4A4Fp4"]
1621

1722

23+
def cutlass_fp4_supported() -> bool:
24+
if not current_platform.is_cuda():
25+
return False
26+
capability_tuple = current_platform.get_device_capability()
27+
capability = -1 if capability_tuple is None else capability_tuple.to_int()
28+
return cutlass_scaled_mm_supports_fp4(capability)
29+
1830
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
1931

2032
def __init__(self):
2133
self.group_size = 16
34+
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
35+
if not self.cutlass_nvfp4_supported:
36+
logger.warning("Current platform does not support cutlass NVFP4."
37+
" Running emulations.")
2238

2339
@classmethod
2440
def get_min_capability(cls) -> int:
@@ -101,37 +117,60 @@ def process_weights_after_loading(self, layer) -> None:
101117
layer.weight_global_scale.max().to(torch.float32),
102118
requires_grad=False)
103119

120+
121+
104122
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
105123
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
106124
requires_grad=False)
107125

126+
# Required by the cutlass kernel - need parameter input, not ModelWeightParameter
127+
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
128+
129+
if self.cutlass_nvfp4_supported:
130+
layer.alpha = Parameter(layer.input_global_scale * layer.weight_global_scale,
131+
requires_grad=False)
132+
108133
def apply_weights(self,
109134
layer: torch.nn.Module,
110135
x: torch.Tensor,
111136
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
112137

113-
x_m, x_k = x.shape
138+
if not self.cutlass_nvfp4_supported:
139+
x_m, x_k = x.shape
140+
output_dtype = x.dtype
141+
142+
# quantize input to (FP4 and interleaved block scale)
143+
x_global_scale = layer.input_global_scale
144+
x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale,
145+
self.group_size)
146+
147+
# dequantize input
148+
x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size)
149+
x_blockscale = x_blockscale.unsqueeze(-1) / x_global_scale
150+
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
151+
del x_fp4, x_blockscale
152+
153+
# dequantize weight
154+
w_fp4 = layer.weight.data.view(torch.uint8)
155+
w_blockscale = layer.weight_scale_swizzled.data
156+
w_global_scale = layer.weight_global_scale
157+
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
158+
output_dtype, x.device, self.group_size)
159+
160+
# matmul
161+
out = torch.matmul(x_dq, w_dq.t())
162+
del w_dq, x_dq
163+
return out
164+
114165
output_dtype = x.dtype
166+
output_shape = [x.shape[0], layer.weight.shape[0]]
167+
168+
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
169+
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
115170

116-
# quantize input to (FP4 and interleaved block scale)
117-
x_global_scale = layer.input_global_scale
118-
x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale,
119-
self.group_size)
120-
121-
# dequantize input
122-
x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size)
123-
x_blockscale = x_blockscale.unsqueeze(-1) / x_global_scale
124-
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
125-
del x_fp4, x_blockscale
126-
127-
# dequantize weight
128-
w_fp4 = layer.weight_packed.data.view(torch.uint8)
129-
w_blockscale = layer.weight_scale_swizzled.data
130-
w_global_scale = layer.weight_global_scale
131-
w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale,
132-
output_dtype, x.device, self.group_size)
133-
134-
# matmul
135-
out = torch.matmul(x_dq, w_dq.t())
136-
del w_dq, x_dq
137-
return out
171+
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
172+
layer.weight_scale_swizzled, 1 / layer.alpha,
173+
output_dtype)
174+
if bias is not None:
175+
out = out + bias
176+
return out.view(*output_shape)

0 commit comments

Comments
 (0)