Skip to content

Commit

Permalink
Reverts fbe8a54
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 582275667
  • Loading branch information
hawkinsp authored and copybara-github committed Nov 14, 2023
1 parent 548a766 commit 1e05b09
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 360 deletions.
2 changes: 0 additions & 2 deletions xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,6 @@ cc_library(
"//xla/client:executable_build_options",
"//xla/client:xla_computation",
"//xla/hlo/ir:hlo",
"//xla/pjrt/distributed:topology_util",
"//xla/runtime:cpu_event",
"//xla/service:buffer_assignment",
"//xla/service:compiler",
Expand Down Expand Up @@ -700,7 +699,6 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3", # TODO(zhangqiaorjc): Remove if use TFRT threadpool.
"@llvm-project//mlir:IR",
Expand Down
62 changes: 19 additions & 43 deletions xla/pjrt/tfrt_cpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
Expand All @@ -58,7 +57,6 @@ limitations under the License.
#include "xla/literal_util.h"
#include "xla/pjrt/abstract_tfrt_cpu_buffer.h"
#include "xla/pjrt/compile_options.pb.h"
#include "xla/pjrt/distributed/topology_util.h"
#include "xla/pjrt/mlir_to_hlo.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
Expand Down Expand Up @@ -238,11 +236,7 @@ class TfrtCpuAsyncHostToDeviceTransferManager

} // namespace

TfrtCpuDeviceDescription::TfrtCpuDeviceDescription(int id, int process_index,
int local_hardware_id)
: id_(id),
process_index_(process_index),
local_hardware_id_(local_hardware_id) {
TfrtCpuDeviceDescription::TfrtCpuDeviceDescription(int id) : id_(id) {
debug_string_ = absl::StrCat("TFRT_CPU_", id);
to_string_ = absl::StrCat("CpuDevice(id=", id, ")");
}
Expand All @@ -259,9 +253,8 @@ absl::string_view TfrtCpuDeviceDescription::ToString() const {
return to_string_;
}

TfrtCpuDevice::TfrtCpuDevice(int id, int process_index, int local_hardware_id,
int max_inflight_computations)
: description_(id, process_index, local_hardware_id),
TfrtCpuDevice::TfrtCpuDevice(int id, int max_inflight_computations)
: description_(id),
max_inflight_computations_semaphore_(
/*capacity=*/max_inflight_computations) {}

Expand All @@ -288,47 +281,30 @@ static int CpuDeviceCount() {
return GetDebugOptionsFromFlags().xla_force_host_platform_device_count();
}

static StatusOr<std::vector<std::unique_ptr<TfrtCpuDevice>>> GetTfrtCpuDevices(
int cpu_device_count, int max_inflight_computations_per_device) {
std::vector<std::unique_ptr<TfrtCpuDevice>> devices;
for (int i = 0; i < cpu_device_count; ++i) {
auto device = std::make_unique<TfrtCpuDevice>(
/*id=*/i, max_inflight_computations_per_device);
devices.push_back(std::move(device));
}
return std::move(devices);
}

StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(
const CpuClientOptions& options) {
// Need at least CpuDeviceCount threads to launch one collective.
int cpu_device_count = options.cpu_device_count.value_or(CpuDeviceCount());
size_t num_threads = std::max(DefaultThreadPoolSize(), cpu_device_count);

LocalTopologyProto local_topology;
local_topology.set_node_id(options.node_id);
std::string boot_id_str;
auto boot_id_str_or_status = GetBootIdString();
if (!boot_id_str_or_status.ok()) {
LOG(INFO) << boot_id_str_or_status.status();
} else {
boot_id_str = boot_id_str_or_status.value();
}
local_topology.set_boot_id(boot_id_str);
for (int i = 0; i < cpu_device_count; ++i) {
DeviceProto* device_proto = local_topology.add_devices();
device_proto->set_local_device_ordinal(i);
device_proto->set_name("cpu");
}

GlobalTopologyProto global_topology;
TF_RETURN_IF_ERROR(
ExchangeTopologies("cpu", options.node_id, options.num_nodes,
absl::Minutes(2), absl::Minutes(5), options.kv_get,
options.kv_put, local_topology, &global_topology));

std::vector<std::unique_ptr<TfrtCpuDevice>> devices;
for (const LocalTopologyProto& node : global_topology.nodes()) {
for (const DeviceProto& device_proto : node.devices()) {
auto device = std::make_unique<TfrtCpuDevice>(
/*id=*/device_proto.global_device_id(), node.node_id(),
device_proto.local_device_ordinal(),
options.max_inflight_computations_per_device);
devices.push_back(std::move(device));
}
}
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<TfrtCpuDevice>> devices,
GetTfrtCpuDevices(cpu_device_count,
options.max_inflight_computations_per_device));

return std::unique_ptr<PjRtClient>(std::make_unique<TfrtCpuClient>(
/*process_index=*/options.node_id, std::move(devices), num_threads));
/*process_index=*/0, std::move(devices), num_threads));
}

