@@ -257,3 +257,113 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
257
257
int count = count_trt_engines (fallback_g);
258
258
ASSERT_TRUE (count == 2 );
259
259
}
260
+
261
+ TEST (Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
262
+ /* parseIR does not support "= aten::_set_item" so we will build this graph manually
263
+ const auto graph = R"IR(
264
+ graph(%x : Tensor,
265
+ %y : Tensor):
266
+ %2 : str = prim::Constant[value="INS"]()
267
+ %3 : str = prim::Constant[value="OUTS"]()
268
+ %4 : bool = prim::Constant[value=0]()
269
+ %5 : int = prim::Constant[value=-1]()
270
+ %6 : Dict(str, Tensor) = prim::DictConstruct()
271
+ = aten::_set_item(%6, %2, %x)
272
+ %7 : Tensor = aten::__getitem__(%6, %2)
273
+ %8 : Tensor = aten::lt(%7, %y)
274
+ %9 : Tensor?[] = prim::ListConstruct(%8)
275
+ %10 : int = prim::dtype(%7)
276
+ %11 : Device = prim::device(%7)
277
+ %12 : Tensor = aten::tensor(%5, %10, %11, %4)
278
+ %13 : Tensor = aten::index_put_(%7, %9, %12, %4)
279
+ = aten::_set_item(%6, %3, %7)
280
+ %14 : Tensor = aten::__getitem__(%6, %2)
281
+ %15 : Tensor = aten::__getitem__(%6, %3)
282
+ return (%14, %15))IR";
283
+ */
284
+ auto g = std::make_shared<torch::jit::Graph>();
285
+ auto x = g->insertInput (0 , " x" );
286
+ auto y = g->insertInput (1 , " y" );
287
+ torch::jit::IValue ins_key (" INS" );
288
+ auto ins_key_val = g->insertConstant (ins_key);
289
+ torch::jit::IValue outs_key (" OUTS" );
290
+ auto outs_key_val = g->insertConstant (outs_key);
291
+ torch::jit::IValue zero (0 );
292
+ auto false_const_val = g->insertConstant (zero);
293
+ false_const_val->setType (c10::BoolType::get ());
294
+ torch::jit::IValue neg_one (-1 );
295
+ auto neg_one_const_val = g->insertConstant (neg_one);
296
+ auto dict_node = g->createDict (ins_key_val->type (), x->type (), torch::jit::ArrayRef<torch::jit::Value*>(), torch::jit::ArrayRef<torch::jit::Value*>());
297
+ g->insertNode (dict_node);
298
+ auto set_node = g->create (torch::jit::Symbol::fromQualString (" aten::_set_item" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val, x}, 0 );
299
+ g->insertNode (set_node);
300
+ auto get_node = g->create (torch::jit::Symbol::fromQualString (" aten::__getitem__" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val}, 1 );
301
+ g->insertNode (get_node);
302
+ auto lt_node = g->create (torch::jit::Symbol::fromQualString (" aten::lt" ), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output (), y}, 1 );
303
+ g->insertNode (lt_node);
304
+ auto list_node = g->createList (at::OptionalType::create (lt_node->output ()->type ()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output ()});
305
+ g->insertNode (list_node);
306
+ auto dtype_node = g->create (torch::jit::Symbol::fromQualString (" prim::dtype" ), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()}, 1 );
307
+ dtype_node->output ()->setType (neg_one_const_val->type ());
308
+ g->insertNode (dtype_node);
309
+ auto device_node = g->create (torch::jit::Symbol::fromQualString (" prim::device" ), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()}, 1 );
310
+ device_node->output ()->setType (c10::DeviceObjType::get ());
311
+ g->insertNode (device_node);
312
+ auto tensor_node = g->create (torch::jit::Symbol::fromQualString (" aten::tensor" ), torch::jit::ArrayRef<torch::jit::Value*>{neg_one_const_val, dtype_node->output (), device_node->output (), false_const_val}, 1 );
313
+ g->insertNode (tensor_node);
314
+ auto index_put_node = g->create (torch::jit::Symbol::fromQualString (" aten::index_put_" ),
315
+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output (), list_node->output (), tensor_node->output (), false_const_val}, 1 );
316
+ g->insertNode (index_put_node);
317
+ auto out_set_node = g->create (torch::jit::Symbol::fromQualString (" aten::_set_item" ),
318
+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val, get_node->output ()}, 0 );
319
+ g->insertNode (out_set_node);
320
+ auto get_ins_node = g->create (torch::jit::Symbol::fromQualString (" aten::__getitem__" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val}, 1 );
321
+ g->insertNode (get_ins_node);
322
+ auto get_outs_node = g->create (torch::jit::Symbol::fromQualString (" aten::__getitem__" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val}, 1 );
323
+ g->insertNode (get_outs_node);
324
+ g->registerOutput (get_ins_node->output ());
325
+ g->registerOutput (get_outs_node->output ());
326
+
327
+ torch_tensorrt::core::partitioning::PartitionInfo partition_info;
328
+ partition_info.enabled = true ;
329
+ std::vector<torch_tensorrt::core::ir::Input> inputs;
330
+ inputs.push_back (torch_tensorrt::core::ir::Input ({4 , 4 }));
331
+ inputs.push_back (torch_tensorrt::core::ir::Input ({4 , 4 }));
332
+
333
+ std::unordered_map<const torch::jit::Value*, torch_tensorrt::core::ir::Input> inputs_map;
334
+ std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> input_types;
335
+ for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
336
+ inputs_map.insert ({g->inputs ()[i], inputs[i]});
337
+ input_types.insert ({g->inputs ()[i], {at::kFloat }});
338
+ }
339
+ auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs (inputs_map, input_types);
340
+ auto segmented_blocks =
341
+ torch_tensorrt::core::partitioning::Partition (g->block (), input_ivalues_map, partition_info);
342
+
343
+ int torch_block_cnt = 0 , trt_block_cnt = 0 ;
344
+ for (const auto & segmented_block : segmented_blocks) {
345
+ if (segmented_block.target () == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT ) {
346
+ ++trt_block_cnt;
347
+ ASSERT_TRUE (checkSegmentedBlockInputType (segmented_block, [](torch::jit::TypePtr type_ptr) {
348
+ return type_ptr->isSubtypeOf (torch::jit::TensorType::get ());
349
+ }));
350
+ } else {
351
+ ++torch_block_cnt;
352
+ bool output_dict = false ;
353
+ bool input_dict = false ;
354
+ auto dict_type = dict_node->output ()->type ();
355
+ for (auto in : segmented_block.raw_inputs ()) {
356
+ if (in->type ()->isSubtypeOf (dict_type)){
357
+ input_dict = true ;
358
+ }
359
+ }
360
+ for (auto out : segmented_block.raw_outputs ()) {
361
+ if (out->type ()->isSubtypeOf (dict_type)){
362
+ output_dict = true ;
363
+ }
364
+ }
365
+ EXPECT_TRUE (output_dict ^ input_dict);
366
+ }
367
+ }
368
+ ASSERT_TRUE (trt_block_cnt == 1 && torch_block_cnt == 2 );
369
+ }
0 commit comments