We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 493230e commit 8ccf8e6Copy full SHA for 8ccf8e6
torchbenchmark/util/backends/trt.py
@@ -114,19 +114,14 @@ def _torch_trt():
114
module, example_inputs = model.get_module()
115
torch_dtype_precision = torch.half if FP16 else torch.float32
116
117
- trt_input = [
118
- torch_tensorrt.Input(shape=input_.shape, dtype=input_.dtype)
119
- for input_ in example_inputs
120
- ]
121
-
122
print(
123
f"Compiling {model.name} with batch size {model.batch_size}, precision {model.dargs.precision}, "
124
+ f"and {'default' if 'ir' not in torch_trt_kwargs else torch_trt_kwargs['ir']} IR"
125
)
126
127
trt_module = torch_tensorrt.compile(
128
module,
129
- inputs=trt_input,
+ inputs=example_inputs,
130
enabled_precisions={torch_dtype_precision},
131
**torch_trt_kwargs,
132
0 commit comments