Skip to content

Commit 485fb03

Browse files
mfeliz-cruiseAllison Autonomous
authored and
Allison Autonomous
committed
fix: Avoid resolving non-tensor inputs to torch segment_blocks unneccessarily
Signed-off-by: Michael Feliz <[email protected]>
1 parent 10b55d4 commit 485fb03

File tree

2 files changed

+143
-29
lines changed

2 files changed

+143
-29
lines changed

core/partitioning/partitioning.cpp

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -137,17 +137,8 @@ std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock
137137
return std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock>(append_blocks, trt_block);
138138
}
139139

140-
PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
141-
// reconstruct segmented_block if this block requires nonTensor input
142-
std::vector<torch::jit::Value*> nontensor_inputs;
143-
// Gather all non-tensor inputs for this seg_block
144-
for (auto input : seg_block.raw_inputs()) {
145-
if (!isTensorOrTensorList(input)) {
146-
nontensor_inputs.push_back(input);
147-
}
148-
}
149-
150-
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(nontensor_inputs);
140+
PartitionedGraph segmentBlocksWithSpecifiedInputs(SegmentedBlock& seg_block, std::vector<torch::jit::Value*> inputs_to_resolve){
141+
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(inputs_to_resolve);
151142
PartitionedGraph new_seg_blocks;
152143
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the
153144
// dependency nodes at the beginning of the current segmented_block and return this merged segmented_block
@@ -162,15 +153,15 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
162153
}
163154
} else {
164155
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
165-
std::unordered_set<torch::jit::Value*> nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end());
156+
std::unordered_set<torch::jit::Value*> inputs_to_resolve_set(inputs_to_resolve.begin(), inputs_to_resolve.end());
166157
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes(dependency_nodes.begin(), dependency_nodes.end());
167158

168159
bool prev_non_tensor_outputs = false;
169160
for (auto n : seg_block.raw_nodes()) {
170161
// Check if the node has non-tensor inputs or if it consumes non-tensor outputs of previous node.
171162
// In these cases, these nodes are placed into a new Pytorch SegmentedBlock. Else, they form a new TensorRT
172163
// SegmentedBlock.
173-
if (containTargetInputs(n, nontensor_inputs_set) || prev_non_tensor_outputs) {
164+
if (containTargetInputs(n, inputs_to_resolve_set) || prev_non_tensor_outputs) {
174165
// If tensorrt_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
175166
// TensorRT segmented_block and clear the tensorrt_nodes list to be later used for new TRT segments.
176167
if (!tensorrt_nodes.empty()) {
@@ -201,6 +192,18 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
201192
return new_seg_blocks;
202193
}
203194

195+
PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
196+
// reconstruct segmented_block if this block requires nonTensor input
197+
std::vector<torch::jit::Value*> inputs_to_resolve;
198+
// Gather all non-tensor inputs for this block
199+
for (auto input : seg_block.raw_inputs()) {
200+
if (!isTensorOrTensorList(input)) {
201+
inputs_to_resolve.push_back(input);
202+
}
203+
}
204+
return segmentBlocksWithSpecifiedInputs(seg_block, inputs_to_resolve);
205+
}
206+
204207
std::unordered_map<torch::jit::Value*, usage_info> getInputUsageCounts(
205208
const PartitionedGraph& segmented_blocks,
206209
const std::function<bool(torch::jit::Value*)>& condition) {
@@ -248,6 +251,9 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
248251
segmented_blocks, [](torch::jit::Value* input) -> bool { return !isTensorOrTensorList(input); });
249252
auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list);
250253

254+
std::map<int, std::vector<torch::jit::Value*>> torch_values_to_fix; //Only need to resolve values generated by tensorrt
255+
std::set<int> tensorrt_blocks_to_fix; //Need to resolve ALL non-tensor inputs
256+
251257
// update blocks_list
252258
std::unordered_set<int> updated_segments;
253259
for (auto& use : usage_counts) {
@@ -256,27 +262,25 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
256262
// kTorch segment.
257263
if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) {
258264
auto first_torch_id = use_info.torch_use_id.back();
259-
if (!updated_segments.count(first_torch_id)) {
260-
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
261-
// Torch-TensorRT doesn't support non-tensor inputs for a module.
262-
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]);
263-
auto next_iter = segmented_blocks_list.erase(idx_to_iter[first_torch_id]);
264-
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
265-
updated_segments.insert(first_torch_id);
266-
}
265+
torch_values_to_fix[first_torch_id].push_back(use.first);
267266
}
268267
// kTensorRT segments always need to inject nodes for the nonTensor inputs
269268
for (auto i : use_info.tensorrt_use_id) {
270-
if (!updated_segments.count(i)) {
271-
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
272-
// Torch-TensorRT doesn't support non-tensor inputs for a module.
273-
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]);
274-
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
275-
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
276-
updated_segments.insert(i);
277-
}
269+
tensorrt_blocks_to_fix.insert(i);
278270
}
279271
}
272+
for(auto torch_block_pair : torch_values_to_fix){
273+
auto to_inject_blocks = segmentBlocksWithSpecifiedInputs(segmented_blocks[torch_block_pair.first], torch_block_pair.second);
274+
auto next_iter = segmented_blocks_list.erase(idx_to_iter[torch_block_pair.first]);
275+
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
276+
}
277+
278+
for(auto i : tensorrt_blocks_to_fix){
279+
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]);
280+
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
281+
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
282+
}
283+
280284
segmented_blocks.clear();
281285
segmented_blocks.insert(segmented_blocks.begin(), segmented_blocks_list.begin(), segmented_blocks_list.end());
282286
return;

