From 2cf4284a5b8007fb66ca903c079d38fd9d5f0e44 Mon Sep 17 00:00:00 2001 From: Sheng Fu Date: Tue, 10 Dec 2024 13:36:08 -0800 Subject: [PATCH] Upgrade Python to 3.10 and allow user to specify the number of warmup iterations (#189) Summary: Meta production is on Python 3.10. PARAM lint is still on 3.8. I saw Lint reported some issues related to type hint syntax. To fix it, I upgrade python to 3.10. ALso changed the option to allow user to specify the number of warmup iteration runs. Also print out the wall clock time for each iteration and the average wall clock time for the whole run. Reviewed By: kingchc Differential Revision: D66801912 --- .github/workflows/python_lint.yml | 10 ++-- et_replay/comm/backend/base_backend.py | 1 - .../comm/backend/pytorch_dist_backend.py | 1 - et_replay/comm/commsTraceParser.py | 1 - et_replay/comm/comms_utils.py | 2 +- et_replay/execution_trace.py | 2 +- et_replay/tools/comm_replay.py | 55 +++++++++---------- et_replay/tools/et_replay.py | 1 - train/comms/pt/commsTraceParser.py | 4 +- 9 files changed, 36 insertions(+), 41 deletions(-) diff --git a/.github/workflows/python_lint.yml b/.github/workflows/python_lint.yml index 00d218e9..d2706a9f 100644 --- a/.github/workflows/python_lint.yml +++ b/.github/workflows/python_lint.yml @@ -13,14 +13,16 @@ jobs: - name: Setup Python Environment uses: actions/setup-python@v2 with: - python-version: '3.8' + python-version: '3.10' - name: Install Dependencies run: | pip install black - - - name: Run Black - run: black . --check + # OpenSource Black seems not match Meta internal version + # temporarily disable it until we figure out how to make + # them consistent + # - name: Run Black + # run: black . --check - name: Run tests run: | diff --git a/et_replay/comm/backend/base_backend.py b/et_replay/comm/backend/base_backend.py index 6ab271e8..2f421703 100644 --- a/et_replay/comm/backend/base_backend.py +++ b/et_replay/comm/backend/base_backend.py @@ -5,7 +5,6 @@ import logging from abc import ABC, abstractmethod -from typing import Dict, List, Optional import torch diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index 09097d74..397b4c0c 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -10,7 +10,6 @@ from itertools import cycle from time import sleep -from typing import Dict, List, Optional, Tuple import numpy as np import torch diff --git a/et_replay/comm/commsTraceParser.py b/et_replay/comm/commsTraceParser.py index 0a2385d5..2644d80d 100644 --- a/et_replay/comm/commsTraceParser.py +++ b/et_replay/comm/commsTraceParser.py @@ -4,7 +4,6 @@ import json import logging -from typing import List, Tuple from et_replay import ExecutionTrace from et_replay.comm import comms_utils diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index 54e7ff1d..c9c29886 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -17,7 +17,7 @@ from collections.abc import Callable from contextlib import ContextDecorator from io import StringIO -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any try: from param_bench.train.comms.pt.fb.internals import ( diff --git a/et_replay/execution_trace.py b/et_replay/execution_trace.py index 67409953..c71d0185 100644 --- a/et_replay/execution_trace.py +++ b/et_replay/execution_trace.py @@ -9,7 +9,7 @@ from collections.abc import Iterable from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Set, TextIO, Tuple +from typing import Any, TextIO import pydot diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index bf349666..d6a2601b 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -11,7 +11,6 @@ import logging import os import time -from typing import Dict, List, Set, Tuple, Union import numpy as np import torch @@ -114,7 +113,7 @@ def __init__(self): self.max_msg_cnt = 0 # 0 means no limit self.num_msg = 0 self.is_blocking = False - self.do_warm_up = False + self.warmup_iter = 5 self.reuse_tensors = False self.allowList = "" @@ -215,10 +214,10 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace: help="Only replay first N operations (0 means no limit)", ) parser.add_argument( - "--do-warm-up", - action="store_true", - default=self.do_warm_up, - help="Toggle to enable performing extra replaying for warm-up", + "--warmup-iter", + type=int, + default=self.warmup_iter, + help="Number of warmup iterations", ) parser.add_argument( "--reuse-tensors", @@ -275,7 +274,7 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace: "--profiler-num-replays-start", type=int, default=self.profiler_num_replays_start, - help=f"Replay iteration to start collecting profiler after warmup (if --do-warm-up is True). Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True", + help=f"Replay iteration to start collecting profiler after warmup runs. Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True", ) parser.add_argument( "--profiler-num-replays", @@ -393,12 +392,19 @@ def reportBenchTime(self): if not self.is_dry_run: print("\n{} Performance of replayed comms {}".format("=" * 20, "=" * 20)) print( - "{}\n Total latency (us) of comms in trace {}: \n{}".format( + "{}\n Total latency (us) of comms in trace: {}. \n{}".format( "-" * 50, self.totalTraceLatency, "-" * 50, ) ) + print( + "{}\n Average latency (us) of comms in trace: {}. \n{}".format( + "-" * 50, + self.totalTraceLatency / self.num_replays, + "-" * 50, + ) + ) for coll, lats in self.collLat.items(): if len(lats) == 0: continue @@ -1199,9 +1205,7 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: """ if commsParams.enable_profiler: # num of iterations to skip - numWarmupIters = ( - 1 if self.do_warm_up else 0 - ) + self.profiler_num_replays_start + numWarmupIters = self.warmup_iter + self.profiler_num_replays_start # num of iterations to profile, at most num_replays iterations numProfileIters = ( self.profiler_num_replays @@ -1215,18 +1219,6 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: numIters=numProfileIters, ) - # warm-up - if self.do_warm_up: - if self.collectiveArgs.enable_profiler: - comms_utils.sampleProfiler() - self.replayIter = -1 - self.replayTrace(commsParams=commsParams, warmup=True) - self.resetComms() - - # sync everything before starting real runs - with paramProfile(description="# PARAM replay warmup post-replay global sync"): - self.backendFuncs.sync_barrier(self.collectiveArgs) - if self.backendFuncs.get_global_rank() == 0: logger.info( f"\n+ {self.max_msg_cnt} messages in the trace...replaying (if present) {list(self.allowList)}" @@ -1234,10 +1226,10 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: for coll, sizes in self.collInMsgBytes.items(): logger.info(f"\t{coll}: {len(sizes)}") - traceStartTime = time.monotonic_ns() - for i in range(self.num_replays): - if self.backendFuncs.get_global_rank() == 0: - logger.info(f"Replay #{i}") + traceStartTime = 0 + for i in range(self.warmup_iter + self.num_replays): + if i == self.warmup_iter: + traceStartTime = time.monotonic_ns() if self.collectiveArgs.enable_profiler: comms_utils.sampleProfiler() @@ -1248,6 +1240,9 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: except NameError: pass + if self.backendFuncs.get_global_rank() == 0: + s = time.monotonic_ns() + # replay comms trace self.replayIter = i self.replayTrace(commsParams=commsParams, warmup=False) @@ -1259,6 +1254,10 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: ): self.backendFuncs.sync_barrier(self.collectiveArgs) + if self.backendFuncs.get_global_rank() == 0: + e = time.monotonic_ns() + logger.info(f"Replay #{i} took {(e-s)/1e3:.2f} us") + # record how long it took for trace-replay to complete traceEndTime = time.monotonic_ns() self.totalTraceLatency = (traceEndTime - traceStartTime) / 1e3 # make it us @@ -1457,7 +1456,7 @@ def initBench( self.shrink = args.auto_shrink self.max_msg_cnt = args.max_msg_cnt self.is_blocking = args.blocking - self.do_warm_up = args.do_warm_up + self.warmup_iter = args.warmup_iter self.reuse_tensors = args.reuse_tensors self.allowList = args.allow_ops if args.output_ranks == "all": diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 0d115f40..ca7b4a5b 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -7,7 +7,6 @@ import time from collections import defaultdict from datetime import datetime -from typing import Dict import numpy as np import torch diff --git a/train/comms/pt/commsTraceParser.py b/train/comms/pt/commsTraceParser.py index 929097b8..a0471250 100644 --- a/train/comms/pt/commsTraceParser.py +++ b/train/comms/pt/commsTraceParser.py @@ -3,8 +3,6 @@ import json -from typing import List, Tuple - from et_replay import ExecutionTrace from param_bench.train.comms.pt import comms_utils @@ -188,7 +186,7 @@ def _parseKinetoUnitrace(in_trace: list, target_rank: int) -> list: def _getTensorInfoFromPyTorchETEntry( tensor_container: list, container_type: str -) -> tuple[int, int, str]: +) -> tuple[int, str]: """ Extract message size, tensor count, type from PyTorch ET entry inputs/outputs field. NOTE: This format can be changed at anytime. TODO: When an extract/parsing tool is available in ATC, switch to it.