From 7191ce5ec4a60398406baa387a2e2f26581875b1 Mon Sep 17 00:00:00 2001 From: Damian Reeves Date: Wed, 7 Feb 2024 13:48:39 -0800 Subject: [PATCH] Switch EmbeddingOffloadScaleupProposer to use Luus Jaakola search Summary: Replace binary-search with Luus Jaakola. This works better at navigating the non-smooth planner cost surface when exploring cost vs cache memory. Reviewed By: henrylhtsang Differential Revision: D53296719 --- torchrec/distributed/planner/planners.py | 6 +- torchrec/distributed/planner/proposers.py | 44 +--- .../planner/tests/test_proposers.py | 190 +++++------------- 3 files changed, 62 insertions(+), 178 deletions(-) diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 137bc7ad1..af174dd75 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -283,9 +283,11 @@ def plan( ) if current_storage < lowest_storage: lowest_storage = current_storage - proposal_cache[proposal_key] = (False, None, None) + proposal_cache[proposal_key] = (False, proposal, None) proposer.feedback( - partitionable=False, storage_constraint=storage_constraint + partitionable=False, + plan=proposal, + storage_constraint=storage_constraint, ) # clear shard.rank for each sharding_option diff --git a/torchrec/distributed/planner/proposers.py b/torchrec/distributed/planner/proposers.py index 342c618c1..b89eb767a 100644 --- a/torchrec/distributed/planner/proposers.py +++ b/torchrec/distributed/planner/proposers.py @@ -22,7 +22,7 @@ ShardingOption, Topology, ) -from torchrec.distributed.planner.utils import BinarySearchPredicate, bytes_to_gb, prod +from torchrec.distributed.planner.utils import bytes_to_gb, LuusJaakolaSearch, prod logger: logging.Logger = logging.getLogger(__name__) @@ -283,8 +283,7 @@ def __init__(self, use_depth: bool = True) -> None: self.enumerator: Optional[Enumerator] = None self.starting_proposal: List[ShardingOption] = [] self.proposal: Optional[List[ShardingOption]] = None - self.search: Optional[BinarySearchPredicate] = None - self.previous_plan_perf_rating: float = 0.0 + self.search: Optional[LuusJaakolaSearch] = None def load( self, @@ -320,7 +319,7 @@ def feedback( perf_rating: Optional[float] = None, storage_constraint: Optional[Topology] = None, ) -> None: - if not self.enumerator or not plan or not storage_constraint: + if not self.enumerator or plan is None: self.proposal = None return @@ -329,47 +328,26 @@ def feedback( ) if self.search is None: - # Determine how much extra HBM memory is available for scaling our caches - # beyond the baseline-min-working-set plan. We may not be able to find a - # partitionable plan that uses all this budget, or we may find a plan that - # uses only a portion of this budget enables a layout that reduces overall - # cost at the expense of larger prefetch delay. So we perform a binary - # search to sample plans with different budgets to discover a good - # configuration. + if not partitionable or storage_constraint is None: + self.proposal = None + return + hbm_available = EmbeddingOffloadScaleupProposer.get_budget( plan, storage_constraint ) logger.info( f"EmbeddingOffloadScaleupProposer - cache scale up budget={round(bytes_to_gb(hbm_available), 2)} GB, exploring [{round(bytes_to_gb(hbm_used_previously), 2)}, {round(bytes_to_gb(hbm_used_previously + hbm_available), 2)}] GB" ) - # Partitioning proposals is quite expensive when there are a lot of tables, - # so we reduce the number probes the binary search uses to find the max - # cache sizes that fit inside the budget by specifying a tolerance. Once - # we've less than tolerance bytes left of unused budget we stop searching. - # We set tolerance to try to waste less than 3% of budget. For 100TB budget, - # this reduces number of proposals from 47 to 6. - tolerance = round(hbm_available * 0.03) - self.search = BinarySearchPredicate(0, hbm_available, tolerance) + self.search = LuusJaakolaSearch(0, hbm_available, max_iterations=16) logger.info( f"EmbeddingOffloadScaleupProposer - proposed size={round(bytes_to_gb(hbm_used_previously), 2)} GB, score={perf_rating}" ) - # Guide the binary search. We assume the partitioned perf model cost is - # monotonic with respect to CLF, so if the feedback from our prior attempt was - # worse than previous one, we reduce the memory for our next proposal, else we - # try using more. This allows us to focus the budget allocation search into the - # productive region where plans are still getting better. - warmer = partitionable and ( - self.previous_plan_perf_rating == 0.0 - or ( - perf_rating is not None and perf_rating < self.previous_plan_perf_rating - ) - ) - self.previous_plan_perf_rating = perf_rating or 0.0 - assert self.search is not None # keep pyre happy - budget = self.search.next(warmer) + budget = self.search.next(perf_rating or 1e99) + if budget is not None: + budget = int(budget) self.proposal = EmbeddingOffloadScaleupProposer.next_plan( self.starting_proposal, budget, self.enumerator ) diff --git a/torchrec/distributed/planner/tests/test_proposers.py b/torchrec/distributed/planner/tests/test_proposers.py index 89574032c..817d90c70 100644 --- a/torchrec/distributed/planner/tests/test_proposers.py +++ b/torchrec/distributed/planner/tests/test_proposers.py @@ -422,20 +422,21 @@ def test_allocate_budget(self) -> None: self.assertEqual(increase, budget) def test_scaleup(self) -> None: - tables = [ EmbeddingBagConfig( num_embeddings=2_000_000, - embedding_dim=10, + embedding_dim=10000, name=f"table_{i}", feature_names=[f"feature_{i}"], ) - for i in range(3) + for i in range(4) ] - # Place first two tables into cache, 3rd table leave on hbm. table_1 has a + # Place first three tables into cache, 4th table leave on hbm. table_1 has a # larger cacheability score so budget should be skewed to scaling table_1 more - # than table_0. + # than table_0. table_2 is a deprecated feature we have no stats for (so + # expected_lookups 0), we want to see that left at its original load factor, + # i.e. doesn't participate in scaleup. constraints = { "table_0": ParameterConstraints( compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], @@ -451,126 +452,22 @@ def test_scaleup(self) -> None: stats=MockCacheStatistics(expected_lookups=2, cacheability=0.5), ), ), - } - - MB = 1024 * 1024 - storage_constraint = Topology( - world_size=2, compute_device="cuda", hbm_cap=100 * MB, ddr_cap=1000 * MB - ) - - model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) - enumerator = EmbeddingEnumerator( - topology=storage_constraint, batch_size=BATCH_SIZE, constraints=constraints - ) - search_space = enumerator.enumerate( - module=model, - sharders=[ - cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) - ], - ) - proposer = EmbeddingOffloadScaleupProposer() - proposer.load(search_space, enumerator=enumerator) - - output = [] - proposal = proposer.propose() - while proposal is not None: - output.append( - [ - ( - candidate.name, - candidate.compute_kernel, - candidate.cache_params.load_factor - if candidate.cache_params - else None, - ) - for candidate in proposal - ] - ) - proposer.feedback( - partitionable=True, - plan=proposal, - storage_constraint=storage_constraint, - ) - proposal = proposer.propose() - - # Expected output (name, kernel clf). - # First attempt uses the mins supplied, then as we apply increasing budget - # clfs increase, with the later attempts enough to promote table_3 into hbm. - expected_output = [ - [ - ("table_0", "fused_uvm_caching", 0.1), - ("table_1", "fused_uvm_caching", 0.1), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.3025801181793213), - ("table_1", "fused_uvm_caching", 0.6064502596855164), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.403870165348053), - ("table_1", "fused_uvm_caching", 0.859675407409668), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.4545151889324188), - ("table_1", "fused_uvm_caching", 0.9862880110740662), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.5294319987297058), - ("table_1", "fused", None), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.573746383190155), - ("table_1", "fused", None), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.5959035754203796), - ("table_1", "fused", None), - ("table_2", "fused", None), - ], - ] - - self.assertEqual(output, expected_output) - - def test_scaleup_ample_budget_and_deprecated_feature(self) -> None: - tables = [ - EmbeddingBagConfig( - num_embeddings=2_000_000, - embedding_dim=10, - name=f"table_{i}", - feature_names=[f"feature_{i}"], - ) - for i in range(3) - ] - - # Place first two tables into cache, 3rd table leave on hbm. table_1 has an - # expected lookup of 0 (deprecated feature). - constraints = { - "table_0": ParameterConstraints( - compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], - cache_params=CacheParams( - load_factor=0.1, - stats=MockCacheStatistics(expected_lookups=2, cacheability=0.2), - ), - ), - "table_1": ParameterConstraints( + "table_2": ParameterConstraints( compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value], cache_params=CacheParams( - load_factor=0.1, - stats=MockCacheStatistics(expected_lookups=0, cacheability=0), + load_factor=0.002, + stats=MockCacheStatistics(expected_lookups=0, cacheability=0.5), ), ), } - MB = 1024 * 1024 + GB = 1024 * 1024 * 1024 storage_constraint = Topology( - world_size=2, compute_device="cuda", hbm_cap=100 * MB, ddr_cap=1000 * MB + world_size=2, compute_device="cuda", hbm_cap=100 * GB, ddr_cap=1000 * GB ) - + # Ignoring table_2, the remainder require 224GB if all placed on HBM. We only + # have 200GB, so we can't promote both uvm tables. Initial plan needs uses 90GB, + # with 110GB of available budget. model = TestSparseNN(tables=tables, sparse_device=torch.device("meta")) enumerator = EmbeddingEnumerator( topology=storage_constraint, batch_size=BATCH_SIZE, constraints=constraints @@ -584,11 +481,18 @@ def test_scaleup_ample_budget_and_deprecated_feature(self) -> None: proposer = EmbeddingOffloadScaleupProposer() proposer.load(search_space, enumerator=enumerator) - output = [] proposal = proposer.propose() + best_plan = None + best_perf = 1e99 + proposals = -1 while proposal is not None: - output.append( - [ + proposals += 1 + mem = sum(so.total_storage.hbm for so in proposal) + # simple perf model, assume partitioner gives a lowest score around 150GB of memory. + perf = abs(mem - (150 * GB)) + plan = { + "mem": mem, + "proposal": [ ( candidate.name, candidate.compute_kernel, @@ -597,36 +501,36 @@ def test_scaleup_ample_budget_and_deprecated_feature(self) -> None: else None, ) for candidate in proposal - ] - ) + ], + } + if perf < best_perf: + best_plan = plan + best_perf = perf proposer.feedback( partitionable=True, plan=proposal, + perf_rating=perf, storage_constraint=storage_constraint, ) proposal = proposer.propose() - # Expected output (name, kernel clf). - # First attempt uses the mins supplied, then as we apply increasing budget - # clfs increase, table 0 gets promoted, table 1 left as original minimum. - expected_output = [ - [ - ("table_0", "fused_uvm_caching", 0.1), - ("table_1", "fused_uvm_caching", 0.1), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused_uvm_caching", 0.8090304136276245), - ("table_1", "fused_uvm_caching", 0.1), - ("table_2", "fused", None), - ], - [ - ("table_0", "fused", None), - ("table_1", "fused_uvm_caching", 0.1), - ("table_2", "fused", None), - ], - ] - self.assertEqual(output[0:3], expected_output) + self.assertEqual(proposals, 16) + self.assertEqual( + best_plan, + { + # 146GB, close to target of 150GB + "mem": 157178896800, + # table 1 has been scaled up 2.5x more than table 0 (vs original 0.1) + # which aligns with their different cacheability scores + # table_2 has been left alone (deprecated feature, expected zero lookups in stats) + "proposal": [ + ("table_0", "fused_uvm_caching", 0.3173336386680603), + ("table_1", "fused_uvm_caching", 0.6433340907096863), + ("table_2", "fused_uvm_caching", 0.002), + ("table_3", "fused", None), + ], + }, + ) def test_proposers_to_proposals_list(self) -> None: def make_mock_proposal(name: str) -> List[ShardingOption]: