Skip to content

Commit 12db9e8

Browse files
authored
fix: Improve input handling for input_signature (#1698)
1 parent d35fe2a commit 12db9e8

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

cpp/src/compile_spec.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@ CompileSpec::CompileSpec(torch::jit::IValue input_signature) {
3636
graph_inputs.input_signature = input_signature;
3737
}
3838

39-
void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue) {
39+
void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue, int depth = 0) {
40+
TORCHTRT_CHECK(
41+
depth <= 2, "Input nesting depth exceeds max supported depth, use 1 level: [A, B], or 2 level: [A, (B, C)]")
4042
if (input_ivalue.isTuple()) {
4143
auto input_tuple = input_ivalue.toTuple();
4244
std::vector<torch::jit::IValue> converted_elements;
4345
for (auto item : input_tuple->elements()) {
4446
torch::jit::IValue converted_item;
45-
to_internal_input_signature(item, converted_item);
47+
to_internal_input_signature(item, converted_item, depth++);
4648
converted_elements.push_back(converted_item);
4749
auto tuple_ptr = c10::ivalue::Tuple::create(converted_elements);
4850
converted_ivalue = torch::jit::IValue(tuple_ptr);
@@ -53,7 +55,7 @@ void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IV
5355
auto converted_elements = c10::impl::GenericList(type);
5456
for (auto item : input_list) {
5557
torch::jit::IValue converted_item;
56-
to_internal_input_signature(item, converted_item);
58+
to_internal_input_signature(item, converted_item, depth++);
5759
converted_elements.push_back(converted_item);
5860
}
5961
converted_ivalue = torch::jit::IValue(converted_elements);

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,17 +194,22 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback:
194194
return info
195195

196196

197-
def _parse_input_signature(input_signature: Any):
197+
def _parse_input_signature(input_signature: Any, depth: int = 0):
198+
if depth > 2:
199+
raise AssertionError(
200+
"Input nesting depth exceeds max supported depth, use 1 level: [A, B], or 2 level: [A, (B, C)]"
201+
)
202+
198203
if isinstance(input_signature, tuple):
199204
input_list = []
200205
for item in input_signature:
201-
input = _parse_input_signature(item)
206+
input = _parse_input_signature(item, depth + 1)
202207
input_list.append(input)
203208
return tuple(input_list)
204209
elif isinstance(input_signature, list):
205210
input_list = []
206211
for item in input_signature:
207-
input = _parse_input_signature(item)
212+
input = _parse_input_signature(item, depth + 1)
208213
input_list.append(input)
209214
return input_list
210215
elif isinstance(input_signature, Input) or isinstance(

0 commit comments

Comments
 (0)