Skip to content

Commit 8554221

Browse files
authored
Merge pull request #1032 from NVIDIA/resolve_inputs_fix
Fix: Resolve the conflicts for ResolveNonTensorInputs
2 parents c872a49 + 7059eba commit 8554221

File tree

3 files changed

+184
-32
lines changed

3 files changed

+184
-32
lines changed

core/lowering/passes/exception_elimination.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ struct ExceptionOrPassPatternElimination {
9090
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
9191
auto n = *it;
9292
if (n->kind() == prim::If && isExceptionOrPassNode(n)) {
93-
LOG_ERROR("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl);
93+
LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl);
9494
it.destroyCurrent();
9595
}
9696
}
@@ -104,7 +104,7 @@ void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
104104
ExceptionOrPassPatternElimination eppe(std::move(graph));
105105
eppe.run();
106106
if (graph) {
107-
LOG_ERROR("Post Eliminate Exception or Pass Patterns: " << *graph);
107+
LOG_GRAPH("Post Eliminate Exception or Pass Patterns: " << *graph);
108108
}
109109
}
110110

core/partitioning/partitioning.cpp

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ bool containNonTensorOutputs(torch::jit::Node* n) {
5757
return false;
5858
}
5959

60-
std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*>& vals) {
60+
std::vector<torch::jit::Node*> getDependencyNodes(const std::vector<torch::jit::Value*>& vals) {
6161
// use bfs to get the DAG dependency nodes for input value
6262
std::queue<torch::jit::Value*, std::deque<torch::jit::Value*>> q(
6363
std::deque<torch::jit::Value*>(vals.begin(), vals.end()));
@@ -169,17 +169,10 @@ std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock
169169
return std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock>(append_blocks, trt_block);
170170
}
171171

172-
PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
173-
// reconstruct segmented_block if this block requires nonTensor input
174-
std::vector<torch::jit::Value*> nontensor_inputs;
175-
// Gather all non-tensor inputs for this seg_block
176-
for (auto input : seg_block.raw_inputs()) {
177-
if (!isTensorOrTensorList(input)) {
178-
nontensor_inputs.push_back(input);
179-
}
180-
}
181-
182-
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(nontensor_inputs);
172+
PartitionedGraph segmentBlocksWithSpecifiedInputs(
173+
SegmentedBlock& seg_block,
174+
std::vector<torch::jit::Value*>& inputs_to_resolve) {
175+
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(inputs_to_resolve);
183176
PartitionedGraph new_seg_blocks;
184177
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the
185178
// dependency nodes at the beginning of the current segmented_block and return this merged segmented_block
@@ -194,7 +187,7 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
194187
}
195188
} else {
196189
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
197-
std::unordered_set<torch::jit::Value*> nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end());
190+
std::unordered_set<torch::jit::Value*> inputs_to_resolve_set(inputs_to_resolve.begin(), inputs_to_resolve.end());
198191
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
199192

200193
// take all nodes with non_tensor_inputs as initial dirty nodes (nodes that should be in PyTorch block), then we use
@@ -205,7 +198,7 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
205198
seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
206199

207200
for (auto n : seg_block.raw_nodes()) {
208-
if (containTargetInputs(n, nontensor_inputs_set)) {
201+
if (containTargetInputs(n, inputs_to_resolve_set)) {
209202
dirty_nodes.insert(n);
210203
}
211204
}
@@ -237,6 +230,18 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
237230
return new_seg_blocks;
238231
}
239232

233+
PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
234+
// reconstruct segmented_block if this block requires nonTensor input
235+
std::vector<torch::jit::Value*> inputs_to_resolve;
236+
// Gather all non-tensor inputs for this block
237+
for (auto input : seg_block.raw_inputs()) {
238+
if (!isTensorOrTensorList(input)) {
239+
inputs_to_resolve.push_back(input);
240+
}
241+
}
242+
return segmentBlocksWithSpecifiedInputs(seg_block, inputs_to_resolve);
243+
}
244+
240245
std::unordered_map<torch::jit::Value*, usage_info> getInputUsageCounts(
241246
const PartitionedGraph& segmented_blocks,
242247
const std::function<bool(torch::jit::Value*)>& condition) {
@@ -284,6 +289,10 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
284289
segmented_blocks, [](torch::jit::Value* input) -> bool { return !isTensorOrTensorList(input); });
285290
auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list);
286291

292+
std::map<int, std::vector<torch::jit::Value*>>
293+
torch_values_to_fix; // Only need to resolve values generated by tensorrt
294+
std::set<int> tensorrt_blocks_to_fix; // Need to resolve ALL non-tensor inputs
295+
287296
// update blocks_list
288297
std::unordered_set<int> updated_segments;
289298
for (auto& use : usage_counts) {
@@ -292,27 +301,26 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
292301
// kTorch segment.
293302
if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) {
294303
auto first_torch_id = use_info.torch_use_id.back();
295-
if (!updated_segments.count(first_torch_id)) {
296-
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
297-
// Torch-TensorRT doesn't support non-tensor inputs for a module.
298-
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]);
299-
auto next_iter = segmented_blocks_list.erase(idx_to_iter[first_torch_id]);
300-
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
301-
updated_segments.insert(first_torch_id);
302-
}
304+
torch_values_to_fix[first_torch_id].push_back(use.first);
303305
}
304306
// kTensorRT segments always need to inject nodes for the nonTensor inputs
305307
for (auto i : use_info.tensorrt_use_id) {
306-
if (!updated_segments.count(i)) {
307-
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
308-
// Torch-TensorRT doesn't support non-tensor inputs for a module.
309-
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]);
310-
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
311-
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
312-
updated_segments.insert(i);
313-
}
308+
tensorrt_blocks_to_fix.insert(i);
314309
}
315310
}
311+
for (auto torch_block_pair : torch_values_to_fix) {
312+
auto to_inject_blocks =
313+
segmentBlocksWithSpecifiedInputs(segmented_blocks[torch_block_pair.first], torch_block_pair.second);
314+
auto next_iter = segmented_blocks_list.erase(idx_to_iter[torch_block_pair.first]);
315+
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
316+
}
317+
318+
for (auto i : tensorrt_blocks_to_fix) {
319+
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]);
320+
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
321+
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
322+
}
323+
316324
segmented_blocks.clear();
317325
segmented_blocks.insert(segmented_blocks.begin(), segmented_blocks_list.begin(), segmented_blocks_list.end());
318326
return;

tests/core/partitioning/test_resolve_nontensor_inputs.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,147 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
257257
int count = count_trt_engines(fallback_g);
258258
ASSERT_TRUE(count == 2);
259259
}
260+
261+
TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
262+
/* parseIR does not support "= aten::_set_item" so we will build this graph manually
263+
const auto graph = R"IR(
264+
graph(%x : Tensor,
265+
%y : Tensor):
266+
%2 : str = prim::Constant[value="INS"]()
267+
%3 : str = prim::Constant[value="OUTS"]()
268+
%4 : bool = prim::Constant[value=0]()
269+
%5 : int = prim::Constant[value=-1]()
270+
%6 : Dict(str, Tensor) = prim::DictConstruct()
271+
= aten::_set_item(%6, %2, %x)
272+
%7 : Tensor = aten::__getitem__(%6, %2)
273+
%8 : Tensor = aten::lt(%7, %y)
274+
%9 : Tensor?[] = prim::ListConstruct(%8)
275+
%10 : int = prim::dtype(%7)
276+
%11 : Device = prim::device(%7)
277+
%12 : Tensor = aten::tensor(%5, %10, %11, %4)
278+
%13 : Tensor = aten::index_put_(%7, %9, %12, %4)
279+
= aten::_set_item(%6, %3, %7)
280+
%14 : Tensor = aten::__getitem__(%6, %2)
281+
%15 : Tensor = aten::__getitem__(%6, %3)
282+
return (%14, %15))IR";
283+
*/
284+
auto g = std::make_shared<torch::jit::Graph>();
285+
auto x = g->insertInput(0, "x");
286+
auto y = g->insertInput(1, "y");
287+
torch::jit::IValue ins_key("INS");
288+
auto ins_key_val = g->insertConstant(ins_key);
289+
torch::jit::IValue outs_key("OUTS");
290+
auto outs_key_val = g->insertConstant(outs_key);
291+
torch::jit::IValue zero(0);
292+
auto false_const_val = g->insertConstant(zero);
293+
false_const_val->setType(c10::BoolType::get());
294+
torch::jit::IValue neg_one(-1);
295+
auto neg_one_const_val = g->insertConstant(neg_one);
296+
auto dict_node = g->createDict(
297+
ins_key_val->type(),
298+
x->type(),
299+
torch::jit::ArrayRef<torch::jit::Value*>(),
300+
torch::jit::ArrayRef<torch::jit::Value*>());
301+
g->insertNode(dict_node);
302+
auto set_node = g->create(
303+
torch::jit::Symbol::fromQualString("aten::_set_item"),
304+
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val, x},
305+
0);
306+
g->insertNode(set_node);
307+
auto get_node = g->create(
308+
torch::jit::Symbol::fromQualString("aten::__getitem__"),
309+
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val},
310+
1);
311+
g->insertNode(get_node);
312+
auto lt_node = g->create(
313+
torch::jit::Symbol::fromQualString("aten::lt"),
314+
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output(), y},
315+
1);
316+
g->insertNode(lt_node);
317+
auto list_node = g->createList(
318+
at::OptionalType::create(lt_node->output()->type()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output()});
319+
g->insertNode(list_node);
320+
auto dtype_node = g->create(
321+
torch::jit::Symbol::fromQualString("prim::dtype"),
322+
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()},
323+
1);
324+
dtype_node->output()->setType(neg_one_const_val->type());
325+
g->insertNode(dtype_node);
326+
auto device_node = g->create(
327+
torch::jit::Symbol::fromQualString("prim::device"),
328+
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()},
329+
1);
330+
device_node->output()->setType(c10::DeviceObjType::get());
331+
g->insertNode(device_node);
332+
auto tensor_node = g->create(
333+
torch::jit::Symbol::fromQualString("aten::tensor"),
334+
torch::jit::ArrayRef<torch::jit::Value*>{
335+
neg_one_const_val, dtype_node->output(), device_node->output(), false_const_val},
336+
1);
337+
g->insertNode(tensor_node);
338+
auto index_put_node = g->create(
339+
torch::jit::Symbol::fromQualString("aten::index_put_"),
340+
torch::jit::ArrayRef<torch::jit::Value*>{
341+
get_node->output(), list_node->output(), tensor_node->output(), false_const_val},
342+
1);
343+
g->insertNode(index_put_node);
344+
auto out_set_node = g->create(
345+
torch::jit::Symbol::fromQualString("aten::_set_item"),
346+
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val, get_node->output()},
347+
0);
348+
g->insertNode(out_set_node);
349+
auto get_ins_node = g->create(
350+
torch::jit::Symbol::fromQualString("aten::__getitem__"),
351+
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val},
352+
1);
353+
g->insertNode(get_ins_node);
354+
auto get_outs_node = g->create(
355+
torch::jit::Symbol::fromQualString("aten::__getitem__"),
356+
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val},
357+
1);
358+
g->insertNode(get_outs_node);
359+
g->registerOutput(get_ins_node->output());
360+
g->registerOutput(get_outs_node->output());
361+
362+
torch_tensorrt::core::partitioning::PartitionInfo partition_info;
363+
partition_info.enabled = true;
364+
std::vector<torch_tensorrt::core::ir::Input> inputs;
365+
inputs.push_back(torch_tensorrt::core::ir::Input({4, 4}));
366+
inputs.push_back(torch_tensorrt::core::ir::Input({4, 4}));
367+
368+
std::unordered_map<const torch::jit::Value*, torch_tensorrt::core::ir::Input> inputs_map;
369+
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> input_types;
370+
for (size_t i = 0; i < g->inputs().size(); ++i) {
371+
inputs_map.insert({g->inputs()[i], inputs[i]});
372+
input_types.insert({g->inputs()[i], {at::kFloat}});
373+
}
374+
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
375+
auto segmented_blocks = torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info);
376+
377+
int torch_block_cnt = 0, trt_block_cnt = 0;
378+
for (const auto& segmented_block : segmented_blocks) {
379+
if (segmented_block.target() == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT) {
380+
++trt_block_cnt;
381+
ASSERT_TRUE(checkSegmentedBlockInputType(segmented_block, [](torch::jit::TypePtr type_ptr) {
382+
return type_ptr->isSubtypeOf(torch::jit::TensorType::get());
383+
}));
384+
} else {
385+
++torch_block_cnt;
386+
bool output_dict = false;
387+
bool input_dict = false;
388+
auto dict_type = dict_node->output()->type();
389+
for (auto in : segmented_block.raw_inputs()) {
390+
if (in->type()->isSubtypeOf(dict_type)) {
391+
input_dict = true;
392+
}
393+
}
394+
for (auto out : segmented_block.raw_outputs()) {
395+
if (out->type()->isSubtypeOf(dict_type)) {
396+
output_dict = true;
397+
}
398+
}
399+
EXPECT_TRUE(output_dict ^ input_dict);
400+
}
401+
}
402+
ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 2);
403+
}

0 commit comments

Comments
 (0)