Skip to content

✨[Feature] Support list and namedtuple input types to forward function #798

Closed
@chaoz-dev

Description

@chaoz-dev

Is your feature request related to a problem? Please describe.

Currently, the forward function only supports tensor input types when compiling. However, sometimes we wish to supply many tensors into the forward function at once (say, greater than 10); this results in a very long forward API call where we have to list every tensor individually when calling forward. It would be helpful if we could pass in a single container containing these tensors all at once instead, which results in a much cleaner API call.

For this specific request, I focus on the list and namedtuple input types first, since these should cover most basic uses cases (and should functionally satisfy named tensor key-value pair type inputs).

Describe the solution you'd like

Instead of supporting only the following, where we need to supply torch.Tensors into forward:

  DEVICE = torch.device("cuda:0")                                                                                            
  SHAPE = (1, 1)        

  torch.manual_seed(0)                                                                                                                                                                                                                                                                                                                                 

  class Model(torch.nn.Module):                                                                                              
      def __init__(self):                                                                                                    
          super().__init__()                                                                                                 
                                                                                                                          
      def forward(self, a, b):                                                                                               
          return a - b                                                                                      

  if __name__ == "__main__":                                                                                                 
      tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)                                                        
                                                                                                                             
      model = Model().eval().to(DEVICE)                                                                                      
      out = model(tensor, tensor)                                                                                                   
                                                                                                                                                                                                                                                          
      model_trt = torch_tensorrt.compile(                                                                                    
          model,                                                                                                             
          inputs=[                                                                                                           
              torch_tensorrt.Input(shape=SHAPE),                                                                             
              torch_tensorrt.Input(shape=SHAPE),                                                                             
          ],                                                                                                                 
          enabled_precisions={torch.float},                                                                                  
      )                                                                                                                      
      out_trt = model(tensor, tensor)                                                                                               
                                                                                                                                                                                                                                                         
      assert torch.max(torch.abs(out - out_trt)) < 1e-6                                                                      

Support also inputting namedtuple or list into forward:

  DEVICE = torch.device("cuda:0")                                                                                            
  SHAPE = (1, 1)        

  torch.manual_seed(0)  

  Input = namedtuple('Input', ['t1', 't2'])                                                                                  
                                                                                                                             
  class Model(torch.nn.Module):                                                                                              
      def __init__(self):                                                                                                    
          super().__init__()                                                                                                 
                                                                                                                          
      def forward(self, input_: Input):                                                                                      
          return input_.t1 - input_.t2                                                                                       

  if __name__ == "__main__":                                                                                                 
      tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)                                                        
      input_ = Input(tensor, tensor)                                                                                         
                                                                                                                             
      model = Model().eval().to(DEVICE)                                                                                      
      out = model(input_)                                                                                                   
                                                                                                                                                                                                                                                          
      model_trt = torch_tensorrt.compile(                                                                                    
          model,                                                                                                             
          inputs=[                                                                                                           
              torch_tensorrt.Input(shape=SHAPE),                                                                             
              torch_tensorrt.Input(shape=SHAPE),                                                                             
          ],                                                                                                                 
          enabled_precisions={torch.float},                                                                                  
      )                                                                                                                      
      out_trt = model(input_)                                                                                               
                                                                                                                                                                                                                                                         
      assert torch.max(torch.abs(out - out_trt)) < 1e-6                                                                      

Describe alternatives you've considered

Currently the only alternative is to supply tensors directly into the forward function; supplying namedtuples will cause the compilation to segfault, and supplying lists will cause the compilation to fail to recognize the input altogether.

Additional context

  • For simplicity, the input containers should contain ONLY tensors (implying that we disallow nested containers). Containers with mixed input types are ignored.
  • Furthermore, there must be a bijection between the tensors in the container and the sizes provided into the compile call; ie. there must be one Input size for each tensor in the container and both are taken in the same order.
  • We can mix tensors and containers into the forward call (eg. forward(x: torch.Tensor, y: List[torch.Tensor], z: namedtuple[torch.Tensor])). Any other types are treated as they are currently when input.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions