Closed
Description
When we implement this feature #1031, we raised a PR here: #1140
In this PR, we used DFS for re-segmentation. However, as we can see from here:
bool check_node_fallback(torch::jit::Node* n, const std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
if (fallback_nodes.count(n)) {
if (fallback_nodes.at(n) == 0) {
LOG_GRAPH("Node not supported by conversion: " << util::node_info(n));
} else if (fallback_nodes.at(n) == 1) {
LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n));
} else if (fallback_nodes.at(n) == 2) {
LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n));
} else {
LOG_GRAPH(
"Node fallback to Torch because the NonTensor dependencies with other fallback nodes: "
<< util::node_info(n));
}
return false;
}
LOG_GRAPH("Node is going to run in TensorRT: " << util::node_info(n));
return true;
}
In this function, it doesn't cover cases that nodes fallback because of min_block_size. Ff we do DFS first then segment graph according to min_block_size, then there would be NonTensor inputs and outputs again.
So, what we are going to do here is peudo-segment graph first with min_block_size then use DFS to determine the fallback nodes.
There are at lease 2 passes for segmentation. Many factors need to be considered here for implementation.