From 03307152e4099c0773296e5fe249c97018a25853 Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Fri, 28 Jun 2024 15:25:14 -0700 Subject: [PATCH] Give user option for multiplier to control preference for rank 0 or rank 1-N for Heteroplanner uneven sharding (#2138) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2138 This diff add a multiplier to penalize non zero shards (parameter server shard) commum cost. This essentially give user a toggle to prefer rank 0 vs rank 1-N Reviewed By: jingsh Differential Revision: D58755047 fbshipit-source-id: fa44d4825f439c9765da9c22a26c67b401a0cd88 --- torchrec/distributed/planner/shard_estimators.py | 2 ++ torchrec/distributed/planner/types.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index 967a6fb8d..529f168f3 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -223,6 +223,7 @@ def estimate( caching_ratio=caching_ratio, prefetch_pipeline=prefetch_pipeline, expected_cache_fetches=expected_cache_fetches, + uneven_sharding_perf_multiplier=self._topology.uneven_sharding_perf_multiplier, ) for shard, perf in zip(sharding_option.shards, shard_perfs): @@ -259,6 +260,7 @@ def perf_func_emb_wall_time( is_inference: bool = False, prefetch_pipeline: bool = False, expected_cache_fetches: float = 0, + uneven_sharding_perf_multiplier: float = 1.0, ) -> List[Perf]: """ Attempts to model perfs as a function of relative wall times. diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 21e205d15..79311dbdc 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -189,6 +189,7 @@ def __init__( bwd_compute_multiplier: float = BWD_COMPUTE_MULTIPLIER, custom_topology_data: Optional[CustomTopologyData] = None, weighted_feature_bwd_compute_multiplier: float = WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER, + uneven_sharding_perf_multiplier: float = 1.0, ) -> None: """ Representation of a network of devices in a cluster. @@ -244,6 +245,7 @@ def __init__( self._weighted_feature_bwd_compute_multiplier = ( weighted_feature_bwd_compute_multiplier ) + self._uneven_sharding_perf_multiplier = uneven_sharding_perf_multiplier @property def compute_device(self) -> str: @@ -285,6 +287,10 @@ def bwd_compute_multiplier(self) -> float: def weighted_feature_bwd_compute_multiplier(self) -> float: return self._weighted_feature_bwd_compute_multiplier + @property + def uneven_sharding_perf_multiplier(self) -> float: + return self._uneven_sharding_perf_multiplier + def __repr__(self) -> str: topology_repr: str = f"world_size={self._world_size} \n" topology_repr += f"compute_device={self._compute_device}\n"