Skip to content

[SPMD] Add API to create global tensor from local CPU/TPU shards #8716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,16 @@ at::Tensor to_meta(const at::Tensor& tensor) {

torch::lazy::BackendDevice GetXlaDeviceOrCurrent(
const std::optional<c10::Device>& device) {
auto xla_device_opt = bridge::GetXlaDevice(device);
std::optional<torch::lazy::BackendDevice> xla_device_opt =
bridge::GetXlaDevice(device);
std::cout << "in GetXlaDeviceOrCurrent" << std::endl;
if (xla_device_opt.has_value()) {
std::cout << "xla_device_opt: " << (*xla_device_opt).toString()
<< std::endl;
} else {
std::cout << "bridge::GetCurrentDevice(): "
<< bridge::GetCurrentDevice().toString() << std::endl;
}
return xla_device_opt ? *xla_device_opt : bridge::GetCurrentDevice();
}

Expand Down Expand Up @@ -641,17 +650,22 @@ at::Tensor XLANativeFunctions::_copy_from(const at::Tensor& self,
auto dst_tensor = bridge::TryGetXlaTensor(dst);
auto self_tensor = bridge::TryGetXlaTensor(self);
if (!self_tensor) {
static bool sync_update =
std::cout << "in copy from 1" << std::endl;
std::cout << "check dst_tensor device" << dst_tensor->GetDevice().toString()
<< std::endl;
bool sync_update =
runtime::sys_util::GetEnvBool("XLA_TENSOR_UPDATE_SYNC", true) &&
!UseVirtualDevice();
dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update);
XLA_CHECK(dst_tensor);
} else if (!dst_tensor) {
std::cout << "in copy from 2" << std::endl;
at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true);
at::Tensor typed_tensor =
torch::lazy::CopyTensor(tensor, dst.scalar_type(), /*copy=*/false);
dst.resize_as_(typed_tensor).copy_(typed_tensor);
} else {
std::cout << "in copy from 3" << std::endl;
tensor_methods::copy_(dst_tensor, self_tensor);
bridge::ReplaceXlaTensor(dst, dst_tensor);
}
Expand Down Expand Up @@ -725,7 +739,7 @@ at::Tensor XLANativeFunctions::_to_copy(
// Use the eager .to on the eager tensor.
return eager_tensor.to(options, non_blocking, /*copy=*/true);
}

std::cout << "check options.device() " << options.device() << std::endl;
// Case 2: Create a new XLA tensor with the supplied data and options.
auto new_tensor =
empty_symint(self.sym_sizes(), at::typeMetaToScalarType(options.dtype()),
Expand Down Expand Up @@ -1555,6 +1569,8 @@ at::Tensor XLANativeFunctions::empty_symint(
// s_copy_().
XLATensorPtr xla_tensor;
if (all_dims_static) {
std::cout << "in empty_symint, GetXlaDeviceOrCurrent(device): "
<< GetXlaDeviceOrCurrent(device).toString() << std::endl;
xla_tensor = tensor_methods::full(XlaHelpers::I64List(int_sizes.value()), 0,
GetXlaDeviceOrCurrent(device),
at::dtype_or_default(dtype));
Expand Down
35 changes: 35 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2419,6 +2419,41 @@ void InitXlaModuleBindings(py::module m) {
}
return result;
});
m.def(
"_global_tensor_from_tpu_shards",
[](const std::vector<at::Tensor>& shards, const xla::OpSharding& sharding,
std::optional<std::vector<int64_t>>& global_shape) -> at::Tensor {
std::vector<runtime::ComputationClient::DataPtr> handles;
std::vector<at::ScalarType> element_types;
for (auto& shard : shards) {
XLATensorPtr xtensor = bridge::GetXlaTensor(shard);
XLA_CHECK(xtensor->GetXlaData() != nullptr)
<< "Shard data is not available";
runtime::ComputationClient::DataPtr handle =
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
xtensor->GetXlaData());
handles.push_back(handle);
element_types.push_back(
MaybeUpcastToHostTorchType(handle->shape().element_type()));
}
auto local_devices = runtime::GetComputationClient()->GetLocalDevices();
auto device = GetVirtualDevice();
auto primitive_type =
MakeXlaPrimitiveType(shards[0].type().scalarType(), &device);
xla::Shape tensor_shape = MakeArrayShapeFromDimensions(
global_shape.value(), /*dynamic_dimensions=*/{}, primitive_type,
static_cast<XlaDeviceType>(device.type()));
auto sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
runtime::ComputationClient::DataPtr sharded_data =
ShardingUtil::CreateShardedDataFromShards(handles, local_devices,
sharding_spec);

XLATensorPtr xla_tensor = XLATensor::Create(std::move(sharded_data));
return bridge::AtenFromXlaTensor(std::move(xla_tensor));
},
py::arg("shards"), py::arg("sharding"),
py::arg("global_shape") = py::none());
// For each input tensors' local shards, returns the tuple:
// (replica_id: int, indices: Union[List[Slice], Ellipsis]),
// where `replica_id` is the replica the shard belongs to and `indices` index
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ class ComputationClient {
std::string device, xla::Shape shape,
std::optional<xla::OpSharding> sharding = std::nullopt) = 0;

