diff --git a/et_replay/comm/backend/base_backend.py b/et_replay/comm/backend/base_backend.py index 1e367dca..488c42de 100644 --- a/et_replay/comm/backend/base_backend.py +++ b/et_replay/comm/backend/base_backend.py @@ -131,21 +131,21 @@ class BaseBackend(ABC): def __init__(self) -> None: self.tcp_store = None self.collectiveFunc = { - "all_to_all_single": self.all_to_all_single, + "all_to_all_single": self.all_to_all_single, # pyre-ignore[16]: "all_to_all": self.all_to_all, "all_to_allv": self.all_to_allv, "all_reduce": self.all_reduce, - "broadcast": self.broadcast, - "gather": self.gather, - "all_gather": self.all_gather, - "all_gather_base": self.all_gather_base, + "broadcast": self.broadcast, # pyre-ignore[16]: + "gather": self.gather, # pyre-ignore[16]: + "all_gather": self.all_gather, # pyre-ignore[16]: + "all_gather_base": self.all_gather_base, # pyre-ignore[16]: "reduce": self.reduce, - "reduce_scatter": self.reduce_scatter, - "reduce_scatter_base": self.reduce_scatter_base, - "scatter": self.scatter, + "reduce_scatter": self.reduce_scatter, # pyre-ignore[16]: + "reduce_scatter_base": self.reduce_scatter_base, # pyre-ignore[16]: + "scatter": self.scatter, # pyre-ignore[16]: "barrier": self.barrier, - "incast": self.incast, - "multicast": self.multicast, + "incast": self.incast, # pyre-ignore[16]: + "multicast": self.multicast, # pyre-ignore[16]: "noop": self.noop, } @@ -153,7 +153,7 @@ def __init__(self) -> None: def alloc_ones( self, - sizeArr: int, + sizeArr: List[int], curRankDevice: str = "cuda", dtype: torch.dtype = torch.float32, scaleFactor: float = 1.0, @@ -176,7 +176,7 @@ def alloc_ones( def noop( self, - collectiveArgs: collectiveArgsHolder = None, + collectiveArgs: collectiveArgsHolder, retFlag: bool = False, pair: bool = False, ) -> None: @@ -236,7 +236,7 @@ def get_mem_size(self, collectiveArgs: collectiveArgsHolder) -> int: @abstractmethod def alloc_random( self, - sizeArr: int, + sizeArr: List[int], curRankDevice: str, dtype: torch.dtype, scaleFactor: float = 1.0, @@ -253,7 +253,7 @@ def alloc_embedding_tables( @abstractmethod def alloc_empty( - self, sizeArr: int, dtype: torch.dtype, curRankDevice: str + self, sizeArr: List[int], dtype: torch.dtype, curRankDevice: str ) -> torch.Tensor: """Allocate tensor with uninitialized data based on parameters.""" pass diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index 7bf17810..765b4337 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -30,7 +30,7 @@ except ImportError: try: # Open-source extend_distributed.py can be found in https://github.com/facebookresearch/dlrm - import extend_distributed + import extend_distributed # pyre-ignore[21]: has_ext_dist = True except ImportError: @@ -783,7 +783,11 @@ def get_mem_size(self, collectiveArgs, pair=False): return _sizeBytes def alloc_random( - self, sizeArr, curRankDevice="cuda", dtype=torch.float32, scaleFactor=1.0 + self, + sizeArr: List[int], + curRankDevice="cuda", + dtype=torch.float32, + scaleFactor=1.0, ): if dtype in ( torch.int8, @@ -954,7 +958,7 @@ def sync_stream( device: Optional[torch.device] = None, ): """Synchronize a stream with its associated device""" - if device.type == "cuda": + if device is not None and device.type == "cuda": # if the stream is None, sync on the current default stream cur_stream = ( stream @@ -1000,7 +1004,7 @@ def __init__(self, bootstrap_info, commsParams): # Import Fairring if backend == "fairring": try: - import fairring # noqa + import fairring # pyre-ignore[21]: except ImportError: raise RuntimeError("Unable to import Fairring") diff --git a/et_replay/comm/backend/pytorch_tpu_backend.py b/et_replay/comm/backend/pytorch_tpu_backend.py index 1c07b629..96c0009b 100644 --- a/et_replay/comm/backend/pytorch_tpu_backend.py +++ b/et_replay/comm/backend/pytorch_tpu_backend.py @@ -4,8 +4,8 @@ import numpy as np import torch import torch.nn as nn -import torch_xla.core.xla_model as xm # @manual -import torch_xla.distributed.xla_multiprocessing as xmp # @manual +import torch_xla.core.xla_model as xm # pyre-ignore[21]: +import torch_xla.distributed.xla_multiprocessing as xmp # pyre-ignore[21]: from et_replay.comm.backend.base_backend import BaseBackend diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index b08e48c7..b9901068 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -11,6 +11,8 @@ unicode_literals, ) +import argparse + import logging import os import random @@ -333,7 +335,7 @@ def clearQuantCommCtx(collectiveArgs: collectiveArgsHolder) -> None: remove_quantization_handlers(collectiveArgs) -def paramToCommName(name: str, supported_comms: List[str] = None) -> str: +def paramToCommName(name: str, supported_comms: Optional[List[str]] = None) -> str: """ Map any possible creative collective names to the internal name. Validate the `name` if `supported_comms` is provided. @@ -374,7 +376,7 @@ def paramToCommName(name: str, supported_comms: List[str] = None) -> str: return new_name -def ensureTensorFlush(tensors: Union[List[torch.Tensor], torch.Tensor]) -> float: +def ensureTensorFlush(tensors: Union[List[torch.Tensor], torch.Tensor]) -> Any: """ Use this to flush non-blocking ops to ensure they are really complete. @@ -645,9 +647,6 @@ def __init__(self, name: str, backendFuncs: BaseBackend) -> None: self.start_event = backendFuncs.get_new_event(enable_timing=True) self.end_event = backendFuncs.get_new_event(enable_timing=True) - def reset(self) -> None: - self.elapsedTimeNS = 0.0 - def start(self, stream=None) -> None: self.start_event.record(stream) @@ -771,7 +770,7 @@ def __init__( class paramCommsBench(ABC): """Abstract class for any param comms benchmark.""" - def __init__(self, supportedNwstacks: List[str] = None) -> None: + def __init__(self, supportedNwstacks: List[str]) -> None: self.supportedNwstacks = supportedNwstacks self.supported_tpu_core_valuses = [1, 8] self.dtypeMap = { @@ -918,7 +917,7 @@ def _prep_all_to_allv( ) -> Tuple[torch.Tensor, torch.Tensor]: """Prepare the all_to_allv mode""" - opTensor = [] + opTensor = torch.Tensor() if allocate: # all_to_allv requires two tensors opTensor = self.backendFuncs.alloc_random( @@ -950,8 +949,8 @@ def _prep_all_to_all_single( scaleFactor: float, allocate: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: - ipTensor = None - opTensor = None + ipTensor = torch.Tensor() + opTensor = torch.Tensor() if allocate: if commsParams.dcheck == 1: ipTensor = self.backendFuncs.alloc_ones( @@ -974,7 +973,7 @@ def _prep_all_to_all_single( def _prep_all_to_all( self, - ipTensor: torch.Tensor, + ipTensor: List[torch.Tensor], curComm: commsArgs, commsParams: commsParamsHolderBase, numElementsIn: int, @@ -984,7 +983,7 @@ def _prep_all_to_all( dtype: torch.dtype, scaleFactor: float, allocate: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: # all_to_all requires two tensor lists, e.g., List[torch.Tensor] ipTensor = [] @@ -1020,7 +1019,7 @@ def _prep_all_to_all( def _prep_all_gather( self, - ipTensor: torch.tensor, + ipTensor: torch.Tensor, curComm: commsArgs, commsParams: commsParamsHolderBase, numElementsIn: int, @@ -1030,7 +1029,7 @@ def _prep_all_gather( dtype: torch.dtype, scaleFactor: float, allocate: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: opTensor = [] if not commsParams.size_from_trace: @@ -1059,7 +1058,7 @@ def _prep_all_gather( def _prep_all_gather_base( self, - ipTensor: torch.tensor, + ipTensor: torch.Tensor, curComm: commsArgs, commsParams: commsParamsHolderBase, numElementsIn: int, @@ -1070,14 +1069,14 @@ def _prep_all_gather_base( scaleFactor: float, allocate: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: - opTensor = [] + opTensor = torch.Tensor() if not commsParams.size_from_trace: numElementsOut = numElementsIn numElementsIn = numElementsIn // world_size if allocate: if commsParams.dcheck == 1: ipTensor = self.backendFuncs.alloc_ones( - numElementsIn, + [numElementsIn], curDevice, dtype, scaleFactor=self.initVal, @@ -1097,7 +1096,7 @@ def _prep_all_gather_base( def _prep_incast( self, - ipTensor: torch.tensor, + ipTensor: torch.Tensor, curComm: commsArgs, commsParams: commsParamsHolderBase, numElementsIn: int, @@ -1107,7 +1106,7 @@ def _prep_incast( dtype: torch.dtype, scaleFactor: float, allocate: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: # incast requires a tensor list with length of src_ranks, e.g., List[torch.Tensor] opTensor = [] @@ -1122,7 +1121,7 @@ def _prep_incast( def _prep_reduce_scatter( self, - ipTensor: torch.tensor, + ipTensor: List[torch.Tensor], curComm: commsArgs, commsParams: commsParamsHolderBase, numElementsIn: int, @@ -1132,9 +1131,9 @@ def _prep_reduce_scatter( dtype: torch.dtype, scaleFactor: float, allocate: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[List[torch.Tensor], torch.Tensor]: ipTensor = [] - opTensor = [] + opTensor = torch.Tensor() if not commsParams.size_from_trace: numElementsIn = numElementsOut // world_size numElementsOut = numElementsOut // world_size @@ -1168,7 +1167,7 @@ def _prep_reduce_scatter( def _prep_reduce_scatter_base( self, - ipTensor: torch.tensor, + ipTensor: torch.Tensor, curComm: commsArgs, commsParams: commsParamsHolderBase, numElementsIn: int, @@ -1179,15 +1178,15 @@ def _prep_reduce_scatter_base( scaleFactor: float, allocate: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: - ipTensor = [] - opTensor = [] + ipTensor = torch.Tensor() + opTensor = torch.Tensor() if not commsParams.size_from_trace: numElementsIn = numElementsOut numElementsOut = numElementsOut // world_size if allocate: if commsParams.dcheck == 1: ipTensor = self.backendFuncs.alloc_ones( - numElementsIn, + [numElementsIn], curDevice, commsParams.dtype, self.initVal, @@ -1206,7 +1205,7 @@ def _prep_reduce_scatter_base( def _prep_pt2pt( self, - ipTensor: torch.tensor, + ipTensor: torch.Tensor, curComm: commsArgs, commsParams: commsParamsHolderBase, numElementsIn: int, @@ -1218,7 +1217,7 @@ def _prep_pt2pt( allocate: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: # pt2pt or out-of-place collectives - opTensor = [] + opTensor = torch.Tensor() if allocate: opTensor = self.backendFuncs.alloc_random( [numElementsOut], @@ -1234,9 +1233,9 @@ def prepGemmNotSquare( mm0_dim1: int, mm1_dim0: int, mm1_dim1: int, - dtype: str, + dtype: torch.dtype, curDevice: str, - gemmTensor: torch.tensor = None, + gemmTensor: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if gemmTensor is None: in1 = np.random.rand(mm0_dim0, mm0_dim1) @@ -1245,7 +1244,7 @@ def prepGemmNotSquare( MMin1 = torch.FloatTensor(in1).to(curDevice) MMin2 = torch.FloatTensor(in2).to(curDevice) MMout = self.backendFuncs.alloc_empty( - (mm0_dim0, mm1_dim1), dtype, curDevice + [mm0_dim0, mm1_dim1], dtype, curDevice ) else: mm_size0 = mm0_dim0 * mm0_dim1 @@ -1263,8 +1262,8 @@ def prepGemmNotSquare( # Prepare generic compute operations that uses 1 or 2 input tensors, and 1 output tensor def prepComp( - self, mm_dim: int, dtype: str, curDevice: str, kernel: str - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self, mm_dim: int, dtype: torch.dtype, curDevice: str, kernel: str + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: compIn1 = self.backendFuncs.alloc_random([mm_dim, mm_dim], curDevice, dtype) compOut = self.backendFuncs.alloc_empty([mm_dim, mm_dim], dtype, curDevice) compIn2 = None @@ -1273,7 +1272,7 @@ def prepComp( return (compOut, compIn1, compIn2) def prepGemm( - self, mm_dim: int, dtype: str, curDevice: str, gemmTensor: torch.tensor = None + self, mm_dim: int, dtype: torch.dtype, curDevice: str, gemmTensor: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.prepGemmNotSquare( mm_dim, mm_dim, mm_dim, mm_dim, dtype, curDevice, gemmTensor @@ -1284,7 +1283,7 @@ def prepComm( curComm: commsArgs, commsParams: commsParamsHolderBase, allocate: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Union[List[torch.Tensor], torch.Tensor]]: """ Allocate the tensors for collective. @@ -1300,7 +1299,7 @@ def prepComm( ) if commOp in ("wait", "barrier"): - return ([], []) + return (torch.Tensor(), torch.Tensor()) numElementsIn = curComm.inMsgSize # numElementsOut is only meaningful for out-of-place collectives and pt2pt @@ -1310,7 +1309,7 @@ def prepComm( curDevice = commsParams.device # seed to generate random value; let's use a small value to avoid potential "overflow when unpacking long" scaleFactor = world_size - opTensor = [] + opTensor = torch.Tensor() if allocate: if commsParams.dcheck == 1: @@ -1323,7 +1322,7 @@ def prepComm( [numElementsIn], curDevice, dtype, scaleFactor ) else: - ipTensor = [] + ipTensor = torch.Tensor() # TODO: consider using this dictionary to check valid keywords rather than silently defaulting dispatchDict = { @@ -1361,12 +1360,12 @@ def prepComm( return (ipTensor, opTensor) @abstractmethod - def runBench(self, *args, **kwargs) -> None: + def runBench(self, commsParams: commsParamsHolderBase) -> None: """Must override to start the desired benchmarking""" pass @abstractmethod - def benchTime(self, *args, **kwargs) -> None: + def benchTime(self, commsParams: commsParamsHolderBase) -> None: """Must override to run the desired benchmarking""" pass @@ -1376,7 +1375,7 @@ def reportBenchTime(self, *args, **kwargs) -> None: pass @abstractmethod - def readArgs(self, parser: ArgumentParser) -> None: + def readArgs(self, parser: ArgumentParser) -> argparse.Namespace: """Basic/Common arguments for all PARAM-Comm benchmarks""" parser.add_argument( "--master-ip", diff --git a/et_replay/comm/param_profile.py b/et_replay/comm/param_profile.py index d9a9acc2..a436590e 100644 --- a/et_replay/comm/param_profile.py +++ b/et_replay/comm/param_profile.py @@ -8,7 +8,7 @@ import logging import time from dataclasses import dataclass -from typing import Any +from typing import Any, Optional from torch.autograd.profiler import record_function @@ -18,10 +18,15 @@ class paramProfile(record_function): """Inherit from PyTorch profiler to enable autoguard profiling while measuring the time interval in PARAM""" - def __init__(self, timer: paramTimer = None, description: str = "") -> None: + def __init__( + self, timer: Optional[paramTimer] = None, description: str = "" + ) -> None: + super().__init__(name=description) self.description = description self.timer = timer - super().__init__(name=description) + self.start = 0.0 + self.end = 0.0 + self.intervalNS = 0.0 def __enter__(self) -> paramProfile: super().__enter__() diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index ff4f59a5..384f690f 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -12,7 +12,6 @@ import logging import os import time -from os import path from typing import Dict, List, Set import numpy as np @@ -60,7 +59,7 @@ def writeCommDetails(commsTracePerf: List, rank: int, folder: str = "./") -> Non if len(folder) == 0: # skip output if the path is explicitly set to "" return - comms_file = folder + f"/replayedCommsPerf.rank{rank}.json" + comms_file = folder + f"/replayedCommsPerf.rank-{rank}.json" logger.info(f"[Rank {rank:3}] Writing comms details to {comms_file}") saveToLocal = True @@ -115,7 +114,7 @@ def __init__(self): self.max_msg_cnt = 0 # 0 means no limit self.num_msg = 0 self.is_blocking = True - self.do_warm_up = True + self.do_warm_up = False self.reuse_tensors = False self.allowList = "" @@ -157,7 +156,7 @@ def __init__(self): self.embLookupReuse = {} - def readArgs(self, parser: argparse.ArgumentParser) -> None: + def readArgs(self, parser: argparse.ArgumentParser) -> argparse.Namespace: """ Reads command line args to set runtime parameters for replay. @@ -217,7 +216,7 @@ def readArgs(self, parser: argparse.ArgumentParser) -> None: "--do-warm-up", action="store_true", default=self.do_warm_up, - help="Toggle to disable performing extra replaying for warm-up", + help="Toggle to enable performing extra replaying for warm-up", ) parser.add_argument( "--reuse-tensors", @@ -559,7 +558,7 @@ def resetComms(self): def getCommGroupInfo( self, curComm: commsArgs, commsParams: commsParamsHolderBase - ) -> (int, str): + ) -> tuple[int, str]: """ Return the group infomation of the current process group including group rank of the local process, and a description string for logging purpose. @@ -616,20 +615,20 @@ def prepComms( curComm: commsArgs, commsParams: commsParamsHolderBase, regenerateTensors: bool = True, - ) -> (torch.Tensor, torch.Tensor): + ) -> tuple[torch.Tensor, torch.Tensor]: """ Prepares the appropriate tensors for the current collective communication. Args: curComm: The current communication that we are preparing the correct tensor for. commsParams: Holds the comms param arguments that will determine tensor attributes. - regenerateTensors: when an id is being replayed multiple times, setting this to false will use temsors from previous runs + regenerateTensors: when an id is being replayed multiple times, setting this to false will use tensors from previous runs Returns: (ipTensor, opTensor) if the current communication requires tensors, None otherwise. """ commOp = paramToCommName(curComm.comms) if commOp in ("wait", "barrier", "batch_isend_irecv"): - return ([], []) + return (torch.Tensor(), torch.Tensor()) # prep process group for hard-coded traces if curComm.pgId is not None and not self.shrink: @@ -729,7 +728,7 @@ def commRebalance(self, curComm: commsArgs) -> None: # Pass in curComm to modify it in the trace self.rebalanceSplit(curComm) - def runCompute(self, func, curBlockStack: str) -> float: + def runCompute(self, func, curBlockStack: str) -> tuple[float, float]: """ Replays a specified compute operation and records metrics for benchmarking. @@ -764,7 +763,7 @@ def runCompute(self, func, curBlockStack: str) -> float: def runComms( self, collName: str, curComm: commsArgs, curBlockStack: str - ) -> (float, float): + ) -> tuple[float, float]: """ Replays collective communication operation and records metrics for benchmarking. @@ -816,6 +815,7 @@ def runComms( logger.warn( f"Unsupported collective name: {collName}. Skipping replaying the collective" ) + retObj = None # if blocking, post outstanding ops and wait for them to complete. if nonblocking, just post op if self.is_blocking: @@ -1115,8 +1115,11 @@ def replayTrace( ) def replaySingle( - self, commsParams: commsParamsHolderBase, id: int, regenerateTensors: True - ) -> torch.tensor: + self, + commsParams: commsParamsHolderBase, + id: int, + regenerateTensors: bool = True, + ) -> torch.Tensor: """ Replay comms trace. Args: @@ -1129,7 +1132,7 @@ def replaySingle( if curComm.id == id: collName = paramToCommName(curComm.comms) if collName not in self.allowList: - return + return torch.Tensor() curBlocks = ( curComm.markerStack if curComm.markerStack is not None else [] @@ -1540,13 +1543,12 @@ def readRawTrace(self, remotePath: str, rank: int) -> None: full_trace_path=self.use_one_trace, trace_type=self.trace_type, ) - self.comms_trace = json.load(raw_comms_trace) else: # Check if self.trace_file is a directory or a single file if os.path.isdir(self.trace_file): # Directory mode: construct the path to the rank-specific file - trace_file_path = f"{self.trace_file}/{rank}.json" + trace_file_path = f"{self.trace_file}/rank-{rank}.json" else: # Single file mode: use self.trace_file as is trace_file_path = self.trace_file @@ -1599,7 +1601,7 @@ def readTrace(self, remotePath: str, rank: int) -> None: ( self.trace_file if not os.path.isdir(self.trace_file) - else f"{self.trace_file}/{rank}.json" + else f"{self.trace_file}/rank-{rank}.json" ), rank, self.backendFuncs.get_world_size(),