|
1 | 1 | from typing import List
|
2 | 2 | import torch
|
| 3 | +import argparse |
3 | 4 |
|
4 | 5 | from torchbenchmark.util.backends import create_backend
|
5 | 6 | from torchbenchmark.util.env_check import is_hf_model
|
6 | 7 |
|
| 8 | +def parse_torch_trt_args(backend_args: List[str]): |
| 9 | + """Parses CLI-provided backend arguments to extract Torch-TRT keywords |
| 10 | +
|
| 11 | + Returns kwargs dictionary and remainder arguments which were unrecognized |
| 12 | + """ |
| 13 | + arg_parser = argparse.ArgumentParser() |
| 14 | + arg_parser.add_argument("--truncate_long_and_double", default=False, action="store_true") |
| 15 | + arg_parser.add_argument("--workspace_size", type=int) |
| 16 | + arg_parser.add_argument("--min_block_size", type=int) |
| 17 | + arg_parser.add_argument("--ir", type=str, default="ts") |
| 18 | + args, unknown = arg_parser.parse_known_args(backend_args) |
| 19 | + |
| 20 | + return vars(args), unknown |
| 21 | + |
7 | 22 | @create_backend
|
8 | 23 | def fx2trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
|
9 | 24 | FP16 = True if model.dargs.precision == "fp16" else False
|
@@ -40,18 +55,37 @@ def _fx2trt():
|
40 | 55 |
|
41 | 56 | @create_backend
|
42 | 57 | def torch_trt(model: 'torchbenchmark.util.model.BenchmarkModel', backend_args: List[str]):
|
| 58 | + """Backend for Torch-TRT |
| 59 | +
|
| 60 | + Can be directly invoked from the command line, for example via: |
| 61 | + python run.py resnet18 -d cuda -t eval --backend torch_trt --precision fp32 --truncate_long_and_double |
| 62 | +
|
| 63 | + Options include: |
| 64 | + --truncate_long_and_double: Whether to automatically truncate long and double operations |
| 65 | + --min_block_size: Minimum number of operations in an accelerated TRT block |
| 66 | + --workspace_size: Size of workspace allotted to TensorRT |
| 67 | + --ir: Which internal representation to use: {"ts", "dynamo_compile", "fx_ts_compat", ...} |
| 68 | + """ |
43 | 69 | FP16 = True if model.dargs.precision == "fp16" else False
|
44 |
| - assert model.device == "cuda" and model.test == "eval", f"fx2trt only works on CUDA inference tests." |
| 70 | + assert model.device == "cuda" and model.test == "eval", f"Torch-TRT only works on CUDA inference tests." |
| 71 | + |
| 72 | + # Extract relevant Torch-TRT arguments from the provided CLI arguments |
| 73 | + torch_trt_kwargs, backend_args = parse_torch_trt_args(backend_args) |
| 74 | + |
45 | 75 | def _torch_trt():
|
| 76 | + """Helper function for invoking Torch-TRT |
| 77 | + """ |
46 | 78 | import torch_tensorrt
|
47 | 79 | module, example_inputs = model.get_module()
|
48 |
| - if FP16: |
49 |
| - torchtrt_dtype = torch_tensorrt.dtype.half |
50 |
| - torch_dtype = torch.half |
51 |
| - else: |
52 |
| - torchtrt_dtype = torch_tensorrt.dtype.float |
53 |
| - torch_dtype = torch.float32 |
54 |
| - trt_input = [torch_tensorrt.Input(shape=example_inputs[0].shape, dtype=torch_dtype)] |
55 |
| - trt_module = torch_tensorrt.compile(module, inputs=trt_input, enabled_precisions=torchtrt_dtype) |
| 80 | + torch_dtype_precision = torch.half if FP16 else torch.float32 |
| 81 | + |
| 82 | + trt_input = [torch_tensorrt.Input(shape=input_.shape, dtype=input_.dtype) |
| 83 | + for input_ in example_inputs] |
| 84 | + |
| 85 | + trt_module = torch_tensorrt.compile(module, |
| 86 | + inputs=trt_input, |
| 87 | + enabled_precisions={torch_dtype_precision}, |
| 88 | + **torch_trt_kwargs) |
56 | 89 | model.set_module(trt_module)
|
| 90 | + |
57 | 91 | return _torch_trt, backend_args
|
0 commit comments