From d7d5449bef8805598a83bcf4a56cac06b23da0c0 Mon Sep 17 00:00:00 2001 From: Sevin Varoglu Date: Fri, 10 Jan 2025 16:54:12 +0000 Subject: [PATCH 1/6] Add ratio of hidden communication time to total communication time --- .../nsys_jax/analyses/communication.py | 28 +++++++++++++++++++ .../container/nsys_jax/nsys_jax/analysis.py | 3 ++ 2 files changed, 31 insertions(+) diff --git a/.github/container/nsys_jax/nsys_jax/analyses/communication.py b/.github/container/nsys_jax/nsys_jax/analyses/communication.py index 5388a1f84..b02e4af01 100644 --- a/.github/container/nsys_jax/nsys_jax/analyses/communication.py +++ b/.github/container/nsys_jax/nsys_jax/analyses/communication.py @@ -8,6 +8,7 @@ load_profiler_data, ) from math import sqrt +from statistics import mean import pathlib from uncertainties import ufloat # type: ignore @@ -95,6 +96,33 @@ def format_bandwidth(data, collective): ) ) + collective_types = set() + summary_data = defaultdict(dict) + for collective, df in steady_state.communication.groupby( + ["Collective"] + ): + collective_types.add(collective) + summary_data[collective] = df["DurHiddenMsToDurMs"].mean() + + collective_width = max(len("Collective"), max(len(f"{collective}") for collective in collective_types)) + ratio_width = len("Mean HiddenToTotalMs") + + print() + print(f"{'Collective':<{collective_width}} | {'Mean HiddenToTotalMs':<{ratio_width}}") + print(f"{'-' * collective_width} | {'-' * ratio_width}") + + for collective in collective_types: + mean_value = summary_data[collective] + collective_str = str(collective[0]) + print(f"{collective_str:<{collective_width}} | {mean_value:>{ratio_width}}") + + overall_hidden_ms_to_total_ms = ( + steady_state.communication["ProjDurHiddenMs"].sum() / + (steady_state.communication["ProjDurMs"] + steady_state.communication["ProjDurHiddenMs"]).sum() + ) + + print() + print(f"Overall HiddenMs to TotalMs: {overall_hidden_ms_to_total_ms:>{ratio_width}}") if __name__ == "__main__": main() diff --git a/.github/container/nsys_jax/nsys_jax/analysis.py b/.github/container/nsys_jax/nsys_jax/analysis.py index c4e37fdf9..31c16cbb7 100644 --- a/.github/container/nsys_jax/nsys_jax/analysis.py +++ b/.github/container/nsys_jax/nsys_jax/analysis.py @@ -331,6 +331,9 @@ def calculate_collective_metrics( comm_df["BusBandwidthGBPerSec"] = ( comm_df["AlgorithmBandwidthGBPerSec"] * comm_df["BusBandwidthCorrection"] ) + comm_df["DurHiddenMsToDurMs"] = ( + comm_df["ProjDurHiddenMs"] / (comm_df["ProjDurMs"] + comm_df["ProjDurHiddenMs"]) + ) return comm_df.drop(columns=["BandwidthCorrection", "BusBandwidthCorrection"]) From 14a59d640caa073b43728ce6188c9dcf93253868 Mon Sep 17 00:00:00 2001 From: Sevin Varoglu Date: Mon, 13 Jan 2025 18:13:18 +0000 Subject: [PATCH 2/6] Switch to PrettyTable --- .../nsys_jax/nsys_jax/analyses/communication.py | 17 ++++++----------- .github/container/nsys_jax/pyproject.toml | 1 + 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/.github/container/nsys_jax/nsys_jax/analyses/communication.py b/.github/container/nsys_jax/nsys_jax/analyses/communication.py index b02e4af01..4f07124a1 100644 --- a/.github/container/nsys_jax/nsys_jax/analyses/communication.py +++ b/.github/container/nsys_jax/nsys_jax/analyses/communication.py @@ -8,6 +8,7 @@ load_profiler_data, ) from math import sqrt +from prettytable import PrettyTable from statistics import mean import pathlib from uncertainties import ufloat # type: ignore @@ -104,25 +105,19 @@ def format_bandwidth(data, collective): collective_types.add(collective) summary_data[collective] = df["DurHiddenMsToDurMs"].mean() - collective_width = max(len("Collective"), max(len(f"{collective}") for collective in collective_types)) - ratio_width = len("Mean HiddenToTotalMs") - - print() - print(f"{'Collective':<{collective_width}} | {'Mean HiddenToTotalMs':<{ratio_width}}") - print(f"{'-' * collective_width} | {'-' * ratio_width}") + table = PrettyTable() + table.field_names = ["Collective", "Mean HiddenToTotalMs"] for collective in collective_types: mean_value = summary_data[collective] - collective_str = str(collective[0]) - print(f"{collective_str:<{collective_width}} | {mean_value:>{ratio_width}}") + table.add_row([collective[0], mean_value]) + print(table) overall_hidden_ms_to_total_ms = ( steady_state.communication["ProjDurHiddenMs"].sum() / (steady_state.communication["ProjDurMs"] + steady_state.communication["ProjDurHiddenMs"]).sum() ) - - print() - print(f"Overall HiddenMs to TotalMs: {overall_hidden_ms_to_total_ms:>{ratio_width}}") + print(f"Overall HiddenMs to TotalMs: {overall_hidden_ms_to_total_ms}") if __name__ == "__main__": main() diff --git a/.github/container/nsys_jax/pyproject.toml b/.github/container/nsys_jax/pyproject.toml index 95bdffd4c..d0c79ad43 100644 --- a/.github/container/nsys_jax/pyproject.toml +++ b/.github/container/nsys_jax/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "pyarrow", "requests", # for install-protoc "uncertainties", # communication analysis recipe + "prettytable", ] requires-python = ">= 3.10" From f22aef3c5eeb6e1c02a5e5cd795ed55f467fc6f0 Mon Sep 17 00:00:00 2001 From: Sevin Varoglu Date: Mon, 13 Jan 2025 18:28:17 +0000 Subject: [PATCH 3/6] Remove mean --- .github/container/nsys_jax/nsys_jax/analyses/communication.py | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/container/nsys_jax/nsys_jax/analyses/communication.py b/.github/container/nsys_jax/nsys_jax/analyses/communication.py index 4f07124a1..e1688f9a8 100644 --- a/.github/container/nsys_jax/nsys_jax/analyses/communication.py +++ b/.github/container/nsys_jax/nsys_jax/analyses/communication.py @@ -9,7 +9,6 @@ ) from math import sqrt from prettytable import PrettyTable -from statistics import mean import pathlib from uncertainties import ufloat # type: ignore From 110314ec7e59d510262fac84a7df77964027e713 Mon Sep 17 00:00:00 2001 From: Sevin Varoglu Date: Mon, 13 Jan 2025 18:41:30 +0000 Subject: [PATCH 4/6] Format --- .../nsys_jax/nsys_jax/analyses/communication.py | 13 +++++++------ .github/container/nsys_jax/nsys_jax/analysis.py | 4 ++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.github/container/nsys_jax/nsys_jax/analyses/communication.py b/.github/container/nsys_jax/nsys_jax/analyses/communication.py index e1688f9a8..39c3a3292 100644 --- a/.github/container/nsys_jax/nsys_jax/analyses/communication.py +++ b/.github/container/nsys_jax/nsys_jax/analyses/communication.py @@ -98,9 +98,7 @@ def format_bandwidth(data, collective): collective_types = set() summary_data = defaultdict(dict) - for collective, df in steady_state.communication.groupby( - ["Collective"] - ): + for collective, df in steady_state.communication.groupby(["Collective"]): collective_types.add(collective) summary_data[collective] = df["DurHiddenMsToDurMs"].mean() @@ -113,9 +111,12 @@ def format_bandwidth(data, collective): print(table) overall_hidden_ms_to_total_ms = ( - steady_state.communication["ProjDurHiddenMs"].sum() / - (steady_state.communication["ProjDurMs"] + steady_state.communication["ProjDurHiddenMs"]).sum() - ) + steady_state.communication["ProjDurHiddenMs"].sum() + / ( + steady_state.communication["ProjDurMs"] + + steady_state.communication["ProjDurHiddenMs"] + ).sum() + ) print(f"Overall HiddenMs to TotalMs: {overall_hidden_ms_to_total_ms}") if __name__ == "__main__": diff --git a/.github/container/nsys_jax/nsys_jax/analysis.py b/.github/container/nsys_jax/nsys_jax/analysis.py index 31c16cbb7..fba5e9cab 100644 --- a/.github/container/nsys_jax/nsys_jax/analysis.py +++ b/.github/container/nsys_jax/nsys_jax/analysis.py @@ -331,8 +331,8 @@ def calculate_collective_metrics( comm_df["BusBandwidthGBPerSec"] = ( comm_df["AlgorithmBandwidthGBPerSec"] * comm_df["BusBandwidthCorrection"] ) - comm_df["DurHiddenMsToDurMs"] = ( - comm_df["ProjDurHiddenMs"] / (comm_df["ProjDurMs"] + comm_df["ProjDurHiddenMs"]) + comm_df["DurHiddenMsToDurMs"] = comm_df["ProjDurHiddenMs"] / ( + comm_df["ProjDurMs"] + comm_df["ProjDurHiddenMs"] ) return comm_df.drop(columns=["BandwidthCorrection", "BusBandwidthCorrection"]) From af2aee6b184ebb8428bf0117d21c43cfdeec22c3 Mon Sep 17 00:00:00 2001 From: Sevin Varoglu Date: Mon, 13 Jan 2025 18:45:08 +0000 Subject: [PATCH 5/6] Format --- .github/container/nsys_jax/nsys_jax/analyses/communication.py | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/container/nsys_jax/nsys_jax/analyses/communication.py b/.github/container/nsys_jax/nsys_jax/analyses/communication.py index 39c3a3292..0e5f40a08 100644 --- a/.github/container/nsys_jax/nsys_jax/analyses/communication.py +++ b/.github/container/nsys_jax/nsys_jax/analyses/communication.py @@ -119,5 +119,6 @@ def format_bandwidth(data, collective): ) print(f"Overall HiddenMs to TotalMs: {overall_hidden_ms_to_total_ms}") + if __name__ == "__main__": main() From 08364a6d2ced5f8e14fd2fe2b02674327a65df6a Mon Sep 17 00:00:00 2001 From: Sevin Varoglu Date: Wed, 15 Jan 2025 22:46:53 +0000 Subject: [PATCH 6/6] Add review feedback --- .../nsys_jax/analyses/communication.py | 142 +++++++++++++++--- .../container/nsys_jax/nsys_jax/analysis.py | 3 - 2 files changed, 118 insertions(+), 27 deletions(-) mode change 100644 => 100755 .github/container/nsys_jax/nsys_jax/analyses/communication.py mode change 100644 => 100755 .github/container/nsys_jax/nsys_jax/analysis.py diff --git a/.github/container/nsys_jax/nsys_jax/analyses/communication.py b/.github/container/nsys_jax/nsys_jax/analyses/communication.py old mode 100644 new mode 100755 index 0e5f40a08..fe54ede30 --- a/.github/container/nsys_jax/nsys_jax/analyses/communication.py +++ b/.github/container/nsys_jax/nsys_jax/analyses/communication.py @@ -1,6 +1,8 @@ #!/usr/bin/env python import argparse +import csv from collections import defaultdict + from nsys_jax import ( align_profiler_data_timestamps, apply_warmup_heuristics, @@ -13,27 +15,7 @@ from uncertainties import ufloat # type: ignore -def main(): - parser = argparse.ArgumentParser( - description="Summarise communication in an nsys-jax report" - ) - parser.add_argument("prefix", type=pathlib.Path) - args = parser.parse_args() - # Make sure that the .proto files under protos/ have been compiled to .py, and - # that those generated .py files are importable. - ensure_compiled_protos_are_importable(prefix=args.prefix) - # Load the profiler data; the compilation part is needed for the warmup heuristics - all_data = load_profiler_data(args.prefix, frames={"communication", "compile"}) - # Align timestamps - all_data, alignment_metadata = align_profiler_data_timestamps(all_data) - # TODO: make this pretty - # print(alignment_metadata) - # Partition the profile data into initialisation and steady-state running - _, steady_state = apply_warmup_heuristics(all_data) - assert len(steady_state.communication), ( - "Communication summary was requested but no steady-state communication was " - "identified." - ) +def process_communication_data(steady_state): collective_types = set() summary_data = defaultdict(dict) for (collective, message_size), df in steady_state.communication.groupby( @@ -53,7 +35,10 @@ def main(): summary_data[message_size][collective] = ufloat( bandwidth.mean(), bandwidth.std() / sqrt(len(bandwidth)) ) - collective_types = sorted(collective_types) + return sorted(collective_types), summary_data + + +def print_bandwidth_table(collective_types, summary_data): collective_widths = { collective: max( len(collective), @@ -96,19 +81,39 @@ def format_bandwidth(data, collective): ) ) + +def process_hidden_ms_to_total_ms(steady_state): + if steady_state.communication["ProjDurHiddenMs"].sum() == 0: + return None, None + collective_types = set() summary_data = defaultdict(dict) for collective, df in steady_state.communication.groupby(["Collective"]): collective_types.add(collective) - summary_data[collective] = df["DurHiddenMsToDurMs"].mean() + mean_dur_hidden_ms_to_total_ms = ( + df["ProjDurHiddenMs"] / (df["ProjDurMs"] + df["ProjDurHiddenMs"]) + ).mean() + summary_data[collective] = mean_dur_hidden_ms_to_total_ms + return collective_types, summary_data + +def print_hidden_ms_to_total_ms_table( + collective_types, summary_data, overall_hidden_ms_to_total_ms +): table = PrettyTable() table.field_names = ["Collective", "Mean HiddenToTotalMs"] for collective in collective_types: mean_value = summary_data[collective] table.add_row([collective[0], mean_value]) + print(table) + print("Overall HiddenMs to TotalMs:", overall_hidden_ms_to_total_ms) + + +def calculate_overall_hidden_ms_to_total_ms(steady_state): + if steady_state.communication["ProjDurHiddenMs"].sum() == 0: + return None overall_hidden_ms_to_total_ms = ( steady_state.communication["ProjDurHiddenMs"].sum() @@ -117,7 +122,96 @@ def format_bandwidth(data, collective): + steady_state.communication["ProjDurHiddenMs"] ).sum() ) - print(f"Overall HiddenMs to TotalMs: {overall_hidden_ms_to_total_ms}") + return overall_hidden_ms_to_total_ms + + +def write_to_csv( + collective_types, + bandwidth_summary, + hidden_to_total_summary, + overall_hidden_ms_to_total_ms, + output_file, +): + with open(output_file, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + + # Write bandwidth table + writer.writerow(["Bandwidth Table"]) + writer.writerow(["Size [B]"] + list(collective_types)) + for message_size in sorted(bandwidth_summary.keys()): + row = [message_size] + for collective in collective_types: + if collective in bandwidth_summary[message_size]: + row.append(f"{bandwidth_summary[message_size][collective]:S}") + else: + row.append("-") + writer.writerow(row) + + writer.writerow([]) # Empty row for separation + + # Write hidden to total table if data is available + if hidden_to_total_summary is not None: + writer.writerow(["HiddenMs to TotalMs Table"]) + writer.writerow(["Collective", "Mean HiddenToTotalMs"]) + for collective in hidden_to_total_summary: + writer.writerow([collective[0], hidden_to_total_summary[collective]]) + + writer.writerow([]) # Empty row for separation + + if overall_hidden_ms_to_total_ms is not None: + writer.writerow( + ["Overall HiddenMs to TotalMs", overall_hidden_ms_to_total_ms] + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Summarise communication in an nsys-jax report" + ) + parser.add_argument("prefix", type=pathlib.Path) + args = parser.parse_args() + + # Make sure that the .proto files under protos/ have been compiled to .py, and + # that those generated .py files are importable. + ensure_compiled_protos_are_importable(prefix=args.prefix) + # Load the profiler data; the compilation part is needed for the warmup heuristics + all_data = load_profiler_data(args.prefix, frames={"communication", "compile"}) + # Align timestamps + all_data, alignment_metadata = align_profiler_data_timestamps(all_data) + # TODO: make this pretty + # print(alignment_metadata) + # Partition the profile data into initialisation and steady-state running + _, steady_state = apply_warmup_heuristics(all_data) + + assert len(steady_state.communication), ( + "Communication summary was requested but no steady-state communication was " + "identified." + ) + + collective_types, bandwidth_summary = process_communication_data(steady_state) + print_bandwidth_table(collective_types, bandwidth_summary) + + hidden_to_total_collective_types, hidden_to_total_summary = ( + process_hidden_ms_to_total_ms(steady_state) + ) + if hidden_to_total_summary is not None: + overall_hidden_ms_to_total_ms = calculate_overall_hidden_ms_to_total_ms( + steady_state + ) + print_hidden_ms_to_total_ms_table( + hidden_to_total_collective_types, + hidden_to_total_summary, + overall_hidden_ms_to_total_ms, + ) + + # Write all tables to a single CSV file + write_to_csv( + collective_types, + bandwidth_summary, + hidden_to_total_summary, + overall_hidden_ms_to_total_ms, + "communication_summary.csv", + ) if __name__ == "__main__": diff --git a/.github/container/nsys_jax/nsys_jax/analysis.py b/.github/container/nsys_jax/nsys_jax/analysis.py old mode 100644 new mode 100755 index fba5e9cab..c4e37fdf9 --- a/.github/container/nsys_jax/nsys_jax/analysis.py +++ b/.github/container/nsys_jax/nsys_jax/analysis.py @@ -331,9 +331,6 @@ def calculate_collective_metrics( comm_df["BusBandwidthGBPerSec"] = ( comm_df["AlgorithmBandwidthGBPerSec"] * comm_df["BusBandwidthCorrection"] ) - comm_df["DurHiddenMsToDurMs"] = comm_df["ProjDurHiddenMs"] / ( - comm_df["ProjDurMs"] + comm_df["ProjDurHiddenMs"] - ) return comm_df.drop(columns=["BandwidthCorrection", "BusBandwidthCorrection"])