From 9c1b504745928e6c746616ec606f8e845d45cf3c Mon Sep 17 00:00:00 2001 From: Grain Team Date: Fri, 20 Dec 2024 12:02:12 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 708388937 --- grain/_src/python/dataset/stats.py | 107 ++++++++++++++++-- grain/_src/python/dataset/stats_test.py | 2 +- .../dataset/transformations/prefetch.py | 4 +- 3 files changed, 104 insertions(+), 9 deletions(-) diff --git a/grain/_src/python/dataset/stats.py b/grain/_src/python/dataset/stats.py index b0ae0372..94dc6bf6 100644 --- a/grain/_src/python/dataset/stats.py +++ b/grain/_src/python/dataset/stats.py @@ -68,6 +68,7 @@ _AVG_PROCESSING_TIME_COLUMN_NAME = "avg processing time" _COLUMN_NAME_OVERRIDES = types.MappingProxyType({ + "wait_time_ratio": "percent wait time", "min_processing_time_ns": "min processing time", "max_processing_time_ns": "max processing time", "total_processing_time_ns": "total processing time", @@ -79,6 +80,10 @@ _MAX_ROW_LINES = 5 +def _format_ratio_as_percent(value: float) -> str: + return f"{value*100:.2f}%" + + def _pretty_format_ns(value: int) -> str: """Pretty formats a time value in nanoseconds to human readable value.""" if value < 1000: @@ -104,6 +109,89 @@ def _get_avg_processing_time_ns( ) +def _get_nodes_before_prefetch( + node: int, summary: execution_summary_pb2.ExecutionSummary +) -> list[int]: + """Returns nodes in the path from a given node to a prefetch node.""" + child_nodes = [] + nodes_to_visit = [node] + while nodes_to_visit: + node_id = nodes_to_visit.pop() + node = summary.nodes[node_id] + child_nodes.append(node_id) + if node.is_prefetch: + continue # Skip adding inputs for the prefetch node + nodes_to_visit.extend(node.inputs) + return child_nodes + + +def _find_aggregated_processing_time( + summary: execution_summary_pb2.ExecutionSummary, + node_ids: list[int], +) -> int: + """Finds aggregated processing time for the given node IDs.""" + return sum( + summary.nodes[node_id].total_processing_time_ns for node_id in node_ids + ) + + +def _compute_wait_time_ratio( + summary: execution_summary_pb2.ExecutionSummary, + node_id: int, + aggregated_wait_time_ns: int, + prefetch_factor: int = 1, +) -> None: + """Computes the wait time ratio for all the nodes in the execution summary. + + Args: + summary: The execution summary to update the `wait_time_ratio` for. + node_id: The current node for which to compute the `wait_time_ratio`. + aggregated_wait_time_ns: The aggregated wait time of the nodes running under + prefetch. + prefetch_factor: The factor by which to multiply the `total_processing_time` + of the node to get it's wait time ratio. + """ + if aggregated_wait_time_ns == 0: + return + node = summary.nodes[node_id] + node_wait_ratio = prefetch_factor * ( + node.total_processing_time_ns / aggregated_wait_time_ns + ) + node.wait_time_ratio = round(node_wait_ratio, 4) + for input_node_id in node.inputs: + # If the node is executed in multiple threads, the iterator's wait time + # ratio attributed to the prefetch node is distributed among these nodes + # proportionally to their total processing time. + if node.is_prefetch: + prefetch_factor = node.wait_time_ratio + prefetch_child_nodes = _get_nodes_before_prefetch(input_node_id, summary) + aggregated_wait_time_ns = _find_aggregated_processing_time( + summary, prefetch_child_nodes + ) + # The `wait_time_ratio` of the prefetch node is sum of `wait_time_ratio` + # of all the nodes running under it. Here we set it to 0 as it is already + # accounted for in the ancestor nodes and sum of `wait_time_ratio` of all + # the nodes in a pipeline should be 1. + node.wait_time_ratio = 0 + _compute_wait_time_ratio( + summary, + input_node_id, + aggregated_wait_time_ns, + prefetch_factor, + ) + + +def _populate_wait_time_ratio( + summary: execution_summary_pb2.ExecutionSummary, +) -> None: + """Populates the `wait_time_ratio` for all the nodes in the execution summary.""" + iterator_nodes = _get_nodes_before_prefetch(0, summary) + aggregated_wait_time_ns = _find_aggregated_processing_time( + summary, iterator_nodes + ) + _compute_wait_time_ratio(summary, 0, aggregated_wait_time_ns) + + def _pretty_format_summary( summary: execution_summary_pb2.ExecutionSummary, ) -> str: @@ -114,9 +202,6 @@ def _pretty_format_summary( # the visualization graph. col_names.remove("output_spec") col_names.remove("is_output") - # TODO: Add a column for `is_prefetch` in the logged execution - # summary. - col_names.remove("wait_time_ratio") col_names.remove("is_prefetch") # Insert the average processing time column after the max processing time # column. @@ -141,16 +226,23 @@ def _pretty_format_summary( for node_id in sorted(summary.nodes, reverse=True): row_values = [] + node = summary.nodes[node_id] for name in col_names: - is_total_processing_time_zero = ( - summary.nodes[node_id].total_processing_time_ns == 0 - ) + is_total_processing_time_zero = node.total_processing_time_ns == 0 if name == _AVG_PROCESSING_TIME_COLUMN_NAME: value = _get_avg_processing_time_ns(summary, node_id) else: value = getattr(summary.nodes[node_id], name) - if name in ( + if name == "wait_time_ratio": + if node.is_prefetch: + # The iterator wait time of the prefetch node is distributed across + # its child nodes. + col_value = "N/A" + else: + col_value = _format_ratio_as_percent(value) + + elif name in ( "min_processing_time_ns", "max_processing_time_ns", "total_processing_time_ns", @@ -561,6 +653,7 @@ def _get_execution_summary(self) -> execution_summary_pb2.ExecutionSummary: """Returns ExecutionStats Summary for the dataset pipeline.""" execution_summary = execution_summary_pb2.ExecutionSummary() result, _ = self._build_execution_summary(execution_summary, 0) + _populate_wait_time_ratio(result) return result @contextlib.contextmanager diff --git a/grain/_src/python/dataset/stats_test.py b/grain/_src/python/dataset/stats_test.py index d4d5cdd1..803e0bd2 100644 --- a/grain/_src/python/dataset/stats_test.py +++ b/grain/_src/python/dataset/stats_test.py @@ -100,7 +100,7 @@ "[]" ││ - ││ MapDatasetIterator(transform= @ .../python/dataset/stats_test.py:462) + ││ MapDatasetIterator(transform= @ .../python/dataset/stats_test.py:525) ││ ╲╱ {'data': "[]", diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 477a85a2..21c9a838 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -136,9 +136,11 @@ def _threshold_checker(self): ) def __next__(self) -> T: + # The time recorded here is the time spent in prefetch node to return an + # element, including the time spent in parent node. + timer = dataset_stats.Timer() # We loop here to skip all None elements (in case the underlying dataset # is sparse), if self._allow_nones = False, else we return Nones too. - timer = dataset_stats.Timer() while True: if self._next_index == self._dataset_length: break