From 508565d1c3ba2b142a7644739cc8e0d5a05609ff Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 9 Dec 2024 15:49:16 -0800 Subject: [PATCH] IFRT proxy: Add profiler spans to all entrypoints at the client. PiperOrigin-RevId: 704444588 --- .../xla/xla/python/ifrt_proxy/client/BUILD | 4 +++ .../xla/xla/python/ifrt_proxy/client/array.cc | 33 ++++++++++++++++++- .../xla/python/ifrt_proxy/client/client.cc | 12 +++++++ .../xla/python/ifrt_proxy/client/compiler.cc | 13 ++++++-- .../python/ifrt_proxy/client/executable.cc | 23 +++++++++++++ 5 files changed, 82 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD index 0e9bc6f94bcd6f..2b382e414d415a 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -178,6 +178,7 @@ cc_library( "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -255,6 +256,7 @@ cc_library( "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -331,6 +333,7 @@ cc_library( "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -405,6 +408,7 @@ cc_library( "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", ], ) diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array.cc b/third_party/xla/xla/python/ifrt_proxy/client/array.cc index 546051bb10ef69..eabbbbd66e7987 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/array.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/array.cc @@ -54,6 +54,7 @@ #include "xla/tsl/concurrency/ref_count.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" namespace xla { namespace ifrt { @@ -113,6 +114,7 @@ Array::MakeArrayFromHostBuffer( return absl::UnimplementedError( "String arrays are not supported in ifrt-proxy version < 9"); } + tsl::profiler::TraceMe traceme("IfrtProxySerializeStringHostBuffer"); TF_ASSIGN_OR_RETURN( std::shared_ptr owned_data, SerializeStringHostBuffer(absl::MakeConstSpan( @@ -127,6 +129,12 @@ Array::MakeArrayFromHostBuffer( } }; } + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + [s = mem_region.size(), semantics]() { + return tsl::profiler::TraceMeEncode( + "IfrtProxyEntrypointMakeArrayFromHostBuffer", + {{"size", s}, {"semantics", static_cast(semantics)}}); + }); const uint64_t host_buffer_handle = rpc_helper->NextHandle(); @@ -226,6 +234,9 @@ void Array::Destruct(RpcHelper* rpc_helper, ArrayHandle handle) { } Future<> Array::GetReadyFuture() const { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointArrayGetReadyFuture"); + auto req = std::make_unique(); req->add_value_handles(handle_.handle); @@ -260,6 +271,8 @@ Future<> Array::Delete() { } bool Array::IsDeleted() const { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointIsDeleted"); if (GetGlobalClientFlags()->array_is_deleted_hack) { return false; } @@ -287,6 +300,14 @@ Array::AssembleArrayFromSingleDeviceArrays( absl::Span> arrays, ArrayCopySemantics array_copy_semantics, SingleDeviceShardSemantics single_device_shard_semantics) { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + [n_arrays = arrays.size(), single_device_shard_semantics]() { + return tsl::profiler::TraceMeEncode( + "IfrtProxyEntrypointAssembleArrayFromSingleDeviceArrays", + {{"n_arrays", n_arrays}, + {"sds_semantics", + static_cast(single_device_shard_semantics)}}); + }); if (single_device_shard_semantics == SingleDeviceShardSemantics::kAddressableShards && rpc_helper->version().protocol_version() < 8) { @@ -338,6 +359,10 @@ Array::RemapArrays(xla::ifrt::Client* client, std::shared_ptr rpc_helper, const RemapPlan& plan, absl::Span> arrays, ArrayCopySemantics semantics) { + tsl::profiler::TraceMe traceme_ifrt_entrypoint([n_arrays = arrays.size()]() { + return tsl::profiler::TraceMeEncode("IfrtProxyEntrypointRemapArrays", + {{"n_arrays", n_arrays}}); + }); auto req = std::make_unique(); TF_RET_CHECK(!arrays.empty()); TF_ASSIGN_OR_RETURN(*req->mutable_plan(), plan.ToProto()); @@ -393,6 +418,8 @@ absl::StatusOr>> Array::DisassembleIntoSingleDeviceArrays( ArrayCopySemantics array_copy_semantics, SingleDeviceShardSemantics single_device_shard_semantics) { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointDisassembleIntoSingleDeviceArrays"); if (single_device_shard_semantics == SingleDeviceShardSemantics::kAddressableShards && rpc_helper_->version().protocol_version() < 8) { @@ -446,6 +473,8 @@ Array::DisassembleIntoSingleDeviceArrays( absl::StatusOr> Array::FullyReplicatedShard( ArrayCopySemantics semantics) { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointFullyReplicatedShard"); auto req = std::make_unique(); req->set_array_handle(handle_.handle); req->set_copy_semantics(ToArrayCopySemanticsProto(semantics)); @@ -481,6 +510,8 @@ absl::StatusOr> Array::FullyReplicatedShard( Future<> Array::CopyToStringHostBuffer( void* data, std::optional> byte_strides, ArrayCopySemantics semantics) { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointCopyToStringHostBuffer"); if (rpc_helper_->version().protocol_version() < 9) { return Future<>(absl::UnimplementedError( "String arrays are not supported in ifrt-proxy version < 9")); @@ -540,7 +571,7 @@ Future<> Array::CopyToHostBuffer( if (dtype_.kind() == DType::kString) { return CopyToStringHostBuffer(data, byte_strides, semantics); } - + tsl::profiler::TraceMe traceme("IfrtProxyEntrypointCopyToHostBuffer"); const auto mem_region = ArrayMemRegion::FromZerothElementPointer( /*zeroth_element=*/data, dtype_, shape_, byte_strides); if (!mem_region.ok()) { diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.cc b/third_party/xla/xla/python/ifrt_proxy/client/client.cc index be76fbff124220..b8ec4ab1ed9f60 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/client.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/client.cc @@ -57,6 +57,7 @@ #include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" #include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" namespace xla { namespace ifrt { @@ -66,6 +67,7 @@ char Client::ID = 0; absl::StatusOr> Client::Create( std::shared_ptr rpc_helper, InitResponse init_response) { + tsl::profiler::TraceMe traceme("IfrtProxyEntrypointClientCreate"); absl::flat_hash_set addressable_device_ids( init_response.addressable_device_ids().begin(), init_response.addressable_device_ids().end()); @@ -254,6 +256,10 @@ Client::CopyArrays( absl::Span> arrays, std::optional> devices, std::optional memory_kind, ArrayCopySemantics semantics) { + tsl::profiler::TraceMe traceme_ifrt_entrypoint([n_arrays = arrays.size()]() { + return tsl::profiler::TraceMeEncode("IfrtProxyEntrypointCopyArrays", + {{"n_arrays", n_arrays}}); + }); if (arrays.empty()) { return std::vector>(); } @@ -334,6 +340,10 @@ Client::RemapArrays(const RemapPlan& plan, xla::ifrt::Future<> Client::GetReadyFuture( absl::Span> values) { + tsl::profiler::TraceMe traceme_ifrt_entrypoint([n_values = values.size()]() { + return tsl::profiler::TraceMeEncode("IfrtProxyEntrypointGetReadyFuture", + {{"n_values", n_values}}); + }); absl::InlinedVector, 1> futures; auto req = std::make_unique(); @@ -364,6 +374,8 @@ absl::Span Client::GetAllDevices() const { absl::StatusOr Client::GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointGetDefaultDeviceAssignment"); auto req = std::make_unique(); req->set_num_replicas(num_replicas); req->set_num_partitions(num_partitions); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/compiler.cc b/third_party/xla/xla/python/ifrt_proxy/client/compiler.cc index 53575b342edab4..08803931989329 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/compiler.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/compiler.cc @@ -44,6 +44,7 @@ #include "xla/tsl/concurrency/ref_count.h" #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" namespace xla { namespace ifrt { @@ -57,8 +58,16 @@ absl::StatusOr> Compiler::Compile( std::unique_ptr program, std::unique_ptr options) { auto request = std::make_unique(); - TF_ASSIGN_OR_RETURN(*request->mutable_program(), - Serialize(*program, /*options=*/nullptr)); + { + tsl::profiler::TraceMe traceme("IfrtProxyProgramSerialize"); + TF_ASSIGN_OR_RETURN(*request->mutable_program(), + Serialize(*program, /*options=*/nullptr)); + } + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + [prog_size = request->program().data().size()]() { + return tsl::profiler::TraceMeEncode( + "IfrtProxyEntrypointCompilerCompile", {{"prog_size", prog_size}}); + }); // Extract host callbacks from the XLA compile options. `XlaCompileOptions`'s // SerDes fails when it contains host callbacks, so the following diff --git a/third_party/xla/xla/python/ifrt_proxy/client/executable.cc b/third_party/xla/xla/python/ifrt_proxy/client/executable.cc index f68926b9985d39..81ef43ec5c0f3b 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/executable.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/executable.cc @@ -72,6 +72,7 @@ #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" +#include "tsl/profiler/lib/traceme.h" namespace xla { namespace ifrt { @@ -272,6 +273,9 @@ LoadedExecutable::LoadedExecutable( } } + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + + "IfrtProxyEntrypointLoadedExecutableCreate"); // Asynchronously fetch shardings. Since users of `LoadedExecutable` typically // require sharding information to invoke the executable, it is beneficial to // eagerly schedule this fetch since, in some implementations, it may take a @@ -362,6 +366,9 @@ LoadedExecutable::LoadedExecutable( } LoadedExecutable::~LoadedExecutable() { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointLoadedExecutableDestruct"); + auto req = std::make_unique(); req->set_loaded_executable_handle(handle_); @@ -406,6 +413,8 @@ absl::StatusOr LoadedExecutable::GetCompiledMemoryStats() std::optional> LoadedExecutable::GetParameterShardings() const { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointLoadedExecutableGetParameterShardings"); auto info = metadata_future_.Await(); if (!info.ok()) { return std::nullopt; @@ -415,6 +424,8 @@ std::optional> LoadedExecutable::GetParameterShardings() std::optional> LoadedExecutable::GetOutputShardings() const { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointLoadedExecutableGetOutputShardings"); auto info = metadata_future_.Await(); if (!info.ok()) { return std::nullopt; @@ -424,6 +435,8 @@ std::optional> LoadedExecutable::GetOutputShardings() absl::StatusOr>> LoadedExecutable::GetParameterLayouts() const { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointLoadedExecutableGetParameterLayouts"); TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); TF_RETURN_IF_ERROR(info->parameter_layouts.status()); @@ -437,6 +450,8 @@ LoadedExecutable::GetParameterLayouts() const { absl::StatusOr>> LoadedExecutable::GetOutputLayouts() const { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointLoadedExecutableGetOutputLayouts"); TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); TF_RETURN_IF_ERROR(info->output_layouts.status()); @@ -450,6 +465,8 @@ LoadedExecutable::GetOutputLayouts() const { absl::StatusOr>> LoadedExecutable::GetOutputMemoryKinds() const { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointLoadedExecutableGetOutputMemoryKinds"); TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); return info->output_memory_kinds; } @@ -471,6 +488,8 @@ LoadedExecutable::Execute( absl::Span> args, const ExecuteOptions& options, std::optional> devices) { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointLoadedExecutableExecute"); auto req = std::make_unique(); req->set_loaded_executable_handle(handle_); for (const auto& arg : args) { @@ -557,6 +576,8 @@ LoadedExecutable::Execute( } Future<> LoadedExecutable::Delete() { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointLoadedExecutableDelete"); auto req = std::make_unique(); req->set_loaded_executable_handle(handle_); @@ -580,6 +601,8 @@ Future<> LoadedExecutable::Delete() { } bool LoadedExecutable::IsDeleted() const { + tsl::profiler::TraceMe traceme_ifrt_entrypoint( + "IfrtProxyEntrypointLoadedExecutableIsDeleted"); auto req = std::make_unique(); req->set_loaded_executable_handle(handle_);