Skip to content

Commit

Permalink
IFRT proxy: Add profiler spans to all entrypoints at the client.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704444588
  • Loading branch information
tensorflower-gardener committed Dec 9, 2024
1 parent a3ec1cc commit 508565d
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 3 deletions.
4 changes: 4 additions & 0 deletions third_party/xla/xla/python/ifrt_proxy/client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@local_tsl//tsl/platform:casts",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/profiler/lib:traceme",
],
)

Expand Down Expand Up @@ -255,6 +256,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/profiler/lib:traceme",
],
)

Expand Down Expand Up @@ -331,6 +333,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@local_tsl//tsl/platform:status_to_from_proto",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/profiler/lib:traceme",
],
)

Expand Down Expand Up @@ -405,6 +408,7 @@ cc_library(
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:status_to_from_proto",
"@local_tsl//tsl/platform:statusor",
"@local_tsl//tsl/profiler/lib:traceme",
],
)

Expand Down
33 changes: 32 additions & 1 deletion third_party/xla/xla/python/ifrt_proxy/client/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include "xla/tsl/concurrency/ref_count.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla {
namespace ifrt {
Expand Down Expand Up @@ -113,6 +114,7 @@ Array::MakeArrayFromHostBuffer(
return absl::UnimplementedError(
"String arrays are not supported in ifrt-proxy version < 9");
}
tsl::profiler::TraceMe traceme("IfrtProxySerializeStringHostBuffer");
TF_ASSIGN_OR_RETURN(
std::shared_ptr<std::string> owned_data,
SerializeStringHostBuffer(absl::MakeConstSpan(
Expand All @@ -127,6 +129,12 @@ Array::MakeArrayFromHostBuffer(
}
};
}
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
[s = mem_region.size(), semantics]() {
return tsl::profiler::TraceMeEncode(
"IfrtProxyEntrypointMakeArrayFromHostBuffer",
{{"size", s}, {"semantics", static_cast<int>(semantics)}});
});

const uint64_t host_buffer_handle = rpc_helper->NextHandle();

Expand Down Expand Up @@ -226,6 +234,9 @@ void Array::Destruct(RpcHelper* rpc_helper, ArrayHandle handle) {
}

Future<> Array::GetReadyFuture() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointArrayGetReadyFuture");

auto req = std::make_unique<CheckValueReadyRequest>();
req->add_value_handles(handle_.handle);

Expand Down Expand Up @@ -260,6 +271,8 @@ Future<> Array::Delete() {
}

bool Array::IsDeleted() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointIsDeleted");
if (GetGlobalClientFlags()->array_is_deleted_hack) {
return false;
}
Expand Down Expand Up @@ -287,6 +300,14 @@ Array::AssembleArrayFromSingleDeviceArrays(
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
[n_arrays = arrays.size(), single_device_shard_semantics]() {
return tsl::profiler::TraceMeEncode(
"IfrtProxyEntrypointAssembleArrayFromSingleDeviceArrays",
{{"n_arrays", n_arrays},
{"sds_semantics",
static_cast<int>(single_device_shard_semantics)}});
});
if (single_device_shard_semantics ==
SingleDeviceShardSemantics::kAddressableShards &&
rpc_helper->version().protocol_version() < 8) {
Expand Down Expand Up @@ -338,6 +359,10 @@ Array::RemapArrays(xla::ifrt::Client* client,
std::shared_ptr<RpcHelper> rpc_helper, const RemapPlan& plan,
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
ArrayCopySemantics semantics) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint([n_arrays = arrays.size()]() {
return tsl::profiler::TraceMeEncode("IfrtProxyEntrypointRemapArrays",
{{"n_arrays", n_arrays}});
});
auto req = std::make_unique<RemapArraysRequest>();
TF_RET_CHECK(!arrays.empty());
TF_ASSIGN_OR_RETURN(*req->mutable_plan(), plan.ToProto());
Expand Down Expand Up @@ -393,6 +418,8 @@ absl::StatusOr<std::vector<tsl::RCReference<xla::ifrt::Array>>>
Array::DisassembleIntoSingleDeviceArrays(
ArrayCopySemantics array_copy_semantics,
SingleDeviceShardSemantics single_device_shard_semantics) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointDisassembleIntoSingleDeviceArrays");
if (single_device_shard_semantics ==
SingleDeviceShardSemantics::kAddressableShards &&
rpc_helper_->version().protocol_version() < 8) {
Expand Down Expand Up @@ -446,6 +473,8 @@ Array::DisassembleIntoSingleDeviceArrays(

absl::StatusOr<tsl::RCReference<xla::ifrt::Array>> Array::FullyReplicatedShard(
ArrayCopySemantics semantics) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointFullyReplicatedShard");
auto req = std::make_unique<FullyReplicatedShardRequest>();
req->set_array_handle(handle_.handle);
req->set_copy_semantics(ToArrayCopySemanticsProto(semantics));
Expand Down Expand Up @@ -481,6 +510,8 @@ absl::StatusOr<tsl::RCReference<xla::ifrt::Array>> Array::FullyReplicatedShard(
Future<> Array::CopyToStringHostBuffer(
void* data, std::optional<absl::Span<const int64_t>> byte_strides,
ArrayCopySemantics semantics) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointCopyToStringHostBuffer");
if (rpc_helper_->version().protocol_version() < 9) {
return Future<>(absl::UnimplementedError(
"String arrays are not supported in ifrt-proxy version < 9"));
Expand Down Expand Up @@ -540,7 +571,7 @@ Future<> Array::CopyToHostBuffer(
if (dtype_.kind() == DType::kString) {
return CopyToStringHostBuffer(data, byte_strides, semantics);
}

tsl::profiler::TraceMe traceme("IfrtProxyEntrypointCopyToHostBuffer");
const auto mem_region = ArrayMemRegion::FromZerothElementPointer(
/*zeroth_element=*/data, dtype_, shape_, byte_strides);
if (!mem_region.ok()) {
Expand Down
12 changes: 12 additions & 0 deletions third_party/xla/xla/python/ifrt_proxy/client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "xla/xla_data.pb.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla {
namespace ifrt {
Expand All @@ -66,6 +67,7 @@ char Client::ID = 0;

absl::StatusOr<std::unique_ptr<Client>> Client::Create(
std::shared_ptr<RpcHelper> rpc_helper, InitResponse init_response) {
tsl::profiler::TraceMe traceme("IfrtProxyEntrypointClientCreate");
absl::flat_hash_set<int> addressable_device_ids(
init_response.addressable_device_ids().begin(),
init_response.addressable_device_ids().end());
Expand Down Expand Up @@ -254,6 +256,10 @@ Client::CopyArrays(
absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
std::optional<tsl::RCReference<xla::ifrt::DeviceList>> devices,
std::optional<MemoryKind> memory_kind, ArrayCopySemantics semantics) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint([n_arrays = arrays.size()]() {
return tsl::profiler::TraceMeEncode("IfrtProxyEntrypointCopyArrays",
{{"n_arrays", n_arrays}});
});
if (arrays.empty()) {
return std::vector<tsl::RCReference<xla::ifrt::Array>>();
}
Expand Down Expand Up @@ -334,6 +340,10 @@ Client::RemapArrays(const RemapPlan& plan,

xla::ifrt::Future<> Client::GetReadyFuture(
absl::Span<const tsl::RCReference<xla::ifrt::Value>> values) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint([n_values = values.size()]() {
return tsl::profiler::TraceMeEncode("IfrtProxyEntrypointGetReadyFuture",
{{"n_values", n_values}});
});
absl::InlinedVector<Future<>, 1> futures;

auto req = std::make_unique<CheckValueReadyRequest>();
Expand Down Expand Up @@ -364,6 +374,8 @@ absl::Span<xla::ifrt::Device* const> Client::GetAllDevices() const {

absl::StatusOr<DeviceAssignment> Client::GetDefaultDeviceAssignment(
int num_replicas, int num_partitions) const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointGetDefaultDeviceAssignment");
auto req = std::make_unique<GetDefaultDeviceAssignmentRequest>();
req->set_num_replicas(num_replicas);
req->set_num_partitions(num_partitions);
Expand Down
13 changes: 11 additions & 2 deletions third_party/xla/xla/python/ifrt_proxy/client/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "xla/tsl/concurrency/ref_count.h"
#include "tsl/platform/status_to_from_proto.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla {
namespace ifrt {
Expand All @@ -57,8 +58,16 @@ absl::StatusOr<std::unique_ptr<xla::ifrt::LoadedExecutable>> Compiler::Compile(
std::unique_ptr<Program> program,
std::unique_ptr<xla::ifrt::CompileOptions> options) {
auto request = std::make_unique<CompileRequest>();
TF_ASSIGN_OR_RETURN(*request->mutable_program(),
Serialize(*program, /*options=*/nullptr));
{
tsl::profiler::TraceMe traceme("IfrtProxyProgramSerialize");
TF_ASSIGN_OR_RETURN(*request->mutable_program(),
Serialize(*program, /*options=*/nullptr));
}
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
[prog_size = request->program().data().size()]() {
return tsl::profiler::TraceMeEncode(
"IfrtProxyEntrypointCompilerCompile", {{"prog_size", prog_size}});
});

// Extract host callbacks from the XLA compile options. `XlaCompileOptions`'s
// SerDes fails when it contains host callbacks, so the following
Expand Down
23 changes: 23 additions & 0 deletions third_party/xla/xla/python/ifrt_proxy/client/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
#include "tsl/platform/status_to_from_proto.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/threadpool.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla {
namespace ifrt {
Expand Down Expand Up @@ -272,6 +273,9 @@ LoadedExecutable::LoadedExecutable(
}
}

tsl::profiler::TraceMe traceme_ifrt_entrypoint(

"IfrtProxyEntrypointLoadedExecutableCreate");
// Asynchronously fetch shardings. Since users of `LoadedExecutable` typically
// require sharding information to invoke the executable, it is beneficial to
// eagerly schedule this fetch since, in some implementations, it may take a
Expand Down Expand Up @@ -362,6 +366,9 @@ LoadedExecutable::LoadedExecutable(
}

LoadedExecutable::~LoadedExecutable() {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableDestruct");

auto req = std::make_unique<LoadedExecutableDestructRequest>();
req->set_loaded_executable_handle(handle_);

Expand Down Expand Up @@ -406,6 +413,8 @@ absl::StatusOr<CompiledMemoryStats> LoadedExecutable::GetCompiledMemoryStats()

std::optional<std::vector<OpSharding>> LoadedExecutable::GetParameterShardings()
const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableGetParameterShardings");
auto info = metadata_future_.Await();
if (!info.ok()) {
return std::nullopt;
Expand All @@ -415,6 +424,8 @@ std::optional<std::vector<OpSharding>> LoadedExecutable::GetParameterShardings()

std::optional<std::vector<OpSharding>> LoadedExecutable::GetOutputShardings()
const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableGetOutputShardings");
auto info = metadata_future_.Await();
if (!info.ok()) {
return std::nullopt;
Expand All @@ -424,6 +435,8 @@ std::optional<std::vector<OpSharding>> LoadedExecutable::GetOutputShardings()

absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>
LoadedExecutable::GetParameterLayouts() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableGetParameterLayouts");
TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await());
TF_RETURN_IF_ERROR(info->parameter_layouts.status());

Expand All @@ -437,6 +450,8 @@ LoadedExecutable::GetParameterLayouts() const {

absl::StatusOr<std::vector<std::unique_ptr<xla::PjRtLayout>>>
LoadedExecutable::GetOutputLayouts() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableGetOutputLayouts");
TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await());
TF_RETURN_IF_ERROR(info->output_layouts.status());

Expand All @@ -450,6 +465,8 @@ LoadedExecutable::GetOutputLayouts() const {

absl::StatusOr<std::vector<std::vector<absl::string_view>>>
LoadedExecutable::GetOutputMemoryKinds() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableGetOutputMemoryKinds");
TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await());
return info->output_memory_kinds;
}
Expand All @@ -471,6 +488,8 @@ LoadedExecutable::Execute(
absl::Span<tsl::RCReference<xla::ifrt::Array>> args,
const ExecuteOptions& options,
std::optional<tsl::RCReference<xla::ifrt::DeviceList>> devices) {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableExecute");
auto req = std::make_unique<LoadedExecutableExecuteRequest>();
req->set_loaded_executable_handle(handle_);
for (const auto& arg : args) {
Expand Down Expand Up @@ -557,6 +576,8 @@ LoadedExecutable::Execute(
}

Future<> LoadedExecutable::Delete() {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableDelete");
auto req = std::make_unique<LoadedExecutableDeleteRequest>();
req->set_loaded_executable_handle(handle_);

Expand All @@ -580,6 +601,8 @@ Future<> LoadedExecutable::Delete() {
}

bool LoadedExecutable::IsDeleted() const {
tsl::profiler::TraceMe traceme_ifrt_entrypoint(
"IfrtProxyEntrypointLoadedExecutableIsDeleted");
auto req = std::make_unique<LoadedExecutableIsDeletedRequest>();
req->set_loaded_executable_handle(handle_);

Expand Down

0 comments on commit 508565d

Please sign in to comment.