Skip to content

Commit

Permalink
Switch EmbeddingOffloadScaleupProposer to use Luus Jaakola search (#1691
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #1691

Replace binary-search with Luus Jaakola. This works better at
navigating the non-smooth planner cost surface when exploring cost vs
cache memory.

Reviewed By: henrylhtsang

Differential Revision: D53296719

fbshipit-source-id: e5a43f55204b7e59b0341bab273f86116aa4a630
  • Loading branch information
Damian Reeves authored and facebook-github-bot committed Feb 8, 2024
1 parent 4696b8d commit 8d93bc2
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 178 deletions.
6 changes: 4 additions & 2 deletions torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,11 @@ def plan(
)
if current_storage < lowest_storage:
lowest_storage = current_storage
proposal_cache[proposal_key] = (False, None, None)
proposal_cache[proposal_key] = (False, proposal, None)
proposer.feedback(
partitionable=False, storage_constraint=storage_constraint
partitionable=False,
plan=proposal,
storage_constraint=storage_constraint,
)

# clear shard.rank for each sharding_option
Expand Down
44 changes: 11 additions & 33 deletions torchrec/distributed/planner/proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ShardingOption,
Topology,
)
from torchrec.distributed.planner.utils import BinarySearchPredicate, bytes_to_gb, prod
from torchrec.distributed.planner.utils import bytes_to_gb, LuusJaakolaSearch, prod

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -283,8 +283,7 @@ def __init__(self, use_depth: bool = True) -> None:
self.enumerator: Optional[Enumerator] = None
self.starting_proposal: List[ShardingOption] = []
self.proposal: Optional[List[ShardingOption]] = None
self.search: Optional[BinarySearchPredicate] = None
self.previous_plan_perf_rating: float = 0.0
self.search: Optional[LuusJaakolaSearch] = None

def load(
self,
Expand Down Expand Up @@ -320,7 +319,7 @@ def feedback(
perf_rating: Optional[float] = None,
storage_constraint: Optional[Topology] = None,
) -> None:
if not self.enumerator or not plan or not storage_constraint:
if not self.enumerator or plan is None:
self.proposal = None
return

Expand All @@ -329,47 +328,26 @@ def feedback(
)

if self.search is None:
# Determine how much extra HBM memory is available for scaling our caches
# beyond the baseline-min-working-set plan. We may not be able to find a
# partitionable plan that uses all this budget, or we may find a plan that
# uses only a portion of this budget enables a layout that reduces overall
# cost at the expense of larger prefetch delay. So we perform a binary
# search to sample plans with different budgets to discover a good
# configuration.
if not partitionable or storage_constraint is None:
self.proposal = None
return

hbm_available = EmbeddingOffloadScaleupProposer.get_budget(
plan, storage_constraint
)
logger.info(
f"EmbeddingOffloadScaleupProposer - cache scale up budget={round(bytes_to_gb(hbm_available), 2)} GB, exploring [{round(bytes_to_gb(hbm_used_previously), 2)}, {round(bytes_to_gb(hbm_used_previously + hbm_available), 2)}] GB"
)
# Partitioning proposals is quite expensive when there are a lot of tables,
# so we reduce the number probes the binary search uses to find the max
# cache sizes that fit inside the budget by specifying a tolerance. Once
# we've less than tolerance bytes left of unused budget we stop searching.
# We set tolerance to try to waste less than 3% of budget. For 100TB budget,
# this reduces number of proposals from 47 to 6.
tolerance = round(hbm_available * 0.03)
self.search = BinarySearchPredicate(0, hbm_available, tolerance)
self.search = LuusJaakolaSearch(0, hbm_available, max_iterations=16)

logger.info(
f"EmbeddingOffloadScaleupProposer - proposed size={round(bytes_to_gb(hbm_used_previously), 2)} GB, score={perf_rating}"
)

# Guide the binary search. We assume the partitioned perf model cost is
# monotonic with respect to CLF, so if the feedback from our prior attempt was
# worse than previous one, we reduce the memory for our next proposal, else we
# try using more. This allows us to focus the budget allocation search into the
# productive region where plans are still getting better.
warmer = partitionable and (
self.previous_plan_perf_rating == 0.0
or (
perf_rating is not None and perf_rating < self.previous_plan_perf_rating
)
)
self.previous_plan_perf_rating = perf_rating or 0.0

assert self.search is not None # keep pyre happy
budget = self.search.next(warmer)
budget = self.search.next(perf_rating or 1e99)
if budget is not None:
budget = int(budget)
self.proposal = EmbeddingOffloadScaleupProposer.next_plan(
self.starting_proposal, budget, self.enumerator
)
Expand Down
190 changes: 47 additions & 143 deletions torchrec/distributed/planner/tests/test_proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,20 +422,21 @@ def test_allocate_budget(self) -> None:
self.assertEqual(increase, budget)

def test_scaleup(self) -> None:

tables = [
EmbeddingBagConfig(
num_embeddings=2_000_000,
embedding_dim=10,
embedding_dim=10000,
name=f"table_{i}",
feature_names=[f"feature_{i}"],
)
for i in range(3)
for i in range(4)
]

# Place first two tables into cache, 3rd table leave on hbm. table_1 has a
# Place first three tables into cache, 4th table leave on hbm. table_1 has a
# larger cacheability score so budget should be skewed to scaling table_1 more
# than table_0.
# than table_0. table_2 is a deprecated feature we have no stats for (so
# expected_lookups 0), we want to see that left at its original load factor,
# i.e. doesn't participate in scaleup.
constraints = {
"table_0": ParameterConstraints(
compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value],
Expand All @@ -451,126 +452,22 @@ def test_scaleup(self) -> None:
stats=MockCacheStatistics(expected_lookups=2, cacheability=0.5),
),
),
}

