Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Upgrade Python to 3.10 and allow user to specify the number of warmup…
Browse files Browse the repository at this point in the history
… iterations (facebookresearch#189)

Summary:

Meta production is on Python 3.10. PARAM lint is still on 3.8. I saw Lint reported some issues related to type hint syntax. To fix it, I upgrade python to 3.10.

ALso changed the option to allow user to specify the number of warmup iteration runs. Also print out the wall clock time for each iteration and the average wall clock time for the whole run.

Reviewed By: kingchc

Differential Revision: D66801912
shengfukevin authored and facebook-github-bot committed Dec 10, 2024
1 parent 7e5ac05 commit 4b5b050
Showing 8 changed files with 30 additions and 37 deletions.
1 change: 0 additions & 1 deletion et_replay/comm/backend/base_backend.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,6 @@

import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional

import torch

1 change: 0 additions & 1 deletion et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,6 @@

from itertools import cycle
from time import sleep
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
1 change: 0 additions & 1 deletion et_replay/comm/commsTraceParser.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@
import json

import logging
from typing import List, Tuple

from et_replay import ExecutionTrace
from et_replay.comm import comms_utils
2 changes: 1 addition & 1 deletion et_replay/comm/comms_utils.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
from collections.abc import Callable
from contextlib import ContextDecorator
from io import StringIO
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any

try:
from param_bench.train.comms.pt.fb.internals import (
2 changes: 1 addition & 1 deletion et_replay/execution_trace.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Set, TextIO, Tuple
from typing import Any, TextIO

import pydot

55 changes: 27 additions & 28 deletions et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,6 @@
import logging
import os
import time
from typing import Dict, List, Set, Tuple, Union

import numpy as np
import torch
@@ -114,7 +113,7 @@ def __init__(self):
self.max_msg_cnt = 0 # 0 means no limit
self.num_msg = 0
self.is_blocking = False
self.do_warm_up = False
self.warmup_iter = 5
self.reuse_tensors = False

self.allowList = ""
@@ -215,10 +214,10 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
help="Only replay first N operations (0 means no limit)",
)
parser.add_argument(
"--do-warm-up",
action="store_true",
default=self.do_warm_up,
help="Toggle to enable performing extra replaying for warm-up",
"--warmup-iter",
type=int,
default=self.warmup_iter,
help="Number of warmup iterations",
)
parser.add_argument(
"--reuse-tensors",
@@ -275,7 +274,7 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
"--profiler-num-replays-start",
type=int,
default=self.profiler_num_replays_start,
help=f"Replay iteration to start collecting profiler after warmup (if --do-warm-up is True). Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True",
help=f"Replay iteration to start collecting profiler after warmup runs. Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True",
)
parser.add_argument(
"--profiler-num-replays",
@@ -393,12 +392,19 @@ def reportBenchTime(self):
if not self.is_dry_run:
print("\n{} Performance of replayed comms {}".format("=" * 20, "=" * 20))
print(
"{}\n Total latency (us) of comms in trace {}: \n{}".format(
"{}\n Total latency (us) of comms in trace: {}. \n{}".format(
"-" * 50,
self.totalTraceLatency,
"-" * 50,
)
)
print(
"{}\n Average latency (us) of comms in trace: {}. \n{}".format(
"-" * 50,
self.totalTraceLatency / self.num_replays,
"-" * 50,
)
)
for coll, lats in self.collLat.items():
if len(lats) == 0:
continue
@@ -1199,9 +1205,7 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
"""
if commsParams.enable_profiler:
# num of iterations to skip
numWarmupIters = (
1 if self.do_warm_up else 0
) + self.profiler_num_replays_start
numWarmupIters = self.warmup_iter + self.profiler_num_replays_start
# num of iterations to profile, at most num_replays iterations
numProfileIters = (
self.profiler_num_replays
@@ -1215,29 +1219,17 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
numIters=numProfileIters,
)

# warm-up
if self.do_warm_up:
if self.collectiveArgs.enable_profiler:
comms_utils.sampleProfiler()
self.replayIter = -1
self.replayTrace(commsParams=commsParams, warmup=True)
self.resetComms()

# sync everything before starting real runs
with paramProfile(description="# PARAM replay warmup post-replay global sync"):
self.backendFuncs.sync_barrier(self.collectiveArgs)

if self.backendFuncs.get_global_rank() == 0:
logger.info(
f"\n+ {self.max_msg_cnt} messages in the trace...replaying (if present) {list(self.allowList)}"
)
for coll, sizes in self.collInMsgBytes.items():
logger.info(f"\t{coll}: {len(sizes)}")

traceStartTime = time.monotonic_ns()
for i in range(self.num_replays):
if self.backendFuncs.get_global_rank() == 0:
logger.info(f"Replay #{i}")
traceStartTime = 0
for i in range(self.warmup_iter + self.num_replays):
if i == self.warmup_iter:
traceStartTime = time.monotonic_ns()

if self.collectiveArgs.enable_profiler:
comms_utils.sampleProfiler()
@@ -1248,6 +1240,9 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
except NameError:
pass

if self.backendFuncs.get_global_rank() == 0:
s = time.monotonic_ns()

# replay comms trace
self.replayIter = i
self.replayTrace(commsParams=commsParams, warmup=False)
@@ -1259,6 +1254,10 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
):
self.backendFuncs.sync_barrier(self.collectiveArgs)

if self.backendFuncs.get_global_rank() == 0:
e = time.monotonic_ns()
logger.info(f"Replay #{i} took {(e-s)/1e3:.2f} us")

# record how long it took for trace-replay to complete
traceEndTime = time.monotonic_ns()
self.totalTraceLatency = (traceEndTime - traceStartTime) / 1e3 # make it us
@@ -1457,7 +1456,7 @@ def initBench(
self.shrink = args.auto_shrink
self.max_msg_cnt = args.max_msg_cnt
self.is_blocking = args.blocking
self.do_warm_up = args.do_warm_up
self.warmup_iter = args.warmup_iter
self.reuse_tensors = args.reuse_tensors
self.allowList = args.allow_ops
if args.output_ranks == "all":
1 change: 0 additions & 1 deletion et_replay/tools/et_replay.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,6 @@
import time
from collections import defaultdict
from datetime import datetime
from typing import Dict

import numpy as np
import torch
4 changes: 1 addition & 3 deletions train/comms/pt/commsTraceParser.py
Original file line number Diff line number Diff line change
@@ -3,8 +3,6 @@

import json

from typing import List, Tuple

from et_replay import ExecutionTrace

from param_bench.train.comms.pt import comms_utils
@@ -188,7 +186,7 @@ def _parseKinetoUnitrace(in_trace: list, target_rank: int) -> list:

def _getTensorInfoFromPyTorchETEntry(
tensor_container: list, container_type: str
) -> tuple[int, int, str]:
) -> tuple[int, str]:
"""
Extract message size, tensor count, type from PyTorch ET entry inputs/outputs field.
NOTE: This format can be changed at anytime. TODO: When an extract/parsing tool is available in ATC, switch to it.

0 comments on commit 4b5b050

Please sign in to comment.