Skip to content

[FX] Fix clamping float32 boundary values, aten2trt init check-in, fix slice issues #1415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/_sources/tutorials/ptq.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Then all thats required to setup the module for INT8 calibration is to set the f
If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well.
From here not much changes in terms of how to execution works. You are still able to fully use LibTorch as the sole interface for inference. Data should remain
in FP32 precision when it's passed into `trt_mod.forward`. There exists an example application in the Torch-TensorRT demo that takes you from training a VGG16 network on
CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/cpp/ptq
CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/examples/int8/ptq

.. _writing_ptq_python:

Expand Down Expand Up @@ -194,8 +194,8 @@ to use ``CacheCalibrator`` to use in INT8 mode.
calibrator=calibrator)

If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient.
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.py
and https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_trt_calibrator.py
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_dataloader_calibrator.py
and https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_trt_calibrator.py

Citations
^^^^^^^^^^^
Expand Down
8 changes: 6 additions & 2 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2854,8 +2854,12 @@ def add_clamp(network, input, val, op, name):
else:
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
acc_ops_clamp_tensor = (
val
* torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
(
val
* torch.ones(
acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)
)
)
.cpu()
.numpy()
)
Expand Down
150 changes: 149 additions & 1 deletion py/torch_tensorrt/fx/passes/lower_basic_pass.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import copy
import logging
import operator
import warnings
from typing import Any
from typing import Any, Optional

import torch
import torch.fx
import torch.fx as fx
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch.fx.experimental.const_fold import split_const_subgraphs

from ..observer import observable
Expand All @@ -13,6 +16,8 @@
from ..tracer.acc_tracer.acc_utils import get_attr
from .pass_utils import log_before_after, validate_inference

_LOGGER = logging.getLogger(__name__)

# Create an alias for module input type to avoid littering pyre-ignore for Any
# throughout the file.
Input = Any
Expand Down Expand Up @@ -460,3 +465,146 @@ def transform_setitem(gm: torch.fx.GraphModule, input: Input):
gm.graph.lint()
gm.recompile()
return gm


def fix_reshape_batch_dim(mod: fx.GraphModule) -> fx.GraphModule:
"""\
TRT cannot reason about shape patterns like x.reshape(y.size(0), -1, 256),
since the dynamic shape of the reshape comes from the dynamic shape of
another node (y). The compilation will fail with various memory related
errors, depending on the size of the input tensor.

This pass fixes the issue by finding this reshape pattern, checking that:

x.size(0) == y.size(0)

And then replaces reshape's batch size from y.size(0) to x.size(0).
"""

def get_reshape_batch_size_as_node(maybe_reshape: fx.Node) -> Optional[fx.Node]:
"""\
Try to find the reshape op's batch size as an input node.

Match below graph structure and return `node_y`:
node_x.reshape({"acc_out_ty": {"shape": (node_y, ...)}})
"""
if (
maybe_reshape.op != "call_function"
or maybe_reshape.target != acc_ops.reshape
):
return None
shape = getattr(maybe_reshape.kwargs["acc_out_ty"], "shape", None)
if not shape:
return None
batch_size = shape[0]
if isinstance(batch_size, fx.Node):
return batch_size
return None

def get_reshape_batch_size_inferred_source(
batch_size_node: fx.Node,
) -> Optional[fx.Node]:
"""\
Given a node representing the batch size used for reshape op, we want
to know if it is coming from below pattern:

batch_size_node = src.size()[0]

or in IR graph:

src -> size(input=_) -> getitem(input=_, idx=0)
^ ~~~ batch_size_node

If so, return `src`. Otherwise, return `None`.
"""
if (
batch_size_node.op != "call_function"
or batch_size_node.target != acc_ops.getitem
or batch_size_node.kwargs["idx"] != 0
):
return None
maybe_size: fx.Node = batch_size_node.all_input_nodes[0]
if maybe_size.op != "call_function" or maybe_size.target != acc_ops.size:
return None
return maybe_size.all_input_nodes[0]

