From 0e758f75ce1ed294b51d8c9a3de1761e0f0a2761 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Mon, 17 Jun 2024 10:52:30 -0700 Subject: [PATCH] Use at least one cache sets per TBE (#2116) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2116 I added assertion to check for cache sets > 0, since it would fail without error when cache sets = 0. Fixing in this diff to use at least one cache sets, so we don't have to worry about rounding error anymore. Example: table size is 30. Which means cache sets is int(30 * 0.2 / 32) = 0 Differential Revision: D58626864 --- torchrec/distributed/batched_embedding_kernel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index b556430c4..fc2a514b5 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -112,7 +112,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: cache_load_factor = 0.2 local_rows_sum: int = sum(table.local_rows for table in config.embedding_tables) - ssd_tbe_params["cache_sets"] = int(cache_load_factor * local_rows_sum / ASSOC) + ssd_tbe_params["cache_sets"] = max( + int(cache_load_factor * local_rows_sum / ASSOC), 1 + ) # populate init min and max if (