Skip to content

Commit 6d3064a

Browse files
committed
feat: allow users to set fallback block size and ops
Signed-off-by: Bo Wang <[email protected]>
1 parent f4c29b4 commit 6d3064a

File tree

4 files changed

+62
-14
lines changed

4 files changed

+62
-14
lines changed

core/compiler.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitionin
184184
old_to_new_g[seg.raw_outputs()[i]] = old_to_new_g[seg.outputs()[i]];
185185
}
186186

187-
LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n");
187+
// LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n");
188188
return;
189189
}
190190

@@ -208,7 +208,10 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
208208

209209

210210
// segment the graph and convert segmented TensorRT block
211-
auto segmented_blocks = partitioning::segment_graph(g, convert_cfg.input_ranges);
211+
auto segmented_blocks = partitioning::segment_graph(g, convert_cfg.input_ranges, convert_cfg.engine_settings.torch_fallback);
212+
if (segmented_blocks.size() == 1 && segmented_blocks[0].target() == partitioning::SegmentedBlock::kTorch) {
213+
return mod;
214+
}
212215

213216
int trt_engine_id = 0;
214217
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
@@ -233,7 +236,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
233236
new_g->registerOutput(old_to_new_g[output]);
234237
}
235238

236-
LOG_INFO(*new_g << "(After CompileGraph)\n");
239+
LOG_INFO(*new_g << "(StitchSegmentedGraph)\n");
237240

238241
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
239242
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,13 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
3939

4040
os << "\n Torch Fallback: " << s.torch_fallback.enabled;
4141
if (s.torch_fallback.enabled) {
42-
os << "\n Fallback min block size: " << s.torch_fallback.min_block_size;
42+
os << "\n Fallback Min Block Size: " << s.torch_fallback.min_block_size;
43+
if (!s.torch_fallback.forced_fallback_operators.empty()) {
44+
os << "\n Forced Fallback Operators:";
45+
for (auto it = s.torch_fallback.forced_fallback_operators.begin(); it != s.torch_fallback.forced_fallback_operators.end(); ++it) {
46+
os << " " << *it;
47+
}
48+
}
4349
}
4450
return os;
4551
}

core/partitioning/partitioning.cpp

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ void registerSegmentInOutShape(SegmentedBlock &seg_block, std::unordered_map<tor
6262
// create a module to run the graph
6363
auto g = seg_block.g();
6464
auto copy_g = g->copy();
65+
if (seg_block.raw_outputs().size() > 1) {
66+
auto new_output_node = copy_g->appendNode(copy_g->createTuple(copy_g->outputs()));
67+
for (int idx = copy_g->outputs().size() - 1; idx >= 0; --idx) {
68+
copy_g->eraseOutput(idx);
69+
}
70+
copy_g->registerOutput(new_output_node->outputs()[0]);
71+
}
72+
6573
torch::jit::script::Module cur_mod(c10::QualifiedName("module"));
6674

6775
auto self = copy_g->insertInput(0, "self_1");
@@ -140,25 +148,45 @@ void registerSegmentsInputsOutputs(std::vector<SegmentedBlock> &segmented_blocks
140148
return;
141149
}
142150

143-
std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g, std::vector<conversion::InputRange>& input_ranges) {
151+
void merge_nodes(std::vector<torch::jit::Node*> &pytorch_nodes, std::vector<torch::jit::Node*> &tensorrt_nodes,
152+
std::vector<SegmentedBlock> &segmented_blocks, size_t min_block_size) {
153+
if (!tensorrt_nodes.empty()) {
154+
if (tensorrt_nodes.size() < min_block_size) {
155+
pytorch_nodes.insert(pytorch_nodes.end(), tensorrt_nodes.begin(), tensorrt_nodes.end());
156+
} else {
157+
if (!pytorch_nodes.empty()) segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
158+
segmented_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
159+
pytorch_nodes.clear();
160+
}
161+
tensorrt_nodes.clear();
162+
}
163+
}
164+
165+
std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
166+
std::vector<conversion::InputRange>& input_ranges,
167+
const conversion::TorchFallback &fallback_info) {
168+
auto min_block_size = fallback_info.min_block_size;
169+
std::unordered_set<std::string> forced_fallback_operators(fallback_info.forced_fallback_operators.begin(), fallback_info.forced_fallback_operators.end());
144170
std::vector<SegmentedBlock> segmented_blocks;
145171

146172
auto nodes = g->block()->nodes();
147173

148174
// segment the nodes
175+
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
176+
149177
for (const auto n : nodes) {
150178
if (n->kind() == torch::jit::prim::Constant) continue;
179+
std::string node_string(n->kind().toQualString());
151180

152-
auto block_target = conversion::OpSupported(n) ? SegmentedBlock::kTensorRT : SegmentedBlock::kTorch;
153-
154-
if (segmented_blocks.empty() || block_target != segmented_blocks.back().target()) {
155-
SegmentedBlock cur_block(block_target);
156-
cur_block.appendNode(n);
157-
segmented_blocks.push_back(cur_block);
181+
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string)) {
182+
tensorrt_nodes.push_back(n);
158183
} else {
159-
segmented_blocks.back().appendNode(n);
184+
merge_nodes(pytorch_nodes, tensorrt_nodes, segmented_blocks, min_block_size);
185+
pytorch_nodes.push_back(n);
160186
}
161187
}
188+
merge_nodes(pytorch_nodes, tensorrt_nodes, segmented_blocks, min_block_size);
189+
if (!pytorch_nodes.empty()) segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
162190

163191
registerSegmentsInputsOutputs(segmented_blocks, g);
164192

core/partitioning/partitioning.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,17 @@ struct SegmentedBlock {
2020
kTensorRT,
2121
};
2222

23+
SegmentedBlock() = default;
24+
2325
SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}
2426

27+
SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*> &nodes) :
28+
target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {
29+
for (auto &node : nodes) {
30+
appendNode(node);
31+
}
32+
}
33+
2534
SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}
2635

2736
enum SegmentedBlockTarget target() {
@@ -40,6 +49,7 @@ struct SegmentedBlock {
4049

4150
void registerOutput(torch::jit::Value* raw_input) {
4251
outputs_.push_back(raw_input);
52+
4353
g_->registerOutput(old_to_new_[raw_input]);
4454
}
4555

@@ -108,8 +118,9 @@ struct SegmentedBlock {
108118

109119
};
110120

111-
std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g, std::vector<conversion::InputRange>& input_ranges);
112-
121+
std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
122+
std::vector<conversion::InputRange>& input_ranges,
123+
const conversion::TorchFallback &fallback_info);
113124
}
114125
}
115126
}

0 commit comments

Comments
 (0)