Skip to content

Commit 6e79f1f

Browse files
author
Wei
authored
Changes done internally at Facebook (#1415)
afdc533da031a64e162bb08c8629ff38739e24f8 Wei Wei <[email protected]> [fx2trt] disable dispatch trace leaf node test c22f691e6eae1b06ecd301eb6285b32d5dc9717c Mike Iovine <[email protected]> [fx2trt] Support dict inputs in acc tracer 8c05a3c57b1f5c63108b979ef8c61411525d0b1f Mike Iovine <[email protected]> [fx2trt] Support namedtuple access in acc tracer getattr 1580805d827eb40c941e769b0b99e7c6a3ed6f89 Wei Wei <[email protected]> [fx2trt] add reshape unit test baab27b81b1275de92fdaf760a158ce951564d33 Donglin Xia <[email protected]> Register avg_pool3d for acc_op in acc_op.py ae4c4e2c3c18d78542140fcc30e1c24f7c647ef3 Wei Wei <[email protected]> [aten2trt] init check-in 87ef03338c9a25c5a610a2eb590345e8935f8d75 Wei Wei <[email protected]> [aten2trt] add binary ops 2bb168517ace7e638cffc7a241b1cbf528790b92 Mike Iovine <[email protected]> [fx2trt] Add acc normalization blocklist 137a3977ffeb03d0387e8a95ff2f32f3d15b3de8 Wei Wei <[email protected]> [aten2trt] resnet support fef54c237589a70c007c861e2d59c4052e3de054 Kefei Lu <[email protected]> [easy] fx2xxx: fix fuse_parallel_linear which changes getitem slices from tuple to list 4b062ef361cd7797e72c51bb4dc41766aca7b6db Kefei Lu <[email protected]> fx2trt: fix bad reshape pattern x.reshape(y.size(0), ...) 49573920892bb2fe75fe011a8cad9887bdc8bd04 Alex Beloi <[email protected]> [FX] add tracing for torch.detach 42c54d69c68dc58ac348647acada88b1e5634b40 Fei Kou <[email protected]> Fix clamping float32 boundary values e013621dedf5960f81b915cef8d2ce19ca349a7a Kefei Lu <[email protected]> trt lower: change preset application logic to in-place instead of immutable update adc9f8ff48c01a0ce70080c930221ac81f048563 Kefei Lu <[email protected]> [easy]: fix another instance of [slice(), ...] to (slice(), ...) e9cc5f4b676df502a80a4b85586096e4a3e6a9d6 Charles David Hernandez <[email protected]> [docs] fix broken links f06174dbb190df4ea488ca99a81d4884b5ed3aa2 wwei6 <[email protected]> [fx2trt] compile
1 parent a9a4bb2 commit 6e79f1f

File tree

10 files changed

+254
-8
lines changed

10 files changed

+254
-8
lines changed

docs/_sources/tutorials/ptq.rst.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ Then all thats required to setup the module for INT8 calibration is to set the f
136136
If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well.
137137
From here not much changes in terms of how to execution works. You are still able to fully use LibTorch as the sole interface for inference. Data should remain
138138
in FP32 precision when it's passed into `trt_mod.forward`. There exists an example application in the Torch-TensorRT demo that takes you from training a VGG16 network on
139-
CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/cpp/ptq
139+
CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/examples/int8/ptq
140140

141141
.. _writing_ptq_python:
142142

@@ -194,8 +194,8 @@ to use ``CacheCalibrator`` to use in INT8 mode.
194194
calibrator=calibrator)
195195
196196
If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient.
197-
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.py
198-
and https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_trt_calibrator.py
197+
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_dataloader_calibrator.py
198+
and https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_trt_calibrator.py
199199

200200
Citations
201201
^^^^^^^^^^^

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2854,8 +2854,12 @@ def add_clamp(network, input, val, op, name):
28542854
else:
28552855
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
28562856
acc_ops_clamp_tensor = (
2857-
val
2858-
* torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype))
2857+
(
2858+
val
2859+
* torch.ones(
2860+
acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)
2861+
)
2862+
)
28592863
.cpu()
28602864
.numpy()
28612865
)

py/torch_tensorrt/fx/passes/lower_basic_pass.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import copy
2+
import logging
23
import operator
34
import warnings
4-
from typing import Any
5+
from typing import Any, Optional
56

67
import torch
78
import torch.fx
9+
import torch.fx as fx
10+
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
811
from torch.fx.experimental.const_fold import split_const_subgraphs
912

1013
from ..observer import observable
@@ -13,6 +16,8 @@
1316
from ..tracer.acc_tracer.acc_utils import get_attr
1417
from .pass_utils import log_before_after, validate_inference
1518

19+
_LOGGER = logging.getLogger(__name__)
20+
1621
# Create an alias for module input type to avoid littering pyre-ignore for Any
1722
# throughout the file.
1823
Input = Any
@@ -460,3 +465,146 @@ def transform_setitem(gm: torch.fx.GraphModule, input: Input):
460465
gm.graph.lint()
461466
gm.recompile()
462467
return gm
468+
469+
470+
def fix_reshape_batch_dim(mod: fx.GraphModule) -> fx.GraphModule:
471+
"""\
472+
TRT cannot reason about shape patterns like x.reshape(y.size(0), -1, 256),
473+
since the dynamic shape of the reshape comes from the dynamic shape of
474+
another node (y). The compilation will fail with various memory related
475+
errors, depending on the size of the input tensor.
476+
477+
This pass fixes the issue by finding this reshape pattern, checking that:
478+
479+
x.size(0) == y.size(0)
480+
481+
And then replaces reshape's batch size from y.size(0) to x.size(0).
482+
"""
483+
484+
def get_reshape_batch_size_as_node(maybe_reshape: fx.Node) -> Optional[fx.Node]:
485+
"""\
486+
Try to find the reshape op's batch size as an input node.
487+
488+
Match below graph structure and return `node_y`:
489+
node_x.reshape({"acc_out_ty": {"shape": (node_y, ...)}})
490+
"""
491+
if (
492+
maybe_reshape.op != "call_function"
493+
or maybe_reshape.target != acc_ops.reshape
494+
):
495+
return None
496+
shape = getattr(maybe_reshape.kwargs["acc_out_ty"], "shape", None)
497+
if not shape:
498+
return None
499+
batch_size = shape[0]
500+
if isinstance(batch_size, fx.Node):
501+
return batch_size
502+
return None
503+
504+
def get_reshape_batch_size_inferred_source(
505+
batch_size_node: fx.Node,
506+
) -> Optional[fx.Node]:
507+
"""\
508+
Given a node representing the batch size used for reshape op, we want
509+
to know if it is coming from below pattern:
510+
511+
batch_size_node = src.size()[0]
512+
513+
or in IR graph:
514+
515+
src -> size(input=_) -> getitem(input=_, idx=0)
516+
^ ~~~ batch_size_node
517+
518+
If so, return `src`. Otherwise, return `None`.
519+
"""
520+
if (
521+
batch_size_node.op != "call_function"
522+
or batch_size_node.target != acc_ops.getitem
523+
or batch_size_node.kwargs["idx"] != 0
524+
):
525+
return None
526+
maybe_size: fx.Node = batch_size_node.all_input_nodes[0]
527+
if maybe_size.op != "call_function" or maybe_size.target != acc_ops.size:
528+
return None
529+
return maybe_size.all_input_nodes[0]
530+
531+
maybe_reshape: fx.Node
532+
for maybe_reshape in mod.graph.nodes:
533+
reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(
534+
maybe_reshape
535+
)
536+
if not reshape_batch_size:
537+
continue
538+
reshape_batch_size_inferred_source: Optional[
539+
fx.Node
540+
] = get_reshape_batch_size_inferred_source(reshape_batch_size)
541+
if not reshape_batch_size_inferred_source:
542+
continue
543+
544+
reshape_input: fx.Node = maybe_reshape.kwargs["input"]
545+
if reshape_input == reshape_batch_size_inferred_source:
546+
continue
547+
548+
if not _is_batch_size_equal(reshape_input, reshape_batch_size_inferred_source):
549+
continue
550+
551+
_LOGGER.info(
552+
f"{fix_reshape_batch_dim}: Found bad pattern: y.reshape((x, ...)). Reshape node: {maybe_reshape}"
553+
)
554+
555+
# Step 1: create a node to compute batch size, using the tensor which
556+
# is being reshaped: reshape_input.size()[0]. This batch size is now
557+
# derived from reshape_input, the same node as the reshape op's input.
558+
with mod.graph.inserting_before(maybe_reshape):
559+
reshape_batch_size_2: fx.Node = maybe_reshape.graph.call_function(
560+
acc_ops.getitem,
561+
kwargs={
562+
"idx": 0,
563+
"input": maybe_reshape.graph.call_function(
564+
acc_ops.size,
565+
kwargs={
566+
"input": reshape_input,
567+
},
568+
),
569+
},
570+
)
571+
572+
# Step 2: update `maybe_reshape`'s shape argument to be
573+
# (reshape_batch_size_2, *DONT_CARE_JUST_COPY_OVER)
574+
maybe_reshape.kwargs = {
575+
**maybe_reshape.kwargs,
576+
"acc_out_ty": acc_utils.build_raw_tensor_meta(
577+
shape=(
578+
reshape_batch_size_2,
579+
*(maybe_reshape.kwargs["acc_out_ty"].shape[1:]),
580+
)
581+
),
582+
}
583+
584+
mod.graph.eliminate_dead_code()
585+
mod.recompile()
586+
return mod
587+
588+
589+
def _is_batch_size_equal(x: fx.Node, y: fx.Node) -> bool:
590+
"""\
591+
Check that x.size(0) == y.size(0)
592+
"""
593+
x_size, y_size = _get_shape(x), _get_shape(y)
594+
return (
595+
x_size
596+
and y_size
597+
# now both are non-empty
598+
and x_size[0] == y_size[0]
599+
)
600+
601+
602+
def _get_shape(node: fx.Node) -> Optional[torch.Size]:
603+
if (
604+
not getattr(node, "meta", None)
605+
or not node.meta.get("tensor_meta", None)
606+
or not getattr(node.meta["tensor_meta"], "shape", None)
607+
):
608+
# shape info not available
609+
return None
610+
return node.meta["tensor_meta"].shape