tests/core/partitioning/test_resolve_nontensor_inputs.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,113 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
257257
int count = count_trt_engines(fallback_g);
258258
ASSERT_TRUE(count == 2);
259259
}
260+
261+
TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
262+
/* parseIR does not support "= aten::_set_item" so we will build this graph manually
263+
const auto graph = R"IR(
264+
graph(%x : Tensor,
265+
%y : Tensor):
266+
%2 : str = prim::Constant[value="INS"]()
267+
%3 : str = prim::Constant[value="OUTS"]()
268+
%4 : bool = prim::Constant[value=0]()
269+
%5 : int = prim::Constant[value=-1]()
270+
%6 : Dict(str, Tensor) = prim::DictConstruct()
271+
= aten::_set_item(%6, %2, %x)
272+
%7 : Tensor = aten::__getitem__(%6, %2)
273+
%8 : Tensor = aten::lt(%7, %y)
274+
%9 : Tensor?[] = prim::ListConstruct(%8)
275+
%10 : int = prim::dtype(%7)
276+
%11 : Device = prim::device(%7)
277+
%12 : Tensor = aten::tensor(%5, %10, %11, %4)
278+
%13 : Tensor = aten::index_put_(%7, %9, %12, %4)
279+
= aten::_set_item(%6, %3, %7)
280+
%14 : Tensor = aten::__getitem__(%6, %2)
281+
%15 : Tensor = aten::__getitem__(%6, %3)
282+
return (%14, %15))IR";
283+
*/
284+
auto g = std::make_shared<torch::jit::Graph>();
285+
auto x = g->insertInput(0, "x");
286+
auto y = g->insertInput(1, "y");
287+
torch::jit::IValue ins_key("INS");
288+
auto ins_key_val = g->insertConstant(ins_key);
289+
torch::jit::IValue outs_key("OUTS");
290+
auto outs_key_val = g->insertConstant(outs_key);
291+
torch::jit::IValue zero(0);
292+
auto false_const_val = g->insertConstant(zero);
293+
false_const_val->setType(c10::BoolType::get());
294+
torch::jit::IValue neg_one(-1);
295+
auto neg_one_const_val = g->insertConstant(neg_one);
296+
auto dict_node = g->createDict(ins_key_val->type(), x->type(), torch::jit::ArrayRef<torch::jit::Value*>(), torch::jit::ArrayRef<torch::jit::Value*>());
297+
g->insertNode(dict_node);
298+
auto set_node = g->create(torch::jit::Symbol::fromQualString("aten::_set_item"), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val, x}, 0);
299+
g->insertNode(set_node);
300+
auto get_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val}, 1);
301+
g->insertNode(get_node);
302+
auto lt_node = g->create(torch::jit::Symbol::fromQualString("aten::lt"), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output(), y}, 1);
303+
g->insertNode(lt_node);
304+
auto list_node = g->createList(at::OptionalType::create(lt_node->output()->type()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output()});
305+
g->insertNode(list_node);
306+
auto dtype_node = g->create(torch::jit::Symbol::fromQualString("prim::dtype"), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()}, 1);
307+
dtype_node->output()->setType(neg_one_const_val->type());
308+
g->insertNode(dtype_node);
309+
auto device_node = g->create(torch::jit::Symbol::fromQualString("prim::device"), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()}, 1);
310+
device_node->output()->setType(c10::DeviceObjType::get());
311+
g->insertNode(device_node);
312+
auto tensor_node = g->create(torch::jit::Symbol::fromQualString("aten::tensor"), torch::jit::ArrayRef<torch::jit::Value*>{neg_one_const_val, dtype_node->output(), device_node->output(), false_const_val}, 1);
313+
g->insertNode(tensor_node);
314+
auto index_put_node = g->create(torch::jit::Symbol::fromQualString("aten::index_put_"),
315+
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output(), list_node->output(), tensor_node->output(), false_const_val}, 1);
316+
g->insertNode(index_put_node);
317+
auto out_set_node = g->create(torch::jit::Symbol::fromQualString("aten::_set_item"),
318+
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val, get_node->output()}, 0);
319+
g->insertNode(out_set_node);
320+
auto get_ins_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val}, 1);
321+
g->insertNode(get_ins_node);
322+
auto get_outs_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val}, 1);
323+
g->insertNode(get_outs_node);
324+
g->registerOutput(get_ins_node->output());
325+
g->registerOutput(get_outs_node->output());
326+
327+
torch_tensorrt::core::partitioning::PartitionInfo partition_info;
328+
partition_info.enabled = true;
329+
std::vector<torch_tensorrt::core::ir::Input> inputs;
330+
inputs.push_back(torch_tensorrt::core::ir::Input({4, 4}));
331+
inputs.push_back(torch_tensorrt::core::ir::Input({4, 4}));
332+
333+
std::unordered_map<const torch::jit::Value*, torch_tensorrt::core::ir::Input> inputs_map;
334+
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> input_types;
335+
for (size_t i = 0; i < g->inputs().size(); ++i) {
336+
inputs_map.insert({g->inputs()[i], inputs[i]});
337+
input_types.insert({g->inputs()[i], {at::kFloat}});
338+
}
339+
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
340+
auto segmented_blocks =
341+
torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info);
342+
343+
int torch_block_cnt = 0, trt_block_cnt = 0;
344+
for (const auto& segmented_block : segmented_blocks) {
345+
if (segmented_block.target() == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT) {
346+
++trt_block_cnt;
347+
ASSERT_TRUE(checkSegmentedBlockInputType(segmented_block, [](torch::jit::TypePtr type_ptr) {
348+
return type_ptr->isSubtypeOf(torch::jit::TensorType::get());
349+
}));
350+
} else {
351+
++torch_block_cnt;
352+
bool output_dict = false;
353+
bool input_dict = false;
354+
auto dict_type = dict_node->output()->type();
355+
for (auto in : segmented_block.raw_inputs()) {
356+
if(in->type()->isSubtypeOf(dict_type)){
357+
input_dict = true;
358+
}
359+
}
360+
for (auto out : segmented_block.raw_outputs()) {
361+
if(out->type()->isSubtypeOf(dict_type)){
362+
output_dict = true;
363+
}
364+
}
365+
EXPECT_TRUE(output_dict ^ input_dict);
366+
}
367+
}
368+
ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 2);
369+
}

0 commit comments

Comments
 (0)