Skip to content

Commit d0af394

Browse files
authored
fix: Allow full model compilation with collection outputs (#1599)
1 parent eef06c9 commit d0af394

File tree

6 files changed

+204
-19
lines changed

6 files changed

+204
-19
lines changed

core/compiler.cpp

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ partitioning::GraphAndMapping BuildHybridGraph(
138138
torch::jit::Block* block,
139139
CompileSpec cfg,
140140
ir::StaticParams static_params,
141-
ir::CollectionTypeMap first_use_types) {
141+
ir::CollectionTypeMap first_use_types,
142+
bool expect_full_compilation = false) {
142143
auto convert_info = cfg.convert_info;
143144
auto partitioning_info = cfg.partitioning_info;
144145

@@ -149,17 +150,20 @@ partitioning::GraphAndMapping BuildHybridGraph(
149150
// TODO: Combine this within partition call
150151
partitioning::populateInputIValues(&partitioning_ctx);
151152

152-
partitioning::partition(&partitioning_ctx);
153+
partitioning::partition(&partitioning_ctx, expect_full_compilation);
153154

154155
for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
155156
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
157+
int num_torch_segments = 0;
158+
int num_trt_segments = 0;
156159

157160
for (auto& seg_block : segmented_blocks) {
158161
LOG_INFO("Block segment:" << seg_block);
159162
std::ostringstream trt_engine_id;
160163
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
161164

162165
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
166+
num_trt_segments++;
163167
auto inputs = seg_block.construct_inputs_spec();
164168
// update the input ranges for each segments
165169
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
@@ -180,8 +184,32 @@ partitioning::GraphAndMapping BuildHybridGraph(
180184
true);
181185

182186
seg_block.update_graph(temp_g);
187+
} else {
188+
num_torch_segments++;
189+
190+
// If full compilation is expected, ensure that all operators in Torch blocks are
191+
// for collections processing
192+
if (expect_full_compilation) {
193+
for (auto torch_node : seg_block.block()->nodes()) {
194+
if (partitioning::CollectionNodeKinds.find(torch_node->kind()) == partitioning::CollectionNodeKinds.end()) {
195+
TORCHTRT_THROW_ERROR(
196+
"Full compilation specified but node "
197+
<< *torch_node
198+
<< " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules."
199+
<< " Try recompiling with require_full_compilation=False.");
200+
}
201+
}
202+
}
183203
}
184204
}
205+
206+
// If full compilation is expected, cannot have more than 2 Torch segments
207+
// (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
208+
if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1)) {
209+
TORCHTRT_THROW_ERROR(
210+
"Full compilation was requested but unable to convert all operations to TensorRT."
211+
<< " Try recompiling with require_full_compilation=False.");
212+
}
185213
}
186214

187215
return partitioning::stitch(&partitioning_ctx, block);
@@ -191,7 +219,8 @@ ir::TypeMap MapInputsAndDetermineDTypes(
191219
CompileSpec& cfg,
192220
std::shared_ptr<torch::jit::Graph>& g,
193221
ir::StaticParams& static_params,
194-
ir::CollectionTypeMap& first_use_type_map) {
222+
ir::CollectionTypeMap& first_use_type_map,
223+
bool requires_collection_handling = false) {
195224
cfg.convert_info.collection_input_spec_map =
196225
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
197226
cfg.partitioning_info.collection_input_spec_map =
@@ -226,7 +255,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
226255
"Cannot infer input type from calcuations in graph for input "
227256
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
228257
spec[i].dtype = at::kFloat;
229-
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) {
258+
} else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info.enabled || requires_collection_handling)) {
230259
if (!est_type_opt[i]) {
231260
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
232261
std::stringstream ss;
@@ -297,6 +326,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
297326
return engine;
298327
}
299328

329+
bool userRequestedFallback(CompileSpec& cfg) {
330+
return cfg.lower_info.forced_fallback_modules.size() != 0 ||
331+
cfg.partitioning_info.forced_fallback_operators.size() != 0;
332+
}
333+
300334
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) {
301335
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");
302336

@@ -315,8 +349,17 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
315349
// Infer the type of an input from the weights of the calculation
316350
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());
317351

352+
// Determine if the block is convertible/has collection output, and based on the result,
353+
// whether full compilation can be expected
354+
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
355+
auto outputIsCollection = conversion::OutputIsCollection(g->block());
356+
auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
357+
358+
// Determine whether user specifications necessitate partitioning
359+
auto isFallbackRequested = userRequestedFallback(cfg);
360+
318361
// Extract map of IValue to DType
319-
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
362+
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types, requires_collection_handling);
320363

321364
// Check whether any of the input types are Long
322365
bool user_requested_long = false;
@@ -330,20 +373,28 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
330373
user_requested_long &= (casts_inserted > 0);
331374
}
332375

