@@ -62,6 +62,14 @@ void registerSegmentInOutShape(SegmentedBlock &seg_block, std::unordered_map<tor
62
62
// create a module to run the graph
63
63
auto g = seg_block.g ();
64
64
auto copy_g = g->copy ();
65
+ if (seg_block.raw_outputs ().size () > 1 ) {
66
+ auto new_output_node = copy_g->appendNode (copy_g->createTuple (copy_g->outputs ()));
67
+ for (int idx = copy_g->outputs ().size () - 1 ; idx >= 0 ; --idx) {
68
+ copy_g->eraseOutput (idx);
69
+ }
70
+ copy_g->registerOutput (new_output_node->outputs ()[0 ]);
71
+ }
72
+
65
73
torch::jit::script::Module cur_mod (c10::QualifiedName (" module" ));
66
74
67
75
auto self = copy_g->insertInput (0 , " self_1" );
@@ -140,25 +148,45 @@ void registerSegmentsInputsOutputs(std::vector<SegmentedBlock> &segmented_blocks
140
148
return ;
141
149
}
142
150
143
- std::vector<SegmentedBlock> segment_graph (std::shared_ptr<torch::jit::Graph> g, std::vector<conversion::InputRange>& input_ranges) {
151
+ void merge_nodes (std::vector<torch::jit::Node*> &pytorch_nodes, std::vector<torch::jit::Node*> &tensorrt_nodes,
152
+ std::vector<SegmentedBlock> &segmented_blocks, size_t min_block_size) {
153
+ if (!tensorrt_nodes.empty ()) {
154
+ if (tensorrt_nodes.size () < min_block_size) {
155
+ pytorch_nodes.insert (pytorch_nodes.end (), tensorrt_nodes.begin (), tensorrt_nodes.end ());
156
+ } else {
157
+ if (!pytorch_nodes.empty ()) segmented_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
158
+ segmented_blocks.emplace_back (SegmentedBlock::kTensorRT , tensorrt_nodes);
159
+ pytorch_nodes.clear ();
160
+ }
161
+ tensorrt_nodes.clear ();
162
+ }
163
+ }
164
+
165
+ std::vector<SegmentedBlock> segment_graph (std::shared_ptr<torch::jit::Graph> g,
166
+ std::vector<conversion::InputRange>& input_ranges,
167
+ const conversion::TorchFallback &fallback_info) {
168
+ auto min_block_size = fallback_info.min_block_size ;
169
+ std::unordered_set<std::string> forced_fallback_operators (fallback_info.forced_fallback_operators .begin (), fallback_info.forced_fallback_operators .end ());
144
170
std::vector<SegmentedBlock> segmented_blocks;
145
171
146
172
auto nodes = g->block ()->nodes ();
147
173
148
174
// segment the nodes
175
+ std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
176
+
149
177
for (const auto n : nodes) {
150
178
if (n->kind () == torch::jit::prim::Constant) continue ;
179
+ std::string node_string (n->kind ().toQualString ());
151
180
152
- auto block_target = conversion::OpSupported (n) ? SegmentedBlock::kTensorRT : SegmentedBlock::kTorch ;
153
-
154
- if (segmented_blocks.empty () || block_target != segmented_blocks.back ().target ()) {
155
- SegmentedBlock cur_block (block_target);
156
- cur_block.appendNode (n);
157
- segmented_blocks.push_back (cur_block);
181
+ if (conversion::OpSupported (n) && !forced_fallback_operators.count (node_string)) {
182
+ tensorrt_nodes.push_back (n);
158
183
} else {
159
- segmented_blocks.back ().appendNode (n);
184
+ merge_nodes (pytorch_nodes, tensorrt_nodes, segmented_blocks, min_block_size);
185
+ pytorch_nodes.push_back (n);
160
186
}
161
187
}
188
+ merge_nodes (pytorch_nodes, tensorrt_nodes, segmented_blocks, min_block_size);
189
+ if (!pytorch_nodes.empty ()) segmented_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
162
190
163
191
registerSegmentsInputsOutputs (segmented_blocks, g);
164
192
0 commit comments