Closed
Description
Bug Description
The model below is converted in a Torch-TensorRT model, the sub_function module is excluded from the conversion. While loading the module with torch.jit.load, this error is raised.
Traceback (most recent call last):
model = torch.jit.load('model_trt.ts')
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_serialization.py", line 161, in load
cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
RuntimeError: expected ) but found 'number' here:
Serialized File "code/__torch__.py", line 6
__torch___function_trt_engine_0x77ea7da0 : __torch__.torch.classes.tensorrt.Engine
def forward(self_1: __torch__.function_trt,
x.1: Tensor,
~~ <--- HERE
kernel.1: Tensor) -> Tensor:
__torch___function_trt_engine_0x77ea7da0 = self_1.__torch___function_trt_engine_0x77ea7da0
To Reproduce
import torch
import torch_tensorrt
import torch.nn as nn
import torch.nn.functional as F
class function(nn.Module):
def __init__( self ):
super(function, self).__init__()
self.conv_kernel = nn.Sequential(
nn.Conv2d(256, 256, 3, bias=False),
nn.BatchNorm2d(256),
)
self.sub_function = sub_function()
def forward( self, x, kernel ):
# type: (Tensor, Tensor) -> Tensor
kernel = self.conv_kernel(kernel)
x = x.view(1, 256 , x.size(2), x.size(3))
kernel = kernel.view(256, 1, kernel.size(2), kernel.size(3))
out = self.sub_function( x, kernel )
return out
class sub_function(nn.Module):
def __init__( self ):
super(sub_function, self).__init__()
# type: (Tensor, Tensor) -> Tensor
def forward( self, x, kernel ):
out = F.conv2d(x, kernel, groups=256)
return out
model = function()
model_script = torch.jit.script(model)
model_script.cuda().eval()
compile_settings = {
"inputs": [
torch_tensorrt.Input([1, 256, 29, 29], dtype=torch.float32),
torch_tensorrt.Input([1, 256, 7, 7], dtype=torch.float32),
],
"enabled_precisions": {torch.float32}
}
model_trt = torch_tensorrt.ts.compile(
model_script,
**compile_settings,
require_full_compilation=False,
torch_executed_modules=['sub_function']
)
torch.jit.save(model_trt, 'model_trt.ts')
model = torch.jit.load('model_trt.ts')
Expected behavior
The Torch-TensorRT model should be load in model to be used.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 1.0.0
- PyTorch Version (e.g. 1.10.0+cu113): 1.10.0+cu113
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: python 3.8.10
- CUDA version: 11.6
- GPU models and configuration:
- Any other relevant information: