diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index ad2c83e70..5343c9f20 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -166,6 +166,8 @@ def estimate( caching_ratio is not None and sharding_option.cache_params is not None and sharding_option.cache_params.stats is not None + and sharding_option.compute_kernel + == EmbeddingComputeKernel.FUSED_UVM_CACHING.value ): _stats = sharding_option.cache_params.stats expected_cache_fetches = ( diff --git a/torchrec/distributed/planner/tests/test_shard_estimators.py b/torchrec/distributed/planner/tests/test_shard_estimators.py index dafb7550a..47587a3bf 100644 --- a/torchrec/distributed/planner/tests/test_shard_estimators.py +++ b/torchrec/distributed/planner/tests/test_shard_estimators.py @@ -437,7 +437,13 @@ def cacheability(self) -> float: embedding_dim=10, name="table_0", feature_names=["feature_0"], - ) + ), + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="table_1", + feature_names=["feature_1"], + ), ] constraints = { "table_0": ParameterConstraints( @@ -447,6 +453,14 @@ def cacheability(self) -> float: stats=MyCacheStatistics(expected_lookups=200_000, cacheability=0.2), ), ), + # simulate promoting a uvm caching table to HBM during scaleup. + "table_1": ParameterConstraints( + compute_kernels=[EmbeddingComputeKernel.FUSED.value], + cache_params=CacheParams( + load_factor=None, + stats=MyCacheStatistics(expected_lookups=200_000, cacheability=0.2), + ), + ), } enumerator = EmbeddingEnumerator( topology=self.topology, @@ -458,80 +472,48 @@ def cacheability(self) -> float: sharding_options = enumerator.enumerate( module=model, sharders=[ - cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + cast( + ModuleSharder[torch.nn.Module], + EmbeddingBagCollectionSharder( + fused_params={"cache_load_factor": 0.2} + ), + ) ], ) - expected_perfs = { - ("fused_uvm_caching", "column_wise"): [ - Perf( - fwd_compute=0.021661629015717183, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.043323258031434365, - bwd_comms=6.357828776041667e-05, - prefetch_compute=0.014608981562595743, - ) - ], - ("fused_uvm_caching", "row_wise"): [ - Perf( - fwd_compute=0.004501117717551623, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.009002235435103246, - bwd_comms=0.006969980785628689, - prefetch_compute=0.007304490781297871, - ), - Perf( - fwd_compute=0.004501117717551623, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.009002235435103246, - bwd_comms=0.006969980785628689, - prefetch_compute=0.007304490781297871, - ), - ], - ("fused_uvm_caching", "table_column_wise"): [ - Perf( - fwd_compute=0.021661629015717183, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.043323258031434365, - bwd_comms=6.357828776041667e-05, - prefetch_compute=0.014608981562595743, - ) + expected_prefetch_computes = { + ("table_0", "fused_uvm_caching", "column_wise"): [0.014608981562595743], + ("table_0", "fused_uvm_caching", "row_wise"): [ + 0.007304490781297871, + 0.007304490781297871, ], - ("fused_uvm_caching", "table_row_wise"): [ - Perf( - fwd_compute=0.004501117717551623, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.009002235435103246, - bwd_comms=0.006969980785628689, - prefetch_compute=0.007304490781297871, - ), - Perf( - fwd_compute=0.004501117717551623, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.009002235435103246, - bwd_comms=0.006969980785628689, - prefetch_compute=0.007304490781297871, - ), + ("table_0", "fused_uvm_caching", "table_column_wise"): [ + 0.014608981562595743 ], - ("fused_uvm_caching", "table_wise"): [ - Perf( - fwd_compute=0.021661629015717183, - fwd_comms=6.357828776041667e-05, - bwd_compute=0.043323258031434365, - bwd_comms=6.357828776041667e-05, - prefetch_compute=0.014608981562595743, - ) + ("table_0", "fused_uvm_caching", "table_row_wise"): [ + 0.007304490781297871, + 0.007304490781297871, ], + ("table_0", "fused_uvm_caching", "table_wise"): [0.014608981562595743], + ("table_1", "fused", "column_wise"): [0.0], + ("table_1", "fused", "row_wise"): [0.0, 0.0], + ("table_1", "fused", "table_column_wise"): [0.0], + ("table_1", "fused", "table_row_wise"): [0.0, 0.0], + ("table_1", "fused", "table_wise"): [0.0], } - perfs = { + + prefetch_computes = { ( + sharding_option.name, sharding_option.compute_kernel, sharding_option.sharding_type, - ): [shard.perf for shard in sharding_option.shards] + ): [ + shard.perf.prefetch_compute if shard.perf else -1 + for shard in sharding_option.shards + ] for sharding_option in sharding_options } - - self.assertEqual(expected_perfs, perfs) + self.assertEqual(expected_prefetch_computes, prefetch_computes) # pyre-ignore[3]