From eaf7f79197417e844a4c982bac345e51eb5accfa Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Fri, 9 Feb 2024 15:43:47 -0800 Subject: [PATCH] Add uvm to allowed compute kernel for zch (#1695) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1695 Allow zch to use uvm. Reviewed By: dstaay-fb Differential Revision: D53598269 fbshipit-source-id: 53951d8ef0998202026343590d096cc0ea4e415b --- torchrec/distributed/mc_embedding_modules.py | 1 + torchrec/distributed/planner/tests/test_enumerators.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/mc_embedding_modules.py b/torchrec/distributed/mc_embedding_modules.py index fd4a77e5e..0d62241ab 100644 --- a/torchrec/distributed/mc_embedding_modules.py +++ b/torchrec/distributed/mc_embedding_modules.py @@ -267,6 +267,7 @@ def compute_kernels( return [ EmbeddingComputeKernel.FUSED.value, EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, ] def sharding_types(self, compute_device_type: str) -> List[str]: diff --git a/torchrec/distributed/planner/tests/test_enumerators.py b/torchrec/distributed/planner/tests/test_enumerators.py index 0b61be814..43a1f7289 100644 --- a/torchrec/distributed/planner/tests/test_enumerators.py +++ b/torchrec/distributed/planner/tests/test_enumerators.py @@ -782,7 +782,10 @@ def test_filter_compute_kernels_mch_ebc(self) -> None: self.assertEqual( set(allowed_compute_kernels), - {EmbeddingComputeKernel.FUSED.value}, + { + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM.value, + }, ) def test_filter_compute_kernels_mch_ebc_no_available(self) -> None: