Skip to content

Commit 23724b7

Browse files
committed
[backends] Add functionality to TRT backend
- Add argument parsing for backend arguments to pass to TRT - Add capability to specify IR via command line CLI - Add functionality to compilation path and clean up code
1 parent 3df1cc9 commit 23724b7

File tree

1 file changed

+43
-9
lines changed
  • torchbenchmark/util/backends

1 file changed

+43
-9
lines changed

torchbenchmark/util/backends/trt.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,24 @@
11
from typing import List
22
import torch
3+
import argparse
34

45
from torchbenchmark.util.backends import create_backend
56
from torchbenchmark.util.env_check import is_hf_model
67

8+
def parse_torch_trt_args(backend_args: List[str]):
9+
"""Parses CLI-provided backend arguments to extract Torch-TRT keywords
10+
11+
Returns kwargs dictionary and remainder arguments which were unrecognized
12+
"""
13+
arg_parser = argparse.ArgumentParser()
14+
arg_parser.add_argument("--truncate_long_and_double", default=False, action="store_true")
15+
arg_parser.add_argument("--workspace_size", type=int)
16+
arg_parser.add_argument("--min_block_size", type=int)
17+
arg_parser.add_argument("--ir", type=str, default="ts")
18+
args, unknown = arg_parser.parse_known_args(backend_args)
19+
20+
return vars(args), unknown
21+
722
@create_backend
823
def fx2trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
924
FP16 = True if model.dargs.precision == "fp16" else False
@@ -40,18 +55,37 @@ def _fx2trt():
4055

4156
@create_backend
4257
def torch_trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
58+
"""Backend for Torch-TRT
59+
60+
Can be directly invoked from the command line, for example via:
61+
python run.py resnet18 -d cuda -t eval --backend torch_trt --precision fp32 --truncate_long_and_double
62+
63+
Options include:
64+
--truncate_long_and_double: Whether to automatically truncate long and double operations
65+
--min_block_size: Minimum number of operations in an accelerated TRT block
66+
--workspace_size: Size of workspace allotted to TensorRT
67+
--ir: Which internal representation to use: {"ts", "dynamo_compile", "fx_ts_compat", ...}
68+
"""
4369
FP16 = True if model.dargs.precision == "fp16" else False
44-
assert model.device == "cuda" and model.test == "eval", f"fx2trt only works on CUDA inference tests."
70+
assert model.device == "cuda" and model.test == "eval", f"Torch-TRT only works on CUDA inference tests."
71+
72+
# Extract relevant Torch-TRT arguments from the provided CLI arguments
73+
torch_trt_kwargs, backend_args = parse_torch_trt_args(backend_args)
74+
4575
def _torch_trt():
76+
"""Helper function for invoking Torch-TRT
77+
"""
4678
import torch_tensorrt
4779
module, example_inputs = model.get_module()
48-
if FP16:
49-
torchtrt_dtype = torch_tensorrt.dtype.half
50-
torch_dtype = torch.half
51-
else:
52-
torchtrt_dtype = torch_tensorrt.dtype.float
53-
torch_dtype = torch.float32
54-
trt_input = [torch_tensorrt.Input(shape=example_inputs[0].shape, dtype=torch_dtype)]
55-
trt_module = torch_tensorrt.compile(module, inputs=trt_input, enabled_precisions=torchtrt_dtype)
80+
torch_dtype_precision = torch.half if FP16 else torch.float32
81+
82+
trt_input = [torch_tensorrt.Input(shape=input_.shape, dtype=input_.dtype)
83+
for input_ in example_inputs]
84+
85+
trt_module = torch_tensorrt.compile(module,
86+
inputs=trt_input,
87+
enabled_precisions={torch_dtype_precision},
88+
**torch_trt_kwargs)
5689
model.set_module(trt_module)
90+
5791
return _torch_trt, backend_args

0 commit comments

Comments
 (0)