Description
Is your feature request related to a problem? Please describe.
What we are doing now for graph segmentation in partitioning is going through each node one by one then divide them into 2 parts: TensorRT supported/ TensorRT not supported. This kind of naive segmentation induces many small issues like what we had in these 2 issues:
#1018
#1024
Although these bugs could be fixed by patching NonTensorInputs function in partitioning, we have many other side effects:
Target: TensorRT
Graph: graph(%z.1 : Tensor,
%y.1 : Tensor):
%0 : Tensor = aten::lt(%z.1, %y.1) # test_resolve.py:12:10
%3 : Tensor?[] = prim::ListConstruct(%0)
%4 : int = prim::dtype(%z.1)
return ()
For example, the graph above is a TensorRT segment, however, 2 nodes out of 3 is useless in this segment because they are producing Non-Tensor outputs which could not be used by later segments since TensorRT engines can't have non-Tensor outputs. As a result, even though we do the calculations for these 2 nodes, they are useless.
Describe the solution you'd like
We would want to apply some more advanced graph algorithm to identify the segmentation more ideally. In other words, we should find a better interface to segment the graph. This could lead to less resolve Non-Tensor Inputs related issues and could optimize each divided segments to ensure there are no useless node in these segments.
Additional context
Looking into them now, this is not complex, while we might need to do some testing after the refactoring.