From d1a97b974b8b10e9ea2e9cc3bb87cf7bccfad9d4 Mon Sep 17 00:00:00 2001 From: Sheng Fu Date: Thu, 5 Dec 2024 11:47:52 -0800 Subject: [PATCH] Allow user defined warm-up iterations (#189) Summary: Allow user to specify the number of warmup iteration runs. Also print out the wall clock time for each iteration and the average wall clock time for the whole run. Differential Revision: D66801912 --- et_replay/comm/backend/base_backend.py | 1 - .../comm/backend/pytorch_dist_backend.py | 1 - et_replay/comm/commsTraceParser.py | 1 - et_replay/comm/comms_utils.py | 2 +- et_replay/execution_trace.py | 2 +- et_replay/tools/comm_replay.py | 56 +++++++++---------- et_replay/tools/et_replay.py | 1 - 7 files changed, 30 insertions(+), 34 deletions(-) diff --git a/et_replay/comm/backend/base_backend.py b/et_replay/comm/backend/base_backend.py index 1d4b77f9..2f0e91dc 100644 --- a/et_replay/comm/backend/base_backend.py +++ b/et_replay/comm/backend/base_backend.py @@ -5,7 +5,6 @@ import logging from abc import ABC, abstractmethod -from typing import Dict, List, Optional import torch diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index 09097d74..397b4c0c 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -10,7 +10,6 @@ from itertools import cycle from time import sleep -from typing import Dict, List, Optional, Tuple import numpy as np import torch diff --git a/et_replay/comm/commsTraceParser.py b/et_replay/comm/commsTraceParser.py index 186a6912..bade289e 100644 --- a/et_replay/comm/commsTraceParser.py +++ b/et_replay/comm/commsTraceParser.py @@ -4,7 +4,6 @@ import json import logging -from typing import List, Tuple from et_replay import ExecutionTrace from et_replay.comm import comms_utils diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index 5f4ac12b..d336b4aa 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -17,7 +17,7 @@ from collections.abc import Callable from contextlib import ContextDecorator from io import StringIO -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any try: from param_bench.train.comms.pt.fb.internals import ( diff --git a/et_replay/execution_trace.py b/et_replay/execution_trace.py index cd972d21..bc6e1b44 100644 --- a/et_replay/execution_trace.py +++ b/et_replay/execution_trace.py @@ -9,7 +9,7 @@ from collections.abc import Iterable from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Set, TextIO, Tuple +from typing import Any, TextIO import pydot diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index bf349666..32c7665d 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -11,7 +11,6 @@ import logging import os import time -from typing import Dict, List, Set, Tuple, Union import numpy as np import torch @@ -114,7 +113,7 @@ def __init__(self): self.max_msg_cnt = 0 # 0 means no limit self.num_msg = 0 self.is_blocking = False - self.do_warm_up = False + self.warmup_iter = 5 self.reuse_tensors = False self.allowList = "" @@ -215,10 +214,10 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace: help="Only replay first N operations (0 means no limit)", ) parser.add_argument( - "--do-warm-up", - action="store_true", - default=self.do_warm_up, - help="Toggle to enable performing extra replaying for warm-up", + "--warmup-iter", + type=int, + default=self.warmup_iter, + help="Number of warmup iterations", ) parser.add_argument( "--reuse-tensors", @@ -275,7 +274,7 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace: "--profiler-num-replays-start", type=int, default=self.profiler_num_replays_start, - help=f"Replay iteration to start collecting profiler after warmup (if --do-warm-up is True). Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True", + help=f"Replay iteration to start collecting profiler after warmup runs. Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True", ) parser.add_argument( "--profiler-num-replays", @@ -393,12 +392,19 @@ def reportBenchTime(self): if not self.is_dry_run: print("\n{} Performance of replayed comms {}".format("=" * 20, "=" * 20)) print( - "{}\n Total latency (us) of comms in trace {}: \n{}".format( + "{}\n Total latency (us) of comms in trace: {}. \n{}".format( "-" * 50, self.totalTraceLatency, "-" * 50, ) ) + print( + "{}\n Average latency (us) of comms in trace: {}. \n{}".format( + "-" * 50, + self.totalTraceLatency / self.num_replays, + "-" * 50, + ) + ) for coll, lats in self.collLat.items(): if len(lats) == 0: continue @@ -1199,9 +1205,7 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: """ if commsParams.enable_profiler: # num of iterations to skip - numWarmupIters = ( - 1 if self.do_warm_up else 0 - ) + self.profiler_num_replays_start + numWarmupIters = self.warmup_iter + self.profiler_num_replays_start # num of iterations to profile, at most num_replays iterations numProfileIters = ( self.profiler_num_replays @@ -1215,18 +1219,6 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: numIters=numProfileIters, ) - # warm-up - if self.do_warm_up: - if self.collectiveArgs.enable_profiler: - comms_utils.sampleProfiler() - self.replayIter = -1 - self.replayTrace(commsParams=commsParams, warmup=True) - self.resetComms() - - # sync everything before starting real runs - with paramProfile(description="# PARAM replay warmup post-replay global sync"): - self.backendFuncs.sync_barrier(self.collectiveArgs) - if self.backendFuncs.get_global_rank() == 0: logger.info( f"\n+ {self.max_msg_cnt} messages in the trace...replaying (if present) {list(self.allowList)}" @@ -1234,10 +1226,11 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: for coll, sizes in self.collInMsgBytes.items(): logger.info(f"\t{coll}: {len(sizes)}") - traceStartTime = time.monotonic_ns() - for i in range(self.num_replays): - if self.backendFuncs.get_global_rank() == 0: - logger.info(f"Replay #{i}") + traceStartTime = 0 + for i in range(self.warmup_iter + self.num_replays): + + if i == self.warmup_iter: + traceStartTime = time.monotonic_ns() if self.collectiveArgs.enable_profiler: comms_utils.sampleProfiler() @@ -1248,6 +1241,9 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: except NameError: pass + if self.backendFuncs.get_global_rank() == 0: + s = time.monotonic_ns() + # replay comms trace self.replayIter = i self.replayTrace(commsParams=commsParams, warmup=False) @@ -1259,6 +1255,10 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None: ): self.backendFuncs.sync_barrier(self.collectiveArgs) + if self.backendFuncs.get_global_rank() == 0: + e = time.monotonic_ns() + logger.info(f"Replay #{i} took {(e-s)/1e3:.2f} us") + # record how long it took for trace-replay to complete traceEndTime = time.monotonic_ns() self.totalTraceLatency = (traceEndTime - traceStartTime) / 1e3 # make it us @@ -1457,7 +1457,7 @@ def initBench( self.shrink = args.auto_shrink self.max_msg_cnt = args.max_msg_cnt self.is_blocking = args.blocking - self.do_warm_up = args.do_warm_up + self.warmup_iter = args.warmup_iter self.reuse_tensors = args.reuse_tensors self.allowList = args.allow_ops if args.output_ranks == "all": diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 0d115f40..ca7b4a5b 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -7,7 +7,6 @@ import time from collections import defaultdict from datetime import datetime -from typing import Dict import numpy as np import torch