Skip to content

🐛 [Bug] KeyError after resolveNonTensorInputs #1018

Closed
@mfeliz-cruise

Description

@mfeliz-cruise

Bug Description

Torch-TensorRT attempts to resolve all non-tensor inputs of a torch block if any of those inputs are generated by tensorrt blocks. This leads to a failed attempt to resolve a dictionary input to a torch block that is generated by another torch block. getDependencyNodes fails to identify the aten::_set_item as a dependency which results in a KeyError.

This is the original graph. This is a small artificial test case only intended to reproduce this issue.

graph(%x.1 : Tensor,
      %y.1 : Tensor):
  %3 : str = prim::Constant[value="INS"]() 
  %4 : int = prim::Constant[value=-1]() 
  %5 : bool = prim::Constant[value=0]()
  %6 : str = prim::Constant[value="OUTS"]()
  %out_dict.1 : Dict(str, Tensor) = prim::DictConstruct()
   = aten::_set_item(%out_dict.1, %3, %x.1) 
  %z.1 : Tensor = aten::__getitem__(%out_dict.1, %3)
  %9 : Tensor = aten::lt(%z.1, %y.1) 
  %13 : Tensor?[] = prim::ListConstruct(%9)
  %45 : int = prim::dtype(%z.1)
  %46 : Device = prim::device(%z.1)
  %49 : Tensor = aten::tensor(%4, %45, %46, %5)
  %14 : Tensor = aten::index_put_(%z.1, %13, %49, %5) 
   = aten::_set_item(%out_dict.1, %6, %z.1) 
  %15 : Tensor = aten::__getitem__(%out_dict.1, %3) 
  %16 : Tensor = aten::__getitem__(%out_dict.1, %6) 
  return (%15, %16)

It is segmented as follows. The Tensor?[] input to @2 from @1 will need to be resolved triggering resolution of all @2 inputs including %out_dict.1 which is a dictionary create in @0.

DEBUG: [Torch-TensorRT - Debug Build] - Segment Block @0:
    Target: Torch

    Graph: graph(%x.1 : Tensor):
  %1 : str = prim::Constant[value="INS"]()
  %out_dict.1 : Dict(str, Tensor) = prim::DictConstruct()
   = aten::_set_item(%out_dict.1, %1, %x.1)
  %z.1 : Tensor = aten::__getitem__(%out_dict.1, %1)
  return ()

DEBUG: [Torch-TensorRT - Debug Build] - Segment Block @1:
    Target: TensorRT

    Graph: graph(%z.1 : Tensor,
      %y.1 : Tensor):
  %0 : Tensor = aten::lt(%z.1, %y.1)
  %3 : Tensor?[] = prim::ListConstruct(%0)
  %4 : int = prim::dtype(%z.1)
  return ()

DEBUG: [Torch-TensorRT - Debug Build] - Segment Block @2:
    Target: Torch

    Graph: graph(%z.1 : Tensor,
      %4 : int,
      %7 : Tensor?[],
      %out_dict.1 : Dict(str, Tensor)):
  %11 : str = prim::Constant[value="INS"]()
  %9 : str = prim::Constant[value="OUTS"]()
  %5 : bool = prim::Constant[value=0]()
  %3 : int = prim::Constant[value=-1]()
  %0 : Device = prim::device(%z.1)
  %2 : Tensor = aten::tensor(%3, %4, %0, %5)
  %6 : Tensor = aten::index_put_(%z.1, %7, %2, %5)
   = aten::_set_item(%out_dict.1, %9, %z.1)
  %10 : Tensor = aten::__getitem__(%out_dict.1, %11)
  %12 : Tensor = aten::__getitem__(%out_dict.1, %9)
  return ()

After resolveNonTensorInputs we can see that the prim::DictConstruct() node for %out_dict.1 is copied into @2 without the following aten::_set_item node.

Segment Block @0:
    Target: Torch

    Graph: graph(%x.1 : Tensor):
  %1 : str = prim::Constant[value="INS"]()
  %out_dict.1 : Dict(str, Tensor) = prim::DictConstruct()
   = aten::_set_item(%out_dict.1, %1, %x.1)
  %z.1 : Tensor = aten::__getitem__(%out_dict.1, %1)
  return ()


Segment Block @1:
    Target: TensorRT

    Graph: graph(%z.1 : Tensor,
      %y.1 : Tensor):
  %0 : Tensor = aten::lt(%z.1, %y.1)
  %3 : Tensor?[] = prim::ListConstruct(%0)
  %4 : int = prim::dtype(%z.1)
  return ()


Segment Block @2:
    Target: Torch

    Graph: graph(%2 : Tensor,
      %z.1 : Tensor):
  %12 : str = prim::Constant[value="INS"]() 
  %10 : str = prim::Constant[value="OUTS"]()
  %8 : bool = prim::Constant[value=0]()
  %7 : int = prim::Constant[value=-1]()
  %out_dict.1 : Dict(str, Tensor) = prim::DictConstruct()
  %1 : Tensor?[] = prim::ListConstruct(%2)
  %3 : int = prim::dtype(%z.1)
  %5 : Device = prim::device(%z.1)
  %6 : Tensor = aten::tensor(%7, %3, %5, %8)
  %9 : Tensor = aten::index_put_(%z.1, %1, %6, %8)
   = aten::_set_item(%out_dict.1, %10, %z.1)
  %11 : Tensor = aten::__getitem__(%out_dict.1, %12)
  %13 : Tensor = aten::__getitem__(%out_dict.1, %10)
  return ()

To Reproduce

Steps to reproduce the behavior:

  1. Run the python below with the latest version of torch-tensorrt
# Third-party imports
import torch
import torch.nn as nn
import torch_tensorrt

class Reproducer(nn.Module):
    def __init__(self):
        super(Reproducer, self).__init__()
    def forward(self, x, y):
        out_dict = {}
        out_dict["INS"] = x
        z = out_dict["INS"]
        z[z < y] = -1
        out_dict["OUTS"] = z
        return out_dict["INS"], out_dict["OUTS"]

def reproduce_error():
    torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Graph)
    model = Reproducer().eval().cuda()

    x = torch.randn(20, 16, 50, 32).cuda()
    y = torch.randn(20, 16, 50, 32).cuda()
    
    trt_model = torch_tensorrt.compile(model, inputs=[x, y],  **{
            "truncate_long_and_double": True,
        })
    #print(trt_model.forward(x, y))


reproduce_error()
  1. Note the error "RuntimeError: KeyError: INS".

Expected behavior

Torch-TensorRT should not attempt to resolve non-tensor inputs of torch blocks that are generated by torch blocks. If it does choose to resolve a dictionary input it should include aten::_set_item as a dependency.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions