Skip to content

Commit 515b9b9

Browse files
author
Wei
authored
Changes done internally at Facebook (#1208)
6703b98dff0695d91026f057b951dba1355825fa Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.prod c822345d6d673e1653c2208435e34ab400bada3d Jason Park <[email protected]> Add support for generic torch ops to be used in training. e5758602a0592d6c2b71d6d66a0398c4dd9b5e20 Shreyansh Prajapati <[email protected]> Test dynamic shape support for repeat interleave c13c633f04df162500eed477c0569eb2b81eb070 Shreyansh Prajapati <[email protected]> Test dynamic shape support for reduce ops 863476cf43b210922b88585b8f196dd84fbebb56 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.convolution 68dff39793e5c30c20010919a855bb3d984015d7 Ruichao Xiao <[email protected]> [fbcode][GPU][DHEN]fuse split squeeze cat as reshape f8b920769507ebd2ff02419b4aece25451298a95 Ruichao Xiao <[email protected]> [fbcode][DHEN][GPU] reorder and merge cats whose input is a sublist of another cat 5b6a8d2d6be979983a52ac96225fefb510c3817c Andrew Or <[email protected]> [Quant][fx] Rename convert_to_reference to convert_to_reference_fx 996a0e080b8a8bc0b292a7c2ac92f41f6db33a2e Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.expand 084631fe74b304fbb9481ca15fd452a3714fb1b8 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.to_dtype b3195e76329ccddbb5c4640cfa884d0e457d2d34 Shreyansh Prajapati <[email protected]> Test dynamic shape support for std a5d964e62bdf769cf8c2e67321138b33e1f524a7 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.tile 3d33d45b2fc7f10f25c22946ba474b227e4b6529 Shreyansh Prajapati <[email protected]> Test dynamic shape support for squeeze 09085abf63d7e7732e2cd66e600e8afc6d58964f Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_op.topk 65edc7ea12899e9bd2af42c890a64de853d9b7fe Huamin Li <[email protected]> temporarily skip gelu tests d11e521f9b90554ca86912a49920afa4406bb40d Shirong Wu <[email protected]> Suppress accuracy check for remove_reshape_with_batch_size_change 6d948298b2327d229e010a34f1c221b11d2eb504 Ankur Singla <[email protected]> [GPULowering] Suppress accuracy check for fuse_unsqueeze_cat_sum e780b647fc9571b77d9f41c963041a6ac3d66f33 Janet Yang <[email protected]> Lower xrayvideo2022 to fx2trt 433c7207fef16b1fdff985546ea969c39fa83e7c generatedunixname89002005287564 <[email protected]> [Codemod][Remove @noautodeps and @autodeps-skip tags] deeplearning/trt 1/2 66fdb65cffa925660c77b4758388399db3cbfe48 Scott Wolchok <[email protected]> [fx2ait] Minor Python cleanup in acc_ops_getitem 188132ecb2c19bcbf83cb2dc381f6e3798629f87 generatedunixname89002005324833 <[email protected]> [AutoAccept][Codemod][FBSourceBuckFormatLinter] Daily `arc lint --take BUCKFORMAT` 4536bae4686dd01f2149541ea7fb330e178a4969 Wei Wei <[email protected]> [fx2trt] support sub 064602e666f86c110d931cd90a8536112a19b4ad Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.interpolate 9dfd0ee0cecb1975e3f53c44de237d67ca443ec5 Shreyansh Prajapati <[email protected]> Test dynamic shape support for unary_ops 39b9efad8d5d82463a2016d135c0cf277de1c3c6 Shreyansh Prajapati <[email protected]> Test dynamic shape support for unsqueeze 2bb17667d1dabc95391950426fc1f921eb3d0959 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.split 64dfb7b096686cb2fd33197340dc72f30d525456 Shirong Wu <[email protected]> Group LN trt plugin 438f670e28df59b0734baa092a514fba3d75eb4f Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.avgpool df0fe32dae4343827bd9b37b72daae761b02f228 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops masked fill 44fe735d3493ea2d05a56b49093e4a23dd63a98e Shreyansh Prajapati <[email protected]> Test dynamic shaope support for acc_ops.pad 4f931acca706d8ce79045ceafef2ea0486609149 Wei Wei <[email protected]> [fx2trt] torch.max dynamic shape test bf6f6cbe217d26a95ca9122574adf7de3966db9e Shreyansh Prajapati <[email protected]> Change the name of the test from full_reduce to dim_reduce 1c5680ed107d9206f3514eff4069a3f6c870ba8c Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.type_as 33e4c175a4f5fec78ac0b1c8eb262ca777c7aaba Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.min f37be34bcef9716080b8bafbd1f4ad72e412c44c Wei Wei <[email protected]> [fx2trt] plugin for grid_sample 57b5cc6a0f4839686ae360361a3a13b424794ee7 generatedunixname89002005367269 <[email protected]> [AutoAccept][Codemod][FBSourceBlackLinter] Daily `arc lint --take BLACK` eb741cc5e5a7babdc94e72d411670905f54da3e0 Shreyansh Prajapati <[email protected]> Updated the dynamic shape support for narrow op 521c36b96a14741ae89d7af6cbb658120bcec2ea Shreyansh Prajapati <[email protected]> Removing the comment for 4 dims dynamic shape support after analysis e947343375967fe9efb0a16fdb9f63bff1449328 Shreyansh Prajapati <[email protected]> Updated the pad test for dynamic batch for analysis 3d64087014e91bc301a315eae43683b1aa2b66bc Oleg Khabinov <[email protected]> [trt_bc] Some improvements dfd937a56fa01aca88a89b46176befdac4c202c4 Shreyansh Prajapati <[email protected]> Updated the test for as_strided op for analysis 11d76d0420dcaa4bb8890dcdeb86b6e534af831c Bangsheng Tang <[email protected]> [gpu][infer] replace fx2trt_layer_norm with fbgemm layer_norm 932046ff6ea6dead114c0222b23ca3854690cffa Wei Wei <[email protected]> [fx2trt] bridge the dynamic batch and fixed shape f911463393d8a671cfee6de6d1b5ef4d4f3991a6 Shirong Wu <[email protected]> group swish LN plugin ea65970f23dd7a468e5bc43240f2a9bfa07c9b3b Shirong Wu <[email protected]> Create backend specific lower pass 38183e4a724e5514db2be7193cf4897b59759252 Alex Beloi <[email protected]> [fx] run acc_linter.lint in acc_tracer.trace 088abb6a790a62ca9f8515298a54117cc7fa31d4 Alex Beloi <[email protected]> [fx] re-add pointwise property to acc_ops.clamp 9905c34f2bd28e9b64f10336f9ac326cc39eb60d Oleg Khabinov <[email protected]> [trt] Comment out torch.ops.fbgemm dependency in TRT converters 8252e779476d2ff22ad78185af97a526b2f70fe3 Alex Beloi <[email protected]> [fx] add operator test suite to test_acc_tracer.py 7b93a89c903bc0b6c59efb73a510c3dce8ef793a Shirong Wu <[email protected]> Add option for lower and trt_splitter e08dabcbcd8c3e8ae92484e14cf07bb26993a8d6 Wei Wei <[email protected]> [fx2trt] convert print to logging 3d61dc169b8a7dd1aecad35891a628e44e2c5a02 Shreyansh Prajapati <[email protected]> Readme.md file for dynamic shape support
1 parent 2f896b3 commit 515b9b9

19 files changed

+276
-41
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# PyTorch Operations Dynamic Shape Support Summary
2+
3+
4+
5+
| Operation | Test Method | Supports Dynamic Shape | Shape | Num of dimensions | Reason |
6+
| --- | --- | --- | --- | --- | --- |
7+
| adaptive_avgpool | | partially | (-1, -1, 256, 256) | 2 | AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims. |
8+
| any | | no | | | torch.zeros(tuple(\[*input_t.shape\])). Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
9+
| as_strided | | no | | | RuntimeError: setStorage: sizes \[2, 3\], strides \[1, 2\], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 16 |
10+
| avg_pool | avg_pool2d | yes | (-1,-,1,-1,-1) | 4 | |
11+
| | avg_pool1d | partially | (-1, 3, 3) | 1 | |
12+
| batchnorm | | partially | (-1, 3, -1, -1) | 3 | "Channel dim can't be dynamic for batch norm." |
13+
| binary_ops | | yes | (-1,-,1,-1,-1) | 4 | |
14+
| cat | | yes | (-1,-,1,-1,-1) | 4 | |
15+
| chunk | | partially | (-1, 1, 3, -1) | any (not chunk dim) | AssertionError: Can't chunk on dynamic shape dimension! |
16+
| clamp | | yes | (-1,-,1,-1,-1) | | |
17+
| convolution | conv2d | partially | (-1, 3, -1, -1) | 3 | AssertionError: Channel dim can't be dynamic for convolution. |
18+
| | conv1d | partially | (-1, 3, 3) | 1 | |
19+
| | conv3d | partially | (-1,-,1,-1,-1) | 4 | AssertionError: Channel dim can't be dynamic for convolution. |
20+
| dequantize | | yes | (-1,-,1,-1,-1) | 4 | |
21+
| eimsum | | yes | (-1,-,1,-1,-1) | 4 | |
22+
| elu | | yes | (-1,-,1,-1,-1) | 4 | |
23+
| embedding | | yes | (-1,-,1,-1,-1) | 4 | |
24+
| eq | SimpleConverter | yes | (-1,-,1,-1,-1) | 4 | |
25+
| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | |
26+
| | EqMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
27+
| | EqOperatorConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
28+
| | EqOperatorConstant | partially | (3,-1) | 1 | |
29+
| | EqConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
30+
| expand | | no | | | Dynamic shape is not suitable for the expand operation. |
31+
| flatten | | yes | (-1, -1, -1, -1, -1) | 5 | |
32+
| gelu | | yes | (-1,-,1,-1,-1) | 4 | |
33+
| getitem | | yes | (-1,-,1,-1,-1) | 4 | |
34+
| gt | EqOperatorSimpleConverter | yes | (-1,-,1,-1,-1) | 4 | |
35+
| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | |
36+
| | GtConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
37+
| | GtMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
38+
| | GtOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
39+
| | EqOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
40+
| hardsigmoid | | yes | (-1,-,1,-1,-1) | 4 | |
41+
| hardtanh | | yes | (-1,-,1,-1,-1) | 4 | |
42+
| interpolate | | yes | (-1,-,1,-1,-1) | 4 | |
43+
| isinf | | yes | (-1,-,1,-1,-1) | 4 | |
44+
| leaky_relu | | yes | (-1,-,1,-1,-1) | 4 | |
45+
| linear | | partially | (-1, 3, 5) | 1 | AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. |
46+
| logical_and | | yes | (-1, -1, -1, -1) | 4 | |
47+
| logical_or | | yes | (-1, -1, -1, -1) | 4 | |
48+
| logical_xor | | yes | (-1, -1, -1, -1) | 4 | |
49+
| lt | | yes | (-1, -1, -1, -1) | 4 | |
50+
| masked_fill | | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
51+
| mat_mul | | yes | batch dim | | |
52+
| max | MaxFullReduce | yes | (-1, -1, -1, -1) | 4 | |
53+
| | MaxDimReduce | yes | (-1, -1, -1, -1) | 4 | |
54+
| | MaxMethod | yes | (-1, -1, -1, -1) | 4 | |
55+
| maximum | | yes | (-1, -1, -1, -1) | 4 | |
56+
| maxpool | max_pool1d | partially | (1, 1, -1) | 1 | shape is not set to (-1, -1, -1) as reshape dimension with, more than one -1 wildcard is not allowed while adding unsqueeze layer |
57+
| | max_pool2d | yes | (-1, -1, -1, -1) | 4 | |
58+
| | max_pool3d | yes | (-1, -1, -1, -1, -1) | 5 | |
59+
| min | MinFullReduce | yes | (-1, -1, -1, -1) | 4 | |
60+
| | MinDimReduce | yes | (-1, -1, -1, -1) | 4 | |
61+
| | MinMethod | yes | (-1, -1, -1, -1) | 4 | |
62+
| minimum | | yes | (-1, -1, -1, -1) | 4 | |
63+
| narrow | | partially | (-1, 3, -1, -1) | 3 | AssertionError: Can't chunk on dynamic shape dimension! |
64+
| ne | NeFunctionConverter | yes | (-1, -1, -1, -1) | 4 | |
65+
| | NeMethodConverter | yes | (-1, -1, -1, -1) | 4 | |
66+
| | NeOperatorConverter | yes | (-1, -1, -1, -1) | 4 | |
67+
| | ConstInputConverter | yes | (-1, -1, -1, -1) | 4 | |
68+
| | NeOperatorConstantConverter | partially | (3, -1) | 1 | |
69+
| new_ones | | yes | (-1, -1, -1, -1) | 4 | |
70+
| numel | | no | limitation in converter | | RuntimeError: numel does not support dynamic shapes. |
71+
| pad | | no | limitation in converter | | test\_pad\_with\_dynamic\_shape\_four\_dimensions\_0\_2d (deeplearning.trt.torch\_tensorrt.py.torch\_tensorrt.fx.test.converters.acc\_op.test\_pad.TestPadConverter) ... \[07/15/2022-09:23:18\] \[TRT\] \[E\] 2: \[intInterval.cpp::max::26\] Error Code 2: Internal Error (Assertion !empty() failed. |
72+
| permute | | yes | (-1, -1, -1, -1) | 4 | |
73+
| prod | | yes | (-1, -1, -1, -1) | 4 | |
74+
| quantize\_per\_tensor | | yes | (-1, -1, -1, -1) | 4 | |
75+
| reduce op | | yes | (-1, -1, -1, -1) | 4 | |
76+
| relu | | yes | (-1, -1, -1, -1) | 4 | |
77+
| repeat interleave | | partially | (-1, 3, 2) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. |
78+
| reshape | | yes | (-1, -1, -1, -1) | 4 | |
79+
| selu | | yes | (-1, -1, -1, -1) | 4 | |
80+
| sigmoid | | yes | (-1,-,1,-1,-1) | 4 | |
81+
| silu | | yes | (-1,-,1,-1,-1) | 4 | |
82+
| size | | yes | (-1, -1, -1, -1) | 4 | |
83+
| softmax | | yes | (-1, -1, -1, -1) | 4 | |
84+
| softsign | | yes | (-1, -1, -1, -1) | 4 | |
85+
| split | | partially | (-1, 10, -1) | 2 | AssertionError: Can't chunk on dynamic shape dimension! |
86+
| squeeze | | partially | (1, -1, 2) | 1 | AssertionError: Currently more than one dynamic dim for input to squeeze is not supported. |
87+
| std | | yes | (-1, -1, -1, -1) | 4 | |
88+
| tanh | | yes | (-1, -1, -1, -1) | 4 | |
89+
| tile | | yes | (-1, -1, -1, -1) | 4 | |
90+
| to_dtype | int | yes | (-1, -1, -1, -1) | 4 | |
91+
| | float | yes | (-1, -1, -1, -1) | 4 | |
92+
| topk | | yes | (-1, -1, -1, -1) | 4 | |
93+
| transpose_convolution | conv_transpose2d | partially | (-1, 3, -1, -1) | 3 | |
94+
| | conv_transpose3d | partially | (-1, 3, -1, -1, -1) | 4 | |
95+
| type_as | | yes | (-1, -1, -1, -1) | 4 | RuntimeError: ShapeProp error for: node=%type\_1 : \[#users=1\] = call\_method\[target=type\](args = (%input_1,), kwargs = {dtype: torch.float32}) with meta={} |
96+
| unary ops | | yes | (-1, -1, -1, -1) | 4 | |
97+
| unsqueeze | | partially | (-1, 2, 3) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. |
98+
| where | | no | limitation in converter | | torch.broadcast_shape can not handle -1 dimension in shape \[-1, 2, 2\] |
99+
100+
101+
102+
Binary Ops Include following operations:
103+
|Binary Ops |
104+
|----------|
105+
|add |
106+
|sub |
107+
|div |
108+
|mul |
109+
|floor_div |
110+
|fmod |
111+
|floor_divide|
112+
|pow |
113+
114+
115+
Unary Ops Include following operations:
116+
|Unary Ops |
117+
|----------|
118+
|rsqrt |
119+
|sin |
120+
|cos |
121+
|tan |
122+
|sinh |
123+
|cosh |
124+
|asin |
125+
|acos |
126+
|atan |
127+
|abs |
128+
|neg |
129+
|reciprocal|
130+
|sqrt |
131+
|log |
132+
|exp |
133+
|floor |
134+
|ceil |
135+
|sign |
136+
137+
Note: For more information about the test method, please refer to the operation test files. Additionally, test files include information about errors encountered during dynamic shape testing.

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# flake8: noqa
2+
import logging
23
import math
34
import operator
45
import warnings
@@ -22,6 +23,9 @@
2223
from .converter_utils import * # noqa: F403
2324

2425

26+
_LOGGER: logging.Logger = logging.getLogger(__name__)
27+
28+
2529
@tensorrt_converter(acc_ops.conv1d)
2630
def acc_ops_conv1d(
2731
network: TRTNetwork,
@@ -641,7 +645,7 @@ def acc_ops_layer_norm(network, target, args, kwargs, name):
641645
try:
642646
normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32)
643647
except TypeError:
644-
print("Unable to convert normalized_shape to a field, fall back to []")
648+
_LOGGER.error("Unable to convert normalized_shape to a field, fall back to []")
645649
normalized_shape = np.array([], dtype=np.int32)
646650

647651
normalized_shape_filed = trt.PluginField(
@@ -657,7 +661,9 @@ def acc_ops_layer_norm(network, target, args, kwargs, name):
657661
else:
658662
plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt")
659663
except AssertionError:
660-
print("Unable to find layer norm plugin, fall back to TensorRT implementation.")
664+
_LOGGER.error(
665+
"Unable to find layer norm plugin, fall back to TensorRT implementation."
666+
)
661667
return layer_norm(network, target, args, kwargs, name)
662668
layer = network.add_plugin_v2([input_val], plugin)
663669
layer.name = name

py/torch_tensorrt/fx/lower.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def create(
197197
cls,
198198
lower_setting: LowerSetting,
199199
interpreter_builder: Callable = create_lower_trt_interpreter,
200+
split_func: Callable = default_split_function,
200201
) -> "Lowerer":
201202
"""Instantiate a `Lowerer` instance."""
202203

@@ -209,7 +210,7 @@ def create(
209210
ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list,
210211
leaf_module_list=lower_setting.leaf_module_list,
211212
),
212-
split_func=default_split_function,
213+
split_func=split_func,
213214
lower_func=default_lower_pass(interpreter_builder),
214215
)
215216
)

py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import logging
23
from functools import partial, wraps
34
from typing import Any, Callable, Optional, Sequence
45

@@ -17,6 +18,10 @@
1718

1819
from .lower_basic_pass import run_const_fold
1920

21+
22+
_LOGGER: logging.Logger = logging.getLogger(__name__)
23+
24+
2025
Input = Sequence[Any]
2126

2227

@@ -143,7 +148,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:
143148

144149
# Only acc submodules will be lowered.
145150
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
146-
print("Now lowering submodule", submod_name)
151+
_LOGGER.info("Now lowering submodule", submod_name)
147152
lowering_start_time = datetime.datetime.now()
148153

149154
self.lower_setting.input_specs = generate_input_specs(
@@ -160,7 +165,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:
160165
LOWER_SPLIT_POST_OBSERVER.observe(
161166
submod_name, lowered_module, submod_inputs
162167
)
163-
print(
168+
_LOGGER.info(
164169
f"Lowering submodule {submod_name} elapsed time",
165170
datetime.datetime.now() - lowering_start_time,
166171
)
@@ -179,7 +184,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:
179184

180185
# Only acc submodules will be lowered.
181186
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
182-
print("Now lowering submodule", submod_name)
187+
_LOGGER.info("Now lowering submodule", submod_name)
183188
lowering_start_time = datetime.datetime.now()
184189

185190
lowered_module = self._lower_func(
@@ -189,7 +194,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:
189194
LOWER_SPLIT_POST_OBSERVER.observe(
190195
submod_name, lowered_module, submod_inputs
191196
)
192-
print(
197+
_LOGGER.info(
193198
f"Lowering submodule {submod_name} elapsed time",
194199
datetime.datetime.now() - lowering_start_time,
195200
)

py/torch_tensorrt/fx/passes/pass_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def pass_with_before_after_log(
116116
encoding="utf-8",
117117
delete=False,
118118
) as f:
119-
print(f"== Log pass {pass_} before/after graph to {f.name}")
119+
_LOGGER.info(f"== Log pass {pass_} before/after graph to {f.name}")
120120
print(f"[{pass_}] Before:\n{module.graph}", file=f)
121121
module = pass_(module, input)
122122
print(f"[{pass_}] After:\n{module.graph}", file=f)

py/torch_tensorrt/fx/test/passes/test_graph_opts.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import unittest
23
from collections import Counter
34
from typing import Callable, Dict, List
@@ -8,13 +9,16 @@
89
from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination
910

1011

12+
_LOGGER: logging.Logger = logging.getLogger(__name__)
13+
14+
1115
def debug_print_graph_module(mod_graph: torch.fx.GraphModule) -> None:
1216
"""
1317
Helper func to print model's graph in plain and tabular format, also print code.
1418
"""
15-
print(mod_graph.graph)
19+
_LOGGER.info(mod_graph.graph)
1620
mod_graph.graph.print_tabular()
17-
print(mod_graph.code)
21+
_LOGGER.info(mod_graph.code)
1822

1923

2024
@torch.fx.wrap
@@ -46,7 +50,7 @@ def _test_opt_with_module(
4650
before_results = module(*inputs)
4751
mod_traced = acc_tracer.trace(module, inputs)
4852
before_node_list = list(mod_traced.graph.nodes)
49-
print("Model before opt.")
53+
_LOGGER.info("Model before opt.")
5054
debug_print_graph_module(mod_traced)
5155

5256
# Apply Opt
@@ -55,7 +59,7 @@ def _test_opt_with_module(
5559
# After Opt
5660
after_results = mod_traced(*inputs)
5761
after_node_list = list(mod_traced.graph.nodes)
58-
print("Model after opt.")
62+
_LOGGER.info("Model after opt.")
5963
mod_traced.recompile()
6064
debug_print_graph_module(mod_traced)
6165

py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Owner(s): ["oncall: fx"]
2-
2+
import logging
33
import unittest
44
from typing import Callable, List
55

@@ -16,6 +16,8 @@
1616

1717
torch.manual_seed(0)
1818

19+
_LOGGER: logging.Logger = logging.getLogger(__name__)
20+
1921

2022
class AccTracerTest(unittest.TestCase):
2123
def _make_model_unit_test(
@@ -258,7 +260,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
258260
torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8
259261
)
260262
traced = acc_tracer.trace(m, [input])
261-
print(traced.graph)
263+
_LOGGER.info(traced.graph)
262264
ph = weight_attr = bias_attr = conv = None
263265
for node in traced.graph.nodes:
264266
if node.op == "placeholder":
@@ -626,7 +628,7 @@ def run_embedding_bag_test(is_4bit, use_weights):
626628
)
627629

628630
traced = acc_tracer.trace(m, inputs)
629-
print(traced.graph)
631+
_LOGGER.info(traced.graph)
630632

631633
expected_target = (
632634
acc_ops.embedding_bag_4bit_rowwise_offsets

py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Owner(s): ["oncall: gpu_enablement"]
22
import functools
33
import glob
4+
import logging
45
import os
56
import shutil
67
import tempfile
@@ -10,6 +11,9 @@
1011
import torch_tensorrt.fx.diagnostics as diag
1112

1213

14+
_LOGGER: logging.Logger = logging.getLogger(__name__)
15+
16+
1317
def reset_diag(fn):
1418
@functools.wraps(fn)
1519
def reset(*a, **kw):
@@ -53,7 +57,7 @@ def boom() -> str:
5357
zip_fn = collector._last_zip_path_for_test
5458
assert os.path.exists(zip_fn)
5559
with tempfile.TemporaryDirectory() as tempdir:
56-
print(f"Unpacking into {tempdir}")
60+
_LOGGER.info(f"Unpacking into {tempdir}")
5761
shutil.unpack_archive(zip_fn, tempdir)
5862
_check_file(tempdir, "aaa", "hello")
5963
_check_file(tempdir, "bbb", "world")
@@ -78,7 +82,7 @@ def test_condition_func_name(self):
7882
zip_fn = collector._last_zip_path_for_test
7983
assert os.path.exists(zip_fn)
8084
with tempfile.TemporaryDirectory() as tempdir:
81-
print(f"Unpacking into {tempdir}")
85+
_LOGGER.info(f"Unpacking into {tempdir}")
8286
shutil.unpack_archive(zip_fn, tempdir)
8387
_check_file(tempdir, "aaa", "hello")
8488

@@ -160,7 +164,7 @@ def _test_cond(
160164
if should_collect:
161165
assert os.path.exists(zip_fn)
162166
with tempfile.TemporaryDirectory() as tempdir:
163-
print(f"Unpacking into {tempdir}")
167+
_LOGGER.info(f"Unpacking into {tempdir}")
164168
shutil.unpack_archive(zip_fn, tempdir)
165169
_check_file(tempdir, "aaa", "hello")
166170
else:

0 commit comments

Comments
 (0)