Skip to content

Commit

Permalink
Take cache_load_factor from sharder if sharding option doesn't have it (
Browse files Browse the repository at this point in the history
#1644)

Summary:
Pull Request resolved: #1644

Sometimes cache_load_factor is passed through sharders. This is not the ideal way of passing cache_load_factor, but for the time being, we still allow it. So we should reflect that in the stats printout.

Reviewed By: ge0405

Differential Revision: D52921387

fbshipit-source-id: f7c70aa9136f71148f21c63ec3257ca18cdc1b60
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jan 24, 2024
1 parent 98a28ad commit 91d679c
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion torchrec/distributed/planner/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,16 @@ def log(
or so.sharding_type == ShardingType.TABLE_COLUMN_WISE.value
else f"{so.tensor.shape[1]}"
)
cache_load_factor = str(so.cache_load_factor)
sharder_cache_load_factor = (
sharder.fused_params.get("cache_load_factor") # pyre-ignore[16]
if hasattr(sharder, "fused_params") and sharder.fused_params
else None
)
cache_load_factor = str(
so.cache_load_factor
if so.cache_load_factor is not None
else sharder_cache_load_factor
)
hash_size = so.tensor.shape[0]
param_table.append(
[
Expand Down

0 comments on commit 91d679c

Please sign in to comment.