MB = 1024 * 1024
storage_constraint = Topology(
world_size=2, compute_device="cuda", hbm_cap=100 * MB, ddr_cap=1000 * MB
)

model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
enumerator = EmbeddingEnumerator(
topology=storage_constraint, batch_size=BATCH_SIZE, constraints=constraints
)
search_space = enumerator.enumerate(
module=model,
sharders=[
cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())
],
)
proposer = EmbeddingOffloadScaleupProposer()
proposer.load(search_space, enumerator=enumerator)

output = []
proposal = proposer.propose()
while proposal is not None:
output.append(
[
(
candidate.name,
candidate.compute_kernel,
candidate.cache_params.load_factor
if candidate.cache_params
else None,
)
for candidate in proposal
]
)
proposer.feedback(
partitionable=True,
plan=proposal,
storage_constraint=storage_constraint,
)
proposal = proposer.propose()

# Expected output (name, kernel clf).
# First attempt uses the mins supplied, then as we apply increasing budget
# clfs increase, with the later attempts enough to promote table_3 into hbm.
expected_output = [
[
("table_0", "fused_uvm_caching", 0.1),
("table_1", "fused_uvm_caching", 0.1),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.3025801181793213),
("table_1", "fused_uvm_caching", 0.6064502596855164),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.403870165348053),
("table_1", "fused_uvm_caching", 0.859675407409668),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.4545151889324188),
("table_1", "fused_uvm_caching", 0.9862880110740662),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.5294319987297058),
("table_1", "fused", None),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.573746383190155),
("table_1", "fused", None),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.5959035754203796),
("table_1", "fused", None),
("table_2", "fused", None),
],
]

self.assertEqual(output, expected_output)

def test_scaleup_ample_budget_and_deprecated_feature(self) -> None:
tables = [
EmbeddingBagConfig(
num_embeddings=2_000_000,
embedding_dim=10,
name=f"table_{i}",
feature_names=[f"feature_{i}"],
)
for i in range(3)
]

