Skip to content

🐛 [Bug] Error while loading Torch-TensorRT model (torch.jit.load) #973

Closed
@pauline6

Description

@pauline6

Bug Description

The model below is converted in a Torch-TensorRT model, the sub_function module is excluded from the conversion. While loading the module with torch.jit.load, this error is raised.

Traceback (most recent call last):
    model = torch.jit.load('model_trt.ts')
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_serialization.py", line 161, in load
    cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
RuntimeError: expected ) but found 'number' here:
Serialized   File "code/__torch__.py", line 6
  __torch___function_trt_engine_0x77ea7da0 : __torch__.torch.classes.tensorrt.Engine
  def forward(self_1: __torch__.function_trt,
    x.1: Tensor,
     ~~ <--- HERE
    kernel.1: Tensor) -> Tensor:
    __torch___function_trt_engine_0x77ea7da0 = self_1.__torch___function_trt_engine_0x77ea7da0

To Reproduce

import torch
import torch_tensorrt
import torch.nn as nn
import torch.nn.functional as F

class function(nn.Module):
    def __init__( self ):
        super(function, self).__init__()
        self.conv_kernel = nn.Sequential(
                nn.Conv2d(256, 256, 3, bias=False),
                nn.BatchNorm2d(256),
        )
        self.sub_function = sub_function()

    def forward( self, x, kernel ):
        # type: (Tensor, Tensor) -> Tensor 
        kernel = self.conv_kernel(kernel)
        x = x.view(1, 256 , x.size(2), x.size(3))
        kernel = kernel.view(256, 1, kernel.size(2), kernel.size(3))
        out = self.sub_function( x, kernel ) 
        return out
    
class sub_function(nn.Module):
    def __init__( self ):
        super(sub_function, self).__init__()

    # type: (Tensor, Tensor) -> Tensor 
    def forward( self, x, kernel ):
        out = F.conv2d(x, kernel, groups=256)
        return out


model = function()
model_script = torch.jit.script(model)
model_script.cuda().eval()

compile_settings = {
            "inputs": [
                torch_tensorrt.Input([1, 256, 29, 29], dtype=torch.float32),
                torch_tensorrt.Input([1, 256, 7, 7], dtype=torch.float32),
            ],
            "enabled_precisions": {torch.float32}
        }

model_trt = torch_tensorrt.ts.compile( 
    model_script, 
    **compile_settings, 
    require_full_compilation=False, 
    torch_executed_modules=['sub_function']
)

torch.jit.save(model_trt, 'model_trt.ts')

model = torch.jit.load('model_trt.ts')

Expected behavior

The Torch-TensorRT model should be load in model to be used.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 1.0.0
  • PyTorch Version (e.g. 1.10.0+cu113): 1.10.0+cu113
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: python 3.8.10
  • CUDA version: 11.6
  • GPU models and configuration:
  • Any other relevant information:

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingcomponent: coreIssues re: The core compiler

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions