Skip to content

Commit

Permalink
Allow user defined warm-up iterations (#189)
Browse files Browse the repository at this point in the history
Summary:

Allow user to specify the number of warmup iteration runs. Also print out the wall clock time for each iteration and the average wall clock time for the whole run.

Differential Revision: D66801912
  • Loading branch information
shengfukevin authored and facebook-github-bot committed Dec 5, 2024
1 parent f3b1968 commit d1a97b9
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 34 deletions.
1 change: 0 additions & 1 deletion et_replay/comm/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional

import torch

Expand Down
1 change: 0 additions & 1 deletion et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from itertools import cycle
from time import sleep
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
Expand Down
1 change: 0 additions & 1 deletion et_replay/comm/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json

import logging
from typing import List, Tuple

from et_replay import ExecutionTrace
from et_replay.comm import comms_utils
Expand Down
2 changes: 1 addition & 1 deletion et_replay/comm/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections.abc import Callable
from contextlib import ContextDecorator
from io import StringIO
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any

try:
from param_bench.train.comms.pt.fb.internals import (
Expand Down
2 changes: 1 addition & 1 deletion et_replay/execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Set, TextIO, Tuple
from typing import Any, TextIO

import pydot

Expand Down
56 changes: 28 additions & 28 deletions et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import logging
import os
import time
from typing import Dict, List, Set, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -114,7 +113,7 @@ def __init__(self):
self.max_msg_cnt = 0 # 0 means no limit
self.num_msg = 0
self.is_blocking = False
self.do_warm_up = False
self.warmup_iter = 5
self.reuse_tensors = False

self.allowList = ""
Expand Down Expand Up @@ -215,10 +214,10 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
help="Only replay first N operations (0 means no limit)",
)
parser.add_argument(
"--do-warm-up",
action="store_true",
default=self.do_warm_up,
help="Toggle to enable performing extra replaying for warm-up",
"--warmup-iter",
type=int,
default=self.warmup_iter,
help="Number of warmup iterations",
)
parser.add_argument(
"--reuse-tensors",
Expand Down Expand Up @@ -275,7 +274,7 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
"--profiler-num-replays-start",
type=int,
default=self.profiler_num_replays_start,
help=f"Replay iteration to start collecting profiler after warmup (if --do-warm-up is True). Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True",
help=f"Replay iteration to start collecting profiler after warmup runs. Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True",
)
parser.add_argument(
"--profiler-num-replays",
Expand Down Expand Up @@ -393,12 +392,19 @@ def reportBenchTime(self):
if not self.is_dry_run:
print("\n{} Performance of replayed comms {}".format("=" * 20, "=" * 20))
print(
"{}\n Total latency (us) of comms in trace {}: \n{}".format(
"{}\n Total latency (us) of comms in trace: {}. \n{}".format(
"-" * 50,
self.totalTraceLatency,
"-" * 50,
)
)
print(
"{}\n Average latency (us) of comms in trace: {}. \n{}".format(
"-" * 50,
self.totalTraceLatency / self.num_replays,
"-" * 50,
)
)
for coll, lats in self.collLat.items():
if len(lats) == 0:
continue
Expand Down Expand Up @@ -1199,9 +1205,7 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
"""
if commsParams.enable_profiler:
# num of iterations to skip
numWarmupIters = (
1 if self.do_warm_up else 0
) + self.profiler_num_replays_start
numWarmupIters = self.warmup_iter + self.profiler_num_replays_start
# num of iterations to profile, at most num_replays iterations
numProfileIters = (
self.profiler_num_replays
Expand All @@ -1215,29 +1219,18 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
numIters=numProfileIters,
)

# warm-up
if self.do_warm_up:
if self.collectiveArgs.enable_profiler:
comms_utils.sampleProfiler()
self.replayIter = -1
self.replayTrace(commsParams=commsParams, warmup=True)
self.resetComms()

# sync everything before starting real runs
with paramProfile(description="# PARAM replay warmup post-replay global sync"):
self.backendFuncs.sync_barrier(self.collectiveArgs)

if self.backendFuncs.get_global_rank() == 0:
logger.info(
f"\n+ {self.max_msg_cnt} messages in the trace...replaying (if present) {list(self.allowList)}"
)
for coll, sizes in self.collInMsgBytes.items():
logger.info(f"\t{coll}: {len(sizes)}")

traceStartTime = time.monotonic_ns()
for i in range(self.num_replays):
if self.backendFuncs.get_global_rank() == 0:
logger.info(f"Replay #{i}")
traceStartTime = 0
for i in range(self.warmup_iter + self.num_replays):

if i == self.warmup_iter:
traceStartTime = time.monotonic_ns()

if self.collectiveArgs.enable_profiler:
comms_utils.sampleProfiler()
Expand All @@ -1248,6 +1241,9 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
except NameError:
pass

if self.backendFuncs.get_global_rank() == 0:
s = time.monotonic_ns()

# replay comms trace
self.replayIter = i
self.replayTrace(commsParams=commsParams, warmup=False)
Expand All @@ -1259,6 +1255,10 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
):
self.backendFuncs.sync_barrier(self.collectiveArgs)

if self.backendFuncs.get_global_rank() == 0:
e = time.monotonic_ns()
logger.info(f"Replay #{i} took {(e-s)/1e3:.2f} us")

# record how long it took for trace-replay to complete
traceEndTime = time.monotonic_ns()
self.totalTraceLatency = (traceEndTime - traceStartTime) / 1e3 # make it us
Expand Down Expand Up @@ -1457,7 +1457,7 @@ def initBench(
self.shrink = args.auto_shrink
self.max_msg_cnt = args.max_msg_cnt
self.is_blocking = args.blocking
self.do_warm_up = args.do_warm_up
self.warmup_iter = args.warmup_iter
self.reuse_tensors = args.reuse_tensors
self.allowList = args.allow_ops
if args.output_ranks == "all":
Expand Down
1 change: 0 additions & 1 deletion et_replay/tools/et_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import time
from collections import defaultdict
from datetime import datetime
from typing import Dict

import numpy as np
import torch
Expand Down

0 comments on commit d1a97b9

Please sign in to comment.