Skip to content

Commit

Permalink
Fix various issues in comm replay (#158)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #158

Updated trace file naming convention to rank{rank}.json when in directory mode.
Disabled warm-up by default and updated the help text for do_warm_up toggle.
Corrected a typo in the regenerateTensors description.
Fixed a few syntax errors.

Reviewed By: briancoutinho

Differential Revision: D61218264
  • Loading branch information
shengfukevin authored and facebook-github-bot committed Aug 19, 2024
1 parent 7f45eb8 commit d2bb29c
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 80 deletions.
28 changes: 14 additions & 14 deletions et_replay/comm/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,29 +131,29 @@ class BaseBackend(ABC):
def __init__(self) -> None:
self.tcp_store = None
self.collectiveFunc = {
"all_to_all_single": self.all_to_all_single,
"all_to_all_single": self.all_to_all_single, # pyre-ignore[16]:
"all_to_all": self.all_to_all,
"all_to_allv": self.all_to_allv,
"all_reduce": self.all_reduce,
"broadcast": self.broadcast,
"gather": self.gather,
"all_gather": self.all_gather,
"all_gather_base": self.all_gather_base,
"broadcast": self.broadcast, # pyre-ignore[16]:
"gather": self.gather, # pyre-ignore[16]:
"all_gather": self.all_gather, # pyre-ignore[16]:
"all_gather_base": self.all_gather_base, # pyre-ignore[16]:
"reduce": self.reduce,
"reduce_scatter": self.reduce_scatter,
"reduce_scatter_base": self.reduce_scatter_base,
"scatter": self.scatter,
"reduce_scatter": self.reduce_scatter, # pyre-ignore[16]:
"reduce_scatter_base": self.reduce_scatter_base, # pyre-ignore[16]:
"scatter": self.scatter, # pyre-ignore[16]:
"barrier": self.barrier,
"incast": self.incast,
"multicast": self.multicast,
"incast": self.incast, # pyre-ignore[16]:
"multicast": self.multicast, # pyre-ignore[16]:
"noop": self.noop,
}

self.computeFunc = {"gemm": self.gemm}

def alloc_ones(
self,
sizeArr: int,
sizeArr: List[int],
curRankDevice: str = "cuda",
dtype: torch.dtype = torch.float32,
scaleFactor: float = 1.0,
Expand All @@ -176,7 +176,7 @@ def alloc_ones(

def noop(
self,
collectiveArgs: collectiveArgsHolder = None,
collectiveArgs: collectiveArgsHolder,
retFlag: bool = False,
pair: bool = False,
) -> None:
Expand Down Expand Up @@ -236,7 +236,7 @@ def get_mem_size(self, collectiveArgs: collectiveArgsHolder) -> int:
@abstractmethod
def alloc_random(
self,
sizeArr: int,
sizeArr: List[int],
curRankDevice: str,
dtype: torch.dtype,
scaleFactor: float = 1.0,
Expand All @@ -253,7 +253,7 @@ def alloc_embedding_tables(

@abstractmethod
def alloc_empty(
self, sizeArr: int, dtype: torch.dtype, curRankDevice: str
self, sizeArr: List[int], dtype: torch.dtype, curRankDevice: str
) -> torch.Tensor:
"""Allocate tensor with uninitialized data based on parameters."""
pass
Expand Down
12 changes: 8 additions & 4 deletions et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
except ImportError:
try:
# Open-source extend_distributed.py can be found in https://github.com/facebookresearch/dlrm
import extend_distributed
import extend_distributed # pyre-ignore[21]:

has_ext_dist = True
except ImportError:
Expand Down Expand Up @@ -783,7 +783,11 @@ def get_mem_size(self, collectiveArgs, pair=False):
return _sizeBytes

def alloc_random(
self, sizeArr, curRankDevice="cuda", dtype=torch.float32, scaleFactor=1.0
self,
sizeArr: List[int],
curRankDevice="cuda",
dtype=torch.float32,
scaleFactor=1.0,
):
if dtype in (
torch.int8,
Expand Down Expand Up @@ -954,7 +958,7 @@ def sync_stream(
device: Optional[torch.device] = None,
):
"""Synchronize a stream with its associated device"""
if device.type == "cuda":
if device is not None and device.type == "cuda":
# if the stream is None, sync on the current default stream
cur_stream = (
stream
Expand Down Expand Up @@ -1000,7 +1004,7 @@ def __init__(self, bootstrap_info, commsParams):
# Import Fairring
if backend == "fairring":
try:
import fairring # noqa
import fairring # pyre-ignore[21]:
except ImportError:
raise RuntimeError("Unable to import Fairring")

Expand Down
4 changes: 2 additions & 2 deletions et_replay/comm/backend/pytorch_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import numpy as np
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm # @manual
import torch_xla.distributed.xla_multiprocessing as xmp # @manual
import torch_xla.core.xla_model as xm # pyre-ignore[21]:
import torch_xla.distributed.xla_multiprocessing as xmp # pyre-ignore[21]:

from et_replay.comm.backend.base_backend import BaseBackend

Expand Down
Loading

0 comments on commit d2bb29c

Please sign in to comment.