Skip to content

Commit fc74cff

Browse files
committed
changes to make Llama example work
1 parent bba4153 commit fc74cff

File tree

7 files changed

+282
-9
lines changed

7 files changed

+282
-9
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414
from torch_tensorrt.dynamo._compiler import compile_module
1515
from torch_tensorrt.dynamo.lowering import (
1616
get_decompositions,
17+
modify_complex_nodes,
1718
post_lowering,
1819
remove_detach,
1920
remove_sym_nodes,
2021
repair_input_aliasing,
22+
replace_complex_placeholder_to_tuple,
2123
)
2224
from torch_tensorrt.dynamo.utils import (
25+
find_complex_nodes,
2326
parse_dynamo_kwargs,
2427
prepare_inputs,
2528
set_log_level,
@@ -61,9 +64,15 @@ def aot_torch_tensorrt_aten_backend(
6164
settings_aot_autograd["decompostions"] = get_decompositions(
6265
settings.enable_experimental_decompositions
6366
)
64-
return aot_autograd(fw_compiler=_pretraced_backend_autograd)(
65-
gm, sample_inputs, **settings_aot_autograd
66-
)
67+
# This is added since detach lowering leads to alias nodes
68+
# Error - View operation returned a tensor that is the same as the input base tensor
69+
# torch nop_decompositions in torch/_decomp/decompositions.py
70+
if aten.detach in settings_aot_autograd["decompositions"]:
71+
del settings_aot_autograd["decompositions"][aten.detach]
72+
return aot_autograd(
73+
fw_compiler=_pretraced_backend_autograd,
74+
decompositions=get_decompositions(settings.enable_experimental_decompositions),
75+
)(gm, sample_inputs)
6776

6877

6978
def _pretraced_backend(
@@ -103,6 +112,16 @@ def _pretraced_backend(
103112
# Remove detach nodes
104113
remove_detach(gm, settings)
105114

115+
complexInputIndices = []
116+
for i, torch_input in enumerate(torch_inputs):
117+
if torch_inputs[i].dtype == torch.complex64:
118+
complexInputIndices.append(i)
119+
torch_input_real = torch_inputs[i].real
120+
torch_input_imaginary = torch_inputs[i].imag
121+
torch_inputs[i] = torch.stack(
122+
(torch_input_real, torch_input_imaginary), dim=-1
123+
)
124+
106125
# Invoke AOTAutograd to translate operators to aten
107126
if settings.use_aot_joint_export:
108127
gm = aot_export_joint_simple(
@@ -120,6 +139,11 @@ def _pretraced_backend(
120139

121140
logger.debug("Lowered Input graph:\n " + str(gm.graph))
122141

142+
complex_nodes = find_complex_nodes(gm)
143+
if complex_nodes:
144+
replace_complex_placeholder_to_tuple(gm, complexInputIndices)
145+
modify_complex_nodes(gm, complex_nodes)
146+
123147
torchtrt_inputs = prepare_inputs(
124148
torch_inputs, disable_memory_format_check=True
125149
)

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
has_static_shapes_in_args,
1717
)
1818
from torch_tensorrt.dynamo.conversion.converter_utils import (
19+
args_bounds_check,
1920
enforce_tensor_types,
2021
get_positive_dim,
2122
is_only_operator_on_placeholder,
@@ -25,12 +26,6 @@
2526
_LOGGER: logging.Logger = logging.getLogger(__name__)
2627

2728

28-
def args_bounds_check(
29-
args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None
30-
) -> Any:
31-
return args[i] if len(args) > i and args[i] is not None else replacement
32-
33-
3429
def get_ir(target: Target) -> SourceIR:
3530
target_module = getattr(target, "__module__", "None")
3631
if any(

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -913,3 +913,9 @@ def set_layer_name(
913913
else f"{source_ir}_ops.{target.__name__}"
914914
)
915915
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"
916+
917+
918+
def args_bounds_check(
919+
args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None
920+
) -> Any:
921+
return args[i] if len(args) > i and args[i] is not None else replacement
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from ._aten_lowering_pass import *
2+
from ._modify_complex_nodes import modify_complex_nodes
3+
from ._replace_complex_placeholder_to_tuple import replace_complex_placeholder_to_tuple
24
from .remove_sym_nodes import remove_sym_nodes
35
from .repair_input_aliasing import repair_input_aliasing
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import logging
2+
3+
import torch
4+
5+
logger = logging.getLogger(__name__)
6+
7+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
8+
clean_up_graph_after_modifications,
9+
)
10+
11+
12+
def tensorrt_complex_mul(args0, args1):
13+
args0_real, args0_imag = torch.ops.aten.split.Tensor(args0, 1, -1)
14+
args1_real, args1_imag = torch.ops.aten.split.Tensor(args1, 1, -1)
15+
16+
args0_real = torch.ops.aten.squeeze.dim(args0_real, -1)
17+
args0_imag = torch.ops.aten.squeeze.dim(args0_imag, -1)
18+
args1_real = torch.ops.aten.squeeze.dim(args1_real, -1)
19+
args1_imag = torch.ops.aten.squeeze.dim(args1_imag, -1)
20+
21+
complex_mul_real = torch.ops.aten.sub(
22+
torch.ops.aten.mul(args0_real, args1_real),
23+
torch.ops.aten.mul(args0_imag, args1_imag),
24+
)
25+
complex_mul_imag = torch.ops.aten.add(
26+
torch.ops.aten.mul(args0_real, args1_imag),
27+
torch.ops.aten.mul(args0_imag, args1_real),
28+
)
29+
30+
return torch.ops.aten.stack((complex_mul_real, complex_mul_imag), -1)
31+
32+
33+
def remove_complex_real_view_nodes(gm: torch.fx.GraphModule):
34+
modified_graph = False
35+
nodes_to_remove = []
36+
for node in gm.graph.nodes:
37+
if "view_as_complex" in node.name or "view_as_real" in node.name:
38+
nodes_to_remove.append(node)
39+
40+
for node in nodes_to_remove:
41+
input_node = node.args[0] if node.args else None
42+
43+
for other_node in gm.graph.nodes:
44+
new_args = tuple(
45+
input_node if arg is node else arg for arg in other_node.args
46+
)
47+
other_node.args = new_args
48+
49+
gm.graph.erase_node(node)
50+
modified_graph = True
51+
52+
if modified_graph:
53+
gm = clean_up_graph_after_modifications(gm)
54+
logger.debug(
55+
f"Graph after removing view_as_complex nodes and view_as_real nodes:\n{gm.graph}"
56+
)
57+
58+
59+
def modify_reshape_nodes(gm: torch.fx.GraphModule, complex_nodes):
60+
for node in gm.graph.nodes:
61+
if node in complex_nodes:
62+
# slice and transpose will remain same
63+
if "reshape" in node.name:
64+
new_shape = list(node.args[1]) + [2]
65+
node.args = (node.args[0], tuple(new_shape))
66+
67+
68+
def modify_mul_nodes(gm: torch.fx.GraphModule, complex_nodes):
69+
modified_graph = False
70+
for node in gm.graph.nodes:
71+
if node in complex_nodes:
72+
if "mul" in node.name:
73+
complex_mul_args = (node.args[0], node.args[1])
74+
with gm.graph.inserting_after(node):
75+
replacement_node = gm.graph.create_node(
76+
op="call_function",
77+
target=tensorrt_complex_mul,
78+
args=complex_mul_args,
79+
)
80+
node.replace_all_uses_with(replacement_node)
81+
replacement_node.meta.update(node.meta)
82+
modified_graph = True
83+
gm.graph.erase_node(node)
84+
85+
if modified_graph:
86+
gm = clean_up_graph_after_modifications(gm)
87+
logger.debug(
88+
f"Graph after custom complex mul nodes is applied to the graph:\n{gm.graph}"
89+
)
90+
91+
92+
def modify_complex_nodes(gm: torch.fx.GraphModule, complex_nodes):
93+
modify_reshape_nodes(gm, complex_nodes)
94+
remove_complex_real_view_nodes(gm)
95+
modify_mul_nodes(gm, complex_nodes)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import logging
2+
3+
import torch
4+
from torch.fx.node import _get_qualified_name
5+
from torch_tensorrt.dynamo._settings import CompilationSettings
6+
from torch_tensorrt.dynamo.conversion.converter_utils import args_bounds_check
7+
8+
# dead-code elimination, linting, and recompilation for graph, in-place
9+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
10+
clean_up_graph_after_modifications,
11+
)
12+
13+
logger = logging.getLogger(__name__)
14+
15+
# for now creating this node, but mostly will want to modify this in input
16+
17+
18+
def replace_complex_placeholder_to_tuple(
19+
gm: torch.fx.GraphModule, inputListindices
20+
) -> torch.fx.GraphModule:
21+
modified_graph = False
22+
input_arg_list = [f"arg{inputListIndex}_1" for inputListIndex in inputListindices]
23+
for node in gm.graph.nodes:
24+
if node.op == "placeholder" and node.target in input_arg_list:
25+
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
26+
27+
node_shape = node.meta["val"].size()
28+
new_node_shape = node_shape + (2,)
29+
new_node_dtype = None
30+
if node.meta["val"].dtype == torch.complex64:
31+
new_node_dtype = torch.float32
32+
else:
33+
new_node_dtype = torch.float64
34+
fake_mode = FakeTensorMode()
35+
36+
real_tensor = torch.empty(new_node_shape, dtype=new_node_dtype)
37+
with FakeTensorMode() as fake_mode:
38+
new_placeholder_tuple = fake_mode.from_tensor(real_tensor)
39+
node.meta["val"] = new_placeholder_tuple
40+
modified_graph = True
41+
# propagate the meta data change for the downstream ops
42+
# TODO:to check if this is required in all cases
43+
propogate_shape_change(gm, node, fake_mode)
44+
45+
# If graph was modified, clean it up
46+
if modified_graph:
47+
gm = clean_up_graph_after_modifications(gm)
48+
logger.debug(
49+
f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}"
50+
)
51+
52+
return gm
53+
54+
55+
def infer_slice_shape(node):
56+
input_shape = node.args[0].meta["val"].shape
57+
slice_args = node.args
58+
dim = slice_args[1]
59+
start = slice_args[2]
60+
end = slice_args[3]
61+
step = args_bounds_check(slice_args, 4, replacement=1)
62+
new_shape = list(input_shape)
63+
new_shape[dim] = (end - start + step - 1) // step
64+
return tuple(new_shape)
65+
66+
67+
def infer_reshape_shape(node):
68+
return node.args[1]
69+
70+
71+
shape_inference_funcs = {
72+
"torch.ops.aten.slice.Tensor": infer_slice_shape,
73+
"torch.ops.aten.reshape.default": infer_reshape_shape,
74+
}
75+
76+
shape_inference_funcs = {
77+
"torch.ops.aten.slice.Tensor": infer_slice_shape,
78+
"torch.ops.aten.reshape.default": infer_reshape_shape,
79+
}
80+
81+
82+
def propogate_shape_change(node, start_node, fake_mode):
83+
visited_nodes = set()
84+
stack = [start_node]
85+
while stack:
86+
node = stack.pop()
87+
if node in visited_nodes:
88+
continue
89+
visited_nodes.add(node)
90+
update_node_meta(node, fake_mode)
91+
for user in node.users:
92+
if (
93+
user.op == "call_function"
94+
and _get_qualified_name(user.target) == "torch.ops.aten.mul.Tensor"
95+
):
96+
continue
97+
stack.append(user)
98+
99+
100+
def update_node_meta(node, fake_mode):
101+
op_name = node.name
102+
op_target = node.target
103+
104+
if node.op == "call_function":
105+
op_target = _get_qualified_name(node.target)
106+
107+
if op_target in shape_inference_funcs:
108+
new_shape = shape_inference_funcs[op_target](node)
109+
real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype)
110+
node.meta["val"] = fake_mode.from_tensor(real_tensor)
111+
else:
112+
print("No shape for the inference function", {op_name})

py/torch_tensorrt/dynamo/utils.py

+39
Original file line numberDiff line numberDiff line change
@@ -780,3 +780,42 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]
780780
f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node"
781781
)
782782
return output_dtypes
783+
784+
785+
def find_complex_nodes(gm: torch.fx.GraphModule):
786+
complex_nodes = []
787+
complexNodes = {}
788+
for node in gm.graph.nodes:
789+
if is_node_complex(node, complexNodes):
790+
complex_nodes.append(node)
791+
return complex_nodes
792+
793+
794+
def is_node_complex(node: torch.fx.Node, complexNodes):
795+
if not isinstance(node, torch.fx.Node):
796+
return False
797+
if node.name in complexNodes:
798+
return True
799+
if node.op == "call_function" and node.args is not None:
800+
for arg in node.args:
801+
if isinstance(arg, int):
802+
continue
803+
elif isinstance(arg, (list, tuple)):
804+
for eachNode in arg:
805+
if is_node_complex(eachNode, complexNodes):
806+
complexNodes[node.name] = True
807+
return True
808+
809+
elif hasattr(arg, "meta") and "val" in arg.meta:
810+
if isinstance(arg.meta["val"], (list, tuple)):
811+
for eachFakeTensorMeta in arg.meta["val"]:
812+
if eachFakeTensorMeta.dtype in (
813+
torch.complex64,
814+
torch.complex128,
815+
):
816+
complexNodes[node.name] = True
817+
return True
818+
elif arg.meta["val"].dtype in (torch.complex64, torch.complex128):
819+
complexNodes[node.name] = True
820+
return True
821+
return False

0 commit comments

Comments
 (0)