diff --git a/test/cpp/test_xla_sharding.cpp b/test/cpp/test_xla_sharding.cpp index e1f908b5c80..ed9b5b7677c 100644 --- a/test/cpp/test_xla_sharding.cpp +++ b/test/cpp/test_xla_sharding.cpp @@ -222,6 +222,113 @@ TEST_F(XLAShardingTest, ShardTensor) { EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); } +TEST_F(XLAShardingTest, ShardTensorLocalMesh) { + // Test sharding with a local mesh. + std::vector devices = {"TPU:8", "TPU:9", "TPU:10", "TPU:11", + "TPU:12", "TPU:13", "TPU:14", "TPU:15"}; + + // 1D tiled + at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat)); + xla::Shape tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + xla::OpSharding sharding = + xla::HloSharding::Tile1D( + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()), + devices.size()) + .ToProto(); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); + EXPECT_EQ(shards.size(), 8); + for (auto shard : shards) { + EXPECT_EQ(shard.sizes(), c10::ArrayRef({1})); + } + + // 2D tiled, The first dim is halved and the last replicated. The last shard + // size should be smaller in dim=1 because it's not evenly divisible. + tensor = at::ones({8, 7, 4}, at::TensorOptions(at::kFloat)); + tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + xla::Array2D mesh({ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + }); + sharding = xla::HloSharding::Tile(mesh).ToProto(); + sharding_spec = + std::make_shared(sharding, tensor_shape); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({4, 2, 4})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({4, 1, 4})); + + // 3D tiled, the first dim is replicated and the last halved. The last shard + // size should be smaller in dim=1 because it's not evenly divisible. + xla::Array3D cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}); + sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto(); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({8, 2, 2})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 1, 2})); + + // Replicated, all shards should be identical. + sharding_spec->sharding = xla::HloSharding::Replicate().ToProto(); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({8, 7, 4})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({8, 7, 4})); + + // 4D tiled, the first and second dims are replicated and the last halved. The + // last shard size should be smaller in dim=2 because it's not evenly + // divisible. + tensor = at::ones({1, 8, 7, 4}, at::TensorOptions(at::kFloat)); + tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + xla::Array4D tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}}); + sharding = xla::HloSharding::Tile(tesseract).ToProto(); + sharding_spec = + std::make_shared(sharding, tensor_shape); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1, 8, 2, 2})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({1, 8, 1, 2})); + + // 4D tiled and padded, all shard sizes should be idential. + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/true); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({1, 8, 2, 2})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({1, 8, 2, 2})); + + // 5D tiled, the first and second dims are replicated and the last halved. The + // last shard size should be smaller in dim=2 because it's not evenly + // divisible. + tensor = at::ones({10, 1, 8, 7, 4}, at::TensorOptions(at::kFloat)); + tensor_shape = + CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()); + xla::Array hypercube(std::vector{1, 1, 2, 2, 2}); + hypercube.FillIota(0); + sharding = xla::HloSharding::Tile(hypercube).ToProto(); + sharding_spec = + std::make_shared(sharding, tensor_shape); + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/false); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 3, 2})); + + // 5D tiled and padded, all shard sizes should be identical. + shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices, + /*padded=*/true); + EXPECT_EQ(shards.size(), 8); + EXPECT_EQ(shards[0].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); + EXPECT_EQ(shards[7].sizes(), c10::ArrayRef({10, 1, 4, 4, 2})); +} + TEST_F(XLAShardingTest, ShardTensorMultiHost) { std::vector devices = {"TPU:4", "TPU:5", "TPU:6", "TPU:7"}; diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 98012ea2d35..020979a4880 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1482,7 +1482,7 @@ void InitXlaModuleBindings(py::module m) { if (UseVirtualDevice()) { return 1; } else { - return runtime::GetComputationClient()->GetNumDevices(); + return runtime::GetComputationClient()->GetNumLocalDevices(); } }); m.def("_xla_get_all_devices", []() { @@ -1500,13 +1500,16 @@ void InitXlaModuleBindings(py::module m) { m.def("_xla_get_runtime_devices", []() { return runtime::GetComputationClient()->GetLocalDevices(); }); m.def("_xla_num_runtime_devices", []() -> int64_t { - return runtime::GetComputationClient()->GetNumDevices(); + return runtime::GetComputationClient()->GetNumLocalDevices(); }); m.def("_xla_get_all_runtime_devices", []() { std::vector all_devices = runtime::GetComputationClient()->GetAllDevices(); return all_devices; }); + m.def("_xla_num_global_devices", []() -> int64_t { + return runtime::GetComputationClient()->GetNumGlobalDevices(); + }); m.def( "_xla_real_devices", [](const std::optional> devices) { diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index b192d8d2e14..20915de32e2 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -374,7 +374,9 @@ class ComputationClient { virtual std::intptr_t GetCudaStreamForDevice(int local_device_id) const = 0; - virtual size_t GetNumDevices() const = 0; + virtual size_t GetNumLocalDevices() const = 0; + + virtual size_t GetNumGlobalDevices() const = 0; virtual std::vector GetLocalDevices() const = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index a197aec460e..11aaa1a0b8d 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -613,10 +613,14 @@ IfrtComputationClient::ExecuteReplicated( return data_handles; } -size_t IfrtComputationClient::GetNumDevices() const { +size_t IfrtComputationClient::GetNumLocalDevices() const { return client_->addressable_device_count(); } +size_t IfrtComputationClient::GetNumGlobalDevices() const { + return client_->device_count(); +} + std::string IfrtComputationClient::GetDefaultDevice() const { return IfrtDeviceToString(client_->addressable_devices()[0]); } diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 73b8e21c9f0..26135f65ab5 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -79,7 +79,9 @@ class IfrtComputationClient : public ComputationClient { absl::Span devices, const ExecuteReplicatedOptions& options) override; - size_t GetNumDevices() const override; + size_t GetNumLocalDevices() const override; + + size_t GetNumGlobalDevices() const override; std::string GetDefaultDevice() const override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 749419f66cd..3783bb61b5d 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -559,7 +559,10 @@ std::vector PjRtComputationClient::Compile( .set_allow_spmd_sharding_propagation_to_output( {instance.allow_spmd_sharding_propagation_to_output}); - int num_partitions = client_->device_count(); + int num_partitions = GetNumGlobalDevices(); + if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) { + num_partitions = GetNumLocalDevices(); + } compile_options.executable_build_options.set_num_partitions( num_partitions); compile_options.executable_build_options.set_num_replicas(1); @@ -589,11 +592,18 @@ std::vector PjRtComputationClient::Compile( } // TODO(244391366) verify this is correct for the collectives ops - xla::DeviceAssignment device_assignment(1, client_->device_count()); + xla::DeviceAssignment device_assignment(1, num_partitions); // DeviceAssignment values must be the PjRtDevice ID, so we need to // unwind the global ordinal mapping. - for (const auto& [device_id, global_ordinal] : global_ordinals_) { - device_assignment(0, global_ordinal) = device_id; + if (runtime::sys_util::GetEnvBool("XLA_USE_LOCAL_SPMD", false)) { + auto local_pjrt_devices = client_->addressable_devices(); + for (int i = 0; i < local_pjrt_devices.size(); ++i) { + device_assignment(0, i) = local_pjrt_devices[i]->id(); + } + } else { + for (const auto& [device_id, global_ordinal] : global_ordinals_) { + device_assignment(0, global_ordinal) = device_id; + } } compile_options.executable_build_options.set_device_assignment( device_assignment); @@ -649,7 +659,6 @@ std::vector PjRtComputationClient::Compile( CreateCompileHandlesCounter()->AddValue(1); } - return computations; } @@ -917,10 +926,14 @@ PjRtComputationClient::ExecuteReplicated( return data_handles; } -size_t PjRtComputationClient::GetNumDevices() const { +size_t PjRtComputationClient::GetNumLocalDevices() const { return client_->addressable_device_count(); } +size_t PjRtComputationClient::GetNumGlobalDevices() const { + return client_->device_count(); +} + std::string PjRtComputationClient::GetDefaultDevice() const { return PjRtDeviceToString(client_->addressable_devices()[0]); } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 9791f32381b..090ff952fdf 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -86,7 +86,9 @@ class PjRtComputationClient : public ComputationClient { absl::Span devices, const ExecuteReplicatedOptions& options) override; - size_t GetNumDevices() const override; + size_t GetNumLocalDevices() const override; + + size_t GetNumGlobalDevices() const override; std::string GetDefaultDevice() const override; diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 4e69127ff81..fcf793ff5bc 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -57,7 +57,7 @@ struct XLAGuardImpl : public c10::impl::DeviceGuardImplInterface { return 0; } - return client->GetNumDevices(); + return client->GetNumLocalDevices(); } }; diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 0b8c5489798..c33b5431455 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1422,10 +1422,10 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( program_shape.result(), static_cast(coll.device.type())); std::vector instances; - instances.push_back({std::move(computation), coll.device.toString(), - runtime::GetComputationClient()->GetCompilationDevices( - coll.device.toString(), devices), - &shape, should_wrap_parameter, is_sharded}); + instances.emplace_back(std::move(computation), coll.device.toString(), + runtime::GetComputationClient()->GetCompilationDevices( + coll.device.toString(), devices), + &shape, should_wrap_parameter, is_sharded); instances.front().eager_mode = UseEagerMode(); if (use_autosharding) { TF_VLOG(5) << "use_auto_spmd_partitioning is set."; diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index d58144d6844..2058e7490a6 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -85,10 +85,20 @@ std::vector TileAssignmentDimensions( // order of the output corresponds to the order of the `devices`, which can be // arbitrarily set by the caller. std::unordered_map build_index_map( - const std::vector& devices) { + const std::vector& devices, size_t num_mesh_devices) { std::unordered_map device_index; for (int i = 0; i < devices.size(); ++i) { - int global_ordinal = ParseDeviceString(devices[i]).ordinal(); + // The global ordianl here is the device's ordinal in the mesh, which is + // can be different from the physical device index. + // We only support 2 cases here: + // 1. Mesh contains all global devices. + // 2. Mesh contains only local devices. (in multi-host scenario) + // Example: In multi-host v6e-8, each host has a mesh of its local + // devices, host 1 has devices TPU:{4, 5, 6, 7}. In this case + // the global ordinal of TPU:4 is 0, TPU:5 is 1, and so on. + + int global_ordinal = + ParseDeviceString(devices[i]).ordinal() % num_mesh_devices; device_index[global_ordinal] = i; } return device_index; @@ -371,7 +381,12 @@ ShardingUtil::GetShardReplicaAndIndicesForDevices( shard_indices[i] = std::make_pair(global_ordinal, indices); } } else if (sharding.type() == xla::OpSharding::OTHER) { - auto device_index = build_index_map(devices); + size_t num_tiles = + std::accumulate(sharding.tile_assignment_dimensions().begin(), + sharding.tile_assignment_dimensions().end(), 1, + [](int a, int b) { return a * b; }); + std::unordered_map device_index = + build_index_map(devices, num_tiles); std::vector tile_assignment_devices( sharding.tile_assignment_devices().begin(), sharding.tile_assignment_devices().end()); @@ -442,7 +457,6 @@ std::vector ShardingUtil::ShardTensor( } TF_VLOG(5) << "ShardTensor with sharding type(" << sharding.type() << ")... and minibatch = " << minibatch << std::endl; - auto device_index = build_index_map(devices); std::vector shards(devices.size()); if (shardings == nullptr || sharding.type() == xla::OpSharding::REPLICATED || sharding.type() == xla::OpSharding::UNKNOWN) { diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index a1cd9540fd1..d1a1db4b644 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -1,25 +1,25 @@ import collections -from collections.abc import Generator, MutableMapping +import functools +import itertools import math +import os from collections import OrderedDict, defaultdict +from collections.abc import Generator, MutableMapping from dataclasses import dataclass, field +from enum import IntEnum +from typing import Any, List, Optional, Sequence, Set, Tuple, Union + +import numpy as np import torch -from torch import Tensor -from torch.library import custom_op import torch_xla -import torch_xla.core.xla_model as xm import torch_xla._internal.utils as _utils -from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard -import torch_xla.runtime as xr +import torch_xla.core.xla_model as xm import torch_xla.debug.profiler as xp - -import numpy as np -import functools -import itertools -from typing import Tuple, Union, List, Sequence, Any, Optional, Set -from enum import IntEnum - -from torch.amp import custom_fwd, custom_bwd +import torch_xla.runtime as xr +from torch import Tensor +from torch.amp import custom_bwd, custom_fwd +from torch.library import custom_op +from torch_xla.distributed.spmd import XLAShard, XLAShardedTensor class Mesh: @@ -63,12 +63,24 @@ def __init__(self, device_ids = np.array(device_ids) assert (axis_names is None) or (len(mesh_shape) == len(axis_names)) assert axis_names is None or (len(set(axis_names)) == len(axis_names)) + # size of device_ids matches mesh_shape assert (len(device_ids) == np.prod(mesh_shape)) + # device ids are unique assert len(device_ids) == len(np.unique(device_ids)) self.device_ids = device_ids self.mesh_shape = mesh_shape self.axis_names = axis_names - assert all(d < self.size() for d in device_ids) + # device ids are continous + if os.environ['XLA_USE_LOCAL_SPMD'] == '1': + # In local SPMD mesh only contains local devices. + min_device_idx = xr.process_index() * xr.addressable_runtime_device_count( + ) + assert min_device_idx == np.min( + device_ids + ), "If not creating a mesh with all global devices, must use local devices." + assert all(d < self.size() for d in device_ids - np.min(device_ids)) + else: + assert all(d < self.size() for d in device_ids) def size(self): return np.prod(self.mesh_shape) @@ -140,6 +152,7 @@ def __str__(self): def from_str(cls, mesh_str: str) -> Optional["Mesh"]: """Create Mesh from string representation.""" import ast + import numpy as np try: dict_str = mesh_str.replace('Mesh', '') @@ -377,6 +390,20 @@ def _get_sharding_type(partition_spec: Tuple[Union[int, None]], return sharding_type +def _normalize_logical_mesh(device_mesh: np.ndarray) -> np.ndarray: + """ + Normalize the device mesh to start from 0. + + This is needed when mesh doesn't include all global devices + (e.g. In multi-host setup, each host has a mesh containing local devices). + Because HLO graph always use logical device ids in the sharding annotation, + we need to normalize the physical device ids to generate the correct HLO + sharding annotation. + """ + device_id_min = np.min(device_mesh) + return device_mesh.copy() - device_id_min + + def _get_tile_assignment( mesh: Mesh, partition_spec: Tuple[Union[Tuple[int], int, None]]) -> np.ndarray: @@ -393,8 +420,8 @@ def _get_tile_assignment( tiled_dims = [x for x in partition_spec if x is not None] permutation = np.hstack(tiled_dims).tolist() if tiled_dims else [] missing_axes = sorted(set(range(len(mesh.shape()))) - set(permutation)) - tile_assignment = mesh.get_logical_mesh().transpose(permutation + - missing_axes) + tile_assignment = _normalize_logical_mesh( + mesh.get_logical_mesh()).transpose(permutation + missing_axes) # For any tuples in the partition_spec, the grouped axes will be adjacent # after the permutation. Combine these dimensions into a single axis. @@ -548,8 +575,9 @@ def mark_sharding( >>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel """ num_devices = xr.global_runtime_device_count() + num_local_devices = xr.addressable_runtime_device_count() assert num_devices > 0, "This requires XLA supported device(s)." - assert mesh.size() == num_devices, \ + assert mesh.size() == num_devices or mesh.size() == num_local_devices, \ f"{mesh.mesh_shape} is not mappable over {num_devices} devices." # We only allow fully specified `partition_spec` to be applicable, as opposed # to filling in the unspecified replicated dims. Fully specified `partiion_spec` diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 1946ae05a52..a17b1c57e3b 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -212,7 +212,7 @@ def global_runtime_device_attributes() -> List[Dict[str, object]]: @functools.lru_cache() def global_runtime_device_count() -> int: """Returns the total number of runtime devices across all processes/hosts, especially useful for SPMD.""" - return len(torch_xla._XLAC._xla_get_all_runtime_devices()) + return torch_xla._XLAC._xla_num_global_devices() def addressable_runtime_device_count() -> int: