Skip to content

Commit

Permalink
add support to bw calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
sanshang-nv committed Jan 10, 2025
1 parent 32a023a commit 7af31e6
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 3 deletions.
252 changes: 252 additions & 0 deletions et_replay/comm/profiler_trace_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import logging
import os
import json
import pathlib
from typing import Dict, Any, Optional, Callable
from collections import defaultdict
import ast
import numpy as np
from intervaltree import Interval, IntervalTree

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# refer to:
# https://github.com/pytorch/pytorch/blob/2cc01cc6d3ad2aff47e8460667ba654b2e4c9f21/c10/core/ScalarType.h#L61
_dtype_size_map: Dict[str, int] = {
'Byte' : 1,
'Char' : 1,
'Short' : 2,
'Int' : 4,
'Long' : 8,
'Half' : 2,
'Float' : 4,
'Double' : 8,
'ComplexHalf' : 4,
'ComplexFloat' : 8,
'ComplexDouble' : 16,
'Bool' : 1,
'QInt8' : 1,
'QUInt8' : 1,
'QInt32' : 4,
'BFloat16' : 2,
'QUInt4x2' : 1,
'QUInt2x4' : 1,
'Bits1x8' : 1,
'Bits2x4' : 1,
'Bits4x2' : 1,
'Bits8' : 1,
'Bits16' : 2,
'Float8_e5m2' : 1,
'Float8_e4m3fn' : 1,
'Float8_e5m2fnuz' : 1,
'Float8_e4m3fnuz' : 1,
'UInt16' : 2,
'UInt32' : 4,
'UInt64' : 8,
}

# refer to: https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md
_busbw_correction_factors_tbl: Dict[str, Callable[[int], float]] = {
'all_reduce' : (lambda n: 2 * (n - 1) / n),
'all_gather' : (lambda n: (n - 1) / n),
'all_to_all' : (lambda n: (n - 1) / n),
'reduce_scatter' : (lambda n: (n - 1) / n),
'reduce' : (lambda n: 1),
'scatter' : (lambda n: (n - 1)/n),
'gather' : (lambda n: (n - 1)/n),
'broadcast' : (lambda n: 1),
'send' : (lambda n: 1),
'recv' : (lambda n: 1),
}

# map collective name of event to key string for bw calculation
_collname_to_busbw_corr_factor_func: Dict[str, Callable[[int], float]] = {
'allreduce' : _busbw_correction_factors_tbl['all_reduce'],
'all_gather' : _busbw_correction_factors_tbl['all_gather'],
'_allgather_base' : _busbw_correction_factors_tbl['all_gather'],
'reduce_scatter' : _busbw_correction_factors_tbl['reduce_scatter'],
'_reduce_scatter_base' : _busbw_correction_factors_tbl['reduce_scatter'],
'all_to_all' : _busbw_correction_factors_tbl['all_to_all'],
'all_to_allv' : _busbw_correction_factors_tbl['all_to_all'],
'broadcast' : _busbw_correction_factors_tbl['broadcast'],
'reduce' : _busbw_correction_factors_tbl['reduce'],
'gather' : _busbw_correction_factors_tbl['gather'],
'scatter' : _busbw_correction_factors_tbl['scatter'],
'send' : _busbw_correction_factors_tbl['send'],
'recv' : _busbw_correction_factors_tbl['recv'],
}

def _get_dict_value(d, k, err_msg):
if k not in d:
raise ValueError(err_msg)
return d.get(k)

def _calculate_event_data_size(evt):
return max(evt['args']['In msg nelems'], evt['args']['Out msg nelems']) * _dtype_size_map[evt['args']['dtype']]

def _calculate_algbw(evt: Dict[str, Any]) -> float:
duration_us = _get_dict_value(evt, 'dur', f'Missing "dur" in event: {evt}')
total_bytes = _calculate_event_data_size(evt)

# NCCL tests use 1024^3 to convert B to GB (but not 1e9)
# https://github.com/NVIDIA/nccl-tests/blob/8dfeab9eb9bdfdf13503e71e1f33e7f8a208b540/src/common.cu#L102
# but it uses 1e9 to convert bw from B/s to GB/s
# https://github.com/NVIDIA/nccl-tests/blob/8dfeab9eb9bdfdf13503e71e1f33e7f8a208b540/src/all_gather.cu#L41
return round((total_bytes / duration_us) / 1e3, 2)

