Skip to content

Commit 5a45f6b

Browse files
authored
Merge pull request #1656 from gs-olive/input_signature_full_compilation
fix: Allow full model compilation with collection inputs (`input_signature`)
2 parents 149b2b2 + 985f6a2 commit 5a45f6b

File tree

9 files changed

+199
-62
lines changed

9 files changed

+199
-62
lines changed

README.md

+7
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ import torch_tensorrt
7373
...
7474
7575
trt_ts_module = torch_tensorrt.compile(torch_script_module,
76+
# If the inputs to the module are plain Tensors, specify them via the `inputs` argument:
7677
inputs = [example_tensor, # Provide example tensor for input shape or...
7778
torch_tensorrt.Input( # Specify input object with shape and dtype
7879
min_shape=[1, 3, 224, 224],
@@ -81,6 +82,12 @@ trt_ts_module = torch_tensorrt.compile(torch_script_module,
8182
# For static size shape=[1, 3, 224, 224]
8283
dtype=torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
8384
],
85+
86+
# For inputs containing tuples or lists of tensors, use the `input_signature` argument:
87+
# Below, we have an input consisting of a Tuple of two Tensors (Tuple[Tensor, Tensor])
88+
# input_signature = ( (torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half),
89+
# torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half)), ),
90+
8491
enabled_precisions = {torch.half}, # Run with FP16
8592
)
8693

core/compiler.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
352352
// Determine if the block is convertible/has collection output, and based on the result,
353353
// whether full compilation can be expected
354354
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
355+
auto inputIsCollection = conversion::InputIsCollection(g->block());
355356
auto outputIsCollection = conversion::OutputIsCollection(g->block());
356-
auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
357+
auto requires_collection_handling = (isBlockConvertible && (inputIsCollection || outputIsCollection));
357358

358359
// Determine whether user specifications necessitate partitioning
359360
auto isFallbackRequested = userRequestedFallback(cfg);

core/conversion/conversion.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -556,10 +556,20 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
556556
return convertable_ops;
557557
}
558558

559+
bool InputIsCollection(const torch::jit::Block* b) {
560+
for (auto in : b->inputs()) {
561+
if (in->type()->kind() == torch::jit::TypeKind::TupleType || in->type()->kind() == torch::jit::TypeKind::ListType) {
562+
return true;
563+
}
564+
}
565+
return false;
566+
}
567+
559568
bool OutputIsCollection(const torch::jit::Block* b) {
560569
for (auto out : b->outputs()) {
561570
if (out->type()->kind() == torch::jit::TypeKind::TupleType ||
562-
out->type()->kind() == torch::jit::TypeKind::ListType) {
571+
out->type()->kind() == torch::jit::TypeKind::ListType ||
572+
out->type()->kind() == torch::jit::TypeKind::DictType) {
563573
return true;
564574
}
565575
}

core/conversion/conversion.h

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ std::string ConvertBlockToEngine(
2626

2727
bool OpSupported(const torch::jit::Node* n);
2828

29+
bool InputIsCollection(const torch::jit::Block* b);
30+
2931
bool OutputIsCollection(const torch::jit::Block* b);
3032

3133
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false);

cpp/src/compile_spec.cpp

-21
Original file line numberDiff line numberDiff line change
@@ -74,27 +74,6 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) {
7474
LOG_WARNING("Input signature parsing is an experimental feature, behavior and APIs may change");
7575
to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature);
7676
torchtrt::core::CompileSpec internal(converted_input_signature);
77-
78-
TORCHTRT_CHECK(
79-
!external.require_full_compilation,
80-
"Grouped inputs currently requires partial compilation to be enabled, \
81-
this restriction will be relaxed in a future release");
82-
83-
LOG_DEBUG("Grouped inputs currently requires additional settings to enable the feature");
84-
LOG_DEBUG(
85-
"Adding the following ops to torch_executed_ops:" << std::endl
86-
<< " - aten::__getitem__" << std::endl
87-
<< " - prim::ListConstruct" << std::endl
88-
<< " - prim::ListUnpack" << std::endl
89-
<< " - prim::TupleIndex" << std::endl
90-
<< " - prim::TupleConstruct" << std::endl
91-
<< " - prim::TupleUnpack");
92-
external.torch_executed_ops.push_back("aten::__getitem__");
93-
external.torch_executed_ops.push_back("prim::ListConstruct");
94-
external.torch_executed_ops.push_back("prim::ListUnpack");
95-
external.torch_executed_ops.push_back("prim::TupleIndex");
96-
external.torch_executed_ops.push_back("prim::TupleConstruct");
97-
external.torch_executed_ops.push_back("prim::TupleUnpack");
9877
return internal;
9978
}
10079
}

docsrc/getting_started/getting_started_with_python_api.rst

+27-3
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ If given a ``torch.nn.Module`` and the ``ir`` flag is set to either ``default``
1414

1515
To compile your input ``torch.nn.Module`` with Torch-TensorRT, all you need to do is provide the module and inputs
1616
to Torch-TensorRT and you will be returned an optimized TorchScript module to run or add into another PyTorch module. Inputs
17-
is a list of ``torch_tensorrt.Input`` classes which define input's shape, datatype and memory format. You can also specify settings such as
18-
operating precision for the engine or target device. After compilation you can save the module just like any other module
17+
is a list of ``torch_tensorrt.Input`` classes which define input Tensors' shape, datatype and memory format. Alternatively, if your input is a more complex data type, such as a tuple or list of Tensors, you can use the ``input_signature`` argument to specify a collection-based input, such as ``(List[Tensor], Tuple[Tensor, Tensor])``. See the second sample below for an example. You can also specify settings such as operating precision for the engine or target device. After compilation you can save the module just like any other module
1918
to load in a deployment application. In order to load a TensorRT/TorchScript module, make sure you first import ``torch_tensorrt``.
2019

2120
.. code-block:: python
@@ -44,6 +43,32 @@ to load in a deployment application. In order to load a TensorRT/TorchScript mod
4443
result = trt_ts_module(input_data)
4544
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
4645
46+
.. code-block:: python
47+
48+
# Sample using collection-based inputs via the input_signature argument
49+
import torch_tensorrt
50+
51+
...
52+
53+
model = MyModel().eval()
54+
55+
# input_signature expects a tuple of individual input arguments to the module
56+
# The module below, for example, would have a docstring of the form:
57+
# def forward(self, input0: List[torch.Tensor], input1: Tuple[torch.Tensor, torch.Tensor])
58+
input_signature = (
59+
[torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)],
60+
(torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)),
61+
)
62+
enabled_precisions = {torch.float, torch.half}
63+
64+
trt_ts_module = torch_tensorrt.compile(
65+
model, input_signature=input_signature, enabled_precisions=enabled_precisions
66+
)
67+
68+
input_data = input_data.to("cuda").half()
69+
result = trt_ts_module(input_data)
70+
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
71+
4772
.. code-block:: python
4873
4974
# Deployment application
@@ -55,4 +80,3 @@ to load in a deployment application. In order to load a TensorRT/TorchScript mod
5580
result = trt_ts_module(input_data)
5681
5782
Torch-TensorRT Python API also provides ``torch_tensorrt.ts.compile`` which accepts a TorchScript module as input and ``torch_tensorrt.fx.compile`` which accepts a FX GraphModule as input.
58-

py/torch_tensorrt/ts/_compile_spec.py

+1-36
Original file line numberDiff line numberDiff line change
@@ -268,42 +268,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
268268
"Input signature parsing is an experimental feature, behavior and APIs may change",
269269
)
270270
signature = _parse_input_signature(compile_spec["input_signature"])
271-
info.input_signature = _C.InputSignature(signature) # py_object
272-
273-
if not compile_spec["torch_fallback"]["enabled"]:
274-
raise ValueError(
275-
"Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release"
276-
)
277-
278-
log(
279-
Level.Debug,
280-
"Grouped inputs currently requires additional settings to enable the feature",
281-
)
282-
log(
283-
Level.Debug,
284-
"""Adding the following ops to torch_executed_ops:
285-
- aten::__getitem__
286-
- prim::ListConstruct
287-
- prim::ListUnpack
288-
- prim::TupleIndex
289-
- prim::TupleConstruct
290-
- prim::TupleUnpack
291-
""",
292-
)
293-
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
294-
"aten::__getitem__"
295-
)
296-
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
297-
"prim::ListConstruct"
298-
)
299-
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack")
300-
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex")
301-
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
302-
"prim::TupleConstruct"
303-
)
304-
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
305-
"prim::TupleUnpack"
306-
)
271+
info.input_signature = _C.InputSignature(signature)
307272

