From b8d8be50f1b3f965b467177bd99258b7abac3247 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Thu, 24 Apr 2025 18:45:35 -0700 Subject: [PATCH 1/3] [call_jax] Bridge the torch_xla and JAX mesh Now we can run a JAX SPMD function that accesses the ambient SPMD mesh from xb.call_jax. Fixes https://github.com/pytorch/xla/issues/8972. Also I beefed up the assume_pure tests and updated the docs to mention that mark_sharding is supported thanks to qihqi@' https://github.com/pytorch/xla/pull/8989. --- benchmarks/benchmark_model.py | 5 +- docs/source/perf/assume_pure.md | 9 ++- test/pytorch_test_base.py | 4 +- test/run_tests.sh | 2 + test/scan/test_scan_spmd.py | 22 ------ test/spmd/test_xla_sharding.py | 27 +++---- test/test_assume_pure_spmd.py | 79 +++++++++++++++++++ test/test_jax_interop_spmd.py | 49 ++++++++++++ test/test_operations.py | 8 +- test/test_pallas.py | 8 +- test/test_pallas_spmd.py | 8 +- test/test_splash_attention.py | 8 +- test/tpu/run_tests.sh | 2 + torch_xla/core/xla_builder.py | 15 +++- torch_xla/distributed/spmd/xla_sharding.py | 4 +- torch_xla/distributed/xla_multiprocessing.py | 5 +- torch_xla/experimental/assume_pure.py | 9 ++- .../experimental/gradient_accumulation.py | 4 +- .../ragged_paged_attention_kernel.py | 12 +-- .../ragged_paged_attention_v2.py | 4 +- torch_xla/experimental/splash_attention.py | 2 +- 21 files changed, 206 insertions(+), 80 deletions(-) create mode 100644 test/test_assume_pure_spmd.py create mode 100644 test/test_jax_interop_spmd.py diff --git a/benchmarks/benchmark_model.py b/benchmarks/benchmark_model.py index 2b2f6c1957b7..29951c02a92d 100644 --- a/benchmarks/benchmark_model.py +++ b/benchmarks/benchmark_model.py @@ -227,8 +227,9 @@ def is_compatible(self, dummy_benchmark_model: BenchmarkModel, def get_benchmark_indices(self, length: int): start = self._args.partition_id * (length // self._args.total_partitions) end = ((self._args.partition_id + 1) * - (length // self._args.total_partitions) if self._args.partition_id - < self._args.total_partitions - 1 else length) + (length // self._args.total_partitions) + if self._args.partition_id < self._args.total_partitions - 1 else + length) return start, end def skip_model(self, model_name: str): diff --git a/docs/source/perf/assume_pure.md b/docs/source/perf/assume_pure.md index 912b83ea8115..2e6cba8d9c51 100644 --- a/docs/source/perf/assume_pure.md +++ b/docs/source/perf/assume_pure.md @@ -122,8 +122,13 @@ a fixed up-front cost, and then later runs will reuse the cached XLA computation ## Limitations Currently, all operations in a function wrapped with `@assume_pure` must be -PyTorch upstream operations (e.g. `torch.einsum`, `torch.sin`, ...). More -PyTorch/XLA operations (e.g. `mark_sharding`) will be supported in the future. +PyTorch upstream operations (e.g. `torch.einsum`, `torch.sin`, ...), or these +PyTorch/XLA operations: + * `torch_xla.experimental.assume_pure` (recursive `assume_pure`) + * `torch_xla.distributed.spmd.mark_sharding` + +More PyTorch/XLA operations (e.g. `flash_attention`) will be supported in the +future. diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index b47ae3f3de6d..bcf3cdb32779 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -619,8 +619,8 @@ def skipped_test(self, *args, reason=reason, **kwargs): setattr(cls, dtype_test_name, disallowed_test) if not skipped: xla_dtypes.append( - dtype_combination if len(dtype_combination) > - 1 else dtype_combination[0]) + dtype_combination + if len(dtype_combination) > 1 else dtype_combination[0]) if len(xla_dtypes) != 0: test.dtypes[cls.device_type] = xla_dtypes super().instantiate_test(name, test, generic_cls=generic_cls) diff --git a/test/run_tests.sh b/test/run_tests.sh index 0e096148ae62..c90468981141 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -211,7 +211,9 @@ function run_xla_op_tests2 { run_test "$CDIR/test_callback.py" XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py" run_test "$CDIR/test_jax_interop.py" + run_test "$CDIR/test_jax_interop_spmd.py" run_test "$CDIR/test_assume_pure.py" + run_test "$CDIR/test_assume_pure_spmd.py" } # All the new xla op tests should go to run_xla_op_tests3 diff --git a/test/scan/test_scan_spmd.py b/test/scan/test_scan_spmd.py index a796bdddb0a2..9bf081527c72 100644 --- a/test/scan/test_scan_spmd.py +++ b/test/scan/test_scan_spmd.py @@ -8,7 +8,6 @@ import torch_xla import torch.nn as nn from torch_xla.distributed.spmd.xla_sharding import apply_xla_patch_to_nn_linear, Mesh -from torch_xla.experimental.assume_pure import assume_pure from torch_xla.experimental.scan import scan from torch_xla.experimental.scan_layers import scan_layers from torch_xla.distributed.spmd import mark_sharding, mark_sharding_with_gradients, set_global_mesh, get_1d_mesh, get_global_mesh @@ -231,27 +230,6 @@ def check_dots_in_model(self, model, x, expect_pattern): def count_regex(self, hlo_text, regex_str): return len(re.findall(regex_str, hlo_text)) - @unittest.skipIf( - torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA', - "TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU" - ) - def test_assume_pure_works_with_mark_sharding(self): - x = torch.randn((3, 4, 5, 128), device='xla') - assume_pure(mark_sharding)(x, self.spmd_mesh, ("model", None, None, None)) - # assert not throwing - - @unittest.skipIf( - torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA', - "TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU" - ) - def test_convert_to_jax_mesh(self): - jax_mesh = self.spmd_mesh.maybe_convert_and_get_jax_mesh() - self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape) - np.testing.assert_equal( - np.array([dev.id for dev in jax_mesh.devices.flatten()]), - self.spmd_mesh.device_ids) - # assert not throwing - if __name__ == '__main__': test = unittest.main() diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 5afbee0aedb2..28c78c635f8a 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -618,9 +618,9 @@ def test_inplace_add_with_sharding(self): # avoid calling xr.addressable_device_count here otherwise it will init the test # in non-spmd mode. - @unittest.skipIf( - xr.device_type() == 'CPU', - "sharding will be the same for both tensors on single device") + @unittest.skipIf(xr.device_type() == 'CPU', + "sharding will be the same for both tensors on single device" + ) def test_shard_hashing(self): xt1 = torch.ones(2, 2).to(xm.xla_device()) xt2 = torch.ones(2, 2).to(xm.xla_device()) @@ -1383,9 +1383,8 @@ def test_get_1d_mesh(self): self.assertEqual(mesh_without_name.mesh_shape, (xr.global_runtime_device_count(),)) - @unittest.skipUnless( - xr.global_runtime_device_count() > 1, - "Multiple devices required for dataloader sharding test") + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for dataloader sharding test") def test_data_loader_with_sharding(self): device = torch_xla.device() mesh = xs.get_1d_mesh("data") @@ -1406,9 +1405,8 @@ def test_data_loader_with_sharding(self): f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" ) - @unittest.skipUnless( - xr.global_runtime_device_count() > 1, - "Multiple devices required for dataloader sharding test") + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for dataloader sharding test") def test_data_loader_with_non_batch_size(self): device = torch_xla.device() mesh = xs.get_1d_mesh("data") @@ -1429,9 +1427,8 @@ def test_data_loader_with_non_batch_size(self): f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" ) - @unittest.skipUnless( - xr.global_runtime_device_count() > 1, - "Multiple devices required for dataloader sharding test") + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required for dataloader sharding test") def test_data_loader_with_non_batch_size_and_mini_batch(self): device = torch_xla.device() mesh = xs.get_1d_mesh("data") @@ -1663,9 +1660,9 @@ def test_get_logical_mesh(self): self.assertEqual(logical_mesh.shape, mesh_shape) np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids) - @unittest.skipIf( - xr.device_type() == 'CPU', - "sharding will be the same for both tensors on single device") + @unittest.skipIf(xr.device_type() == 'CPU', + "sharding will be the same for both tensors on single device" + ) def test_shard_as(self): mesh = self._get_mesh((self.n_devices,)) partition_spec = (0,) diff --git a/test/test_assume_pure_spmd.py b/test/test_assume_pure_spmd.py new file mode 100644 index 000000000000..4edf8cf45614 --- /dev/null +++ b/test/test_assume_pure_spmd.py @@ -0,0 +1,79 @@ +import os +import sys +import unittest + +import numpy as np +import torch +import torch_xla +import torch_xla.runtime as xr +from torch_xla.experimental.assume_pure import assume_pure +from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, get_1d_mesh, Mesh + + +class AssumePureSpmdTest(unittest.TestCase): + + def setUp(self): + # Activate SPMD + xr.use_spmd() + + # Set up a simple SPMD mesh for these tests. + self.spmd_mesh = get_1d_mesh(axis_name="model") + set_global_mesh(self.spmd_mesh) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required") + @unittest.skipIf( + torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA', + "TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU" + ) + def test_assume_pure_works_with_mark_sharding(self): + x = torch.randn((8, 4, 5, 128), device='xla') + result = assume_pure(mark_sharding)(x, self.spmd_mesh, + ("model", None, None, None)) + torch_xla.sync(wait=True) + N = xr.global_runtime_device_count() + self.assertIn(f'devices=[{N}', + torch_xla._XLAC._get_xla_sharding_spec(result)) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required") + @unittest.skipIf( + torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA', + "TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU" + ) + def test_convert_to_jax_mesh(self): + jax_mesh = self.spmd_mesh.get_jax_mesh() + self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape) + np.testing.assert_equal( + np.array([dev.id for dev in jax_mesh.devices.flatten()]), + self.spmd_mesh.device_ids) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required") + @unittest.skipUnless(os.environ.get('PJRT_DEVICE') == 'TPU', "TPU only test") + def test_convert_to_jax_mesh_shuffled(self): + """Test get_jax_mesh when the PyTorch/XLA mesh has a custom order.""" + + # Arrange + num_devices = xr.global_runtime_device_count() + device_ids = np.arange(num_devices) + device_ids = np.random.permutation(device_ids) + self.spmd_mesh = Mesh( + device_ids, mesh_shape=(num_devices,), axis_names=('model',)) + + # Act + jax_mesh = self.spmd_mesh.get_jax_mesh() + + # Assert + torch_xla_devices = np.array( + [xr.global_runtime_device_attributes()[i] for i in device_ids]) + self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape) + np.testing.assert_equal( + np.array([dev.coords for dev in jax_mesh.devices.flatten()]), + np.array([dev['coords'] for dev in torch_xla_devices.flatten()]), + ) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_jax_interop_spmd.py b/test/test_jax_interop_spmd.py new file mode 100644 index 000000000000..bf0a39d0eeac --- /dev/null +++ b/test/test_jax_interop_spmd.py @@ -0,0 +1,49 @@ +import sys +import unittest + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.core.xla_builder as xb +import torch_xla.runtime as xr +from torch_xla.distributed.spmd import set_global_mesh, get_1d_mesh + + +class TestJaxInteropSpmd(unittest.TestCase): + + def setUp(self): + xb._JAX_TO_XLA_COMPUTATION_CACHE.clear() + # Activate SPMD + xr.use_spmd() + + # Set up a simple SPMD mesh for these tests. + self.spmd_mesh = get_1d_mesh(axis_name="model") + set_global_mesh(self.spmd_mesh) + + @unittest.skipUnless(xr.global_runtime_device_count() > 1, + "Multiple devices required") + def test_call_jax_sharding_constraints(self): + """Test that we can call jax.lax.with_sharding_constraints from PyTorch/XLA.""" + + # Arrange + a = torch.ones((8, 8), device='xla') + + def f(a, b): + import jax + from jax.sharding import PartitionSpec as P + import jax.numpy as jnp + return jax.lax.with_sharding_constraint(a, P("model",)) + jnp.sin(b) + + # Act + result = xb.call_jax(f, (a, a)) + torch_xla.sync(wait=True) + + # Assert + N = xr.global_runtime_device_count() + self.assertIn(f'devices=[{N}', + torch_xla._XLAC._get_xla_sharding_spec(result)) + + +if __name__ == "__main__": + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/test_operations.py b/test/test_operations.py index aa6ac3e016df..33db55a3585a 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2959,9 +2959,11 @@ def test_dlpack_roundtrip_tensor(self, dtype): @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA - @parameterized.parameters( - *all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, - torch.uint16, torch.uint32, torch.uint64)) + @parameterized.parameters(*all_types_and_complex_and(torch.half, + torch.bfloat16, + torch.bool, torch.uint16, + torch.uint32, + torch.uint64)) def test_dlpack_roundtrip_scalar(self, dtype): xla_device = xm.xla_device() xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) diff --git a/test/test_pallas.py b/test/test_pallas.py index c1f9df9ba0c2..5340755506fc 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -41,10 +41,10 @@ class PallasTest(parameterized.TestCase): # therefore we use != instead of ==. def _make_attention_mask_from_segment_ids(self, q_segment_ids, kv_segment_ids): - return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1], - 1) != kv_segment_ids.view(kv_segment_ids.shape[0], - 1, 1, - kv_segment_ids.shape[1]) + return q_segment_ids.view(q_segment_ids.shape[0], 1, + q_segment_ids.shape[1], 1) != kv_segment_ids.view( + kv_segment_ids.shape[0], 1, 1, + kv_segment_ids.shape[1]) def _attention(self, q, k, v, *, attn_mask=None, ab=None): attn_weight = q @ k.transpose(-2, -1) diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index 2e4de37fadc6..f723227217a9 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -41,10 +41,10 @@ class PallasTest(unittest.TestCase): # therefore we use != instead of ==. def _make_attention_mask_from_segment_ids(self, q_segment_ids, kv_segment_ids): - return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1], - 1) != kv_segment_ids.view(kv_segment_ids.shape[0], - 1, 1, - kv_segment_ids.shape[1]) + return q_segment_ids.view(q_segment_ids.shape[0], 1, + q_segment_ids.shape[1], 1) != kv_segment_ids.view( + kv_segment_ids.shape[0], 1, 1, + kv_segment_ids.shape[1]) def _attention(self, q, k, v, *, attn_mask=None, ab=None): attn_weight = q @ k.transpose(-2, -1) diff --git a/test/test_splash_attention.py b/test/test_splash_attention.py index 6e8bb56fe2c3..9fd6c1306ea9 100644 --- a/test/test_splash_attention.py +++ b/test/test_splash_attention.py @@ -62,10 +62,10 @@ def setUp(self): def _make_attention_mask_from_segment_ids(self, q_segment_ids, kv_segment_ids): - return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1], - 1) != kv_segment_ids.view(kv_segment_ids.shape[0], - 1, 1, - kv_segment_ids.shape[1]) + return q_segment_ids.view(q_segment_ids.shape[0], 1, + q_segment_ids.shape[1], 1) != kv_segment_ids.view( + kv_segment_ids.shape[0], 1, 1, + kv_segment_ids.shape[1]) def maybe_repeat_kv(self, hidden_state): if hidden_state.size(1) == self.NUM_Q_HEADS: diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 5cd1b4a45f9d..12d5e5222bab 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -37,6 +37,8 @@ python3 "$TEST_CDIR/scan/test_scan_pallas.py" python3 "$TEST_CDIR/scan/test_scan_layers.py" python3 "$TEST_CDIR/test_gru.py" python3 "$TEST_CDIR/test_assume_pure.py" +python3 "$TEST_CDIR/test_assume_pure_spmd.py" +python3 "$TEST_CDIR/test_jax_interop_spmd.py" python3 "$TEST_CDIR/test_as_stride_use_slice.py" run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py" python3 "$TEST_CDIR/test_pallas.py" -v diff --git a/torch_xla/core/xla_builder.py b/torch_xla/core/xla_builder.py index e4cf706be424..073fdc5ae5df 100644 --- a/torch_xla/core/xla_builder.py +++ b/torch_xla/core/xla_builder.py @@ -920,8 +920,19 @@ def get_xla_computation(): import torch_xla.debug.profiler as xp # If we see this trace span in the profiler, we'll know that there's a cache miss. with xp.Trace('jax_to_xla_computation'): - lowered = jax.jit(fn, keep_unused=True).lower(*sample_tensor_args) - hlo_ir = lowered.compiler_ir('hlo') + jitted = jax.jit(fn, keep_unused=True) + + def do_lower(): + import torch_xla.runtime as xr + import torch_xla.distributed.spmd as xs + if xr.is_spmd(): + mesh = xs.get_global_mesh() + if mesh is not None: + with mesh.get_jax_mesh(): + return jitted.lower(*sample_tensor_args) + return jitted.lower(*sample_tensor_args) + + hlo_ir = do_lower().compiler_ir('hlo') assert len(traced_out_spec) == 1, \ "fn must be traced to obtain the output tree spec" spec = traced_out_spec[0] diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 55fc2e3fa433..cb61158df903 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -182,7 +182,7 @@ def from_str(cls, mesh_str: str) -> Optional["Mesh"]: return None @requires_jax - def maybe_convert_and_get_jax_mesh(self): + def get_jax_mesh(self): # Construct a JAX mesh object with the same device ids shape and ordering # from torch_xla device mesh. import jax @@ -611,7 +611,7 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, if tx is not None and isinstance(t, tx.tensor.Tensor): from jax.sharding import PartitionSpec as P, NamedSharding op_sharding = tuple(str(i) if i is not None else i for i in partition_spec) - jmesh = mesh.maybe_convert_and_get_jax_mesh() + jmesh = mesh.get_jax_mesh() t.shard_(NamedSharding(jmesh, P(*op_sharding))) return t diff --git a/torch_xla/distributed/xla_multiprocessing.py b/torch_xla/distributed/xla_multiprocessing.py index d699abaebafb..36a2f514c06c 100644 --- a/torch_xla/distributed/xla_multiprocessing.py +++ b/torch_xla/distributed/xla_multiprocessing.py @@ -174,10 +174,7 @@ def _v6e_create_replica_groups() -> List | None: return None -device_kind_handler_dict: dict[ - str, - Callable[..., List | None], -] = { +device_kind_handler_dict: dict[str, Callable[..., List | None],] = { _TPU_V5P: _v5p_create_replica_groups, _TPU_V6E: _v6e_create_replica_groups } diff --git a/torch_xla/experimental/assume_pure.py b/torch_xla/experimental/assume_pure.py index c52ba997845e..6a6641bb4d1a 100644 --- a/torch_xla/experimental/assume_pure.py +++ b/torch_xla/experimental/assume_pure.py @@ -15,9 +15,12 @@ def assume_pure(fn): Limitations: - The decorated function can only use upstream PyTorch operators e.g. - `torch.einsum`, `torch.nn.functional.layer_norm`. Custom PyTorch/XLA - operations such as `mark_sharding` are not supported. This limitation - may be lifted in the future. + `torch.einsum`, `torch.nn.functional.layer_norm`, and a few PyTorch/XLA operators: + * `torch_xla.experimental.assume_pure` (recursive `assume_pure`) + * `torch_xla.distributed.spmd.mark_sharding` + + - Other custom PyTorch/XLA operations such as `flash_attention` are not + supported. This limitation may be lifted in the future. """ from torchax.interop import jax_view return j2t_autograd(jax_view(fn)) diff --git a/torch_xla/experimental/gradient_accumulation.py b/torch_xla/experimental/gradient_accumulation.py index 5299291861ff..ad299a393492 100644 --- a/torch_xla/experimental/gradient_accumulation.py +++ b/torch_xla/experimental/gradient_accumulation.py @@ -288,8 +288,8 @@ def add_to_mapping(val: torch.Tensor, iterable_tensors, fake_iterable_tensors, carried_tensors, fake_carried_tensors, params, grads) - def _body_fn_wrapper(curr_iter: xb.Op, curr_loss: xb.Op, *while_params: - xb.Op): + def _body_fn_wrapper(curr_iter: xb.Op, curr_loss: xb.Op, + *while_params: xb.Op): def dynamic_slice(xs: xb.Op, idx: xb.Op) -> xb.Op: indices = [idx] + [idx.zeros_like() for _ in range(xs.shape().rank - 1)] diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py index d4999cd8bb0b..fcb0a3e89aaa 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py @@ -243,8 +243,8 @@ def make_sequence_metadata( # # Remove tile visits that belong to a sequence not in our shard. iota = jnp.arange(num_sequences, dtype=jnp.int32) - active_sequence_mask = jnp.logical_and(iota <= end_sequence, iota - >= start_sequence) + active_sequence_mask = jnp.logical_and(iota <= end_sequence, + iota >= start_sequence) sequence_tiles = jnp.where(active_sequence_mask, sequence_tiles[:num_sequences], 0) num_tiles = sequence_tiles.sum() @@ -375,8 +375,8 @@ def _flash_attention( logical_q_blk_idx - 1, 0) is_first_processed_logical_q_blk = logical_q_blk_idx == 0 physical_q_blk_changed = ( - physical_q_tile_ids[logical_q_blk_idx] - != physical_q_tile_ids[prev_logical_q_blk_idx]) + physical_q_tile_ids[logical_q_blk_idx] != + physical_q_tile_ids[prev_logical_q_blk_idx]) first_time_seeing_physical_q_blk = jnp.logical_or( is_first_processed_logical_q_blk, physical_q_blk_changed) is_first_kv_blk = (kv_blk_idx == 0) @@ -509,8 +509,8 @@ def init_scratch_ref(): # pylint: disable=unused-variable logical_q_blk_idx + 1) is_last_logical_q_blk = (logical_q_blk_idx == num_logical_q_blks - 1) physical_q_blk_will_change = ( - physical_q_tile_ids[logical_q_blk_idx] - != physical_q_tile_ids[next_logical_q_blk_idx]) + physical_q_tile_ids[logical_q_blk_idx] != + physical_q_tile_ids[next_logical_q_blk_idx]) last_time_seeing_cur_physical_q_blk = jnp.logical_or( is_last_logical_q_blk, physical_q_blk_will_change) should_store_to_output = jnp.logical_and(is_last_kv_blk_idx, diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py index 48c28825418b..ee350d7fdacd 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -421,8 +421,8 @@ def init_scratch_ref(): ) causal_mask = row_ids < col_ids if sliding_window is not None: - causal_mask = jnp.logical_or(causal_mask, row_ids - sliding_window - >= col_ids) + causal_mask = jnp.logical_or(causal_mask, + row_ids - sliding_window >= col_ids) if soft_cap is not None: qk = soft_cap * jnp.tanh(qk / soft_cap) qk += jnp.where(causal_mask, mask_value, 0.0) diff --git a/torch_xla/experimental/splash_attention.py b/torch_xla/experimental/splash_attention.py index bdf05d03e5d1..d1d85220236a 100644 --- a/torch_xla/experimental/splash_attention.py +++ b/torch_xla/experimental/splash_attention.py @@ -86,7 +86,7 @@ def splash_attention_jax_wrapper( splash_attention_kernel, splash_attention_mask, ) - mesh = Mesh.from_str(config.mesh).maybe_convert_and_get_jax_mesh() + mesh = Mesh.from_str(config.mesh).get_jax_mesh() # input q,k,v shape: [batch, #head, seq_len, head_dim] if decoder_segment_ids is not None and not decoder_segment_ids.shape: decoder_segment_ids = None From de22e83beace8f68cebd960827dfcc1322c5fb10 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Thu, 24 Apr 2025 18:52:06 -0700 Subject: [PATCH 2/3] update yapf --- benchmarks/benchmark_model.py | 5 ++-- infra/ansible/config/pip.yaml | 2 +- test/pytorch_test_base.py | 4 +-- test/spmd/test_xla_sharding.py | 27 ++++++++++--------- test/test_operations.py | 8 +++--- test/test_pallas.py | 8 +++--- test/test_pallas_spmd.py | 8 +++--- test/test_splash_attention.py | 8 +++--- torch_xla/distributed/xla_multiprocessing.py | 5 +++- .../experimental/gradient_accumulation.py | 4 +-- .../ragged_paged_attention_kernel.py | 12 ++++----- .../ragged_paged_attention_v2.py | 4 +-- 12 files changed, 49 insertions(+), 46 deletions(-) diff --git a/benchmarks/benchmark_model.py b/benchmarks/benchmark_model.py index 29951c02a92d..2b2f6c1957b7 100644 --- a/benchmarks/benchmark_model.py +++ b/benchmarks/benchmark_model.py @@ -227,9 +227,8 @@ def is_compatible(self, dummy_benchmark_model: BenchmarkModel, def get_benchmark_indices(self, length: int): start = self._args.partition_id * (length // self._args.total_partitions) end = ((self._args.partition_id + 1) * - (length // self._args.total_partitions) - if self._args.partition_id < self._args.total_partitions - 1 else - length) + (length // self._args.total_partitions) if self._args.partition_id + < self._args.total_partitions - 1 else length) return start, end def skip_model(self, model_name: str): diff --git a/infra/ansible/config/pip.yaml b/infra/ansible/config/pip.yaml index 33840cbb0c6b..695c4a917978 100644 --- a/infra/ansible/config/pip.yaml +++ b/infra/ansible/config/pip.yaml @@ -28,7 +28,7 @@ pip: - tqdm - typing_extensions - sympy - - yapf==0.30.0 + - yapf==0.40.2 build_amd64: - mkl diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index bcf3cdb32779..b47ae3f3de6d 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -619,8 +619,8 @@ def skipped_test(self, *args, reason=reason, **kwargs): setattr(cls, dtype_test_name, disallowed_test) if not skipped: xla_dtypes.append( - dtype_combination - if len(dtype_combination) > 1 else dtype_combination[0]) + dtype_combination if len(dtype_combination) > + 1 else dtype_combination[0]) if len(xla_dtypes) != 0: test.dtypes[cls.device_type] = xla_dtypes super().instantiate_test(name, test, generic_cls=generic_cls) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 28c78c635f8a..5afbee0aedb2 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -618,9 +618,9 @@ def test_inplace_add_with_sharding(self): # avoid calling xr.addressable_device_count here otherwise it will init the test # in non-spmd mode. - @unittest.skipIf(xr.device_type() == 'CPU', - "sharding will be the same for both tensors on single device" - ) + @unittest.skipIf( + xr.device_type() == 'CPU', + "sharding will be the same for both tensors on single device") def test_shard_hashing(self): xt1 = torch.ones(2, 2).to(xm.xla_device()) xt2 = torch.ones(2, 2).to(xm.xla_device()) @@ -1383,8 +1383,9 @@ def test_get_1d_mesh(self): self.assertEqual(mesh_without_name.mesh_shape, (xr.global_runtime_device_count(),)) - @unittest.skipUnless(xr.global_runtime_device_count() > 1, - "Multiple devices required for dataloader sharding test") + @unittest.skipUnless( + xr.global_runtime_device_count() > 1, + "Multiple devices required for dataloader sharding test") def test_data_loader_with_sharding(self): device = torch_xla.device() mesh = xs.get_1d_mesh("data") @@ -1405,8 +1406,9 @@ def test_data_loader_with_sharding(self): f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" ) - @unittest.skipUnless(xr.global_runtime_device_count() > 1, - "Multiple devices required for dataloader sharding test") + @unittest.skipUnless( + xr.global_runtime_device_count() > 1, + "Multiple devices required for dataloader sharding test") def test_data_loader_with_non_batch_size(self): device = torch_xla.device() mesh = xs.get_1d_mesh("data") @@ -1427,8 +1429,9 @@ def test_data_loader_with_non_batch_size(self): f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" ) - @unittest.skipUnless(xr.global_runtime_device_count() > 1, - "Multiple devices required for dataloader sharding test") + @unittest.skipUnless( + xr.global_runtime_device_count() > 1, + "Multiple devices required for dataloader sharding test") def test_data_loader_with_non_batch_size_and_mini_batch(self): device = torch_xla.device() mesh = xs.get_1d_mesh("data") @@ -1660,9 +1663,9 @@ def test_get_logical_mesh(self): self.assertEqual(logical_mesh.shape, mesh_shape) np.testing.assert_array_equal(np.sort(logical_mesh.flatten()), device_ids) - @unittest.skipIf(xr.device_type() == 'CPU', - "sharding will be the same for both tensors on single device" - ) + @unittest.skipIf( + xr.device_type() == 'CPU', + "sharding will be the same for both tensors on single device") def test_shard_as(self): mesh = self._get_mesh((self.n_devices,)) partition_spec = (0,) diff --git a/test/test_operations.py b/test/test_operations.py index 33db55a3585a..aa6ac3e016df 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2959,11 +2959,9 @@ def test_dlpack_roundtrip_tensor(self, dtype): @onlyIfTorchSupportsCUDA @onlyIfPJRTDeviceIsCUDA - @parameterized.parameters(*all_types_and_complex_and(torch.half, - torch.bfloat16, - torch.bool, torch.uint16, - torch.uint32, - torch.uint64)) + @parameterized.parameters( + *all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, + torch.uint16, torch.uint32, torch.uint64)) def test_dlpack_roundtrip_scalar(self, dtype): xla_device = xm.xla_device() xla_tensor_0 = torch.tensor(42, dtype=dtype).to(xla_device) diff --git a/test/test_pallas.py b/test/test_pallas.py index 5340755506fc..c1f9df9ba0c2 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -41,10 +41,10 @@ class PallasTest(parameterized.TestCase): # therefore we use != instead of ==. def _make_attention_mask_from_segment_ids(self, q_segment_ids, kv_segment_ids): - return q_segment_ids.view(q_segment_ids.shape[0], 1, - q_segment_ids.shape[1], 1) != kv_segment_ids.view( - kv_segment_ids.shape[0], 1, 1, - kv_segment_ids.shape[1]) + return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1], + 1) != kv_segment_ids.view(kv_segment_ids.shape[0], + 1, 1, + kv_segment_ids.shape[1]) def _attention(self, q, k, v, *, attn_mask=None, ab=None): attn_weight = q @ k.transpose(-2, -1) diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index f723227217a9..2e4de37fadc6 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -41,10 +41,10 @@ class PallasTest(unittest.TestCase): # therefore we use != instead of ==. def _make_attention_mask_from_segment_ids(self, q_segment_ids, kv_segment_ids): - return q_segment_ids.view(q_segment_ids.shape[0], 1, - q_segment_ids.shape[1], 1) != kv_segment_ids.view( - kv_segment_ids.shape[0], 1, 1, - kv_segment_ids.shape[1]) + return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1], + 1) != kv_segment_ids.view(kv_segment_ids.shape[0], + 1, 1, + kv_segment_ids.shape[1]) def _attention(self, q, k, v, *, attn_mask=None, ab=None): attn_weight = q @ k.transpose(-2, -1) diff --git a/test/test_splash_attention.py b/test/test_splash_attention.py index 9fd6c1306ea9..6e8bb56fe2c3 100644 --- a/test/test_splash_attention.py +++ b/test/test_splash_attention.py @@ -62,10 +62,10 @@ def setUp(self): def _make_attention_mask_from_segment_ids(self, q_segment_ids, kv_segment_ids): - return q_segment_ids.view(q_segment_ids.shape[0], 1, - q_segment_ids.shape[1], 1) != kv_segment_ids.view( - kv_segment_ids.shape[0], 1, 1, - kv_segment_ids.shape[1]) + return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1], + 1) != kv_segment_ids.view(kv_segment_ids.shape[0], + 1, 1, + kv_segment_ids.shape[1]) def maybe_repeat_kv(self, hidden_state): if hidden_state.size(1) == self.NUM_Q_HEADS: diff --git a/torch_xla/distributed/xla_multiprocessing.py b/torch_xla/distributed/xla_multiprocessing.py index 36a2f514c06c..d699abaebafb 100644 --- a/torch_xla/distributed/xla_multiprocessing.py +++ b/torch_xla/distributed/xla_multiprocessing.py @@ -174,7 +174,10 @@ def _v6e_create_replica_groups() -> List | None: return None -device_kind_handler_dict: dict[str, Callable[..., List | None],] = { +device_kind_handler_dict: dict[ + str, + Callable[..., List | None], +] = { _TPU_V5P: _v5p_create_replica_groups, _TPU_V6E: _v6e_create_replica_groups } diff --git a/torch_xla/experimental/gradient_accumulation.py b/torch_xla/experimental/gradient_accumulation.py index ad299a393492..5299291861ff 100644 --- a/torch_xla/experimental/gradient_accumulation.py +++ b/torch_xla/experimental/gradient_accumulation.py @@ -288,8 +288,8 @@ def add_to_mapping(val: torch.Tensor, iterable_tensors, fake_iterable_tensors, carried_tensors, fake_carried_tensors, params, grads) - def _body_fn_wrapper(curr_iter: xb.Op, curr_loss: xb.Op, - *while_params: xb.Op): + def _body_fn_wrapper(curr_iter: xb.Op, curr_loss: xb.Op, *while_params: + xb.Op): def dynamic_slice(xs: xb.Op, idx: xb.Op) -> xb.Op: indices = [idx] + [idx.zeros_like() for _ in range(xs.shape().rank - 1)] diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py index fcb0a3e89aaa..d4999cd8bb0b 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py @@ -243,8 +243,8 @@ def make_sequence_metadata( # # Remove tile visits that belong to a sequence not in our shard. iota = jnp.arange(num_sequences, dtype=jnp.int32) - active_sequence_mask = jnp.logical_and(iota <= end_sequence, - iota >= start_sequence) + active_sequence_mask = jnp.logical_and(iota <= end_sequence, iota + >= start_sequence) sequence_tiles = jnp.where(active_sequence_mask, sequence_tiles[:num_sequences], 0) num_tiles = sequence_tiles.sum() @@ -375,8 +375,8 @@ def _flash_attention( logical_q_blk_idx - 1, 0) is_first_processed_logical_q_blk = logical_q_blk_idx == 0 physical_q_blk_changed = ( - physical_q_tile_ids[logical_q_blk_idx] != - physical_q_tile_ids[prev_logical_q_blk_idx]) + physical_q_tile_ids[logical_q_blk_idx] + != physical_q_tile_ids[prev_logical_q_blk_idx]) first_time_seeing_physical_q_blk = jnp.logical_or( is_first_processed_logical_q_blk, physical_q_blk_changed) is_first_kv_blk = (kv_blk_idx == 0) @@ -509,8 +509,8 @@ def init_scratch_ref(): # pylint: disable=unused-variable logical_q_blk_idx + 1) is_last_logical_q_blk = (logical_q_blk_idx == num_logical_q_blks - 1) physical_q_blk_will_change = ( - physical_q_tile_ids[logical_q_blk_idx] != - physical_q_tile_ids[next_logical_q_blk_idx]) + physical_q_tile_ids[logical_q_blk_idx] + != physical_q_tile_ids[next_logical_q_blk_idx]) last_time_seeing_cur_physical_q_blk = jnp.logical_or( is_last_logical_q_blk, physical_q_blk_will_change) should_store_to_output = jnp.logical_and(is_last_kv_blk_idx, diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py index ee350d7fdacd..48c28825418b 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -421,8 +421,8 @@ def init_scratch_ref(): ) causal_mask = row_ids < col_ids if sliding_window is not None: - causal_mask = jnp.logical_or(causal_mask, - row_ids - sliding_window >= col_ids) + causal_mask = jnp.logical_or(causal_mask, row_ids - sliding_window + >= col_ids) if soft_cap is not None: qk = soft_cap * jnp.tanh(qk / soft_cap) qk += jnp.where(causal_mask, mask_value, 0.0) From d83ffca03e994c97d9245b6cb1fb200186304cec Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Fri, 25 Apr 2025 11:34:00 -0700 Subject: [PATCH 3/3] Address comments --- test/test_assume_pure_spmd.py | 4 +++- test/test_jax_interop_spmd.py | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/test/test_assume_pure_spmd.py b/test/test_assume_pure_spmd.py index 4edf8cf45614..6cc7615eddbe 100644 --- a/test/test_assume_pure_spmd.py +++ b/test/test_assume_pure_spmd.py @@ -12,10 +12,12 @@ class AssumePureSpmdTest(unittest.TestCase): - def setUp(self): + @classmethod + def setUpClass(cls): # Activate SPMD xr.use_spmd() + def setUp(self): # Set up a simple SPMD mesh for these tests. self.spmd_mesh = get_1d_mesh(axis_name="model") set_global_mesh(self.spmd_mesh) diff --git a/test/test_jax_interop_spmd.py b/test/test_jax_interop_spmd.py index bf0a39d0eeac..7712a50e533f 100644 --- a/test/test_jax_interop_spmd.py +++ b/test/test_jax_interop_spmd.py @@ -11,11 +11,14 @@ class TestJaxInteropSpmd(unittest.TestCase): - def setUp(self): - xb._JAX_TO_XLA_COMPUTATION_CACHE.clear() + @classmethod + def setUpClass(cls): # Activate SPMD xr.use_spmd() + def setUp(self): + # Clear cached HLO between test cases. + xb._JAX_TO_XLA_COMPUTATION_CACHE.clear() # Set up a simple SPMD mesh for these tests. self.spmd_mesh = get_1d_mesh(axis_name="model") set_global_mesh(self.spmd_mesh)