Skip to content

Commit 97dc6c3

Browse files
committed
[call_jax] Bridge the torch_xla and JAX mesh
Now we can run a JAX SPMD function that accesses the ambient SPMD mesh from xb.call_jax. Fixes #8972. Also I beefed up the assume_pure tests and updated the docs to mention that mark_sharding is supported thanks to qihqi@' #8989.
1 parent 1b6630e commit 97dc6c3

21 files changed

+206
-80
lines changed

benchmarks/benchmark_model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,9 @@ def is_compatible(self, dummy_benchmark_model: BenchmarkModel,
227227
def get_benchmark_indices(self, length: int):
228228
start = self._args.partition_id * (length // self._args.total_partitions)
229229
end = ((self._args.partition_id + 1) *
230-
(length // self._args.total_partitions) if self._args.partition_id
231-
< self._args.total_partitions - 1 else length)
230+
(length // self._args.total_partitions)
231+
if self._args.partition_id < self._args.total_partitions - 1 else
232+
length)
232233
return start, end
233234

234235
def skip_model(self, model_name: str):

docs/source/perf/assume_pure.md

+7-2
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,13 @@ a fixed up-front cost, and then later runs will reuse the cached XLA computation
122122
## Limitations
123123

124124
Currently, all operations in a function wrapped with `@assume_pure` must be
125-
PyTorch upstream operations (e.g. `torch.einsum`, `torch.sin`, ...). More
126-
PyTorch/XLA operations (e.g. `mark_sharding`) will be supported in the future.
125+
PyTorch upstream operations (e.g. `torch.einsum`, `torch.sin`, ...), or these
126+
PyTorch/XLA operations:
127+
* `torch_xla.experimental.assume_pure` (recursive `assume_pure`)
128+
* `torch_xla.distributed.spmd.mark_sharding`
129+
130+
More PyTorch/XLA operations (e.g. `flash_attention`) will be supported in the
131+
future.
127132

128133
<!-- xrefs -->
129134

test/pytorch_test_base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -619,8 +619,8 @@ def skipped_test(self, *args, reason=reason, **kwargs):
619619
setattr(cls, dtype_test_name, disallowed_test)
620620
if not skipped:
621621
xla_dtypes.append(
622-
dtype_combination if len(dtype_combination) >
623-
1 else dtype_combination[0])
622+
dtype_combination
623+
if len(dtype_combination) > 1 else dtype_combination[0])
624624
if len(xla_dtypes) != 0:
625625
test.dtypes[cls.device_type] = xla_dtypes
626626
super().instantiate_test(name, test, generic_cls=generic_cls)

test/run_tests.sh

+2
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ function run_xla_op_tests2 {
211211
run_test "$CDIR/test_callback.py"
212212
XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py"
213213
run_test "$CDIR/test_jax_interop.py"
214+
run_test "$CDIR/test_jax_interop_spmd.py"
214215
run_test "$CDIR/test_assume_pure.py"
216+
run_test "$CDIR/test_assume_pure_spmd.py"
215217
}
216218

217219
# All the new xla op tests should go to run_xla_op_tests3

test/scan/test_scan_spmd.py

-22
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch_xla
99
import torch.nn as nn
1010
from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear, Mesh
11-
from torch_xla.experimental.assume_pure import assume_pure
1211
from torch_xla.experimental.scan import scan
1312
from torch_xla.experimental.scan_layers import scan_layers
1413
from torch_xla.distributed.spmd import mark_sharding, mark_sharding_with_gradients, set_global_mesh, get_1d_mesh, get_global_mesh
@@ -231,27 +230,6 @@ def check_dots_in_model(self, model, x, expect_pattern):
231230
def count_regex(self, hlo_text, regex_str):
232231
return len(re.findall(regex_str, hlo_text))
233232

234-
@unittest.skipIf(
235-
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
236-
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
237-
)
238-
def test_assume_pure_works_with_mark_sharding(self):
239-
x = torch.randn((3, 4, 5, 128), device='xla')
240-
assume_pure(mark_sharding)(x, self.spmd_mesh, ("model", None, None, None))
241-
# assert not throwing
242-
243-
@unittest.skipIf(
244-
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
245-
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
246-
)
247-
def test_convert_to_jax_mesh(self):
248-
jax_mesh = self.spmd_mesh.maybe_convert_and_get_jax_mesh()
249-
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
250-
np.testing.assert_equal(
251-
np.array([dev.id for dev in jax_mesh.devices.flatten()]),
252-
self.spmd_mesh.device_ids)
253-
# assert not throwing
254-
255233

256234
if __name__ == '__main__':
257235
test = unittest.main()

test/spmd/test_xla_sharding.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,9 @@ def test_inplace_add_with_sharding(self):
618618

619619
# avoid calling xr.addressable_device_count here otherwise it will init the test
620620
# in non-spmd mode.
621-
@unittest.skipIf(
622-
xr.device_type() == 'CPU',
623-
"sharding will be the same for both tensors on single device")
621+
@unittest.skipIf(xr.device_type() == 'CPU',
622+
"sharding will be the same for both tensors on single device"
623+
)
624624
def test_shard_hashing(self):
625625
xt1 = torch.ones(2, 2).to(xm.xla_device())
626626
xt2 = torch.ones(2, 2).to(xm.xla_device())
@@ -1383,9 +1383,8 @@ def test_get_1d_mesh(self):
13831383
self.assertEqual(mesh_without_name.mesh_shape,
13841384
(xr.global_runtime_device_count(),))
13851385

1386-
@unittest.skipUnless(
1387-
xr.global_runtime_device_count() > 1,
1388-
"Multiple devices required for dataloader sharding test")
1386+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1387+
"Multiple devices required for dataloader sharding test")
13891388
def test_data_loader_with_sharding(self):
13901389
device = torch_xla.device()
13911390
mesh = xs.get_1d_mesh("data")
@@ -1406,9 +1405,8 @@ def test_data_loader_with_sharding(self):
14061405
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
14071406
)
14081407

1409-
@unittest.skipUnless(
1410-
xr.global_runtime_device_count() > 1,
1411-
"Multiple devices required for dataloader sharding test")
1408+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1409+
"Multiple devices required for dataloader sharding test")
14121410
def test_data_loader_with_non_batch_size(self):
14131411
device = torch_xla.device()
14141412
mesh = xs.get_1d_mesh("data")
@@ -1429,9 +1427,8 @@ def test_data_loader_with_non_batch_size(self):
14291427
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
14301428
)
14311429

1432-
@unittest.skipUnless(
1433-
xr.global_runtime_device_count() > 1,
1434-
"Multiple devices required for dataloader sharding test")
1430+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
1431+
"Multiple devices required for dataloader sharding test")
14351432
def test_data_loader_with_non_batch_size_and_mini_batch(self):
14361433
device = torch_xla.device()
14371434
mesh = xs.get_1d_mesh("data")
@@ -1663,9 +1660,9 @@ def test_get_logical_mesh(self):
16631660
self.assertEqual(logical_mesh.shape, mesh_shape)
16641661
np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids)
16651662

1666-
@unittest.skipIf(
1667-
xr.device_type() == 'CPU',
1668-
"sharding will be the same for both tensors on single device")
1663+
@unittest.skipIf(xr.device_type() == 'CPU',
1664+
"sharding will be the same for both tensors on single device"
1665+
)
16691666
def test_shard_as(self):
16701667
mesh = self._get_mesh((self.n_devices,))
16711668
partition_spec = (0,)