virtual DataPtr CreateShardedDataFromShards(std::vector<DataPtr> shards,
std::string device,
xla::Shape global_shape,
xla::OpSharding sharding) = 0;

// Returns data shards. We expect this to be called on PjRtShardedData to
// retrieve the shards. If other data type is passed, it returns the input
// wrapped inside a vector.
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ ComputationClient::DataPtr IfrtComputationClient::CreateDataPlaceholder(
std::move(sharding));
}

ComputationClient::DataPtr IfrtComputationClient::CreateShardedDataFromShards(
std::vector<ComputationClient::DataPtr> shards, std::string device,
xla::Shape global_shape, xla::OpSharding sharding) {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

std::vector<ComputationClient::DataPtr> IfrtComputationClient::GetDataShards(
ComputationClient::DataPtr data) {
tsl::profiler::TraceMe activity("IfrtComputationClient::GetDataShards",
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ class IfrtComputationClient : public ComputationClient {
std::string device, xla::Shape shape,
std::optional<xla::OpSharding> sharding = std::nullopt) override;

DataPtr CreateShardedDataFromShards(std::vector<DataPtr> shards,
std::string device,
xla::Shape global_shape,
xla::OpSharding sharding) override;

std::vector<DataPtr> GetDataShards(DataPtr data) override;

DataPtr GetDataShard(DataPtr data, size_t index) override;
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,20 @@ ComputationClient::DataPtr PjRtComputationClient::TransferShardsToDevice(
sharding);
}

ComputationClient::DataPtr PjRtComputationClient::CreateShardedDataFromShards(
std::vector<DataPtr> shards, std::string device, xla::Shape global_shape,
xla::OpSharding sharding) {
tsl::profiler::TraceMe activity(
"PjRtComputationClient::CreateShardedDataFromShards",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<std::shared_ptr<PjRtData>> pjrt_shards;
for (auto shard : shards) {
pjrt_shards.push_back(std::dynamic_pointer_cast<PjRtData>(shard));
}
return std::make_shared<PjRtShardedData>(device, global_shape, pjrt_shards,
sharding);
}

ComputationClient::DataPtr PjRtComputationClient::CopyToDevice(
ComputationClient::DataPtr data, std::string dst) {
tsl::profiler::TraceMe activity("PjRtComputationClient::CopyToDevice",
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class PjRtComputationClient : public ComputationClient {
static DataPtr CreateData(std::string device, xla::Shape shape,
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer);

DataPtr CreateShardedDataFromShards(std::vector<DataPtr> shards,
std::string device,
xla::Shape global_shape,
xla::OpSharding sharding) override;

std::vector<DataPtr> GetDataShards(DataPtr data) override;

DataPtr GetDataShard(DataPtr data, size_t index) override;
Expand Down
27 changes: 23 additions & 4 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -622,10 +622,26 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
source_tensors.push_back(std::make_shared<runtime::AtenSource>(
local_shards[j], shard_shape, devices[j]));
}
std::cout << "ShardingUtil::CreateShardedData check global shape"
<< global_shape.ToString() << std::endl;
return runtime::GetComputationClient()->TransferShardsToDevice(
source_tensors, GetVirtualDevice().toString(), global_shape, sharding);
}

runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedDataFromShards(
const std::vector<runtime::ComputationClient::DataPtr>& local_shards,
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec) {
XLA_CHECK(local_shards.size() == devices.size())
<< "A device must be speficied for each shard";
XLA_CHECK(sharding_spec != nullptr)
<< "A device must be speficied for each shard";
xla::Shape global_shape = sharding_spec->shape;
xla::OpSharding sharding = sharding_spec->sharding;
return runtime::GetComputationClient()->CreateShardedDataFromShards(
local_shards, GetVirtualDevice().toString(), global_shape, sharding);
}

std::vector<int64_t> ShardingUtil::GetAutoShardingMesh() {
// Auto-sharding uses mesh_shape = {n_devices, 1} if XLA_AUTO_SPMD_MESH
// is not set. XLA_AUTO_SPMD_MESH takes a form of string, "2,2" which
Expand Down Expand Up @@ -779,11 +795,14 @@ void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
XLA_CHECK(sharding.type() != xla::OpSharding::UNKNOWN)
<< "Can't explicilty annotate with UNKNOWN sharding type.";
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
auto xla_shape = MakeShapeWithDeviceLayout(
xtensor->shape(),
static_cast<XlaDeviceType>(xtensor->GetDevice().type()));
XLATensor::ShardingSpecPtr new_sharding_spec =
std::make_shared<XLATensor::ShardingSpec>(
sharding, MakeShapeWithDeviceLayout(
xtensor->shape(), static_cast<XlaDeviceType>(
xtensor->GetDevice().type())));
std::make_shared<XLATensor::ShardingSpec>(sharding, xla_shape);
std::cout << "XlaMarkSharding... "
<< new_sharding_spec->sharding.DebugString() << std::endl;
std::cout << "xla shape..." << xla_shape.ToString() << std::endl;

// For Non DeviceData IR values, we directly attach the sharding spec
// to the xtensor.
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ class ShardingUtil {
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec);

static runtime::ComputationClient::DataPtr CreateShardedDataFromShards(
const std::vector<runtime::ComputationClient::DataPtr>& local_shards,
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec);

static void XlaMarkSharding(const at::Tensor& input,
xla::OpSharding sharding);

Expand Down
22 changes: 22 additions & 0 deletions torch_xla/distributed/spmd/dataloading_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, List, Optional, Sequence, Set, Tuple, Union

import torch
import torch_xla


def create_global_tensor_from_shards(shape: Tuple[int],
sharding: torch_xla._XLAC.OpSharding,
shards: Sequence[torch.Tensor]):
"""
Similar to jax.make_array_from_single_device_arrays
"""
# Now this function relies on caller to pass matching sharding and shape.
# TODO(lsy323): Check if shape, sharding and shards are matching.
if shards[0].device.type == 'cpu':
assert all(
s.device.type == 'cpu' for s in shards), "All shards must be on CPU."
from_cpu_shards = torch_xla._XLAC._global_tensor_from_tpu_shards
return from_cpu_shards(shards, sharding, shape)
else:
from_tpu_shards = torch_xla._XLAC._global_tensor_from_tpu_shards
return from_tpu_shards(shards, sharding, shape)
Loading