Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch EmbeddingOffloadScaleupProposer to use Luus Jaakola search #1691

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,11 @@ def plan(
)
if current_storage < lowest_storage:
lowest_storage = current_storage
proposal_cache[proposal_key] = (False, None, None)
proposal_cache[proposal_key] = (False, proposal, None)
proposer.feedback(
partitionable=False, storage_constraint=storage_constraint
partitionable=False,
plan=proposal,
storage_constraint=storage_constraint,
)

# clear shard.rank for each sharding_option
Expand Down
44 changes: 11 additions & 33 deletions torchrec/distributed/planner/proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ShardingOption,
Topology,
)
from torchrec.distributed.planner.utils import BinarySearchPredicate, bytes_to_gb, prod
from torchrec.distributed.planner.utils import bytes_to_gb, LuusJaakolaSearch, prod

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -283,8 +283,7 @@ def __init__(self, use_depth: bool = True) -> None:
self.enumerator: Optional[Enumerator] = None
self.starting_proposal: List[ShardingOption] = []
self.proposal: Optional[List[ShardingOption]] = None
self.search: Optional[BinarySearchPredicate] = None
self.previous_plan_perf_rating: float = 0.0
self.search: Optional[LuusJaakolaSearch] = None

def load(
self,
Expand Down Expand Up @@ -320,7 +319,7 @@ def feedback(
perf_rating: Optional[float] = None,
storage_constraint: Optional[Topology] = None,
) -> None:
if not self.enumerator or not plan or not storage_constraint:
if not self.enumerator or plan is None:
self.proposal = None
return

Expand All @@ -329,47 +328,26 @@ def feedback(
)

if self.search is None:
# Determine how much extra HBM memory is available for scaling our caches
# beyond the baseline-min-working-set plan. We may not be able to find a
# partitionable plan that uses all this budget, or we may find a plan that
# uses only a portion of this budget enables a layout that reduces overall
# cost at the expense of larger prefetch delay. So we perform a binary
# search to sample plans with different budgets to discover a good
# configuration.
if not partitionable or storage_constraint is None:
self.proposal = None
return

hbm_available = EmbeddingOffloadScaleupProposer.get_budget(
plan, storage_constraint
)
logger.info(
f"EmbeddingOffloadScaleupProposer - cache scale up budget={round(bytes_to_gb(hbm_available), 2)} GB, exploring [{round(bytes_to_gb(hbm_used_previously), 2)}, {round(bytes_to_gb(hbm_used_previously + hbm_available), 2)}] GB"
)
# Partitioning proposals is quite expensive when there are a lot of tables,
# so we reduce the number probes the binary search uses to find the max
# cache sizes that fit inside the budget by specifying a tolerance. Once
# we've less than tolerance bytes left of unused budget we stop searching.
# We set tolerance to try to waste less than 3% of budget. For 100TB budget,
# this reduces number of proposals from 47 to 6.
tolerance = round(hbm_available * 0.03)
self.search = BinarySearchPredicate(0, hbm_available, tolerance)
self.search = LuusJaakolaSearch(0, hbm_available, max_iterations=16)

logger.info(
f"EmbeddingOffloadScaleupProposer - proposed size={round(bytes_to_gb(hbm_used_previously), 2)} GB, score={perf_rating}"
)

# Guide the binary search. We assume the partitioned perf model cost is
# monotonic with respect to CLF, so if the feedback from our prior attempt was
# worse than previous one, we reduce the memory for our next proposal, else we
# try using more. This allows us to focus the budget allocation search into the
# productive region where plans are still getting better.
warmer = partitionable and (
self.previous_plan_perf_rating == 0.0
or (
perf_rating is not None and perf_rating < self.previous_plan_perf_rating
)
)
self.previous_plan_perf_rating = perf_rating or 0.0

assert self.search is not None # keep pyre happy
budget = self.search.next(warmer)
budget = self.search.next(perf_rating or 1e99)
if budget is not None:
budget = int(budget)
self.proposal = EmbeddingOffloadScaleupProposer.next_plan(
self.starting_proposal, budget, self.enumerator
)
Expand Down
190 changes: 47 additions & 143 deletions torchrec/distributed/planner/tests/test_proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,20 +422,21 @@ def test_allocate_budget(self) -> None:
self.assertEqual(increase, budget)

def test_scaleup(self) -> None:

tables = [
EmbeddingBagConfig(
num_embeddings=2_000_000,
embedding_dim=10,
embedding_dim=10000,
name=f"table_{i}",
feature_names=[f"feature_{i}"],
)
for i in range(3)
for i in range(4)
]

# Place first two tables into cache, 3rd table leave on hbm. table_1 has a
# Place first three tables into cache, 4th table leave on hbm. table_1 has a
# larger cacheability score so budget should be skewed to scaling table_1 more
# than table_0.
# than table_0. table_2 is a deprecated feature we have no stats for (so
# expected_lookups 0), we want to see that left at its original load factor,
# i.e. doesn't participate in scaleup.
constraints = {
"table_0": ParameterConstraints(
compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value],
Expand All @@ -451,126 +452,22 @@ def test_scaleup(self) -> None:
stats=MockCacheStatistics(expected_lookups=2, cacheability=0.5),
),
),
}

MB = 1024 * 1024
storage_constraint = Topology(
world_size=2, compute_device="cuda", hbm_cap=100 * MB, ddr_cap=1000 * MB
)

model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
enumerator = EmbeddingEnumerator(
topology=storage_constraint, batch_size=BATCH_SIZE, constraints=constraints
)
search_space = enumerator.enumerate(
module=model,
sharders=[
cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())
],
)
proposer = EmbeddingOffloadScaleupProposer()
proposer.load(search_space, enumerator=enumerator)

output = []
proposal = proposer.propose()
while proposal is not None:
output.append(
[
(
candidate.name,
candidate.compute_kernel,
candidate.cache_params.load_factor
if candidate.cache_params
else None,
)
for candidate in proposal
]
)
proposer.feedback(
partitionable=True,
plan=proposal,
storage_constraint=storage_constraint,
)
proposal = proposer.propose()

# Expected output (name, kernel clf).
# First attempt uses the mins supplied, then as we apply increasing budget
# clfs increase, with the later attempts enough to promote table_3 into hbm.
expected_output = [
[
("table_0", "fused_uvm_caching", 0.1),
("table_1", "fused_uvm_caching", 0.1),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.3025801181793213),
("table_1", "fused_uvm_caching", 0.6064502596855164),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.403870165348053),
("table_1", "fused_uvm_caching", 0.859675407409668),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.4545151889324188),
("table_1", "fused_uvm_caching", 0.9862880110740662),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.5294319987297058),
("table_1", "fused", None),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.573746383190155),
("table_1", "fused", None),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.5959035754203796),
("table_1", "fused", None),
("table_2", "fused", None),
],
]

self.assertEqual(output, expected_output)

def test_scaleup_ample_budget_and_deprecated_feature(self) -> None:
tables = [
EmbeddingBagConfig(
num_embeddings=2_000_000,
embedding_dim=10,
name=f"table_{i}",
feature_names=[f"feature_{i}"],
)
for i in range(3)
]

# Place first two tables into cache, 3rd table leave on hbm. table_1 has an
# expected lookup of 0 (deprecated feature).
constraints = {
"table_0": ParameterConstraints(
compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value],
cache_params=CacheParams(
load_factor=0.1,
stats=MockCacheStatistics(expected_lookups=2, cacheability=0.2),
),
),
"table_1": ParameterConstraints(
"table_2": ParameterConstraints(
compute_kernels=[EmbeddingComputeKernel.FUSED_UVM_CACHING.value],
cache_params=CacheParams(
load_factor=0.1,
stats=MockCacheStatistics(expected_lookups=0, cacheability=0),
load_factor=0.002,
stats=MockCacheStatistics(expected_lookups=0, cacheability=0.5),
),
),
}

MB = 1024 * 1024
GB = 1024 * 1024 * 1024
storage_constraint = Topology(
world_size=2, compute_device="cuda", hbm_cap=100 * MB, ddr_cap=1000 * MB
world_size=2, compute_device="cuda", hbm_cap=100 * GB, ddr_cap=1000 * GB
)

# Ignoring table_2, the remainder require 224GB if all placed on HBM. We only
# have 200GB, so we can't promote both uvm tables. Initial plan needs uses 90GB,
# with 110GB of available budget.
model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
enumerator = EmbeddingEnumerator(
topology=storage_constraint, batch_size=BATCH_SIZE, constraints=constraints
Expand All @@ -584,11 +481,18 @@ def test_scaleup_ample_budget_and_deprecated_feature(self) -> None:
proposer = EmbeddingOffloadScaleupProposer()
proposer.load(search_space, enumerator=enumerator)

output = []
proposal = proposer.propose()
best_plan = None
best_perf = 1e99
proposals = -1
while proposal is not None:
output.append(
[
proposals += 1
mem = sum(so.total_storage.hbm for so in proposal)
# simple perf model, assume partitioner gives a lowest score around 150GB of memory.
perf = abs(mem - (150 * GB))
plan = {
"mem": mem,
"proposal": [
(
candidate.name,
candidate.compute_kernel,
Expand All @@ -597,36 +501,36 @@ def test_scaleup_ample_budget_and_deprecated_feature(self) -> None:
else None,
)
for candidate in proposal
]
)
],
}
if perf < best_perf:
best_plan = plan
best_perf = perf
proposer.feedback(
partitionable=True,
plan=proposal,
perf_rating=perf,
storage_constraint=storage_constraint,
)
proposal = proposer.propose()

# Expected output (name, kernel clf).
# First attempt uses the mins supplied, then as we apply increasing budget
# clfs increase, table 0 gets promoted, table 1 left as original minimum.
expected_output = [
[
("table_0", "fused_uvm_caching", 0.1),
("table_1", "fused_uvm_caching", 0.1),
("table_2", "fused", None),
],
[
("table_0", "fused_uvm_caching", 0.8090304136276245),
("table_1", "fused_uvm_caching", 0.1),
("table_2", "fused", None),
],
[
("table_0", "fused", None),
("table_1", "fused_uvm_caching", 0.1),
("table_2", "fused", None),
],
]
self.assertEqual(output[0:3], expected_output)
self.assertEqual(proposals, 16)
self.assertEqual(
best_plan,
{
# 146GB, close to target of 150GB
"mem": 157178896800,
# table 1 has been scaled up 2.5x more than table 0 (vs original 0.1)
# which aligns with their different cacheability scores
# table_2 has been left alone (deprecated feature, expected zero lookups in stats)
"proposal": [
("table_0", "fused_uvm_caching", 0.3173336386680603),
("table_1", "fused_uvm_caching", 0.6433340907096863),
("table_2", "fused_uvm_caching", 0.002),
("table_3", "fused", None),
],
},
)

def test_proposers_to_proposals_list(self) -> None:
def make_mock_proposal(name: str) -> List[ShardingOption]:
Expand Down
Loading
Loading