From 0061cdd9061e4a8e9b4c85bed4170ec5085ec4a0 Mon Sep 17 00:00:00 2001 From: Igor Chorazewicz Date: Wed, 28 Aug 2024 00:37:52 +0200 Subject: [PATCH 1/2] [L0 v2] Use ze_handle_wrapper to hold ze context This was the last place where we stored ze handle directly. Now, we always use ze_handle_wrapper for managing ownership which makes the code more robust. --- source/adapters/level_zero/v2/common.hpp | 5 +++++ source/adapters/level_zero/v2/context.cpp | 18 +++--------------- source/adapters/level_zero/v2/context.hpp | 6 +++--- .../level_zero/v2/event_provider_counter.cpp | 6 +++--- .../level_zero/v2/event_provider_normal.cpp | 2 +- .../level_zero/v2/command_list_cache_test.cpp | 4 ++-- 6 files changed, 17 insertions(+), 24 deletions(-) diff --git a/source/adapters/level_zero/v2/common.hpp b/source/adapters/level_zero/v2/common.hpp index ffef317ae8..6db588ee5d 100644 --- a/source/adapters/level_zero/v2/common.hpp +++ b/source/adapters/level_zero/v2/common.hpp @@ -54,6 +54,8 @@ struct ze_handle_wrapper { try { reset(); } catch (...) { + // TODO: add appropriate logging or pass the error + // to the caller (make the dtor noexcept(false) or use tls?) } } @@ -94,5 +96,8 @@ using ze_event_handle_t = using ze_event_pool_handle_t = ze_handle_wrapper<::ze_event_pool_handle_t, zeEventPoolDestroy>; +using ze_context_handle_t = + ze_handle_wrapper<::ze_context_handle_t, zeContextDestroy>; + } // namespace raii } // namespace v2 diff --git a/source/adapters/level_zero/v2/context.cpp b/source/adapters/level_zero/v2/context.cpp index 08032fe85e..2792694c74 100644 --- a/source/adapters/level_zero/v2/context.cpp +++ b/source/adapters/level_zero/v2/context.cpp @@ -17,8 +17,8 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext, uint32_t numDevices, const ur_device_handle_t *phDevices, bool ownZeContext) - : hContext(hContext), hDevices(phDevices, phDevices + numDevices), - commandListCache(hContext), + : hContext(hContext, ownZeContext), + hDevices(phDevices, phDevices + numDevices), commandListCache(hContext), eventPoolCache(phDevices[0]->Platform->getNumDevices(), [context = this, platform = phDevices[0]->Platform](DeviceId deviceId) { @@ -27,19 +27,7 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext, return std::make_unique( context, device, v2::EVENT_COUNTER, v2::QUEUE_IMMEDIATE); - }) { - std::ignore = ownZeContext; -} - -ur_context_handle_t_::~ur_context_handle_t_() noexcept(false) { - // ur_context_handle_t_ is only created/destroyed through urContextCreate - // and urContextRelease so it's safe to throw here - ZE2UR_CALL_THROWS(zeContextDestroy, (hContext)); -} - -ze_context_handle_t ur_context_handle_t_::getZeHandle() const { - return hContext; -} + }) {} ur_result_t ur_context_handle_t_::retain() { RefCount.increment(); diff --git a/source/adapters/level_zero/v2/context.hpp b/source/adapters/level_zero/v2/context.hpp index 69bf406594..0ed701400d 100644 --- a/source/adapters/level_zero/v2/context.hpp +++ b/source/adapters/level_zero/v2/context.hpp @@ -13,17 +13,17 @@ #include #include "command_list_cache.hpp" +#include "common.hpp" #include "event_pool_cache.hpp" struct ur_context_handle_t_ : _ur_object { ur_context_handle_t_(ze_context_handle_t hContext, uint32_t numDevices, const ur_device_handle_t *phDevices, bool ownZeContext); - ~ur_context_handle_t_() noexcept(false); ur_result_t retain(); ur_result_t release(); - ze_context_handle_t getZeHandle() const; + inline ze_context_handle_t getZeHandle() const { return hContext.get(); } ur_platform_handle_t getPlatform() const; const std::vector &getDevices() const; @@ -31,7 +31,7 @@ struct ur_context_handle_t_ : _ur_object { // For that the Device or its root devices need to be in the context. bool isValidDevice(ur_device_handle_t Device) const; - const ze_context_handle_t hContext; + const v2::raii::ze_context_handle_t hContext; const std::vector hDevices; v2::command_list_cache_t commandListCache; v2::event_pool_cache eventPoolCache; diff --git a/source/adapters/level_zero/v2/event_provider_counter.cpp b/source/adapters/level_zero/v2/event_provider_counter.cpp index 5334b2f888..353704f9ad 100644 --- a/source/adapters/level_zero/v2/event_provider_counter.cpp +++ b/source/adapters/level_zero/v2/event_provider_counter.cpp @@ -27,9 +27,9 @@ provider_counter::provider_counter(ur_platform_handle_t platform, ZE2UR_CALL_THROWS(zeDriverGetExtensionFunctionAddress, (platform->ZeDriver, "zexCounterBasedEventCreate", (void **)&this->eventCreateFunc)); - ZE2UR_CALL_THROWS( - zelLoaderTranslateHandle, - (ZEL_HANDLE_CONTEXT, context->hContext, (void **)&translatedContext)); + ZE2UR_CALL_THROWS(zelLoaderTranslateHandle, + (ZEL_HANDLE_CONTEXT, context->getZeHandle(), + (void **)&translatedContext)); ZE2UR_CALL_THROWS( zelLoaderTranslateHandle, (ZEL_HANDLE_DEVICE, device->ZeDevice, (void **)&translatedDevice)); diff --git a/source/adapters/level_zero/v2/event_provider_normal.cpp b/source/adapters/level_zero/v2/event_provider_normal.cpp index f5a1c940c6..808d795fc9 100644 --- a/source/adapters/level_zero/v2/event_provider_normal.cpp +++ b/source/adapters/level_zero/v2/event_provider_normal.cpp @@ -43,7 +43,7 @@ provider_pool::provider_pool(ur_context_handle_t context, } ZE2UR_CALL_THROWS(zeEventPoolCreate, - (context->hContext, &desc, 1, + (context->getZeHandle(), &desc, 1, const_cast(&device->ZeDevice), pool.ptr())); diff --git a/test/adapters/level_zero/v2/command_list_cache_test.cpp b/test/adapters/level_zero/v2/command_list_cache_test.cpp index 74bcbf4634..44755b699e 100644 --- a/test/adapters/level_zero/v2/command_list_cache_test.cpp +++ b/test/adapters/level_zero/v2/command_list_cache_test.cpp @@ -23,7 +23,7 @@ struct CommandListCacheTest : public uur::urContextTest {}; UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(CommandListCacheTest); TEST_P(CommandListCacheTest, CanStoreAndRetriveImmediateAndRegularCmdLists) { - v2::command_list_cache_t cache(context->hContext); + v2::command_list_cache_t cache(context->getZeHandle()); bool IsInOrder = false; uint32_t Ordinal = 0; @@ -75,7 +75,7 @@ TEST_P(CommandListCacheTest, CanStoreAndRetriveImmediateAndRegularCmdLists) { } TEST_P(CommandListCacheTest, ImmediateCommandListsHaveProperAttributes) { - v2::command_list_cache_t cache(context->hContext); + v2::command_list_cache_t cache(context->getZeHandle()); uint32_t numQueueGroups = 0; ASSERT_EQ(zeDeviceGetCommandQueueGroupProperties(device->ZeDevice, From 35b8ef5085e5b8330bc13de04d7099d981846c76 Mon Sep 17 00:00:00 2001 From: Igor Chorazewicz Date: Wed, 28 Aug 2024 00:47:10 +0200 Subject: [PATCH 2/2] [L0 v2] Use ze_handle_wraper instead of custom raii handle in command_list_cache --- .../level_zero/v2/command_list_cache.cpp | 22 +++++++++---------- .../level_zero/v2/command_list_cache.hpp | 15 ++++++------- source/adapters/level_zero/v2/common.hpp | 3 +++ 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/source/adapters/level_zero/v2/command_list_cache.cpp b/source/adapters/level_zero/v2/command_list_cache.cpp index eee6555f87..651cb5944a 100644 --- a/source/adapters/level_zero/v2/command_list_cache.cpp +++ b/source/adapters/level_zero/v2/command_list_cache.cpp @@ -43,7 +43,7 @@ inline size_t command_list_descriptor_hash_t::operator()( command_list_cache_t::command_list_cache_t(ze_context_handle_t ZeContext) : ZeContext{ZeContext} {} -raii::ze_command_list_t +raii::ze_command_list_handle_t command_list_cache_t::createCommandList(const command_list_descriptor_t &desc) { if (auto ImmCmdDesc = std::get_if(&desc)) { @@ -61,7 +61,7 @@ command_list_cache_t::createCommandList(const command_list_descriptor_t &desc) { ZE2UR_CALL_THROWS( zeCommandListCreateImmediate, (ZeContext, ImmCmdDesc->ZeDevice, &QueueDesc, &ZeCommandList)); - return raii::ze_command_list_t(ZeCommandList, &zeCommandListDestroy); + return raii::ze_command_list_handle_t(ZeCommandList); } else { auto RegCmdDesc = std::get(desc); ZeStruct CmdListDesc; @@ -72,7 +72,7 @@ command_list_cache_t::createCommandList(const command_list_descriptor_t &desc) { ze_command_list_handle_t ZeCommandList; ZE2UR_CALL_THROWS(zeCommandListCreate, (ZeContext, RegCmdDesc.ZeDevice, &CmdListDesc, &ZeCommandList)); - return raii::ze_command_list_t(ZeCommandList, &zeCommandListDestroy); + return raii::ze_command_list_handle_t(ZeCommandList); } } @@ -94,8 +94,7 @@ command_list_cache_t::getImmediateCommandList( auto CommandList = getCommandList(Desc).release(); return raii::cache_borrowed_command_list_t( CommandList, [Cache = this, Desc](ze_command_list_handle_t CmdList) { - Cache->addCommandList( - Desc, raii::ze_command_list_t(CmdList, &zeCommandListDestroy)); + Cache->addCommandList(Desc, raii::ze_command_list_handle_t(CmdList)); }); } @@ -113,12 +112,11 @@ command_list_cache_t::getRegularCommandList(ze_device_handle_t ZeDevice, return raii::cache_borrowed_command_list_t( CommandList, [Cache = this, Desc](ze_command_list_handle_t CmdList) { - Cache->addCommandList( - Desc, raii::ze_command_list_t(CmdList, &zeCommandListDestroy)); + Cache->addCommandList(Desc, raii::ze_command_list_handle_t(CmdList)); }); } -raii::ze_command_list_t +raii::ze_command_list_handle_t command_list_cache_t::getCommandList(const command_list_descriptor_t &desc) { std::unique_lock Lock(ZeCommandListCacheMutex); auto it = ZeCommandListCache.find(desc); @@ -129,7 +127,8 @@ command_list_cache_t::getCommandList(const command_list_descriptor_t &desc) { assert(!it->second.empty()); - raii::ze_command_list_t CommandListHandle = std::move(it->second.top()); + raii::ze_command_list_handle_t CommandListHandle = + std::move(it->second.top()); it->second.pop(); if (it->second.empty()) @@ -138,8 +137,9 @@ command_list_cache_t::getCommandList(const command_list_descriptor_t &desc) { return CommandListHandle; } -void command_list_cache_t::addCommandList(const command_list_descriptor_t &desc, - raii::ze_command_list_t cmdList) { +void command_list_cache_t::addCommandList( + const command_list_descriptor_t &desc, + raii::ze_command_list_handle_t cmdList) { // TODO: add a limit? std::unique_lock Lock(ZeCommandListCacheMutex); auto [it, _] = ZeCommandListCache.try_emplace(desc); diff --git a/source/adapters/level_zero/v2/command_list_cache.hpp b/source/adapters/level_zero/v2/command_list_cache.hpp index 1850a4334c..78692c975d 100644 --- a/source/adapters/level_zero/v2/command_list_cache.hpp +++ b/source/adapters/level_zero/v2/command_list_cache.hpp @@ -17,15 +17,13 @@ #include #include -#include "../common.hpp" +#include "common.hpp" namespace v2 { namespace raii { -using ze_command_list_t = std::unique_ptr<::_ze_command_list_handle_t, - decltype(&zeCommandListDestroy)>; using cache_borrowed_command_list_t = std::unique_ptr<::_ze_command_list_handle_t, - std::function>; + std::function>; } // namespace raii struct immediate_command_list_descriptor_t { @@ -72,15 +70,16 @@ struct command_list_cache_t { private: ze_context_handle_t ZeContext; std::unordered_map, + std::stack, command_list_descriptor_hash_t> ZeCommandListCache; ur_mutex ZeCommandListCacheMutex; - raii::ze_command_list_t getCommandList(const command_list_descriptor_t &desc); + raii::ze_command_list_handle_t + getCommandList(const command_list_descriptor_t &desc); void addCommandList(const command_list_descriptor_t &desc, - raii::ze_command_list_t cmdList); - raii::ze_command_list_t + raii::ze_command_list_handle_t cmdList); + raii::ze_command_list_handle_t createCommandList(const command_list_descriptor_t &desc); }; } // namespace v2 diff --git a/source/adapters/level_zero/v2/common.hpp b/source/adapters/level_zero/v2/common.hpp index 6db588ee5d..4fb851bad8 100644 --- a/source/adapters/level_zero/v2/common.hpp +++ b/source/adapters/level_zero/v2/common.hpp @@ -99,5 +99,8 @@ using ze_event_pool_handle_t = using ze_context_handle_t = ze_handle_wrapper<::ze_context_handle_t, zeContextDestroy>; +using ze_command_list_handle_t = + ze_handle_wrapper<::ze_command_list_handle_t, zeCommandListDestroy>; + } // namespace raii } // namespace v2