Skip to content

Commit fa33de2

Browse files
author
Wei
authored
Merge pull request #1104 from pytorch/fb-sync-2-wwei6
Refactor the internal codebase from fx2trt_oss to torch_tensorrt
2 parents 916c3de + 7618ac5 commit fa33de2

File tree

118 files changed

+245
-203
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

118 files changed

+245
-203
lines changed

docs/_modules/torch_tensorrt/_compile.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ <h1>Source code for torch_tensorrt._compile</h1><div class="highlight"><pre>
518518
<span class="c1"># profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile</span>
519519
<span class="p">)</span>
520520
<span class="c1"># For profile</span>
521-
<span class="c1"># from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module</span>
521+
<span class="c1"># from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module</span>
522522
<span class="c1"># profile_trt_module(&quot;&quot;, trt_mod, acc_inputs)</span>
523523
<span class="n">trt_mod</span> <span class="o">=</span> <span class="n">TRTModule</span><span class="p">(</span><span class="o">*</span><span class="n">r</span><span class="p">)</span>
524524

py/torch_tensorrt/_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def get_input(self, inputs):
172172
# profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
173173
)
174174
# For profile
175-
# from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module
175+
# from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module
176176
# profile_trt_module("", trt_mod, acc_inputs)
177177
trt_mod = TRTModule(*r)
178178

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,21 @@
44
import warnings
55
from typing import cast, Dict, Optional, Sequence, Tuple, Union
66

7-
from ..tracer.acc_tracer import acc_ops
87
import numpy as np
98

109
# @manual=//deeplearning/trt/python:py_tensorrt
1110
import tensorrt as trt
1211
import torch
12+
1313
from ..converter_registry import tensorrt_converter
14+
15+
from ..tracer.acc_tracer import acc_ops
1416
from ..types import * # noqa: F403
15-
from ..utils import (
16-
get_dynamic_dims,
17-
torch_dtype_from_trt,
18-
torch_dtype_to_trt,
19-
)
2017
from torch.fx.immutable_collections import immutable_list
2118
from torch.fx.node import Argument, Target
2219

20+
from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt
21+
2322
from .converter_utils import * # noqa: F403
2423

2524

py/torch_tensorrt/fx/converters/activation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# @manual=//deeplearning/trt/python:py_tensorrt
44
import tensorrt as trt
55
import torch
6+
67
from ..converter_registry import tensorrt_converter
78

89
from .converter_utils import mark_as_int8_layer

py/torch_tensorrt/fx/converters/adaptive_avgpool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# @manual=//deeplearning/trt/python:py_tensorrt
22
import tensorrt as trt
33
import torch
4+
45
from ..converter_registry import tensorrt_converter
56

67
from .converter_utils import extend_mod_attr_to_tuple, mark_as_int8_layer

py/torch_tensorrt/fx/converters/add.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# @manual=//deeplearning/trt/python:py_tensorrt
44
import tensorrt as trt
55
import torch
6+
67
from ..converter_registry import tensorrt_converter
78

89
from .converter_utils import get_dyn_range, mark_as_int8_layer

py/torch_tensorrt/fx/converters/batchnorm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# @manual=//deeplearning/trt/python:py_tensorrt
44
import tensorrt as trt
55
import torch
6+
67
from ..converter_registry import tensorrt_converter
78

89
from .converter_utils import get_dyn_range, mark_as_int8_layer, to_numpy

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
# @manual=//deeplearning/trt/python:py_tensorrt
88
import tensorrt as trt
99
import torch
10+
from torch.fx.node import Argument, Target
11+
1012
from ..types import (
1113
Shape,
1214
TRTDataType,
@@ -18,7 +20,6 @@
1820
TRTTensor,
1921
)
2022
from ..utils import torch_dtype_from_trt
21-
from torch.fx.node import Argument, Target
2223

2324

