@@ -200,38 +200,18 @@ void AddSegmentedBlockToGraph(std::shared_ptr<torch::jit::Graph>& g, partitionin
200
200
201
201
torch::jit::Node *node;
202
202
for (const auto n : seg.nodes ()) {
203
- node = partitioning::cloneNode (n, g, old_to_new_g);
203
+ partitioning::cloneNode (n, g, old_to_new_g);
204
204
}
205
205
206
206
// original graph value => new global graph value
207
207
for (size_t i = 0 ; i < seg.raw_outputs ().size (); ++i) {
208
208
old_to_new_g[seg.raw_outputs ()[i]] = old_to_new_g[seg.outputs ()[i]];
209
209
}
210
210
211
- for (size_t i = 0 ; i < g->outputs ().size (); ++i) {
212
- g->eraseOutput (i);
213
- }
214
- for (auto &value : node->outputs ()) {
215
- g->registerOutput (value);
216
- }
217
211
LOG_INFO (*g << " (AddSegmentedBlockToGraph)\n " );
218
212
return ;
219
213
}
220
214
221
- // void print_type_dim(c10::TypePtr type) {
222
- // printf("type: %s\n", type->str().c_str());
223
- // auto tensor_type = type->cast<torch::jit::TensorType>();
224
- // auto optional_vec = tensor_type->sizes().sizes().value();
225
- // if (!tensor_type->isComplete()) {
226
- // printf("Not complete type\n");
227
- // return;
228
- // }
229
- // printf("dimension: %d\n", optional_vec.size());
230
- // for (int i = 0; i < optional_vec.size(); ++i) {
231
- // printf("dim(%d) : %d\n", i, optional_vec[i].value());
232
- // }
233
- // }
234
-
235
215
236
216
torch::jit::script::Module CompileGraph (const torch::jit::script::Module& mod, CompileSpec cfg) {
237
217
// TODO: Should be doing a functional transform but need PR #31978
@@ -255,10 +235,6 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
255
235
// segment the graph and convert segmented TensorRT block
256
236
auto segmented_blocks = partitioning::segment_graph (g, convert_cfg.input_ranges );
257
237
258
- for (auto &seg_block : segmented_blocks) {
259
- LOG_INFO (*seg_block.g () << " SegmentedBlockGraph" );
260
- }
261
-
262
238
int trt_engine_id = 0 ;
263
239
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
264
240
for (auto &seg_block : segmented_blocks) {
@@ -271,16 +247,19 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
271
247
auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, named_params);
272
248
auto temp_g = std::make_shared<torch::jit::Graph>();
273
249
AddEngineToGraph (new_mod, temp_g, engine, trt_engine_id++);
274
- // printf("type: %s\n", temp_g->inputs()[0]->type()->str().c_str());
275
- // auto temp_seg_block = partitioning::SegmentedBlock(partitioning::SegmentedBlock::kTensorRT, temp_g);
276
- // AddSegmentedBlockToGraph(new_g, temp_seg_block);
277
250
seg_block.update_graph (temp_g);
278
251
AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
279
252
} else {
280
253
AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
281
254
}
282
255
}
283
256
257
+ for (auto &output : g->outputs ()) {
258
+ new_g->registerOutput (old_to_new_g[output]);
259
+ }
260
+
261
+ LOG_INFO (*new_g << " (After CompileGraph)\n " );
262
+
284
263
auto new_method = new_mod._ivalue ()->compilation_unit ()->create_function (method.name (), new_g);
285
264
auto schema = GenerateGraphSchema (new_mod, new_method->name (), new_g);
286
265
new_mod.type ()->addMethod (new_method);
0 commit comments