diff --git a/et_replay/et_replay_utils.py b/et_replay/et_replay_utils.py index dd625ee9..c2964f82 100644 --- a/et_replay/et_replay_utils.py +++ b/et_replay/et_replay_utils.py @@ -16,7 +16,7 @@ TORCH_DTYPES_RNG = { "bool": (torch.bool, torch.ones), "int8": (torch.int8, torch.ones), - "half": (torch.half, torch.ones), + "half": (torch.half, torch.randn), "int": (torch.int, torch.ones), "long": (torch.int64, torch.ones), "long int": (torch.int64, torch.ones), @@ -24,14 +24,15 @@ "double": (torch.float64, torch.randn), "signed char": (torch.int8, torch.ones), "unsigned char": (torch.uint8, torch.ones), - "c10::Half": (torch.half, torch.ones), - "c10::BFloat16": (torch.bfloat16, torch.ones), + "c10::Half": (torch.half, torch.randn), + "c10::BFloat16": (torch.bfloat16, torch.randn), + "c10::complex": (torch.complex32, torch.randn), } TORCH_DTYPES_RNG_str = { "bool": ("torch.bool", "torch.ones"), "int8": ("torch.int8", "torch.ones"), - "half": ("torch.half", "torch.ones"), + "half": ("torch.half", "torch.randn"), "int": ("torch.int", "torch.ones"), "long": ("torch.int64", "torch.ones"), "long int": ("torch.int64", "torch.ones"), @@ -39,8 +40,9 @@ "double": ("torch.float64", "torch.randn"), "signed char": ("torch.int8", "torch.ones"), "unsigned char": ("torch.uint8", "torch.ones"), - "c10::Half": ("torch.half", "torch.ones"), - "c10::BFloat16": ("torch.bfloat16", "torch.ones"), + "c10::Half": ("torch.half", "torch.randn"), + "c10::BFloat16": ("torch.bfloat16", "torch.randn"), + "c10::complex": ("torch.complex32", "torch.randn"), } TORCH_DTYPES_BYTES = { @@ -56,6 +58,7 @@ "unsigned char": 1, "c10::Half": 2, "c10::BFloat16": 2, + "c10::complex": 8, } diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index e020590e..a210a14e 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -1153,7 +1153,7 @@ def replaySingle( if self.backendFuncs.get_global_rank() == 0: logger.info( - f"{logLable}[{cnt+1} / {self.max_msg_cnt}] Replayed {recordName} in block [{curBlockStack}]... {global_latency:.2f} us" + f"{logLable}[{cnt+1} / {self.max_msg_cnt}] Replayed {recordName} with id={curComm.id} in block [{curBlockStack}]... {global_latency:.2f} us" ) def benchTime(self, commsParams: commsParamsHolderBase) -> None: diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index f680ffe1..0dcc9f4f 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -7,6 +7,7 @@ import time from collections import defaultdict from datetime import datetime +from enum import Enum import numpy as np import torch @@ -93,6 +94,19 @@ def extract_tensor_from_list(tensor_list): """ +class TensorAllcationMode(Enum): + """ + Enum to represent the tensor allocation mode + """ + + # Allocate input tensors that can not be generated when replaying the trace + # at the beginning and reuse them for all iterations. + PRE_ALLOCATE = 1 + + # Allocate tensors on the fly and free them after they are out of scope + LAZY_ALLOCATE = 2 + + class ExgrReplayManager: def __init__(self): self.numWarmupIters = 1 @@ -147,7 +161,7 @@ def __init__(self): self.tensor_shapes = defaultdict(set) # Dict that maps tensor storage id to its size, and a map for {device, torch.Tensor}. # The tensor with the same storage id may located on different devices. - self.tensor_storage_map: dict[int, []] = defaultdict(set) + self.tensor_storage_map: dict[int, []] = defaultdict(list) # Mark those tensors that occur first as an input in the original et as needing to be instantiated in replay # at the very beginning. self.instantiate = set() @@ -200,6 +214,12 @@ def __init__(self): # of the replay_tensor_id is the same as or greater than the current node id, that tensor is deleted self.replay_tensor_id_to_last_node_id_map = {} + self.tensor_storage_id_to_last_node_id_map: dict[int, int] = defaultdict(int) + + self.tensor_allocate_mode: TensorAllcationMode = ( + TensorAllcationMode.PRE_ALLOCATE + ) + # Unrecognized nodes that are neither operators nor predefined label nodes. self.exceptional_nodes = set() @@ -219,6 +239,8 @@ def __init__(self): # Replay on CPU. self.cpu = False + self.available_memory = 0 + def initBench(self): self.numWarmupIters = self.args.warmup_iter self.numIters = self.args.iter @@ -234,6 +256,8 @@ def initBench(self): self.wait_delay = self.args.delay self.cpu = self.args.cpu self.tf32 = self.args.tf32 + if self.args.enable_lazy_tensor_allocation: + self.tensor_allocate_mode = TensorAllcationMode.LAZY_ALLOCATE # Single trace. if not self.args.trace_path: @@ -318,6 +342,14 @@ def initBench(self): else: self.device = torch.device(self.cuda) + # Total memory of the device + total_memory = torch.cuda.get_device_properties(self.device).total_memory + # Memory currently allocated and reserved by PyTorch + reserved_memory = torch.cuda.memory_reserved(self.device) + # Available memory (approximate) by subtracting allocated and reserved from total + self.available_memory = (total_memory - reserved_memory) / (1024**3) + logger.info(f"Available memory: {self.available_memory} GB") + def detect_tensor_device(self, root): # Automatically detect whether the captured tensor information includes device. # Just a temporary utility to accommodate old and new versions et and should be removed later. @@ -481,13 +513,16 @@ def has_parallel_parent(node): assert len(self.parallel_nodes_ids) == len(set(self.parallel_nodes_ids)) def analyze_tensors(self): - def add_storage_tensor(t_id, device): + def add_storage_tensor(node_id, t_id, device): # t_id is a tupe of (tensor_id, storage_id, offset, number of element, # number of bytes for each element, device) # ET does not save the size of the tensor storage, so we iterate over all the # tensors to find the maximum size of the storage. storage_id = t_id[1] + self.tensor_storage_id_to_last_node_id_map[storage_id] = max( + node_id, self.tensor_storage_id_to_last_node_id_map[storage_id] + ) if storage_id not in self.tensor_storage_map: # the storage size for this tensor is the sum of the storage offset and # number of elements * number of bytes per element. @@ -501,7 +536,7 @@ def add_storage_tensor(t_id, device): ) def add_unique_tensor(node_name, node_id, t_id, shape, input, device=-1): - add_storage_tensor(t_id, device) + add_storage_tensor(node_id, t_id, device) # If we did not see this tensor before, add it as a unique tensor. if t_id not in self.original_unique_tensors: self.original_unique_tensors.add(t_id) @@ -687,7 +722,16 @@ def allocate_comp_tensors(self, node): # noqa: C901 device = t_id[5] t_id = tuple(list(t_id)[:5]) replay_t_id = self.tensors_mapping[(node.id, t_id, True)] - if t_id in self.input_tensor_ids: + if t_id not in self.input_tensor_ids: + continue + found_tensor = False + if self.tensor_allocate_mode == TensorAllcationMode.PRE_ALLOCATE: + if replay_t_id in self.tensor_registry_permanent.keys(): + found_tensor = True + else: + if replay_t_id in self.tensor_registry.keys(): + found_tensor = True + if not found_tensor: try: dtype, _ = TORCH_DTYPES_RNG[data_type.lstrip("Tensor(").rstrip(")")] @@ -712,11 +756,8 @@ def allocate_comp_tensors(self, node): # noqa: C901 ) if tensor is not None: node.pre_load_tensors[idx] = tensor - if ( - tensor is None - and replay_t_id in self.instantiate - and replay_t_id not in self.tensor_registry_permanent.keys() - ): + + if tensor is None: tensor = self.get_tensor_from_storage( t_id[1], # storage_id t_id[2], # offset @@ -727,12 +768,22 @@ def allocate_comp_tensors(self, node): # noqa: C901 strides, node.get_input_tensor_range(idx), ) - self.tensor_registry_permanent[replay_t_id] = tensor + + if ( + self.tensor_allocate_mode + == TensorAllcationMode.PRE_ALLOCATE + ): + self.tensor_registry_permanent[replay_t_id] = tensor + else: + self.tensor_registry[replay_t_id] = tensor except KeyError: if data_type != "Tensor(nullptr (uninitialized))": logger.info(f"KeyError: {node.id}, {t_id}, {data_type}") - self.tensor_registry_permanent[replay_t_id] = None + if self.tensor_allocate_mode == TensorAllcationMode.PRE_ALLOCATE: + self.tensor_registry_permanent[replay_t_id] = None + else: + self.tensor_registry[replay_t_id] = None def build_func(self, node): if node.kernel_backend == "triton": @@ -1182,7 +1233,8 @@ def get_tensor_from_storage( strides, tensor_range, ): - assert storage_id in self.tensor_storage_map + if storage_id not in self.tensor_storage_map: + return None tensor_data = self.tensor_storage_map[storage_id] if device not in tensor_data[1]: if ( @@ -1246,10 +1298,20 @@ def get_tensor_from_storage( return x + def free_tensor_in_storage(self, storage_id, node_id): + if ( + self.tensor_allocate_mode == TensorAllcationMode.LAZY_ALLOCATE + and storage_id in self.tensor_storage_id_to_last_node_id_map + and node_id >= self.tensor_storage_id_to_last_node_id_map[storage_id] + ): + self.tensor_storage_map[storage_id][1] = {} + def get_data(self, node, is_input): try: if is_input: data_in = node.inputs + if self.tensor_allocate_mode == TensorAllcationMode.LAZY_ALLOCATE: + self.allocate_comp_tensors(node) else: data_in = node.outputs data_out = [] @@ -1363,7 +1425,24 @@ def get_comm_outputs(self, node): except Exception as e: logger.info(f"Outputs error: {e} at node: {node.id}") + def free_device_memory(self, force: bool = False): + free_memory = force + allocated_memory = torch.cuda.memory_allocated(self.device) / 1024 / 1024 / 1024 + if allocated_memory / self.available_memory > self.args.device_memory_threshold: + free_memory = True + + if free_memory: + for _, v in self.tensor_storage_map.items(): + if len(v) > 1: + v[1] = {} + self.tensor_registry = {} + def run_op(self, node, iter, cnt): # noqa: C901 + if ( + self.tensor_allocate_mode == TensorAllcationMode.LAZY_ALLOCATE + and self.args.device_memory_threshold != 1.0 + ): + self.free_device_memory() if isinstance(node, commsArgs): if self.debug and iter >= self.numWarmupIters: start_ns = time.time_ns() @@ -1399,6 +1478,7 @@ def run_op(self, node, iter, cnt): # noqa: C901 inputs, msg = self.get_data(node, True) if msg != "": + logger.info(f"Failed to get input data for node {node.id}: {msg}") return False, msg # TODO: why need this hack? # Workaround to eliminate the "strides() called on undefined Tensor" error. @@ -1443,9 +1523,10 @@ def run_op(self, node, iter, cnt): # noqa: C901 if ( node.id >= self.replay_tensor_id_to_last_node_id_map[replay_t_id] - and replay_t_id not in self.instantiate + and replay_t_id in self.tensor_registry ): del self.tensor_registry[replay_t_id] + self.free_tensor_in_storage(t_id[1], node.id) for (_, t_id, _), output in zip(get_output_tensors(node), outputs): if self.tensor_with_device: @@ -1454,18 +1535,16 @@ def run_op(self, node, iter, cnt): # noqa: C901 if t_id in self.input_tensor_ids: replay_t_id = self.tensors_mapping[(node.id, t_id, False)] if ( - replay_t_id not in self.unchangeable_intermediate_tensors - and replay_t_id not in self.instantiate + node.id + < self.replay_tensor_id_to_last_node_id_map[replay_t_id] + and replay_t_id not in self.tensor_registry ): - if ( - node.id - < self.replay_tensor_id_to_last_node_id_map[replay_t_id] - ): - self.tensor_registry[replay_t_id] = output - else: - del output + self.tensor_registry[replay_t_id] = output + else: + del output else: del output + self.free_tensor_in_storage(t_id[1], node.id) except Exception as e: msg = f"Run op exception Error: {e}, node id: {node.id}, node name: {node.name}" @@ -1528,7 +1607,6 @@ def remove_op_with_runtime_error(self): success, msg = self.run_op(node, 0, cnt) if success: continue - if ( msg.find("RuntimeError: CUDA error") != -1 or msg.find("torch.OutOfMemoryError") != -1 @@ -1600,9 +1678,8 @@ def preprocess_graph(self): if self.generator: self.generate_code() - else: + elif self.tensor_allocate_mode == TensorAllcationMode.PRE_ALLOCATE: self.allocate_tensors() - self.reset_registry() def benchTime(self): # A dictionary to save the benchmark result. @@ -1644,7 +1721,7 @@ def run_ops(event_1, event_2, iter): for cnt, node in enumerate(self.sorted_nodes): success, _ = self.run_op(node, iter, cnt) if not success: - break + return False event_2.record() if not (self.compute_only or self.args.separate): self.commsBench.resetComms() @@ -1663,6 +1740,8 @@ def run_ops(event_1, event_2, iter): gc.collect() torch.cuda.empty_cache() + return True + # log real time qps every # iterations. qps_print_interval = 10 @@ -1691,10 +1770,17 @@ def run_iter(iter): ) prev_iter = iter start_ns = time.time_ns() - run_ops(event_1, event_2, iter) + + if self.tensor_allocate_mode == TensorAllcationMode.LAZY_ALLOCATE: + self.free_device_memory(force=True) + else: + self.reset_registry() + ret = run_ops(event_1, event_2, iter) if iter >= self.numWarmupIters: total_time += event_1.elapsed_time(event_2) + return ret + if self.et_profile: et_file = "/tmp/replay_et.json" et = ExecutionTraceObserver() @@ -1723,6 +1809,8 @@ def run_iter(iter): self.sorted_nodes.sort(key=lambda x: x.id) self.commsBench.replay_start_time = time.monotonic_ns() + torch.cuda.set_device(self.cuda_id) + if self.profile_replay: try: from aiplatform.monitoring.unitrace.upload_manifold import ( @@ -1749,16 +1837,23 @@ def run_iter(iter): record_shapes=True, on_trace_ready=on_trace_ready, ) as prof: + success = True for iter in range(self.numWarmupIters + self.numIters): - run_iter(iter) + if not run_iter(iter): + success = False + break prof.step() - benchmark_result["execution finished"] = True - logger.info("Execution finished!") + benchmark_result["execution finished"] = success else: + success = True for iter in range(self.numWarmupIters + self.numIters): - run_iter(iter) - benchmark_result["execution finished"] = True - logger.info("Execution finished!") + if not run_iter(iter): + success = False + break + benchmark_result["execution finished"] = success + + if not benchmark_result["execution finished"]: + return benchmark_result if self.profile_memory: logger.info("Allocated GPU memory(B):") @@ -1937,6 +2032,20 @@ def readComputeArgs(self, check_args: bool = True): default=False, help="When true, the node skip list will be updated with the nodes that are skipped during replay.", ) + + parser.add_argument( + "--enable-lazy-tensor-allocation", + action="store_true", + default=False, + help="When true, the tensors will be allocated lazily during replay.", + ) + + parser.add_argument( + "--device-memory-threshold", + type=float, + default=1.0, + help="When the device memory usage is above this threshold, the replay will release allocated tensors.", + ) self.args, _ = parser.parse_known_args() # Check if both 'input' and 'trace_path' are not provided