diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 7d979158f162..6fce95198434 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -251,7 +251,16 @@ at::Tensor to_meta(const at::Tensor& tensor) { torch::lazy::BackendDevice GetXlaDeviceOrCurrent( const std::optional& device) { - auto xla_device_opt = bridge::GetXlaDevice(device); + std::optional xla_device_opt = + bridge::GetXlaDevice(device); + std::cout << "in GetXlaDeviceOrCurrent" << std::endl; + if (xla_device_opt.has_value()) { + std::cout << "xla_device_opt: " << (*xla_device_opt).toString() + << std::endl; + } else { + std::cout << "bridge::GetCurrentDevice(): " + << bridge::GetCurrentDevice().toString() << std::endl; + } return xla_device_opt ? *xla_device_opt : bridge::GetCurrentDevice(); } @@ -641,17 +650,22 @@ at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self, auto dst_tensor = bridge::TryGetXlaTensor(dst); auto self_tensor = bridge::TryGetXlaTensor(self); if (!self_tensor) { - static bool sync_update = + std::cout << "in copy from 1" << std::endl; + std::cout << "check dst_tensor device" << dst_tensor->GetDevice().toString() + << std::endl; + bool sync_update = runtime::sys_util::GetEnvBool("XLA_TENSOR_UPDATE_SYNC", true) && !UseVirtualDevice(); dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update); XLA_CHECK(dst_tensor); } else if (!dst_tensor) { + std::cout << "in copy from 2" << std::endl; at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true); at::Tensor typed_tensor = torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false); dst.resize_as_(typed_tensor).copy_(typed_tensor); } else { + std::cout << "in copy from 3" << std::endl; tensor_methods::copy_(dst_tensor, self_tensor); bridge::ReplaceXlaTensor(dst, dst_tensor); } @@ -725,7 +739,7 @@ at::Tensor XLANativeFunctions::_to_copy( // Use the eager .to on the eager tensor. return eager_tensor.to(options, non_blocking, /*copy=*/true); } - + std::cout << "check options.device() " << options.device() << std::endl; // Case 2: Create a new XLA tensor with the supplied data and options. auto new_tensor = empty_symint(self.sym_sizes(), at::typeMetaToScalarType(options.dtype()), @@ -1555,6 +1569,8 @@ at::Tensor XLANativeFunctions::empty_symint( // s_copy_(). XLATensorPtr xla_tensor; if (all_dims_static) { + std::cout << "in empty_symint, GetXlaDeviceOrCurrent(device): " + << GetXlaDeviceOrCurrent(device).toString() << std::endl; xla_tensor = tensor_methods::full(XlaHelpers::I64List(int_sizes.value()), 0, GetXlaDeviceOrCurrent(device), at::dtype_or_default(dtype)); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 04dcbf526ed0..389e0f6d795a 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2419,6 +2419,41 @@ void InitXlaModuleBindings(py::module m) { } return result; }); + m.def( + "_global_tensor_from_tpu_shards", + [](const std::vector& shards, const xla::OpSharding& sharding, + std::optional>& global_shape) -> at::Tensor { + std::vector handles; + std::vector element_types; + for (auto& shard : shards) { + XLATensorPtr xtensor = bridge::GetXlaTensor(shard); + XLA_CHECK(xtensor->GetXlaData() != nullptr) + << "Shard data is not available"; + runtime::ComputationClient::DataPtr handle = + std::dynamic_pointer_cast( + xtensor->GetXlaData()); + handles.push_back(handle); + element_types.push_back( + MaybeUpcastToHostTorchType(handle->shape().element_type())); + } + auto local_devices = runtime::GetComputationClient()->GetLocalDevices(); + auto device = GetVirtualDevice(); + auto primitive_type = + MakeXlaPrimitiveType(shards[0].type().scalarType(), &device); + xla::Shape tensor_shape = MakeArrayShapeFromDimensions( + global_shape.value(), /*dynamic_dimensions=*/{}, primitive_type, + static_cast(device.type())); + auto sharding_spec = + std::make_shared(sharding, tensor_shape); + runtime::ComputationClient::DataPtr sharded_data = + ShardingUtil::CreateShardedDataFromShards(handles, local_devices, + sharding_spec); + + XLATensorPtr xla_tensor = XLATensor::Create(std::move(sharded_data)); + return bridge::AtenFromXlaTensor(std::move(xla_tensor)); + }, + py::arg("shards"), py::arg("sharding"), + py::arg("global_shape") = py::none()); // For each input tensors' local shards, returns the tuple: // (replica_id: int, indices: Union[List[Slice], Ellipsis]), // where `replica_id` is the replica the shard belongs to and `indices` index diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index b192d8d2e149..e32240558145 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -273,6 +273,11 @@ class ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) = 0; + virtual DataPtr CreateShardedDataFromShards(std::vector shards, + std::string device, + xla::Shape global_shape, + xla::OpSharding sharding) = 0; + // Returns data shards. We expect this to be called on PjRtShardedData to // retrieve the shards. If other data type is passed, it returns the input // wrapped inside a vector. diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index 4a2e528e26d8..1c63ab4dcca5 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -192,6 +192,12 @@ ComputationClient::DataPtr IfrtComputationClient::CreateDataPlaceholder( std::move(sharding)); } +ComputationClient::DataPtr IfrtComputationClient::CreateShardedDataFromShards( + std::vector shards, std::string device, + xla::Shape global_shape, xla::OpSharding sharding) { + XLA_ERROR() << __FUNCTION__ << " not implemented"; +} + std::vector IfrtComputationClient::GetDataShards( ComputationClient::DataPtr data) { tsl::profiler::TraceMe activity("IfrtComputationClient::GetDataShards", diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index c83a705abbbd..a8e29e92c99d 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -35,6 +35,11 @@ class IfrtComputationClient : public ComputationClient { std::string device, xla::Shape shape, std::optional sharding = std::nullopt) override; + DataPtr CreateShardedDataFromShards(std::vector shards, + std::string device, + xla::Shape global_shape, + xla::OpSharding sharding) override; + std::vector GetDataShards(DataPtr data) override; DataPtr GetDataShard(DataPtr data, size_t index) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 8caad6d230fa..e3d4e11ab7b0 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -309,6 +309,20 @@ ComputationClient::DataPtr PjRtComputationClient::TransferShardsToDevice( sharding); } +ComputationClient::DataPtr PjRtComputationClient::CreateShardedDataFromShards( + std::vector shards, std::string device, xla::Shape global_shape, + xla::OpSharding sharding) { + tsl::profiler::TraceMe activity( + "PjRtComputationClient::CreateShardedDataFromShards", + tsl::profiler::TraceMeLevel::kInfo); + std::vector> pjrt_shards; + for (auto shard : shards) { + pjrt_shards.push_back(std::dynamic_pointer_cast(shard)); + } + return std::make_shared(device, global_shape, pjrt_shards, + sharding); +} + ComputationClient::DataPtr PjRtComputationClient::CopyToDevice( ComputationClient::DataPtr data, std::string dst) { tsl::profiler::TraceMe activity("PjRtComputationClient::CopyToDevice", diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 6530ce768b4b..fb1fe27cf631 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -36,6 +36,11 @@ class PjRtComputationClient : public ComputationClient { static DataPtr CreateData(std::string device, xla::Shape shape, std::shared_ptr pjrt_buffer); + DataPtr CreateShardedDataFromShards(std::vector shards, + std::string device, + xla::Shape global_shape, + xla::OpSharding sharding) override; + std::vector GetDataShards(DataPtr data) override; DataPtr GetDataShard(DataPtr data, size_t index) override; diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index d58144d6844a..79f91783f748 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -622,10 +622,26 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( source_tensors.push_back(std::make_shared( local_shards[j], shard_shape, devices[j])); } + std::cout << "ShardingUtil::CreateShardedData check global shape" + << global_shape.ToString() << std::endl; return runtime::GetComputationClient()->TransferShardsToDevice( source_tensors, GetVirtualDevice().toString(), global_shape, sharding); } +runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedDataFromShards( + const std::vector& local_shards, + const std::vector& devices, + const XLATensor::ShardingSpecPtr& sharding_spec) { + XLA_CHECK(local_shards.size() == devices.size()) + << "A device must be speficied for each shard"; + XLA_CHECK(sharding_spec != nullptr) + << "A device must be speficied for each shard"; + xla::Shape global_shape = sharding_spec->shape; + xla::OpSharding sharding = sharding_spec->sharding; + return runtime::GetComputationClient()->CreateShardedDataFromShards( + local_shards, GetVirtualDevice().toString(), global_shape, sharding); +} + std::vector ShardingUtil::GetAutoShardingMesh() { // Auto-sharding uses mesh_shape = {n_devices, 1} if XLA_AUTO_SPMD_MESH // is not set. XLA_AUTO_SPMD_MESH takes a form of string, "2,2" which @@ -779,11 +795,14 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input, XLA_CHECK(sharding.type() != xla::OpSharding::UNKNOWN) << "Can't explicilty annotate with UNKNOWN sharding type."; XLATensorPtr xtensor = bridge::GetXlaTensor(input); + auto xla_shape = MakeShapeWithDeviceLayout( + xtensor->shape(), + static_cast(xtensor->GetDevice().type())); XLATensor::ShardingSpecPtr new_sharding_spec = - std::make_shared( - sharding, MakeShapeWithDeviceLayout( - xtensor->shape(), static_cast( - xtensor->GetDevice().type()))); + std::make_shared(sharding, xla_shape); + std::cout << "XlaMarkSharding... " + << new_sharding_spec->sharding.DebugString() << std::endl; + std::cout << "xla shape..." << xla_shape.ToString() << std::endl; // For Non DeviceData IR values, we directly attach the sharding spec // to the xtensor. diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index d243c8872a31..2223977a52f3 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -120,6 +120,11 @@ class ShardingUtil { const std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec); + static runtime::ComputationClient::DataPtr CreateShardedDataFromShards( + const std::vector& local_shards, + const std::vector& devices, + const XLATensor::ShardingSpecPtr& sharding_spec); + static void XlaMarkSharding(const at::Tensor& input, xla::OpSharding sharding); diff --git a/torch_xla/distributed/spmd/dataloading_utils.py b/torch_xla/distributed/spmd/dataloading_utils.py new file mode 100644 index 000000000000..f2302d15eb25 --- /dev/null +++ b/torch_xla/distributed/spmd/dataloading_utils.py @@ -0,0 +1,22 @@ +from typing import Any, List, Optional, Sequence, Set, Tuple, Union + +import torch +import torch_xla + + +def create_global_tensor_from_shards(shape: Tuple[int], + sharding: torch_xla._XLAC.OpSharding, + shards: Sequence[torch.Tensor]): + """ + Similar to jax.make_array_from_single_device_arrays + """ + # Now this function relies on caller to pass matching sharding and shape. + # TODO(lsy323): Check if shape, sharding and shards are matching. + if shards[0].device.type == 'cpu': + assert all( + s.device.type == 'cpu' for s in shards), "All shards must be on CPU." + from_cpu_shards = torch_xla._XLAC._global_tensor_from_tpu_shards + return from_cpu_shards(shards, sharding, shape) + else: + from_tpu_shards = torch_xla._XLAC._global_tensor_from_tpu_shards + return from_tpu_shards(shards, sharding, shape)