333-
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
334-
auto outputIsCollection = conversion::OutputIsCollection(g->block());
335-
if (cfg.partitioning_info.enabled && !user_requested_long &&
336-
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
337-
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
338-
!outputIsCollection) {
376+
// Partitioning is required if:
377+
// 1. User requested some modules/operators fallback
378+
// 2. The block (graph) cannot be converted due to operator coverage
379+
// 3. The output of the graph is a collection
380+
// 4. The user requested a non-TRT data type input
381+
auto isPartitioningRequired =
382+
(isFallbackRequested || !isBlockConvertible || outputIsCollection || user_requested_long);
383+
384+
// The user did not require full compilation, but the model can be fully compiled
385+
if (cfg.partitioning_info.enabled && !isPartitioningRequired) {
339386
LOG_INFO("Skipping partitioning since model is fully supported");
340387
}
341388

342-
if (cfg.partitioning_info.enabled &&
343-
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
344-
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
345-
outputIsCollection || user_requested_long)) {
346-
auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types);
389+
// The user did not require full compilation, and the model can be fully compiled
390+
// or, the user required full compilation but the I/O of the graph use collections
391+
if ((cfg.partitioning_info.enabled && isPartitioningRequired) || requires_collection_handling) {
392+
// If the model is fully-compilable and the user has specified full compilation, run partitioning
393+
// to generate collection-processing code in Torch
394+
auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info.enabled);
395+
396+
auto graph_and_mapping =
397+
BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types, expect_full_compilation);
347398
new_g = graph_and_mapping.first;
348399
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
349400
for (size_t i = 0; i < new_g->inputs().size(); ++i) {

core/lowering/lowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ int AutocastLongInputs(
3232
std::string target_device_name) {
3333
int num_autocasts = 0;
3434
// For each graph input, determine if it can be autocasted
35-
for (int i = 0; i < g->inputs().size(); i++) {
35+
for (size_t i = 0; i < g->inputs().size(); i++) {
3636
auto input = g->inputs()[i];
3737

3838
// Autocasted inputs must be Tensor-type

core/partitioning/partitioning.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,21 @@ void populateInputIValues(PartitioningCtx* ctx) {
564564
}
565565
}
566566

567-
void partition(PartitioningCtx* ctx) {
567+
void partition(PartitioningCtx* ctx, bool expect_full_compilation) {
568+
// If full compilation is expected, overwrite minimum block size
569+
// Any nonzero block size is valid if full compilation to TRT is desired
570+
// Override the default min_block_size to ensure all TRT-supported operations are
571+
// executed in TRT, regardless of the size of the graph
572+
if (expect_full_compilation) {
573+
// If minimum block size is different from the default, the user must have specified it
574+
if (ctx->settings.min_block_size != 3) {
575+
LOG_WARNING(
576+
"Detected user-specified min_block_size with require_full_compilation=True "
577+
<< "disregarding min_block_size.");
578+
}
579+
ctx->settings.min_block_size = 1;
580+
}
581+
568582
LOG_DEBUG(ctx->settings);
569583

570584
// Go through all the blocks to do the partitioning

core/partitioning/partitioning.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,19 @@ typedef std::unordered_map<const torch::jit::Value*, torch::jit::IValue> Example
1818
typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
1919
GraphAndMapping;
2020

21+
// Set of schemas allowed to be executed in Torch, even with require_full_compilation=true,
22+
// as necessary for returning collections of Tensors or other complex constructs, and for
23+
// processing inputs to TRT engines
24+
const std::unordered_set<c10::Symbol> CollectionNodeKinds = {
25+
c10::Symbol::fromQualString("prim::Constant"),
26+
c10::Symbol::fromQualString("aten::__getitem__"),
27+
c10::Symbol::fromQualString("prim::ListConstruct"),
28+
c10::Symbol::fromQualString("prim::ListUnpack"),
29+
c10::Symbol::fromQualString("prim::TupleIndex"),
30+
c10::Symbol::fromQualString("prim::TupleConstruct"),
31+
c10::Symbol::fromQualString("prim::TupleUnpack"),
32+
};
33+
2134
ExampleIValues generateRandomInputs(
2235
ir::CollectionInputSpecMap& input_ranges,
2336
ir::CollectionTypeMap& input_types,
@@ -35,7 +48,7 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block);
3548

3649
GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block);
3750

38-
void partition(PartitioningCtx* ctx);
51+
void partition(PartitioningCtx* ctx, bool expect_full_compilation = false);
3952

4053
} // namespace partitioning
4154
} // namespace core

tests/py/api/test_e2e_behavior.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torchvision.models as models
55
import copy
66
from typing import Dict
7+
from utils import same_output_format
78

89

910
class TestInputTypeDefaultsFP32Model(unittest.TestCase):
@@ -109,6 +110,73 @@ def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self):
109110
)
110111
trt_mod(self.input)
111112

