Skip to content

Commit 55e0510

Browse files
committed
refactored the new graph output registration
Signed-off-by: Bo Wang <[email protected]>
1 parent 0d28164 commit 55e0510

File tree

2 files changed

+7
-36
lines changed

2 files changed

+7
-36
lines changed

core/compiler.cpp

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -200,38 +200,18 @@ void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitionin
200200

201201
torch::jit::Node *node;
202202
for (const auto n : seg.nodes()) {
203-
node = partitioning::cloneNode(n, g, old_to_new_g);
203+
partitioning::cloneNode(n, g, old_to_new_g);
204204
}
205205

206206
// original graph value => new global graph value
207207
for (size_t i = 0; i < seg.raw_outputs().size(); ++i) {
208208
old_to_new_g[seg.raw_outputs()[i]] = old_to_new_g[seg.outputs()[i]];
209209
}
210210

211-
for (size_t i = 0; i < g->outputs().size(); ++i) {
212-
g->eraseOutput(i);
213-
}
214-
for (auto &value : node->outputs()) {
215-
g->registerOutput(value);
216-
}
217211
LOG_INFO(*g << "(AddSegmentedBlockToGraph)\n");
218212
return;
219213
}
220214

221-
//void print_type_dim(c10::TypePtr type) {
222-
// printf("type: %s\n", type->str().c_str());
223-
// auto tensor_type = type->cast<torch::jit::TensorType>();
224-
// auto optional_vec = tensor_type->sizes().sizes().value();
225-
// if (!tensor_type->isComplete()) {
226-
// printf("Not complete type\n");
227-
// return;
228-
// }
229-
// printf("dimension: %d\n", optional_vec.size());
230-
// for (int i = 0; i < optional_vec.size(); ++i) {
231-
// printf("dim(%d) : %d\n", i, optional_vec[i].value());
232-
// }
233-
//}
234-
235215

236216
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, CompileSpec cfg) {
237217
// TODO: Should be doing a functional transform but need PR #31978
@@ -255,10 +235,6 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
255235
// segment the graph and convert segmented TensorRT block
256236
auto segmented_blocks = partitioning::segment_graph(g, convert_cfg.input_ranges);
257237

258-
for (auto &seg_block : segmented_blocks) {
259-
LOG_INFO(*seg_block.g() << "SegmentedBlockGraph");
260-
}
261-
262238
int trt_engine_id = 0;
263239
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
264240
for (auto &seg_block : segmented_blocks) {
@@ -271,16 +247,19 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
271247
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
272248
auto temp_g = std::make_shared<torch::jit::Graph>();
273249
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id++);
274-
// printf("type: %s\n", temp_g->inputs()[0]->type()->str().c_str());
275-
// auto temp_seg_block = partitioning::SegmentedBlock(partitioning::SegmentedBlock::kTensorRT, temp_g);
276-
// AddSegmentedBlockToGraph(new_g, temp_seg_block);
277250
seg_block.update_graph(temp_g);
278251
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
279252
} else {
280253
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
281254
}
282255
}
283256

257+
for (auto &output : g->outputs()) {
258+
new_g->registerOutput(old_to_new_g[output]);
259+
}
260+
261+
LOG_INFO(*new_g << "(After CompileGraph)\n");
262+
284263
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
285264
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
286265
new_mod.type()->addMethod(new_method);

core/partitioning/partitioning.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
164164
}
165165
}
166166

167-
printf("before register input\n");
168167
registerSegmentsInputsOutputs(segmented_blocks, g);
169168

170169
std::vector<nvinfer1::Dims> graph_inputs_shape = extractNvinfer1Dims(input_ranges);
@@ -175,13 +174,6 @@ std::vector<SegmentedBlock> segment_graph(std::shared_ptr<torch::jit::Graph> g,
175174
}
176175

177176
for (auto &seg_block : segmented_blocks) {
178-
LOG_INFO(*seg_block.g() << "In partitioning\n");
179-
}
180-
181-
printf("before register shapes\n");
182-
183-
for (auto &seg_block : segmented_blocks) {
184-
printf("h\n");
185177
registerSegmentInOutShape(seg_block, input_shape_map);
186178
}
187179

0 commit comments

Comments
 (0)