diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index 68dc6787..e29525b8 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -698,6 +698,7 @@ def _initialize_stats( parents_stats = [] if hasattr(self, "_parents"): for p in self._parents: + p._stats._is_multithread = True # pylint: disable=protected-access parents_stats.append(p._initialize_stats(execution_tracking_mode)) # pylint: disable=protected-access self._stats = dataset_stats.make_stats( dataset_stats.StatsConfig( diff --git a/grain/_src/python/dataset/stats.py b/grain/_src/python/dataset/stats.py index 469dfe4c..2da24493 100644 --- a/grain/_src/python/dataset/stats.py +++ b/grain/_src/python/dataset/stats.py @@ -68,6 +68,7 @@ _AVG_PROCESSING_TIME_COLUMN_NAME = "avg processing time" _COLUMN_NAME_OVERRIDES = types.MappingProxyType({ + "percent_iterator_wait_time": "percent wait time", "min_processing_time_ns": "min processing time", "max_processing_time_ns": "max processing time", "total_processing_time_ns": "total processing time", @@ -79,6 +80,10 @@ _MAX_ROW_LINES = 5 +def _pretty_format_percent(value: float) -> str: + return f"{value:.2f}%" + + def _pretty_format_ns(value: int) -> str: """Pretty formats a time value in nanoseconds to human readable value.""" if value < 1000: @@ -104,6 +109,70 @@ def _get_avg_processing_time_ns( ) +def _compute_percent_iterator_wait_time( + summary: execution_summary_pb2.ExecutionSummary, +) -> None: + """Calculates relevant statistics from an execution summary. + + This method iterates through the nodes in the provided execution summary + and computes the following: + + - Total iterator wait time: The time spent by the iterator waiting for the + next element to be produced. + - Prefetch iterator wait time: The time spent in the prefetch node waiting + for the next element to be produced. + - Total processing time in multithreaded nodes: The total processing time + consumed by all multithreaded nodes. + + Args: + summary: The execution summary to calculate statistics from. + """ + + def _get_aggregated_metrics(summary: execution_summary_pb2.ExecutionSummary): + total_iterator_wait_time = 0 + prefetch_iterator_wait_time = 0 + total_processing_time_in_multithread = 0 + for node_id in summary.nodes: + node = summary.nodes[node_id] + if node.name.startswith("PrefetchDatasetIterator"): + prefetch_iterator_wait_time = summary.nodes[ + node_id + ].total_processing_time_ns + if node.is_multithread: + total_processing_time_in_multithread += node.total_processing_time_ns + else: + total_iterator_wait_time += node.total_processing_time_ns + return ( + total_iterator_wait_time, + prefetch_iterator_wait_time, + total_processing_time_in_multithread, + ) + + ( + total_iterator_wait_time, + prefetch_iterator_wait_time, + total_processing_time_in_multithread, + ) = _get_aggregated_metrics(summary) + + prefetch_wait_time_percent = ( + prefetch_iterator_wait_time / total_iterator_wait_time + ) * 100 + + for node_id in summary.nodes: + node = summary.nodes[node_id] + if node.is_multithread: + # If the node is executed in multiple threads, the percent of the + # iterator wait time attributed to the prefetch node is distributed + # among these nodes proportionally to their total processing time. + node.percent_iterator_wait_time = prefetch_wait_time_percent * ( + node.total_processing_time_ns / total_processing_time_in_multithread + ) + else: + node.percent_iterator_wait_time = ( + node.total_processing_time_ns / total_iterator_wait_time * 100 + ) + + def _pretty_format_summary( summary: execution_summary_pb2.ExecutionSummary, ) -> str: @@ -114,6 +183,7 @@ def _pretty_format_summary( # the visualization graph. col_names.remove("output_spec") col_names.remove("is_output") + col_names.remove("is_multithread") # Insert the average processing time column after the max processing time # column. index = col_names.index("max_processing_time_ns") @@ -137,16 +207,21 @@ def _pretty_format_summary( for node_id in sorted(summary.nodes, reverse=True): row_values = [] + node = summary.nodes[node_id] for name in col_names: - is_total_processing_time_zero = ( - summary.nodes[node_id].total_processing_time_ns == 0 - ) + is_total_processing_time_zero = node.total_processing_time_ns == 0 if name == _AVG_PROCESSING_TIME_COLUMN_NAME: value = _get_avg_processing_time_ns(summary, node_id) else: value = getattr(summary.nodes[node_id], name) - if name in ( + if name == "percent_iterator_wait_time": + if node.name.startswith("PrefetchDatasetIterator"): + col_value = "N/A" + else: + col_value = _pretty_format_percent(value) + + elif name in ( "min_processing_time_ns", "max_processing_time_ns", "total_processing_time_ns", @@ -307,6 +382,7 @@ def __init__(self, config: StatsConfig, parents: Sequence[Stats]): self._config = config self._self_output_spec = None self._parents = parents + self._is_multithread = False # Mark parent nodes as non-outputs. Nodes that are not updated are the # output nodes. self._is_output = True @@ -548,6 +624,7 @@ def _build_execution_summary( self._summary.name = self._config.name self._summary.output_spec = str(self.output_spec) self._summary.is_output = self._is_output + self._summary.is_multithread = self._is_multithread execution_summary.nodes.get_or_create(node_id) execution_summary.nodes[node_id].CopyFrom(self._summary) current_node_id = node_id @@ -563,6 +640,7 @@ def _get_execution_summary(self) -> execution_summary_pb2.ExecutionSummary: """Returns ExecutionStats Summary for the dataset pipeline.""" execution_summary = execution_summary_pb2.ExecutionSummary() result, _ = self._build_execution_summary(execution_summary, 0) + _compute_percent_iterator_wait_time(result) return result @contextlib.contextmanager diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 5353451a..8a28cc2a 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -108,6 +108,7 @@ def __init__( @functools.cached_property def _stats(self): + self._map_parent._stats._is_multithread = True # pylint: disable=protected-access execution_tracking_mode = self._options_with_default.execution_tracking_mode parent_stats = self._map_parent._initialize_stats( # pylint: disable=protected-access execution_tracking_mode