Skip to content

fix int8/fp8 constant folding issue #3543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/build_wheels_linux_aarch64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,10 @@ jobs:
export PYTORCH_VERSION="$(${CONDA_RUN} pip show torch | grep ^Version: | sed 's/Version: *//' | sed 's/+.\+//')"
${CONDA_RUN} python setup.py clean
echo "Successfully ran `python setup.py clean`"
if [[ "$BUILD_VERSION" != *"+"${CU_VERSION} ]]; then
BUILD_VERSION="${BUILD_VERSION}+${CU_VERSION}"
fi
echo "BUILD_VERSION=$BUILD_VERSION"
if [[ ${{ inputs.is-jetpack }} == false ]]; then
${CONDA_RUN} python setup.py bdist_wheel
else
Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,9 @@ def aten_ops_neg(
)
else:

@dynamo_tensorrt_converter(torch.ops.tensorrt.quantize_op.default)
@dynamo_tensorrt_converter(
torch.ops.tensorrt.quantize_op.default, supports_dynamic_shapes=True
)
def aten_ops_quantize_op(
ctx: ConversionContext,
target: Target,
Expand Down
68 changes: 49 additions & 19 deletions py/torch_tensorrt/dynamo/conversion/impl/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,42 +28,72 @@ def quantize(
"""

with unset_fake_temporarily():
if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in (
trt.float32,
trt.float16,
):
raise ValueError(
f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16"
)
if isinstance(input_tensor, (torch.Tensor, TRTTensor)):
if input_tensor.dtype not in (
trt.float32,
trt.float16,
trt.bfloat16,
torch.bfloat16,
torch.float16,
torch.float32,
):
raise ValueError(
f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16 | bfloat16"
)
if num_bits != 8 or exponent_bits not in (0, 4):
raise ValueError(
f"quantize converter currently only accept INT8 or FP8 based quantize, got {num_bits=}, {exponent_bits=}"
)
else:
raise ValueError(
f"quantize converter received an input of {type(input_tensor)} type. Supported types: torch.Tensor | TRTTensor"
)

if num_bits == 8 and exponent_bits == 0:
dtype = trt.DataType.INT8
max_bound = 127
elif num_bits == 8 and exponent_bits == 4:
dtype = trt.DataType.FP8
max_bound = 448

amax = to_torch(amax, None)
axis = None
# int8 weight quantization is per-channel quantization(it can have one or multiple amax values)
if dtype == trt.DataType.INT8 and amax.numel() > 1:
# if the amax has more than one element, calculate the axis, otherwise axis value will be ignored
amax_init_shape = amax.shape
amax = amax.squeeze().data
assert (
len(amax.shape) == 1
), f"TensorRT does not support multi-axis quantization. {name=} {amax_init_shape=} {amax.shape=} "
axis = list(amax_init_shape).index(list(amax.shape)[0])
assert (
axis == 0
), f"{name=} {amax=} is per-channel quantization, expected axis to be 0, but got {axis=}"
else:
# int8 activation and fp8 weight/activation quantization is per-tensor quantization, it can only have single amax value
assert (
amax.numel() == 1
), f"{name=} is per-tensor quantization, expected amax is a singular value, but got {amax.shape=}"
scale = torch.divide(amax, max_bound)
scale.masked_fill_(scale == 0, 1.0)
scale = get_trt_tensor(ctx, scale, name + "_scale")
# Add Q node
quantize_layer = ctx.net.add_quantize(input_tensor, scale)
if num_bits == 8 and exponent_bits == 0:
quantize_layer.set_output_type(0, trt.DataType.INT8)
elif num_bits == 8 and exponent_bits == 4:
quantize_layer.set_output_type(0, trt.DataType.FP8)
input_tensor = get_trt_tensor(ctx, input_tensor, name)

# Add Q node
quantize_layer = ctx.net.add_quantize(input_tensor, scale, dtype)
if axis is not None:
quantize_layer.axis = axis
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
q_output = quantize_layer.get_output(0)
# Add DQ node
dequantize_layer = ctx.net.add_dequantize(q_output, scale)
dequantize_layer = ctx.net.add_dequantize(
q_output, scale, output_type=input_tensor.dtype
)
dequantize_layer.to_type = input_tensor.dtype
if axis is not None:
dequantize_layer.axis = axis
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
if num_bits == 8 and exponent_bits == 0:
dequantize_layer.precision = trt.DataType.INT8
elif num_bits == 8 and exponent_bits == 4:
# Set DQ layer precision to FP8
dequantize_layer.precision = trt.DataType.FP8
dq_output = dequantize_layer.get_output(0)

return dq_output
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:

# TODO: Update this function when quantization is added
def is_impure(self, node: torch.fx.node.Node) -> bool:
if node.target in (torch.ops.tensorrt.quantize_op.default,):
return True
return False
56 changes: 56 additions & 0 deletions tests/py/dynamo/models/test_models_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,59 @@ def calibrate_loop(model):
)
outputs_trt = trt_model(input_tensor)
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-3, atol=1e-2)


@unittest.skipIf(
platform.system() != "Linux"
or not importlib.util.find_spec("modelopt")
or Version(metadata.version("nvidia-modelopt")) < Version("0.17.0"),
"modelopt 0.17.0 or later is required, Int8 quantization is supported in modelopt since 0.17.0 or later for linux",
)
@pytest.mark.unit
def test_base_int8_dynamic_shape(ir):
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode

dtype = torch.bfloat16

class SimpleNetwork(torch.nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3, dtype=dtype)
self.linear = torch.nn.Linear(222, 222, dtype=dtype)

def forward(self, x):
return self.linear(self.conv(x))

def calibrate_loop(model):
"""Simple calibration function for testing."""
model(input_tensor)

BATCH_SIZE = torch.export.Dim("BATCH_SIZE", min=2, max=16)
batch_size = 8
input_tensor = torch.randn(batch_size, 3, 224, 224, dtype=dtype).cuda()
model = SimpleNetwork().eval().cuda()

quant_cfg = mtq.INT8_DEFAULT_CFG
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)

# model has INT8 qdq nodes at this point
output_pyt = model(input_tensor)

with torch.no_grad():
with export_torch_mode():
exp_program = torch.export.export(
model, (input_tensor,), strict=False, dynamic_shapes=({0: BATCH_SIZE},)
)
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
enabled_precisions={torch.int8, dtype},
min_block_size=1,
debug=True,
cache_built_engines=False,
reuse_cached_engines=False,
truncate_double=True,
)
outputs_trt = trt_model(input_tensor)
assert torch.allclose(output_pyt, outputs_trt, rtol=5e-2, atol=5e-2)
Loading