From 83932024bc303bfd4e21ce267a51a706be2246c4 Mon Sep 17 00:00:00 2001 From: Sen Yang Date: Wed, 5 Jun 2024 22:40:30 -0700 Subject: [PATCH] Support weighted_bwd_compute_multiplier in sharding estimators (#2068) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Context This diff stack is a quick mitigation for sharding imbalance on weighted features. TorchRec sharding planner focus more on forward cost. e.g., it isn’t aware difference like: - for the unweighted lookups, backward kernel is, say, 2x of forward kernel. - but for the weighted lookups, backward kernel may be 4x of forward kernel. See https://docs.google.com/document/d/1o-lB6veGVIZFO148ljSuhVr8sA9fm2EGwOfhX6TSdrg/edit?usp=sharing for the context # This Diff - Enable shard estimator to include weighted_feature_bwd_compute_multiplier when compute bwd cost for weighted feature Pull Request resolved: https://github.com/pytorch/torchrec/pull/2068 Reviewed By: xush6528, sarckk Differential Revision: D53550851 fbshipit-source-id: bd14c9b8dc01d47802741978288a002cc58e85ee --- torchrec/distributed/planner/constants.py | 1 + .../distributed/planner/shard_estimators.py | 18 ++++ .../planner/tests/test_shard_estimators.py | 91 ++++++++++++++++++- torchrec/distributed/planner/types.py | 9 ++ 4 files changed, 118 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/planner/constants.py b/torchrec/distributed/planner/constants.py index c00192193..423d08aa3 100644 --- a/torchrec/distributed/planner/constants.py +++ b/torchrec/distributed/planner/constants.py @@ -34,6 +34,7 @@ HALF_BLOCK_PENALTY: float = 1.15 # empirical studies QUARTER_BLOCK_PENALTY: float = 1.75 # empirical studies BWD_COMPUTE_MULTIPLIER: float = 2 # empirical studies +WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER: float = 1 # empirical studies WEIGHTED_KERNEL_MULTIPLIER: float = 1.1 # empirical studies DP_ELEMENTWISE_KERNELS_PERF_FACTOR: float = 9.22 # empirical studies diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index cdfd25219..487502ebc 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -216,6 +216,7 @@ def estimate( intra_host_bw=self._topology.intra_host_bw, inter_host_bw=self._topology.inter_host_bw, bwd_compute_multiplier=self._topology.bwd_compute_multiplier, + weighted_feature_bwd_compute_multiplier=self._topology.weighted_feature_bwd_compute_multiplier, is_pooled=sharding_option.is_pooled, is_weighted=is_weighted, is_inference=self._is_inference, @@ -251,6 +252,7 @@ def perf_func_emb_wall_time( intra_host_bw: float, inter_host_bw: float, bwd_compute_multiplier: float, + weighted_feature_bwd_compute_multiplier: float, is_pooled: bool, is_weighted: bool = False, caching_ratio: Optional[float] = None, @@ -336,6 +338,7 @@ def perf_func_emb_wall_time( inter_host_bw=inter_host_bw, intra_host_bw=intra_host_bw, bwd_compute_multiplier=bwd_compute_multiplier, + weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier, is_pooled=is_pooled, is_weighted=is_weighted, is_inference=is_inference, @@ -361,6 +364,7 @@ def perf_func_emb_wall_time( inter_host_bw=inter_host_bw, intra_host_bw=intra_host_bw, bwd_compute_multiplier=bwd_compute_multiplier, + weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier, is_pooled=is_pooled, is_weighted=is_weighted, expected_cache_fetches=expected_cache_fetches, @@ -386,6 +390,7 @@ def perf_func_emb_wall_time( inter_host_bw=inter_host_bw, intra_host_bw=intra_host_bw, bwd_compute_multiplier=bwd_compute_multiplier, + weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier, is_pooled=is_pooled, is_weighted=is_weighted, expected_cache_fetches=expected_cache_fetches, @@ -405,6 +410,7 @@ def perf_func_emb_wall_time( device_bw=device_bw, inter_host_bw=inter_host_bw, bwd_compute_multiplier=bwd_compute_multiplier, + weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier, is_pooled=is_pooled, is_weighted=is_weighted, ) @@ -447,6 +453,7 @@ def _get_tw_sharding_perf( inter_host_bw: float, intra_host_bw: float, bwd_compute_multiplier: float, + weighted_feature_bwd_compute_multiplier: float, is_pooled: bool, is_weighted: bool = False, is_inference: bool = False, @@ -507,6 +514,8 @@ def _get_tw_sharding_perf( # includes fused optimizers bwd_compute = fwd_compute * bwd_compute_multiplier + if is_weighted: + bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier prefetch_compute = cls._get_expected_cache_prefetch_time( ddr_mem_bw, expected_cache_fetches, emb_dim, table_data_type_size @@ -543,6 +552,7 @@ def _get_rw_sharding_perf( inter_host_bw: float, intra_host_bw: float, bwd_compute_multiplier: float, + weighted_feature_bwd_compute_multiplier: float, is_pooled: bool, is_weighted: bool = False, expected_cache_fetches: float = 0, @@ -601,6 +611,8 @@ def _get_rw_sharding_perf( ) bwd_compute = fwd_compute * bwd_compute_multiplier + if is_weighted: + bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier # for row-wise, expected_cache_fetches per shard is / world_size prefetch_compute = cls._get_expected_cache_prefetch_time( @@ -639,6 +651,7 @@ def _get_twrw_sharding_perf( inter_host_bw: float, intra_host_bw: float, bwd_compute_multiplier: float, + weighted_feature_bwd_compute_multiplier: float, is_pooled: bool, is_weighted: bool = False, expected_cache_fetches: float = 0, @@ -697,6 +710,8 @@ def _get_twrw_sharding_perf( bwd_batched_copy = bwd_output_write_size * BATCHED_COPY_PERF_FACTOR / device_bw bwd_compute = fwd_compute * bwd_compute_multiplier + if is_weighted: + bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier # for table-wise-row-wise, expected_cache_fetches per shard is / local_world_size prefetch_compute = cls._get_expected_cache_prefetch_time( @@ -730,6 +745,7 @@ def _get_dp_sharding_perf( device_bw: float, inter_host_bw: float, bwd_compute_multiplier: float, + weighted_feature_bwd_compute_multiplier: float, is_pooled: bool, is_weighted: bool = False, ) -> Perf: @@ -772,6 +788,8 @@ def _get_dp_sharding_perf( optimizer_kernels = table_size * DP_ELEMENTWISE_KERNELS_PERF_FACTOR / device_bw bwd_compute = fwd_compute * bwd_compute_multiplier + if is_weighted: + bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier bwd_grad_indice_weights_kernel = ( fwd_compute * WEIGHTED_KERNEL_MULTIPLIER if is_weighted else 0 diff --git a/torchrec/distributed/planner/tests/test_shard_estimators.py b/torchrec/distributed/planner/tests/test_shard_estimators.py index 0e92f1a39..b60a01114 100644 --- a/torchrec/distributed/planner/tests/test_shard_estimators.py +++ b/torchrec/distributed/planner/tests/test_shard_estimators.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import cast +from typing import cast, Dict, List, Tuple from unittest.mock import Mock, patch @@ -526,6 +526,95 @@ def cacheability(self) -> float: } self.assertEqual(expected_prefetch_computes, prefetch_computes) + def test_weighted_feature_bwd_compute_multiplier(self) -> None: + def _get_bwd_computes( + model: torch.nn.Module, + weighted_feature_bwd_compute_multiplier: float, + ) -> Dict[Tuple[str, str, str], List[float]]: + topology = Topology( + world_size=2, + compute_device="cuda", + weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier, + ) + estimator = EmbeddingPerfEstimator(topology=topology) + enumerator = EmbeddingEnumerator( + topology=topology, batch_size=BATCH_SIZE, estimator=estimator + ) + sharding_options = enumerator.enumerate( + module=model, + sharders=[ + cast( + ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder() + ) + ], + ) + bwd_computes = { + ( + sharding_option.name, + sharding_option.compute_kernel, + sharding_option.sharding_type, + ): [ + shard.perf.bwd_compute if shard.perf else -1 + for shard in sharding_option.shards + ] + for sharding_option in sharding_options + } + return bwd_computes + + tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_0", + feature_names=["feature_0"], + ) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="weighted_table_0", + feature_names=["weighted_feature_0"], + ) + ] + model = TestSparseNN(tables=tables, weighted_tables=weighted_tables) + + MULTIPLIER = 7 + bwd_computes_1 = _get_bwd_computes( + model, weighted_feature_bwd_compute_multiplier=1 + ) + bwd_computes_2 = _get_bwd_computes( + model, + weighted_feature_bwd_compute_multiplier=2, + ) + bwd_computes_n = _get_bwd_computes( + model, + weighted_feature_bwd_compute_multiplier=MULTIPLIER, + ) + self.assertEqual(bwd_computes_1.keys(), bwd_computes_2.keys()) + self.assertEqual(bwd_computes_1.keys(), bwd_computes_n.keys()) + for key in bwd_computes_1.keys(): + table_name, _, sharding_type = key + if table_name.startswith("weighted"): + self.assertEqual(len(bwd_computes_1), len(bwd_computes_2)) + self.assertEqual(len(bwd_computes_1), len(bwd_computes_n)) + for bwd_compute_1, bwd_compute_2, bwd_compute_n in zip( + bwd_computes_1[key], bwd_computes_2[key], bwd_computes_n[key] + ): + # bwd_compute_1 = base_bwd_compute + offset + # bwd_compute_2 = base_bwd_compute * 2 + offset + # bwd_compute_n = base_bwd_compute * MULTIPLIER + offset + # (where offset = bwd_grad_indice_weights_kernel in production + # https://fburl.com/code/u9hq6vhf) + base_bwd_compute = bwd_compute_2 - bwd_compute_1 + offset = bwd_compute_1 - base_bwd_compute + self.assertAlmostEqual( + base_bwd_compute * MULTIPLIER, + bwd_compute_n - offset, + ) + else: + self.assertEqual(bwd_computes_1[key], bwd_computes_2[key]) + # pyre-ignore[3] def calculate_storage_specific_size_data_provider(): diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 64bb23183..b20d96673 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -25,6 +25,7 @@ HBM_MEM_BW, INTRA_NODE_BANDWIDTH, POOLING_FACTOR, + WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER, ) from torchrec.distributed.types import ( BoundsCheckMode, @@ -186,6 +187,7 @@ def __init__( inter_host_bw: float = CROSS_NODE_BANDWIDTH, bwd_compute_multiplier: float = BWD_COMPUTE_MULTIPLIER, custom_topology_data: Optional[CustomTopologyData] = None, + weighted_feature_bwd_compute_multiplier: float = WEIGHTED_FEATURE_BWD_COMPUTE_MULTIPLIER, ) -> None: """ Representation of a network of devices in a cluster. @@ -238,6 +240,9 @@ def __init__( self._inter_host_bw = inter_host_bw self._bwd_compute_multiplier = bwd_compute_multiplier self._custom_topology_data = custom_topology_data + self._weighted_feature_bwd_compute_multiplier = ( + weighted_feature_bwd_compute_multiplier + ) @property def compute_device(self) -> str: @@ -275,6 +280,10 @@ def inter_host_bw(self) -> float: def bwd_compute_multiplier(self) -> float: return self._bwd_compute_multiplier + @property + def weighted_feature_bwd_compute_multiplier(self) -> float: + return self._weighted_feature_bwd_compute_multiplier + def __repr__(self) -> str: topology_repr: str = f"world_size={self._world_size} \n" topology_repr += f"compute_device={self._compute_device}\n"