Skip to content

Commit

Permalink
Fix invalid record_stream arg for Request awaitable (pytorch#2634)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2634

pytorch#2598 causes failures when running on cpu:

```
RuntimeError: unknown parameter type
```

This is because `record_stream` doesn't accept CPU streams. This diff forward fixes the issue.

Differential Revision: D67171994

fbshipit-source-id: a9ffddfcd978d343a58fb039a4fcd89c336ceee6
  • Loading branch information
sarckk authored and facebook-github-bot committed Dec 13, 2024
1 parent 4a2b291 commit 3928a1b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
2 changes: 1 addition & 1 deletion torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _wait_impl(self) -> W:
"""

ret = self.wait_function.apply(self.pg, self, self.dummy_tensor)
if isinstance(ret, torch.Tensor):
if isinstance(ret, torch.Tensor) and ret.device.type == "cuda":
ret.record_stream(torch.get_device_module(ret.device).current_stream())
self.req = None
self.tensor = None
Expand Down
27 changes: 27 additions & 0 deletions torchrec/distributed/tests/test_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,3 +767,30 @@ def test_all_gather_base_pooled(
specify_pg=specify_pg,
gradient_division=gradient_division,
)

@classmethod
def _test_all_gather_base_pooled_cpu(
cls,
rank: int,
world_size: int,
backend: str,
) -> None:
pg = GroupMember.WORLD
if pg is None:
dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
pg = GroupMember.WORLD

device = torch.device(f"cpu")
input_tensor = torch.randn([4, 4], requires_grad=True).to(device)
comm_ops.all_gather_base_pooled(input_tensor, pg).wait()
dist.destroy_process_group()

def test_all_gather_base_pooled_cpu(
self,
) -> None:
self._run_multi_process_test(
world_size=self.WORLD_SIZE,
backend="gloo",
# pyre-ignore [6]
callable=self._test_all_gather_base_pooled_cpu,
)
4 changes: 0 additions & 4 deletions torchrec/distributed/tests/test_sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,6 @@ def _test_sharding(

@skip_if_asan_class
class ConstructParameterShardingAndShardTest(MultiProcessTestBase):
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
# pyre-fixme[56]
@given(
per_param_sharding=st.sampled_from(
Expand Down

0 comments on commit 3928a1b

Please sign in to comment.