|
1 | 1 | from typing import List
|
2 | 2 | import torch
|
| 3 | +import argparse |
3 | 4 |
|
4 | 5 | from torchbenchmark.util.backends import create_backend
|
5 | 6 | from torchbenchmark.util.env_check import is_hf_model
|
6 | 7 |
|
| 8 | +def parse_torch_trt_args(backend_args: List[str]): |
| 9 | + """Parses CLI-provided backend arguments to extract Torch-TRT keywords |
| 10 | +
|
| 11 | + Returns kwargs dictionary and remainder arguments which were unrecognized |
| 12 | + """ |
| 13 | + arg_parser = argparse.ArgumentParser() |
| 14 | + arg_parser.add_argument("--truncate_long_and_double", default=None, action="store_true") |
| 15 | + arg_parser.add_argument("--workspace_size", type=int) |
| 16 | + arg_parser.add_argument("--min_block_size", type=int) |
| 17 | + arg_parser.add_argument("--ir", type=str) |
| 18 | + args, unknown = arg_parser.parse_known_args(backend_args) |
| 19 | + |
| 20 | + # Remove unspecified arguments from the args dictionary |
| 21 | + # (Only pass through user-specified args) |
| 22 | + parsed_args = vars(args) |
| 23 | + for key in list(parsed_args.keys()): |
| 24 | + if parsed_args[key] is None: |
| 25 | + del parsed_args[key] |
| 26 | + |
| 27 | + return parsed_args, unknown |
| 28 | + |
7 | 29 | @create_backend
|
8 | 30 | def fx2trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
|
9 | 31 | FP16 = True if model.dargs.precision == "fp16" else False
|
@@ -40,18 +62,37 @@ def _fx2trt():
|
40 | 62 |
|
41 | 63 | @create_backend
|
42 | 64 | def torch_trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
|
| 65 | + """Backend for Torch-TRT |
| 66 | +
|
| 67 | + Can be directly invoked from the command line, for example via: |
| 68 | + python run.py resnet18 -d cuda -t eval --backend torch_trt --precision fp32 --truncate_long_and_double |
| 69 | +
|
| 70 | + Options include: |
| 71 | + --truncate_long_and_double: Whether to automatically truncate long and double operations |
| 72 | + --min_block_size: Minimum number of operations in an accelerated TRT block |
| 73 | + --workspace_size: Size of workspace allotted to TensorRT |
| 74 | + --ir: Which internal representation to use: {"ts", "dynamo_compile", "fx_ts_compat", ...} |
| 75 | + """ |
43 | 76 | FP16 = True if model.dargs.precision == "fp16" else False
|
44 |
| - assert model.device == "cuda" and model.test == "eval", f"fx2trt only works on CUDA inference tests." |
| 77 | + assert model.device == "cuda" and model.test == "eval", f"Torch-TRT only works on CUDA inference tests." |
| 78 | + |
| 79 | + # Extract relevant Torch-TRT arguments from the provided CLI arguments |
| 80 | + torch_trt_kwargs, backend_args = parse_torch_trt_args(backend_args) |
| 81 | + |
45 | 82 | def _torch_trt():
|
| 83 | + """Helper function for invoking Torch-TRT |
| 84 | + """ |
46 | 85 | import torch_tensorrt
|
47 | 86 | module, example_inputs = model.get_module()
|
48 |
| - if FP16: |
49 |
| - torchtrt_dtype = torch_tensorrt.dtype.half |
50 |
| - torch_dtype = torch.half |
51 |
| - else: |
52 |
| - torchtrt_dtype = torch_tensorrt.dtype.float |
53 |
| - torch_dtype = torch.float32 |
54 |
| - trt_input = [torch_tensorrt.Input(shape=example_inputs[0].shape, dtype=torch_dtype)] |
55 |
| - trt_module = torch_tensorrt.compile(module, inputs=trt_input, enabled_precisions=torchtrt_dtype) |
| 87 | + torch_dtype_precision = torch.half if FP16 else torch.float32 |
| 88 | + |
| 89 | + trt_input = [torch_tensorrt.Input(shape=input_.shape, dtype=input_.dtype) |
| 90 | + for input_ in example_inputs] |
| 91 | + |
| 92 | + trt_module = torch_tensorrt.compile(module, |
| 93 | + inputs=trt_input, |
| 94 | + enabled_precisions={torch_dtype_precision}, |
| 95 | + **torch_trt_kwargs) |
56 | 96 | model.set_module(trt_module)
|
| 97 | + |
57 | 98 | return _torch_trt, backend_args
|
0 commit comments