diff --git a/torchrec/distributed/mc_embedding_modules.py b/torchrec/distributed/mc_embedding_modules.py index ff9f71e25..b33287be7 100644 --- a/torchrec/distributed/mc_embedding_modules.py +++ b/torchrec/distributed/mc_embedding_modules.py @@ -264,7 +264,10 @@ def compute_kernels( sharding_type: str, compute_device_type: str, ) -> List[str]: - return [EmbeddingComputeKernel.FUSED.value] + return [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + ] def sharding_types(self, compute_device_type: str) -> List[str]: return list(