Skip to content

Commit d1a97b9

Browse files
shengfukevinfacebook-github-bot
authored andcommitted
Allow user defined warm-up iterations (#189)
Summary: Allow user to specify the number of warmup iteration runs. Also print out the wall clock time for each iteration and the average wall clock time for the whole run. Differential Revision: D66801912
1 parent f3b1968 commit d1a97b9

File tree

7 files changed

+30
-34
lines changed

7 files changed

+30
-34
lines changed

et_replay/comm/backend/base_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import logging
77
from abc import ABC, abstractmethod
8-
from typing import Dict, List, Optional
98

109
import torch
1110

et_replay/comm/backend/pytorch_dist_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from itertools import cycle
1212
from time import sleep
13-
from typing import Dict, List, Optional, Tuple
1413

1514
import numpy as np
1615
import torch

et_replay/comm/commsTraceParser.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import json
55

66
import logging
7-
from typing import List, Tuple
87

98
from et_replay import ExecutionTrace
109
from et_replay.comm import comms_utils

et_replay/comm/comms_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections.abc import Callable
1818
from contextlib import ContextDecorator
1919
from io import StringIO
20-
from typing import Any, Dict, List, Optional, Tuple, Union
20+
from typing import Any
2121

2222
try:
2323
from param_bench.train.comms.pt.fb.internals import (

et_replay/execution_trace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from collections.abc import Iterable
1010
from dataclasses import dataclass
1111
from enum import Enum
12-
from typing import Any, Dict, List, Optional, Set, TextIO, Tuple
12+
from typing import Any, TextIO
1313

1414
import pydot
1515

et_replay/tools/comm_replay.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import logging
1212
import os
1313
import time
14-
from typing import Dict, List, Set, Tuple, Union
1514

1615
import numpy as np
1716
import torch
@@ -114,7 +113,7 @@ def __init__(self):
114113
self.max_msg_cnt = 0 # 0 means no limit
115114
self.num_msg = 0
116115
self.is_blocking = False
117-
self.do_warm_up = False
116+
self.warmup_iter = 5
118117
self.reuse_tensors = False
119118

120119
self.allowList = ""
@@ -215,10 +214,10 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
215214
help="Only replay first N operations (0 means no limit)",
216215
)
217216
parser.add_argument(
218-
"--do-warm-up",
219-
action="store_true",
220-
default=self.do_warm_up,
221-
help="Toggle to enable performing extra replaying for warm-up",
217+
"--warmup-iter",
218+
type=int,
219+
default=self.warmup_iter,
220+
help="Number of warmup iterations",
222221
)
223222
parser.add_argument(
224223
"--reuse-tensors",
@@ -275,7 +274,7 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
275274
"--profiler-num-replays-start",
276275
type=int,
277276
default=self.profiler_num_replays_start,
278-
help=f"Replay iteration to start collecting profiler after warmup (if --do-warm-up is True). Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True",
277+
help=f"Replay iteration to start collecting profiler after warmup runs. Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True",
279278
)
280279
parser.add_argument(
281280
"--profiler-num-replays",
@@ -393,12 +392,19 @@ def reportBenchTime(self):
393392
if not self.is_dry_run:
394393
print("\n{} Performance of replayed comms {}".format("=" * 20, "=" * 20))
395394
print(
396-
"{}\n Total latency (us) of comms in trace {}: \n{}".format(
395+
"{}\n Total latency (us) of comms in trace: {}. \n{}".format(
397396
"-" * 50,
398397
self.totalTraceLatency,
399398
"-" * 50,
400399
)
401400
)
401+
print(
402+
"{}\n Average latency (us) of comms in trace: {}. \n{}".format(
403+
"-" * 50,
404+
self.totalTraceLatency / self.num_replays,
405+
"-" * 50,
406+
)
407+
)
402408
for coll, lats in self.collLat.items():
403409
if len(lats) == 0:
404410
continue
@@ -1199,9 +1205,7 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
11991205
"""
12001206
if commsParams.enable_profiler:
12011207
# num of iterations to skip
1202-
numWarmupIters = (
1203-
1 if self.do_warm_up else 0
1204-
) + self.profiler_num_replays_start
1208+
numWarmupIters = self.warmup_iter + self.profiler_num_replays_start
12051209
# num of iterations to profile, at most num_replays iterations
12061210
numProfileIters = (
12071211
self.profiler_num_replays
@@ -1215,29 +1219,18 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
12151219
numIters=numProfileIters,
12161220
)
12171221

1218-
# warm-up
1219-
if self.do_warm_up:
1220-
if self.collectiveArgs.enable_profiler:
1221-
comms_utils.sampleProfiler()
1222-
self.replayIter = -1
1223-
self.replayTrace(commsParams=commsParams, warmup=True)
1224-
self.resetComms()
1225-
1226-
# sync everything before starting real runs
1227-
with paramProfile(description="# PARAM replay warmup post-replay global sync"):
1228-
self.backendFuncs.sync_barrier(self.collectiveArgs)
1229-
12301222
if self.backendFuncs.get_global_rank() == 0:
12311223
logger.info(
12321224
f"\n+ {self.max_msg_cnt} messages in the trace...replaying (if present) {list(self.allowList)}"
12331225
)
12341226
for coll, sizes in self.collInMsgBytes.items():
12351227
logger.info(f"\t{coll}: {len(sizes)}")
12361228

1237-
traceStartTime = time.monotonic_ns()
1238-
for i in range(self.num_replays):
1239-
if self.backendFuncs.get_global_rank() == 0:
1240-
logger.info(f"Replay #{i}")
1229+
traceStartTime = 0
1230+
for i in range(self.warmup_iter + self.num_replays):
1231+
1232+
if i == self.warmup_iter:
1233+
traceStartTime = time.monotonic_ns()
12411234

12421235
if self.collectiveArgs.enable_profiler:
12431236
comms_utils.sampleProfiler()
@@ -1248,6 +1241,9 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
12481241
except NameError:
12491242
pass
12501243

1244+
if self.backendFuncs.get_global_rank() == 0:
1245+
s = time.monotonic_ns()
1246+
12511247
# replay comms trace
12521248
self.replayIter = i
12531249
self.replayTrace(commsParams=commsParams, warmup=False)
@@ -1259,6 +1255,10 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
12591255
):
12601256
self.backendFuncs.sync_barrier(self.collectiveArgs)
12611257

1258+
if self.backendFuncs.get_global_rank() == 0:
1259+
e = time.monotonic_ns()
1260+
logger.info(f"Replay #{i} took {(e-s)/1e3:.2f} us")
1261+
12621262
# record how long it took for trace-replay to complete
12631263
traceEndTime = time.monotonic_ns()
12641264
self.totalTraceLatency = (traceEndTime - traceStartTime) / 1e3 # make it us
@@ -1457,7 +1457,7 @@ def initBench(
14571457
self.shrink = args.auto_shrink
14581458
self.max_msg_cnt = args.max_msg_cnt
14591459
self.is_blocking = args.blocking
1460-
self.do_warm_up = args.do_warm_up
1460+
self.warmup_iter = args.warmup_iter
14611461
self.reuse_tensors = args.reuse_tensors
14621462
self.allowList = args.allow_ops
14631463
if args.output_ranks == "all":

et_replay/tools/et_replay.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import time
88
from collections import defaultdict
99
from datetime import datetime
10-
from typing import Dict
1110

1211
import numpy as np
1312
import torch

0 commit comments

Comments
 (0)