Skip to content

Commit

Permalink
[L0 v2] Use ze_handle_wraper instead of custom raii handle
Browse files Browse the repository at this point in the history
in command_list_cache
  • Loading branch information
igchor committed Sep 4, 2024
1 parent 0061cdd commit 35b8ef5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 19 deletions.
22 changes: 11 additions & 11 deletions source/adapters/level_zero/v2/command_list_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ inline size_t command_list_descriptor_hash_t::operator()(
command_list_cache_t::command_list_cache_t(ze_context_handle_t ZeContext)
: ZeContext{ZeContext} {}

raii::ze_command_list_t
raii::ze_command_list_handle_t
command_list_cache_t::createCommandList(const command_list_descriptor_t &desc) {
if (auto ImmCmdDesc =
std::get_if<immediate_command_list_descriptor_t>(&desc)) {
Expand All @@ -61,7 +61,7 @@ command_list_cache_t::createCommandList(const command_list_descriptor_t &desc) {
ZE2UR_CALL_THROWS(
zeCommandListCreateImmediate,
(ZeContext, ImmCmdDesc->ZeDevice, &QueueDesc, &ZeCommandList));
return raii::ze_command_list_t(ZeCommandList, &zeCommandListDestroy);
return raii::ze_command_list_handle_t(ZeCommandList);
} else {
auto RegCmdDesc = std::get<regular_command_list_descriptor_t>(desc);
ZeStruct<ze_command_list_desc_t> CmdListDesc;
Expand All @@ -72,7 +72,7 @@ command_list_cache_t::createCommandList(const command_list_descriptor_t &desc) {
ze_command_list_handle_t ZeCommandList;
ZE2UR_CALL_THROWS(zeCommandListCreate, (ZeContext, RegCmdDesc.ZeDevice,
&CmdListDesc, &ZeCommandList));
return raii::ze_command_list_t(ZeCommandList, &zeCommandListDestroy);
return raii::ze_command_list_handle_t(ZeCommandList);
}
}

Expand All @@ -94,8 +94,7 @@ command_list_cache_t::getImmediateCommandList(
auto CommandList = getCommandList(Desc).release();
return raii::cache_borrowed_command_list_t(
CommandList, [Cache = this, Desc](ze_command_list_handle_t CmdList) {
Cache->addCommandList(
Desc, raii::ze_command_list_t(CmdList, &zeCommandListDestroy));
Cache->addCommandList(Desc, raii::ze_command_list_handle_t(CmdList));
});
}

Expand All @@ -113,12 +112,11 @@ command_list_cache_t::getRegularCommandList(ze_device_handle_t ZeDevice,

return raii::cache_borrowed_command_list_t(
CommandList, [Cache = this, Desc](ze_command_list_handle_t CmdList) {
Cache->addCommandList(
Desc, raii::ze_command_list_t(CmdList, &zeCommandListDestroy));
Cache->addCommandList(Desc, raii::ze_command_list_handle_t(CmdList));
});
}

raii::ze_command_list_t
raii::ze_command_list_handle_t
command_list_cache_t::getCommandList(const command_list_descriptor_t &desc) {
std::unique_lock<ur_mutex> Lock(ZeCommandListCacheMutex);
auto it = ZeCommandListCache.find(desc);
Expand All @@ -129,7 +127,8 @@ command_list_cache_t::getCommandList(const command_list_descriptor_t &desc) {

assert(!it->second.empty());

raii::ze_command_list_t CommandListHandle = std::move(it->second.top());
raii::ze_command_list_handle_t CommandListHandle =
std::move(it->second.top());
it->second.pop();

if (it->second.empty())
Expand All @@ -138,8 +137,9 @@ command_list_cache_t::getCommandList(const command_list_descriptor_t &desc) {
return CommandListHandle;
}

void command_list_cache_t::addCommandList(const command_list_descriptor_t &desc,
raii::ze_command_list_t cmdList) {
void command_list_cache_t::addCommandList(
const command_list_descriptor_t &desc,
raii::ze_command_list_handle_t cmdList) {
// TODO: add a limit?
std::unique_lock<ur_mutex> Lock(ZeCommandListCacheMutex);
auto [it, _] = ZeCommandListCache.try_emplace(desc);
Expand Down
15 changes: 7 additions & 8 deletions source/adapters/level_zero/v2/command_list_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
#include <ur_api.h>
#include <ze_api.h>

#include "../common.hpp"
#include "common.hpp"

namespace v2 {
namespace raii {
using ze_command_list_t = std::unique_ptr<::_ze_command_list_handle_t,
decltype(&zeCommandListDestroy)>;
using cache_borrowed_command_list_t =
std::unique_ptr<::_ze_command_list_handle_t,
std::function<void(ze_command_list_handle_t)>>;
std::function<void(::ze_command_list_handle_t)>>;
} // namespace raii

struct immediate_command_list_descriptor_t {
Expand Down Expand Up @@ -72,15 +70,16 @@ struct command_list_cache_t {
private:
ze_context_handle_t ZeContext;
std::unordered_map<command_list_descriptor_t,
std::stack<raii::ze_command_list_t>,
std::stack<raii::ze_command_list_handle_t>,
command_list_descriptor_hash_t>
ZeCommandListCache;
ur_mutex ZeCommandListCacheMutex;

raii::ze_command_list_t getCommandList(const command_list_descriptor_t &desc);
raii::ze_command_list_handle_t
getCommandList(const command_list_descriptor_t &desc);
void addCommandList(const command_list_descriptor_t &desc,
raii::ze_command_list_t cmdList);
raii::ze_command_list_t
raii::ze_command_list_handle_t cmdList);
raii::ze_command_list_handle_t
createCommandList(const command_list_descriptor_t &desc);
};
} // namespace v2
3 changes: 3 additions & 0 deletions source/adapters/level_zero/v2/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,8 @@ using ze_event_pool_handle_t =
using ze_context_handle_t =
ze_handle_wrapper<::ze_context_handle_t, zeContextDestroy>;

using ze_command_list_handle_t =
ze_handle_wrapper<::ze_command_list_handle_t, zeCommandListDestroy>;

} // namespace raii
} // namespace v2

0 comments on commit 35b8ef5

Please sign in to comment.