Description
Bug Description
Torch-TensorRT attempts to resolve all non-tensor inputs of a torch block if any of those inputs are generated by tensorrt blocks. This leads to a failed attempt to resolve a dictionary input to a torch block that is generated by another torch block. getDependencyNodes
fails to identify the aten::_set_item as a dependency which results in a KeyError.
This is the original graph. This is a small artificial test case only intended to reproduce this issue.
graph(%x.1 : Tensor,
%y.1 : Tensor):
%3 : str = prim::Constant[value="INS"]()
%4 : int = prim::Constant[value=-1]()
%5 : bool = prim::Constant[value=0]()
%6 : str = prim::Constant[value="OUTS"]()
%out_dict.1 : Dict(str, Tensor) = prim::DictConstruct()
= aten::_set_item(%out_dict.1, %3, %x.1)
%z.1 : Tensor = aten::__getitem__(%out_dict.1, %3)
%9 : Tensor = aten::lt(%z.1, %y.1)
%13 : Tensor?[] = prim::ListConstruct(%9)
%45 : int = prim::dtype(%z.1)
%46 : Device = prim::device(%z.1)
%49 : Tensor = aten::tensor(%4, %45, %46, %5)
%14 : Tensor = aten::index_put_(%z.1, %13, %49, %5)
= aten::_set_item(%out_dict.1, %6, %z.1)
%15 : Tensor = aten::__getitem__(%out_dict.1, %3)
%16 : Tensor = aten::__getitem__(%out_dict.1, %6)
return (%15, %16)
It is segmented as follows. The Tensor?[]
input to @2 from @1 will need to be resolved triggering resolution of all @2 inputs including %out_dict.1
which is a dictionary create in @0.
DEBUG: [Torch-TensorRT - Debug Build] - Segment Block @0:
Target: Torch
Graph: graph(%x.1 : Tensor):
%1 : str = prim::Constant[value="INS"]()
%out_dict.1 : Dict(str, Tensor) = prim::DictConstruct()
= aten::_set_item(%out_dict.1, %1, %x.1)
%z.1 : Tensor = aten::__getitem__(%out_dict.1, %1)
return ()
DEBUG: [Torch-TensorRT - Debug Build] - Segment Block @1:
Target: TensorRT
Graph: graph(%z.1 : Tensor,
%y.1 : Tensor):
%0 : Tensor = aten::lt(%z.1, %y.1)
%3 : Tensor?[] = prim::ListConstruct(%0)
%4 : int = prim::dtype(%z.1)
return ()
DEBUG: [Torch-TensorRT - Debug Build] - Segment Block @2:
Target: Torch
Graph: graph(%z.1 : Tensor,
%4 : int,
%7 : Tensor?[],
%out_dict.1 : Dict(str, Tensor)):
%11 : str = prim::Constant[value="INS"]()
%9 : str = prim::Constant[value="OUTS"]()
%5 : bool = prim::Constant[value=0]()
%3 : int = prim::Constant[value=-1]()
%0 : Device = prim::device(%z.1)
%2 : Tensor = aten::tensor(%3, %4, %0, %5)
%6 : Tensor = aten::index_put_(%z.1, %7, %2, %5)
= aten::_set_item(%out_dict.1, %9, %z.1)
%10 : Tensor = aten::__getitem__(%out_dict.1, %11)
%12 : Tensor = aten::__getitem__(%out_dict.1, %9)
return ()
After resolveNonTensorInputs
we can see that the prim::DictConstruct()
node for %out_dict.1
is copied into @2 without the following aten::_set_item
node.
Segment Block @0:
Target: Torch
Graph: graph(%x.1 : Tensor):
%1 : str = prim::Constant[value="INS"]()
%out_dict.1 : Dict(str, Tensor) = prim::DictConstruct()
= aten::_set_item(%out_dict.1, %1, %x.1)
%z.1 : Tensor = aten::__getitem__(%out_dict.1, %1)
return ()
Segment Block @1:
Target: TensorRT
Graph: graph(%z.1 : Tensor,
%y.1 : Tensor):
%0 : Tensor = aten::lt(%z.1, %y.1)
%3 : Tensor?[] = prim::ListConstruct(%0)
%4 : int = prim::dtype(%z.1)
return ()
Segment Block @2:
Target: Torch
Graph: graph(%2 : Tensor,
%z.1 : Tensor):
%12 : str = prim::Constant[value="INS"]()
%10 : str = prim::Constant[value="OUTS"]()
%8 : bool = prim::Constant[value=0]()
%7 : int = prim::Constant[value=-1]()
%out_dict.1 : Dict(str, Tensor) = prim::DictConstruct()
%1 : Tensor?[] = prim::ListConstruct(%2)
%3 : int = prim::dtype(%z.1)
%5 : Device = prim::device(%z.1)
%6 : Tensor = aten::tensor(%7, %3, %5, %8)
%9 : Tensor = aten::index_put_(%z.1, %1, %6, %8)
= aten::_set_item(%out_dict.1, %10, %z.1)
%11 : Tensor = aten::__getitem__(%out_dict.1, %12)
%13 : Tensor = aten::__getitem__(%out_dict.1, %10)
return ()
To Reproduce
Steps to reproduce the behavior:
- Run the python below with the latest version of torch-tensorrt
# Third-party imports
import torch
import torch.nn as nn
import torch_tensorrt
class Reproducer(nn.Module):
def __init__(self):
super(Reproducer, self).__init__()
def forward(self, x, y):
out_dict = {}
out_dict["INS"] = x
z = out_dict["INS"]
z[z < y] = -1
out_dict["OUTS"] = z
return out_dict["INS"], out_dict["OUTS"]
def reproduce_error():
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Graph)
model = Reproducer().eval().cuda()
x = torch.randn(20, 16, 50, 32).cuda()
y = torch.randn(20, 16, 50, 32).cuda()
trt_model = torch_tensorrt.compile(model, inputs=[x, y], **{
"truncate_long_and_double": True,
})
#print(trt_model.forward(x, y))
reproduce_error()
- Note the error "RuntimeError: KeyError: INS".
Expected behavior
Torch-TensorRT should not attempt to resolve non-tensor inputs of torch blocks that are generated by torch blocks. If it does choose to resolve a dictionary input it should include aten::_set_item
as a dependency.