Skip to content

[DRAFT] call_jax sets the JAX contextual mesh #9043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docs/source/perf/assume_pure.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,13 @@ a fixed up-front cost, and then later runs will reuse the cached XLA computation
## Limitations

Currently, all operations in a function wrapped with `@assume_pure` must be
PyTorch upstream operations (e.g. `torch.einsum`, `torch.sin`, ...). More
PyTorch/XLA operations (e.g. `mark_sharding`) will be supported in the future.
PyTorch upstream operations (e.g. `torch.einsum`, `torch.sin`, ...), or these
PyTorch/XLA operations:
* `torch_xla.experimental.assume_pure` (recursive `assume_pure`)
* `torch_xla.distributed.spmd.mark_sharding`

More PyTorch/XLA operations (e.g. `flash_attention`) will be supported in the
future.

<!-- xrefs -->

Expand Down
2 changes: 1 addition & 1 deletion infra/ansible/config/pip.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pip:
- tqdm
- typing_extensions
- sympy
- yapf==0.30.0
- yapf==0.40.2

build_amd64:
- mkl
Expand Down
2 changes: 2 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ function run_xla_op_tests2 {
run_test "$CDIR/test_callback.py"
XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py"
run_test "$CDIR/test_jax_interop.py"
run_test "$CDIR/test_jax_interop_spmd.py"
run_test "$CDIR/test_assume_pure.py"
run_test "$CDIR/test_assume_pure_spmd.py"
}

# All the new xla op tests should go to run_xla_op_tests3
Expand Down
22 changes: 0 additions & 22 deletions test/scan/test_scan_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch_xla
import torch.nn as nn
from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear, Mesh
from torch_xla.experimental.assume_pure import assume_pure
from torch_xla.experimental.scan import scan
from torch_xla.experimental.scan_layers import scan_layers
from torch_xla.distributed.spmd import mark_sharding, mark_sharding_with_gradients, set_global_mesh, get_1d_mesh, get_global_mesh
Expand Down Expand Up @@ -231,27 +230,6 @@ def check_dots_in_model(self, model, x, expect_pattern):
def count_regex(self, hlo_text, regex_str):
return len(re.findall(regex_str, hlo_text))

@unittest.skipIf(
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
)
def test_assume_pure_works_with_mark_sharding(self):
x = torch.randn((3, 4, 5, 128), device='xla')
assume_pure(mark_sharding)(x, self.spmd_mesh, ("model", None, None, None))
# assert not throwing

@unittest.skipIf(
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
)
def test_convert_to_jax_mesh(self):
jax_mesh = self.spmd_mesh.maybe_convert_and_get_jax_mesh()
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
np.testing.assert_equal(
np.array([dev.id for dev in jax_mesh.devices.flatten()]),
self.spmd_mesh.device_ids)
# assert not throwing


if __name__ == '__main__':
test = unittest.main()
Expand Down
81 changes: 81 additions & 0 deletions test/test_assume_pure_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import sys
import unittest

import numpy as np
import torch
import torch_xla
import torch_xla.runtime as xr
from torch_xla.experimental.assume_pure import assume_pure
from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, get_1d_mesh, Mesh


class AssumePureSpmdTest(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Activate SPMD
xr.use_spmd()

def setUp(self):
# Set up a simple SPMD mesh for these tests.
self.spmd_mesh = get_1d_mesh(axis_name="model")
set_global_mesh(self.spmd_mesh)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required")
@unittest.skipIf(
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
)
def test_assume_pure_works_with_mark_sharding(self):
x = torch.randn((8, 4, 5, 128), device='xla')
result = assume_pure(mark_sharding)(x, self.spmd_mesh,
("model", None, None, None))
torch_xla.sync(wait=True)
N = xr.global_runtime_device_count()
self.assertIn(f'devices=[{N}',
torch_xla._XLAC._get_xla_sharding_spec(result))

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required")
@unittest.skipIf(
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
)
def test_convert_to_jax_mesh(self):
jax_mesh = self.spmd_mesh.get_jax_mesh()
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
np.testing.assert_equal(
np.array([dev.id for dev in jax_mesh.devices.flatten()]),
self.spmd_mesh.device_ids)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required")
@unittest.skipUnless(os.environ.get('PJRT_DEVICE') == 'TPU', "TPU only test")
def test_convert_to_jax_mesh_shuffled(self):
"""Test get_jax_mesh when the PyTorch/XLA mesh has a custom order."""

# Arrange
num_devices = xr.global_runtime_device_count()
device_ids = np.arange(num_devices)
device_ids = np.random.permutation(device_ids)
self.spmd_mesh = Mesh(
device_ids, mesh_shape=(num_devices,), axis_names=('model',))

# Act
jax_mesh = self.spmd_mesh.get_jax_mesh()

# Assert
torch_xla_devices = np.array(
[xr.global_runtime_device_attributes()[i] for i in device_ids])
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)
np.testing.assert_equal(
np.array([dev.coords for dev in jax_mesh.devices.flatten()]),
np.array([dev['coords'] for dev in torch_xla_devices.flatten()]),
)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
52 changes: 52 additions & 0 deletions test/test_jax_interop_spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import sys
import unittest

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_builder as xb
import torch_xla.runtime as xr
from torch_xla.distributed.spmd import set_global_mesh, get_1d_mesh