py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .graph_opts import common_subexpression_elimination
1818

1919
from .lower_basic_pass import (
20+
fix_reshape_batch_dim,
2021
replace_mutable_op,
2122
replace_op_with_indices,
2223
run_const_fold,
@@ -112,6 +113,7 @@ def graph_optimization_pass(self) -> PassManager:
112113
passes.append(
113114
inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input))
114115
)
116+
passes.append(fix_reshape_batch_dim)
115117

116118
return PassManager.build_from_passlist(passes)
117119

py/torch_tensorrt/fx/passes/pass_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import logging
23
import tempfile
34
from functools import wraps
@@ -233,15 +234,23 @@ def log_before_after(pass_: PassFunc) -> PassFunc:
233234
def pass_with_before_after_log(
234235
module: fx.GraphModule, input: Input
235236
) -> fx.GraphModule:
237+
before_io = io.StringIO()
238+
after_io = io.StringIO()
236239
with tempfile.NamedTemporaryFile(
237240
mode="w",
238241
encoding="utf-8",
239242
delete=False,
240243
) as f:
241-
_LOGGER.info(f"== Log pass {pass_} before/after graph to {f.name}")
242244
print(f"[{pass_}] Before:\n{module.graph}", file=f)
245+
print(module.graph, file=before_io)
246+
243247
module = pass_(module, input)
244248
print(f"[{pass_}] After:\n{module.graph}", file=f)
249+
print(module.graph, file=after_io)
250+
t = before_io.getvalue() == after_io.getvalue()
251+
_LOGGER.info(
252+
f"== Log pass {pass_} before/after graph to {f.name}, before/after are the same = {t}"
253+
)
245254
return module
246255

247256
return pass_with_before_after_log