test/test_assume_pure_spmd.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import os
2+
import sys
3+
import unittest
4+
5+
import numpy as np
6+
import torch
7+
import torch_xla
8+
import torch_xla.runtime as xr
9+
from torch_xla.experimental.assume_pure import assume_pure
10+
from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, get_1d_mesh, Mesh
11+
12+
13+
class AssumePureSpmdTest(unittest.TestCase):
14+
15+
def setUp(self):
16+
# Activate SPMD
17+
xr.use_spmd()
18+
19+
# Set up a simple SPMD mesh for these tests.
20+
self.spmd_mesh = get_1d_mesh(axis_name="model")
21+
set_global_mesh(self.spmd_mesh)
22+
23+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
24+
"Multiple devices required")
25+
@unittest.skipIf(
26+
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
27+
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
28+
)
29+
def test_assume_pure_works_with_mark_sharding(self):
30+
x = torch.randn((8, 4, 5, 128), device='xla')
31+
result = assume_pure(mark_sharding)(x, self.spmd_mesh,
32+
("model", None, None, None))
33+
torch_xla.sync(wait=True)
34+
N = xr.global_runtime_device_count()
35+
self.assertIn(f'devices=[{N}',
36+
torch_xla._XLAC._get_xla_sharding_spec(result))
37+
38+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
39+
"Multiple devices required")
40+
@unittest.skipIf(
41+
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
42+
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
43+
)
44+
def test_convert_to_jax_mesh(self):
45+
jax_mesh = self.spmd_mesh.get_jax_mesh()
46+
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
47+
np.testing.assert_equal(
48+
np.array([dev.id for dev in jax_mesh.devices.flatten()]),
49+
self.spmd_mesh.device_ids)
50+
51+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
52+
"Multiple devices required")
53+
@unittest.skipUnless(os.environ.get('PJRT_DEVICE') == 'TPU', "TPU only test")
54+
def test_convert_to_jax_mesh_shuffled(self):
55+
"""Test get_jax_mesh when the PyTorch/XLA mesh has a custom order."""
56+
57+
# Arrange
58+
num_devices = xr.global_runtime_device_count()
59+
device_ids = np.arange(num_devices)
60+
device_ids = np.random.permutation(device_ids)
61+
self.spmd_mesh = Mesh(
62+
device_ids, mesh_shape=(num_devices,), axis_names=('model',))
63+
64+
# Act
65+
jax_mesh = self.spmd_mesh.get_jax_mesh()
66+
67+
# Assert
68+
torch_xla_devices = np.array(
69+
[xr.global_runtime_device_attributes()[i] for i in device_ids])
70+
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
71+
np.testing.assert_equal(
72+
np.array([dev.coords for dev in jax_mesh.devices.flatten()]),
73+
np.array([dev['coords'] for dev in torch_xla_devices.flatten()]),
74+
)
75+
76+
77+
if __name__ == '__main__':
78+
test = unittest.main()
79+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/test_jax_interop_spmd.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import sys
2+
import unittest
3+
4+
import torch
5+
import torch_xla
6+
import torch_xla.core.xla_model as xm
7+
import torch_xla.core.xla_builder as xb
8+
import torch_xla.runtime as xr
9+
from torch_xla.distributed.spmd import set_global_mesh, get_1d_mesh
10+
11+
12+
class TestJaxInteropSpmd(unittest.TestCase):
13+
14+
def setUp(self):
15+
xb._JAX_TO_XLA_COMPUTATION_CACHE.clear()
16+
# Activate SPMD
17+
xr.use_spmd()
18+
19+
# Set up a simple SPMD mesh for these tests.
20+
self.spmd_mesh = get_1d_mesh(axis_name="model")
21+
set_global_mesh(self.spmd_mesh)
22+
23+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
24+
"Multiple devices required")
25+
def test_call_jax_sharding_constraints(self):
26+
"""Test that we can call jax.lax.with_sharding_constraints from PyTorch/XLA."""
27+
28+
# Arrange
29+
a = torch.ones((8, 8), device='xla')
30+
31+
def f(a, b):
32+
import jax
33+
from jax.sharding import PartitionSpec as P
34+
import jax.numpy as jnp
35+
return jax.lax.with_sharding_constraint(a, P("model",)) + jnp.sin(b)
36+
37+
# Act
38+
result = xb.call_jax(f, (a, a))
39+
torch_xla.sync(wait=True)
40+
41+
# Assert
42+
N = xr.global_runtime_device_count()
43+
self.assertIn(f'devices=[{N}',
44+
torch_xla._XLAC._get_xla_sharding_spec(result))
45+
46+
47+
if __name__ == "__main__":
48+
test = unittest.main()
49+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/test_operations.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -2959,9 +2959,11 @@ def test_dlpack_roundtrip_tensor(self, dtype):
29592959

29602960
@onlyIfTorchSupportsCUDA
29612961
@onlyIfPJRTDeviceIsCUDA
2962-
@parameterized.parameters(
2963-
*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool,
2964-
torch.uint16, torch.uint32, torch.uint64))
2962+
@parameterized.parameters(*all_types_and_complex_and(torch.half,
2963+
torch.bfloat16,
2964+
torch.bool, torch.uint16,
2965+
torch.uint32,
2966+
torch.uint64))
29652967
def test_dlpack_roundtrip_scalar(self, dtype):
29662968
xla_device = xm.xla_device()
29672969
xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device)

test/test_pallas.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ class PallasTest(parameterized.TestCase):
4141
# therefore we use != instead of ==.
4242
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
4343
kv_segment_ids):
44-
return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1],
45-
1) != kv_segment_ids.view(kv_segment_ids.shape[0],
46-
1, 1,
47-
kv_segment_ids.shape[1])
44+
return q_segment_ids.view(q_segment_ids.shape[0], 1,
45+
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
46+
kv_segment_ids.shape[0], 1, 1,
47+
kv_segment_ids.shape[1])
4848

4949
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
5050
attn_weight = q @ k.transpose(-2, -1)

test/test_pallas_spmd.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ class PallasTest(unittest.TestCase):
4141
# therefore we use != instead of ==.
4242
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
4343
kv_segment_ids):
44-
return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1],
45-
1) != kv_segment_ids.view(kv_segment_ids.shape[0],
46-
1, 1,
47-
kv_segment_ids.shape[1])
44+
return q_segment_ids.view(q_segment_ids.shape[0], 1,
45+
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
46+
kv_segment_ids.shape[0], 1, 1,
47+
kv_segment_ids.shape[1])
4848

4949
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
5050
attn_weight = q @ k.transpose(-2, -1)

test/test_splash_attention.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def setUp(self):
6262

6363
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
6464
kv_segment_ids):
65-
return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1],
66-
1) != kv_segment_ids.view(kv_segment_ids.shape[0],
67-
1, 1,
68-
kv_segment_ids.shape[1])
65+
return q_segment_ids.view(q_segment_ids.shape[0], 1,
66+
q_segment_ids.shape[1], 1) != kv_segment_ids.view(
67+
kv_segment_ids.shape[0], 1, 1,
68+
kv_segment_ids.shape[1])
6969

7070
def maybe_repeat_kv(self, hidden_state):
7171
if hidden_state.size(1) == self.NUM_Q_HEADS:

test/tpu/run_tests.sh

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ python3 "$TEST_CDIR/scan/test_scan_pallas.py"
3737
python3 "$TEST_CDIR/scan/test_scan_layers.py"
3838
python3 "$TEST_CDIR/test_gru.py"
3939
python3 "$TEST_CDIR/test_assume_pure.py"
40+
python3 "$TEST_CDIR/test_assume_pure_spmd.py"
41+
python3 "$TEST_CDIR/test_jax_interop_spmd.py"
4042
python3 "$TEST_CDIR/test_as_stride_use_slice.py"
4143
run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
4244
python3 "$TEST_CDIR/test_pallas.py" -v

torch_xla/core/xla_builder.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -920,8 +920,19 @@ def get_xla_computation():
920920
import torch_xla.debug.profiler as xp
921921
# If we see this trace span in the profiler, we'll know that there's a cache miss.
922922
with xp.Trace('jax_to_xla_computation'):
923-
lowered = jax.jit(fn, keep_unused=True).lower(*sample_tensor_args)
924-
hlo_ir = lowered.compiler_ir('hlo')
923+
jitted = jax.jit(fn, keep_unused=True)
924+
925+
def do_lower():
926+
import torch_xla.runtime as xr
927+
import torch_xla.distributed.spmd as xs
928+
if xr.is_spmd():
929+
mesh = xs.get_global_mesh()
930+
if mesh is not None:
931+
with mesh.get_jax_mesh():
932+
return jitted.lower(*sample_tensor_args)
933+
return jitted.lower(*sample_tensor_args)
934+
935+
hlo_ir = do_lower().compiler_ir('hlo')
925936
assert len(traced_out_spec) == 1, \
926937
"fn must be traced to obtain the output tree spec"
927938
spec = traced_out_spec[0]

torch_xla/distributed/spmd/xla_sharding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def from_str(cls, mesh_str: str) -> Optional["Mesh"]:
182182
return None
183183

184184
@requires_jax
185-
def maybe_convert_and_get_jax_mesh(self):
185+
def get_jax_mesh(self):
186186
# Construct a JAX mesh object with the same device ids shape and ordering
187187
# from torch_xla device mesh.
188188
import jax
@@ -611,7 +611,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
611611
if tx is not None and isinstance(t, tx.tensor.Tensor):
612612
from jax.sharding import PartitionSpec as P, NamedSharding
613613
op_sharding = tuple(str(i) if i is not None else i for i in partition_spec)
614-
jmesh = mesh.maybe_convert_and_get_jax_mesh()
614+
jmesh = mesh.get_jax_mesh()
615615
t.shard_(NamedSharding(jmesh, P(*op_sharding)))
616616
return t
617617

torch_xla/distributed/xla_multiprocessing.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,7 @@ def _v6e_create_replica_groups() -> List | None:
174174
return None
175175

176176

177-
device_kind_handler_dict: dict[
178-
str,
179-
Callable[..., List | None],
180-
] = {
177+
device_kind_handler_dict: dict[str, Callable[..., List | None],] = {
181178
_TPU_V5P: _v5p_create_replica_groups,
182179
_TPU_V6E: _v6e_create_replica_groups
183180
}

0 commit comments

Comments
 (0)