2425
def get_trt_plugin(

py/torch_tensorrt/fx/converters/convolution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import tensorrt as trt
44
import torch
5+
56
from ..converter_registry import tensorrt_converter
67

78
from .converter_utils import (

py/torch_tensorrt/fx/converters/linear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# @manual=//deeplearning/trt/python:py_tensorrt
22
import tensorrt as trt
33
import torch
4+
45
from ..converter_registry import tensorrt_converter
56

67
from .converter_utils import get_dyn_range, mark_as_int8_layer, to_numpy

py/torch_tensorrt/fx/converters/maxpool.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# @manual=//deeplearning/trt/python:py_tensorrt
22
import tensorrt as trt
33
import torch
4+
45
from ..converter_registry import tensorrt_converter
56

67
from .converter_utils import extend_mod_attr_to_tuple, mark_as_int8_layer

py/torch_tensorrt/fx/converters/mul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# @manual=//deeplearning/trt/python:py_tensorrt
44
import tensorrt as trt
55
import torch
6+
67
from ..converter_registry import tensorrt_converter
78

89
from .converter_utils import get_dyn_range, mark_as_int8_layer

py/torch_tensorrt/fx/converters/quantization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# @manual=//deeplearning/trt/python:py_tensorrt
22
import tensorrt as trt
33
import torch
4+
45
from ..converter_registry import tensorrt_converter
56

67
from .converter_utils import get_dyn_range, get_inputs_from_args_and_kwargs

py/torch_tensorrt/fx/converters/transformation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# @manual=//deeplearning/trt/python:py_tensorrt
22
import tensorrt as trt
33
import torch
4+
45
from ..converter_registry import tensorrt_converter
56

67
from .converter_utils import mark_as_int8_layer

py/torch_tensorrt/fx/example/fx2trt_example.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# type: ignore[]
22

3-
import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer
43
import torch
54
import torch.fx
65
import torch.nn as nn
7-
from fx2trt_oss.fx import InputTensorSpec, TRTInterpreter, TRTModule
8-
from fx2trt_oss.fx.tools.trt_splitter import TRTSplitter
6+
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
7+
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
8+
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter
99

1010

1111
# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
@@ -83,12 +83,12 @@ def forward(self, x):
8383
%x : [#users=1] = placeholder[target=x]
8484
%linear_weight : [#users=1] = get_attr[target=linear.weight]
8585
%linear_bias : [#users=1] = get_attr[target=linear.bias]
86-
%linear_1 : [#users=1] = call_function[target=fx2trt_oss.tracer.acc_tracer.acc_ops.linear](args = (), ...
87-
%relu_1 : [#users=1] = call_function[target=fx2trt_oss.tracer.acc_tracer.acc_ops.relu](args = (), ...
86+
%linear_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linear](args = (), ...
87+
%relu_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), ...
8888
return relu_1
8989
graph():
9090
%relu_1 : [#users=1] = placeholder[target=relu_1]
91-
%linalg_norm_1 : [#users=1] = call_function[target=fx2trt_oss.tracer.acc_tracer.acc_ops.linalg_norm](args = (), ...
91+
%linalg_norm_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linalg_norm](args = (), ...
9292
return linalg_norm_1
9393
"""
9494

py/torch_tensorrt/fx/example/lower_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import torch
66
import torchvision
7-
from fx2trt_oss.fx.lower import lower_to_trt
8-
from fx2trt_oss.fx.utils import LowerPrecision
7+
from torch_tensorrt.fx.lower import lower_to_trt
8+
from torch_tensorrt.fx.utils import LowerPrecision
99

1010

1111
"""

py/torch_tensorrt/fx/example/quantized_resnet_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import copy
22

3-
import fx2trt_oss.tracer.acc_tracer.acc_tracer as acc_tracer
4-
53
# @manual=//deeplearning/trt/python:py_tensorrt
64
import tensorrt as trt
75
import torch.fx
6+
7+
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
88
import torchvision.models as models
9-
from fx2trt_oss.fx import InputTensorSpec, TRTInterpreter, TRTModule
10-
from fx2trt_oss.fx.utils import LowerPrecision
119
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
1210
from torch.fx.experimental.normalize import NormalizeArgs
1311
from torch.fx.passes import shape_prop
12+
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
13+
from torch_tensorrt.fx.utils import LowerPrecision
1414

1515
rn18 = models.resnet18().eval()
1616

py/torch_tensorrt/fx/example/test_fx2trt.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,35 @@
1-
import torch_tensorrt
21
import torch
2+
import torch_tensorrt
3+
34

45
class MyModel(torch.nn.Module):
56
def __init__(self):
67
super().__init__()
7-
self.linear = torch.nn.Linear(5,3)
8+
self.linear = torch.nn.Linear(5, 3)
89
self.relu = torch.nn.functional.relu
9-
def forward(self,x):
10+
11+
def forward(self, x):
1012
x = self.linear(x)
1113
x = self.relu(x)
1214
return x
1315

1416

15-
model = MyModel().eval() # torch module needs to be in eval (not training) mode
17+
model = MyModel().eval() # torch module needs to be in eval (not training) mode
1618

1719
# torch tensorrt
18-
inputs = [torch_tensorrt.Input(
19-
(2,5),
20-
dtype=torch.half,
21-
)]
22-
enabled_precisions = {torch.float, torch.half} # Run with fp16
23-
24-
trt_ts_module = torch_tensorrt.compile(model, inputs=inputs, enabled_precisions=enabled_precisions)
25-
26-
inputs_ts = [torch.ones(2,5)]
20+
inputs = [
21+
torch_tensorrt.Input(
22+
(2, 5),
23+
dtype=torch.half,
24+
)
25+
]
26+
enabled_precisions = {torch.float, torch.half} # Run with fp16
27+
28+
trt_ts_module = torch_tensorrt.compile(
29+
model, inputs=inputs, enabled_precisions=enabled_precisions
30+
)
31+
32+
inputs_ts = [torch.ones(2, 5)]
2733
inputs_ts = [i.cuda().half() for i in inputs_ts]
2834
result = trt_ts_module(*inputs_ts)
2935
print(result)
@@ -33,12 +39,14 @@ def forward(self,x):
3339
print(ref)
3440

3541
# fx2trt
36-
inputs_fx = [torch.ones((2,5))]
42+
inputs_fx = [torch.ones((2, 5))]
3743

3844
model.cuda().half()
3945
inputs_fx = [i.cuda().half() for i in inputs_fx]
4046

41-
trt_fx_module = torch_tensorrt.compile(model, ir="fx", inputs=inputs_fx, enabled_precisions={torch.half})
47+
trt_fx_module = torch_tensorrt.compile(
48+
model, ir="fx", inputs=inputs_fx, enabled_precisions={torch.half}
49+
)
4250
result = trt_fx_module(*inputs_fx)
4351
print(result)
4452

py/torch_tensorrt/fx/example/torchdynamo_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import torch
66
import torchdynamo
77
import torchvision
8-
from fx2trt_oss.fx.lower import lower_to_trt
9-
from fx2trt_oss.fx.utils import LowerPrecision
8+
from torch_tensorrt.fx.lower import lower_to_trt
9+
from torch_tensorrt.fx.utils import LowerPrecision
1010
from torchdynamo.optimizations import backends
1111

1212
"""

py/torch_tensorrt/fx/fx2trt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
import tensorrt as trt
1010
import torch
1111
import torch.fx
12-
from .observer import Observer
1312
from torch.fx.node import _get_qualified_name
1413
from torch.fx.passes.shape_prop import TensorMetadata
1514

1615
from .converter_registry import CONVERTERS
1716
from .input_tensor_spec import InputTensorSpec
17+
from .observer import Observer
1818
from .utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt
1919

2020

py/torch_tensorrt/fx/lower.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,27 @@
22
import logging
33
from typing import Any, Callable, Sequence
44

5-
from .tracer.acc_tracer import acc_tracer
6-
75
# @manual=//deeplearning/trt/python:py_tensorrt
86
import tensorrt as trt
97
import torch
108
import torch.fx as fx
119
import torch.nn as nn
12-
from .lower_setting import LowerSetting
13-
from .passes.pass_utils import decorate_method, validate_inference
14-
from .passes.splitter_base import SplitResult
10+
from torch.fx.passes.splitter_base import SplitResult
1511

1612
from .fx2trt import TRTInterpreter, TRTInterpreterResult
1713
from .input_tensor_spec import InputTensorSpec
14+
from .lower_setting import LowerSetting
1815
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder
19-
from .passes.pass_utils import chain_passes, PassFunc
16+
from .passes.pass_utils import (
17+
chain_passes,
18+
decorate_method,
19+
PassFunc,
20+
validate_inference,
21+
)
2022
from .tools.timing_cache_utils import TimingCacheManager
2123
from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting
24+
25+
from .tracer.acc_tracer import acc_tracer
2226
from .trt_module import TRTModule
2327
from .utils import LowerPrecision
2428

py/torch_tensorrt/fx/lower_setting.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
import dataclasses as dc
22
from typing import List, Optional, Sequence, Set, Type
33

4-
from .input_tensor_spec import InputTensorSpec
5-
from .passes.lower_basic_pass import (
6-
fuse_permute_linear,
7-
fuse_permute_matmul,
8-
)
9-
from .utils import LowerPrecision
104
from torch import nn
115
from torch.fx.passes.pass_manager import PassManager
126

7+
from .input_tensor_spec import InputTensorSpec
8+
from .passes.lower_basic_pass import fuse_permute_linear, fuse_permute_matmul
9+
from .utils import LowerPrecision
10+
1311

1412
@dc.dataclass
1513
class LowerSetting:

py/torch_tensorrt/fx/passes/lower_basic_pass.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import warnings
44
from typing import Any
55

6-
from ..tracer.acc_tracer import acc_ops
76
import torch
87
import torch.fx
8+
from torch.fx.experimental.const_fold import split_const_subgraphs
9+
910
from ..observer import observable
10-
from .pass_utils import log_before_after, validate_inference
11+
12+
from ..tracer.acc_tracer import acc_ops
1113
from ..tracer.acc_tracer.acc_utils import get_attr
12-
from torch.fx.experimental.const_fold import split_const_subgraphs
14+
from .pass_utils import log_before_after, validate_inference
1315

1416
# Create an alias for module input type to avoid littering pyre-ignore for Any
1517
# throughout the file.
@@ -46,15 +48,15 @@ def fuse_sparse_matmul_add(gm: torch.fx.GraphModule, input: Input):
4648
def forward(self, x):
4749
a = self.a
4850
b = self.b
49-
addmm_mm = fx2trt_oss.tracer.acc_tracer.acc_ops.matmul(input = a, other = b); a = b = None
50-
addmm_add = fx2trt_oss.tracer.acc_tracer.acc_ops.add(input = addmm_mm, other = x); addmm_mm = x = None
51+
addmm_mm = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.matmul(input = a, other = b); a = b = None
52+
addmm_add = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.add(input = addmm_mm, other = x); addmm_mm = x = None
5153
return addmm_add
5254
5355
After:
5456
def forward(self, x):
5557
a = self.a
5658
b = self.b
57-
linear_1 = fx2trt_oss.tracer.acc_tracer.acc_ops.linear(input = a, weight = b, bias = x); a = b = x = None
59+
linear_1 = torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linear(input = a, weight = b, bias = x); a = b = x = None
5860
return linear_1
5961
"""
6062
counter = 0
@@ -198,8 +200,8 @@ def fuse_permute_matmul(gm: torch.fx.GraphModule, input: Input):
198200
try:
199201
# @manual=//deeplearning/trt/python:py_tensorrt
200202
import tensorrt as trt
201-
from fx2trt_oss.fx.converter_registry import tensorrt_converter
202-
from fx2trt_oss.fx.converters.converter_utils import (
203+
from torch_tensorrt.fx.converter_registry import tensorrt_converter
204+
from torch_tensorrt.fx.converters.converter_utils import (
203205
add_binary_elementwise_layer,
204206
broadcast,
205207
get_trt_tensor,

0 commit comments

Comments
 (0)