Description
Is your feature request related to a problem? Please describe.
Currently, the forward
function only supports tensor input types when compiling. However, sometimes we wish to supply many tensors into the forward
function at once (say, greater than 10); this results in a very long forward
API call where we have to list every tensor individually when calling forward
. It would be helpful if we could pass in a single container containing these tensors all at once instead, which results in a much cleaner API call.
For this specific request, I focus on the list
and namedtuple
input types first, since these should cover most basic uses cases (and should functionally satisfy named tensor key-value pair type inputs).
Describe the solution you'd like
Instead of supporting only the following, where we need to supply torch.Tensor
s into forward
:
DEVICE = torch.device("cuda:0")
SHAPE = (1, 1)
torch.manual_seed(0)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b):
return a - b
if __name__ == "__main__":
tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)
model = Model().eval().to(DEVICE)
out = model(tensor, tensor)
model_trt = torch_tensorrt.compile(
model,
inputs=[
torch_tensorrt.Input(shape=SHAPE),
torch_tensorrt.Input(shape=SHAPE),
],
enabled_precisions={torch.float},
)
out_trt = model(tensor, tensor)
assert torch.max(torch.abs(out - out_trt)) < 1e-6
Support also inputting namedtuple
or list
into forward
:
DEVICE = torch.device("cuda:0")
SHAPE = (1, 1)
torch.manual_seed(0)
Input = namedtuple('Input', ['t1', 't2'])
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_: Input):
return input_.t1 - input_.t2
if __name__ == "__main__":
tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)
input_ = Input(tensor, tensor)
model = Model().eval().to(DEVICE)
out = model(input_)
model_trt = torch_tensorrt.compile(
model,
inputs=[
torch_tensorrt.Input(shape=SHAPE),
torch_tensorrt.Input(shape=SHAPE),
],
enabled_precisions={torch.float},
)
out_trt = model(input_)
assert torch.max(torch.abs(out - out_trt)) < 1e-6
Describe alternatives you've considered
Currently the only alternative is to supply tensors directly into the forward
function; supplying namedtuples
will cause the compilation to segfault, and supplying lists will cause the compilation to fail to recognize the input altogether.
Additional context
- For simplicity, the input containers should contain ONLY tensors (implying that we disallow nested containers). Containers with mixed input types are ignored.
- Furthermore, there must be a bijection between the tensors in the container and the sizes provided into the
compile
call; ie. there must be oneInput
size for each tensor in the container and both are taken in the same order. - We can mix tensors and containers into the
forward
call (eg.forward(x: torch.Tensor, y: List[torch.Tensor], z: namedtuple[torch.Tensor])
). Any other types are treated as they are currently when input.