Skip to content

Commit

Permalink
set the right cuda device in DataLoadingThread (pytorch#1742)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1742

We need to set the cuda device in DataLoadingThread.

Without it, GPU 0 will be used by the data loading threads of other GPUs.

Reviewed By: joshuadeng

Differential Revision: D54402994

fbshipit-source-id: 47a071ac40eebf3703e3a63fffdd256933b5b12c
  • Loading branch information
leitian authored and facebook-github-bot committed Mar 2, 2024
1 parent 75772b9 commit cd8b538
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torchrec/distributed/train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def __init__(
self._buffer_empty_event.set()

def run(self) -> None:
if self._device.type == "cuda" and torch.cuda.is_available():
# set the current device the same as the one used in the main thread
torch.cuda.set_device(self._device)

while not self._stop:
self._buffer_empty_event.wait()
# Set the filled event to unblock progress() and return.
Expand Down

0 comments on commit cd8b538

Please sign in to comment.