maybe_reshape: fx.Node
for maybe_reshape in mod.graph.nodes:
reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
maybe_reshape
)
if not reshape_batch_size:
continue
reshape_batch_size_inferred_source: Optional[
fx.Node
] = get_reshape_batch_size_inferred_source(reshape_batch_size)
if not reshape_batch_size_inferred_source:
continue

reshape_input: fx.Node = maybe_reshape.kwargs["input"]
if reshape_input == reshape_batch_size_inferred_source:
continue

if not _is_batch_size_equal(reshape_input, reshape_batch_size_inferred_source):
continue

_LOGGER.info(
f"{fix_reshape_batch_dim}: Found bad pattern: y.reshape((x, ...)). Reshape node: {maybe_reshape}"
)

# Step 1: create a node to compute batch size, using the tensor which
# is being reshaped: reshape_input.size()[0]. This batch size is now
# derived from reshape_input, the same node as the reshape op's input.
with mod.graph.inserting_before(maybe_reshape):
reshape_batch_size_2: fx.Node = maybe_reshape.graph.call_function(
acc_ops.getitem,
kwargs={
"idx": 0,
"input": maybe_reshape.graph.call_function(
acc_ops.size,
kwargs={
"input": reshape_input,
},
),
},
)

# Step 2: update `maybe_reshape`'s shape argument to be
# (reshape_batch_size_2, *DONT_CARE_JUST_COPY_OVER)
maybe_reshape.kwargs = {
**maybe_reshape.kwargs,
"acc_out_ty": acc_utils.build_raw_tensor_meta(
shape=(
reshape_batch_size_2,
*(maybe_reshape.kwargs["acc_out_ty"].shape[1:]),
)
),
}

mod.graph.eliminate_dead_code()
mod.recompile()
return mod


def _is_batch_size_equal(x: fx.Node, y: fx.Node) -> bool:
"""\
Check that x.size(0) == y.size(0)
"""
x_size, y_size = _get_shape(x), _get_shape(y)
return (
x_size
and y_size
# now both are non-empty
and x_size[0] == y_size[0]
)


def _get_shape(node: fx.Node) -> Optional[torch.Size]:
if (
not getattr(node, "meta", None)
or not node.meta.get("tensor_meta", None)
or not getattr(node.meta["tensor_meta"], "shape", None)
):
# shape info not available
return None
return node.meta["tensor_meta"].shape
2 changes: 2 additions & 0 deletions py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .graph_opts import common_subexpression_elimination

from .lower_basic_pass import (
fix_reshape_batch_dim,
replace_mutable_op,
replace_op_with_indices,
run_const_fold,
Expand Down Expand Up @@ -112,6 +113,7 @@ def graph_optimization_pass(self) -> PassManager:
passes.append(
inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input))
)
passes.append(fix_reshape_batch_dim)

return PassManager.build_from_passlist(passes)

Expand Down
11 changes: 10 additions & 1 deletion py/torch_tensorrt/fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import logging
import tempfile
from functools import wraps
Expand Down Expand Up @@ -233,15 +234,23 @@ def log_before_after(pass_: PassFunc) -> PassFunc:
def pass_with_before_after_log(
module: fx.GraphModule, input: Input
) -> fx.GraphModule:
before_io = io.StringIO()
after_io = io.StringIO()
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
delete=False,
) as f:
_LOGGER.info(f"== Log pass {pass_} before/after graph to {f.name}")
print(f"[{pass_}] Before:\n{module.graph}", file=f)
print(module.graph, file=before_io)

module = pass_(module, input)
print(f"[{pass_}] After:\n{module.graph}", file=f)
print(module.graph, file=after_io)
t = before_io.getvalue() == after_io.getvalue()
_LOGGER.info(
f"== Log pass {pass_} before/after graph to {f.name}, before/after are the same = {t}"
)
return module

return pass_with_before_after_log
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class TestClampConverter(AccTestCase):
param("min", min=0.5),
param("max", max=0.5),
param("minBiggerThanMax", min=1, max=0),
param("float32Boundary", min=-3.4028234663852886e38),
]
)
def test_clamp(
Expand Down
51 changes: 51 additions & 0 deletions py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Owner(s): ["oncall: gpu_enablement"]

import logging
from copy import deepcopy

import torch
import torch.fx as fx
import torch.nn as nn

from torch.testing._internal.common_utils import run_tests, TestCase
from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer

_LOGGER = logging.getLogger(__name__)


class TestFixReshapeBatchDim(TestCase):
def test_fix_reshape_batch_dim(self):
class Repro(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return y.view(x.size(0), -1, 3)

mod = Repro()
modt = fx.symbolic_trace(mod)
inp = [
torch.rand([10, 60]),
torch.rand([10, 60]),
]
mod(*inp)
mod_acc_traced = acc_tracer.trace(modt, inp)
mod_fixed = fix_reshape_batch_dim(deepcopy(mod_acc_traced))

expected_graph = r"""
graph():
%x : [#users=0] = placeholder[target=x]
%y : [#users=2] = placeholder[target=y]
%size : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.size](args = (), kwargs = {input: %y})
%getitem_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.getitem](args = (), kwargs = {idx: 0, input: %size})
%reshape : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)})
return reshape
"""
assert (
str(mod_fixed.graph).strip() == expected_graph.strip()
), f"Unexpected fixed graph. \nActual: {str(mod_fixed.graph)} \nExpected: {expected_graph}"


if __name__ == "__main__":
run_tests()
25 changes: 25 additions & 0 deletions py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2566,6 +2566,31 @@ def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
# Make sure we didn't convert to the acc version
self.assertEqual(node.target, operator.getitem)

def test_detach(self):
class TestModule(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.detach(x)

m = TestModule()
sample_inputs = [torch.randn(8)]
traced = acc_tracer.trace(m, sample_inputs)

placeholder = output = None
for node in traced.graph.nodes:
if node.op == "placeholder":
assert placeholder is None
placeholder = node
elif node.op == "output":
assert output is None
output = node
else:
raise RuntimeError(f"Unexpected Node {node.format_node()}")

self.assertIsNotNone(placeholder)
self.assertIsNotNone(output)

self.assertTrue(torch.equal(m(*sample_inputs), traced(*sample_inputs)))

def test_all_acc_ops_registered(self):
self.assertEqual(
acc_normalizer._acc_ops,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def f(x):
inputs = [torch.ones(32, 3, 224, 224)]
inputs = [i.cuda().half() for i in inputs]
torchdynamo.reset()
dynamo_aten_mod = torchdynamo.optimize(backends.fx2trt_aten_compiler_fp16)(mod)
dynamo_aten_mod = torchdynamo.optimize(backends.fx2trt_compiler_fp16)(mod)
dynamo_aten_output = dynamo_aten_mod(*inputs)

torchdynamo.reset()
Expand Down
6 changes: 6 additions & 0 deletions py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,9 @@ def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:

@register_acc_op_properties(AccOpProperty.pointwise)
@register_acc_op_mapping(op_and_target=("call_function", torch.clamp))
@register_acc_op_mapping(op_and_target=("call_function", torch.clip))
@register_acc_op_mapping(op_and_target=("call_method", "clamp"))
@register_acc_op_mapping(op_and_target=("call_method", "clip"))
@register_acc_op
def clamp(*, input, min=None, max=None):
return torch.clamp(input=input, min=min, max=max)
Expand Down Expand Up @@ -818,6 +820,10 @@ def matmul(*, input, other):
@register_custom_acc_mapper_fn(
op_and_target=("call_method", "detach"), arg_replacement_tuples=[("input", "input")]
)
@register_custom_acc_mapper_fn(
op_and_target=("call_function", torch.detach),
arg_replacement_tuples=[("input", "input")],
)
def dropout_mapper(node: torch.fx.Node, mod: nn.Module):
"""
Remove dropout node and directly map its input to output.
Expand Down