Skip to content

Commit

Permalink
Merge pull request oneapi-src#2051 from igchor/ze_handle
Browse files Browse the repository at this point in the history
[L0 v2] Use ze_handle_wrappers to hold ze_handles
  • Loading branch information
pbalcer authored Sep 9, 2024
2 parents 76ef4d5 + 35b8ef5 commit 88c3287
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 43 deletions.
22 changes: 11 additions & 11 deletions source/adapters/level_zero/v2/command_list_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ inline size_t command_list_descriptor_hash_t::operator()(
command_list_cache_t::command_list_cache_t(ze_context_handle_t ZeContext)
: ZeContext{ZeContext} {}

raii::ze_command_list_t
raii::ze_command_list_handle_t
command_list_cache_t::createCommandList(const command_list_descriptor_t &desc) {
if (auto ImmCmdDesc =
std::get_if<immediate_command_list_descriptor_t>(&desc)) {
Expand All @@ -61,7 +61,7 @@ command_list_cache_t::createCommandList(const command_list_descriptor_t &desc) {
ZE2UR_CALL_THROWS(
zeCommandListCreateImmediate,
(ZeContext, ImmCmdDesc->ZeDevice, &QueueDesc, &ZeCommandList));
return raii::ze_command_list_t(ZeCommandList, &zeCommandListDestroy);
return raii::ze_command_list_handle_t(ZeCommandList);
} else {
auto RegCmdDesc = std::get<regular_command_list_descriptor_t>(desc);
ZeStruct<ze_command_list_desc_t> CmdListDesc;
Expand All @@ -72,7 +72,7 @@ command_list_cache_t::createCommandList(const command_list_descriptor_t &desc) {
ze_command_list_handle_t ZeCommandList;
ZE2UR_CALL_THROWS(zeCommandListCreate, (ZeContext, RegCmdDesc.ZeDevice,
&CmdListDesc, &ZeCommandList));
return raii::ze_command_list_t(ZeCommandList, &zeCommandListDestroy);
return raii::ze_command_list_handle_t(ZeCommandList);
}
}

Expand All @@ -94,8 +94,7 @@ command_list_cache_t::getImmediateCommandList(
auto CommandList = getCommandList(Desc).release();
return raii::cache_borrowed_command_list_t(
CommandList, [Cache = this, Desc](ze_command_list_handle_t CmdList) {
Cache->addCommandList(
Desc, raii::ze_command_list_t(CmdList, &zeCommandListDestroy));
Cache->addCommandList(Desc, raii::ze_command_list_handle_t(CmdList));
});
}

Expand All @@ -113,12 +112,11 @@ command_list_cache_t::getRegularCommandList(ze_device_handle_t ZeDevice,

return raii::cache_borrowed_command_list_t(
CommandList, [Cache = this, Desc](ze_command_list_handle_t CmdList) {
Cache->addCommandList(
Desc, raii::ze_command_list_t(CmdList, &zeCommandListDestroy));
Cache->addCommandList(Desc, raii::ze_command_list_handle_t(CmdList));
});
}

raii::ze_command_list_t
raii::ze_command_list_handle_t
command_list_cache_t::getCommandList(const command_list_descriptor_t &desc) {
std::unique_lock<ur_mutex> Lock(ZeCommandListCacheMutex);
auto it = ZeCommandListCache.find(desc);
Expand All @@ -129,7 +127,8 @@ command_list_cache_t::getCommandList(const command_list_descriptor_t &desc) {

assert(!it->second.empty());

raii::ze_command_list_t CommandListHandle = std::move(it->second.top());
raii::ze_command_list_handle_t CommandListHandle =
std::move(it->second.top());
it->second.pop();

if (it->second.empty())
Expand All @@ -138,8 +137,9 @@ command_list_cache_t::getCommandList(const command_list_descriptor_t &desc) {
return CommandListHandle;
}

void command_list_cache_t::addCommandList(const command_list_descriptor_t &desc,
raii::ze_command_list_t cmdList) {
void command_list_cache_t::addCommandList(
const command_list_descriptor_t &desc,
raii::ze_command_list_handle_t cmdList) {
// TODO: add a limit?
std::unique_lock<ur_mutex> Lock(ZeCommandListCacheMutex);
auto [it, _] = ZeCommandListCache.try_emplace(desc);
Expand Down
15 changes: 7 additions & 8 deletions source/adapters/level_zero/v2/command_list_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
#include <ur_ddi.h>
#include <ze_api.h>

#include "../common.hpp"
#include "common.hpp"

namespace v2 {
namespace raii {
using ze_command_list_t = std::unique_ptr<::_ze_command_list_handle_t,
decltype(&zeCommandListDestroy)>;
using cache_borrowed_command_list_t =
std::unique_ptr<::_ze_command_list_handle_t,
std::function<void(ze_command_list_handle_t)>>;
std::function<void(::ze_command_list_handle_t)>>;
} // namespace raii

struct immediate_command_list_descriptor_t {
Expand Down Expand Up @@ -72,15 +70,16 @@ struct command_list_cache_t {
private:
ze_context_handle_t ZeContext;
std::unordered_map<command_list_descriptor_t,
std::stack<raii::ze_command_list_t>,
std::stack<raii::ze_command_list_handle_t>,
command_list_descriptor_hash_t>
ZeCommandListCache;
ur_mutex ZeCommandListCacheMutex;

raii::ze_command_list_t getCommandList(const command_list_descriptor_t &desc);
raii::ze_command_list_handle_t
getCommandList(const command_list_descriptor_t &desc);
void addCommandList(const command_list_descriptor_t &desc,
raii::ze_command_list_t cmdList);
raii::ze_command_list_t
raii::ze_command_list_handle_t cmdList);
raii::ze_command_list_handle_t
createCommandList(const command_list_descriptor_t &desc);
};
} // namespace v2
8 changes: 8 additions & 0 deletions source/adapters/level_zero/v2/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ struct ze_handle_wrapper {
try {
reset();
} catch (...) {
// TODO: add appropriate logging or pass the error
// to the caller (make the dtor noexcept(false) or use tls?)
}
}

Expand Down Expand Up @@ -94,5 +96,11 @@ using ze_event_handle_t =
using ze_event_pool_handle_t =
ze_handle_wrapper<::ze_event_pool_handle_t, zeEventPoolDestroy>;

using ze_context_handle_t =
ze_handle_wrapper<::ze_context_handle_t, zeContextDestroy>;

using ze_command_list_handle_t =
ze_handle_wrapper<::ze_command_list_handle_t, zeCommandListDestroy>;

} // namespace raii
} // namespace v2
18 changes: 3 additions & 15 deletions source/adapters/level_zero/v2/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
uint32_t numDevices,
const ur_device_handle_t *phDevices,
bool ownZeContext)
: hContext(hContext), hDevices(phDevices, phDevices + numDevices),
commandListCache(hContext),
: hContext(hContext, ownZeContext),
hDevices(phDevices, phDevices + numDevices), commandListCache(hContext),
eventPoolCache(phDevices[0]->Platform->getNumDevices(),
[context = this,
platform = phDevices[0]->Platform](DeviceId deviceId) {
Expand All @@ -27,19 +27,7 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
return std::make_unique<v2::provider_normal>(
context, device, v2::EVENT_COUNTER,
v2::QUEUE_IMMEDIATE);
}) {
std::ignore = ownZeContext;
}

ur_context_handle_t_::~ur_context_handle_t_() noexcept(false) {
// ur_context_handle_t_ is only created/destroyed through urContextCreate
// and urContextRelease so it's safe to throw here
ZE2UR_CALL_THROWS(zeContextDestroy, (hContext));
}

ze_context_handle_t ur_context_handle_t_::getZeHandle() const {
return hContext;
}
}) {}

ur_result_t ur_context_handle_t_::retain() {
RefCount.increment();
Expand Down
6 changes: 3 additions & 3 deletions source/adapters/level_zero/v2/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,25 @@
#include <ur_api.h>

#include "command_list_cache.hpp"
#include "common.hpp"
#include "event_pool_cache.hpp"

struct ur_context_handle_t_ : _ur_object {
ur_context_handle_t_(ze_context_handle_t hContext, uint32_t numDevices,
const ur_device_handle_t *phDevices, bool ownZeContext);
~ur_context_handle_t_() noexcept(false);

ur_result_t retain();
ur_result_t release();

ze_context_handle_t getZeHandle() const;
inline ze_context_handle_t getZeHandle() const { return hContext.get(); }
ur_platform_handle_t getPlatform() const;
const std::vector<ur_device_handle_t> &getDevices() const;

// Checks if Device is covered by this context.
// For that the Device or its root devices need to be in the context.
bool isValidDevice(ur_device_handle_t Device) const;

const ze_context_handle_t hContext;
const v2::raii::ze_context_handle_t hContext;
const std::vector<ur_device_handle_t> hDevices;
v2::command_list_cache_t commandListCache;
v2::event_pool_cache eventPoolCache;
Expand Down
6 changes: 3 additions & 3 deletions source/adapters/level_zero/v2/event_provider_counter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ provider_counter::provider_counter(ur_platform_handle_t platform,
ZE2UR_CALL_THROWS(zeDriverGetExtensionFunctionAddress,
(platform->ZeDriver, "zexCounterBasedEventCreate",
(void **)&this->eventCreateFunc));
ZE2UR_CALL_THROWS(
zelLoaderTranslateHandle,
(ZEL_HANDLE_CONTEXT, context->hContext, (void **)&translatedContext));
ZE2UR_CALL_THROWS(zelLoaderTranslateHandle,
(ZEL_HANDLE_CONTEXT, context->getZeHandle(),
(void **)&translatedContext));
ZE2UR_CALL_THROWS(
zelLoaderTranslateHandle,
(ZEL_HANDLE_DEVICE, device->ZeDevice, (void **)&translatedDevice));
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/level_zero/v2/event_provider_normal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ provider_pool::provider_pool(ur_context_handle_t context,
}

ZE2UR_CALL_THROWS(zeEventPoolCreate,
(context->hContext, &desc, 1,
(context->getZeHandle(), &desc, 1,
const_cast<ze_device_handle_t *>(&device->ZeDevice),
pool.ptr()));

Expand Down
4 changes: 2 additions & 2 deletions test/adapters/level_zero/v2/command_list_cache_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct CommandListCacheTest : public uur::urContextTest {};
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(CommandListCacheTest);

TEST_P(CommandListCacheTest, CanStoreAndRetriveImmediateAndRegularCmdLists) {
v2::command_list_cache_t cache(context->hContext);
v2::command_list_cache_t cache(context->getZeHandle());

bool IsInOrder = false;
uint32_t Ordinal = 0;
Expand Down Expand Up @@ -75,7 +75,7 @@ TEST_P(CommandListCacheTest, CanStoreAndRetriveImmediateAndRegularCmdLists) {
}

TEST_P(CommandListCacheTest, ImmediateCommandListsHaveProperAttributes) {
v2::command_list_cache_t cache(context->hContext);
v2::command_list_cache_t cache(context->getZeHandle());

uint32_t numQueueGroups = 0;
ASSERT_EQ(zeDeviceGetCommandQueueGroupProperties(device->ZeDevice,
Expand Down

0 comments on commit 88c3287

Please sign in to comment.