diff --git a/tensorflow/core/common_runtime/dml/dml_execution_context.cc b/tensorflow/core/common_runtime/dml/dml_execution_context.cc index 8a24539a89..056b4588ce 100644 --- a/tensorflow/core/common_runtime/dml/dml_execution_context.cc +++ b/tensorflow/core/common_runtime/dml/dml_execution_context.cc @@ -20,6 +20,14 @@ limitations under the License. #include "dml_tracing.h" #include "dml_util.h" #include "tensorflow/core/util/env_var.h" +#include "tensorflow/stream_executor/platform/default/dso_loader.h" + +#if _WIN32 +typedef HRESULT(WINAPI* SetThreadDescriptionFn)(HANDLE hThread, + PCWSTR lpThreadDescription); + +static SetThreadDescriptionFn g_setThreadDescription = nullptr; +#endif namespace tensorflow { @@ -27,6 +35,17 @@ DmlExecutionContext::DmlExecutionContext(ID3D12Device* d3d_device, IDMLDevice* dml_device, ID3D12CommandQueue* queue, DmlAllocator* allocator) { +#if _WIN32 + auto kernel32_handle_or = + stream_executor::internal::CachedDsoLoader::GetKernel32DsoHandle(); + + if (kernel32_handle_or.ok()) { + tensorflow::Env::Default()->GetSymbolFromLibrary( + kernel32_handle_or.ValueOrDie(), "SetThreadDescription", + reinterpret_cast(&g_setThreadDescription)); + } +#endif + dml_command_queue_ = std::make_shared(queue); batch_state_ = std::make_shared(); @@ -212,7 +231,9 @@ D3D12_COMMAND_LIST_TYPE DmlExecutionContext::GetCommandListTypeForQueue() std::shared_ptr command_queue, uint32_t batch_flush_size, uint32_t batch_flush_time_us) { #if _WIN32 - SetThreadDescription(GetCurrentThread(), L"TFDML Execution Thread"); + if (g_setThreadDescription) { + g_setThreadDescription(GetCurrentThread(), L"TFDML Execution Thread"); + } #endif auto last_flush_time = std::chrono::steady_clock::now(); diff --git a/tensorflow/stream_executor/platform/default/dso_loader.cc b/tensorflow/stream_executor/platform/default/dso_loader.cc index b365d868e6..b93d192e00 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.cc +++ b/tensorflow/stream_executor/platform/default/dso_loader.cc @@ -234,6 +234,14 @@ port::StatusOr GetPixDsoHandle() { #endif } +port::StatusOr GetKernel32DsoHandle() { +#if _WIN32 + return GetDsoHandle("Kernel32", ""); +#else + return port::Status(port::error::UNIMPLEMENTED, "Kernel32.dll is only available on Windows"); +#endif +} + } // namespace DsoLoader namespace CachedDsoLoader { @@ -322,6 +330,11 @@ port::StatusOr GetPixDsoHandle() { return *result; } +port::StatusOr GetKernel32DsoHandle() { + static auto result = new auto(DsoLoader::GetKernel32DsoHandle()); + return *result; +} + } // namespace CachedDsoLoader } // namespace internal } // namespace stream_executor diff --git a/tensorflow/stream_executor/platform/default/dso_loader.h b/tensorflow/stream_executor/platform/default/dso_loader.h index cab8c3cbc9..721c6f7ff5 100644 --- a/tensorflow/stream_executor/platform/default/dso_loader.h +++ b/tensorflow/stream_executor/platform/default/dso_loader.h @@ -54,6 +54,7 @@ port::StatusOr GetHipDsoHandle(); port::StatusOr GetDirectMLDsoHandle(); port::StatusOr GetDirectMLDebugDsoHandle(); port::StatusOr GetPixDsoHandle(); +port::StatusOr GetKernel32DsoHandle(); // The following method tries to dlopen all necessary GPU libraries for the GPU // platform TF is built with (CUDA or ROCm) only when these libraries should be @@ -89,6 +90,7 @@ port::StatusOr GetHipDsoHandle(); port::StatusOr GetDirectMLDsoHandle(); port::StatusOr GetDirectMLDebugDsoHandle(); port::StatusOr GetPixDsoHandle(); +port::StatusOr GetKernel32DsoHandle(); } // namespace CachedDsoLoader } // namespace internal