-
Notifications
You must be signed in to change notification settings - Fork 362
/
Copy path_settings.py
167 lines (156 loc) · 8.71 KB
/
_settings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from dataclasses import dataclass, field
from typing import Collection, Optional, Set, Tuple, Union
from torch.fx.node import Target
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
CACHE_BUILT_ENGINES,
DEBUG,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
DRYRUN,
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENABLE_WEIGHT_STREAMING,
ENABLED_PRECISIONS,
ENGINE_CAPABILITY,
HARDWARE_COMPATIBLE,
IMMUTABLE_WEIGHTS,
LAZY_ENGINE_INIT,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
REFIT_IDENTICAL_ENGINE_WEIGHTS,
REQUIRE_FULL_COMPILATION,
REUSE_CACHED_ENGINES,
SPARSE_WEIGHTS,
STRIP_ENGINE_WEIGHTS,
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
USE_AOT_JOINT_EXPORT,
USE_EXPLICIT_TYPING,
USE_FAST_PARTITIONER,
USE_FP32_ACC,
USE_PYTHON_RUNTIME,
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
default_device,
)
@dataclass
class CompilationSettings:
"""Compilation settings for Torch-TensorRT Dynamo Paths
Args:
enabled_precisions (Set[dtype]): Available kernel dtype precisions
debug (bool): Whether to print out verbose debugging information
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
min_block_size (int): Minimum number of operators per TRT-Engine Block
torch_executed_ops (Collection[Target]): Collection of operations to run in Torch, regardless of converter coverage
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
version_compatible (bool): Provide version forward-compatibility for engine plan files
optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time,
searching for more optimization options. TRT defaults to 3
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
argument as None
truncate_double (bool): Whether to truncate float64 TRT engine inputs or weights to float32
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
or only a selected subset of them
device (Device): GPU to compile the model on
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
sparse_weights (bool): Whether to allow the builder to use sparse weights
engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
dryrun (Union[bool, str]): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to
TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the
output to a file if a string path is specified
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
refit_identical_engine_weights (bool): Whether to refit the engine with identical weights
strip_engine_weights (bool): Whether to strip the engine weights
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored
enable_weight_streaming (bool): Enable weight streaming.
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
"""
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
debug: bool = DEBUG
workspace_size: int = WORKSPACE_SIZE
min_block_size: int = MIN_BLOCK_SIZE
torch_executed_ops: Collection[Target] = field(default_factory=set)
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
max_aux_streams: Optional[int] = MAX_AUX_STREAMS
version_compatible: bool = VERSION_COMPATIBLE
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
truncate_double: bool = TRUNCATE_DOUBLE
use_fast_partitioner: bool = USE_FAST_PARTITIONER
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
device: Device = field(default_factory=default_device)
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
disable_tf32: bool = DISABLE_TF32
assume_dynamic_shape_support: bool = ASSUME_DYNAMIC_SHAPE_SUPPORT
sparse_weights: bool = SPARSE_WEIGHTS
engine_capability: EngineCapability = field(
default_factory=lambda: ENGINE_CAPABILITY
)
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS
dla_sram_size: int = DLA_SRAM_SIZE
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE
dryrun: Union[bool, str] = DRYRUN
hardware_compatible: bool = HARDWARE_COMPATIBLE
timing_cache_path: str = TIMING_CACHE_PATH
lazy_engine_init: bool = LAZY_ENGINE_INIT
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
use_explicit_typing: bool = USE_EXPLICIT_TYPING
use_fp32_acc: bool = USE_FP32_ACC
refit_identical_engine_weights: bool = REFIT_IDENTICAL_ENGINE_WEIGHTS
strip_engine_weights: bool = STRIP_ENGINE_WEIGHTS
immutable_weights: bool = IMMUTABLE_WEIGHTS
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT
_SETTINGS_TO_BE_ENGINE_INVARIANT = (
"enabled_precisions",
"max_aux_streams",
"version_compatible",
"optimization_level",
"disable_tf32",
"sparse_weights",
"engine_capability",
"hardware_compatible",
"refit_identical_engine_weights",
"strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default?
"immutable_weights",
"enable_weight_streaming",
)
def settings_are_compatible(
set_a: CompilationSettings, set_b: CompilationSettings
) -> Tuple[bool, Set[str]]:
incompatible_settings: Set[str] = set()
for f in _SETTINGS_TO_BE_ENGINE_INVARIANT:
if getattr(set_a, f) != getattr(set_b, f):
incompatible_settings.add(f)
if len(incompatible_settings) == 0:
return True, set()
else:
return False, incompatible_settings