Closed
Description
Bug Description
When compiling a model with input type Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
, compilation fails with the following error:
RuntimeError: [Error thrown at core/ir/ir.cpp:46] Expected vals.size() == specs.size() to be true but got false
Expected dimension specifications for all input tensors, but found 1 input tensors and 2 dimension specs
To Reproduce
Steps to reproduce the behavior:
- Define a Torch model with forward function having the following form:
def forward(self, input : Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]])
- Define the Torch-TRT Inputs and compilation settings using
input_signature
, then compile the scripted model:
compile_settings = {
"input_signature": (torch_tensorrt.Input((5, 5), dtype=torch.float), (torch_tensorrt.Input((5, 5), dtype=torch.float), torch_tensorrt.Input((5, 5), dtype=torch.float)),),
"enabled_precisions": {torch.float},
"truncate_long_and_double": True,
}
trt_ts_module = torch_tensorrt.ts.compile(scripted_model, **compile_settings)
Expected behavior
Model should compile with input signature containing nested Tuple collection.
Environment
- Torch-TensorRT Version: 1.4.0.dev0+f43be5b6
- PyTorch Version: 1.14.0.dev20221114+cu116
- CPU Architecture: Intel Xeon CPU
- OS: Ubuntu 20.04
- How you installed PyTorch: pip
- Build command you used:
python setup.py develop
- Are you using local sources or building from archives: local
- Python version: 3.8.13
- CUDA version: 11.6
Additional context
The case of a nested Tuple containing a singleton Tensor followed by a Tuple of Tensors is not parsed as expected by the input signature interpreter, whereas regular tuple Inputs of Tensors are supported, as in:
TensorRT/tests/py/api/test_collections.py
Lines 210 to 213 in b2a5da6
Wrapping the input signature in another Tuple does not resolve the issue, and raises a different error message:
RuntimeError: forward() Expected a value of type 'Tuple[Tensor, Tuple[Tensor, Tensor]]' for argument 'input.1' but instead found type 'Tuple[Tensor]'.
Position: 1
Declaration: forward.forward(__torch__.module self_1, (Tensor, (Tensor, Tensor)) input.1) -> ((Tensor, Tensor, Tensor) 7)