Skip to content

Commit 6707c6f

Browse files
committed
changes to make Llama example work
1 parent f6c73d2 commit 6707c6f

File tree

6 files changed

+268
-0
lines changed

6 files changed

+268
-0
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

+16
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from torch_tensorrt.dynamo._compiler import compile_module
1515
from torch_tensorrt.dynamo.lowering import (
1616
get_decompositions,
17+
modify_complex_nodes,
1718
post_lowering,
1819
remove_detach,
1920
remove_sym_nodes,
2021
repair_input_aliasing,
22+
replace_complex_placeholder_to_tuple,
2123
)
2224
from torch_tensorrt.dynamo.utils import (
2325
parse_dynamo_kwargs,
@@ -103,6 +105,16 @@ def _pretraced_backend(
103105
# Remove detach nodes
104106
remove_detach(gm, settings)
105107

108+
complexInputIndices = []
109+
for i, torch_input in enumerate(torch_inputs):
110+
if torch_inputs[i].dtype == torch.complex64:
111+
complexInputIndices.append(i)
112+
torch_input_real = torch_inputs[i].real
113+
torch_input_imaginary = torch_inputs[i].imag
114+
torch_inputs[i] = torch.stack(
115+
(torch_input_real, torch_input_imaginary), dim=-1
116+
)
117+
106118
# Invoke AOTAutograd to translate operators to aten
107119
if settings.use_aot_joint_export:
108120
gm = aot_export_joint_simple(
@@ -120,6 +132,10 @@ def _pretraced_backend(
120132

121133
logger.debug("Lowered Input graph:\n " + str(gm.graph))
122134

135+
complex_nodes = find_complex_nodes(gm)
136+
replace_complex_placeholder_to_tuple(gm, complexInputIndices)
137+
modify_complex_nodes(gm, complex_nodes)
138+
123139
torchtrt_inputs = prepare_inputs(
124140
torch_inputs, disable_memory_format_check=True
125141
)
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from ._aten_lowering_pass import *
2+
from ._modify_complex_nodes import modify_complex_nodes
3+
from ._replace_complex_placeholder_to_tuple import replace_complex_placeholder_to_tuple
24
from .remove_sym_nodes import remove_sym_nodes
35
from .repair_input_aliasing import repair_input_aliasing
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import logging
2+
3+
import torch
4+
5+
logger = logging.getLogger(__name__)
6+
7+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
8+
clean_up_graph_after_modifications,
9+
)
10+
11+
12+
def tensorrt_complex_mul(args0, args1):
13+
args0_real, args0_imag = torch.ops.aten.split.Tensor(args0, 1, -1)
14+
args1_real, args1_imag = torch.ops.aten.split.Tensor(args1, 1, -1)
15+
16+
args0_real = torch.ops.aten.squeeze.dim(args0_real, -1)
17+
args0_imag = torch.ops.aten.squeeze.dim(args0_imag, -1)
18+
args1_real = torch.ops.aten.squeeze.dim(args1_real, -1)
19+
args1_imag = torch.ops.aten.squeeze.dim(args1_imag, -1)
20+
21+
complex_mul_real = torch.ops.aten.sub(
22+
torch.ops.aten.mul(args0_real, args1_real),
23+
torch.ops.aten.mul(args0_imag, args1_imag),
24+
)
25+
complex_mul_imag = torch.ops.aten.add(
26+
torch.ops.aten.mul(args0_real, args1_imag),
27+
torch.ops.aten.mul(args0_imag, args1_real),
28+
)
29+
30+
return torch.ops.aten.stack((complex_mul_real, complex_mul_imag), -1)
31+
32+
33+
def remove_complex_real_view_nodes(gm: torch.fx.GraphModule):
34+
modified_graph = False
35+
nodes_to_remove = []
36+
for node in gm.graph.nodes:
37+
if "view_as_complex" in node.name or "view_as_real" in node.name:
38+
nodes_to_remove.append(node)
39+
40+
for node in nodes_to_remove:
41+
input_node = node.args[0] if node.args else None
42+
43+
for other_node in gm.graph.nodes:
44+
new_args = tuple(
45+
input_node if arg is node else arg for arg in other_node.args
46+
)
47+
other_node.args = new_args
48+
49+
gm.graph.erase_node(node)
50+
modified_graph = True
51+
52+
if modified_graph:
53+
gm = clean_up_graph_after_modifications(gm)
54+
logger.debug(
55+
f"Graph after removing view_as_complex nodes and view_as_real nodes:\n{gm.graph}"
56+
)
57+
58+
59+
def modify_reshape_nodes(gm: torch.fx.GraphModule, complex_nodes):
60+
for node in gm.graph.nodes:
61+
if node in complex_nodes:
62+
# slice and transpose will remain same
63+
if "reshape" in node.name:
64+
new_shape = list(node.args[1]) + [2]
65+
node.args = (node.args[0], tuple(new_shape))
66+
67+
68+
def modify_mul_nodes(gm: torch.fx.GraphModule, complex_nodes):
69+
modified_graph = False
70+
for node in gm.graph.nodes:
71+
if node.name in complex_nodes:
72+
if "mul" in node.name:
73+
complex_mul_args = (node.args[0], node.args[1])
74+
with gm.graph.inserting_after(node):
75+
replacement_node = gm.graph.create_node(
76+
op="call_function",
77+
target=tensorrt_complex_mul,
78+
args=complex_mul_args,
79+
)
80+
node.replace_all_uses_with(replacement_node)
81+
replacement_node.meta.update(node.meta)
82+
modified_graph = True
83+
gm.graph.erase_node(node)
84+
85+
if modified_graph:
86+
gm = clean_up_graph_after_modifications(gm)
87+
logger.debug(
88+
f"Graph after custom complex mul nodes is applied to the graph:\n{gm.graph}"
89+
)
90+
91+
92+
def modify_complex_nodes(gm: torch.fx.GraphModule, complex_nodes):
93+
modify_reshape_nodes(gm, complex_nodes)
94+
remove_complex_real_view_nodes(gm)
95+
modify_mul_nodes(gm, complex_nodes)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import logging
2+
3+
import torch
4+
from torch.fx.node import _get_qualified_name
5+
from torch_tensorrt.dynamo._settings import CompilationSettings
6+
from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check
7+
8+
# dead-code elimination, linting, and recompilation for graph, in-place
9+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
10+
clean_up_graph_after_modifications,
11+
)
12+
13+
logger = logging.getLogger(__name__)
14+
15+
# for now creating this node, but mostly will want to modify this in input
16+
17+
18+
def replace_complex_placeholder_to_tuple(
19+
gm: torch.fx.GraphModule, inputListindices
20+
) -> torch.fx.GraphModule:
21+
modified_graph = False
22+
input_arg_list = [f"arg{inputListIndex}_1" for inputListIndex in inputListindices]
23+
for node in gm.graph.nodes:
24+
if node.op == "placeholder" and node.target in input_arg_list:
25+
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
26+
27+
node_shape = node.meta["val"].size()
28+
new_node_shape = node_shape + (2,)
29+
new_node_dtype = None
30+
if node.meta["val"].dtype == torch.complex64:
31+
new_node_dtype = torch.float32
32+
else:
33+
new_node_dtype = torch.float64
34+
fake_mode = FakeTensorMode()
35+
36+
real_tensor = torch.empty(new_node_shape, dtype=new_node_dtype)
37+
with FakeTensorMode() as fake_mode:
38+
new_placeholder_tuple = fake_mode.from_tensor(real_tensor)
39+
node.meta["val"] = new_placeholder_tuple
40+
modified_graph = True
41+
# propagate the meta data change for the downstream ops
42+
# TODO:to check if this is required in all cases
43+
propogate_shape_change(gm, node, fake_mode)
44+
45+
# If graph was modified, clean it up
46+
if modified_graph:
47+
gm = clean_up_graph_after_modifications(gm)
48+
logger.debug(
49+
f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}"
50+
)
51+
52+
return gm
53+
54+
55+
def infer_slice_shape(node):
56+
input_shape = node.args[0].meta["val"].shape
57+
slice_args = node.args
58+
dim = slice_args[1]
59+
start = slice_args[2]
60+
end = slice_args[3]
61+
step = args_bounds_check(slice_args, 4, replacement=1)
62+
new_shape = list(input_shape)
63+
new_shape[dim] = (end - start + step - 1) // step
64+
return tuple(new_shape)
65+
66+
67+
def infer_reshape_shape(node):
68+
return node.args[1]
69+
70+
71+
shape_inference_funcs = {
72+
"torch.ops.aten.slice.Tensor": infer_slice_shape,
73+
"torch.ops.aten.reshape.default": infer_reshape_shape,
74+
}
75+
76+
shape_inference_funcs = {
77+
"torch.ops.aten.slice.Tensor": infer_slice_shape,
78+
"torch.ops.aten.reshape.default": infer_reshape_shape,
79+
}
80+
81+
82+
def propogate_shape_change(node, start_node, fake_mode):
83+
visited_nodes = set()
84+
stack = [start_node]
85+
while stack:
86+
node = stack.pop()
87+
if node in visited_nodes:
88+
continue
89+
visited_nodes.add(node)
90+
update_node_meta(node, fake_mode)
91+
for user in node.users:
92+
if (
93+
user.op == "call_function"
94+
and _get_qualified_name(user.target) == "torch.ops.aten.mul.Tensor"
95+
):
96+
continue
97+
stack.append(user)
98+
99+
100+
def update_node_meta(node, fake_mode):
101+
op_name = node.name
102+
op_target = node.target
103+
104+
if node.op == "call_function":
105+
op_target = _get_qualified_name(node.target)
106+
107+
if op_target in shape_inference_funcs:
108+
new_shape = shape_inference_funcs[op_target](node)
109+
real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype)
110+
node.meta["val"] = fake_mode.from_tensor(real_tensor)
111+
else:
112+
print("No shape for the inference function", {op_name})

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+8
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,14 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
240240
for i in inputs
241241
]
242242

