Skip to content

🐛 [Bug] Compilation causes error: RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:66] Expected ivalues_maps.count(input) to be true but got false Could not find torch::jit::Value* 47 produced from %47 : int = prim::dtype(%52) in lowering graph for mini graph input. #922

Closed
@chaoz-dev

Description

@chaoz-dev

Bug Description

Compiling the graph throws the following error:

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:66] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 47 produced from %47 : int = prim::dtype(%52) in lowering graph for mini graph input.

Looking at the output torchscript graph, %47 is defined in a prior node, however, it does not appear to be visible in the current node.

To Reproduce

  import torch    
  import torch.nn as nn    
  import torch.nn.functional as F    
  import torch_tensorrt as torchtrt    
  import torch_tensorrt.logging as logging    
      
  logging.set_reportable_log_level(logging.Level.Graph)    
      
  torch.manual_seed(0)    
      
  DEVICE = torch.device("cuda:0")    
  SHAPE = (1, 1)    
      
      
  class Model(torch.nn.Module):    
      def __init__(self):    
          super().__init__()    
          self.lin = nn.Linear(1, 1)    
      
      def forward(self, a):    
          out = self.lin(a)    
      
          tril = torch.zeros(1, 1, 1, device=a.device, dtype=out.dtype)    
          indices = torch.tril_indices(1, 1)    
          tril[:, indices[0], indices[1]] = out    
      
          return tril    
      
      
  if __name__ == "__main__":    
      tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)    
      
      model = Model().eval().to(DEVICE)    
      out = model(tensor)    
      print(f"Model: {out}")    
      
      model_trt = torchtrt.compile(    
          model,    
          inputs=[    
              torchtrt.Input(shape=SHAPE),    
          ],    
          enabled_precisions={torch.float},    
          truncate_long_and_double=True    
      )    
      out_trt = model(tensor)    
      print(f"Model TRT: {out_trt}")    
      
      assert torch.max(torch.abs(out - out_trt)) < 1e-6    

Throws the following error:

Traceback (most recent call last):
  File "/scripts/tril.py", line 39, in <module>
    model_trt = torchtrt.compile(
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 97, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 119, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:66] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 47 produced from %47 : int = prim::dtype(%52) in lowering graph for mini graph input.

Expected behavior

Compilation should not fail, and should produce the following output when run:

Model: tensor([[[0.5434]]], device='cuda:0', grad_fn=<CopySlices>)

Environment

Ubuntu 18.04 x86-64, run with NGC 21.11-py3 and 22.02-py3.

Additional context

See output.txt for full torchscript graph output.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions