Skip to content

Fix: Resolve the conflicts for ResolveNonTensorInputs #1032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 8, 2022
8 changes: 3 additions & 5 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,9 @@ auto aten_registrations TORCHTRT_UNUSED =
return {};
}
},
EvalOptions().validSchemas({
"aten::add.int(int a, int b) -> (int)",
"aten::add.float(float a, float b) -> (float)",
"aten::add.str(str a, str b) -> (str)"
})})
EvalOptions().validSchemas({"aten::add.int(int a, int b) -> (int)",
"aten::add.float(float a, float b) -> (float)",
"aten::add.str(str a, str b) -> (str)"})})
.evaluator({c10::Symbol::fromQualString("aten::add_"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (args.at(n->input(0)).IValue()->isList()) {
Expand Down
10 changes: 5 additions & 5 deletions core/lowering/passes/exception_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ struct ExceptionOrPassPatternElimination {
bool arm1_starts_with_exception = (*arm1_start)->kind() == prim::RaiseException;
bool arm2_starts_with_exception = (*arm2_start)->kind() == prim::RaiseException;

//if (!arm1_starts_with_exception && !arm2_starts_with_exception) {
// Neither arm matches the pattern
// return false;
// if (!arm1_starts_with_exception && !arm2_starts_with_exception) {
// Neither arm matches the pattern
// return false;
//}

/// Check if this Node hosts a pattern like so:
Expand Down Expand Up @@ -90,7 +90,7 @@ struct ExceptionOrPassPatternElimination {
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
auto n = *it;
if (n->kind() == prim::If && isExceptionOrPassNode(n)) {
LOG_ERROR("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl);
LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl);
it.destroyCurrent();
}
}
Expand All @@ -104,7 +104,7 @@ void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
ExceptionOrPassPatternElimination eppe(std::move(graph));
eppe.run();
if (graph) {
LOG_ERROR("Post Eliminate Exception or Pass Patterns: " << *graph);
LOG_GRAPH("Post Eliminate Exception or Pass Patterns: " << *graph);
}
}

Expand Down
5 changes: 4 additions & 1 deletion core/lowering/register_trt_placeholder_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() {
RegisterOperators trt_placeholder_ops_reg({
/// Op marks a Tensor to be conveted from an Torch Tensor
/// to a TRT constant Tensor
Operator("trt::const(Tensor val) -> Tensor", [](Stack& stack) { /*noop*/ }, aliasAnalysisFromSchema()),
Operator(
"trt::const(Tensor val) -> Tensor",
[](Stack& stack) { /*noop*/ },
aliasAnalysisFromSchema()),
});

} // namespace jit
Expand Down
68 changes: 38 additions & 30 deletions core/partitioning/partitioning.cpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ bool containNonTensorOutputs(torch::jit::Node* n) {
return false;
}

std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*>& vals) {
std::vector<torch::jit::Node*> getDependencyNodes(const std::vector<torch::jit::Value*>& vals) {
// use bfs to get the DAG dependency nodes for input value
std::queue<torch::jit::Value*, std::deque<torch::jit::Value*>> q(
std::deque<torch::jit::Value*>(vals.begin(), vals.end()));
Expand Down Expand Up @@ -169,17 +169,10 @@ std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock
return std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock>(append_blocks, trt_block);
}

PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
// reconstruct segmented_block if this block requires nonTensor input
std::vector<torch::jit::Value*> nontensor_inputs;
// Gather all non-tensor inputs for this seg_block
for (auto input : seg_block.raw_inputs()) {
if (!isTensorOrTensorList(input)) {
nontensor_inputs.push_back(input);
}
}

std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(nontensor_inputs);
PartitionedGraph segmentBlocksWithSpecifiedInputs(
SegmentedBlock& seg_block,
std::vector<torch::jit::Value*>& inputs_to_resolve) {
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(inputs_to_resolve);
PartitionedGraph new_seg_blocks;
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the
// dependency nodes at the beginning of the current segmented_block and return this merged segmented_block
Expand All @@ -194,7 +187,7 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
}
} else {
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
std::unordered_set<torch::jit::Value*> nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end());
std::unordered_set<torch::jit::Value*> inputs_to_resolve_set(inputs_to_resolve.begin(), inputs_to_resolve.end());
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;

// take all nodes with non_tensor_inputs as initial dirty nodes (nodes that should be in PyTorch block), then we use
Expand All @@ -205,7 +198,7 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());

for (auto n : seg_block.raw_nodes()) {
if (containTargetInputs(n, nontensor_inputs_set)) {
if (containTargetInputs(n, inputs_to_resolve_set)) {
dirty_nodes.insert(n);
}
}
Expand Down Expand Up @@ -237,6 +230,18 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
return new_seg_blocks;
}

PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
// reconstruct segmented_block if this block requires nonTensor input
std::vector<torch::jit::Value*> inputs_to_resolve;
// Gather all non-tensor inputs for this block
for (auto input : seg_block.raw_inputs()) {
if (!isTensorOrTensorList(input)) {
inputs_to_resolve.push_back(input);
}
}
return segmentBlocksWithSpecifiedInputs(seg_block, inputs_to_resolve);
}

std::unordered_map<torch::jit::Value*, usage_info> getInputUsageCounts(
const PartitionedGraph& segmented_blocks,
const std::function<bool(torch::jit::Value*)>& condition) {
Expand Down Expand Up @@ -284,6 +289,10 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
segmented_blocks, [](torch::jit::Value* input) -> bool { return !isTensorOrTensorList(input); });
auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list);

