diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index 967a6fb8d..529f168f3 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -223,6 +223,7 @@ def estimate( caching_ratio=caching_ratio, prefetch_pipeline=prefetch_pipeline, expected_cache_fetches=expected_cache_fetches, + uneven_sharding_perf_multiplier=self._topology.uneven_sharding_perf_multiplier, ) for shard, perf in zip(sharding_option.shards, shard_perfs): @@ -259,6 +260,7 @@ def perf_func_emb_wall_time( is_inference: bool = False, prefetch_pipeline: bool = False, expected_cache_fetches: float = 0, + uneven_sharding_perf_multiplier: float = 1.0, ) -> List[Perf]: """ Attempts to model perfs as a function of relative wall times. diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 21e205d15..79311dbdc 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -189,6 +189,7 @@ def __init__( bwd_compute_multiplier: float = BWD_COMPUTE_MULTIPLIER, custom_topology_data: Optional[CustomTopologyData] = None, weighted_feature_bwd_compute_multiplier: float = WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER, + uneven_sharding_perf_multiplier: float = 1.0, ) -> None: """ Representation of a network of devices in a cluster. @@ -244,6 +245,7 @@ def __init__( self._weighted_feature_bwd_compute_multiplier = ( weighted_feature_bwd_compute_multiplier ) + self._uneven_sharding_perf_multiplier = uneven_sharding_perf_multiplier @property def compute_device(self) -> str: @@ -285,6 +287,10 @@ def bwd_compute_multiplier(self) -> float: def weighted_feature_bwd_compute_multiplier(self) -> float: return self._weighted_feature_bwd_compute_multiplier + @property + def uneven_sharding_perf_multiplier(self) -> float: + return self._uneven_sharding_perf_multiplier + def __repr__(self) -> str: topology_repr: str = f"world_size={self._world_size} \n" topology_repr += f"compute_device={self._compute_device}\n"