class TestJaxInteropSpmd(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Activate SPMD
xr.use_spmd()

def setUp(self):
# Clear cached HLO between test cases.
xb._JAX_TO_XLA_COMPUTATION_CACHE.clear()
# Set up a simple SPMD mesh for these tests.
self.spmd_mesh = get_1d_mesh(axis_name="model")
set_global_mesh(self.spmd_mesh)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
"Multiple devices required")
def test_call_jax_sharding_constraints(self):
"""Test that we can call jax.lax.with_sharding_constraints from PyTorch/XLA."""

# Arrange
a = torch.ones((8, 8), device='xla')

def f(a, b):
import jax
from jax.sharding import PartitionSpec as P
import jax.numpy as jnp
return jax.lax.with_sharding_constraint(a, P("model",)) + jnp.sin(b)

# Act
result = xb.call_jax(f, (a, a))
torch_xla.sync(wait=True)

# Assert
N = xr.global_runtime_device_count()
self.assertIn(f'devices=[{N}',
torch_xla._XLAC._get_xla_sharding_spec(result))


if __name__ == "__main__":
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
2 changes: 2 additions & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ python3 "$TEST_CDIR/scan/test_scan_pallas.py"
python3 "$TEST_CDIR/scan/test_scan_layers.py"
python3 "$TEST_CDIR/test_gru.py"
python3 "$TEST_CDIR/test_assume_pure.py"
python3 "$TEST_CDIR/test_assume_pure_spmd.py"
python3 "$TEST_CDIR/test_jax_interop_spmd.py"
python3 "$TEST_CDIR/test_as_stride_use_slice.py"
run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
python3 "$TEST_CDIR/test_pallas.py" -v
Expand Down
15 changes: 13 additions & 2 deletions torch_xla/core/xla_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,8 +920,19 @@ def get_xla_computation():
import torch_xla.debug.profiler as xp
# If we see this trace span in the profiler, we'll know that there's a cache miss.
with xp.Trace('jax_to_xla_computation'):
lowered = jax.jit(fn, keep_unused=True).lower(*sample_tensor_args)
hlo_ir = lowered.compiler_ir('hlo')
jitted = jax.jit(fn, keep_unused=True)

def do_lower():
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs
if xr.is_spmd():
mesh = xs.get_global_mesh()
if mesh is not None:
with mesh.get_jax_mesh():
return jitted.lower(*sample_tensor_args)
return jitted.lower(*sample_tensor_args)

hlo_ir = do_lower().compiler_ir('hlo')
assert len(traced_out_spec) == 1, \
"fn must be traced to obtain the output tree spec"
spec = traced_out_spec[0]
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def from_str(cls, mesh_str: str) -> Optional["Mesh"]:
return None

@requires_jax
def maybe_convert_and_get_jax_mesh(self):
def get_jax_mesh(self):
# Construct a JAX mesh object with the same device ids shape and ordering
# from torch_xla device mesh.
import jax
Expand Down Expand Up @@ -611,7 +611,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
if tx is not None and isinstance(t, tx.tensor.Tensor):
from jax.sharding import PartitionSpec as P, NamedSharding
op_sharding = tuple(str(i) if i is not None else i for i in partition_spec)
jmesh = mesh.maybe_convert_and_get_jax_mesh()
jmesh = mesh.get_jax_mesh()
t.shard_(NamedSharding(jmesh, P(*op_sharding)))
return t

Expand Down
9 changes: 6 additions & 3 deletions torch_xla/experimental/assume_pure.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ def assume_pure(fn):

Limitations:
- The decorated function can only use upstream PyTorch operators e.g.
`torch.einsum`, `torch.nn.functional.layer_norm`. Custom PyTorch/XLA
operations such as `mark_sharding` are not supported. This limitation
may be lifted in the future.
`torch.einsum`, `torch.nn.functional.layer_norm`, and a few PyTorch/XLA operators:
* `torch_xla.experimental.assume_pure` (recursive `assume_pure`)
* `torch_xla.distributed.spmd.mark_sharding`

- Other custom PyTorch/XLA operations such as `flash_attention` are not
supported. This limitation may be lifted in the future.
"""
from torchax.interop import jax_view
return j2t_autograd(jax_view(fn))
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/experimental/splash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def splash_attention_jax_wrapper(
splash_attention_kernel,
splash_attention_mask,
)
mesh = Mesh.from_str(config.mesh).maybe_convert_and_get_jax_mesh()
mesh = Mesh.from_str(config.mesh).get_jax_mesh()
# input q,k,v shape: [batch, #head, seq_len, head_dim]
if decoder_segment_ids is not None and not decoder_segment_ids.shape:
decoder_segment_ids = None
Expand Down
Loading