def _get_event_busbw_factor(evt):
coll_name = _get_dict_value(evt['args'], 'Collective name', f'Missing "Collective name" in event: {evt}')

# barrier is implemented using AllReduce
if coll_name in ['barrier', ]:
return 0

group_size = _get_dict_value(evt['args'], 'Group size', f'Missing "Group size" in event: {evt}')
correction_factor_func = _get_dict_value(_collname_to_busbw_corr_factor_func,
coll_name,
f'Unsupported collective op for busbw calculation: {coll_name}')

return correction_factor_func(group_size)

def calculate_bw_(trace_data):
nccl_events = [i for i in trace_data['traceEvents'] if i.get('cat', '') == 'kernel' and i['name'].startswith('ncclDevKernel_')]
for evt in nccl_events:
try:
coll_name = _get_dict_value(evt['args'], 'Collective name', f'Missing "Collective name" in event: {evt}')

# barrier is implemented using AllReduce
if coll_name in ['barrier', ]:
continue

algbw = _calculate_algbw(evt)
busbw_factor = _get_event_busbw_factor(evt)
busbw = round(algbw * busbw_factor, 2)

evt['args']['algbw (GB/sec)'] = algbw
evt['args']['busbw (GB/sec)'] = busbw
evt['args']['busbw_factor'] = busbw_factor
except ValueError as e:
logger.error('Error processing event: %s', e)

def calculate_sbw(trace_data):
# calculate shared bw per rank
nccl_events = [i for i in trace_data['traceEvents'] if i.get('cat', '') == 'kernel' \
and i['name'].startswith('ncclDevKernel_') \
and 'busbw_factor' in i['args']]

if not len(nccl_events):
return 0

total_data_size = sum([_calculate_event_data_size(evt) * _get_event_busbw_factor(evt) for evt in nccl_events])

time_range_tree = IntervalTree([Interval(evt['ts'], evt['ts'] + evt['dur']) for evt in nccl_events])
time_range_tree.merge_overlaps()

begin_time_point = min([i.begin for i in time_range_tree])
end_time_point = max([i.end for i in time_range_tree])

sorted_tr = sorted(time_range_tree)
total_idle_time = \
sum([sorted_tr[i + 1].begin - sorted_tr[i].end for i in range(len(sorted_tr) - 1)]) \
if len(sorted_tr) > 1 else 0

return total_data_size / (end_time_point - begin_time_point - total_idle_time) / 1e3

def pick_iter_e2e_time_(trace_data, tl):
tl.extend([evt['dur'] for evt in trace_data['traceEvents'] if evt.get('cat', '') == 'user_annotation' and evt['name'].startswith('ProfilerStep#')])

def pick_comm_bw_(trace_data, comm_bw_data):
rank = trace_data['distributedInfo']['rank']
nccl_events = [i for i in trace_data['traceEvents'] if i.get('cat', '') == 'kernel' \
and i['name'].startswith('ncclDevKernel_') \
and 'algbw (GB/sec)' in i['args']]
for evt in nccl_events:
knl_name = evt['name'][:evt['name'].index('(')]
data_size = _calculate_event_data_size(evt)
ranks_count = evt['args']['Group size']

ranks = ast.literal_eval(evt['args']['Process Group Ranks'])
pg_id = int(evt['args']['Process Group Name'])
pg = tuple([*ranks, pg_id]) if rank == min(ranks) else None

comm_bw_data[(knl_name, data_size, ranks_count)].append(
[evt['dur'], evt['args']['algbw (GB/sec)'], evt['args']['busbw (GB/sec)'], pg]
)

def analyze_profiler_trace(trace_dir: str, report_dir: str):
'''
Analyse input PyTorch profiler trace (i.e. Kineto trace) and generate report.
Args:
trace_dir (str): dir path of input traces, where trace name should be in "rank-n.json" format.
report_dir (str): dir path for generated reports
'''
logger.info(f'Parse profiler trace from "{trace_dir}" and generate reports to "{report_dir}"')

processed_trace_dir = os.path.join(report_dir, 'profiler_trace_processed')
pathlib.Path(processed_trace_dir).mkdir(parents=True, exist_ok=True)

# list of iteration time in all ranks
iter_e2e_time = []

# list of shared bw
sbw_lst = []

# key is (kernel_name, data size, ranks number)
# value is list of [dur, algbw, busbw, pg]
comm_bw_data = defaultdict(list)

