|
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | from torch.nn.parameter import Parameter
|
6 |
| - |
| 6 | +from vllm.logger import init_logger |
| 7 | +from vllm.platforms import current_platform |
| 8 | +from vllm._custom_ops import (cutlass_scaled_fp4_mm, |
| 9 | + cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) |
7 | 10 | from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
8 | 11 | CompressedTensorsScheme)
|
9 | 12 | from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
|
12 | 15 | ModelWeightParameter,
|
13 | 16 | PerTensorScaleParameter)
|
14 | 17 |
|
| 18 | +logger = init_logger(__name__) |
| 19 | + |
15 | 20 | __all__ = ["CompressedTensorsW4A4Fp4"]
|
16 | 21 |
|
17 | 22 |
|
| 23 | +def cutlass_fp4_supported() -> bool: |
| 24 | + if not current_platform.is_cuda(): |
| 25 | + return False |
| 26 | + capability_tuple = current_platform.get_device_capability() |
| 27 | + capability = -1 if capability_tuple is None else capability_tuple.to_int() |
| 28 | + return cutlass_scaled_mm_supports_fp4(capability) |
| 29 | + |
18 | 30 | class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
19 | 31 |
|
20 | 32 | def __init__(self):
|
21 | 33 | self.group_size = 16
|
| 34 | + self.cutlass_nvfp4_supported = cutlass_fp4_supported() |
| 35 | + if not self.cutlass_nvfp4_supported: |
| 36 | + logger.warning("Current platform does not support cutlass NVFP4." |
| 37 | + " Running emulations.") |
22 | 38 |
|
23 | 39 | @classmethod
|
24 | 40 | def get_min_capability(cls) -> int:
|
@@ -101,37 +117,60 @@ def process_weights_after_loading(self, layer) -> None:
|
101 | 117 | layer.weight_global_scale.max().to(torch.float32),
|
102 | 118 | requires_grad=False)
|
103 | 119 |
|
| 120 | + |
| 121 | + |
104 | 122 | swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
|
105 | 123 | layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
106 | 124 | requires_grad=False)
|
107 | 125 |
|
| 126 | + # Required by the cutlass kernel - need parameter input, not ModelWeightParameter |
| 127 | + layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) |
| 128 | + |
| 129 | + if self.cutlass_nvfp4_supported: |
| 130 | + layer.alpha = Parameter(layer.input_global_scale * layer.weight_global_scale, |
| 131 | + requires_grad=False) |
| 132 | + |
108 | 133 | def apply_weights(self,
|
109 | 134 | layer: torch.nn.Module,
|
110 | 135 | x: torch.Tensor,
|
111 | 136 | bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
112 | 137 |
|
113 |
| - x_m, x_k = x.shape |
| 138 | + if not self.cutlass_nvfp4_supported: |
| 139 | + x_m, x_k = x.shape |
| 140 | + output_dtype = x.dtype |
| 141 | + |
| 142 | + # quantize input to (FP4 and interleaved block scale) |
| 143 | + x_global_scale = layer.input_global_scale |
| 144 | + x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale, |
| 145 | + self.group_size) |
| 146 | + |
| 147 | + # dequantize input |
| 148 | + x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size) |
| 149 | + x_blockscale = x_blockscale.unsqueeze(-1) / x_global_scale |
| 150 | + x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) |
| 151 | + del x_fp4, x_blockscale |
| 152 | + |
| 153 | + # dequantize weight |
| 154 | + w_fp4 = layer.weight.data.view(torch.uint8) |
| 155 | + w_blockscale = layer.weight_scale_swizzled.data |
| 156 | + w_global_scale = layer.weight_global_scale |
| 157 | + w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, |
| 158 | + output_dtype, x.device, self.group_size) |
| 159 | + |
| 160 | + # matmul |
| 161 | + out = torch.matmul(x_dq, w_dq.t()) |
| 162 | + del w_dq, x_dq |
| 163 | + return out |
| 164 | + |
114 | 165 | output_dtype = x.dtype
|
| 166 | + output_shape = [x.shape[0], layer.weight.shape[0]] |
| 167 | + |
| 168 | + # quantize BF16 or FP16 to (FP4 and interleaved block scale) |
| 169 | + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) |
115 | 170 |
|
116 |
| - # quantize input to (FP4 and interleaved block scale) |
117 |
| - x_global_scale = layer.input_global_scale |
118 |
| - x_fp4, x_blockscale = ref_nvfp4_quant(x, x_global_scale, |
119 |
| - self.group_size) |
120 |
| - |
121 |
| - # dequantize input |
122 |
| - x_fp4 = x_fp4.reshape(x_m, x_k // self.group_size, self.group_size) |
123 |
| - x_blockscale = x_blockscale.unsqueeze(-1) / x_global_scale |
124 |
| - x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) |
125 |
| - del x_fp4, x_blockscale |
126 |
| - |
127 |
| - # dequantize weight |
128 |
| - w_fp4 = layer.weight_packed.data.view(torch.uint8) |
129 |
| - w_blockscale = layer.weight_scale_swizzled.data |
130 |
| - w_global_scale = layer.weight_global_scale |
131 |
| - w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, |
132 |
| - output_dtype, x.device, self.group_size) |
133 |
| - |
134 |
| - # matmul |
135 |
| - out = torch.matmul(x_dq, w_dq.t()) |
136 |
| - del w_dq, x_dq |
137 |
| - return out |
| 171 | + out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, |
| 172 | + layer.weight_scale_swizzled, 1 / layer.alpha, |
| 173 | + output_dtype) |
| 174 | + if bias is not None: |
| 175 | + out = out + bias |
| 176 | + return out.view(*output_shape) |
0 commit comments