std::map<int, std::vector<torch::jit::Value*>>
torch_values_to_fix; // Only need to resolve values generated by tensorrt
std::set<int> tensorrt_blocks_to_fix; // Need to resolve ALL non-tensor inputs

// update blocks_list
std::unordered_set<int> updated_segments;
for (auto& use : usage_counts) {
Expand All @@ -292,27 +301,26 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
// kTorch segment.
if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) {
auto first_torch_id = use_info.torch_use_id.back();
if (!updated_segments.count(first_torch_id)) {
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
// Torch-TensorRT doesn't support non-tensor inputs for a module.
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]);
auto next_iter = segmented_blocks_list.erase(idx_to_iter[first_torch_id]);
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
updated_segments.insert(first_torch_id);
}
torch_values_to_fix[first_torch_id].push_back(use.first);
}
// kTensorRT segments always need to inject nodes for the nonTensor inputs
for (auto i : use_info.tensorrt_use_id) {
if (!updated_segments.count(i)) {
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
// Torch-TensorRT doesn't support non-tensor inputs for a module.
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]);
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
updated_segments.insert(i);
}
tensorrt_blocks_to_fix.insert(i);
}
}
for (auto torch_block_pair : torch_values_to_fix) {
auto to_inject_blocks =
segmentBlocksWithSpecifiedInputs(segmented_blocks[torch_block_pair.first], torch_block_pair.second);
auto next_iter = segmented_blocks_list.erase(idx_to_iter[torch_block_pair.first]);
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
}

for (auto i : tensorrt_blocks_to_fix) {
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]);
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
}

segmented_blocks.clear();
segmented_blocks.insert(segmented_blocks.begin(), segmented_blocks_list.begin(), segmented_blocks_list.end());
return;
Expand Down
6 changes: 3 additions & 3 deletions tests/core/lowering/test_exception_elimination_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) {
auto if_block0 = if_node->addBlock();
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
if_block0->appendNode(exception_node);
/*auto if_block1 =*/ if_node->addBlock();
/*auto if_block1 =*/if_node->addBlock();
g->insertNode(if_node);
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
g->insertNode(cat_node);
Expand Down Expand Up @@ -97,7 +97,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block1) {
bool_node->output()->setType(torch::jit::BoolType::get());
g->insertNode(bool_node);
auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0);
/*auto if_block0 = */if_node->addBlock();
/*auto if_block0 = */ if_node->addBlock();
auto if_block1 = if_node->addBlock();
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
if_block1->appendNode(exception_node);
Expand Down Expand Up @@ -154,7 +154,7 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) {
auto if_block0 = if_node->addBlock();
auto append_node = g->create(torch::jit::aten::append, {list_node->output(), y});
if_block0->appendNode(append_node);
/*auto if_block1 = */if_node->addBlock();
/*auto if_block1 = */ if_node->addBlock();
g->insertNode(if_node);
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
g->insertNode(cat_node);
Expand Down
144 changes: 144 additions & 0 deletions tests/core/partitioning/test_resolve_nontensor_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,147 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
int count = count_trt_engines(fallback_g);
ASSERT_TRUE(count == 2);
}

TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
/* parseIR does not support "= aten::_set_item" so we will build this graph manually
const auto graph = R"IR(
graph(%x : Tensor,
%y : Tensor):
%2 : str = prim::Constant[value="INS"]()
%3 : str = prim::Constant[value="OUTS"]()
%4 : bool = prim::Constant[value=0]()
%5 : int = prim::Constant[value=-1]()
%6 : Dict(str, Tensor) = prim::DictConstruct()
= aten::_set_item(%6, %2, %x)
%7 : Tensor = aten::__getitem__(%6, %2)
%8 : Tensor = aten::lt(%7, %y)
%9 : Tensor?[] = prim::ListConstruct(%8)
%10 : int = prim::dtype(%7)
%11 : Device = prim::device(%7)
%12 : Tensor = aten::tensor(%5, %10, %11, %4)
%13 : Tensor = aten::index_put_(%7, %9, %12, %4)
= aten::_set_item(%6, %3, %7)
%14 : Tensor = aten::__getitem__(%6, %2)
%15 : Tensor = aten::__getitem__(%6, %3)
return (%14, %15))IR";
*/
auto g = std::make_shared<torch::jit::Graph>();
auto x = g->insertInput(0, "x");
auto y = g->insertInput(1, "y");
torch::jit::IValue ins_key("INS");
auto ins_key_val = g->insertConstant(ins_key);
torch::jit::IValue outs_key("OUTS");
auto outs_key_val = g->insertConstant(outs_key);
torch::jit::IValue zero(0);
auto false_const_val = g->insertConstant(zero);
false_const_val->setType(c10::BoolType::get());
torch::jit::IValue neg_one(-1);
auto neg_one_const_val = g->insertConstant(neg_one);
auto dict_node = g->createDict(
ins_key_val->type(),
x->type(),
torch::jit::ArrayRef<torch::jit::Value*>(),
torch::jit::ArrayRef<torch::jit::Value*>());
g->insertNode(dict_node);
auto set_node = g->create(
torch::jit::Symbol::fromQualString("aten::_set_item"),
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val, x},
0);
g->insertNode(set_node);
auto get_node = g->create(
torch::jit::Symbol::fromQualString("aten::__getitem__"),
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val},
1);
g->insertNode(get_node);
auto lt_node = g->create(
torch::jit::Symbol::fromQualString("aten::lt"),
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output(), y},
1);
g->insertNode(lt_node);
auto list_node = g->createList(
at::OptionalType::create(lt_node->output()->type()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output()});
g->insertNode(list_node);
auto dtype_node = g->create(
torch::jit::Symbol::fromQualString("prim::dtype"),
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()},
1);
dtype_node->output()->setType(neg_one_const_val->type());
g->insertNode(dtype_node);
auto device_node = g->create(
torch::jit::Symbol::fromQualString("prim::device"),
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()},
1);
device_node->output()->setType(c10::DeviceObjType::get());
g->insertNode(device_node);
auto tensor_node = g->create(
torch::jit::Symbol::fromQualString("aten::tensor"),
torch::jit::ArrayRef<torch::jit::Value*>{
neg_one_const_val, dtype_node->output(), device_node->output(), false_const_val},
1);
g->insertNode(tensor_node);
auto index_put_node = g->create(
torch::jit::Symbol::fromQualString("aten::index_put_"),
torch::jit::ArrayRef<torch::jit::Value*>{
get_node->output(), list_node->output(), tensor_node->output(), false_const_val},
1);
g->insertNode(index_put_node);
auto out_set_node = g->create(
torch::jit::Symbol::fromQualString("aten::_set_item"),
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val, get_node->output()},
0);
g->insertNode(out_set_node);
auto get_ins_node = g->create(
torch::jit::Symbol::fromQualString("aten::__getitem__"),
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val},
1);
g->insertNode(get_ins_node);
auto get_outs_node = g->create(
torch::jit::Symbol::fromQualString("aten::__getitem__"),
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val},
1);
g->insertNode(get_outs_node);
g->registerOutput(get_ins_node->output());
g->registerOutput(get_outs_node->output());

torch_tensorrt::core::partitioning::PartitionInfo partition_info;
partition_info.enabled = true;
std::vector<torch_tensorrt::core::ir::Input> inputs;
inputs.push_back(torch_tensorrt::core::ir::Input({4, 4}));
inputs.push_back(torch_tensorrt::core::ir::Input({4, 4}));

std::unordered_map<const torch::jit::Value*, torch_tensorrt::core::ir::Input> inputs_map;
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> input_types;
for (size_t i = 0; i < g->inputs().size(); ++i) {
inputs_map.insert({g->inputs()[i], inputs[i]});
input_types.insert({g->inputs()[i], {at::kFloat}});
}
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
auto segmented_blocks = torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info);

int torch_block_cnt = 0, trt_block_cnt = 0;
for (const auto& segmented_block : segmented_blocks) {
if (segmented_block.target() == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT) {
++trt_block_cnt;
ASSERT_TRUE(checkSegmentedBlockInputType(segmented_block, [](torch::jit::TypePtr type_ptr) {
return type_ptr->isSubtypeOf(torch::jit::TensorType::get());
}));
} else {
++torch_block_cnt;
bool output_dict = false;
bool input_dict = false;
auto dict_type = dict_node->output()->type();
for (auto in : segmented_block.raw_inputs()) {
if (in->type()->isSubtypeOf(dict_type)) {
input_dict = true;
}
}
for (auto out : segmented_block.raw_outputs()) {
if (out->type()->isSubtypeOf(dict_type)) {
output_dict = true;
}
}
EXPECT_TRUE(output_dict ^ input_dict);
}
}
ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 2);
}