Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708388937
  • Loading branch information
Grain Team authored and copybara-github committed Jan 17, 2025
1 parent f3bec91 commit 9c1b504
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 9 deletions.
107 changes: 100 additions & 7 deletions grain/_src/python/dataset/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
_AVG_PROCESSING_TIME_COLUMN_NAME = "avg processing time"

_COLUMN_NAME_OVERRIDES = types.MappingProxyType({
"wait_time_ratio": "percent wait time",
"min_processing_time_ns": "min processing time",
"max_processing_time_ns": "max processing time",
"total_processing_time_ns": "total processing time",
Expand All @@ -79,6 +80,10 @@
_MAX_ROW_LINES = 5


def _format_ratio_as_percent(value: float) -> str:
return f"{value*100:.2f}%"


def _pretty_format_ns(value: int) -> str:
"""Pretty formats a time value in nanoseconds to human readable value."""
if value < 1000:
Expand All @@ -104,6 +109,89 @@ def _get_avg_processing_time_ns(
)


def _get_nodes_before_prefetch(
node: int, summary: execution_summary_pb2.ExecutionSummary
) -> list[int]:
"""Returns nodes in the path from a given node to a prefetch node."""
child_nodes = []
nodes_to_visit = [node]
while nodes_to_visit:
node_id = nodes_to_visit.pop()
node = summary.nodes[node_id]
child_nodes.append(node_id)
if node.is_prefetch:
continue # Skip adding inputs for the prefetch node
nodes_to_visit.extend(node.inputs)
return child_nodes


def _find_aggregated_processing_time(
summary: execution_summary_pb2.ExecutionSummary,
node_ids: list[int],
) -> int:
"""Finds aggregated processing time for the given node IDs."""
return sum(
summary.nodes[node_id].total_processing_time_ns for node_id in node_ids
)


def _compute_wait_time_ratio(
summary: execution_summary_pb2.ExecutionSummary,
node_id: int,
aggregated_wait_time_ns: int,
prefetch_factor: int = 1,
) -> None:
"""Computes the wait time ratio for all the nodes in the execution summary.
Args:
summary: The execution summary to update the `wait_time_ratio` for.
node_id: The current node for which to compute the `wait_time_ratio`.
aggregated_wait_time_ns: The aggregated wait time of the nodes running under
prefetch.
prefetch_factor: The factor by which to multiply the `total_processing_time`
of the node to get it's wait time ratio.
"""
if aggregated_wait_time_ns == 0:
return
node = summary.nodes[node_id]
node_wait_ratio = prefetch_factor * (
node.total_processing_time_ns / aggregated_wait_time_ns
)
node.wait_time_ratio = round(node_wait_ratio, 4)
for input_node_id in node.inputs:
# If the node is executed in multiple threads, the iterator's wait time
# ratio attributed to the prefetch node is distributed among these nodes
# proportionally to their total processing time.
if node.is_prefetch:
prefetch_factor = node.wait_time_ratio
prefetch_child_nodes = _get_nodes_before_prefetch(input_node_id, summary)
aggregated_wait_time_ns = _find_aggregated_processing_time(
summary, prefetch_child_nodes
)
# The `wait_time_ratio` of the prefetch node is sum of `wait_time_ratio`
# of all the nodes running under it. Here we set it to 0 as it is already
# accounted for in the ancestor nodes and sum of `wait_time_ratio` of all
# the nodes in a pipeline should be 1.
node.wait_time_ratio = 0
_compute_wait_time_ratio(
summary,
input_node_id,
aggregated_wait_time_ns,
prefetch_factor,
)


def _populate_wait_time_ratio(
summary: execution_summary_pb2.ExecutionSummary,
) -> None:
"""Populates the `wait_time_ratio` for all the nodes in the execution summary."""
iterator_nodes = _get_nodes_before_prefetch(0, summary)
aggregated_wait_time_ns = _find_aggregated_processing_time(
summary, iterator_nodes
)
_compute_wait_time_ratio(summary, 0, aggregated_wait_time_ns)


def _pretty_format_summary(
summary: execution_summary_pb2.ExecutionSummary,
) -> str:
Expand All @@ -114,9 +202,6 @@ def _pretty_format_summary(
# the visualization graph.
col_names.remove("output_spec")
col_names.remove("is_output")
# TODO: Add a column for `is_prefetch` in the logged execution
# summary.
col_names.remove("wait_time_ratio")
col_names.remove("is_prefetch")
# Insert the average processing time column after the max processing time
# column.
Expand All @@ -141,16 +226,23 @@ def _pretty_format_summary(

for node_id in sorted(summary.nodes, reverse=True):
row_values = []
node = summary.nodes[node_id]
for name in col_names:
is_total_processing_time_zero = (
summary.nodes[node_id].total_processing_time_ns == 0
)
is_total_processing_time_zero = node.total_processing_time_ns == 0
if name == _AVG_PROCESSING_TIME_COLUMN_NAME:
value = _get_avg_processing_time_ns(summary, node_id)
else:
value = getattr(summary.nodes[node_id], name)

if name in (
if name == "wait_time_ratio":
if node.is_prefetch:
# The iterator wait time of the prefetch node is distributed across
# its child nodes.
col_value = "N/A"
else:
col_value = _format_ratio_as_percent(value)

elif name in (
"min_processing_time_ns",
"max_processing_time_ns",
"total_processing_time_ns",
Expand Down Expand Up @@ -561,6 +653,7 @@ def _get_execution_summary(self) -> execution_summary_pb2.ExecutionSummary:
"""Returns ExecutionStats Summary for the dataset pipeline."""
execution_summary = execution_summary_pb2.ExecutionSummary()
result, _ = self._build_execution_summary(execution_summary, 0)
_populate_wait_time_ratio(result)
return result

@contextlib.contextmanager
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
"<class 'int'>[]"
││
││ MapDatasetIterator(transform=<lambda> @ .../python/dataset/stats_test.py:462)
││ MapDatasetIterator(transform=<lambda> @ .../python/dataset/stats_test.py:525)
││
╲╱
{'data': "<class 'int'>[]",
Expand Down
4 changes: 3 additions & 1 deletion grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,11 @@ def _threshold_checker(self):
)

def __next__(self) -> T:
# The time recorded here is the time spent in prefetch node to return an
# element, including the time spent in parent node.
timer = dataset_stats.Timer()
# We loop here to skip all None elements (in case the underlying dataset
# is sparse), if self._allow_nones = False, else we return Nones too.
timer = dataset_stats.Timer()
while True:
if self._next_index == self._dataset_length:
break
Expand Down

0 comments on commit 9c1b504

Please sign in to comment.