Skip to content

Commit

Permalink
Fix hbm-promoted table prefetch estimate. (#1620)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1620

Fixes a bug where promoted tables were erroneously more expensive in
the planner.

Previously, when an embedding offload table was scaled up to use
memory budget and there was enough budget to promote it to dedicated
HBM, and there was a fused param default for cache_load_factor,
prefetch_compute delay was incorrectly calculated to assume it was
still uvm_caching with the given default CLF. HBM only tables should
always have 0 prefetch_compute cost.

Kudos to Henry who quickly identified this problem & proposed the
solution implemented here.

Reviewed By: henrylhtsang

Differential Revision: D52649994

fbshipit-source-id: 99ce7a9ccceb78eba76fbcbf54e4678eab676af0
  • Loading branch information
Damian Reeves authored and facebook-github-bot committed Jan 10, 2024
1 parent 7351978 commit 813671c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 63 deletions.
2 changes: 2 additions & 0 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def estimate(
caching_ratio is not None
and sharding_option.cache_params is not None
and sharding_option.cache_params.stats is not None
and sharding_option.compute_kernel
== EmbeddingComputeKernel.FUSED_UVM_CACHING.value
):
_stats = sharding_option.cache_params.stats
expected_cache_fetches = (
Expand Down
108 changes: 45 additions & 63 deletions torchrec/distributed/planner/tests/test_shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,13 @@ def cacheability(self) -> float:
embedding_dim=10,
name="table_0",
feature_names=["feature_0"],
)
),
EmbeddingBagConfig(
num_embeddings=100,
embedding_dim=10,
name="table_1",
feature_names=["feature_1"],
),
]
constraints = {
"table_0": ParameterConstraints(
Expand All @@ -447,6 +453,14 @@ def cacheability(self) -> float:
stats=MyCacheStatistics(expected_lookups=200_000, cacheability=0.2),
),
),
# simulate promoting a uvm caching table to HBM during scaleup.
"table_1": ParameterConstraints(
compute_kernels=[EmbeddingComputeKernel.FUSED.value],
cache_params=CacheParams(
load_factor=None,
stats=MyCacheStatistics(expected_lookups=200_000, cacheability=0.2),
),
),
}
enumerator = EmbeddingEnumerator(
topology=self.topology,
Expand All @@ -458,80 +472,48 @@ def cacheability(self) -> float:
sharding_options = enumerator.enumerate(
module=model,
sharders=[
cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())
cast(
ModuleSharder[torch.nn.Module],
EmbeddingBagCollectionSharder(
fused_params={"cache_load_factor": 0.2}
),
)
],
)

expected_perfs = {
("fused_uvm_caching", "column_wise"): [
Perf(
fwd_compute=0.021661629015717183,
fwd_comms=6.357828776041667e-05,
bwd_compute=0.043323258031434365,
bwd_comms=6.357828776041667e-05,
prefetch_compute=0.014608981562595743,
)
],
("fused_uvm_caching", "row_wise"): [
Perf(
fwd_compute=0.004501117717551623,
fwd_comms=6.357828776041667e-05,
bwd_compute=0.009002235435103246,
bwd_comms=0.006969980785628689,
prefetch_compute=0.007304490781297871,
),
Perf(
fwd_compute=0.004501117717551623,
fwd_comms=6.357828776041667e-05,
bwd_compute=0.009002235435103246,
bwd_comms=0.006969980785628689,
prefetch_compute=0.007304490781297871,
),
],
("fused_uvm_caching", "table_column_wise"): [
Perf(
fwd_compute=0.021661629015717183,
fwd_comms=6.357828776041667e-05,
bwd_compute=0.043323258031434365,
bwd_comms=6.357828776041667e-05,
prefetch_compute=0.014608981562595743,
)
expected_prefetch_computes = {
("table_0", "fused_uvm_caching", "column_wise"): [0.014608981562595743],
("table_0", "fused_uvm_caching", "row_wise"): [
0.007304490781297871,
0.007304490781297871,
],
("fused_uvm_caching", "table_row_wise"): [
Perf(
fwd_compute=0.004501117717551623,
fwd_comms=6.357828776041667e-05,
bwd_compute=0.009002235435103246,
bwd_comms=0.006969980785628689,
prefetch_compute=0.007304490781297871,
),
Perf(
fwd_compute=0.004501117717551623,
fwd_comms=6.357828776041667e-05,
bwd_compute=0.009002235435103246,
bwd_comms=0.006969980785628689,
prefetch_compute=0.007304490781297871,
),
("table_0", "fused_uvm_caching", "table_column_wise"): [
0.014608981562595743
],
("fused_uvm_caching", "table_wise"): [
Perf(
fwd_compute=0.021661629015717183,
fwd_comms=6.357828776041667e-05,
bwd_compute=0.043323258031434365,
bwd_comms=6.357828776041667e-05,
prefetch_compute=0.014608981562595743,
)
("table_0", "fused_uvm_caching", "table_row_wise"): [
0.007304490781297871,
0.007304490781297871,
],
("table_0", "fused_uvm_caching", "table_wise"): [0.014608981562595743],
("table_1", "fused", "column_wise"): [0.0],
("table_1", "fused", "row_wise"): [0.0, 0.0],
("table_1", "fused", "table_column_wise"): [0.0],
("table_1", "fused", "table_row_wise"): [0.0, 0.0],
("table_1", "fused", "table_wise"): [0.0],
}
perfs = {

prefetch_computes = {
(
sharding_option.name,
sharding_option.compute_kernel,
sharding_option.sharding_type,
): [shard.perf for shard in sharding_option.shards]
): [
shard.perf.prefetch_compute if shard.perf else -1
for shard in sharding_option.shards
]
for sharding_option in sharding_options
}

self.assertEqual(expected_perfs, perfs)
self.assertEqual(expected_prefetch_computes, prefetch_computes)


# pyre-ignore[3]
Expand Down

0 comments on commit 813671c

Please sign in to comment.