@@ -36,13 +36,15 @@ CompileSpec::CompileSpec(torch::jit::IValue input_signature) {
36
36
graph_inputs.input_signature = input_signature;
37
37
}
38
38
39
- void to_internal_input_signature (torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue) {
39
+ void to_internal_input_signature (torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue, int depth = 0 ) {
40
+ TORCHTRT_CHECK (
41
+ depth <= 2 , " Input nesting depth exceeds max supported depth, use 1 level: [A, B], or 2 level: [A, (B, C)]" )
40
42
if (input_ivalue.isTuple ()) {
41
43
auto input_tuple = input_ivalue.toTuple ();
42
44
std::vector<torch::jit::IValue> converted_elements;
43
45
for (auto item : input_tuple->elements ()) {
44
46
torch::jit::IValue converted_item;
45
- to_internal_input_signature (item, converted_item);
47
+ to_internal_input_signature (item, converted_item, depth++ );
46
48
converted_elements.push_back (converted_item);
47
49
auto tuple_ptr = c10::ivalue::Tuple::create (converted_elements);
48
50
converted_ivalue = torch::jit::IValue (tuple_ptr);
@@ -53,7 +55,7 @@ void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IV
53
55
auto converted_elements = c10::impl::GenericList (type);
54
56
for (auto item : input_list) {
55
57
torch::jit::IValue converted_item;
56
- to_internal_input_signature (item, converted_item);
58
+ to_internal_input_signature (item, converted_item, depth++ );
57
59
converted_elements.push_back (converted_item);
58
60
}
59
61
converted_ivalue = torch::jit::IValue (converted_elements);
0 commit comments