Skip to content

Commit 30a3c4a

Browse files
gs-olivefacebook-github-bot
authored andcommitted
Add functionality to TRT backend (#1753)
Summary: - Add argument parsing for backend arguments to be passed to Torch-TRT - Add capability to specify IR and other Torch-TRT fields via command line interface - Add functionality to compilation path and clean up code cc: narendasan Pull Request resolved: #1753 Reviewed By: FindHao Differential Revision: D47682302 Pulled By: xuzhao9 fbshipit-source-id: 4f1a17db1b44908aa564099c8afcc55fb7fdb0df
1 parent 36e1ed1 commit 30a3c4a

File tree

3 files changed

+316
-20
lines changed

3 files changed

+316
-20
lines changed

torchbenchmark/util/backends/trt.py

Lines changed: 93 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,68 @@
11
from typing import List
22
import torch
3+
import argparse
34

4-
from torchbenchmark.util.backends import create_backend
5+
from torchbenchmark.util.backends import create_backend
56
from torchbenchmark.util.env_check import is_hf_model
67

8+
9+
def parse_torch_trt_args(backend_args: List[str]):
10+
"""Parses CLI-provided backend arguments to extract Torch-TRT keywords
11+
12+
Returns kwargs dictionary and remainder arguments which were unrecognized
13+
"""
14+
arg_parser = argparse.ArgumentParser()
15+
arg_parser.add_argument(
16+
"--truncate_long_and_double",
17+
default=None,
18+
action="store_true",
19+
help="Whether to automatically truncate long and double operations",
20+
)
21+
arg_parser.add_argument(
22+
"--workspace_size", type=int, help="Size of workspace allotted to TensorRT"
23+
)
24+
arg_parser.add_argument(
25+
"--min_block_size",
26+
type=int,
27+
help="Minimum number of operations in an accelerated TRT block",
28+
)
29+
arg_parser.add_argument(
30+
"--ir",
31+
type=str,
32+
help="Which internal representation to use: {'ts', 'dynamo_compile', 'fx_ts_compat', ...}",
33+
)
34+
args, unknown = arg_parser.parse_known_args(backend_args)
35+
36+
# Remove unspecified arguments from the args dictionary
37+
# (Only pass through user-specified args)
38+
parsed_args = vars(args)
39+
for key in list(parsed_args.keys()):
40+
if parsed_args[key] is None:
41+
del parsed_args[key]
42+
43+
return parsed_args, unknown
44+
45+
746
@create_backend
8-
def fx2trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
47+
def fx2trt(model: "torchbenchmark.util.model.BenchmarkModel", backend_args: List[str]):
948
FP16 = True if model.dargs.precision == "fp16" else False
1049
HF_MODEL = True if is_hf_model(model) else False
11-
assert model.device == "cuda" and model.test == "eval", f"fx2trt only works on CUDA inference tests."
50+
assert (
51+
model.device == "cuda" and model.test == "eval"
52+
), f"fx2trt only works on CUDA inference tests."
53+
1254
def _fx2trt():
1355
from torch_tensorrt.fx import compile
1456
from torch_tensorrt.fx.utils import LowerPrecision
57+
1558
module, example_inputs = model.get_module()
1659
precision = LowerPrecision.FP16 if FP16 else LowerPrecision.FP32
1760

1861
if HF_MODEL:
1962
from transformers.utils.fx import symbolic_trace as hf_symbolic_trace
63+
2064
traced_model = hf_symbolic_trace(
21-
module,
22-
batch_size = model.batch_size,
23-
sequence_lenghth = model.max_length
65+
module, batch_size=model.batch_size, sequence_lenghth=model.max_length
2466
)
2567
trt_model = compile(
2668
traced_model,
@@ -31,27 +73,58 @@ def _fx2trt():
3173
max_workspace_size=20 << 30,
3274
)
3375
else:
34-
trt_model = compile(module=module,
35-
input=example_inputs,
36-
max_batch_size=model.batch_size,
37-
lower_precision=precision)
76+
trt_model = compile(
77+
module=module,
78+
input=example_inputs,
79+
max_batch_size=model.batch_size,
80+
lower_precision=precision,
81+
)
3882
model.set_module(trt_model)
83+
3984
return _fx2trt, backend_args
4085

86+
4187
@create_backend
42-
def torch_trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
88+
def torch_trt(
89+
model: "torchbenchmark.util.model.BenchmarkModel", backend_args: List[str]
90+
):
91+
"""Backend for Torch-TRT
92+
93+
Can be directly invoked from the command line, for example via:
94+
python run.py resnet18 -d cuda -t eval --backend torch_trt --precision fp32 --truncate_long_and_double
95+
96+
Options include:
97+
--truncate_long_and_double: Whether to automatically truncate long and double operations
98+
--min_block_size: Minimum number of operations in an accelerated TRT block
99+
--workspace_size: Size of workspace allotted to TensorRT
100+
--ir: Which internal representation to use: {"ts", "dynamo_compile", "fx_ts_compat", ...}
101+
"""
43102
FP16 = True if model.dargs.precision == "fp16" else False
44-
assert model.device == "cuda" and model.test == "eval", f"fx2trt only works on CUDA inference tests."
103+
assert (
104+
model.device == "cuda" and model.test == "eval"
105+
), f"Torch-TRT only works on CUDA inference tests."
106+
107+
# Extract relevant Torch-TRT arguments from the provided CLI arguments
108+
torch_trt_kwargs, backend_args = parse_torch_trt_args(backend_args)
109+
45110
def _torch_trt():
111+
"""Helper function for invoking Torch-TRT"""
46112
import torch_tensorrt
113+
47114
module, example_inputs = model.get_module()
48-
if FP16:
49-
torchtrt_dtype = torch_tensorrt.dtype.half
50-
torch_dtype = torch.half
51-
else:
52-
torchtrt_dtype = torch_tensorrt.dtype.float
53-
torch_dtype = torch.float32
54-
trt_input = [torch_tensorrt.Input(shape=example_inputs[0].shape, dtype=torch_dtype)]
55-
trt_module = torch_tensorrt.compile(module, inputs=trt_input, enabled_precisions=torchtrt_dtype)
115+
torch_dtype_precision = torch.half if FP16 else torch.float32
116+
117+
print(
118+
f"Compiling {model.name} with batch size {model.batch_size}, precision {model.dargs.precision}, "
119+
+ f"and {'default' if 'ir' not in torch_trt_kwargs else torch_trt_kwargs['ir']} IR"
120+
)
121+
122+
trt_module = torch_tensorrt.compile(
123+
module,
124+
inputs=example_inputs,
125+
enabled_precisions={torch_dtype_precision},
126+
**torch_trt_kwargs,
127+
)
56128
model.set_module(trt_module)
129+
57130
return _torch_trt, backend_args

userbenchmark/torch_trt/__init__.py

Whitespace-only changes.

userbenchmark/torch_trt/run.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import argparse
2+
import traceback
3+
import torch
4+
5+
import numpy as np
6+
7+
import json
8+
import os
9+
import time
10+
from datetime import datetime
11+
from typing import List
12+
13+
from torchbenchmark import (
14+
load_canary_model_by_name,
15+
load_model_by_name,
16+
list_models,
17+
ModelNotFoundError,
18+
)
19+
20+
21+
def cli(args: List[str]):
22+
"""Parse input arguments, extracting model specification and batch size"""
23+
arg_parser = argparse.ArgumentParser(args)
24+
arg_parser.add_argument(
25+
"--model",
26+
help="Full or partial name of a model to run. If partial, picks the first match.",
27+
default="",
28+
type=str,
29+
)
30+
arg_parser.add_argument(
31+
"--bs",
32+
help="Input batch size to test.",
33+
default=1,
34+
type=int,
35+
)
36+
arg_parser.add_argument(
37+
"--num_warmup",
38+
help="Number of inference warmup iterations.",
39+
default=10,
40+
type=int,
41+
)
42+
arg_parser.add_argument(
43+
"--num_iter",
44+
help="Number of inference iterations for benchmarking.",
45+
default=100,
46+
type=int,
47+
)
48+
parsed_args, unknown = arg_parser.parse_known_args()
49+
50+
return vars(parsed_args), unknown
51+
52+
53+
def save_metrics(metrics):
54+
"""Save metrics to a JSON file with formatted filename"""
55+
metrics_json = {
56+
"name": "torch_trt",
57+
"environ": {
58+
"metrics_version": "v0.1",
59+
"pytorch_git_version": torch.version.git_version,
60+
},
61+
"metrics": metrics,
62+
}
63+
64+
# Obtain target save directory for JSON metrics from current save directory
65+
current_dir = os.path.dirname(os.path.abspath(__file__))
66+
target_dir = os.path.normpath(
67+
os.path.join(current_dir, "../../.userbenchmark/torch_trt/")
68+
)
69+
70+
os.makedirs(target_dir, exist_ok=True)
71+
72+
# Format filename and path to save metrics
73+
metrics_file = "metrics-{}.json".format(
74+
datetime.fromtimestamp(time.time()).strftime("%Y%m%d%H%M%S")
75+
)
76+
metrics_save_path = os.path.join(target_dir, metrics_file)
77+
78+
with open(metrics_save_path, "w") as f:
79+
json.dump(metrics_json, f, indent=4)
80+
81+
82+
def run_single_model(
83+
Model,
84+
batch_size: int,
85+
extra_args: List[str],
86+
selected_ir: str,
87+
num_warmup: int,
88+
num_iter: int,
89+
):
90+
"""Run inference benchmarking on a single model"""
91+
# Build TorchBench model instance, with backend having the userbenchmark name
92+
# This invokes the torch_trt backend functionality directly
93+
model = Model(
94+
device="cuda",
95+
test="eval",
96+
jit=False,
97+
batch_size=batch_size,
98+
extra_args=[
99+
"--backend",
100+
]
101+
+ extra_args,
102+
)
103+
104+
metrics = run_one_step(model.invoke, model, num_warmup, num_iter, selected_ir)
105+
106+
# Print dynamo compilation metrics, if there are any.
107+
try:
108+
if model.pt2_compilation_time:
109+
metrics[
110+
f"{model.name}.bs_{model.batch_size}.precision_{model.dargs.precision}."
111+
+ f"ir_{selected_ir}.pt2_compilation_time"
112+
] = model.pt2_compilation_time
113+
except:
114+
pass
115+
116+
return metrics
117+
118+
119+
def run_one_step(func, model, num_warmup, num_iter, selected_ir):
120+
# Warmup model inference
121+
for _ in range(num_warmup):
122+
func()
123+
124+
result_summary = []
125+
126+
# Run inference for the specified number of iterations
127+
for _ in range(num_iter):
128+
torch.cuda.synchronize()
129+
start_event = torch.cuda.Event(enable_timing=True)
130+
end_event = torch.cuda.Event(enable_timing=True)
131+
132+
# Collect time_ns() instead of time() which does not provide better precision than 1
133+
# second according to https://docs.python.org/3/library/time.html#time.time.
134+
t0 = time.time_ns()
135+
start_event.record()
136+
func()
137+
end_event.record()
138+
torch.cuda.synchronize()
139+
t1 = time.time_ns()
140+
result_summary.append(
141+
(start_event.elapsed_time(end_event), (t1 - t0) / 1_000_000)
142+
)
143+
144+
# Get median times for GPU and CPU Walltime
145+
gpu_time = np.median(list(map(lambda x: x[0], result_summary)))
146+
cpu_walltime = np.median(list(map(lambda x: x[1], result_summary)))
147+
148+
if hasattr(model, "NUM_BATCHES"):
149+
median_gpu_time_per_batch = gpu_time / model.NUM_BATCHES
150+
median_cpu_walltime_per_batch = cpu_walltime / model.NUM_BATCHES
151+
else:
152+
median_gpu_time_per_batch = gpu_time
153+
median_cpu_walltime_per_batch = cpu_walltime
154+
155+
metrics = {
156+
f"{model.name}.bs_{model.batch_size}.precision_{model.dargs.precision}."
157+
+ f"ir_{selected_ir}.median_gpu_time_per_batch": median_gpu_time_per_batch,
158+
f"{model.name}.bs_{model.batch_size}.precision_{model.dargs.precision}."
159+
+ f"ir_{selected_ir}.median_cpu_walltime_per_batch": median_cpu_walltime_per_batch,
160+
}
161+
162+
return metrics
163+
164+
165+
def run(args: List[str]):
166+
"""Run inference and extract requested metrics"""
167+
parsed_args, unknown_args = cli(args)
168+
169+
# Attempt to extract specified IR for logging purposes
170+
try:
171+
ir_idx = unknown_args.index("--ir")
172+
selected_ir = unknown_args[ir_idx + 1]
173+
except (ValueError, IndexError):
174+
selected_ir = "default"
175+
176+
# Parse model string if specified, otherwise run all models
177+
# Adapted from benchmark/run.py
178+
if parsed_args["model"]:
179+
try:
180+
Model = load_model_by_name(parsed_args["model"])
181+
except ModuleNotFoundError:
182+
traceback.print_exc()
183+
exit(-1)
184+
except ModelNotFoundError:
185+
print(
186+
f"Warning: The model {parsed_args['model']} cannot be found at core set."
187+
)
188+
if not Model:
189+
try:
190+
Model = load_canary_model_by_name(parsed_args["model"])
191+
except ModuleNotFoundError:
192+
traceback.print_exc()
193+
exit(-1)
194+
except ModelNotFoundError:
195+
print(
196+
f"Error: The model {parsed_args['model']} cannot be found at either core or canary model set."
197+
)
198+
exit(-1)
199+
200+
all_metrics = run_single_model(
201+
Model,
202+
parsed_args["bs"],
203+
unknown_args,
204+
selected_ir,
205+
parsed_args["num_warmup"],
206+
parsed_args["num_iter"],
207+
)
208+
209+
else:
210+
all_metrics = {}
211+
212+
for Model in list_models():
213+
metrics = run_single_model(
214+
Model,
215+
parsed_args["bs"],
216+
unknown_args,
217+
selected_ir,
218+
parsed_args["num_warmup"],
219+
parsed_args["num_iter"],
220+
)
221+
all_metrics = {**all_metrics, **metrics}
222+
223+
save_metrics(all_metrics)

0 commit comments

Comments
 (0)