From 813671ce4b7a2b731414d6c36fc55e5dcd83c278 Mon Sep 17 00:00:00 2001 From: Damian Reeves Date: Wed, 10 Jan 2024 12:33:13 -0800 Subject: [PATCH] Fix hbm-promoted table prefetch estimate. (#1620) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1620 Fixes a bug where promoted tables were erroneously more expensive in the planner. Previously, when an embedding offload table was scaled up to use memory budget and there was enough budget to promote it to dedicated HBM, and there was a fused param default for cache_load_factor, prefetch_compute delay was incorrectly calculated to assume it was still uvm_caching with the given default CLF. HBM only tables should always have 0 prefetch_compute cost. Kudos to Henry who quickly identified this problem & proposed the solution implemented here. Reviewed By: henrylhtsang Differential Revision: D52649994 fbshipit-source-id: 99ce7a9ccceb78eba76fbcbf54e4678eab676af0 --- .../distributed/planner/shard_estimators.py | 2 + .../planner/tests/test_shard_estimators.py | 108 ++++++++---------- 2 files changed, 47 insertions(+), 63 deletions(-) diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index ad2c83e70..5343c9f20 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -166,6 +166,8 @@ def estimate( caching_ratio is not None and sharding_option.cache_params is not None and sharding_option.cache_params.stats is not None + and sharding_option.compute_kernel + == EmbeddingComputeKernel.FUSED_UVM_CACHING.value ): _stats = sharding_option.cache_params.stats expected_cache_fetches = ( diff --git a/torchrec/distributed/planner/tests/test_shard_estimators.py b/torchrec/distributed/planner/tests/test_shard_estimators.py index dafb7550a..47587a3bf 100644 --- a/torchrec/distributed/planner/tests/test_shard_estimators.py +++ b/torchrec/distributed/planner/tests/test_shard_estimators.py @@ -437,7 +437,13 @@ def cacheability(self) -> float: embedding_dim=10, name="table_0", feature_names=["feature_0"], - ) + ), + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_1", + feature_names=["feature_1"], + ), ] constraints = { "table_0": ParameterConstraints( @@ -447,6 +453,14 @@ def cacheability(self) -> float: stats=MyCacheStatistics(expected_lookups=200_000, cacheability=0.2), ), ), + # simulate promoting a uvm caching table to HBM during scaleup. + "table_1": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED.value], + cache_params=CacheParams( + load_factor=None, + stats=MyCacheStatistics(expected_lookups=200_000, cacheability=0.2), + ), + ), } enumerator = EmbeddingEnumerator( topology=self.topology, @@ -458,80 +472,48 @@ def cacheability(self) -> float: sharding_options = enumerator.enumerate( module=model, sharders=[ - cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + cast( + ModuleSharder[torch.nn.Module], + EmbeddingBagCollectionSharder( + fused_params={"cache_load_factor": 0.2} + ), + ) ], ) - expected_perfs = { - ("fused_uvm_caching", "column_wise"): [ - Perf( - fwd_compute=0.021661629015717183, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.043323258031434365, - bwd_comms=6.357828776041667e-05, - prefetch_compute=0.014608981562595743, - ) - ], - ("fused_uvm_caching", "row_wise"): [ - Perf( - fwd_compute=0.004501117717551623, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.009002235435103246, - bwd_comms=0.006969980785628689, - prefetch_compute=0.007304490781297871, - ), - Perf( - fwd_compute=0.004501117717551623, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.009002235435103246, - bwd_comms=0.006969980785628689, - prefetch_compute=0.007304490781297871, - ), - ], - ("fused_uvm_caching", "table_column_wise"): [ - Perf( - fwd_compute=0.021661629015717183, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.043323258031434365, - bwd_comms=6.357828776041667e-05, - prefetch_compute=0.014608981562595743, - ) + expected_prefetch_computes = { + ("table_0", "fused_uvm_caching", "column_wise"): [0.014608981562595743], + ("table_0", "fused_uvm_caching", "row_wise"): [ + 0.007304490781297871, + 0.007304490781297871, ], - ("fused_uvm_caching", "table_row_wise"): [ - Perf( - fwd_compute=0.004501117717551623, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.009002235435103246, - bwd_comms=0.006969980785628689, - prefetch_compute=0.007304490781297871, - ), - Perf( - fwd_compute=0.004501117717551623, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.009002235435103246, - bwd_comms=0.006969980785628689, - prefetch_compute=0.007304490781297871, - ), + ("table_0", "fused_uvm_caching", "table_column_wise"): [ + 0.014608981562595743 ], - ("fused_uvm_caching", "table_wise"): [ - Perf( - fwd_compute=0.021661629015717183, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.043323258031434365, - bwd_comms=6.357828776041667e-05, - prefetch_compute=0.014608981562595743, - ) + ("table_0", "fused_uvm_caching", "table_row_wise"): [ + 0.007304490781297871, + 0.007304490781297871, ], + ("table_0", "fused_uvm_caching", "table_wise"): [0.014608981562595743], + ("table_1", "fused", "column_wise"): [0.0], + ("table_1", "fused", "row_wise"): [0.0, 0.0], + ("table_1", "fused", "table_column_wise"): [0.0], + ("table_1", "fused", "table_row_wise"): [0.0, 0.0], + ("table_1", "fused", "table_wise"): [0.0], } - perfs = { + + prefetch_computes = { ( + sharding_option.name, sharding_option.compute_kernel, sharding_option.sharding_type, - ): [shard.perf for shard in sharding_option.shards] + ): [ + shard.perf.prefetch_compute if shard.perf else -1 + for shard in sharding_option.shards + ] for sharding_option in sharding_options } - - self.assertEqual(expected_perfs, perfs) + self.assertEqual(expected_prefetch_computes, prefetch_computes) # pyre-ignore[3]