# Place first two tables into cache, 3rd table leave on hbm. table_1 has an
# expected lookup of 0 (deprecated feature).
constraints = {
"table_0": ParameterConstraints(
compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value],
cache_params=CacheParams(
load_factor=0.1,
stats=MockCacheStatistics(expected_lookups=2, cacheability=0.2),
),
),
"table_1": ParameterConstraints(
"table_2": ParameterConstraints(
compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value],
cache_params=CacheParams(
load_factor=0.1,
stats=MockCacheStatistics(expected_lookups=0, cacheability=0),
load_factor=0.002,
stats=MockCacheStatistics(expected_lookups=0, cacheability=0.5),
),
),
}

MB = 1024 * 1024
GB = 1024 * 1024 * 1024
storage_constraint = Topology(
world_size=2, compute_device="cuda", hbm_cap=100 * MB, ddr_cap=1000 * MB
world_size=2, compute_device="cuda", hbm_cap=100 * GB, ddr_cap=1000 * GB
)

# Ignoring table_2, the remainder require 224GB if all placed on HBM. We only
# have 200GB, so we can't promote both uvm tables. Initial plan needs uses 90GB,
# with 110GB of available budget.
model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
enumerator = EmbeddingEnumerator(
topology=storage_constraint, batch_size=BATCH_SIZE, constraints=constraints
Expand All @@ -584,11 +481,18 @@ def test_scaleup_ample_budget_and_deprecated_feature(self) -> None:
proposer = EmbeddingOffloadScaleupProposer()
proposer.load(search_space, enumerator=enumerator)

output = []
proposal = proposer.propose()
best_plan = None
best_perf = 1e99
proposals = -1
while proposal is not None:
output.append(
[
proposals += 1
mem = sum(so.total_storage.hbm for so in proposal)
# simple perf model, assume partitioner gives a lowest score around 150GB of memory.
perf = abs(mem - (150 * GB))
plan = {
"mem": mem,
"proposal": [
(
candidate.name,
candidate.compute_kernel,
Expand All @@ -597,36 +501,36 @@ def test_scaleup_ample_budget_and_deprecated_feature(self) -> None:
else None,
)
for candidate in proposal
]
)
],
}
if perf < best_perf:
best_plan = plan
best_perf = perf
proposer.feedback(
partitionable=True,
plan=proposal,
perf_rating=perf,
storage_constraint=storage_constraint,
)
proposal = proposer.propose()

# Expected output (name, kernel clf).
# First attempt uses the mins supplied, then as we apply increasing budget
# clfs increase, table 0 gets promoted, table 1 left as original minimum.
expected_output = [
[
("table_0", "fused_uvm_caching", 0.1),
("table_1", "fused_uvm_caching", 0.1),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.8090304136276245),
("table_1", "fused_uvm_caching", 0.1),
("table_2", "fused", None),
],
[
("table_0", "fused", None),
("table_1", "fused_uvm_caching", 0.1),
("table_2", "fused", None),
],
]
self.assertEqual(output[0:3], expected_output)
self.assertEqual(proposals, 16)
self.assertEqual(
best_plan,
{
# 146GB, close to target of 150GB
"mem": 157178896800,
# table 1 has been scaled up 2.5x more than table 0 (vs original 0.1)
# which aligns with their different cacheability scores
# table_2 has been left alone (deprecated feature, expected zero lookups in stats)
"proposal": [
("table_0", "fused_uvm_caching", 0.3173336386680603),
("table_1", "fused_uvm_caching", 0.6433340907096863),
("table_2", "fused_uvm_caching", 0.002),
("table_3", "fused", None),
],
},
)

def test_proposers_to_proposals_list(self) -> None:
def make_mock_proposal(name: str) -> List[ShardingOption]:
Expand Down

0 comments on commit 8d93bc2

Please sign in to comment.