diff --git a/torchrec/distributed/planner/stats.py b/torchrec/distributed/planner/stats.py index 9455b7549..bc3f090f9 100644 --- a/torchrec/distributed/planner/stats.py +++ b/torchrec/distributed/planner/stats.py @@ -16,6 +16,7 @@ from torch import nn +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.planner.constants import BIGINT_DTYPE, NUM_POOLINGS from torchrec.distributed.planner.shard_estimators import _calculate_shard_io_sizes from torchrec.distributed.planner.storage_reservations import ( @@ -421,11 +422,14 @@ def log( if hasattr(sharder, "fused_params") and sharder.fused_params else None ) - cache_load_factor = str( - so.cache_load_factor - if so.cache_load_factor is not None - else sharder_cache_load_factor - ) + cache_load_factor = "None" + # Surfacing cache load factor does not make sense if not using uvm caching. + if so.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value: + cache_load_factor = str( + so.cache_load_factor + if so.cache_load_factor is not None + else sharder_cache_load_factor + ) hash_size = so.tensor.shape[0] param_table.append( [