@@ -257,3 +257,147 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
257
257
int count = count_trt_engines (fallback_g);
258
258
ASSERT_TRUE (count == 2 );
259
259
}
260
+
261
+ TEST (Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
262
+ /* parseIR does not support "= aten::_set_item" so we will build this graph manually
263
+ const auto graph = R"IR(
264
+ graph(%x : Tensor,
265
+ %y : Tensor):
266
+ %2 : str = prim::Constant[value="INS"]()
267
+ %3 : str = prim::Constant[value="OUTS"]()
268
+ %4 : bool = prim::Constant[value=0]()
269
+ %5 : int = prim::Constant[value=-1]()
270
+ %6 : Dict(str, Tensor) = prim::DictConstruct()
271
+ = aten::_set_item(%6, %2, %x)
272
+ %7 : Tensor = aten::__getitem__(%6, %2)
273
+ %8 : Tensor = aten::lt(%7, %y)
274
+ %9 : Tensor?[] = prim::ListConstruct(%8)
275
+ %10 : int = prim::dtype(%7)
276
+ %11 : Device = prim::device(%7)
277
+ %12 : Tensor = aten::tensor(%5, %10, %11, %4)
278
+ %13 : Tensor = aten::index_put_(%7, %9, %12, %4)
279
+ = aten::_set_item(%6, %3, %7)
280
+ %14 : Tensor = aten::__getitem__(%6, %2)
281
+ %15 : Tensor = aten::__getitem__(%6, %3)
282
+ return (%14, %15))IR";
283
+ */
284
+ auto g = std::make_shared<torch::jit::Graph>();
285
+ auto x = g->insertInput (0 , " x" );
286
+ auto y = g->insertInput (1 , " y" );
287
+ torch::jit::IValue ins_key (" INS" );
288
+ auto ins_key_val = g->insertConstant (ins_key);
289
+ torch::jit::IValue outs_key (" OUTS" );
290
+ auto outs_key_val = g->insertConstant (outs_key);
291
+ torch::jit::IValue zero (0 );
292
+ auto false_const_val = g->insertConstant (zero);
293
+ false_const_val->setType (c10::BoolType::get ());
294
+ torch::jit::IValue neg_one (-1 );
295
+ auto neg_one_const_val = g->insertConstant (neg_one);
296
+ auto dict_node = g->createDict (
297
+ ins_key_val->type (),
298
+ x->type (),
299
+ torch::jit::ArrayRef<torch::jit::Value*>(),
300
+ torch::jit::ArrayRef<torch::jit::Value*>());
301
+ g->insertNode (dict_node);
302
+ auto set_node = g->create (
303
+ torch::jit::Symbol::fromQualString (" aten::_set_item" ),
304
+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val, x},
305
+ 0 );
306
+ g->insertNode (set_node);
307
+ auto get_node = g->create (
308
+ torch::jit::Symbol::fromQualString (" aten::__getitem__" ),
309
+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val},
310
+ 1 );
311
+ g->insertNode (get_node);
312
+ auto lt_node = g->create (
313
+ torch::jit::Symbol::fromQualString (" aten::lt" ),
314
+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output (), y},
315
+ 1 );
316
+ g->insertNode (lt_node);
317
+ auto list_node = g->createList (
318
+ at::OptionalType::create (lt_node->output ()->type ()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output ()});
319
+ g->insertNode (list_node);
320
+ auto dtype_node = g->create (
321
+ torch::jit::Symbol::fromQualString (" prim::dtype" ),
322
+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()},
323
+ 1 );
324
+ dtype_node->output ()->setType (neg_one_const_val->type ());
325
+ g->insertNode (dtype_node);
326
+ auto device_node = g->create (
327
+ torch::jit::Symbol::fromQualString (" prim::device" ),
328
+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()},
329
+ 1 );
330
+ device_node->output ()->setType (c10::DeviceObjType::get ());
331
+ g->insertNode (device_node);
332
+ auto tensor_node = g->create (
333
+ torch::jit::Symbol::fromQualString (" aten::tensor" ),
334
+ torch::jit::ArrayRef<torch::jit::Value*>{
335
+ neg_one_const_val, dtype_node->output (), device_node->output (), false_const_val},
336
+ 1 );
337
+ g->insertNode (tensor_node);
338
+ auto index_put_node = g->create (
339
+ torch::jit::Symbol::fromQualString (" aten::index_put_" ),
340
+ torch::jit::ArrayRef<torch::jit::Value*>{
341
+ get_node->output (), list_node->output (), tensor_node->output (), false_const_val},
342
+ 1 );
343
+ g->insertNode (index_put_node);
344
+ auto out_set_node = g->create (
345
+ torch::jit::Symbol::fromQualString (" aten::_set_item" ),
346
+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val, get_node->output ()},
347
+ 0 );
348
+ g->insertNode (out_set_node);
349
+ auto get_ins_node = g->create (
350
+ torch::jit::Symbol::fromQualString (" aten::__getitem__" ),
351
+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val},
352
+ 1 );
353
+ g->insertNode (get_ins_node);
354
+ auto get_outs_node = g->create (
355
+ torch::jit::Symbol::fromQualString (" aten::__getitem__" ),
356
+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val},
357
+ 1 );
358
+ g->insertNode (get_outs_node);
359
+ g->registerOutput (get_ins_node->output ());
360
+ g->registerOutput (get_outs_node->output ());
361
+
362
+ torch_tensorrt::core::partitioning::PartitionInfo partition_info;
363
+ partition_info.enabled = true ;
364
+ std::vector<torch_tensorrt::core::ir::Input> inputs;
365
+ inputs.push_back (torch_tensorrt::core::ir::Input ({4 , 4 }));
366
+ inputs.push_back (torch_tensorrt::core::ir::Input ({4 , 4 }));
367
+
368
+ std::unordered_map<const torch::jit::Value*, torch_tensorrt::core::ir::Input> inputs_map;
369
+ std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> input_types;
370
+ for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
371
+ inputs_map.insert ({g->inputs ()[i], inputs[i]});
372
+ input_types.insert ({g->inputs ()[i], {at::kFloat }});
373
+ }
374
+ auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs (inputs_map, input_types);
375
+ auto segmented_blocks = torch_tensorrt::core::partitioning::Partition (g->block (), input_ivalues_map, partition_info);
376
+
377
+ int torch_block_cnt = 0 , trt_block_cnt = 0 ;
378
+ for (const auto & segmented_block : segmented_blocks) {
379
+ if (segmented_block.target () == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT ) {
380
+ ++trt_block_cnt;
381
+ ASSERT_TRUE (checkSegmentedBlockInputType (segmented_block, [](torch::jit::TypePtr type_ptr) {
382
+ return type_ptr->isSubtypeOf (torch::jit::TensorType::get ());
383
+ }));
384
+ } else {
385
+ ++torch_block_cnt;
386
+ bool output_dict = false ;
387
+ bool input_dict = false ;
388
+ auto dict_type = dict_node->output ()->type ();
389
+ for (auto in : segmented_block.raw_inputs ()) {
390
+ if (in->type ()->isSubtypeOf (dict_type)) {
391
+ input_dict = true ;
392
+ }
393
+ }
394
+ for (auto out : segmented_block.raw_outputs ()) {
395
+ if (out->type ()->isSubtypeOf (dict_type)) {
396
+ output_dict = true ;
397
+ }
398
+ }
399
+ EXPECT_TRUE (output_dict ^ input_dict);
400
+ }
401
+ }
402
+ ASSERT_TRUE (trt_block_cnt == 1 && torch_block_cnt == 2 );
403
+ }
0 commit comments