Skip to content

Commit 8ccf8e6

Browse files
committed
fix: Pass example inputs directly to Torch-TRT
1 parent 493230e commit 8ccf8e6

File tree

1 file changed

+1
-6
lines changed
  • torchbenchmark/util/backends

1 file changed

+1
-6
lines changed

torchbenchmark/util/backends/trt.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,14 @@ def _torch_trt():
114114
module, example_inputs = model.get_module()
115115
torch_dtype_precision = torch.half if FP16 else torch.float32
116116

117-
trt_input = [
118-
torch_tensorrt.Input(shape=input_.shape, dtype=input_.dtype)
119-
for input_ in example_inputs
120-
]
121-
122117
print(
123118
f"Compiling {model.name} with batch size {model.batch_size}, precision {model.dargs.precision}, "
124119
+ f"and {'default' if 'ir' not in torch_trt_kwargs else torch_trt_kwargs['ir']} IR"
125120
)
126121

127122
trt_module = torch_tensorrt.compile(
128123
module,
129-
inputs=trt_input,
124+
inputs=example_inputs,
130125
enabled_precisions={torch_dtype_precision},
131126
**torch_trt_kwargs,
132127
)

0 commit comments

Comments
 (0)