From 91d679c4a290bb7c29368a2574e0a8b7a4de7f74 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Tue, 23 Jan 2024 16:43:49 -0800 Subject: [PATCH] Take cache_load_factor from sharder if sharding option doesn't have it (#1644) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1644 Sometimes cache_load_factor is passed through sharders. This is not the ideal way of passing cache_load_factor, but for the time being, we still allow it. So we should reflect that in the stats printout. Reviewed By: ge0405 Differential Revision: D52921387 fbshipit-source-id: f7c70aa9136f71148f21c63ec3257ca18cdc1b60 --- torchrec/distributed/planner/stats.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/planner/stats.py b/torchrec/distributed/planner/stats.py index e2cbdc90f..1cd637149 100644 --- a/torchrec/distributed/planner/stats.py +++ b/torchrec/distributed/planner/stats.py @@ -329,7 +329,16 @@ def log( or so.sharding_type == ShardingType.TABLE_COLUMN_WISE.value else f"{so.tensor.shape[1]}" ) - cache_load_factor = str(so.cache_load_factor) + sharder_cache_load_factor = ( + sharder.fused_params.get("cache_load_factor") # pyre-ignore[16] + if hasattr(sharder, "fused_params") and sharder.fused_params + else None + ) + cache_load_factor = str( + so.cache_load_factor + if so.cache_load_factor is not None + else sharder_cache_load_factor + ) hash_size = so.tensor.shape[0] param_table.append( [