From 35b8ef5085e5b8330bc13de04d7099d981846c76 Mon Sep 17 00:00:00 2001 From: Igor Chorazewicz Date: Wed, 28 Aug 2024 00:47:10 +0200 Subject: [PATCH] [L0 v2] Use ze_handle_wraper instead of custom raii handle in command_list_cache --- .../level_zero/v2/command_list_cache.cpp | 22 +++++++++---------- .../level_zero/v2/command_list_cache.hpp | 15 ++++++------- source/adapters/level_zero/v2/common.hpp | 3 +++ 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/source/adapters/level_zero/v2/command_list_cache.cpp b/source/adapters/level_zero/v2/command_list_cache.cpp index eee6555f87..651cb5944a 100644 --- a/source/adapters/level_zero/v2/command_list_cache.cpp +++ b/source/adapters/level_zero/v2/command_list_cache.cpp @@ -43,7 +43,7 @@ inline size_t command_list_descriptor_hash_t::operator()( command_list_cache_t::command_list_cache_t(ze_context_handle_t ZeContext) : ZeContext{ZeContext} {} -raii::ze_command_list_t +raii::ze_command_list_handle_t command_list_cache_t::createCommandList(const command_list_descriptor_t &desc) { if (auto ImmCmdDesc = std::get_if(&desc)) { @@ -61,7 +61,7 @@ command_list_cache_t::createCommandList(const command_list_descriptor_t &desc) { ZE2UR_CALL_THROWS( zeCommandListCreateImmediate, (ZeContext, ImmCmdDesc->ZeDevice, &QueueDesc, &ZeCommandList)); - return raii::ze_command_list_t(ZeCommandList, &zeCommandListDestroy); + return raii::ze_command_list_handle_t(ZeCommandList); } else { auto RegCmdDesc = std::get(desc); ZeStruct CmdListDesc; @@ -72,7 +72,7 @@ command_list_cache_t::createCommandList(const command_list_descriptor_t &desc) { ze_command_list_handle_t ZeCommandList; ZE2UR_CALL_THROWS(zeCommandListCreate, (ZeContext, RegCmdDesc.ZeDevice, &CmdListDesc, &ZeCommandList)); - return raii::ze_command_list_t(ZeCommandList, &zeCommandListDestroy); + return raii::ze_command_list_handle_t(ZeCommandList); } } @@ -94,8 +94,7 @@ command_list_cache_t::getImmediateCommandList( auto CommandList = getCommandList(Desc).release(); return raii::cache_borrowed_command_list_t( CommandList, [Cache = this, Desc](ze_command_list_handle_t CmdList) { - Cache->addCommandList( - Desc, raii::ze_command_list_t(CmdList, &zeCommandListDestroy)); + Cache->addCommandList(Desc, raii::ze_command_list_handle_t(CmdList)); }); } @@ -113,12 +112,11 @@ command_list_cache_t::getRegularCommandList(ze_device_handle_t ZeDevice, return raii::cache_borrowed_command_list_t( CommandList, [Cache = this, Desc](ze_command_list_handle_t CmdList) { - Cache->addCommandList( - Desc, raii::ze_command_list_t(CmdList, &zeCommandListDestroy)); + Cache->addCommandList(Desc, raii::ze_command_list_handle_t(CmdList)); }); } -raii::ze_command_list_t +raii::ze_command_list_handle_t command_list_cache_t::getCommandList(const command_list_descriptor_t &desc) { std::unique_lock Lock(ZeCommandListCacheMutex); auto it = ZeCommandListCache.find(desc); @@ -129,7 +127,8 @@ command_list_cache_t::getCommandList(const command_list_descriptor_t &desc) { assert(!it->second.empty()); - raii::ze_command_list_t CommandListHandle = std::move(it->second.top()); + raii::ze_command_list_handle_t CommandListHandle = + std::move(it->second.top()); it->second.pop(); if (it->second.empty()) @@ -138,8 +137,9 @@ command_list_cache_t::getCommandList(const command_list_descriptor_t &desc) { return CommandListHandle; } -void command_list_cache_t::addCommandList(const command_list_descriptor_t &desc, - raii::ze_command_list_t cmdList) { +void command_list_cache_t::addCommandList( + const command_list_descriptor_t &desc, + raii::ze_command_list_handle_t cmdList) { // TODO: add a limit? std::unique_lock Lock(ZeCommandListCacheMutex); auto [it, _] = ZeCommandListCache.try_emplace(desc); diff --git a/source/adapters/level_zero/v2/command_list_cache.hpp b/source/adapters/level_zero/v2/command_list_cache.hpp index 1850a4334c..78692c975d 100644 --- a/source/adapters/level_zero/v2/command_list_cache.hpp +++ b/source/adapters/level_zero/v2/command_list_cache.hpp @@ -17,15 +17,13 @@ #include #include -#include "../common.hpp" +#include "common.hpp" namespace v2 { namespace raii { -using ze_command_list_t = std::unique_ptr<::_ze_command_list_handle_t, - decltype(&zeCommandListDestroy)>; using cache_borrowed_command_list_t = std::unique_ptr<::_ze_command_list_handle_t, - std::function>; + std::function>; } // namespace raii struct immediate_command_list_descriptor_t { @@ -72,15 +70,16 @@ struct command_list_cache_t { private: ze_context_handle_t ZeContext; std::unordered_map, + std::stack, command_list_descriptor_hash_t> ZeCommandListCache; ur_mutex ZeCommandListCacheMutex; - raii::ze_command_list_t getCommandList(const command_list_descriptor_t &desc); + raii::ze_command_list_handle_t + getCommandList(const command_list_descriptor_t &desc); void addCommandList(const command_list_descriptor_t &desc, - raii::ze_command_list_t cmdList); - raii::ze_command_list_t + raii::ze_command_list_handle_t cmdList); + raii::ze_command_list_handle_t createCommandList(const command_list_descriptor_t &desc); }; } // namespace v2 diff --git a/source/adapters/level_zero/v2/common.hpp b/source/adapters/level_zero/v2/common.hpp index 6db588ee5d..4fb851bad8 100644 --- a/source/adapters/level_zero/v2/common.hpp +++ b/source/adapters/level_zero/v2/common.hpp @@ -99,5 +99,8 @@ using ze_event_pool_handle_t = using ze_context_handle_t = ze_handle_wrapper<::ze_context_handle_t, zeContextDestroy>; +using ze_command_list_handle_t = + ze_handle_wrapper<::ze_command_list_handle_t, zeCommandListDestroy>; + } // namespace raii } // namespace v2