Skip to content

Improve assume_pure docs and tests #9038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docs/source/perf/assume_pure.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,13 @@ a fixed up-front cost, and then later runs will reuse the cached XLA computation
## Limitations

Currently, all operations in a function wrapped with `@assume_pure` must be
PyTorch upstream operations (e.g. `torch.einsum`, `torch.sin`, ...). More
PyTorch/XLA operations (e.g. `mark_sharding`) will be supported in the future.
PyTorch upstream operations (e.g. `torch.einsum`, `torch.sin`, ...), or these
PyTorch/XLA operations:
* `torch_xla.experimental.assume_pure` (recursive `assume_pure`)
* `torch_xla.distributed.spmd.mark_sharding`

More PyTorch/XLA operations (e.g. `flash_attention`) will be supported in the
future.

<!-- xrefs -->

Expand Down
2 changes: 1 addition & 1 deletion infra/ansible/config/pip.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pip:
- tqdm
- typing_extensions
- sympy
- yapf==0.30.0
- yapf==0.40.2

build_amd64:
- mkl
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ function run_xla_op_tests2 {
XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py"
run_test "$CDIR/test_jax_interop.py"
run_test "$CDIR/test_assume_pure.py"
run_test "$CDIR/test_assume_pure_spmd.py"
}

# All the new xla op tests should go to run_xla_op_tests3
Expand Down
22 changes: 0 additions & 22 deletions test/scan/test_scan_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch_xla
import torch.nn as nn
from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear, Mesh
from torch_xla.experimental.assume_pure import assume_pure
from torch_xla.experimental.scan import scan
from torch_xla.experimental.scan_layers import scan_layers
from torch_xla.distributed.spmd import mark_sharding, mark_sharding_with_gradients, set_global_mesh, get_1d_mesh, get_global_mesh
Expand Down Expand Up @@ -231,27 +230,6 @@ def check_dots_in_model(self, model, x, expect_pattern):
def count_regex(self, hlo_text, regex_str):
return len(re.findall(regex_str, hlo_text))

@unittest.skipIf(
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
)
def test_assume_pure_works_with_mark_sharding(self):
x = torch.randn((3, 4, 5, 128), device='xla')
assume_pure(mark_sharding)(x, self.spmd_mesh, ("model", None, None, None))
# assert not throwing

@unittest.skipIf(
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
)
def test_convert_to_jax_mesh(self):
jax_mesh = self.spmd_mesh.maybe_convert_and_get_jax_mesh()
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
np.testing.assert_equal(
np.array([dev.id for dev in jax_mesh.devices.flatten()]),
self.spmd_mesh.device_ids)
# assert not throwing


if __name__ == '__main__':
test = unittest.main()
Expand Down
81 changes: 81 additions & 0 deletions test/test_assume_pure_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import sys
import unittest

import numpy as np
import torch
import torch_xla
import torch_xla.runtime as xr
from torch_xla.experimental.assume_pure import assume_pure
from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, get_1d_mesh, Mesh


class AssumePureSpmdTest(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Activate SPMD
xr.use_spmd()

def setUp(self):
# Set up a simple SPMD mesh for these tests.
self.spmd_mesh = get_1d_mesh(axis_name="model")
set_global_mesh(self.spmd_mesh)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required")
@unittest.skipIf(
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
)
def test_assume_pure_works_with_mark_sharding(self):
x = torch.randn((8, 4, 5, 128), device='xla')
result = assume_pure(mark_sharding)(x, self.spmd_mesh,
("model", None, None, None))
torch_xla.sync(wait=True)
N = xr.global_runtime_device_count()
self.assertIn(f'devices=[{N}',
torch_xla._XLAC._get_xla_sharding_spec(result))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required")
@unittest.skipIf(
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
)
def test_convert_to_jax_mesh(self):
jax_mesh = self.spmd_mesh.get_jax_mesh()
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
np.testing.assert_equal(
np.array([dev.id for dev in jax_mesh.devices.flatten()]),
self.spmd_mesh.device_ids)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required")
@unittest.skipUnless(os.environ.get('PJRT_DEVICE') == 'TPU', "TPU only test")
def test_convert_to_jax_mesh_shuffled(self):
"""Test get_jax_mesh when the PyTorch/XLA mesh has a custom order."""

# Arrange
num_devices = xr.global_runtime_device_count()
device_ids = np.arange(num_devices)
device_ids = np.random.permutation(device_ids)
self.spmd_mesh = Mesh(
device_ids, mesh_shape=(num_devices,), axis_names=('model',))

# Act
jax_mesh = self.spmd_mesh.get_jax_mesh()

# Assert
torch_xla_devices = np.array(
[xr.global_runtime_device_attributes()[i] for i in device_ids])
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
np.testing.assert_equal(
np.array([dev.coords for dev in jax_mesh.devices.flatten()]),
np.array([dev['coords'] for dev in torch_xla_devices.flatten()]),
)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ python3 "$TEST_CDIR/scan/test_scan_pallas.py"
python3 "$TEST_CDIR/scan/test_scan_layers.py"
python3 "$TEST_CDIR/test_gru.py"
python3 "$TEST_CDIR/test_assume_pure.py"
python3 "$TEST_CDIR/test_assume_pure_spmd.py"
python3 "$TEST_CDIR/test_as_stride_use_slice.py"
run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
python3 "$TEST_CDIR/test_pallas.py" -v
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def from_str(cls, mesh_str: str) -> Optional["Mesh"]:
return None

@requires_jax
def maybe_convert_and_get_jax_mesh(self):
def get_jax_mesh(self):
# Construct a JAX mesh object with the same device ids shape and ordering
# from torch_xla device mesh.
import jax
Expand Down Expand Up @@ -611,7 +611,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
if tx is not None and isinstance(t, tx.tensor.Tensor):
from jax.sharding import PartitionSpec as P, NamedSharding
op_sharding = tuple(str(i) if i is not None else i for i in partition_spec)
jmesh = mesh.maybe_convert_and_get_jax_mesh()
jmesh = mesh.get_jax_mesh()
t.shard_(NamedSharding(jmesh, P(*op_sharding)))
return t

Expand Down
9 changes: 6 additions & 3 deletions torch_xla/experimental/assume_pure.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ def assume_pure(fn):

Limitations:
- The decorated function can only use upstream PyTorch operators e.g.
`torch.einsum`, `torch.nn.functional.layer_norm`. Custom PyTorch/XLA
operations such as `mark_sharding` are not supported. This limitation
may be lifted in the future.
`torch.einsum`, `torch.nn.functional.layer_norm`, and a few PyTorch/XLA operators:
* `torch_xla.experimental.assume_pure` (recursive `assume_pure`)
* `torch_xla.distributed.spmd.mark_sharding`

- Other custom PyTorch/XLA operations such as `flash_attention` are not
supported. This limitation may be lifted in the future.
"""
from torchax.interop import jax_view
return j2t_autograd(jax_view(fn))
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/experimental/splash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def splash_attention_jax_wrapper(
splash_attention_kernel,
splash_attention_mask,
)
mesh = Mesh.from_str(config.mesh).maybe_convert_and_get_jax_mesh()
mesh = Mesh.from_str(config.mesh).get_jax_mesh()
# input q,k,v shape: [batch, #head, seq_len, head_dim]
if decoder_segment_ids is not None and not decoder_segment_ids.shape:
decoder_segment_ids = None
Expand Down
Loading