diff --git a/.github/workflows/python_lint.yml b/.github/workflows/python_lint.yml index 00d218e9..a28514a1 100644 --- a/.github/workflows/python_lint.yml +++ b/.github/workflows/python_lint.yml @@ -13,7 +13,7 @@ jobs: - name: Setup Python Environment uses: actions/setup-python@v2 with: - python-version: '3.8' + python-version: '3.10' - name: Install Dependencies run: | diff --git a/et_replay/comm/backend/base_backend.py b/et_replay/comm/backend/base_backend.py index 6ab271e8..2f0e91dc 100644 --- a/et_replay/comm/backend/base_backend.py +++ b/et_replay/comm/backend/base_backend.py @@ -5,7 +5,6 @@ import logging from abc import ABC, abstractmethod -from typing import Dict, List, Optional import torch @@ -100,7 +99,9 @@ def __init__(self) -> None: self.dataSize = 0 self.numElements = 0 self.waitObj = [] - self.waitObjIds = {} # mapping of (pg_id, req_id, is_p2p) to future of async collectives + self.waitObjIds = ( + {} + ) # mapping of (pg_id, req_id, is_p2p) to future of async collectives self.ipTensor_split_pair = [] self.opTensor_split_pair = [] diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index 09097d74..397b4c0c 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -10,7 +10,6 @@ from itertools import cycle from time import sleep -from typing import Dict, List, Optional, Tuple import numpy as np import torch diff --git a/et_replay/comm/commsTraceParser.py b/et_replay/comm/commsTraceParser.py index 0a2385d5..bade289e 100644 --- a/et_replay/comm/commsTraceParser.py +++ b/et_replay/comm/commsTraceParser.py @@ -4,7 +4,6 @@ import json import logging -from typing import List, Tuple from et_replay import ExecutionTrace from et_replay.comm import comms_utils @@ -96,7 +95,9 @@ def _parse_proc_group_info(in_trace: ExecutionTrace): ) pg_id = int(pg_id) pg_ranks_map[node.id][pg_id] = ( - ranks if len(ranks) > 0 else list(range(group_size)) + ranks + if len(ranks) > 0 + else list(range(group_size)) # rank list is empty when all ranks are in a pg ) break # only one process_group init node per trace diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index 54e7ff1d..d336b4aa 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -17,7 +17,7 @@ from collections.abc import Callable from contextlib import ContextDecorator from io import StringIO -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any try: from param_bench.train.comms.pt.fb.internals import ( @@ -691,7 +691,9 @@ def __init__(self, args: Namespace) -> None: self.quant_a2a_embedding_dim = args.quant_a2a_embedding_dim self.quant_threshold = args.quant_threshold self.dcheck = args.c - self.groupRanks = {} # record what ranks each process group will work on {pg_id, ranks} + self.groupRanks = ( + {} + ) # record what ranks each process group will work on {pg_id, ranks} self.use_ext_dist = args.use_ext_dist self.size_from_trace = False self.init_method = args.init_method diff --git a/et_replay/execution_trace.py b/et_replay/execution_trace.py index 67409953..bc6e1b44 100644 --- a/et_replay/execution_trace.py +++ b/et_replay/execution_trace.py @@ -9,7 +9,7 @@ from collections.abc import Iterable from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Set, TextIO, Tuple +from typing import Any, TextIO import pydot @@ -66,8 +66,8 @@ def add_shape(self, shape: list[Any]): def is_leaf_tensor(self): return ( - (not self.sources) and self.sinks - ) # A tensor having no sources yet having some sinks is a leaf tensor + not self.sources + ) and self.sinks # A tensor having no sources yet having some sinks is a leaf tensor @dataclass diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index bf349666..d6a2601b 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -11,7 +11,6 @@ import logging import os import time -from typing import Dict, List, Set, Tuple, Union import numpy as np import torch @@ -114,7 +113,7 @@ def __init__(self): self.max_msg_cnt = 0 # 0 means no limit self.num_msg = 0 self.is_blocking = False - self.do_warm_up = False + self.warmup_iter = 5 self.reuse_tensors = False self.allowList = "" @@ -215,10 +214,10 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace: help="Only replay first N operations (0 means no limit)", ) parser.add_argument( - "--do-warm-up", - action="store_true", - default=self.do_warm_up, - help="Toggle to enable performing extra replaying for warm-up", + "--warmup-iter", + type=int, + default=self.warmup_iter, + help="Number of warmup iterations", ) parser.add_argument( "--reuse-tensors", @@ -275,7 +274,7 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace: "--profiler-num-replays-start", type=int, default=self.profiler_num_replays_start, - help=f"Replay iteration to start collecting profiler after warmup (if --do-warm-up is True). Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True", + help=f"Replay iteration to start collecting profiler after warmup runs. Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True", ) parser.add_argument( "--profiler-num-replays", @@ -393,12 +392,19 @@ def reportBenchTime(self): if not self.is_dry_run: print("\n{} Performance of replayed comms {}".format("=" * 20, "=" * 20)) print( - "{}\n Total latency (us) of comms in trace {}: \n{}".format( + "{}\n Total latency (us) of comms in trace: {}. \n{}".format( "-" * 50, self.totalTraceLatency, "-" * 50, ) ) + print( + "{}\n Average latency (us) of comms in trace: {}. \n{}".format( + "-" * 50, + self.totalTraceLatency / self.num_replays, + "-" * 50, + ) + ) for coll, lats in self.collLat.items(): if len(lats) == 0: continue @@ -1199,9 +1205,7 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: """ if commsParams.enable_profiler: # num of iterations to skip - numWarmupIters = ( - 1 if self.do_warm_up else 0 - ) + self.profiler_num_replays_start + numWarmupIters = self.warmup_iter + self.profiler_num_replays_start # num of iterations to profile, at most num_replays iterations numProfileIters = ( self.profiler_num_replays @@ -1215,18 +1219,6 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: numIters=numProfileIters, ) - # warm-up - if self.do_warm_up: - if self.collectiveArgs.enable_profiler: - comms_utils.sampleProfiler() - self.replayIter = -1 - self.replayTrace(commsParams=commsParams, warmup=True) - self.resetComms() - - # sync everything before starting real runs - with paramProfile(description="# PARAM replay warmup post-replay global sync"): - self.backendFuncs.sync_barrier(self.collectiveArgs) - if self.backendFuncs.get_global_rank() == 0: logger.info( f"\n+ {self.max_msg_cnt} messages in the trace...replaying (if present) {list(self.allowList)}" @@ -1234,10 +1226,10 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: for coll, sizes in self.collInMsgBytes.items(): logger.info(f"\t{coll}: {len(sizes)}") - traceStartTime = time.monotonic_ns() - for i in range(self.num_replays): - if self.backendFuncs.get_global_rank() == 0: - logger.info(f"Replay #{i}") + traceStartTime = 0 + for i in range(self.warmup_iter + self.num_replays): + if i == self.warmup_iter: + traceStartTime = time.monotonic_ns() if self.collectiveArgs.enable_profiler: comms_utils.sampleProfiler() @@ -1248,6 +1240,9 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: except NameError: pass + if self.backendFuncs.get_global_rank() == 0: + s = time.monotonic_ns() + # replay comms trace self.replayIter = i self.replayTrace(commsParams=commsParams, warmup=False) @@ -1259,6 +1254,10 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: ): self.backendFuncs.sync_barrier(self.collectiveArgs) + if self.backendFuncs.get_global_rank() == 0: + e = time.monotonic_ns() + logger.info(f"Replay #{i} took {(e-s)/1e3:.2f} us") + # record how long it took for trace-replay to complete traceEndTime = time.monotonic_ns() self.totalTraceLatency = (traceEndTime - traceStartTime) / 1e3 # make it us @@ -1457,7 +1456,7 @@ def initBench( self.shrink = args.auto_shrink self.max_msg_cnt = args.max_msg_cnt self.is_blocking = args.blocking - self.do_warm_up = args.do_warm_up + self.warmup_iter = args.warmup_iter self.reuse_tensors = args.reuse_tensors self.allowList = args.allow_ops if args.output_ranks == "all": diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 0d115f40..ca7b4a5b 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -7,7 +7,6 @@ import time from collections import defaultdict from datetime import datetime -from typing import Dict import numpy as np import torch diff --git a/train/comms/pt/commsTraceParser.py b/train/comms/pt/commsTraceParser.py index 929097b8..2b947d5d 100644 --- a/train/comms/pt/commsTraceParser.py +++ b/train/comms/pt/commsTraceParser.py @@ -3,8 +3,6 @@ import json -from typing import List, Tuple - from et_replay import ExecutionTrace from param_bench.train.comms.pt import comms_utils @@ -188,7 +186,7 @@ def _parseKinetoUnitrace(in_trace: list, target_rank: int) -> list: def _getTensorInfoFromPyTorchETEntry( tensor_container: list, container_type: str -) -> tuple[int, int, str]: +) -> tuple[int, str]: """ Extract message size, tensor count, type from PyTorch ET entry inputs/outputs field. NOTE: This format can be changed at anytime. TODO: When an extract/parsing tool is available in ATC, switch to it. @@ -249,7 +247,9 @@ def _parseExecutionTrace( ranks = pg["ranks"] groupCnt = pg["group_count"] pgRanksMap[pgId] = ( - ranks if len(ranks) > 0 else list(range(pg["group_size"])) + ranks + if len(ranks) > 0 + else list(range(pg["group_size"])) # rank list is empty when all ranks are in a pg ) if ET_BACKENDID: diff --git a/train/comms/pt/comms_utils.py b/train/comms/pt/comms_utils.py index c2ef06cf..e4e89b85 100644 --- a/train/comms/pt/comms_utils.py +++ b/train/comms/pt/comms_utils.py @@ -793,7 +793,9 @@ def __init__(self, args: Namespace) -> None: self.quant_a2a_embedding_dim = args.quant_a2a_embedding_dim self.quant_threshold = args.quant_threshold self.dcheck = args.c - self.groupRanks = {} # record what ranks each process group will work on {pg_id, ranks} + self.groupRanks = ( + {} + ) # record what ranks each process group will work on {pg_id, ranks} self.use_ext_dist = args.use_ext_dist self.size_from_trace = False self.init_method = args.init_method