|
| 1 | +import logging |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch.fx.node import _get_qualified_name |
| 5 | +from torch_tensorrt.dynamo._settings import CompilationSettings |
| 6 | +from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check |
| 7 | + |
| 8 | +# dead-code elimination, linting, and recompilation for graph, in-place |
| 9 | +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( |
| 10 | + clean_up_graph_after_modifications, |
| 11 | +) |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | +# for now creating this node, but mostly will want to modify this in input |
| 16 | + |
| 17 | + |
| 18 | +def replace_complex_placeholder_to_tuple( |
| 19 | + gm: torch.fx.GraphModule, inputListindices |
| 20 | +) -> torch.fx.GraphModule: |
| 21 | + modified_graph = False |
| 22 | + input_arg_list = [f"arg{inputListIndex}_1" for inputListIndex in inputListindices] |
| 23 | + for node in gm.graph.nodes: |
| 24 | + if node.op == "placeholder" and node.target in input_arg_list: |
| 25 | + from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode |
| 26 | + |
| 27 | + node_shape = node.meta["val"].size() |
| 28 | + new_node_shape = node_shape + (2,) |
| 29 | + new_node_dtype = None |
| 30 | + if node.meta["val"].dtype == torch.complex64: |
| 31 | + new_node_dtype = torch.float32 |
| 32 | + else: |
| 33 | + new_node_dtype = torch.float64 |
| 34 | + fake_mode = FakeTensorMode() |
| 35 | + |
| 36 | + real_tensor = torch.empty(new_node_shape, dtype=new_node_dtype) |
| 37 | + with FakeTensorMode() as fake_mode: |
| 38 | + new_placeholder_tuple = fake_mode.from_tensor(real_tensor) |
| 39 | + node.meta["val"] = new_placeholder_tuple |
| 40 | + modified_graph = True |
| 41 | + # propagate the meta data change for the downstream ops |
| 42 | + # TODO:to check if this is required in all cases |
| 43 | + propogate_shape_change(gm, node, fake_mode) |
| 44 | + |
| 45 | + # If graph was modified, clean it up |
| 46 | + if modified_graph: |
| 47 | + gm = clean_up_graph_after_modifications(gm) |
| 48 | + logger.debug( |
| 49 | + f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}" |
| 50 | + ) |
| 51 | + |
| 52 | + return gm |
| 53 | + |
| 54 | + |
| 55 | +def infer_slice_shape(node): |
| 56 | + input_shape = node.args[0].meta["val"].shape |
| 57 | + slice_args = node.args |
| 58 | + dim = slice_args[1] |
| 59 | + start = slice_args[2] |
| 60 | + end = slice_args[3] |
| 61 | + step = args_bounds_check(slice_args, 4, replacement=1) |
| 62 | + new_shape = list(input_shape) |
| 63 | + new_shape[dim] = (end - start + step - 1) // step |
| 64 | + return tuple(new_shape) |
| 65 | + |
| 66 | + |
| 67 | +def infer_reshape_shape(node): |
| 68 | + return node.args[1] |
| 69 | + |
| 70 | + |
| 71 | +shape_inference_funcs = { |
| 72 | + "torch.ops.aten.slice.Tensor": infer_slice_shape, |
| 73 | + "torch.ops.aten.reshape.default": infer_reshape_shape, |
| 74 | +} |
| 75 | + |
| 76 | +shape_inference_funcs = { |
| 77 | + "torch.ops.aten.slice.Tensor": infer_slice_shape, |
| 78 | + "torch.ops.aten.reshape.default": infer_reshape_shape, |
| 79 | +} |
| 80 | + |
| 81 | + |
| 82 | +def propogate_shape_change(node, start_node, fake_mode): |
| 83 | + visited_nodes = set() |
| 84 | + stack = [start_node] |
| 85 | + while stack: |
| 86 | + node = stack.pop() |
| 87 | + if node in visited_nodes: |
| 88 | + continue |
| 89 | + visited_nodes.add(node) |
| 90 | + update_node_meta(node, fake_mode) |
| 91 | + for user in node.users: |
| 92 | + if ( |
| 93 | + user.op == "call_function" |
| 94 | + and _get_qualified_name(user.target) == "torch.ops.aten.mul.Tensor" |
| 95 | + ): |
| 96 | + continue |
| 97 | + stack.append(user) |
| 98 | + |
| 99 | + |
| 100 | +def update_node_meta(node, fake_mode): |
| 101 | + op_name = node.name |
| 102 | + op_target = node.target |
| 103 | + |
| 104 | + if node.op == "call_function": |
| 105 | + op_target = _get_qualified_name(node.target) |
| 106 | + |
| 107 | + if op_target in shape_inference_funcs: |
| 108 | + new_shape = shape_inference_funcs[op_target](node) |
| 109 | + real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype) |
| 110 | + node.meta["val"] = fake_mode.from_tensor(real_tensor) |
| 111 | + else: |
| 112 | + print("No shape for the inference function", {op_name}) |
0 commit comments