🐛 [Bug] Compilation causes error: RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:66] Expected ivalues_maps.count(input) to be true but got false Could not find torch::jit::Value* 47 produced from %47 : int = prim::dtype(%52) in lowering graph for mini graph input.
#922
Closed
Closed
Description
Bug Description
Compiling the graph throws the following error:
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:66] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 47 produced from %47 : int = prim::dtype(%52) in lowering graph for mini graph input.
Looking at the output torchscript graph, %47 is defined in a prior node, however, it does not appear to be visible in the current node.
To Reproduce
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt as torchtrt
import torch_tensorrt.logging as logging
logging.set_reportable_log_level(logging.Level.Graph)
torch.manual_seed(0)
DEVICE = torch.device("cuda:0")
SHAPE = (1, 1)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(1, 1)
def forward(self, a):
out = self.lin(a)
tril = torch.zeros(1, 1, 1, device=a.device, dtype=out.dtype)
indices = torch.tril_indices(1, 1)
tril[:, indices[0], indices[1]] = out
return tril
if __name__ == "__main__":
tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)
model = Model().eval().to(DEVICE)
out = model(tensor)
print(f"Model: {out}")
model_trt = torchtrt.compile(
model,
inputs=[
torchtrt.Input(shape=SHAPE),
],
enabled_precisions={torch.float},
truncate_long_and_double=True
)
out_trt = model(tensor)
print(f"Model TRT: {out_trt}")
assert torch.max(torch.abs(out - out_trt)) < 1e-6
Throws the following error:
Traceback (most recent call last):
File "/scripts/tril.py", line 39, in <module>
model_trt = torchtrt.compile(
File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 97, in compile
return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 119, in compile
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:66] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 47 produced from %47 : int = prim::dtype(%52) in lowering graph for mini graph input.
Expected behavior
Compilation should not fail, and should produce the following output when run:
Model: tensor([[[0.5434]]], device='cuda:0', grad_fn=<CopySlices>)
Environment
Ubuntu 18.04 x86-64, run with NGC 21.11-py3
and 22.02-py3
.
Additional context
See output.txt for full torchscript graph output.