Skip to content

[fx2trt] Modify lower setting class #1107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 9, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 30 additions & 35 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,65 +10,64 @@


@dc.dataclass
class LowerSetting:
class LowerSettingBasic:
"""
Basic class for lowering.
max_batch_size: The maximum batch size for lowering job.
If run with TensorRT lowering, this is the maximum
batch size which can be used at execution time,
and also the batch size for which the ICudaEngine
will be optimized.
If run with AITemplate lowering, this the max batch_size
for the model.
lower_precision: lower precision dtype during lowering.
min_acc_module_size(int): minimal number of nodes for an accelerate submodule.
ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of
modules that need AST rewriting. This is aiming to eliminate input variable involve in
exception checking control flow.
leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where
modules will not be traced into.
verbose_profile (bool): verbosity of profiler, default to False.
"""
Basic configuration for lowering stack.

Args:
max_batch_size: The maximum batch size which can be used at execution time,
and also the batch size for which the ICudaEngine will be optimized.
max_batch_size: int = 2048
lower_precision: LowerPrecision = LowerPrecision.FP32
min_acc_module_size: int = 10
ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None
leaf_module_list: Optional[Set[Type[nn.Module]]] = None
verbose_profile: bool = False


@dc.dataclass
class LowerSetting(LowerSettingBasic):
"""
Basic configuration for lowering stack.
Args:
input_specs: Specs for inputs to engine, can either be a single size or a
range defined by Min, Optimal, Max sizes.

explicit_batch_dimension: Use explicit batch dimension during lowering.

explicit_precision: Use explicit precision during lowering.

lower_precision: lower precision dtype during lowering.

max_workspace_size: The maximum workspace size. The maximum GPU temporary
memory which the TensorRT engine can use at execution time.

strict_type_constraints: Require TensorRT engine to strictly follow data type
setting at execution time.

customized_fuse_pass: List of custmozied pass to apply during lowering process.

lower_basic_fuse_pass: Enable basic pass fuse duirng lowering, i.e. fuse multiple operations
as (a->b->c->d)=>(e). Current basic fuse patterns are:
permute->linear
permute->matmul

verbose_log: Enable TensorRT engine verbose log mode.

algo_selector: Enable TensorRT algorithm selector at execution time.

timing_cache_prefix: TensorRT timing cache file path. TensorRT engine will use timing
cache file at execution time if valid timing cache file is provided.

save_timing_cache: Save updated timing cache data into timing cache file if the timing
cache file is provided.

ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of
modules that need AST rewriting. This is aiming to eliminate input variable involve in
exception checking control flow.

leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where
modules will not be traced into.

cuda_graph_batch_size (int): Cuda graph batch size, default to be -1.

verbose_profile (bool): verbosity of profiler, default to False.

min_acc_module_size(int): minimal number of nodes for an accelerate submodule.
"""

max_batch_size: int = 2048
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
explicit_batch_dimension: bool = True
explicit_precision: bool = False
lower_precision: LowerPrecision = LowerPrecision.FP32
max_workspace_size: int = 1 << 30
strict_type_constraints: bool = False
customized_fuse_pass: PassManager = PassManager.build_from_passlist([])
Expand All @@ -79,8 +78,4 @@ class LowerSetting:
algo_selector = None
timing_cache_prefix: str = ""
save_timing_cache: bool = False
ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None
leaf_module_list: Optional[Set[Type[nn.Module]]] = None
cuda_graph_batch_size: int = -1
verbose_profile: bool = False
min_acc_module_size: int = 10