Skip to content

Commit

Permalink
Add uvm caching compute kernel to mch sharders
Browse files Browse the repository at this point in the history
Summary: Add uvm caching to allowed kernels.

Reviewed By: dstaay-fb

Differential Revision: D52919490
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Jan 19, 2024
1 parent a6bad42 commit 282e81d
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion torchrec/distributed/mc_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,10 @@ def compute_kernels(
sharding_type: str,
compute_device_type: str,
) -> List[str]:
return [EmbeddingComputeKernel.FUSED.value]
return [
EmbeddingComputeKernel.FUSED.value,
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
]

def sharding_types(self, compute_device_type: str) -> List[str]:
return list(
Expand Down

0 comments on commit 282e81d

Please sign in to comment.