Skip to content

Commit

Permalink
Give user option for multiplier to control preference for rank 0 or r…
Browse files Browse the repository at this point in the history
…ank 1-N for Heteroplanner uneven sharding (#2138)

Summary:
Pull Request resolved: #2138

This diff add a multiplier to penalize non zero shards (parameter server shard) commum cost. This essentially give user a toggle to prefer rank 0 vs rank 1-N

Reviewed By: jingsh

Differential Revision: D58755047

fbshipit-source-id: fa44d4825f439c9765da9c22a26c67b401a0cd88
  • Loading branch information
gnahzg authored and facebook-github-bot committed Jun 28, 2024
1 parent a59ef93 commit 0330715
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def estimate(
caching_ratio=caching_ratio,
prefetch_pipeline=prefetch_pipeline,
expected_cache_fetches=expected_cache_fetches,
uneven_sharding_perf_multiplier=self._topology.uneven_sharding_perf_multiplier,
)

for shard, perf in zip(sharding_option.shards, shard_perfs):
Expand Down Expand Up @@ -259,6 +260,7 @@ def perf_func_emb_wall_time(
is_inference: bool = False,
prefetch_pipeline: bool = False,
expected_cache_fetches: float = 0,
uneven_sharding_perf_multiplier: float = 1.0,
) -> List[Perf]:
"""
Attempts to model perfs as a function of relative wall times.
Expand Down
6 changes: 6 additions & 0 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __init__(
bwd_compute_multiplier: float = BWD_COMPUTE_MULTIPLIER,
custom_topology_data: Optional[CustomTopologyData] = None,
weighted_feature_bwd_compute_multiplier: float = WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER,
uneven_sharding_perf_multiplier: float = 1.0,
) -> None:
"""
Representation of a network of devices in a cluster.
Expand Down Expand Up @@ -244,6 +245,7 @@ def __init__(
self._weighted_feature_bwd_compute_multiplier = (
weighted_feature_bwd_compute_multiplier
)
self._uneven_sharding_perf_multiplier = uneven_sharding_perf_multiplier

@property
def compute_device(self) -> str:
Expand Down Expand Up @@ -285,6 +287,10 @@ def bwd_compute_multiplier(self) -> float:
def weighted_feature_bwd_compute_multiplier(self) -> float:
return self._weighted_feature_bwd_compute_multiplier

@property
def uneven_sharding_perf_multiplier(self) -> float:
return self._uneven_sharding_perf_multiplier

def __repr__(self) -> str:
topology_repr: str = f"world_size={self._world_size} \n"
topology_repr += f"compute_device={self._compute_device}\n"
Expand Down

0 comments on commit 0330715

Please sign in to comment.