Skip to content

Commit 2f896b3

Browse files
authored
Merge pull request #1195 from pytorch/support_min_block_size
feat: support min_block_size != 1 caused fallback nodes re-segmentation
2 parents e07687d + 52abece commit 2f896b3

File tree

4 files changed

+129
-32
lines changed

4 files changed

+129
-32
lines changed

core/compiler.cpp

+1-16
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ GraphAndMapping ConstructFallbackGraph(
240240
}
241241

242242
for (auto& seg_block : segmented_blocks) {
243-
LOG_INFO(*seg_block.g() << "(GraphInSegmentedBlock)\n");
243+
LOG_INFO(seg_block << "(GraphInSegmentedBlock)\n");
244244
std::ostringstream trt_engine_id;
245245
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
246246

@@ -372,15 +372,6 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
372372
// Infer the type of an input from the weights of the calculation
373373
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
374374

375-
// // GPU default WS size : 1 GB
376-
// // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
377-
// auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
378-
// auto device_spec = cfg.convert_info.engine_settings.device;
379-
// auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
380-
// if (workspace_size == 0) {
381-
// cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
382-
// }
383-
384375
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
385376

386377
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
@@ -391,14 +382,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
391382
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) {
392383
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");
393384

394-
// // GPU default WS size : 1 GB
395-
// // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
396-
// auto workspace_size = cfg.convert_info.engine_settings.workspace_size;
397385
auto device_spec = cfg.convert_info.engine_settings.device;
398386
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
399-
// if (workspace_size == 0) {
400-
// cfg.convert_info.engine_settings.workspace_size = GetRecommendedWorkspaceSize(cuda_device);
401-
// }
402387

403388
for (const torch::jit::Method& method : mod.get_methods()) {
404389
if (method.name().compare("forward") == 0) {

core/partitioning/partitioning.cpp

+74-16
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,13 @@ std::vector<torch::jit::Node*> getDependencyNodes(
9898
return stk;
9999
}
100100

101-
void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
101+
void find_all_fallback_nodes(
102+
std::unordered_map<torch::jit::Node*, int>& initial_fallback_nodes,
103+
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
104+
// initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
105+
// global_fallback_nodes are the fallback nodes that we maintain globally
102106
std::queue<torch::jit::Node*> q;
103-
for (auto& node : fallback_nodes) {
107+
for (auto& node : initial_fallback_nodes) {
104108
q.push(node.first);
105109
}
106110

@@ -111,7 +115,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
111115
// for every node that produces this fallback node's NonTensor input, they should fallback too
112116
for (auto input : cur_node->inputs()) {
113117
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
114-
fallback_nodes.insert({input->node(), 4}).second) {
118+
global_fallback_nodes.insert({input->node(), FallbackNodeType::kNON_TENSOR}).second) {
115119
q.push(input->node());
116120
}
117121
}
@@ -120,7 +124,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
120124
if (!isTensor(output)) {
121125
for (auto use : output->uses()) {
122126
auto node = use.user;
123-
if (node->kind() != torch::jit::prim::Constant && fallback_nodes.insert({node, 4}).second) {
127+
if (node->kind() != torch::jit::prim::Constant && global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
124128
q.push(node);
125129
}
126130
}
@@ -225,12 +229,14 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
225229

226230
bool check_node_fallback(torch::jit::Node* n, const std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
227231
if (fallback_nodes.count(n)) {
228-
if (fallback_nodes.at(n) == 0) {
232+
if (fallback_nodes.at(n) == FallbackNodeType::kUNSUPPORTED) {
229233
LOG_GRAPH("Node not supported by conversion: " << util::node_info(n));
230-
} else if (fallback_nodes.at(n) == 1) {
234+
} else if (fallback_nodes.at(n) == FallbackNodeType::kOPERATOR_FALLBACK) {
231235
LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n));
232-
} else if (fallback_nodes.at(n) == 2) {
236+
} else if (fallback_nodes.at(n) == FallbackNodeType::kMODULE_FALLBACK) {
233237
LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n));
238+
} else if (fallback_nodes.at(n) == FallbackNodeType::kMIN_BLOCK_FALLBACK) {
239+
LOG_GRAPH("Node fallback to Torch because of min_block_size" << util::node_info(n));
234240
} else {
235241
LOG_GRAPH(
236242
"Node fallback to Torch because the NonTensor dependencies with other fallback nodes: "
@@ -267,39 +273,91 @@ void get_fallback_nodes(
267273

268274
// If the op is not supported by the conversion phase it should run in PyTorch
269275
if (!conversion::OpSupported(n)) {
270-
fallback_nodes.insert({n, 0});
276+
fallback_nodes.insert({n, FallbackNodeType::kUNSUPPORTED});
271277
}
272278

273279
// If the user specifies the op to run in Torch it should run in PyTorch
274280
if (forced_fallback_ops.find(n->kind().toQualString()) != forced_fallback_ops.end()) {
275-
fallback_nodes.insert({n, 1});
281+
fallback_nodes.insert({n, FallbackNodeType::kOPERATOR_FALLBACK});
276282
}
277283

278284
// If the user specifies the module containing this op to run in torch it should run in PyTorch
279285
const auto to_compile_sym = c10::Symbol::attr("to_compile");
280286
if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) {
281-
fallback_nodes.insert({n, 2});
287+
fallback_nodes.insert({n, FallbackNodeType::kMODULE_FALLBACK});
282288
}
283289
}
284290
return;
285291
}
286292

293+
std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size(
294+
torch::jit::Block* block,
295+
const std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes,
296+
size_t min_block_size) {
297+
auto nodes = block->nodes();
298+
std::vector<torch::jit::Node*> cur_trt_nodes;
299+
std::vector<torch::jit::Node*> min_block_fallback_nodes;
300+
for (const auto n : nodes) {
301+
if (n->kind() == torch::jit::prim::Constant)
302+
continue;
303+
304+
// check if current node fallback or not
305+
if (!global_fallback_nodes.count(n)) {
306+
// if this node is not in fallback nodes, then it's in trt segments
307+
cur_trt_nodes.push_back(n);
308+
} else {
309+
if (cur_trt_nodes.size() < min_block_size) {
310+
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
311+
}
312+
cur_trt_nodes.clear();
313+
}
314+
}
315+
if (cur_trt_nodes.size() < min_block_size) {
316+
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
317+
}
318+
return min_block_fallback_nodes;
319+
}
320+
321+
void find_min_block_size_fallback_nodes(
322+
torch::jit::Block* block,
323+
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes,
324+
size_t min_block_size) {
325+
// first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
326+
auto min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size);
327+
std::unordered_map<torch::jit::Node*, int> initial_fallback_nodes;
328+
329+
// keep fallback until all segments meet the min_block_size requirement
330+
while (!min_block_fallback_nodes.empty()) {
331+
for (const auto i : min_block_fallback_nodes) {
332+
initial_fallback_nodes.insert({i, FallbackNodeType::kMIN_BLOCK_FALLBACK});
333+
}
334+
global_fallback_nodes.insert(initial_fallback_nodes.begin(), initial_fallback_nodes.end());
335+
// find the fallback nodes because of dependency with min_block_size caused fallback nodes
336+
find_all_fallback_nodes(initial_fallback_nodes, global_fallback_nodes);
337+
// keep traverse the graph until there is no node fallback because of min_block_size
338+
min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size);
339+
}
340+
}
341+
287342
PartitionedGraph segment_graph(
288343
torch::jit::Block* block,
289344
const PartitionInfo& partition_info,
290-
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
345+
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
291346
auto min_block_size = partition_info.min_block_size;
292347
std::unordered_set<std::string> forced_fallback_ops(
293348
partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end());
294349

295350
// get the initial fallback nodes (nodes that are unsupported or forced fallback)
296-
get_fallback_nodes(block, forced_fallback_ops, fallback_nodes);
351+
get_fallback_nodes(block, forced_fallback_ops, global_fallback_nodes);
297352

298353
// For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this
299354
// input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node
300355
// that produces this input should also fallback
301356
// TODO: don't need to fallback the TensorList related nodes once the collection feature is supported
302-
find_all_fallback_nodes(fallback_nodes);
357+
find_all_fallback_nodes(global_fallback_nodes, global_fallback_nodes);
358+
359+
// find all fallback nodes because of the min_block_size requirement
360+
find_min_block_size_fallback_nodes(block, global_fallback_nodes, min_block_size);
303361

304362
auto nodes = block->nodes();
305363

@@ -313,7 +371,7 @@ PartitionedGraph segment_graph(
313371
continue;
314372
}
315373

316-
if (check_node_fallback(n, fallback_nodes)) {
374+
if (check_node_fallback(n, global_fallback_nodes)) {
317375
in_prog_trt_blk_nodes.push_back(n);
318376

319377
// If there is an active PyTorch block and we have passed the threshold for a valid TRT
@@ -379,11 +437,11 @@ PartitionedGraph Partition(
379437
torch::jit::Block* block,
380438
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
381439
const PartitionInfo& partition_info,
382-
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
440+
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
383441
LOG_DEBUG(partition_info);
384442
// segment lowering global graph into blocks
385443
LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");
386-
PartitionedGraph segmented_blocks = segment_graph(block, partition_info, fallback_nodes);
444+
PartitionedGraph segmented_blocks = segment_graph(block, partition_info, global_fallback_nodes);
387445

388446
// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
389447

core/partitioning/partitioning.h

+14
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@ namespace partitioning {
1616

1717
typedef std::vector<SegmentedBlock> PartitionedGraph;
1818

19+
enum FallbackNodeType {
20+
/// Node is not supported by TensorRT
21+
kUNSUPPORTED,
22+
/// Node is explicitly forced to fallback to Pytorch due to operator fallback
23+
kOPERATOR_FALLBACK,
24+
/// Node is explicitly forced to fallback to Pytorch due to module fallback
25+
kMODULE_FALLBACK,
26+
/// This node is in a TRT segment which does not satisfy min_block_size
27+
/// and hence is forced to fallback.
28+
kMIN_BLOCK_FALLBACK,
29+
/// This node produces/consumes non-tensor inputs
30+
kNON_TENSOR,
31+
};
32+
1933
PartitionedGraph segment_graph(
2034
torch::jit::Block* block,
2135
const PartitionInfo& partition_info,

tests/core/partitioning/test_segmentation.cpp

+40
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,46 @@ TEST(Partitioning, SegmentSequentialModelWithMinBlockSizeCorrectly) {
120120
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2}, {3, 4}}));
121121
}
122122

123+
TEST(Partitioning, SegmentModelWithMinBlockSizeCausedFallbackCorrectly) {
124+
const auto graph = R"IR(
125+
graph(%0 : Tensor,
126+
%1 : Tensor,
127+
%2 : Tensor):
128+
%3 : int[] = prim::Constant[value=[-1, 5]]()
129+
%4 : int[] = prim::Constant[value=[-1]]()
130+
%5 : int = prim::Constant[value=2]()
131+
%6 : int = prim::Constant[value=4]()
132+
%7 : int = prim::Constant[value=5]()
133+
%8 : int = prim::Constant[value=0]()
134+
%9 : bool = prim::Constant[value=0]()
135+
%10 : NoneType = prim::Constant()
136+
%11 : int = prim::Constant[value=1]()
137+
%12: Tensor = aten::reshape(%1, %4)
138+
%13: Tensor = aten::reshape(%2, %3)
139+
%14: Tensor = aten::reshape(%1, %3)
140+
%15 : Tensor = aten::to(%12, %6, %9, %9, %10)
141+
%16 : int = aten::size(%1, %8)
142+
%17 : int[] = prim::ListConstruct(%16, %6, %5, %7)
143+
%18 : Tensor = aten::index_add_(%14, %8, %15, %13, %11)
144+
%20 : Tensor = aten::reshape(%18, %17)
145+
return (%20))IR";
146+
147+
auto g = std::make_shared<torch::jit::Graph>();
148+
torch::jit::parseIR(graph, g.get());
149+
150+
torch_tensorrt::core::partitioning::PartitionInfo partition_info;
151+
partition_info.enabled = true;
152+
partition_info.min_block_size = 3;
153+
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
154+
std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
155+
torch_tensorrt::core::partitioning::segment_graph(g->block(), partition_info, fallback_nodes);
156+
ASSERT_TRUE(
157+
checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT, 1));
158+
ASSERT_TRUE(
159+
checkSegmentedBlockNumber(segmented_blocks, torch_tensorrt::core::partitioning::SegmentedBlock::kTorch, 1));
160+
ASSERT_TRUE(checkSegmentedBlockNodesMapping(segmented_blocks, g, {{0, 1, 2, 3}, {4, 5, 6, 7}}));
161+
}
162+
123163
TEST(Partitioning, SegmentSequentialModelWithForcedOPCorrectly) {
124164
const auto graph = R"IR(
125165
graph(%0 : Tensor,

0 commit comments

Comments
 (0)