diff --git a/torchrec/distributed/tests/test_pt2_multiprocess.py b/torchrec/distributed/tests/test_pt2_multiprocess.py index 250b01542..9960db8b8 100644 --- a/torchrec/distributed/tests/test_pt2_multiprocess.py +++ b/torchrec/distributed/tests/test_pt2_multiprocess.py @@ -509,6 +509,7 @@ def disable_cuda_tf32(self) -> bool: kernel_type=st.sampled_from( [ EmbeddingComputeKernel.DENSE.value, + EmbeddingComputeKernel.FUSED.value, ], ), given_config_tuple=st.sampled_from(