Skip to content

Commit

Permalink
enable tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxiaoli73 committed Feb 13, 2025
1 parent cc24e89 commit a5f1ca3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
16 changes: 8 additions & 8 deletions test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,11 +441,11 @@ def _test_train_parity_multi_group(
orig_reduce_scatter = dist.reduce_scatter_tensor

def delayed_all_gather(*args, **kwargs):
torch.xpu._sleep(int(delay_in_ms * get_cycles_per_ms()))
# torch.xpu._sleep(int(delay_in_ms * get_cycles_per_ms()))
return orig_all_gather(*args, **kwargs)

def delayed_reduce_scatter(*args, **kwargs):
torch.xpu._sleep(int(delay_in_ms * get_cycles_per_ms()))
# torch.xpu._sleep(int(delay_in_ms * get_cycles_per_ms()))
return orig_reduce_scatter(*args, **kwargs)

torch.manual_seed(42 + self.rank + 1)
Expand All @@ -466,11 +466,11 @@ def delayed_reduce_scatter(*args, **kwargs):
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum())
if _model is model and delay_after_forward:
torch.xpu._sleep(int(delay_in_ms * get_cycles_per_ms()))
# if _model is model and delay_after_forward:
# torch.xpu._sleep(int(delay_in_ms * get_cycles_per_ms()))
losses[-1].backward()
if _model is model and delay_before_optim:
torch.xpu._sleep(int(delay_in_ms * get_cycles_per_ms()))
# if _model is model and delay_before_optim:
# torch.xpu._sleep(int(delay_in_ms * get_cycles_per_ms()))
_optim.step()
self.assertEqual(losses[0], losses[1])

Expand Down Expand Up @@ -509,7 +509,7 @@ def test_non_root_forward_backward(self):

root_loss = model(inp).sum()
root_loss.backward()
torch.xpu._sleep(int(100 * get_cycles_per_ms()))
# torch.xpu._sleep(int(100 * get_cycles_per_ms()))
optim.step()
optim.zero_grad()
nonroot_loss = model[0](inp).sum()
Expand Down Expand Up @@ -638,7 +638,7 @@ def step_post_hook(
optim.step()
# Sleep after the optimizer step to allow CPU to run ahead into the
# next iteration's forward, exercising the post-optim stream sync
torch.xpu._sleep(int(25 * get_cycles_per_ms()))
# torch.xpu._sleep(int(25 * get_cycles_per_ms()))
for ref_loss, loss in zip(ref_losses, losses):
self.assertEqual(ref_loss, loss)

Expand Down
2 changes: 2 additions & 0 deletions torch/testing/_internal/common_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,8 @@ def backend(self, device) -> str:
return "nccl"
elif "hpu" in device : # intel gaudi
return "hccl"
elif "xpu" in device:
return "xccl"
else :
return "gloo"

Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/distributed/multi_threaded_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def _create_threaded_pg(prefix_store, rank, world_size, timeout):
return pg


dist.Backend.register_backend("threaded", _create_threaded_pg, devices=["cpu", "cuda"])
dist.Backend.register_backend("threaded", _create_threaded_pg, devices=["cpu", "cuda", "xpu"])


@dataclass
Expand Down

0 comments on commit a5f1ca3

Please sign in to comment.