py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class TestClampConverter(AccTestCase):
1212
param("min", min=0.5),
1313
param("max", max=0.5),
1414
param("minBiggerThanMax", min=1, max=0),
15+
param("float32Boundary", min=-3.4028234663852886e38),
1516
]
1617
)
1718
def test_clamp(
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Owner(s): ["oncall: gpu_enablement"]
2+
3+
import logging
4+
from copy import deepcopy
5+
6+
import torch
7+
import torch.fx as fx
8+
import torch.nn as nn
9+
10+
from torch.testing._internal.common_utils import run_tests, TestCase
11+
from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim
12+
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
13+
14+
_LOGGER = logging.getLogger(__name__)
15+
16+
17+
class TestFixReshapeBatchDim(TestCase):
18+
def test_fix_reshape_batch_dim(self):
19+
class Repro(nn.Module):
20+
def __init__(self):
21+
super().__init__()
22+
23+
def forward(self, x, y):
24+
return y.view(x.size(0), -1, 3)
25+
26+
mod = Repro()
27+
modt = fx.symbolic_trace(mod)
28+
inp = [
29+
torch.rand([10, 60]),
30+
torch.rand([10, 60]),
31+
]
32+
mod(*inp)
33+
mod_acc_traced = acc_tracer.trace(modt, inp)
34+
mod_fixed = fix_reshape_batch_dim(deepcopy(mod_acc_traced))
35+
36+
expected_graph = r"""
37+
graph():
38+
%x : [#users=0] = placeholder[target=x]
39+
%y : [#users=2] = placeholder[target=y]
40+
%size : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.size](args = (), kwargs = {input: %y})
41+
%getitem_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.getitem](args = (), kwargs = {idx: 0, input: %size})
42+
%reshape : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)})
43+
return reshape
44+
"""
45+
assert (
46+
str(mod_fixed.graph).strip() == expected_graph.strip()
47+
), f"Unexpected fixed graph. \nActual: {str(mod_fixed.graph)} \nExpected: {expected_graph}"
48+
49+
50+
if __name__ == "__main__":
51+
run_tests()

py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2566,6 +2566,31 @@ def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
25662566
# Make sure we didn't convert to the acc version
25672567
self.assertEqual(node.target, operator.getitem)
25682568

2569+
def test_detach(self):
2570+
class TestModule(nn.Module):
2571+
def forward(self, x: torch.Tensor) -> torch.Tensor:
2572+
return torch.detach(x)
2573+
2574+
m = TestModule()
2575+
sample_inputs = [torch.randn(8)]
2576+
traced = acc_tracer.trace(m, sample_inputs)
2577+
2578+
placeholder = output = None
2579+
for node in traced.graph.nodes:
2580+
if node.op == "placeholder":
2581+
assert placeholder is None
2582+
placeholder = node
2583+
elif node.op == "output":
2584+
assert output is None
2585+
output = node
2586+
else:
2587+
raise RuntimeError(f"Unexpected Node {node.format_node()}")
2588+
2589+
self.assertIsNotNone(placeholder)
2590+
self.assertIsNotNone(output)
2591+
2592+
self.assertTrue(torch.equal(m(*sample_inputs), traced(*sample_inputs)))
2593+
25692594
def test_all_acc_ops_registered(self):
25702595
self.assertEqual(
25712596
acc_normalizer._acc_ops,

py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def f(x):
162162
inputs = [torch.ones(32, 3, 224, 224)]
163163
inputs = [i.cuda().half() for i in inputs]
164164
torchdynamo.reset()
165-
dynamo_aten_mod = torchdynamo.optimize(backends.fx2trt_aten_compiler_fp16)(mod)
165+
dynamo_aten_mod = torchdynamo.optimize(backends.fx2trt_compiler_fp16)(mod)
166166
dynamo_aten_output = dynamo_aten_mod(*inputs)
167167

168168
torchdynamo.reset()

py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,9 @@ def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node:
582582

583583
@register_acc_op_properties(AccOpProperty.pointwise)
584584
@register_acc_op_mapping(op_and_target=("call_function", torch.clamp))
585+
@register_acc_op_mapping(op_and_target=("call_function", torch.clip))
585586
@register_acc_op_mapping(op_and_target=("call_method", "clamp"))
587+
@register_acc_op_mapping(op_and_target=("call_method", "clip"))
586588
@register_acc_op
587589
def clamp(*, input, min=None, max=None):
588590
return torch.clamp(input=input, min=min, max=max)
@@ -818,6 +820,10 @@ def matmul(*, input, other):
818820
@register_custom_acc_mapper_fn(
819821
op_and_target=("call_method", "detach"), arg_replacement_tuples=[("input", "input")]
820822
)
823+
@register_custom_acc_mapper_fn(
824+
op_and_target=("call_function", torch.detach),
825+
arg_replacement_tuples=[("input", "input")],
826+
)
821827
def dropout_mapper(node: torch.fx.Node, mod: nn.Module):
822828
"""
823829
Remove dropout node and directly map its input to output.

0 commit comments

Comments
 (0)