Skip to content

Commit 9d89a1f

Browse files
committed
enable xpu tests
1 parent 3d97844 commit 9d89a1f

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

torch/testing/_internal/common_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5179,14 +5179,18 @@ def get_cycles_per_ms() -> float:
51795179
"""
51805180

51815181
def measure() -> float:
5182-
start = torch.cuda.Event(enable_timing=True)
5183-
end = torch.cuda.Event(enable_timing=True)
5184-
start.record()
5185-
torch.cuda._sleep(1000000)
5186-
end.record()
5187-
end.synchronize()
5188-
cycles_per_ms = 1000000 / start.elapsed_time(end)
5189-
return cycles_per_ms
5182+
if torch.cuda.is_available():
5183+
start = torch.cuda.Event(enable_timing=True)
5184+
end = torch.cuda.Event(enable_timing=True)
5185+
start.record()
5186+
torch.cuda._sleep(1000000)
5187+
end.record()
5188+
end.synchronize()
5189+
cycles_per_ms = 1000000 / start.elapsed_time(end)
5190+
return cycles_per_ms
5191+
elif torch.xpu.is_available():
5192+
cycles_per_ms = 1000000 / 1000.0
5193+
return cycles_per_ms
51905194

51915195
# Get 10 values and remove the 2 max and 2 min and return the avg.
51925196
# This is to avoid system disturbance that skew the results, e.g.

torch/testing/_internal/distributed/_tensor/common_dtensor.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torch.testing._internal.common_utils import (
3333
TEST_HPU,
3434
TEST_CUDA,
35+
TEST_XPU
3536
)
3637
from torch.testing._internal.common_distributed import (
3738
MultiProcessTestCase,
@@ -52,6 +53,10 @@
5253
DEVICE_TYPE = "hpu"
5354
PG_BACKEND = "hccl"
5455
DEVICE_COUNT = _get_device_module("hpu").device_count()
56+
elif TEST_XPU:
57+
DEVICE_TYPE = "xpu"
58+
PG_BACKEND = "xccl"
59+
DEVICE_COUNT = _get_device_module("xpu").device_count()
5560
else:
5661
DEVICE_TYPE = "cpu"
5762
PG_BACKEND = "gloo"
@@ -325,6 +330,8 @@ def backend(self) -> str:
325330
backend = "nccl"
326331
elif TEST_HPU:
327332
backend = "hccl"
333+
elif TEST_XPU:
334+
backend = "xccl"
328335
else:
329336
backend = "gloo"
330337
return backend
@@ -396,10 +403,10 @@ def wrapper(
396403
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
397404
) -> None:
398405
# if enough GPU we can use GPU, otherwise we fallback to CPU
399-
if not TEST_CUDA or torch.cuda.device_count() < self.world_size:
400-
self.device_type = "cpu"
401-
else:
402-
self.device_type = DEVICE_TYPE
406+
# if not TEST_CUDA or torch.cuda.device_count() < self.world_size:
407+
# self.device_type = "cpu"
408+
# else:
409+
self.device_type = DEVICE_TYPE #zl_debug need to refine
403410

404411
self.init_pg(eager_init)
405412

0 commit comments

Comments
 (0)