-
Notifications
You must be signed in to change notification settings - Fork 362
/
Copy pathbackends.py
180 lines (156 loc) · 6.57 KB
/
backends.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from __future__ import annotations
import functools
import logging
import unittest
from typing import Any, Callable, Sequence
import torch
import torch._dynamo as td
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import aot_export_joint_simple
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo._compiler import compile_module
from torch_tensorrt.dynamo.lowering import (
get_decompositions,
modify_reshape_complex_nodes,
post_lowering,
remove_detach,
remove_sym_nodes,
repair_input_aliasing,
)
from torch_tensorrt.dynamo.utils import (
parse_dynamo_kwargs,
prepare_inputs,
set_log_level,
)
logger = logging.getLogger(__name__)
@td.register_backend(name="tensorrt") # type: ignore[misc]
@td.register_backend(name="torch_tensorrt") # type: ignore[misc]
def torch_tensorrt_backend(
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
) -> torch.nn.Module:
# Set log level at the top of compilation (torch_tensorrt.dynamo)
if (
"options" in kwargs
and "debug" in kwargs["options"]
and kwargs["options"]["debug"]
) or ("debug" in kwargs and kwargs["debug"]):
set_log_level(logger.parent, logging.DEBUG)
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
def aot_torch_tensorrt_aten_backend(
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
) -> torch.nn.Module:
settings, engine_cache = parse_dynamo_kwargs(kwargs)
if settings.use_aot_joint_export:
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
logger.debug("Wrapping the backend with aot_autograd\n")
_pretraced_backend_autograd = functools.partial(
_pretraced_backend, settings=settings, engine_cache=engine_cache
)
settings_aot_autograd = {}
settings_aot_autograd["decompostions"] = get_decompositions(
settings.enable_experimental_decompositions
)
# This is added since detach lowering leads to alias nodes
# Error - View operation returned a tensor that is the same as the input base tensor
# torch nop_decompositions in torch/_decomp/decompositions.py
if aten.detach in settings_aot_autograd["decompositions"]:
del settings_aot_autograd["decompositions"][aten.detach]
return aot_autograd(
fw_compiler=_pretraced_backend_autograd,
decompositions=get_decompositions(settings.enable_experimental_decompositions),
)(gm, sample_inputs)
def _pretraced_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[Any],
settings: CompilationSettings = CompilationSettings(),
engine_cache: Any = None,
) -> torch.fx.GraphModule | Callable[..., Any]:
"""Helper function to manage translation of traced FX module to TRT engines
Args:
module: FX GraphModule to convert
inputs: Inputs to the module
settings: Compilation settings
engine_cache: Engine cache instance
Returns:
Compiled FX GraphModule
"""
try:
logger.debug("Pre-AOT Autograd graph:\n" + str(gm.graph))
fake_mode = detect_fake_mode(sample_inputs)
# Place backend tracing within FakeTensor context allowing nonfake Tensors
with unittest.mock.patch.object(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
repair_input_aliasing(gm, settings)
# Remove sym_int placeholders and inputs
remove_sym_nodes(gm, sample_inputs, settings)
torch_inputs = [
input for input in sample_inputs if isinstance(input, torch.Tensor)
]
# Remove detach nodes
remove_detach(gm, settings)
complexInputIndices = []
for i, torch_input in enumerate(torch_inputs):
if torch_inputs[i].dtype == torch.complex64:
complexInputIndices.append(i)
torch_input_real = torch_inputs[i].real
torch_input_imaginary = torch_inputs[i].imag
torch_inputs[i] = torch.stack(
(torch_input_real, torch_input_imaginary), dim=-1
)
# Invoke AOTAutograd to translate operators to aten
if settings.use_aot_joint_export:
gm = aot_export_joint_simple(
gm,
sample_inputs,
trace_joint=False,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
),
)
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
gm = post_lowering(gm, settings)
logger.debug("Lowered Input graph:\n " + str(gm.graph))
if complexInputIndices:
modify_reshape_complex_nodes(gm, complexInputIndices)
logger.debug(
"Input graph after modifying complex nodes:\n " + str(gm.graph)
)
torchtrt_inputs = prepare_inputs(
torch_inputs, disable_memory_format_check=True
)
if settings.require_full_compilation:
logger.warning(
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
)
if settings.strip_engine_weights:
logger.error(
"strip_engine_weights arg is not supported for torch.compile()"
)
trt_compiled = compile_module(
gm,
torchtrt_inputs,
settings=settings,
engine_cache=engine_cache,
)
return trt_compiled
except (AssertionError, RuntimeError):
if not settings.pass_through_build_failures:
logger.warning(
"TRT conversion failed on the subgraph. See trace above. "
+ "Returning GraphModule forward instead.",
exc_info=True,
)
return gm
else:
logger.critical(
"Halting compilation on build failure since "
+ "pass_through_build_failures was specified as True. "
+ "To return the default Torch implementation and avoid "
+ "halting compilation on engine build failures, "
+ "specify pass_through_build_failures=False."
)
raise