@@ -219,8 +219,9 @@ def _check_scheme_supported(self,
219
219
220
220
def _is_fp4a4_nvfp4 (self , weight_quant : BaseModel , input_quant : BaseModel ):
221
221
222
- is_weight_act_quant = (weight_quant is not None
223
- and input_quant is not None )
222
+ if weight_quant is None or input_quant is None :
223
+ return False
224
+
224
225
is_group_quant = (
225
226
weight_quant .strategy == QuantizationStrategy .GROUP .value )
226
227
is_symmetric = weight_quant .symmetric and input_quant .symmetric
@@ -231,13 +232,18 @@ def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel):
231
232
and input_quant .type == QuantizationType .FLOAT .value )
232
233
is_4_bits = weight_quant .num_bits == 4 and input_quant .num_bits == 4
233
234
234
- return (is_weight_act_quant and is_group_quant and is_float_type
235
- and is_4_bits and is_group_size_16 and is_symmetric )
235
+ return (is_group_quant and is_float_type and is_4_bits
236
+ and is_group_size_16 and is_symmetric )
236
237
237
238
def _is_fp4a16_nvfp4 (self , weight_quant : BaseModel ,
238
239
input_quant : BaseModel ):
239
240
240
- is_weight_only = weight_quant is not None and input_quant is None
241
+ if weight_quant is None :
242
+ return False
243
+
244
+ if input_quant is not None :
245
+ return False
246
+
241
247
is_group_quant = (
242
248
weight_quant .strategy == QuantizationStrategy .GROUP .value )
243
249
is_symmetric = weight_quant .symmetric
@@ -246,8 +252,8 @@ def _is_fp4a16_nvfp4(self, weight_quant: BaseModel,
246
252
is_float_type = weight_quant .type == QuantizationType .FLOAT
247
253
is_4_bits = weight_quant .num_bits == 4
248
254
249
- return (is_weight_only and is_group_quant and is_float_type
250
- and is_4_bits and is_group_size_16 and is_symmetric )
255
+ return (is_group_quant and is_float_type and is_4_bits
256
+ and is_group_size_16 and is_symmetric )
251
257
252
258
def _is_static_tensor_w8a8 (self , weight_quant : BaseModel ,
253
259
input_quant : BaseModel ) -> bool :
@@ -351,9 +357,6 @@ def _get_scheme_from_parts(
351
357
if self ._is_fp4a16_nvfp4 (weight_quant , input_quant ):
352
358
return CompressedTensorsW4A16Fp4 ()
353
359
354
- if self ._is_fp4a4_nvfp4 (weight_quant , input_quant ):
355
- return CompressedTensorsW4A4Fp4 ()
356
-
357
360
if self ._is_wNa16_group_channel (weight_quant , input_quant ):
358
361
if (self .quant_format == CompressionFormat .marlin_24 .value
359
362
and weight_quant .num_bits in W4A16SPARSE24_SUPPORTED_BITS ):
@@ -372,6 +375,9 @@ def _get_scheme_from_parts(
372
375
actorder = weight_quant .actorder )
373
376
374
377
if is_activation_quantization_format (self .quant_format ):
378
+ if self ._is_fp4a4_nvfp4 (weight_quant , input_quant ):
379
+ return CompressedTensorsW4A4Fp4 ()
380
+
375
381
if self ._is_fp8_w8a8 (weight_quant , input_quant ):
376
382
is_fp8_w8a8_supported = self ._check_scheme_supported (
377
383
CompressedTensorsW8A8Fp8 .get_min_capability (), error = False )
0 commit comments