Skip to content

Commit e4e02e1

Browse files
author
Wei
authored
Merge pull request #1107 from pytorch/fb-sync-wwei6
fx2trt] Modify lower setting class to accommandate AIT lowering
2 parents 2e09ce5 + 4df1d24 commit e4e02e1

File tree

1 file changed

+30
-35
lines changed

1 file changed

+30
-35
lines changed

py/torch_tensorrt/fx/lower_setting.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,65 +10,64 @@
1010

1111

1212
@dc.dataclass
13-
class LowerSetting:
13+
class LowerSettingBasic:
14+
"""
15+
Basic class for lowering.
16+
max_batch_size: The maximum batch size for lowering job.
17+
If run with TensorRT lowering, this is the maximum
18+
batch size which can be used at execution time,
19+
and also the batch size for which the ICudaEngine
20+
will be optimized.
21+
If run with AITemplate lowering, this the max batch_size
22+
for the model.
23+
lower_precision: lower precision dtype during lowering.
24+
min_acc_module_size(int): minimal number of nodes for an accelerate submodule.
25+
ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of
26+
modules that need AST rewriting. This is aiming to eliminate input variable involve in
27+
exception checking control flow.
28+
leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where
29+
modules will not be traced into.
30+
verbose_profile (bool): verbosity of profiler, default to False.
1431
"""
15-
Basic configuration for lowering stack.
1632

17-
Args:
18-
max_batch_size: The maximum batch size which can be used at execution time,
19-
and also the batch size for which the ICudaEngine will be optimized.
33+
max_batch_size: int = 2048
34+
lower_precision: LowerPrecision = LowerPrecision.FP32
35+
min_acc_module_size: int = 10
36+
ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None
37+
leaf_module_list: Optional[Set[Type[nn.Module]]] = None
38+
verbose_profile: bool = False
2039

40+
41+
@dc.dataclass
42+
class LowerSetting(LowerSettingBasic):
43+
"""
44+
Basic configuration for lowering stack.
45+
Args:
2146
input_specs: Specs for inputs to engine, can either be a single size or a
2247
range defined by Min, Optimal, Max sizes.
23-
2448
explicit_batch_dimension: Use explicit batch dimension during lowering.
25-
2649
explicit_precision: Use explicit precision during lowering.
27-
28-
lower_precision: lower precision dtype during lowering.
29-
3050
max_workspace_size: The maximum workspace size. The maximum GPU temporary
3151
memory which the TensorRT engine can use at execution time.
32-
3352
strict_type_constraints: Require TensorRT engine to strictly follow data type
3453
setting at execution time.
35-
3654
customized_fuse_pass: List of custmozied pass to apply during lowering process.
37-
3855
lower_basic_fuse_pass: Enable basic pass fuse duirng lowering, i.e. fuse multiple operations
3956
as (a->b->c->d)=>(e). Current basic fuse patterns are:
4057
permute->linear
4158
permute->matmul
42-
4359
verbose_log: Enable TensorRT engine verbose log mode.
44-
4560
algo_selector: Enable TensorRT algorithm selector at execution time.
46-
4761
timing_cache_prefix: TensorRT timing cache file path. TensorRT engine will use timing
4862
cache file at execution time if valid timing cache file is provided.
49-
5063
save_timing_cache: Save updated timing cache data into timing cache file if the timing
5164
cache file is provided.
52-
53-
ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of
54-
modules that need AST rewriting. This is aiming to eliminate input variable involve in
55-
exception checking control flow.
56-
57-
leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where
58-
modules will not be traced into.
59-
6065
cuda_graph_batch_size (int): Cuda graph batch size, default to be -1.
61-
62-
verbose_profile (bool): verbosity of profiler, default to False.
63-
64-
min_acc_module_size(int): minimal number of nodes for an accelerate submodule.
6566
"""
6667

67-
max_batch_size: int = 2048
6868
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
6969
explicit_batch_dimension: bool = True
7070
explicit_precision: bool = False
71-
lower_precision: LowerPrecision = LowerPrecision.FP32
7271
max_workspace_size: int = 1 << 30
7372
strict_type_constraints: bool = False
7473
customized_fuse_pass: PassManager = PassManager.build_from_passlist([])
@@ -79,8 +78,4 @@ class LowerSetting:
7978
algo_selector = None
8079
timing_cache_prefix: str = ""
8180
save_timing_cache: bool = False
82-
ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None
83-
leaf_module_list: Optional[Set[Type[nn.Module]]] = None
8481
cuda_graph_batch_size: int = -1
85-
verbose_profile: bool = False
86-
min_acc_module_size: int = 10

0 commit comments

Comments
 (0)