Skip to content

Commit

Permalink
add max L1 cache size in torchrec
Browse files Browse the repository at this point in the history
Summary:
Currently we use cache load factor, a ratio to embedding table size to decide L1 cache size.
However when using ssd offloading, we usually see extremely large embedding table, keep using ratios might be meaningless and not straightforward, this diff provide a new way to set cap L1 cache size instead

Reviewed By: jiayulu

Differential Revision: D67071717
  • Loading branch information
duduyi2013 authored and facebook-github-bot committed Dec 12, 2024
1 parent 575e081 commit a1bf330
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
16 changes: 16 additions & 0 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
SSDTableBatchedEmbeddingBags.__init__
).parameters.keys()
invalid_keys: List[str] = []

for key, value in fused_params.items():
if key not in ssd_tbe_signature:
invalid_keys.append(key)
Expand Down Expand Up @@ -151,6 +152,21 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
weights_precision = data_type_to_sparse_type(config.data_type)
ssd_tbe_params["weights_precision"] = weights_precision

if "max_l1_cache_size" in fused_params:
l1_cache_size = fused_params.get("max_l1_cache_size") * 1024 * 1024
max_dim: int = max(table.local_cols for table in config.embedding_tables)
weight_precision_bytes = ssd_tbe_params["weights_precision"].bit_rate() / 8
max_cache_sets = (
l1_cache_size / ASSOC / weight_precision_bytes / max_dim
) # 100MB

if ssd_tbe_params["cache_sets"] > int(max_cache_sets):
logger.warning(
f"cache_sets {ssd_tbe_params['cache_sets']} is larger than max_cache_sets {max_cache_sets} calculated "
"by max_l1_cache_size, cap at max_cache_sets instead"
)
ssd_tbe_params["cache_sets"] = int(max_cache_sets)

return ssd_tbe_params


Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,8 @@ class KeyValueParams:
gather_ssd_cache_stats: Optional[bool] = None
stats_reporter_config: Optional[TBEStatsReporterConfig] = None
use_passed_in_path: bool = True
l2_cache_size: Optional[int] = None
l2_cache_size: Optional[int] = None # size in GB
max_l1_cache_size: Optional[int] = None # size in MB
enable_async_update: Optional[bool] = None

# Parameter Server (PS) Attributes
Expand All @@ -673,6 +674,7 @@ def __hash__(self) -> int:
self.gather_ssd_cache_stats,
self.stats_reporter_config,
self.l2_cache_size,
self.max_l1_cache_size,
self.enable_async_update,
)
)
Expand Down

0 comments on commit a1bf330

Please sign in to comment.