Skip to content

Commit

Permalink
Upgrade Python to 3.10 and allow user to specify the number of warmup…
Browse files Browse the repository at this point in the history
… iterations (facebookresearch#189)

Summary:

Meta production is on Python 3.10. PARAM lint is still on 3.8. I saw Lint reported some issues related to type hint syntax. To fix it, I upgrade python to 3.10.

ALso changed the option to allow user to specify the number of warmup iteration runs. Also print out the wall clock time for each iteration and the average wall clock time for the whole run.

Reviewed By: kingchc

Differential Revision: D66801912
  • Loading branch information
shengfukevin authored and facebook-github-bot committed Dec 10, 2024
1 parent 7e5ac05 commit 06f95dd
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 40 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/python_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ jobs:
- name: Setup Python Environment
uses: actions/setup-python@v2
with:
python-version: '3.8'
python-version: '3.10'

- name: Install Dependencies
run: |
pip install black
- name: Run Black
run: black . --check
# - name: Run Black
# run: black . --check

- name: Run tests
run: |
Expand Down
1 change: 0 additions & 1 deletion et_replay/comm/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import logging
from abc import ABC, abstractmethod
from typing import Dict, List, Optional

import torch

Expand Down
1 change: 0 additions & 1 deletion et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from itertools import cycle
from time import sleep
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
Expand Down
1 change: 0 additions & 1 deletion et_replay/comm/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json

import logging
from typing import List, Tuple

from et_replay import ExecutionTrace
from et_replay.comm import comms_utils
Expand Down
2 changes: 1 addition & 1 deletion et_replay/comm/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections.abc import Callable
from contextlib import ContextDecorator
from io import StringIO
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any

try:
from param_bench.train.comms.pt.fb.internals import (
Expand Down
2 changes: 1 addition & 1 deletion et_replay/execution_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Set, TextIO, Tuple
from typing import Any, TextIO

import pydot

Expand Down
55 changes: 27 additions & 28 deletions et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import logging
import os
import time
from typing import Dict, List, Set, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -114,7 +113,7 @@ def __init__(self):
self.max_msg_cnt = 0 # 0 means no limit
self.num_msg = 0
self.is_blocking = False
self.do_warm_up = False
self.warmup_iter = 5
self.reuse_tensors = False

self.allowList = ""
Expand Down Expand Up @@ -215,10 +214,10 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
help="Only replay first N operations (0 means no limit)",
)
parser.add_argument(
"--do-warm-up",
action="store_true",
default=self.do_warm_up,
help="Toggle to enable performing extra replaying for warm-up",
"--warmup-iter",
type=int,
default=self.warmup_iter,
help="Number of warmup iterations",
)
parser.add_argument(
"--reuse-tensors",
Expand Down Expand Up @@ -275,7 +274,7 @@ def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
"--profiler-num-replays-start",
type=int,
default=self.profiler_num_replays_start,
help=f"Replay iteration to start collecting profiler after warmup (if --do-warm-up is True). Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True",
help=f"Replay iteration to start collecting profiler after warmup runs. Default start from {self.profiler_num_replays_start} replay if --enables-profiler is True",
)
parser.add_argument(
"--profiler-num-replays",
Expand Down Expand Up @@ -393,12 +392,19 @@ def reportBenchTime(self):
if not self.is_dry_run:
print("\n{} Performance of replayed comms {}".format("=" * 20, "=" * 20))
print(
"{}\n Total latency (us) of comms in trace {}: \n{}".format(
"{}\n Total latency (us) of comms in trace: {}. \n{}".format(
"-" * 50,
self.totalTraceLatency,
"-" * 50,
)
)
print(
"{}\n Average latency (us) of comms in trace: {}. \n{}".format(
"-" * 50,
self.totalTraceLatency / self.num_replays,
"-" * 50,
)
)
for coll, lats in self.collLat.items():
if len(lats) == 0:
continue
Expand Down Expand Up @@ -1199,9 +1205,7 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
"""
if commsParams.enable_profiler:
# num of iterations to skip
numWarmupIters = (
1 if self.do_warm_up else 0
) + self.profiler_num_replays_start
numWarmupIters = self.warmup_iter + self.profiler_num_replays_start
# num of iterations to profile, at most num_replays iterations
numProfileIters = (
self.profiler_num_replays
Expand All @@ -1215,29 +1219,17 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
numIters=numProfileIters,
)

# warm-up
if self.do_warm_up:
if self.collectiveArgs.enable_profiler:
comms_utils.sampleProfiler()
self.replayIter = -1
self.replayTrace(commsParams=commsParams, warmup=True)
self.resetComms()

# sync everything before starting real runs
with paramProfile(description="# PARAM replay warmup post-replay global sync"):
self.backendFuncs.sync_barrier(self.collectiveArgs)

if self.backendFuncs.get_global_rank() == 0:
logger.info(
f"\n+ {self.max_msg_cnt} messages in the trace...replaying (if present) {list(self.allowList)}"
)
for coll, sizes in self.collInMsgBytes.items():
logger.info(f"\t{coll}: {len(sizes)}")

traceStartTime = time.monotonic_ns()
for i in range(self.num_replays):
if self.backendFuncs.get_global_rank() == 0:
logger.info(f"Replay #{i}")
traceStartTime = 0
for i in range(self.warmup_iter + self.num_replays):
if i == self.warmup_iter:
traceStartTime = time.monotonic_ns()

if self.collectiveArgs.enable_profiler:
comms_utils.sampleProfiler()
Expand All @@ -1248,6 +1240,9 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
except NameError:
pass

if self.backendFuncs.get_global_rank() == 0:
s = time.monotonic_ns()

# replay comms trace
self.replayIter = i
self.replayTrace(commsParams=commsParams, warmup=False)
Expand All @@ -1259,6 +1254,10 @@ def benchTime(self, commsParams: commsParamsHolderBase) -> None:
):
self.backendFuncs.sync_barrier(self.collectiveArgs)

if self.backendFuncs.get_global_rank() == 0:
e = time.monotonic_ns()
logger.info(f"Replay #{i} took {(e-s)/1e3:.2f} us")

# record how long it took for trace-replay to complete
traceEndTime = time.monotonic_ns()
self.totalTraceLatency = (traceEndTime - traceStartTime) / 1e3 # make it us
Expand Down Expand Up @@ -1457,7 +1456,7 @@ def initBench(
self.shrink = args.auto_shrink
self.max_msg_cnt = args.max_msg_cnt
self.is_blocking = args.blocking
self.do_warm_up = args.do_warm_up
self.warmup_iter = args.warmup_iter
self.reuse_tensors = args.reuse_tensors
self.allowList = args.allow_ops
if args.output_ranks == "all":
Expand Down
1 change: 0 additions & 1 deletion et_replay/tools/et_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import time
from collections import defaultdict
from datetime import datetime
from typing import Dict

import numpy as np
import torch
Expand Down
4 changes: 1 addition & 3 deletions train/comms/pt/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import json

from typing import List, Tuple

from et_replay import ExecutionTrace

from param_bench.train.comms.pt import comms_utils
Expand Down Expand Up @@ -188,7 +186,7 @@ def _parseKinetoUnitrace(in_trace: list, target_rank: int) -> list:

def _getTensorInfoFromPyTorchETEntry(
tensor_container: list, container_type: str
) -> tuple[int, int, str]:
) -> tuple[int, str]:
"""
Extract message size, tensor count, type from PyTorch ET entry inputs/outputs field.
NOTE: This format can be changed at anytime. TODO: When an extract/parsing tool is available in ATC, switch to it.
Expand Down

0 comments on commit 06f95dd

Please sign in to comment.