TfrtCpuClient::TfrtCpuClient(
Expand Down
27 changes: 5 additions & 22 deletions xla/pjrt/tfrt_cpu_pjrt_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,11 @@ namespace xla {

class TfrtCpuDeviceDescription final : public PjRtDeviceDescription {
public:
TfrtCpuDeviceDescription(int id, int process_index, int local_hardware_id);
explicit TfrtCpuDeviceDescription(int id);

int id() const override { return id_; }

int process_index() const override { return process_index_; }

int local_hardware_id() const { return local_hardware_id_; }
int process_index() const override { return 0; }

absl::string_view device_kind() const override;

Expand All @@ -84,17 +82,14 @@ class TfrtCpuDeviceDescription final : public PjRtDeviceDescription {

private:
int id_;
int process_index_;
int local_hardware_id_;
std::string debug_string_;
std::string to_string_;
absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_ = {};
};

class TfrtCpuDevice final : public PjRtDevice {
public:
explicit TfrtCpuDevice(int id, int process_index, int local_hardware_id,
int max_inflight_computations = 32);
explicit TfrtCpuDevice(int id, int max_inflight_computations = 32);

const TfrtCpuDeviceDescription& description() const override {
return description_;
Expand All @@ -111,9 +106,8 @@ class TfrtCpuDevice final : public PjRtDevice {
return process_index() == client()->process_index();
}

int local_hardware_id() const override {
return description_.local_hardware_id();
}
// Used as `device_ordinal`.
int local_hardware_id() const override { return id(); }

Status TransferToInfeed(const LiteralSlice& literal) override;

Expand Down Expand Up @@ -524,17 +518,6 @@ struct CpuClientOptions {
std::optional<int> cpu_device_count = std::nullopt;

int max_inflight_computations_per_device = 32;

// Number of distributed nodes. node_id, kv_get, and kv_put are ignored if
// this is set to 1.
int num_nodes = 1;

// My node ID.
int node_id = 0;

// KV store primitives for sharing topology information.
PjRtClient::KeyValueGetCallback kv_get = nullptr;
PjRtClient::KeyValuePutCallback kv_put = nullptr;
};
StatusOr<std::unique_ptr<PjRtClient>> GetTfrtCpuClient(
const CpuClientOptions& options);
Expand Down
1 change: 0 additions & 1 deletion xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,6 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@local_config_python//:python_headers", # buildcleaner: keep
"//xla:literal",
Expand Down
27 changes: 2 additions & 25 deletions xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ limitations under the License.
#include "absl/strings/ascii.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "pybind11/attr.h" // from @pybind11
#include "pybind11/cast.h" // from @pybind11
Expand Down Expand Up @@ -493,38 +492,16 @@ static void Init(py::module_& m) {

m.def(
"get_tfrt_cpu_client",
[](bool asynchronous,
std::shared_ptr<DistributedRuntimeClient> distributed_client,
int node_id, int num_nodes) -> std::shared_ptr<PyClient> {
[](bool asynchronous) -> std::shared_ptr<PyClient> {
py::gil_scoped_release gil_release;
CpuClientOptions options;
if (distributed_client != nullptr) {
std::string key_prefix = "cpu:";
options.kv_get =
[distributed_client, key_prefix](
const std::string& k,
absl::Duration timeout) -> xla::StatusOr<std::string> {
return distributed_client->BlockingKeyValueGet(
absl::StrCat(key_prefix, k), timeout);
};
options.kv_put = [distributed_client, key_prefix](
const std::string& k,
const std::string& v) -> xla::Status {
return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k),
v);
};
options.node_id = node_id;
options.num_nodes = num_nodes;
}

options.asynchronous = asynchronous;
std::unique_ptr<PjRtClient> client =
xla::ValueOrThrow(GetTfrtCpuClient(options));
return std::make_shared<PyClient>(
ifrt::PjRtClient::Create(std::move(client)));
},
py::arg("asynchronous") = true, py::arg("distributed_client") = nullptr,
py::arg("node_id") = 0, py::arg("num_nodes") = 1);
py::arg("asynchronous") = true);
m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool {
xla::StatusOr<const PJRT_Api*> pjrt_api = pjrt::PjrtApi(platform_name);
return pjrt_api.ok();
Expand Down
Loading

0 comments on commit 1e05b09

Please sign in to comment.