Skip to content

Commit

Permalink
Add fused compute kernel to PT2 multiprocess test (#2235)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2235

As titled

Reviewed By: TroyGarden

Differential Revision: D59872821

fbshipit-source-id: 241d09e3d5836803fca6734553c76b567b57958f
  • Loading branch information
gnahzg authored and facebook-github-bot committed Jul 19, 2024
1 parent 50ecc5c commit a68a99f
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions torchrec/distributed/tests/test_pt2_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ def disable_cuda_tf32(self) -> bool:
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.DENSE.value,
EmbeddingComputeKernel.FUSED.value,
],
),
given_config_tuple=st.sampled_from(
Expand Down

0 comments on commit a68a99f

Please sign in to comment.