Skip to content

Commit a59ef93

Browse files
levythufacebook-github-bot
authored andcommitted
Make promote_high_prefetch_overheaad_table_to_hbm a static function so other proposer can reuse (#2191)
Summary: Pull Request resolved: #2191 As titled. This function is proven to be useful across all embedding offloading related planning. Make it a static function so other proposer can use it. Reviewed By: henrylhtsang Differential Revision: D59136171 fbshipit-source-id: 0f0fec8895711e1c84990bc1181e571e4185bee1
1 parent d34462d commit a59ef93

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

torchrec/distributed/planner/proposers.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,14 @@ def load(
309309
]
310310
# deepcopy so it won't affect other proposers
311311
self.starting_proposal = copy.deepcopy(proposal)
312-
self.promote_high_prefetch_overheaad_table_to_hbm(self.starting_proposal)
312+
self.promote_high_prefetch_overheaad_table_to_hbm(
313+
self.enumerator, self.starting_proposal
314+
)
313315
self.proposal = copy.deepcopy(self.starting_proposal)
314316

317+
@staticmethod
315318
def promote_high_prefetch_overheaad_table_to_hbm(
316-
self, proposal: List[ShardingOption]
319+
enumerator: Optional[Enumerator], proposal: List[ShardingOption]
317320
) -> None:
318321
"""
319322
Prefetch overhead is related to IO. When it's larger than saved memory from
@@ -322,7 +325,7 @@ def promote_high_prefetch_overheaad_table_to_hbm(
322325
323326
This function will end up updating proposal.
324327
"""
325-
if not self.enumerator:
328+
if not enumerator:
326329
return
327330
what_if_hbm_proposal = copy.deepcopy(proposal)
328331
what_if_hbm_cached_tables = [
@@ -345,8 +348,8 @@ def promote_high_prefetch_overheaad_table_to_hbm(
345348
sharding_option.compute_kernel = EmbeddingComputeKernel.FUSED.value
346349

347350
# appease pyre
348-
assert self.enumerator
349-
self.enumerator.populate_estimates(what_if_hbm_cached_tables)
351+
assert enumerator
352+
enumerator.populate_estimates(what_if_hbm_cached_tables)
350353

351354
# Now what_if_hbm_proposal contain estimated storage for all HBM case. If
352355
# it's even smaller than offloaded case, we promote it to HBM
@@ -368,8 +371,8 @@ def promote_high_prefetch_overheaad_table_to_hbm(
368371
# In the end, update the storage cost for new proposal
369372

370373
# appease pyre
371-
assert self.enumerator
372-
self.enumerator.populate_estimates(original_cached_tables)
374+
assert enumerator
375+
enumerator.populate_estimates(original_cached_tables)
373376

374377
def propose(self) -> Optional[List[ShardingOption]]:
375378
return self.proposal
@@ -426,7 +429,9 @@ def feedback(
426429
self.starting_proposal, budget, self.enumerator
427430
)
428431
if self.proposal is not None:
429-
self.promote_high_prefetch_overheaad_table_to_hbm(self.proposal)
432+
self.promote_high_prefetch_overheaad_table_to_hbm(
433+
self.enumerator, self.proposal
434+
)
430435

431436
@staticmethod
432437
def get_budget(proposal: List[ShardingOption], storage_constraint: Topology) -> int:

0 commit comments

Comments
 (0)