113+
def test_nested_combination_tuple_list_output_with_full_compilation(self):
114+
class Sample(torch.nn.Module):
115+
def __init__(self):
116+
super(Sample, self).__init__()
117+
118+
def forward(self, x, y, z):
119+
c = 1.0
120+
b = x + 2.0 * z
121+
b = y + b
122+
a = b + c
123+
return (a, [b, c])
124+
125+
self.model = Sample().eval().to("cuda")
126+
self.input_1 = torch.zeros((5, 5), dtype=torch.float, device="cuda:0")
127+
self.input_2 = torch.ones((5, 5), dtype=torch.float, device="cuda:0")
128+
self.input_3 = torch.ones((5, 5), dtype=torch.float, device="cuda:0")
129+
scripted_mod = torch.jit.script(self.model)
130+
131+
inputs = [
132+
torchtrt.Input((5, 5), dtype=torch.float),
133+
torchtrt.Input((5, 5), dtype=torch.float),
134+
torchtrt.Input((5, 5), dtype=torch.float),
135+
]
136+
137+
trt_mod = torchtrt.ts.compile(
138+
scripted_mod,
139+
inputs=inputs,
140+
require_full_compilation=True,
141+
enabled_precisions={torch.float, torch.half},
142+
)
143+
trt_output = trt_mod(self.input_1, self.input_2, self.input_3)
144+
torch_output = self.model(self.input_1, self.input_2, self.input_3)
145+
assert same_output_format(
146+
trt_output, torch_output
147+
), "Found differing output formatting between Torch-TRT and Torch"
148+
149+
def test_tuple_output_with_full_compilation(self):
150+
class Sample(torch.nn.Module):
151+
def __init__(self):
152+
super(Sample, self).__init__()
153+
154+
def forward(self, x, y):
155+
a = x + y
156+
return (a,)
157+
158+
self.model = Sample().eval().to("cuda")
159+
self.input_1 = torch.zeros((5, 5), dtype=torch.float, device="cuda:0")
160+
self.input_2 = torch.ones((5, 5), dtype=torch.float, device="cuda:0")
161+
scripted_mod = torch.jit.script(self.model)
162+
163+
inputs = [
164+
torchtrt.Input((5, 5), dtype=torch.float),
165+
torchtrt.Input((5, 5), dtype=torch.float),
166+
]
167+
168+
trt_mod = torchtrt.ts.compile(
169+
scripted_mod,
170+
inputs=inputs,
171+
require_full_compilation=True,
172+
enabled_precisions={torch.float, torch.half},
173+
)
174+
trt_output = trt_mod(self.input_1, self.input_2)
175+
torch_output = self.model(self.input_1, self.input_2)
176+
assert same_output_format(
177+
trt_output, torch_output
178+
), "Found differing output formatting between Torch-TRT and Torch"
179+
112180

113181
if __name__ == "__main__":
114182
unittest.main()

tests/py/api/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,42 @@ def cosine_similarity(gt_tensor, pred_tensor):
1313
res = res.cpu().detach().item()
1414

1515
return res
16+
17+
18+
def same_output_format(trt_output, torch_output):
19+
# For each encountered collection type, ensure the torch and trt outputs agree
20+
# on type and size, checking recursively through all member elements.
21+
if isinstance(trt_output, tuple):
22+
return (
23+
isinstance(torch_output, tuple)
24+
and (len(trt_output) == len(torch_output))
25+
and all(
26+
same_output_format(trt_entry, torch_entry)
27+
for trt_entry, torch_entry in zip(trt_output, torch_output)
28+
)
29+
)
30+
elif isinstance(trt_output, list):
31+
return (
32+
isinstance(torch_output, list)
33+
and (len(trt_output) == len(torch_output))
34+
and all(
35+
same_output_format(trt_entry, torch_entry)
36+
for trt_entry, torch_entry in zip(trt_output, torch_output)
37+
)
38+
)
39+
elif isinstance(trt_output, dict):
40+
return (
41+
isinstance(torch_output, dict)
42+
and (len(trt_output) == len(torch_output))
43+
and (trt_output.keys() == torch_output.keys())
44+
and all(
45+
same_output_format(trt_output[key], torch_output[key])
46+
for key in trt_output.keys()
47+
)
48+
)
49+
elif isinstance(trt_output, set) or isinstance(trt_output, frozenset):
50+
raise AssertionError(
51+
"Unsupported output type 'set' encountered in output format check."
52+
)
53+
else:
54+
return type(trt_output) is type(torch_output)

0 commit comments

Comments
 (0)