Skip to content

Commit

Permalink
Make promote_high_prefetch_overheaad_table_to_hbm a static function s…
Browse files Browse the repository at this point in the history
…o other proposer can reuse (#2191)

Summary:
Pull Request resolved: #2191

As titled. This function is proven to be useful across all embedding offloading related planning. Make it a static function so other proposer can use it.

Reviewed By: henrylhtsang

Differential Revision: D59136171

fbshipit-source-id: 0f0fec8895711e1c84990bc1181e571e4185bee1
  • Loading branch information
levythu authored and facebook-github-bot committed Jun 28, 2024
1 parent d34462d commit a59ef93
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions torchrec/distributed/planner/proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,14 @@ def load(
]
# deepcopy so it won't affect other proposers
self.starting_proposal = copy.deepcopy(proposal)
self.promote_high_prefetch_overheaad_table_to_hbm(self.starting_proposal)
self.promote_high_prefetch_overheaad_table_to_hbm(
self.enumerator, self.starting_proposal
)
self.proposal = copy.deepcopy(self.starting_proposal)

@staticmethod
def promote_high_prefetch_overheaad_table_to_hbm(
self, proposal: List[ShardingOption]
enumerator: Optional[Enumerator], proposal: List[ShardingOption]
) -> None:
"""
Prefetch overhead is related to IO. When it's larger than saved memory from
Expand All @@ -322,7 +325,7 @@ def promote_high_prefetch_overheaad_table_to_hbm(
This function will end up updating proposal.
"""
if not self.enumerator:
if not enumerator:
return
what_if_hbm_proposal = copy.deepcopy(proposal)
what_if_hbm_cached_tables = [
Expand All @@ -345,8 +348,8 @@ def promote_high_prefetch_overheaad_table_to_hbm(
sharding_option.compute_kernel = EmbeddingComputeKernel.FUSED.value

# appease pyre
assert self.enumerator
self.enumerator.populate_estimates(what_if_hbm_cached_tables)
assert enumerator
enumerator.populate_estimates(what_if_hbm_cached_tables)

# Now what_if_hbm_proposal contain estimated storage for all HBM case. If
# it's even smaller than offloaded case, we promote it to HBM
Expand All @@ -368,8 +371,8 @@ def promote_high_prefetch_overheaad_table_to_hbm(
# In the end, update the storage cost for new proposal

# appease pyre
assert self.enumerator
self.enumerator.populate_estimates(original_cached_tables)
assert enumerator
enumerator.populate_estimates(original_cached_tables)

def propose(self) -> Optional[List[ShardingOption]]:
return self.proposal
Expand Down Expand Up @@ -426,7 +429,9 @@ def feedback(
self.starting_proposal, budget, self.enumerator
)
if self.proposal is not None:
self.promote_high_prefetch_overheaad_table_to_hbm(self.proposal)
self.promote_high_prefetch_overheaad_table_to_hbm(
self.enumerator, self.proposal
)

@staticmethod
def get_budget(proposal: List[ShardingOption], storage_constraint: Topology) -> int:
Expand Down

0 comments on commit a59ef93

Please sign in to comment.