for fpath in os.scandir(trace_dir):
if not fpath.is_file():
continue

with open(fpath.path, 'r', encoding='utf-8') as f:
trace = json.load(f)

calculate_bw_(trace)
with open(os.path.join(processed_trace_dir, fpath.name), 'w', encoding='utf-8') as f:
json.dump(trace, f)

sbw_lst.append(calculate_sbw(trace))

pick_iter_e2e_time_(trace, iter_e2e_time)
pick_comm_bw_(trace, comm_bw_data)

comm_bw_summary = {}
for k,v in comm_bw_data.items():
t_lst = [i[0] for i in v]
busbw_lst = [i[2] for i in v]
pg_set = set([i[3] for i in v if i[3]])
comm_bw_summary[k] = [len(pg_set),
np.average(t_lst),
np.average(busbw_lst),
np.percentile(busbw_lst, 1),
np.percentile(busbw_lst, 50),
np.percentile(busbw_lst, 90),
np.percentile(busbw_lst, 99)]
comm_bw_summary = dict(sorted(comm_bw_summary.items()))

# dump summary report
with open(os.path.join(report_dir, 'profiler_trace_summary_report.txt'), 'w', encoding='utf-8') as f:
f.write(f'avg. E2ETime of iters among all ranks: {sum(iter_e2e_time) / len(iter_e2e_time) / 1e3 :.3f} ms\n')
f.write(f'avg. SharedBW (i.e. sum(data_size * busbw_factor) / GPU_comm_busy_time per rank) among all ranks: {sum(sbw_lst) / len(sbw_lst) :.3f} GB/s\n')

f.write(f'\n{" ":>70s}|{" ":>5s}|{"AVG.":^19s}|{"p01":^8s}|{"p50":^8s}|{"p90":^8s}|{"p99":^8s}|\n')

f.write(f'{"kernel":>50s} {"size":>12s} {"#rks":>6s}|{"#pgs":>5s}|{" dur":>10s} ')
for i in range(5): # average, p01, p50, p90, p99
f.write(f'{" busbw":>8s}|')
f.write('\n')

f.write(f'{" ":>50s} {" (B)":>12s} {" ":>6s}|{" ":>5s}|{" (ms)":>10s} ')
for i in range(5): # average, p50, p90, p99
f.write(f'{"(GB/s)":>8s}|')
f.write('\n')

for k,v in comm_bw_summary.items():
f.write(f'{k[0]:>50s} {k[1]:>12d} {k[2]:>6d}|{v[0]:>5d}|{v[1]/1e3:>10.3f} ')
for i in range(2, len(v)):
f.write(f'{v[i]:>8.2f}|')
f.write('\n')

4 changes: 4 additions & 0 deletions et_replay/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ build-backend = "setuptools.build_meta"
[project]
name = "et_replay"
version = "0.5.0"
dependencies = [
"numpy",
"intervaltree",
]

[tool.setuptools.package-dir]
"et_replay" = "."
Expand Down
11 changes: 8 additions & 3 deletions et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from torch.profiler import profile, ProfilerActivity, schedule

from et_replay.comm import comms_utils, commsTraceParser
from et_replay.comm import comms_utils, commsTraceParser, profiler_trace_analysis
from et_replay.comm.backend.base_backend import supportedP2pOps
from et_replay.comm.comms_utils import (
bootstrap_info_holder,
Expand Down Expand Up @@ -1302,6 +1302,8 @@ def trace_handler(p):
# cleanup any memory left in use
self.backendFuncs.clear_memory(self.collectiveArgs)

self.backendFuncs.barrier_all_ranks()

def runBench(
self,
commsParams: commsParamsHolderBase,
Expand Down Expand Up @@ -1358,8 +1360,11 @@ def runBench(
rank=global_rank
)
# TODO: collect perf. from all ranks to rank 0 and detect any imbalanced perf?
self.backendFuncs.barrier(self.collectiveArgs)
self.backendFuncs.complete_accel_ops(self.collectiveArgs)

if commsParams.enable_profiler and not fb_internal.has_fb_internal_libs and self.backendFuncs.get_global_rank() == 0:
profiler_trace_analysis.analyze_profiler_trace(os.path.join(self.out_path, 'profiler_trace'), self.out_path)

self.backendFuncs.barrier_all_ranks()

def replayInit(
self,
Expand Down

0 comments on commit 7af31e6

Please sign in to comment.