243+
for i, contiguous_input in enumerate(contiguous_inputs):
244+
if contiguous_input.dtype == torch.complex64:
245+
contiguous_input_real = contiguous_input.real
246+
contiguous_input_imag = contiguous_input.imag
247+
contiguous_inputs[i] = torch.stack(
248+
(contiguous_input_real, contiguous_input_imag), dim=-1
249+
)
250+
243251
with (
244252
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
245253
if self.profiling_enabled

py/torch_tensorrt/dynamo/utils.py

+35
Original file line numberDiff line numberDiff line change
@@ -780,3 +780,38 @@ def get_output_dtypes(output: Any, truncate_doulbe: bool = False) -> List[dtype]
780780
f"got unexpected type {type(output)}, expected type is a torch.fx.node.Node or a tuple/list of torch.fx.node.Node"
781781
)
782782
return output_dtypes
783+
784+
785+
def find_complex_nodes(gm: torch.fx.GraphModule):
786+
complex_nodes = []
787+
complexNodes = {}
788+
for node in gm.graph.nodes:
789+
if is_node_complex(node):
790+
complex_nodes.append(node, complexNodes)
791+
return complex_nodes
792+
793+
794+
def is_node_complex(node: torch.fx.Node, complexNodes):
795+
if node.name in complexNodes:
796+
return True
797+
if node.op == "call_function" and node.args is not None:
798+
if isinstance(node.args[0], (list, tuple)):
799+
for eachNode in node.args[0]:
800+
if is_node_complex(eachNode):
801+
complexNodes[node.name] = True
802+
return True
803+
elif isinstance(node.args[0].meta["val"], (list, tuple)):
804+
for eachFakeTensorMeta in node.args[0].meta["val"]:
805+
if (
806+
eachFakeTensorMeta.dtype == torch.complex64
807+
or eachFakeTensorMeta.dtype == torch.complex128
808+
):
809+
complexNodes[node.name] = True
810+
return True
811+
elif (
812+
node.args[0].meta["val"].dtype == torch.complex64
813+
or node.args[0].meta["val"].dtype == torch.complex128
814+
):
815+
complexNodes[node.name] = True
816+
return True
817+
return False

0 commit comments

Comments
 (0)