Skip to content

Commit 6980284

Browse files
committed
fix conditions; add test models
Signed-off-by: Dipika <[email protected]>
1 parent ae335f7 commit 6980284

File tree

2 files changed

+28
-17
lines changed

2 files changed

+28
-17
lines changed

tests/quantization/test_compressed_tensors.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from tests.models.utils import check_logprobs_close
1414
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
1515
CompressedTensors24, CompressedTensorsLinearMethod,
16-
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
17-
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
18-
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
16+
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
17+
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
18+
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
19+
CompressedTensorsWNA16)
1920
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
2021
sparse_cutlass_supported)
2122
from vllm.platforms import current_platform
@@ -650,9 +651,13 @@ def check_model(model):
650651
assert output
651652

652653

653-
def test_compressed_tensors_nvfp4a16(vllm_runner):
654-
# run weight only example
655-
model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP4"
654+
# TODO: update model configs with next ct release
655+
@pytest.mark.parametrize("args", [
656+
("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP4", CompressedTensorsW4A16Fp4),
657+
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A4", CompressedTensorsW4A4Fp4)
658+
])
659+
def test_compressed_tensors_nvfp4(vllm_runner, args):
660+
model, scheme = args
656661
with vllm_runner(model, enforce_eager=True) as llm:
657662

658663
def check_model(model):
@@ -661,7 +666,7 @@ def check_model(model):
661666
qkv_proj = layer.self_attn.qkv_proj
662667
assert isinstance(qkv_proj.quant_method,
663668
CompressedTensorsLinearMethod)
664-
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
669+
assert isinstance(qkv_proj.scheme, scheme)
665670
assert qkv_proj.scheme.group_size == 16
666671

667672
llm.apply_model(check_model)

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,9 @@ def _check_scheme_supported(self,
219219

220220
def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
221221

222-
is_weight_act_quant = (weight_quant is not None
223-
and input_quant is not None)
222+
if weight_quant is None or input_quant is None:
223+
return False
224+
224225
is_group_quant = (
225226
weight_quant.strategy == QuantizationStrategy.GROUP.value)
226227
is_symmetric = weight_quant.symmetric and input_quant.symmetric
@@ -231,13 +232,18 @@ def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
231232
and input_quant.type == QuantizationType.FLOAT.value)
232233
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
233234

234-
return (is_weight_act_quant and is_group_quant and is_float_type
235-
and is_4_bits and is_group_size_16 and is_symmetric)
235+
return (is_group_quant and is_float_type and is_4_bits
236+
and is_group_size_16 and is_symmetric)
236237

237238
def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
238239
input_quant: BaseModel):
239240

240-
is_weight_only = weight_quant is not None and input_quant is None
241+
if weight_quant is None:
242+
return False
243+
244+
if input_quant is not None:
245+
return False
246+
241247
is_group_quant = (
242248
weight_quant.strategy == QuantizationStrategy.GROUP.value)
243249
is_symmetric = weight_quant.symmetric
@@ -246,8 +252,8 @@ def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
246252
is_float_type = weight_quant.type == QuantizationType.FLOAT
247253
is_4_bits = weight_quant.num_bits == 4
248254

249-
return (is_weight_only and is_group_quant and is_float_type
250-
and is_4_bits and is_group_size_16 and is_symmetric)
255+
return (is_group_quant and is_float_type and is_4_bits
256+
and is_group_size_16 and is_symmetric)
251257

252258
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
253259
input_quant: BaseModel) -> bool:
@@ -351,9 +357,6 @@ def _get_scheme_from_parts(
351357
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
352358
return CompressedTensorsW4A16Fp4()
353359

354-
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
355-
return CompressedTensorsW4A4Fp4()
356-
357360
if self._is_wNa16_group_channel(weight_quant, input_quant):
358361
if (self.quant_format == CompressionFormat.marlin_24.value
359362
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
@@ -372,6 +375,9 @@ def _get_scheme_from_parts(
372375
actorder=weight_quant.actorder)
373376

374377
if is_activation_quantization_format(self.quant_format):
378+
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
379+
return CompressedTensorsW4A4Fp4()
380+
375381
if self._is_fp8_w8a8(weight_quant, input_quant):
376382
is_fp8_w8a8_supported = self._check_scheme_supported(
377383
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)

0 commit comments

Comments
 (0)