Skip to content

Commit 7d681a9

Browse files
authored
Improve assume_pure docs and tests (#9038)
1 parent caa40cc commit 7d681a9

File tree

9 files changed

+100
-31
lines changed

9 files changed

+100
-31
lines changed

docs/source/perf/assume_pure.md

+7-2
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,13 @@ a fixed up-front cost, and then later runs will reuse the cached XLA computation
122122
## Limitations
123123

124124
Currently, all operations in a function wrapped with `@assume_pure` must be
125-
PyTorch upstream operations (e.g. `torch.einsum`, `torch.sin`, ...). More
126-
PyTorch/XLA operations (e.g. `mark_sharding`) will be supported in the future.
125+
PyTorch upstream operations (e.g. `torch.einsum`, `torch.sin`, ...), or these
126+
PyTorch/XLA operations:
127+
* `torch_xla.experimental.assume_pure` (recursive `assume_pure`)
128+
* `torch_xla.distributed.spmd.mark_sharding`
129+
130+
More PyTorch/XLA operations (e.g. `flash_attention`) will be supported in the
131+
future.
127132

128133
<!-- xrefs -->
129134

infra/ansible/config/pip.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pip:
2828
- tqdm
2929
- typing_extensions
3030
- sympy
31-
- yapf==0.30.0
31+
- yapf==0.40.2
3232

3333
build_amd64:
3434
- mkl

test/run_tests.sh

+1
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ function run_xla_op_tests2 {
212212
XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py"
213213
run_test "$CDIR/test_jax_interop.py"
214214
run_test "$CDIR/test_assume_pure.py"
215+
run_test "$CDIR/test_assume_pure_spmd.py"
215216
}
216217

217218
# All the new xla op tests should go to run_xla_op_tests3

test/scan/test_scan_spmd.py

-22
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch_xla
99
import torch.nn as nn
1010
from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear, Mesh
11-
from torch_xla.experimental.assume_pure import assume_pure
1211
from torch_xla.experimental.scan import scan
1312
from torch_xla.experimental.scan_layers import scan_layers
1413
from torch_xla.distributed.spmd import mark_sharding, mark_sharding_with_gradients, set_global_mesh, get_1d_mesh, get_global_mesh
@@ -231,27 +230,6 @@ def check_dots_in_model(self, model, x, expect_pattern):
231230
def count_regex(self, hlo_text, regex_str):
232231
return len(re.findall(regex_str, hlo_text))
233232

234-
@unittest.skipIf(
235-
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
236-
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
237-
)
238-
def test_assume_pure_works_with_mark_sharding(self):
239-
x = torch.randn((3, 4, 5, 128), device='xla')
240-
assume_pure(mark_sharding)(x, self.spmd_mesh, ("model", None, None, None))
241-
# assert not throwing
242-
243-
@unittest.skipIf(
244-
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
245-
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
246-
)
247-
def test_convert_to_jax_mesh(self):
248-
jax_mesh = self.spmd_mesh.maybe_convert_and_get_jax_mesh()
249-
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
250-
np.testing.assert_equal(
251-
np.array([dev.id for dev in jax_mesh.devices.flatten()]),
252-
self.spmd_mesh.device_ids)
253-
# assert not throwing
254-
255233

256234
if __name__ == '__main__':
257235
test = unittest.main()

test/test_assume_pure_spmd.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
import sys
3+
import unittest
4+
5+
import numpy as np
6+
import torch
7+
import torch_xla
8+
import torch_xla.runtime as xr
9+
from torch_xla.experimental.assume_pure import assume_pure
10+
from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, get_1d_mesh, Mesh
11+
12+
13+
class AssumePureSpmdTest(unittest.TestCase):
14+
15+
@classmethod
16+
def setUpClass(cls):
17+
# Activate SPMD
18+
xr.use_spmd()
19+
20+
def setUp(self):
21+
# Set up a simple SPMD mesh for these tests.
22+
self.spmd_mesh = get_1d_mesh(axis_name="model")
23+
set_global_mesh(self.spmd_mesh)
24+
25+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
26+
"Multiple devices required")
27+
@unittest.skipIf(
28+
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
29+
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
30+
)
31+
def test_assume_pure_works_with_mark_sharding(self):
32+
x = torch.randn((8, 4, 5, 128), device='xla')
33+
result = assume_pure(mark_sharding)(x, self.spmd_mesh,
34+
("model", None, None, None))
35+
torch_xla.sync(wait=True)
36+
N = xr.global_runtime_device_count()
37+
self.assertIn(f'devices=[{N}',
38+
torch_xla._XLAC._get_xla_sharding_spec(result))
39+
40+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
41+
"Multiple devices required")
42+
@unittest.skipIf(
43+
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
44+
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
45+
)
46+
def test_convert_to_jax_mesh(self):
47+
jax_mesh = self.spmd_mesh.get_jax_mesh()
48+
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
49+
np.testing.assert_equal(
50+
np.array([dev.id for dev in jax_mesh.devices.flatten()]),
51+
self.spmd_mesh.device_ids)
52+
53+
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
54+
"Multiple devices required")
55+
@unittest.skipUnless(os.environ.get('PJRT_DEVICE') == 'TPU', "TPU only test")
56+
def test_convert_to_jax_mesh_shuffled(self):
57+
"""Test get_jax_mesh when the PyTorch/XLA mesh has a custom order."""
58+
59+
# Arrange
60+
num_devices = xr.global_runtime_device_count()
61+
device_ids = np.arange(num_devices)
62+
device_ids = np.random.permutation(device_ids)
63+
self.spmd_mesh = Mesh(
64+
device_ids, mesh_shape=(num_devices,), axis_names=('model',))
65+
66+
# Act
67+
jax_mesh = self.spmd_mesh.get_jax_mesh()
68+
69+
# Assert
70+
torch_xla_devices = np.array(
71+
[xr.global_runtime_device_attributes()[i] for i in device_ids])
72+
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
73+
np.testing.assert_equal(
74+
np.array([dev.coords for dev in jax_mesh.devices.flatten()]),
75+
np.array([dev['coords'] for dev in torch_xla_devices.flatten()]),
76+
)
77+
78+
79+
if __name__ == '__main__':
80+
test = unittest.main()
81+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ python3 "$TEST_CDIR/scan/test_scan_pallas.py"
3737
python3 "$TEST_CDIR/scan/test_scan_layers.py"
3838
python3 "$TEST_CDIR/test_gru.py"
3939
python3 "$TEST_CDIR/test_assume_pure.py"
40+
python3 "$TEST_CDIR/test_assume_pure_spmd.py"
4041
python3 "$TEST_CDIR/test_as_stride_use_slice.py"
4142
run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
4243
python3 "$TEST_CDIR/test_pallas.py" -v

torch_xla/distributed/spmd/xla_sharding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def from_str(cls, mesh_str: str) -> Optional["Mesh"]:
182182
return None
183183

184184
@requires_jax
185-
def maybe_convert_and_get_jax_mesh(self):
185+
def get_jax_mesh(self):
186186
# Construct a JAX mesh object with the same device ids shape and ordering
187187
# from torch_xla device mesh.
188188
import jax
@@ -611,7 +611,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
611611
if tx is not None and isinstance(t, tx.tensor.Tensor):
612612
from jax.sharding import PartitionSpec as P, NamedSharding
613613
op_sharding = tuple(str(i) if i is not None else i for i in partition_spec)
614-
jmesh = mesh.maybe_convert_and_get_jax_mesh()
614+
jmesh = mesh.get_jax_mesh()
615615
t.shard_(NamedSharding(jmesh, P(*op_sharding)))
616616
return t
617617

torch_xla/experimental/assume_pure.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@ def assume_pure(fn):
1515
1616
Limitations:
1717
- The decorated function can only use upstream PyTorch operators e.g.
18-
`torch.einsum`, `torch.nn.functional.layer_norm`. Custom PyTorch/XLA
19-
operations such as `mark_sharding` are not supported. This limitation
20-
may be lifted in the future.
18+
`torch.einsum`, `torch.nn.functional.layer_norm`, and a few PyTorch/XLA operators:
19+
* `torch_xla.experimental.assume_pure` (recursive `assume_pure`)
20+
* `torch_xla.distributed.spmd.mark_sharding`
21+
22+
- Other custom PyTorch/XLA operations such as `flash_attention` are not
23+
supported. This limitation may be lifted in the future.
2124
"""
2225
from torchax.interop import jax_view
2326
return j2t_autograd(jax_view(fn))

torch_xla/experimental/splash_attention.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def splash_attention_jax_wrapper(
8686
splash_attention_kernel,
8787
splash_attention_mask,
8888
)
89-
mesh = Mesh.from_str(config.mesh).maybe_convert_and_get_jax_mesh()
89+
mesh = Mesh.from_str(config.mesh).get_jax_mesh()
9090
# input q,k,v shape: [batch, #head, seq_len, head_dim]
9191
if decoder_segment_ids is not None and not decoder_segment_ids.shape:
9292
decoder_segment_ids = None

0 commit comments

Comments
 (0)