diff --git a/torchrec/distributed/planner/proposers.py b/torchrec/distributed/planner/proposers.py index 4848ce01b..d1a1542ad 100644 --- a/torchrec/distributed/planner/proposers.py +++ b/torchrec/distributed/planner/proposers.py @@ -309,11 +309,14 @@ def load( ] # deepcopy so it won't affect other proposers self.starting_proposal = copy.deepcopy(proposal) - self.promote_high_prefetch_overheaad_table_to_hbm(self.starting_proposal) + self.promote_high_prefetch_overheaad_table_to_hbm( + self.enumerator, self.starting_proposal + ) self.proposal = copy.deepcopy(self.starting_proposal) + @staticmethod def promote_high_prefetch_overheaad_table_to_hbm( - self, proposal: List[ShardingOption] + enumerator: Optional[Enumerator], proposal: List[ShardingOption] ) -> None: """ Prefetch overhead is related to IO. When it's larger than saved memory from @@ -322,7 +325,7 @@ def promote_high_prefetch_overheaad_table_to_hbm( This function will end up updating proposal. """ - if not self.enumerator: + if not enumerator: return what_if_hbm_proposal = copy.deepcopy(proposal) what_if_hbm_cached_tables = [ @@ -345,8 +348,8 @@ def promote_high_prefetch_overheaad_table_to_hbm( sharding_option.compute_kernel = EmbeddingComputeKernel.FUSED.value # appease pyre - assert self.enumerator - self.enumerator.populate_estimates(what_if_hbm_cached_tables) + assert enumerator + enumerator.populate_estimates(what_if_hbm_cached_tables) # Now what_if_hbm_proposal contain estimated storage for all HBM case. If # it's even smaller than offloaded case, we promote it to HBM @@ -368,8 +371,8 @@ def promote_high_prefetch_overheaad_table_to_hbm( # In the end, update the storage cost for new proposal # appease pyre - assert self.enumerator - self.enumerator.populate_estimates(original_cached_tables) + assert enumerator + enumerator.populate_estimates(original_cached_tables) def propose(self) -> Optional[List[ShardingOption]]: return self.proposal @@ -426,7 +429,9 @@ def feedback( self.starting_proposal, budget, self.enumerator ) if self.proposal is not None: - self.promote_high_prefetch_overheaad_table_to_hbm(self.proposal) + self.promote_high_prefetch_overheaad_table_to_hbm( + self.enumerator, self.proposal + ) @staticmethod def get_budget(proposal: List[ShardingOption], storage_constraint: Topology) -> int: