@@ -98,9 +98,13 @@ std::vector<torch::jit::Node*> getDependencyNodes(
98
98
return stk;
99
99
}
100
100
101
- void find_all_fallback_nodes (std::unordered_map<torch::jit::Node*, int >& fallback_nodes) {
101
+ void find_all_fallback_nodes (
102
+ std::unordered_map<torch::jit::Node*, int >& initial_fallback_nodes,
103
+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes) {
104
+ // initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
105
+ // global_fallback_nodes are the fallback nodes that we maintain globally
102
106
std::queue<torch::jit::Node*> q;
103
- for (auto & node : fallback_nodes ) {
107
+ for (auto & node : initial_fallback_nodes ) {
104
108
q.push (node.first );
105
109
}
106
110
@@ -111,7 +115,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
111
115
// for every node that produces this fallback node's NonTensor input, they should fallback too
112
116
for (auto input : cur_node->inputs ()) {
113
117
if (!isTensor (input) && input->node ()->kind () != torch::jit::prim::Constant &&
114
- fallback_nodes .insert ({input->node (), 4 }).second ) {
118
+ global_fallback_nodes .insert ({input->node (), FallbackNodeType:: kNON_TENSOR }).second ) {
115
119
q.push (input->node ());
116
120
}
117
121
}
@@ -120,7 +124,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
120
124
if (!isTensor (output)) {
121
125
for (auto use : output->uses ()) {
122
126
auto node = use.user ;
123
- if (node->kind () != torch::jit::prim::Constant && fallback_nodes .insert ({node, 4 }).second ) {
127
+ if (node->kind () != torch::jit::prim::Constant && global_fallback_nodes .insert ({node, FallbackNodeType:: kNON_TENSOR }).second ) {
124
128
q.push (node);
125
129
}
126
130
}
@@ -225,12 +229,14 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
225
229
226
230
bool check_node_fallback (torch::jit::Node* n, const std::unordered_map<torch::jit::Node*, int >& fallback_nodes) {
227
231
if (fallback_nodes.count (n)) {
228
- if (fallback_nodes.at (n) == 0 ) {
232
+ if (fallback_nodes.at (n) == FallbackNodeType:: kUNSUPPORTED ) {
229
233
LOG_GRAPH (" Node not supported by conversion: " << util::node_info (n));
230
- } else if (fallback_nodes.at (n) == 1 ) {
234
+ } else if (fallback_nodes.at (n) == FallbackNodeType:: kOPERATOR_FALLBACK ) {
231
235
LOG_GRAPH (" Node explicitly set to run in torch: " << util::node_info (n));
232
- } else if (fallback_nodes.at (n) == 2 ) {
236
+ } else if (fallback_nodes.at (n) == FallbackNodeType:: kMODULE_FALLBACK ) {
233
237
LOG_GRAPH (" Node is within a module set to run in torch: " << util::node_info (n));
238
+ } else if (fallback_nodes.at (n) == FallbackNodeType::kMIN_BLOCK_FALLBACK ) {
239
+ LOG_GRAPH (" Node fallback to Torch because of min_block_size" << util::node_info (n));
234
240
} else {
235
241
LOG_GRAPH (
236
242
" Node fallback to Torch because the NonTensor dependencies with other fallback nodes: "
@@ -267,39 +273,91 @@ void get_fallback_nodes(
267
273
268
274
// If the op is not supported by the conversion phase it should run in PyTorch
269
275
if (!conversion::OpSupported (n)) {
270
- fallback_nodes.insert ({n, 0 });
276
+ fallback_nodes.insert ({n, FallbackNodeType:: kUNSUPPORTED });
271
277
}
272
278
273
279
// If the user specifies the op to run in Torch it should run in PyTorch
274
280
if (forced_fallback_ops.find (n->kind ().toQualString ()) != forced_fallback_ops.end ()) {
275
- fallback_nodes.insert ({n, 1 });
281
+ fallback_nodes.insert ({n, FallbackNodeType:: kOPERATOR_FALLBACK });
276
282
}
277
283
278
284
// If the user specifies the module containing this op to run in torch it should run in PyTorch
279
285
const auto to_compile_sym = c10::Symbol::attr (" to_compile" );
280
286
if (n->hasAttribute (to_compile_sym) && n->i (to_compile_sym) == (int64_t ) false ) {
281
- fallback_nodes.insert ({n, 2 });
287
+ fallback_nodes.insert ({n, FallbackNodeType:: kMODULE_FALLBACK });
282
288
}
283
289
}
284
290
return ;
285
291
}
286
292
293
+ std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size (
294
+ torch::jit::Block* block,
295
+ const std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes,
296
+ size_t min_block_size) {
297
+ auto nodes = block->nodes ();
298
+ std::vector<torch::jit::Node*> cur_trt_nodes;
299
+ std::vector<torch::jit::Node*> min_block_fallback_nodes;
300
+ for (const auto n : nodes) {
301
+ if (n->kind () == torch::jit::prim::Constant)
302
+ continue ;
303
+
304
+ // check if current node fallback or not
305
+ if (!global_fallback_nodes.count (n)) {
306
+ // if this node is not in fallback nodes, then it's in trt segments
307
+ cur_trt_nodes.push_back (n);
308
+ } else {
309
+ if (cur_trt_nodes.size () < min_block_size) {
310
+ min_block_fallback_nodes.insert (min_block_fallback_nodes.end (), cur_trt_nodes.begin (), cur_trt_nodes.end ());
311
+ }
312
+ cur_trt_nodes.clear ();
313
+ }
314
+ }
315
+ if (cur_trt_nodes.size () < min_block_size) {
316
+ min_block_fallback_nodes.insert (min_block_fallback_nodes.end (), cur_trt_nodes.begin (), cur_trt_nodes.end ());
317
+ }
318
+ return min_block_fallback_nodes;
319
+ }
320
+
321
+ void find_min_block_size_fallback_nodes (
322
+ torch::jit::Block* block,
323
+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes,
324
+ size_t min_block_size) {
325
+ // first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
326
+ auto min_block_fallback_nodes = traverse_nodes_for_min_block_size (block, global_fallback_nodes, min_block_size);
327
+ std::unordered_map<torch::jit::Node*, int > initial_fallback_nodes;
328
+
329
+ // keep fallback until all segments meet the min_block_size requirement
330
+ while (!min_block_fallback_nodes.empty ()) {
331
+ for (const auto i : min_block_fallback_nodes) {
332
+ initial_fallback_nodes.insert ({i, FallbackNodeType::kMIN_BLOCK_FALLBACK });
333
+ }
334
+ global_fallback_nodes.insert (initial_fallback_nodes.begin (), initial_fallback_nodes.end ());
335
+ // find the fallback nodes because of dependency with min_block_size caused fallback nodes
336
+ find_all_fallback_nodes (initial_fallback_nodes, global_fallback_nodes);
337
+ // keep traverse the graph until there is no node fallback because of min_block_size
338
+ min_block_fallback_nodes = traverse_nodes_for_min_block_size (block, global_fallback_nodes, min_block_size);
339
+ }
340
+ }
341
+
287
342
PartitionedGraph segment_graph (
288
343
torch::jit::Block* block,
289
344
const PartitionInfo& partition_info,
290
- std::unordered_map<torch::jit::Node*, int >& fallback_nodes ) {
345
+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes ) {
291
346
auto min_block_size = partition_info.min_block_size ;
292
347
std::unordered_set<std::string> forced_fallback_ops (
293
348
partition_info.forced_fallback_operators .begin (), partition_info.forced_fallback_operators .end ());
294
349
295
350
// get the initial fallback nodes (nodes that are unsupported or forced fallback)
296
- get_fallback_nodes (block, forced_fallback_ops, fallback_nodes );
351
+ get_fallback_nodes (block, forced_fallback_ops, global_fallback_nodes );
297
352
298
353
// For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this
299
354
// input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node
300
355
// that produces this input should also fallback
301
356
// TODO: don't need to fallback the TensorList related nodes once the collection feature is supported
302
- find_all_fallback_nodes (fallback_nodes);
357
+ find_all_fallback_nodes (global_fallback_nodes, global_fallback_nodes);
358
+
359
+ // find all fallback nodes because of the min_block_size requirement
360
+ find_min_block_size_fallback_nodes (block, global_fallback_nodes, min_block_size);
303
361
304
362
auto nodes = block->nodes ();
305
363
@@ -313,7 +371,7 @@ PartitionedGraph segment_graph(
313
371
continue ;
314
372
}
315
373
316
- if (check_node_fallback (n, fallback_nodes )) {
374
+ if (check_node_fallback (n, global_fallback_nodes )) {
317
375
in_prog_trt_blk_nodes.push_back (n);
318
376
319
377
// If there is an active PyTorch block and we have passed the threshold for a valid TRT
@@ -379,11 +437,11 @@ PartitionedGraph Partition(
379
437
torch::jit::Block* block,
380
438
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
381
439
const PartitionInfo& partition_info,
382
- std::unordered_map<torch::jit::Node*, int >& fallback_nodes ) {
440
+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes ) {
383
441
LOG_DEBUG (partition_info);
384
442
// segment lowering global graph into blocks
385
443
LOG_DEBUG (" Parititioning source module into PyTorch and TensorRT sub blocks" );
386
- PartitionedGraph segmented_blocks = segment_graph (block, partition_info, fallback_nodes );
444
+ PartitionedGraph segmented_blocks = segment_graph (block, partition_info, global_fallback_nodes );
387
445
388
446
// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
389
447
0 commit comments