308273
else:
309274
raise KeyError(

tests/cpp/test_collections.cpp

+62
Original file line numberDiff line numberDiff line change
@@ -404,3 +404,65 @@ TEST(CppAPITests, TestCollectionComplexModel) {
404404
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
405405
out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor()));
406406
}
407+
408+
TEST(CppAPITests, TestCollectionFullCompilationComplexModel) {
409+
std::string path = "tests/modules/list_input_tuple_output_scripted.jit.pt";
410+
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
411+
std::vector<at::Tensor> inputs;
412+
inputs.push_back(in0);
413+
414+
torch::jit::Module mod;
415+
try {
416+
// Deserialize the ScriptModule from a file using torch::jit::load().
417+
mod = torch::jit::load(path);
418+
} catch (const c10::Error& e) {
419+
std::cerr << "error loading the model\n";
420+
}
421+
mod.eval();
422+
mod.to(torch::kCUDA);
423+
424+
std::vector<torch::jit::IValue> inputs_;
425+
426+
for (auto in : inputs) {
427+
inputs_.push_back(torch::jit::IValue(in.clone()));
428+
}
429+
430+
std::vector<torch::jit::IValue> complex_inputs;
431+
auto input_list = c10::impl::GenericList(c10::TensorType::get());
432+
input_list.push_back(inputs_[0]);
433+
input_list.push_back(inputs_[0]);
434+
435+
torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
436+
437+
complex_inputs.push_back(input_list_ivalue);
438+
439+
auto out = mod.forward(complex_inputs);
440+
441+
auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf);
442+
443+
auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
444+
445+
c10::TypePtr elementType = input_shape_ivalue.type();
446+
auto list = c10::impl::GenericList(elementType);
447+
list.push_back(input_shape_ivalue);
448+
list.push_back(input_shape_ivalue);
449+
450+
torch::jit::IValue complex_input_shape(list);
451+
std::tuple<torch::jit::IValue> input_tuple2(complex_input_shape);
452+
torch::jit::IValue complex_input_shape2(input_tuple2);
453+
454+
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
455+
compile_settings.min_block_size = 1;
456+
compile_settings.require_full_compilation = true;
457+
458+
// // FP16 execution
459+
compile_settings.enabled_precisions = {torch::kHalf};
460+
// // Compile module
461+
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
462+
auto trt_out = trt_mod.forward(complex_inputs);
463+
464+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
465+
out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor()));
466+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
467+
out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor()));
468+
}

tests/py/api/test_collections.py

+87
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,34 @@ def test_compile(self):
194194
msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
195195
)
196196

197+
def test_compile_full_compilation(self):
198+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
199+
self.model = (
200+
torch.jit.load(MODULE_DIR + "/tuple_input_output_scripted.jit.pt")
201+
.eval()
202+
.to("cuda")
203+
)
204+
205+
compile_spec = {
206+
"input_signature": (
207+
(torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),
208+
),
209+
"device": torchtrt.Device("gpu:0"),
210+
"enabled_precisions": {torch.float},
211+
"min_block_size": 1,
212+
"require_full_compilation": True,
213+
}
214+
215+
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
216+
trt_out = trt_mod((self.input, self.input))
217+
pyt_out = self.model((self.input, self.input))
218+
for (t, p) in zip(trt_out, pyt_out):
219+
cos_sim = cosine_similarity(t, p)
220+
self.assertTrue(
221+
cos_sim > COSINE_THRESHOLD,
222+
msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
223+
)
224+
197225

198226
class TestListInputOutput(unittest.TestCase):
199227
def test_compile(self):
@@ -225,6 +253,36 @@ def test_compile(self):
225253
msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
226254
)
227255

256+
def test_compile_full_compilation(self):
257+
258+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
259+
self.model = (
260+
torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt")
261+
.eval()
262+
.to("cuda")
263+
)
264+
265+
compile_spec = {
266+
"input_signature": (
267+
[torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],
268+
),
269+
"device": torchtrt.Device("gpu:0"),
270+
"enabled_precisions": {torch.float},
271+
"min_block_size": 1,
272+
"require_full_compilation": True,
273+
}
274+
275+
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
276+
trt_out = trt_mod((self.input, self.input))
277+
pyt_out = self.model((self.input, self.input))
278+
279+
for (t, p) in zip(trt_out, pyt_out):
280+
cos_sim = cosine_similarity(t, p)
281+
self.assertTrue(
282+
cos_sim > COSINE_THRESHOLD,
283+
msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
284+
)
285+
228286

229287
class TestListInputTupleOutput(unittest.TestCase):
230288
def test_compile(self):
@@ -255,6 +313,35 @@ def test_compile(self):
255313
msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
256314
)
257315

316+
def test_compile_full_compilation(self):
317+
318+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
319+
self.model = (
320+
torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt")
321+
.eval()
322+
.to("cuda")
323+
)
324+
325+
compile_spec = {
326+
"input_signature": (
327+
[torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],
328+
),
329+
"device": torchtrt.Device("gpu:0"),
330+
"enabled_precisions": {torch.float},
331+
"min_block_size": 1,
332+
"require_full_compilation": True,
333+
}
334+
335+
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
336+
trt_out = trt_mod((self.input, self.input))
337+
pyt_out = self.model((self.input, self.input))
338+
for (t, p) in zip(trt_out, pyt_out):
339+
cos_sim = cosine_similarity(t, p)
340+
self.assertTrue(
341+
cos_sim > COSINE_THRESHOLD,
342+
msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
343+
)
344+
258345

259346
if __name__ == "__main__":
260347
unittest.main()

0 commit comments

Comments
 (0)