From d7b2c23265a738ffb4aefded7ee96c74ed680ae2 Mon Sep 17 00:00:00 2001 From: Ewan Crawford Date: Thu, 13 Feb 2025 16:09:23 +0000 Subject: [PATCH 1/2] [SYCL][Graph][UR] Propagate graph update list to UR Update the `urCommandBufferUpdateKernelLaunchExp` API for updating commands in a command-buffer to take a list of commands. The current API operates on a single command, this means that the SYCL-Graph `update(std::vector)` API needs to serialize the list into N calls to the UR API. Given that both OpenCL `clUpdateMutableCommandsKHR` and Level-Zero `zeCommandListUpdateMutableCommandsExp` can operate on a list of commands, this serialization at the UR layer of the stack introduces extra host API calls. This PR improves the `urCommandBufferUpdateKernelLaunchExp` API so that a list of commands is passed all the way from SYCL to the native backend API. As highlighted in https://github.com/oneapi-src/unified-runtime/issues/2671 this patch requires the handle translation auto generated code to be able to handle a list of structs, which is not currently the case. This is PR includes a API specific temporary workaround in the mako file which will unblock this PR until a more permanent solution is completed. Co-authored-by: Ross Brunton --- .../sycl_ext_oneapi_graph.asciidoc | 4 + sycl/source/detail/graph_impl.cpp | 292 ++++++--- sycl/source/detail/graph_impl.hpp | 45 +- sycl/source/detail/scheduler/commands.cpp | 10 +- unified-runtime/include/ur_api.h | 113 ++-- unified-runtime/include/ur_ddi.h | 2 +- unified-runtime/include/ur_print.hpp | 27 +- .../scripts/core/EXP-COMMAND-BUFFER.rst | 7 +- .../scripts/core/exp-command-buffer.yml | 76 ++- .../scripts/templates/ldrddi.cpp.mako | 23 + .../source/adapters/cuda/command_buffer.cpp | 181 +++--- .../source/adapters/hip/command_buffer.cpp | 165 ++--- .../adapters/level_zero/command_buffer.cpp | 604 ++++++++++-------- .../level_zero/ur_interface_loader.hpp | 2 +- .../source/adapters/level_zero/v2/api.cpp | 2 +- .../source/adapters/mock/ur_mockddi.cpp | 11 +- .../adapters/native_cpu/command_buffer.cpp | 2 +- .../source/adapters/opencl/command_buffer.cpp | 170 +++-- .../loader/layers/tracing/ur_trcddi.cpp | 14 +- .../loader/layers/validation/ur_valddi.cpp | 22 +- unified-runtime/source/loader/ur_ldrddi.cpp | 72 ++- unified-runtime/source/loader/ur_libapi.cpp | 81 ++- unified-runtime/source/ur_api.cpp | 78 ++- .../update/buffer_fill_kernel_update.cpp | 60 +- .../update/buffer_saxpy_kernel_update.cpp | 5 +- .../update/invalid_update.cpp | 173 +++-- .../update/kernel_handle_update.cpp | 33 +- .../update/local_memory_update.cpp | 130 ++-- .../update/ndrange_update.cpp | 25 +- .../update/usm_fill_kernel_update.cpp | 42 +- .../update/usm_saxpy_kernel_update.cpp | 132 ++-- .../urMultiDeviceProgramCreateWithBinary.cpp | 4 +- 32 files changed, 1577 insertions(+), 1030 deletions(-) diff --git a/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc index 88c027beeab55..c1a665b42408d 100644 --- a/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc +++ b/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc @@ -1431,6 +1431,7 @@ Exceptions: created. * Throws with error code `invalid` if `node` is not part of the graph. +* If any other exception is thrown the state of the graph node is undefined. | [source,c++] @@ -1465,6 +1466,7 @@ Exceptions: `property::graph::updatable` was not set when the executable graph was created. * Throws with error code `invalid` if any node in `nodes` is not part of the graph. +* If any other exception is thrown the state of the graph nodes is undefined. | [source, c++] @@ -1517,6 +1519,8 @@ Exceptions: * Throws synchronously with error code `invalid` if `property::graph::updatable` was not set when the executable graph was created. + +* If any other exception is thrown the state of the graph nodes is undefined. |=== Table {counter: tableNumber}. Member functions of the `command_graph` class for diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index 301532b2ff618..2e8b3ced44ce3 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -1381,18 +1381,72 @@ void exec_graph_impl::update(std::shared_ptr Node) { void exec_graph_impl::update( const std::vector> &Nodes) { - if (!MIsUpdatable) { throw sycl::exception(sycl::make_error_code(errc::invalid), "update() cannot be called on a executable graph " "which was not created with property::updatable"); } + // If the graph contains host tasks we need special handling here because + // their state lives in the graph object itself, so we must do the update + // immediately here. Whereas all other command state lives in the backend so + // it can be scheduled along with other commands. + if (MContainsHostTask) { + updateHostTasksImpl(Nodes); + } + + // If there are any accessor requirements, we have to update through the + // scheduler to ensure that any allocations have taken place before trying + // to update. + std::vector UpdateRequirements; + bool NeedScheduledUpdate = needsScheduledUpdate(Nodes, UpdateRequirements); + if (NeedScheduledUpdate) { + auto AllocaQueue = std::make_shared( + sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()), + sycl::detail::getSyclObjImpl(MGraphImpl->getContext()), + sycl::async_handler{}, sycl::property_list{}); + + // Track the event for the update command since execution may be blocked by + // other scheduler commands + auto UpdateEvent = + sycl::detail::Scheduler::getInstance().addCommandGraphUpdate( + this, Nodes, AllocaQueue, UpdateRequirements, MExecutionEvents); + + MExecutionEvents.push_back(UpdateEvent); + } else { + // For each partition in the executable graph, call UR update on the + // command-buffer with the nodes to update. + auto PartitionedNodes = getPartitionForNodes(Nodes); + for (auto It = PartitionedNodes.begin(); It != PartitionedNodes.end(); + It++) { + auto CommandBuffer = It->first->MCommandBuffers[MDevice]; + updateKernelsImpl(CommandBuffer, It->second); + } + } + + // Rebuild cached requirements and accessor storage for this graph with + // updated nodes + MRequirements.clear(); + MAccessors.clear(); + for (auto &Node : MNodeStorage) { + if (!Node->MCommandGroup) + continue; + MRequirements.insert(MRequirements.end(), + Node->MCommandGroup->getRequirements().begin(), + Node->MCommandGroup->getRequirements().end()); + MAccessors.insert(MAccessors.end(), + Node->MCommandGroup->getAccStorage().begin(), + Node->MCommandGroup->getAccStorage().end()); + } +} + +bool exec_graph_impl::needsScheduledUpdate( + const std::vector> &Nodes, + std::vector &UpdateRequirements) { // If there are any accessor requirements, we have to update through the // scheduler to ensure that any allocations have taken place before trying to // update. bool NeedScheduledUpdate = false; - std::vector UpdateRequirements; // At worst we may have as many requirements as there are for the entire graph // for updating. UpdateRequirements.reserve(MRequirements.size()); @@ -1435,94 +1489,17 @@ void exec_graph_impl::update( // ensure it is ordered correctly. NeedScheduledUpdate |= MExecutionEvents.size() > 0; - if (NeedScheduledUpdate) { - // Copy the list of nodes as we may need to modify it - auto NodesCopy = Nodes; - - // If the graph contains host tasks we need special handling here because - // their state lives in the graph object itself, so we must do the update - // immediately here. Whereas all other command state lives in the backend so - // it can be scheduled along with other commands. - if (MContainsHostTask) { - std::vector> HostTasks; - // Remove any nodes that are host tasks and put them in HostTasks - auto RemovedIter = std::remove_if( - NodesCopy.begin(), NodesCopy.end(), - [&HostTasks](const std::shared_ptr &Node) -> bool { - if (Node->MNodeType == node_type::host_task) { - HostTasks.push_back(Node); - return true; - } - return false; - }); - // Clean up extra elements in NodesCopy after the remove - NodesCopy.erase(RemovedIter, NodesCopy.end()); - - // Update host-tasks synchronously - for (auto &HostTaskNode : HostTasks) { - updateImpl(HostTaskNode); - } - } - - auto AllocaQueue = std::make_shared( - sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()), - sycl::detail::getSyclObjImpl(MGraphImpl->getContext()), - sycl::async_handler{}, sycl::property_list{}); - - // Track the event for the update command since execution may be blocked by - // other scheduler commands - auto UpdateEvent = - sycl::detail::Scheduler::getInstance().addCommandGraphUpdate( - this, std::move(NodesCopy), AllocaQueue, UpdateRequirements, - MExecutionEvents); - - MExecutionEvents.push_back(UpdateEvent); - } else { - for (auto &Node : Nodes) { - updateImpl(Node); - } - } - - // Rebuild cached requirements and accessor storage for this graph with - // updated nodes - MRequirements.clear(); - MAccessors.clear(); - for (auto &Node : MNodeStorage) { - if (!Node->MCommandGroup) - continue; - MRequirements.insert(MRequirements.end(), - Node->MCommandGroup->getRequirements().begin(), - Node->MCommandGroup->getRequirements().end()); - MAccessors.insert(MAccessors.end(), - Node->MCommandGroup->getAccStorage().begin(), - Node->MCommandGroup->getAccStorage().end()); - } + return NeedScheduledUpdate; } -void exec_graph_impl::updateImpl(std::shared_ptr Node) { - // Updating empty or barrier nodes is a no-op - if (Node->isEmpty() || Node->MNodeType == node_type::ext_oneapi_barrier) { - return; - } - - // Query the ID cache to find the equivalent exec node for the node passed to - // this function. - // TODO: Handle subgraphs or any other cases where multiple nodes may be - // associated with a single key, once those node types are supported for - // update. - auto ExecNode = MIDCache.find(Node->MID); - assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); - - // Update ExecNode with new values from Node, in case we ever need to - // rebuild the command buffers - ExecNode->second->updateFromOtherNode(Node); - - // Host task update only requires updating the node itself, so can return - // early - if (Node->MNodeType == node_type::host_task) { - return; - } - +void exec_graph_impl::populateURKernelUpdateStructs( + const std::shared_ptr &Node, + std::pair &BundleObjs, + std::vector &MemobjDescs, + std::vector &PtrDescs, + std::vector &ValueDescs, + sycl::detail::NDRDescT &NDRDesc, + ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) { auto ContextImpl = sycl::detail::getSyclObjImpl(MContext); const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter(); auto DeviceImpl = sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()); @@ -1533,9 +1510,8 @@ void exec_graph_impl::updateImpl(std::shared_ptr Node) { // Copy args because we may modify them std::vector NodeArgs = ExecCG.getArguments(); // Copy NDR desc since we need to modify it - auto NDRDesc = ExecCG.MNDRDesc; + NDRDesc = ExecCG.MNDRDesc; - ur_program_handle_t UrProgram = nullptr; ur_kernel_handle_t UrKernel = nullptr; auto Kernel = ExecCG.MSyclKernel; auto KernelBundleImplPtr = ExecCG.MKernelBundle; @@ -1560,9 +1536,11 @@ void exec_graph_impl::updateImpl(std::shared_ptr Node) { UrKernel = Kernel->getHandleRef(); EliminatedArgMask = Kernel->getKernelArgMask(); } else { + ur_program_handle_t UrProgram = nullptr; std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) = sycl::detail::ProgramManager::getInstance().getOrCreateKernel( ContextImpl, DeviceImpl, ExecCG.MKernelName); + BundleObjs = std::make_pair(UrProgram, UrKernel); } // Remove eliminated args @@ -1596,17 +1574,12 @@ void exec_graph_impl::updateImpl(std::shared_ptr Node) { if (EnforcedLocalSize) LocalSize = RequiredWGSize; } - // Create update descriptor // Storage for individual arg descriptors - std::vector MemobjDescs; - std::vector PtrDescs; - std::vector ValueDescs; MemobjDescs.reserve(MaskedArgs.size()); PtrDescs.reserve(MaskedArgs.size()); ValueDescs.reserve(MaskedArgs.size()); - ur_exp_command_buffer_update_kernel_launch_desc_t UpdateDesc{}; UpdateDesc.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC; UpdateDesc.pNext = nullptr; @@ -1675,20 +1648,131 @@ void exec_graph_impl::updateImpl(std::shared_ptr Node) { UpdateDesc.pNewLocalWorkSize = LocalSize; UpdateDesc.newWorkDim = NDRDesc.Dims; + // Query the ID cache to find the equivalent exec node for the node passed to + // this function. + // TODO: Handle subgraphs or any other cases where multiple nodes may be + // associated with a single key, once those node types are supported for + // update. + auto ExecNode = MIDCache.find(Node->MID); + assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); + ur_exp_command_buffer_command_handle_t Command = MCommandMap[ExecNode->second]; - ur_result_t Res = Adapter->call_nocheck< - sycl::detail::UrApiKind::urCommandBufferUpdateKernelLaunchExp>( - Command, &UpdateDesc); + UpdateDesc.hCommand = Command; + + // Update ExecNode with new values from Node, in case we ever need to + // rebuild the command buffers + ExecNode->second->updateFromOtherNode(Node); +} + +std::map, std::vector>> +exec_graph_impl::getPartitionForNodes( + const std::vector> &Nodes) { + // Iterate over each partition in the executable graph, and find the nodes + // in "Nodes" that also exist in the partition. + std::map, std::vector>> + PartitionedNodes; + for (const auto &Partition : MPartitions) { + std::vector> NodesForPartition; + const auto PartitionBegin = Partition->MSchedule.begin(); + const auto PartitionEnd = Partition->MSchedule.end(); + for (auto &Node : Nodes) { + auto ExecNode = MIDCache.find(Node->MID); + assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); + + if (std::find_if(PartitionBegin, PartitionEnd, + [ExecNode](const auto &PartitionNode) { + return PartitionNode->MID == ExecNode->second->MID; + }) != PartitionEnd) { + NodesForPartition.push_back(Node); + } + } + if (!NodesForPartition.empty()) { + PartitionedNodes.insert({Partition, NodesForPartition}); + } + } + + return PartitionedNodes; +} + +void exec_graph_impl::updateHostTasksImpl( + const std::vector> &Nodes) { + for (auto &Node : Nodes) { + if (Node->MNodeType != node_type::host_task) { + continue; + } + // Query the ID cache to find the equivalent exec node for the node passed + // to this function. + // TODO: Handle subgraphs or any other cases where multiple nodes may be + // associated with a single key, once those node types are supported for + // update. + auto ExecNode = MIDCache.find(Node->MID); + assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); - if (UrProgram) { - // We retained these objects by calling getOrCreateKernel() - Adapter->call(UrKernel); - Adapter->call(UrProgram); + ExecNode->second->updateFromOtherNode(Node); } +} - if (Res != UR_RESULT_SUCCESS) { - throw sycl::exception(errc::invalid, "Error updating command_graph"); +void exec_graph_impl::updateKernelsImpl( + ur_exp_command_buffer_handle_t CommandBuffer, + const std::vector> &Nodes) { + // Kernel node update is the only command type supported in UR for update. + // Updating any other types of nodes, e.g. empty & barrier nodes is a no-op. + size_t NumKernelNodes = 0; + for (auto &N : Nodes) { + if (N->MCGType == sycl::detail::CGType::Kernel) { + NumKernelNodes++; + } + } + + // Don't need to call through to UR if no kernel nodes to update + if (NumKernelNodes == 0) { + return; + } + + std::vector> + MemobjDescsList(NumKernelNodes); + std::vector> + PtrDescsList(NumKernelNodes); + std::vector> + ValueDescsList(NumKernelNodes); + std::vector NDRDescList(NumKernelNodes); + std::vector UpdateDescList( + NumKernelNodes); + std::vector> + KernelBundleObjList(NumKernelNodes); + + size_t StructListIndex = 0; + for (auto &Node : Nodes) { + if (Node->MCGType != sycl::detail::CGType::Kernel) { + continue; + } + + auto &MemobjDescs = MemobjDescsList[StructListIndex]; + auto &KernelBundleObjs = KernelBundleObjList[StructListIndex]; + auto &PtrDescs = PtrDescsList[StructListIndex]; + auto &ValueDescs = ValueDescsList[StructListIndex]; + auto &NDRDesc = NDRDescList[StructListIndex]; + auto &UpdateDesc = UpdateDescList[StructListIndex]; + populateURKernelUpdateStructs(Node, KernelBundleObjs, MemobjDescs, PtrDescs, + ValueDescs, NDRDesc, UpdateDesc); + StructListIndex++; + } + + auto ContextImpl = sycl::detail::getSyclObjImpl(MContext); + const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter(); + Adapter->call( + CommandBuffer, UpdateDescList.size(), UpdateDescList.data()); + + for (auto &BundleObjs : KernelBundleObjList) { + // We retained these objects by inside populateUpdateStruct() by calling + // getOrCreateKernel() + if (auto &UrKernel = BundleObjs.second; nullptr != UrKernel) { + Adapter->call(UrKernel); + } + if (auto &UrProgram = BundleObjs.first; nullptr != UrProgram) { + Adapter->call(UrProgram); + } } } diff --git a/sycl/source/detail/graph_impl.hpp b/sycl/source/detail/graph_impl.hpp index 915002c5f8483..b2eeac7b09cdf 100644 --- a/sycl/source/detail/graph_impl.hpp +++ b/sycl/source/detail/graph_impl.hpp @@ -1303,7 +1303,18 @@ class exec_graph_impl { void update(std::shared_ptr Node); void update(const std::vector> &Nodes); - void updateImpl(std::shared_ptr NodeImpl); + /// Calls UR entry-point to update kernel nodes in command-buffer. + /// @param CommandBuffer The UR command-buffer to update commands in. + /// @param Nodes List of nodes to update. May contain nodes of non-kernel + /// type, but only kernel nodes from the list will be used for update + void updateKernelsImpl(ur_exp_command_buffer_handle_t CommandBuffer, + const std::vector> &Nodes); + + /// Splits a list of nodes into separate lists depending on partition. + /// @param Nodes List of nodes to split + /// @return Map of partitions to nodes + std::map, std::vector>> + getPartitionForNodes(const std::vector> &Nodes); unsigned long long getID() const { return MID; } @@ -1373,6 +1384,38 @@ class exec_graph_impl { Stream.close(); } + /// Determines if scheduler needs to be used for node update. + /// @param[in] Nodes List of nodes to be updated + /// @param[out] UpdateRequirements Accessor requirements found in /p Nodes. + /// return True if update should be done through the scheduler. + bool needsScheduledUpdate( + const std::vector> &Nodes, + std::vector &UpdateRequirements); + + /// Sets the UR struct values required to update a graph node. + /// @param[in] Node The node to be updated. + /// @param[out] BundleObjs UR objects created from kernel bundle. + /// Responsibility of the caller to release. + /// @param[out] MemobjDescs Memory object arguments to update. + /// @param[out] PtrDescs Pointer arguments to update. + /// @param[out] ValueDescs Value arguments to update. + /// @param[out] NDRDesc ND-Range to update. + /// @param[out] UpdateDesc Base struct in the pointer chain. + void populateURKernelUpdateStructs( + const std::shared_ptr &Node, + std::pair &BundleObjs, + std::vector &MemobjDescs, + std::vector &PtrDescs, + std::vector &ValueDescs, + sycl::detail::NDRDescT &NDRDesc, + ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc); + + /// Updates host-task nodes in the graph + /// @param Nodes List of nodes to update, any node that is not a host-task + /// will be ignored. + void + updateHostTasksImpl(const std::vector> &Nodes); + /// Execution schedule of nodes in the graph. std::list> MSchedule; /// Pointer to the modifiable graph impl associated with this executable diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index 175e1b3937259..f8e8a02ee73d9 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -3741,7 +3741,15 @@ ur_result_t UpdateCommandBufferCommand::enqueueImp() { default: break; } - MGraph->updateImpl(Node); + } + + // Split list of nodes into nodes per UR command-buffer partition, then + // call UR update on each command-buffer partition + auto PartitionedNodes = MGraph->getPartitionForNodes(MNodes); + auto Device = MQueue->get_device(); + for (auto It = PartitionedNodes.begin(); It != PartitionedNodes.end(); It++) { + auto CommandBuffer = It->first->MCommandBuffers[Device]; + MGraph->updateKernelsImpl(CommandBuffer, It->second); } return UR_RESULT_SUCCESS; diff --git a/unified-runtime/include/ur_api.h b/unified-runtime/include/ur_api.h index c390ed4410d16..18d485ddabdbf 100644 --- a/unified-runtime/include/ur_api.h +++ b/unified-runtime/include/ur_api.h @@ -9971,6 +9971,21 @@ typedef struct ur_exp_command_buffer_desc_t { } ur_exp_command_buffer_desc_t; +/////////////////////////////////////////////////////////////////////////////// +/// @brief A value that identifies a command inside of a command-buffer, used +/// for +/// defining dependencies between commands in the same command-buffer. +typedef uint32_t ur_exp_command_buffer_sync_point_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Handle of Command-Buffer object +typedef struct ur_exp_command_buffer_handle_t_ *ur_exp_command_buffer_handle_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Handle of a Command-Buffer command +typedef struct ur_exp_command_buffer_command_handle_t_ + *ur_exp_command_buffer_command_handle_t; + /////////////////////////////////////////////////////////////////////////////// /// @brief Descriptor type for updating a kernel command memobj argument. typedef struct ur_exp_command_buffer_update_memobj_arg_desc_t { @@ -10034,6 +10049,8 @@ typedef struct ur_exp_command_buffer_update_kernel_launch_desc_t { ur_structure_type_t stype; /// [in][optional] pointer to extension-specific structure const void *pNext; + /// [in] Handle of the command-buffer kernel command to update. + ur_exp_command_buffer_command_handle_t hCommand; /// [in][optional] The new kernel handle. If this parameter is nullptr, /// the current kernel handle in `hCommand` /// will be used. If a kernel handle is passed, it must be a valid kernel @@ -10083,21 +10100,6 @@ typedef struct ur_exp_command_buffer_update_kernel_launch_desc_t { } ur_exp_command_buffer_update_kernel_launch_desc_t; -/////////////////////////////////////////////////////////////////////////////// -/// @brief A value that identifies a command inside of a command-buffer, used -/// for -/// defining dependencies between commands in the same command-buffer. -typedef uint32_t ur_exp_command_buffer_sync_point_t; - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Handle of Command-Buffer object -typedef struct ur_exp_command_buffer_handle_t_ *ur_exp_command_buffer_handle_t; - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Handle of a Command-Buffer command -typedef struct ur_exp_command_buffer_command_handle_t_ - *ur_exp_command_buffer_command_handle_t; - /////////////////////////////////////////////////////////////////////////////// /// @brief Create a Command-Buffer object /// @@ -11045,7 +11047,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// /// @details /// This entry-point is synchronous and may block if the command-buffer is -/// executing when the entry-point is called. +/// executing when the entry-point is called. On error, the state of the +/// command-buffer commands being updated is undefined. /// /// @returns /// - ::UR_RESULT_SUCCESS @@ -11053,66 +11056,75 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// - ::UR_RESULT_ERROR_DEVICE_LOST /// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE -/// + `NULL == hCommand` +/// + `NULL == hCommandBuffer` +/// + `NULL == pUpdateKernelLaunch->hCommand` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == pUpdateKernelLaunch` +/// - ::UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_EXP +/// - ::UR_RESULT_ERROR_INVALID_SIZE +/// + `numKernelUpdates == 0` /// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_ARGUMENTS -/// is not supported by the device, but any of -/// `pUpdateKernelLaunch->numNewMemObjArgs`, -/// `pUpdateKernelLaunch->numNewPointerArgs`, or -/// `pUpdateKernelLaunch->numNewValueArgs` are not zero. +/// is not supported by the device, and for any of any element of +/// `pUpdateKernelLaunch` the `numNewMemObjArgs`, `numNewPointerArgs`, +/// or `numNewValueArgs` members are not zero. /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_LOCAL_WORK_SIZE is -/// not supported by the device but -/// `pUpdateKernelLaunch->pNewLocalWorkSize` is not nullptr. +/// not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `pNewLocalWorkSize` member is not nullptr. /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_LOCAL_WORK_SIZE is -/// not supported by the device but -/// `pUpdateKernelLaunch->pNewLocalWorkSize` is nullptr and -/// `pUpdateKernelLaunch->pNewGlobalWorkSize` is not nullptr. +/// not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `pNewLocalWorkSize` member is nullptr and +/// `pNewGlobalWorkSize` is not nullptr. /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_SIZE -/// is not supported by the device but -/// `pUpdateKernelLaunch->pNewGlobalWorkSize` is not nullptr +/// is not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `pNewGlobalWorkSize` member is not nullptr /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_OFFSET -/// is not supported by the device but -/// `pUpdateKernelLaunch->pNewGlobalWorkOffset` is not nullptr. +/// is not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `pNewGlobalWorkOffset` member is not +/// nullptr. /// + If ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_HANDLE -/// is not supported by the device but `pUpdateKernelLaunch->hNewKernel` -/// is not nullptr. +/// is not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `hNewKernel` member is not nullptr. /// - ::UR_RESULT_ERROR_INVALID_OPERATION /// + If ::ur_exp_command_buffer_desc_t::isUpdatable was not set to true -/// on creation of the command-buffer `hCommand` belongs to. -/// + If the command-buffer `hCommand` belongs to has not been -/// finalized. +/// on creation of the `hCommandBuffer`. +/// + If `hCommandBuffer` has not been finalized. /// - ::UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP -/// + If `hCommand` is not a kernel execution command. +/// + If for any element of `pUpdateKernelLaunch` the `hCommand` member +/// is not a kernel execution command. +/// + If for any element of `pUpdateKernelLaunch` the `hCommand` member +/// was not created from `hCommandBuffer`. /// - ::UR_RESULT_ERROR_INVALID_MEM_OBJECT /// - ::UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX /// - ::UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE /// - ::UR_RESULT_ERROR_INVALID_ENUMERATION /// - ::UR_RESULT_ERROR_INVALID_WORK_DIMENSION -/// + `pUpdateKernelLaunch->newWorkDim < 1 || -/// pUpdateKernelLaunch->newWorkDim > 3` +/// + If for any element of `pUpdateKernelLaunch` the `newWorkDim` +/// member is less than 1 or greater than 3. /// - ::UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE /// - ::UR_RESULT_ERROR_INVALID_VALUE -/// + If `pUpdateKernelLaunch->hNewKernel` was not passed to the -/// `hKernel` or `phKernelAlternatives` parameters of -/// ::urCommandBufferAppendKernelLaunchExp when this command was -/// created. -/// + If `pUpdateKernelLaunch->newWorkDim` is different from the current -/// workDim in `hCommand` and, -/// `pUpdateKernelLaunch->pNewGlobalWorkSize`, or -/// `pUpdateKernelLaunch->pNewGlobalWorkOffset` are nullptr. +/// + If for any element of `pUpdateKernelLaunch` the `hNewKernel` +/// member was not passed to the `hKernel` or `phKernelAlternatives` +/// parameters of ::urCommandBufferAppendKernelLaunchExp when the +/// command was created. +/// + If for any element of `pUpdateKernelLaunch` the `newWorkDim` +/// member is different from the current workDim in the `hCommand` +/// member, and `pNewGlobalWorkSize` or `pNewGlobalWorkOffset` are +/// nullptr. /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( - /// [in] Handle of the command-buffer kernel command to update. - ur_exp_command_buffer_command_handle_t hCommand, - /// [in] Struct defining how the kernel command is to be updated. + /// [in] Handle of the command-buffer object. + ur_exp_command_buffer_handle_t hCommandBuffer, + /// [in] Length of pUpdateKernelLaunch. + uint32_t numKernelUpdates, + /// [in][range(0, numKernelUpdates)] List of structs defining how a + /// kernel commands are to be updated. const ur_exp_command_buffer_update_kernel_launch_desc_t *pUpdateKernelLaunch); @@ -14203,7 +14215,8 @@ typedef struct ur_command_buffer_enqueue_exp_params_t { /// @details Each entry is a pointer to the parameter passed to the function; /// allowing the callback the ability to modify the parameter's value typedef struct ur_command_buffer_update_kernel_launch_exp_params_t { - ur_exp_command_buffer_command_handle_t *phCommand; + ur_exp_command_buffer_handle_t *phCommandBuffer; + uint32_t *pnumKernelUpdates; const ur_exp_command_buffer_update_kernel_launch_desc_t * *ppUpdateKernelLaunch; } ur_command_buffer_update_kernel_launch_exp_params_t; diff --git a/unified-runtime/include/ur_ddi.h b/unified-runtime/include/ur_ddi.h index c64aaa8d464b7..9a0d5cb8c3d7e 100644 --- a/unified-runtime/include/ur_ddi.h +++ b/unified-runtime/include/ur_ddi.h @@ -1599,7 +1599,7 @@ typedef ur_result_t(UR_APICALL *ur_pfnCommandBufferEnqueueExp_t)( /////////////////////////////////////////////////////////////////////////////// /// @brief Function-pointer for urCommandBufferUpdateKernelLaunchExp typedef ur_result_t(UR_APICALL *ur_pfnCommandBufferUpdateKernelLaunchExp_t)( - ur_exp_command_buffer_command_handle_t, + ur_exp_command_buffer_handle_t, uint32_t, const ur_exp_command_buffer_update_kernel_launch_desc_t *); /////////////////////////////////////////////////////////////////////////////// diff --git a/unified-runtime/include/ur_print.hpp b/unified-runtime/include/ur_print.hpp index 5c5f573477929..e9a5589a0dfaf 100644 --- a/unified-runtime/include/ur_print.hpp +++ b/unified-runtime/include/ur_print.hpp @@ -11201,6 +11201,11 @@ inline std::ostream &operator<<( ur::details::printStruct(os, (params.pNext)); + os << ", "; + os << ".hCommand = "; + + ur::details::printPtr(os, (params.hCommand)); + os << ", "; os << ".hNewKernel = "; @@ -18691,14 +18696,30 @@ inline std::ostream & operator<<(std::ostream &os, [[maybe_unused]] const struct ur_command_buffer_update_kernel_launch_exp_params_t *params) { - os << ".hCommand = "; + os << ".hCommandBuffer = "; - ur::details::printPtr(os, *(params->phCommand)); + ur::details::printPtr(os, *(params->phCommandBuffer)); + + os << ", "; + os << ".numKernelUpdates = "; + + os << *(params->pnumKernelUpdates); os << ", "; os << ".pUpdateKernelLaunch = "; + ur::details::printPtr( + os, reinterpret_cast(*(params->ppUpdateKernelLaunch))); + if (*(params->ppUpdateKernelLaunch) != NULL) { + os << " {"; + for (size_t i = 0; i < *params->pnumKernelUpdates; ++i) { + if (i != 0) { + os << ", "; + } - ur::details::printPtr(os, *(params->ppUpdateKernelLaunch)); + os << (*(params->ppUpdateKernelLaunch))[i]; + } + os << "}"; + } return os; } diff --git a/unified-runtime/scripts/core/EXP-COMMAND-BUFFER.rst b/unified-runtime/scripts/core/EXP-COMMAND-BUFFER.rst index 1a4925e83fa0f..9d05a32ee1e3f 100644 --- a/unified-runtime/scripts/core/EXP-COMMAND-BUFFER.rst +++ b/unified-runtime/scripts/core/EXP-COMMAND-BUFFER.rst @@ -309,7 +309,8 @@ ${x}CommandBufferUpdateKernelLaunchExp. ${x}_exp_command_buffer_update_kernel_launch_desc_t update { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext - hNewKernel // hNewKernel + hCommand, // hCommand + hNewKernel, // hNewKernel 2, // numNewMemobjArgs 0, // numNewPointerArgs 0, // numNewValueArgs @@ -325,7 +326,7 @@ ${x}CommandBufferUpdateKernelLaunchExp. }; // Perform the update - ${x}CommandBufferUpdateKernelLaunchExp(hCommand, &update); + ${x}CommandBufferUpdateKernelLaunchExp(hCommandBuffer, 1, &update); Command Event Update ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ @@ -514,6 +515,8 @@ Changelog +-----------+-------------------------------------------------------+ | 1.7 | Remove command handle reference counting and querying | +-----------+-------------------------------------------------------+ +| 1.8 | Change Kernel command update API to take a list | ++-----------+-------------------------------------------------------+ Contributors -------------------------------------------------------------------------------- diff --git a/unified-runtime/scripts/core/exp-command-buffer.yml b/unified-runtime/scripts/core/exp-command-buffer.yml index 680bc60f8d184..babc674a7cbb7 100644 --- a/unified-runtime/scripts/core/exp-command-buffer.yml +++ b/unified-runtime/scripts/core/exp-command-buffer.yml @@ -150,6 +150,22 @@ members: name: enableProfiling desc: "[in] Command-buffer profiling is enabled." --- #-------------------------------------------------------------------------- +type: typedef +desc: "A value that identifies a command inside of a command-buffer, used for defining dependencies between commands in the same command-buffer." +class: $xCommandBuffer +name: $x_exp_command_buffer_sync_point_t +value: uint32_t +--- #-------------------------------------------------------------------------- +type: handle +desc: "Handle of Command-Buffer object" +class: $xCommandBuffer +name: "$x_exp_command_buffer_handle_t" +--- #-------------------------------------------------------------------------- +type: handle +desc: "Handle of a Command-Buffer command" +class: $xCommandBuffer +name: "$x_exp_command_buffer_command_handle_t" +--- #-------------------------------------------------------------------------- type: struct desc: "Descriptor type for updating a kernel command memobj argument." base: $x_base_desc_t @@ -203,6 +219,9 @@ desc: "Descriptor type for updating a kernel launch command." base: $x_base_desc_t name: $x_exp_command_buffer_update_kernel_launch_desc_t members: + - type: $x_exp_command_buffer_command_handle_t + name: hCommand + desc: "[in] Handle of the command-buffer kernel command to update." - type: $x_kernel_handle_t name: hNewKernel desc: | @@ -250,22 +269,6 @@ members: then the runtime implementation will choose the local work size. If `pNewGlobalWorkSize` is nullptr and `pNewLocalWorkSize` is nullptr, the current local work size in the command will be used. --- #-------------------------------------------------------------------------- -type: typedef -desc: "A value that identifies a command inside of a command-buffer, used for defining dependencies between commands in the same command-buffer." -class: $xCommandBuffer -name: $x_exp_command_buffer_sync_point_t -value: uint32_t ---- #-------------------------------------------------------------------------- -type: handle -desc: "Handle of Command-Buffer object" -class: $xCommandBuffer -name: "$x_exp_command_buffer_handle_t" ---- #-------------------------------------------------------------------------- -type: handle -desc: "Handle of a Command-Buffer command" -class: $xCommandBuffer -name: "$x_exp_command_buffer_command_handle_t" ---- #-------------------------------------------------------------------------- type: function desc: "Create a Command-Buffer object" class: $xCommandBuffer @@ -1166,39 +1169,46 @@ returns: --- #-------------------------------------------------------------------------- type: function desc: "Update a kernel launch command in a finalized command-buffer." -details: "This entry-point is synchronous and may block if the command-buffer is executing when the entry-point is called." +details: "This entry-point is synchronous and may block if the command-buffer is executing when the entry-point is called. On error, the state of the command-buffer commands being updated is undefined." class: $xCommandBuffer name: UpdateKernelLaunchExp params: - - type: $x_exp_command_buffer_command_handle_t - name: hCommand - desc: "[in] Handle of the command-buffer kernel command to update." + - type: $x_exp_command_buffer_handle_t + name: hCommandBuffer + desc: "[in] Handle of the command-buffer object." + - type: uint32_t + name: numKernelUpdates + desc: "[in] Length of pUpdateKernelLaunch." - type: "const $x_exp_command_buffer_update_kernel_launch_desc_t*" name: pUpdateKernelLaunch - desc: "[in] Struct defining how the kernel command is to be updated." + desc: "[in][range(0, numKernelUpdates)] List of structs defining how a kernel commands are to be updated." returns: + - $X_RESULT_ERROR_INVALID_COMMAND_BUFFER_EXP + - $X_RESULT_ERROR_INVALID_SIZE: + - "`numKernelUpdates == 0`" - $X_RESULT_ERROR_UNSUPPORTED_FEATURE: - - "If $X_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_ARGUMENTS is not supported by the device, but any of `pUpdateKernelLaunch->numNewMemObjArgs`, `pUpdateKernelLaunch->numNewPointerArgs`, or `pUpdateKernelLaunch->numNewValueArgs` are not zero." - - "If $X_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_LOCAL_WORK_SIZE is not supported by the device but `pUpdateKernelLaunch->pNewLocalWorkSize` is not nullptr." - - "If $X_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_LOCAL_WORK_SIZE is not supported by the device but `pUpdateKernelLaunch->pNewLocalWorkSize` is nullptr and `pUpdateKernelLaunch->pNewGlobalWorkSize` is not nullptr." - - "If $X_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_SIZE is not supported by the device but `pUpdateKernelLaunch->pNewGlobalWorkSize` is not nullptr" - - "If $X_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_OFFSET is not supported by the device but `pUpdateKernelLaunch->pNewGlobalWorkOffset` is not nullptr." - - "If $X_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_HANDLE is not supported by the device but `pUpdateKernelLaunch->hNewKernel` is not nullptr." + - "If $X_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_ARGUMENTS is not supported by the device, and for any of any element of `pUpdateKernelLaunch` the `numNewMemObjArgs`, `numNewPointerArgs`, or `numNewValueArgs` members are not zero." + - "If $X_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_LOCAL_WORK_SIZE is not supported by the device, and for any element of `pUpdateKernelLaunch` the `pNewLocalWorkSize` member is not nullptr." + - "If $X_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_LOCAL_WORK_SIZE is not supported by the device, and for any element of `pUpdateKernelLaunch` the `pNewLocalWorkSize` member is nullptr and `pNewGlobalWorkSize` is not nullptr." + - "If $X_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_SIZE is not supported by the device, and for any element of `pUpdateKernelLaunch` the `pNewGlobalWorkSize` member is not nullptr" + - "If $X_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_OFFSET is not supported by the device, and for any element of `pUpdateKernelLaunch` the `pNewGlobalWorkOffset` member is not nullptr." + - "If $X_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_HANDLE is not supported by the device, and for any element of `pUpdateKernelLaunch` the `hNewKernel` member is not nullptr." - $X_RESULT_ERROR_INVALID_OPERATION: - - "If $x_exp_command_buffer_desc_t::isUpdatable was not set to true on creation of the command-buffer `hCommand` belongs to." - - "If the command-buffer `hCommand` belongs to has not been finalized." + - "If $x_exp_command_buffer_desc_t::isUpdatable was not set to true on creation of the `hCommandBuffer`." + - "If `hCommandBuffer` has not been finalized." - $X_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP: - - "If `hCommand` is not a kernel execution command." + - "If for any element of `pUpdateKernelLaunch` the `hCommand` member is not a kernel execution command." + - "If for any element of `pUpdateKernelLaunch` the `hCommand` member was not created from `hCommandBuffer`." - $X_RESULT_ERROR_INVALID_MEM_OBJECT - $X_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX - $X_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE - $X_RESULT_ERROR_INVALID_ENUMERATION - $X_RESULT_ERROR_INVALID_WORK_DIMENSION: - - "`pUpdateKernelLaunch->newWorkDim < 1 || pUpdateKernelLaunch->newWorkDim > 3`" + - "If for any element of `pUpdateKernelLaunch` the `newWorkDim` member is less than 1 or greater than 3." - $X_RESULT_ERROR_INVALID_WORK_GROUP_SIZE - $X_RESULT_ERROR_INVALID_VALUE: - - "If `pUpdateKernelLaunch->hNewKernel` was not passed to the `hKernel` or `phKernelAlternatives` parameters of $xCommandBufferAppendKernelLaunchExp when this command was created." - - "If `pUpdateKernelLaunch->newWorkDim` is different from the current workDim in `hCommand` and, `pUpdateKernelLaunch->pNewGlobalWorkSize`, or `pUpdateKernelLaunch->pNewGlobalWorkOffset` are nullptr." + - "If for any element of `pUpdateKernelLaunch` the `hNewKernel` member was not passed to the `hKernel` or `phKernelAlternatives` parameters of $xCommandBufferAppendKernelLaunchExp when the command was created." + - "If for any element of `pUpdateKernelLaunch` the `newWorkDim` member is different from the current workDim in the `hCommand` member, and `pNewGlobalWorkSize` or `pNewGlobalWorkOffset` are nullptr." - $X_RESULT_ERROR_OUT_OF_HOST_MEMORY - $X_RESULT_ERROR_OUT_OF_RESOURCES --- #-------------------------------------------------------------------------- diff --git a/unified-runtime/scripts/templates/ldrddi.cpp.mako b/unified-runtime/scripts/templates/ldrddi.cpp.mako index ba191a9ceb30e..ee01b04487f3a 100644 --- a/unified-runtime/scripts/templates/ldrddi.cpp.mako +++ b/unified-runtime/scripts/templates/ldrddi.cpp.mako @@ -212,6 +212,16 @@ namespace ur_loader <% handle_structs = th.get_object_handle_structs_to_convert(n, tags, obj, meta) %> %if handle_structs: // Deal with any struct parameters that have handle members we need to convert. + %if func_basename == "CommandBufferUpdateKernelLaunchExp": + ## CommandBufferUpdateKernelLaunchExp entry-point takes a list of structs with + ## handle members, as well as members defining a nested list of structs + ## containing handles. This useage is not supported yet, so special case as + ## a temporary measure. + std::vector pUpdateKernelLaunchVector = {}; + std::vector> + ppUpdateKernelLaunchpNewMemObjArgList(numKernelUpdates); + for (size_t Offset = 0; Offset < numKernelUpdates; Offset ++) { + %endif %for struct in handle_structs: %if struct['optional']: ${struct['type']} ${struct['name']}Local = {}; @@ -239,7 +249,13 @@ namespace ur_loader range_end = member['range_end'] if not re.match(r"[0-9]+$", range_end): range_end = struct['name'] + "->" + member['parent'] + range_end %> + + %if func_basename == "CommandBufferUpdateKernelLaunchExp": + std::vector& + pUpdateKernelLaunchpNewMemObjArgList = ppUpdateKernelLaunchpNewMemObjArgList[Offset]; + %else: std::vector<${member['type']}> ${range_vector_name}; + %endif for(uint32_t i = ${range_start}; i < ${range_end}; i++) { ${member['type']} NewRangeStruct = ${struct['name']}Local.${member['parent']}${member['name']}[i]; %for handle_member in member['handle_members']: @@ -277,6 +293,12 @@ namespace ur_loader %endfor %endfor + %if func_basename == "CommandBufferUpdateKernelLaunchExp": + pUpdateKernelLaunchVector.push_back(pUpdateKernelLaunchLocal); + pUpdateKernelLaunch++; + } + pUpdateKernelLaunch = pUpdateKernelLaunchVector.data(); + %else: // Now that we've converted all the members update the param pointers %for struct in handle_structs: %if struct['optional']: @@ -285,6 +307,7 @@ namespace ur_loader ${struct['name']} = &${struct['name']}Local; %endfor %endif + %endif // forward to device-platform %if add_local: diff --git a/unified-runtime/source/adapters/cuda/command_buffer.cpp b/unified-runtime/source/adapters/cuda/command_buffer.cpp index 42ec8dbafc9ef..e474aab238de0 100644 --- a/unified-runtime/source/adapters/cuda/command_buffer.cpp +++ b/unified-runtime/source/adapters/cuda/command_buffer.cpp @@ -1161,28 +1161,37 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( /** * Validates contents of the update command description. - * @param[in] Command The command which is being updated. + * @param[in] CommandBuffer The command-buffer which is being updated. * @param[in] UpdateCommandDesc The update command description. * @return UR_RESULT_SUCCESS or an error code on failure */ ur_result_t -validateCommandDesc(kernel_command_handle *Command, +validateCommandDesc(ur_exp_command_buffer_handle_t CommandBuffer, const ur_exp_command_buffer_update_kernel_launch_desc_t - *UpdateCommandDesc) { - auto CommandBuffer = Command->CommandBuffer; + &UpdateCommandDesc) { + if (UpdateCommandDesc.hCommand->getCommandType() != CommandType::Kernel) { + return UR_RESULT_ERROR_INVALID_VALUE; + } + + auto Command = + static_cast(UpdateCommandDesc.hCommand); + if (CommandBuffer != Command->CommandBuffer) { + return UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP; + } + // Update requires the command-buffer to be finalized and updatable. if (!CommandBuffer->CudaGraphExec || !CommandBuffer->IsUpdatable) { return UR_RESULT_ERROR_INVALID_OPERATION; } - if (UpdateCommandDesc->newWorkDim != Command->WorkDim && - (!UpdateCommandDesc->pNewGlobalWorkOffset || - !UpdateCommandDesc->pNewGlobalWorkSize)) { + if (UpdateCommandDesc.newWorkDim != Command->WorkDim && + (!UpdateCommandDesc.pNewGlobalWorkOffset || + !UpdateCommandDesc.pNewGlobalWorkSize)) { return UR_RESULT_ERROR_INVALID_VALUE; } - if (UpdateCommandDesc->hNewKernel && - !Command->ValidKernelHandles.count(UpdateCommandDesc->hNewKernel)) { + if (UpdateCommandDesc.hNewKernel && + !Command->ValidKernelHandles.count(UpdateCommandDesc.hNewKernel)) { return UR_RESULT_ERROR_INVALID_VALUE; } return UR_RESULT_SUCCESS; @@ -1190,24 +1199,22 @@ validateCommandDesc(kernel_command_handle *Command, /** * Updates the arguments of a kernel command. - * @param[in] Command The command associated with the kernel node being - * updated. * @param[in] UpdateCommandDesc The update command description that contains - * the new arguments. + * the new configuration. * @return UR_RESULT_SUCCESS or an error code on failure */ ur_result_t -updateKernelArguments(kernel_command_handle *Command, - const ur_exp_command_buffer_update_kernel_launch_desc_t - *UpdateCommandDesc) { - +updateKernelArguments(const ur_exp_command_buffer_update_kernel_launch_desc_t + &UpdateCommandDesc) { + auto Command = + static_cast(UpdateCommandDesc.hCommand); ur_kernel_handle_t Kernel = Command->Kernel; ur_device_handle_t Device = Command->CommandBuffer->Device; // Update pointer arguments to the kernel - uint32_t NumPointerArgs = UpdateCommandDesc->numNewPointerArgs; + uint32_t NumPointerArgs = UpdateCommandDesc.numNewPointerArgs; const ur_exp_command_buffer_update_pointer_arg_desc_t *ArgPointerList = - UpdateCommandDesc->pNewPointerArgList; + UpdateCommandDesc.pNewPointerArgList; for (uint32_t i = 0; i < NumPointerArgs; i++) { const auto &PointerArgDesc = ArgPointerList[i]; uint32_t ArgIndex = PointerArgDesc.argIndex; @@ -1223,9 +1230,9 @@ updateKernelArguments(kernel_command_handle *Command, } // Update memobj arguments to the kernel - uint32_t NumMemobjArgs = UpdateCommandDesc->numNewMemObjArgs; + uint32_t NumMemobjArgs = UpdateCommandDesc.numNewMemObjArgs; const ur_exp_command_buffer_update_memobj_arg_desc_t *ArgMemobjList = - UpdateCommandDesc->pNewMemObjArgList; + UpdateCommandDesc.pNewMemObjArgList; for (uint32_t i = 0; i < NumMemobjArgs; i++) { const auto &MemobjArgDesc = ArgMemobjList[i]; uint32_t ArgIndex = MemobjArgDesc.argIndex; @@ -1246,9 +1253,9 @@ updateKernelArguments(kernel_command_handle *Command, } // Update value arguments to the kernel - uint32_t NumValueArgs = UpdateCommandDesc->numNewValueArgs; + uint32_t NumValueArgs = UpdateCommandDesc.numNewValueArgs; const ur_exp_command_buffer_update_value_arg_desc_t *ArgValueList = - UpdateCommandDesc->pNewValueArgList; + UpdateCommandDesc.pNewValueArgList; for (uint32_t i = 0; i < NumValueArgs; i++) { const auto &ValueArgDesc = ArgValueList[i]; uint32_t ArgIndex = ValueArgDesc.argIndex; @@ -1275,94 +1282,100 @@ updateKernelArguments(kernel_command_handle *Command, /** * Updates the command-buffer command with new values from the update * description. - * @param[in] Command The command to be updated. * @param[in] UpdateCommandDesc The update command description. * @return UR_RESULT_SUCCESS or an error code on failure */ ur_result_t -updateCommand(kernel_command_handle *Command, - const ur_exp_command_buffer_update_kernel_launch_desc_t - *UpdateCommandDesc) { - if (UpdateCommandDesc->hNewKernel) { - Command->Kernel = UpdateCommandDesc->hNewKernel; +updateCommand(const ur_exp_command_buffer_update_kernel_launch_desc_t + &UpdateCommandDesc) { + auto Command = + static_cast(UpdateCommandDesc.hCommand); + if (UpdateCommandDesc.hNewKernel) { + Command->Kernel = UpdateCommandDesc.hNewKernel; } - if (UpdateCommandDesc->newWorkDim) { - Command->WorkDim = UpdateCommandDesc->newWorkDim; + if (UpdateCommandDesc.newWorkDim) { + Command->WorkDim = UpdateCommandDesc.newWorkDim; } - if (UpdateCommandDesc->pNewGlobalWorkOffset) { - Command->setGlobalOffset(UpdateCommandDesc->pNewGlobalWorkOffset); + if (UpdateCommandDesc.pNewGlobalWorkOffset) { + Command->setGlobalOffset(UpdateCommandDesc.pNewGlobalWorkOffset); } - if (UpdateCommandDesc->pNewGlobalWorkSize) { - Command->setGlobalSize(UpdateCommandDesc->pNewGlobalWorkSize); - if (!UpdateCommandDesc->pNewLocalWorkSize) { + if (UpdateCommandDesc.pNewGlobalWorkSize) { + Command->setGlobalSize(UpdateCommandDesc.pNewGlobalWorkSize); + if (!UpdateCommandDesc.pNewLocalWorkSize) { Command->setNullLocalSize(); } } - if (UpdateCommandDesc->pNewLocalWorkSize) { - Command->setLocalSize(UpdateCommandDesc->pNewLocalWorkSize); + if (UpdateCommandDesc.pNewLocalWorkSize) { + Command->setLocalSize(UpdateCommandDesc.pNewLocalWorkSize); } return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( - ur_exp_command_buffer_command_handle_t hCommand, + ur_exp_command_buffer_handle_t hCommandBuffer, uint32_t numKernelUpdates, const ur_exp_command_buffer_update_kernel_launch_desc_t *pUpdateKernelLaunch) try { + // First validate user inputs, as no update should be propagated if there + // are errors. + for (uint32_t i = 0; i < numKernelUpdates; i++) { + UR_CHECK_ERROR(validateCommandDesc(hCommandBuffer, pUpdateKernelLaunch[i])); + } - ur_exp_command_buffer_handle_t CommandBuffer = hCommand->CommandBuffer; - - if (hCommand->getCommandType() != CommandType::Kernel) { - return UR_RESULT_ERROR_INVALID_VALUE; + // Store changes in config struct in command handle object + for (uint32_t i = 0; i < numKernelUpdates; i++) { + UR_CHECK_ERROR(updateCommand(pUpdateKernelLaunch[i])); + UR_CHECK_ERROR(updateKernelArguments(pUpdateKernelLaunch[i])); } - auto KernelCommandHandle = static_cast(hCommand); + // Propagate changes to CUDA driver API + for (uint32_t i = 0; i < numKernelUpdates; i++) { + const auto &UpdateCommandDesc = pUpdateKernelLaunch[i]; - UR_CHECK_ERROR(validateCommandDesc(KernelCommandHandle, pUpdateKernelLaunch)); - UR_CHECK_ERROR(updateCommand(KernelCommandHandle, pUpdateKernelLaunch)); - UR_CHECK_ERROR( - updateKernelArguments(KernelCommandHandle, pUpdateKernelLaunch)); - - // If no work-size is provided make sure we pass nullptr to setKernelParams - // so it can guess the local work size. - const bool ProvidedLocalSize = !KernelCommandHandle->isNullLocalSize(); - size_t *LocalWorkSize = - ProvidedLocalSize ? KernelCommandHandle->LocalWorkSize : nullptr; - - // Set the number of threads per block to the number of threads per warp - // by default unless user has provided a better number. - size_t ThreadsPerBlock[3] = {32u, 1u, 1u}; - size_t BlocksPerGrid[3] = {1u, 1u, 1u}; - CUfunction CuFunc = KernelCommandHandle->Kernel->get(); - auto Result = setKernelParams( - CommandBuffer->Context, CommandBuffer->Device, - KernelCommandHandle->WorkDim, KernelCommandHandle->GlobalWorkOffset, - KernelCommandHandle->GlobalWorkSize, LocalWorkSize, - KernelCommandHandle->Kernel, CuFunc, ThreadsPerBlock, BlocksPerGrid); - if (Result != UR_RESULT_SUCCESS) { - return Result; - } - - CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle->Params; - - Params.func = CuFunc; - Params.gridDimX = BlocksPerGrid[0]; - Params.gridDimY = BlocksPerGrid[1]; - Params.gridDimZ = BlocksPerGrid[2]; - Params.blockDimX = ThreadsPerBlock[0]; - Params.blockDimY = ThreadsPerBlock[1]; - Params.blockDimZ = ThreadsPerBlock[2]; - Params.sharedMemBytes = KernelCommandHandle->Kernel->getLocalSize(); - Params.kernelParams = - const_cast(KernelCommandHandle->Kernel->getArgPointers().data()); - - CUgraphNode Node = KernelCommandHandle->Node; - CUgraphExec CudaGraphExec = CommandBuffer->CudaGraphExec; - UR_CHECK_ERROR(cuGraphExecKernelNodeSetParams(CudaGraphExec, Node, &Params)); + // If no work-size is provided make sure we pass nullptr to setKernelParams + // so it can guess the local work size. + auto KernelCommandHandle = + static_cast(UpdateCommandDesc.hCommand); + const bool ProvidedLocalSize = !KernelCommandHandle->isNullLocalSize(); + size_t *LocalWorkSize = + ProvidedLocalSize ? KernelCommandHandle->LocalWorkSize : nullptr; + + // Set the number of threads per block to the number of threads per warp + // by default unless user has provided a better number. + size_t ThreadsPerBlock[3] = {32u, 1u, 1u}; + size_t BlocksPerGrid[3] = {1u, 1u, 1u}; + CUfunction CuFunc = KernelCommandHandle->Kernel->get(); + auto Result = setKernelParams( + hCommandBuffer->Context, hCommandBuffer->Device, + KernelCommandHandle->WorkDim, KernelCommandHandle->GlobalWorkOffset, + KernelCommandHandle->GlobalWorkSize, LocalWorkSize, + KernelCommandHandle->Kernel, CuFunc, ThreadsPerBlock, BlocksPerGrid); + if (Result != UR_RESULT_SUCCESS) { + return Result; + } + + CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle->Params; + + Params.func = CuFunc; + Params.gridDimX = BlocksPerGrid[0]; + Params.gridDimY = BlocksPerGrid[1]; + Params.gridDimZ = BlocksPerGrid[2]; + Params.blockDimX = ThreadsPerBlock[0]; + Params.blockDimY = ThreadsPerBlock[1]; + Params.blockDimZ = ThreadsPerBlock[2]; + Params.sharedMemBytes = KernelCommandHandle->Kernel->getLocalSize(); + Params.kernelParams = const_cast( + KernelCommandHandle->Kernel->getArgPointers().data()); + + CUgraphNode Node = KernelCommandHandle->Node; + CUgraphExec CudaGraphExec = hCommandBuffer->CudaGraphExec; + UR_CHECK_ERROR( + cuGraphExecKernelNodeSetParams(CudaGraphExec, Node, &Params)); + } return UR_RESULT_SUCCESS; } catch (ur_result_t Err) { return Err; diff --git a/unified-runtime/source/adapters/hip/command_buffer.cpp b/unified-runtime/source/adapters/hip/command_buffer.cpp index ce07332ce8d97..079b25675a0ff 100644 --- a/unified-runtime/source/adapters/hip/command_buffer.cpp +++ b/unified-runtime/source/adapters/hip/command_buffer.cpp @@ -865,30 +865,32 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( /** * Validates contents of the update command description. - * @param[in] Command The command which is being updated. + * @param[in] CommandBuffer The command-buffer which is being updated. * @param[in] UpdateCommandDesc The update command description. * @return UR_RESULT_SUCCESS or an error code on failure */ ur_result_t -validateCommandDesc(ur_exp_command_buffer_command_handle_t Command, +validateCommandDesc(ur_exp_command_buffer_handle_t CommandBuffer, const ur_exp_command_buffer_update_kernel_launch_desc_t - *UpdateCommandDesc) { - - auto CommandBuffer = Command->CommandBuffer; - + &UpdateCommandDesc) { // Update requires the command-buffer to be finalized and updatable. if (!CommandBuffer->HIPGraphExec || !CommandBuffer->IsUpdatable) { return UR_RESULT_ERROR_INVALID_OPERATION; } - if (UpdateCommandDesc->newWorkDim != Command->WorkDim && - (!UpdateCommandDesc->pNewGlobalWorkOffset || - !UpdateCommandDesc->pNewGlobalWorkSize)) { + auto Command = UpdateCommandDesc.hCommand; + if (CommandBuffer != Command->CommandBuffer) { + return UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP; + } + + if (UpdateCommandDesc.newWorkDim != Command->WorkDim && + (!UpdateCommandDesc.pNewGlobalWorkOffset || + !UpdateCommandDesc.pNewGlobalWorkSize)) { return UR_RESULT_ERROR_INVALID_VALUE; } - if (UpdateCommandDesc->hNewKernel && - !Command->ValidKernelHandles.count(UpdateCommandDesc->hNewKernel)) { + if (UpdateCommandDesc.hNewKernel && + !Command->ValidKernelHandles.count(UpdateCommandDesc.hNewKernel)) { return UR_RESULT_ERROR_INVALID_VALUE; } @@ -897,23 +899,21 @@ validateCommandDesc(ur_exp_command_buffer_command_handle_t Command, /** * Updates the arguments of a kernel command. - * @param[in] Command The command associated with the kernel node being updated. * @param[in] UpdateCommandDesc The update command description that contains the - * new arguments. + * new configuration. * @return UR_RESULT_SUCCESS or an error code on failure */ ur_result_t -updateKernelArguments(ur_exp_command_buffer_command_handle_t Command, - const ur_exp_command_buffer_update_kernel_launch_desc_t - *UpdateCommandDesc) { - +updateKernelArguments(const ur_exp_command_buffer_update_kernel_launch_desc_t + &UpdateCommandDesc) { + auto Command = UpdateCommandDesc.hCommand; ur_kernel_handle_t Kernel = Command->Kernel; ur_device_handle_t Device = Command->CommandBuffer->Device; // Update pointer arguments to the kernel - uint32_t NumPointerArgs = UpdateCommandDesc->numNewPointerArgs; + uint32_t NumPointerArgs = UpdateCommandDesc.numNewPointerArgs; const ur_exp_command_buffer_update_pointer_arg_desc_t *ArgPointerList = - UpdateCommandDesc->pNewPointerArgList; + UpdateCommandDesc.pNewPointerArgList; for (uint32_t i = 0; i < NumPointerArgs; i++) { const auto &PointerArgDesc = ArgPointerList[i]; uint32_t ArgIndex = PointerArgDesc.argIndex; @@ -927,9 +927,9 @@ updateKernelArguments(ur_exp_command_buffer_command_handle_t Command, } // Update memobj arguments to the kernel - uint32_t NumMemobjArgs = UpdateCommandDesc->numNewMemObjArgs; + uint32_t NumMemobjArgs = UpdateCommandDesc.numNewMemObjArgs; const ur_exp_command_buffer_update_memobj_arg_desc_t *ArgMemobjList = - UpdateCommandDesc->pNewMemObjArgList; + UpdateCommandDesc.pNewMemObjArgList; for (uint32_t i = 0; i < NumMemobjArgs; i++) { const auto &MemobjArgDesc = ArgMemobjList[i]; uint32_t ArgIndex = MemobjArgDesc.argIndex; @@ -948,9 +948,9 @@ updateKernelArguments(ur_exp_command_buffer_command_handle_t Command, } // Update value arguments to the kernel - uint32_t NumValueArgs = UpdateCommandDesc->numNewValueArgs; + uint32_t NumValueArgs = UpdateCommandDesc.numNewValueArgs; const ur_exp_command_buffer_update_value_arg_desc_t *ArgValueList = - UpdateCommandDesc->pNewValueArgList; + UpdateCommandDesc.pNewValueArgList; for (uint32_t i = 0; i < NumValueArgs; i++) { const auto &ValueArgDesc = ArgValueList[i]; uint32_t ArgIndex = ValueArgDesc.argIndex; @@ -975,83 +975,94 @@ updateKernelArguments(ur_exp_command_buffer_command_handle_t Command, /** * Updates the command-buffer command with new values from the update * description. - * @param[in] Command The command to be updated. * @param[in] UpdateCommandDesc The update command description. * @return UR_RESULT_SUCCESS or an error code on failure */ ur_result_t -updateCommand(ur_exp_command_buffer_command_handle_t Command, - const ur_exp_command_buffer_update_kernel_launch_desc_t - *UpdateCommandDesc) { - - if (UpdateCommandDesc->hNewKernel) { - Command->Kernel = UpdateCommandDesc->hNewKernel; +updateCommand(const ur_exp_command_buffer_update_kernel_launch_desc_t + &UpdateCommandDesc) { + auto Command = UpdateCommandDesc.hCommand; + if (UpdateCommandDesc.hNewKernel) { + Command->Kernel = UpdateCommandDesc.hNewKernel; } - if (UpdateCommandDesc->hNewKernel) { - Command->WorkDim = UpdateCommandDesc->newWorkDim; + if (UpdateCommandDesc.hNewKernel) { + Command->WorkDim = UpdateCommandDesc.newWorkDim; } - if (UpdateCommandDesc->pNewGlobalWorkOffset) { - Command->setGlobalOffset(UpdateCommandDesc->pNewGlobalWorkOffset); + if (UpdateCommandDesc.pNewGlobalWorkOffset) { + Command->setGlobalOffset(UpdateCommandDesc.pNewGlobalWorkOffset); } - if (UpdateCommandDesc->pNewGlobalWorkSize) { - Command->setGlobalSize(UpdateCommandDesc->pNewGlobalWorkSize); - if (!UpdateCommandDesc->pNewLocalWorkSize) { + if (UpdateCommandDesc.pNewGlobalWorkSize) { + Command->setGlobalSize(UpdateCommandDesc.pNewGlobalWorkSize); + if (!UpdateCommandDesc.pNewLocalWorkSize) { Command->setNullLocalSize(); } } - if (UpdateCommandDesc->pNewLocalWorkSize) { - Command->setLocalSize(UpdateCommandDesc->pNewLocalWorkSize); + if (UpdateCommandDesc.pNewLocalWorkSize) { + Command->setLocalSize(UpdateCommandDesc.pNewLocalWorkSize); } return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( - ur_exp_command_buffer_command_handle_t hCommand, + ur_exp_command_buffer_handle_t hCommandBuffer, uint32_t numKernelUpdates, const ur_exp_command_buffer_update_kernel_launch_desc_t *pUpdateKernelLaunch) try { + // First validate user inputs, as no update should be propagated if there + // are errors. + for (uint32_t i = 0; i < numKernelUpdates; i++) { + UR_CHECK_ERROR(validateCommandDesc(hCommandBuffer, pUpdateKernelLaunch[i])); + } - ur_exp_command_buffer_handle_t CommandBuffer = hCommand->CommandBuffer; - - UR_CHECK_ERROR(validateCommandDesc(hCommand, pUpdateKernelLaunch)); - UR_CHECK_ERROR(updateCommand(hCommand, pUpdateKernelLaunch)); - UR_CHECK_ERROR(updateKernelArguments(hCommand, pUpdateKernelLaunch)); - - // If no worksize is provided make sure we pass nullptr to setKernelParams - // so it can guess the local work size. - const bool ProvidedLocalSize = !hCommand->isNullLocalSize(); - size_t *LocalWorkSize = ProvidedLocalSize ? hCommand->LocalWorkSize : nullptr; - - // Set the number of threads per block to the number of threads per warp - // by default unless user has provided a better number - size_t ThreadsPerBlock[3] = {32u, 1u, 1u}; - size_t BlocksPerGrid[3] = {1u, 1u, 1u}; - hipFunction_t HIPFunc = hCommand->Kernel->get(); - UR_CHECK_ERROR(setKernelParams( - CommandBuffer->Device, hCommand->WorkDim, hCommand->GlobalWorkOffset, - hCommand->GlobalWorkSize, LocalWorkSize, hCommand->Kernel, HIPFunc, - ThreadsPerBlock, BlocksPerGrid)); - - hipKernelNodeParams &Params = hCommand->Params; - - Params.func = HIPFunc; - Params.gridDim.x = BlocksPerGrid[0]; - Params.gridDim.y = BlocksPerGrid[1]; - Params.gridDim.z = BlocksPerGrid[2]; - Params.blockDim.x = ThreadsPerBlock[0]; - Params.blockDim.y = ThreadsPerBlock[1]; - Params.blockDim.z = ThreadsPerBlock[2]; - Params.sharedMemBytes = hCommand->Kernel->getLocalSize(); - Params.kernelParams = - const_cast(hCommand->Kernel->getArgPointers().data()); - - hipGraphNode_t Node = hCommand->Node; - hipGraphExec_t HipGraphExec = CommandBuffer->HIPGraphExec; - UR_CHECK_ERROR(hipGraphExecKernelNodeSetParams(HipGraphExec, Node, &Params)); + // Store changes in config struct in command handle object + for (uint32_t i = 0; i < numKernelUpdates; i++) { + UR_CHECK_ERROR(updateCommand(pUpdateKernelLaunch[i])); + UR_CHECK_ERROR(updateKernelArguments(pUpdateKernelLaunch[i])); + } + + // Propagate changes to HIP driver API + for (uint32_t i = 0; i < numKernelUpdates; i++) { + const auto &UpdateCommandDesc = pUpdateKernelLaunch[i]; + + // If no worksize is provided make sure we pass nullptr to setKernelParams + // so it can guess the local work size. + auto Command = UpdateCommandDesc.hCommand; + const bool ProvidedLocalSize = !Command->isNullLocalSize(); + size_t *LocalWorkSize = + ProvidedLocalSize ? Command->LocalWorkSize : nullptr; + + // Set the number of threads per block to the number of threads per warp + // by default unless user has provided a better number + size_t ThreadsPerBlock[3] = {32u, 1u, 1u}; + size_t BlocksPerGrid[3] = {1u, 1u, 1u}; + hipFunction_t HIPFunc = Command->Kernel->get(); + UR_CHECK_ERROR(setKernelParams( + hCommandBuffer->Device, Command->WorkDim, Command->GlobalWorkOffset, + Command->GlobalWorkSize, LocalWorkSize, Command->Kernel, HIPFunc, + ThreadsPerBlock, BlocksPerGrid)); + + hipKernelNodeParams &Params = Command->Params; + + Params.func = HIPFunc; + Params.gridDim.x = BlocksPerGrid[0]; + Params.gridDim.y = BlocksPerGrid[1]; + Params.gridDim.z = BlocksPerGrid[2]; + Params.blockDim.x = ThreadsPerBlock[0]; + Params.blockDim.y = ThreadsPerBlock[1]; + Params.blockDim.z = ThreadsPerBlock[2]; + Params.sharedMemBytes = Command->Kernel->getLocalSize(); + Params.kernelParams = + const_cast(Command->Kernel->getArgPointers().data()); + + hipGraphNode_t Node = Command->Node; + hipGraphExec_t HipGraphExec = hCommandBuffer->HIPGraphExec; + UR_CHECK_ERROR( + hipGraphExecKernelNodeSetParams(HipGraphExec, Node, &Params)); + } return UR_RESULT_SUCCESS; } catch (ur_result_t Err) { return Err; diff --git a/unified-runtime/source/adapters/level_zero/command_buffer.cpp b/unified-runtime/source/adapters/level_zero/command_buffer.cpp index 4705964190547..e677d8176f65a 100644 --- a/unified-runtime/source/adapters/level_zero/command_buffer.cpp +++ b/unified-runtime/source/adapters/level_zero/command_buffer.cpp @@ -1481,6 +1481,7 @@ ur_result_t waitForDependencies(ur_exp_command_buffer_handle_t CommandBuffer, ur_queue_handle_t Queue, uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList) { + std::scoped_lock Guard(CommandBuffer->Mutex); const bool UseCopyEngine = false; bool MustSignalWaitEvent = true; if (NumEventsInWaitList) { @@ -1761,40 +1762,46 @@ ur_result_t urCommandBufferEnqueueExp( return UR_RESULT_SUCCESS; } +// anonymous namespace of update helper functions +namespace { + /** * Validates contents of the update command description. - * @param[in] Command The command which is being updated. - * @param[in] CommandDesc The update command description. + * @param[in] CommandBuffer The command-buffer which is being updated. + * @param[in] CommandDesc The update command configuration. * @return UR_RESULT_SUCCESS or an error code on failure */ ur_result_t validateCommandDesc( - kernel_command_handle *Command, - const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) { + ur_exp_command_buffer_handle_t CommandBuffer, + const ur_exp_command_buffer_update_kernel_launch_desc_t &CommandDesc) { + std::scoped_lock Guard(CommandBuffer->Mutex); - auto CommandBuffer = Command->CommandBuffer; auto SupportedFeatures = - Command->CommandBuffer->Device->ZeDeviceMutableCmdListsProperties + CommandBuffer->Device->ZeDeviceMutableCmdListsProperties ->mutableCommandFlags; logger::debug("Mutable features supported by device {}", SupportedFeatures); + auto Command = static_cast(CommandDesc.hCommand); + UR_ASSERT(CommandBuffer == Command->CommandBuffer, + UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP); + UR_ASSERT( - !CommandDesc->hNewKernel || + !CommandDesc.hNewKernel || (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION), UR_RESULT_ERROR_UNSUPPORTED_FEATURE); // Check if the provided new kernel is in the list of valid alternatives. - if (CommandDesc->hNewKernel && - !Command->ValidKernelHandles.count(CommandDesc->hNewKernel)) { + if (CommandDesc.hNewKernel && + !Command->ValidKernelHandles.count(CommandDesc.hNewKernel)) { return UR_RESULT_ERROR_INVALID_VALUE; } - if (CommandDesc->newWorkDim != Command->WorkDim && - (!CommandDesc->pNewGlobalWorkOffset || - !CommandDesc->pNewGlobalWorkSize)) { + if (CommandDesc.newWorkDim != Command->WorkDim && + (!CommandDesc.pNewGlobalWorkOffset || !CommandDesc.pNewGlobalWorkSize)) { return UR_RESULT_ERROR_INVALID_VALUE; } // Check if new global offset is provided. - size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset; + size_t *NewGlobalWorkOffset = CommandDesc.pNewGlobalWorkOffset; UR_ASSERT(!NewGlobalWorkOffset || (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET), UR_RESULT_ERROR_UNSUPPORTED_FEATURE); @@ -1807,13 +1814,13 @@ ur_result_t validateCommandDesc( } // Check if new group size is provided. - size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize; + size_t *NewLocalWorkSize = CommandDesc.pNewLocalWorkSize; UR_ASSERT(!NewLocalWorkSize || (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE), UR_RESULT_ERROR_UNSUPPORTED_FEATURE); // Check if new global size is provided and we need to update group count. - size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize; + size_t *NewGlobalWorkSize = CommandDesc.pNewGlobalWorkSize; UR_ASSERT(!NewGlobalWorkSize || (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT), UR_RESULT_ERROR_UNSUPPORTED_FEATURE); @@ -1822,23 +1829,193 @@ ur_result_t validateCommandDesc( UR_RESULT_ERROR_UNSUPPORTED_FEATURE); UR_ASSERT( - (!CommandDesc->numNewMemObjArgs && !CommandDesc->numNewPointerArgs && - !CommandDesc->numNewValueArgs) || + (!CommandDesc.numNewMemObjArgs && !CommandDesc.numNewPointerArgs && + !CommandDesc.numNewValueArgs) || (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS), UR_RESULT_ERROR_UNSUPPORTED_FEATURE); return UR_RESULT_SUCCESS; } +ur_result_t updateKernelHandle(ur_exp_command_buffer_handle_t CommandBuffer, + ur_kernel_handle_t NewKernel, + kernel_command_handle *Command) { + auto Platform = CommandBuffer->Context->getPlatform(); + auto ZeDevice = CommandBuffer->Device->ZeDevice; + ze_kernel_handle_t KernelHandle{}; + ze_kernel_handle_t ZeNewKernel{}; + UR_CALL(getZeKernel(ZeDevice, NewKernel, &ZeNewKernel)); + + ze_command_list_handle_t ZeCommandList = CommandBuffer->ZeComputeCommandList; + KernelHandle = ZeNewKernel; + if (!Platform->ZeMutableCmdListExt.LoaderExtension) { + ZeCommandList = CommandBuffer->ZeComputeCommandListTranslated; + ZE2UR_CALL(zelLoaderTranslateHandle, + (ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&KernelHandle)); + } + + ZE2UR_CALL(Platform->ZeMutableCmdListExt + .zexCommandListUpdateMutableCommandKernelsExp, + (ZeCommandList, 1, &Command->CommandId, &KernelHandle)); + // Set current kernel to be the new kernel + Command->Kernel = NewKernel; + return UR_RESULT_SUCCESS; +} + +ur_result_t setMutableOffsetDesc( + std::unique_ptr> &Desc, + uint32_t Dim, size_t *NewGlobalWorkOffset, const void *NextDesc, + uint64_t CommandID) { + Desc->commandId = CommandID; + DEBUG_LOG(Desc->commandId); + Desc->pNext = NextDesc; + DEBUG_LOG(Desc->pNext); + Desc->offsetX = NewGlobalWorkOffset[0]; + DEBUG_LOG(Desc->offsetX); + Desc->offsetY = Dim >= 2 ? NewGlobalWorkOffset[1] : 0; + DEBUG_LOG(Desc->offsetY); + Desc->offsetZ = Dim == 3 ? NewGlobalWorkOffset[2] : 0; + DEBUG_LOG(Desc->offsetZ); + return UR_RESULT_SUCCESS; +} + +ur_result_t setMutableGroupSizeDesc( + std::unique_ptr> &Desc, + uint32_t Dim, uint32_t *NewLocalWorkSize, const void *NextDesc, + uint64_t CommandID) { + Desc->commandId = CommandID; + DEBUG_LOG(Desc->commandId); + Desc->pNext = NextDesc; + DEBUG_LOG(Desc->pNext); + Desc->groupSizeX = NewLocalWorkSize[0]; + DEBUG_LOG(Desc->groupSizeX); + Desc->groupSizeY = Dim >= 2 ? NewLocalWorkSize[1] : 1; + DEBUG_LOG(Desc->groupSizeY); + Desc->groupSizeZ = Dim == 3 ? NewLocalWorkSize[2] : 1; + DEBUG_LOG(Desc->groupSizeZ); + return UR_RESULT_SUCCESS; +} + +ur_result_t setMutableGroupCountDesc( + std::unique_ptr> &Desc, + ze_group_count_t *ZeThreadGroupDimensions, const void *NextDesc, + uint64_t CommandID) { + Desc->commandId = CommandID; + DEBUG_LOG(Desc->commandId); + Desc->pNext = NextDesc; + DEBUG_LOG(Desc->pNext); + Desc->pGroupCount = ZeThreadGroupDimensions; + DEBUG_LOG(Desc->pGroupCount->groupCountX); + DEBUG_LOG(Desc->pGroupCount->groupCountY); + DEBUG_LOG(Desc->pGroupCount->groupCountZ); + return UR_RESULT_SUCCESS; +} + +ur_result_t setMutableMemObjArgDesc( + ur_exp_command_buffer_handle_t CommandBuffer, + std::unique_ptr> &Desc, + const ur_exp_command_buffer_update_memobj_arg_desc_t &NewMemObjArgDesc, + const void *NextDesc, uint64_t CommandID) { + + const ur_kernel_arg_mem_obj_properties_t *Properties = + NewMemObjArgDesc.pProperties; + ur_mem_handle_t_::access_mode_t UrAccessMode = ur_mem_handle_t_::read_write; + if (Properties) { + switch (Properties->memoryAccess) { + case UR_MEM_FLAG_READ_WRITE: + UrAccessMode = ur_mem_handle_t_::read_write; + break; + case UR_MEM_FLAG_WRITE_ONLY: + UrAccessMode = ur_mem_handle_t_::write_only; + break; + case UR_MEM_FLAG_READ_ONLY: + UrAccessMode = ur_mem_handle_t_::read_only; + break; + default: + return UR_RESULT_ERROR_INVALID_ARGUMENT; + } + } + + ur_mem_handle_t NewMemObjArg = NewMemObjArgDesc.hNewMemObjArg; + // The NewMemObjArg may be a NULL pointer in which case a NULL value is used + // for the kernel argument declared as a pointer to global or constant + // memory. + char **ZeHandlePtr = nullptr; + if (NewMemObjArg) { + UR_CALL(NewMemObjArg->getZeHandlePtr(ZeHandlePtr, UrAccessMode, + CommandBuffer->Device, nullptr, 0u)); + } + + Desc->commandId = CommandID; + DEBUG_LOG(Desc->commandId); + Desc->pNext = NextDesc; + DEBUG_LOG(Desc->pNext); + Desc->argIndex = NewMemObjArgDesc.argIndex; + DEBUG_LOG(Desc->argIndex); + Desc->argSize = sizeof(void *); + DEBUG_LOG(Desc->argSize); + Desc->pArgValue = ZeHandlePtr; + DEBUG_LOG(Desc->pArgValue); + return UR_RESULT_SUCCESS; +} + +ur_result_t setMutablePointerArgDesc( + std::unique_ptr> &Desc, + const ur_exp_command_buffer_update_pointer_arg_desc_t &NewPointerArgDesc, + const void *NextDesc, uint64_t CommandID) { + Desc->commandId = CommandID; + DEBUG_LOG(Desc->commandId); + Desc->pNext = NextDesc; + DEBUG_LOG(Desc->pNext); + Desc->argIndex = NewPointerArgDesc.argIndex; + DEBUG_LOG(Desc->argIndex); + Desc->argSize = sizeof(void *); + DEBUG_LOG(Desc->argSize); + Desc->pArgValue = NewPointerArgDesc.pNewPointerArg; + DEBUG_LOG(Desc->pArgValue); + return UR_RESULT_SUCCESS; +} + +ur_result_t setMutableValueArgDesc( + std::unique_ptr> &Desc, + const ur_exp_command_buffer_update_value_arg_desc_t &NewValueArgDesc, + const void *NextDesc, uint64_t CommandID) { + Desc->commandId = CommandID; + DEBUG_LOG(Desc->commandId); + Desc->pNext = NextDesc; + DEBUG_LOG(Desc->pNext); + Desc->argIndex = NewValueArgDesc.argIndex; + DEBUG_LOG(Desc->argIndex); + Desc->argSize = NewValueArgDesc.argSize; + DEBUG_LOG(Desc->argSize); + // OpenCL: "the arg_value pointer can be NULL or point to a NULL value + // in which case a NULL value will be used as the value for the argument + // declared as a pointer to global or constant memory in the kernel" + // + // We don't know the type of the argument but it seems that the only time + // SYCL RT would send a pointer to NULL in 'arg_value' is when the argument + // is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL. + const void *ArgValuePtr = NewValueArgDesc.pNewValueArg; + if (NewValueArgDesc.argSize == sizeof(void *) && ArgValuePtr && + *(void **)(const_cast(ArgValuePtr)) == nullptr) { + ArgValuePtr = nullptr; + } + Desc->pArgValue = ArgValuePtr; + DEBUG_LOG(Desc->pArgValue); + return UR_RESULT_SUCCESS; +} + /** * Update the kernel command with the new values. - * @param[in] Command The command which is being updated. - * @param[in] CommandDesc The update command description. + * @param[in] CommandBuffer The command-buffer which is being updated. + * @param[in] NumKernelUpdates Length of /p CommadnDescs. + * @param[in] CommandDescs List of update command descriptions. * @return UR_RESULT_SUCCESS or an error code on failure */ -ur_result_t updateKernelCommand( - kernel_command_handle *Command, - const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) { +ur_result_t updateCommandBuffer( + ur_exp_command_buffer_handle_t CommandBuffer, uint32_t NumKernelUpdates, + const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDescs) { + std::scoped_lock Guard(CommandBuffer->Mutex); // We need the created descriptors to live till the point when // zeCommandListUpdateMutableCommandsExp is called at the end of the @@ -1850,257 +2027,168 @@ ur_result_t updateKernelCommand( std::unique_ptr>>> Descs; - const auto CommandBuffer = Command->CommandBuffer; - const void *NextDesc = nullptr; - auto Platform = CommandBuffer->Context->getPlatform(); - auto ZeDevice = CommandBuffer->Device->ZeDevice; - - uint32_t Dim = CommandDesc->newWorkDim; - size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset; - size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize; - size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize; - - // Kernel handle must be updated first for a given CommandId if required - ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel; - - if (NewKernel && Command->Kernel != NewKernel) { - ze_kernel_handle_t KernelHandle{}; - ze_kernel_handle_t ZeNewKernel{}; - UR_CALL(getZeKernel(ZeDevice, NewKernel, &ZeNewKernel)); - - ze_command_list_handle_t ZeCommandList = - CommandBuffer->ZeComputeCommandList; - KernelHandle = ZeNewKernel; - if (!Platform->ZeMutableCmdListExt.LoaderExtension) { - ZeCommandList = CommandBuffer->ZeComputeCommandListTranslated; - ZE2UR_CALL(zelLoaderTranslateHandle, - (ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&KernelHandle)); + std::vector ZeThreadGroupDimensionsList( + NumKernelUpdates, ze_group_count_t{1, 1, 1}); + const void *NextDesc = nullptr; // Used for pointer chaining + // Iterate over every UR update descriptor struct, which corresponds to + // several L0 update descriptor structs. + for (uint32_t i = 0; i < NumKernelUpdates; i++) { + const auto &CommandDesc = CommandDescs[i]; + auto Command = static_cast(CommandDesc.hCommand); + + std::scoped_lock Guard( + Command->Mutex, Command->Kernel->Mutex); + + // Kernel handle must be updated first for a given CommandId if required + ur_kernel_handle_t NewKernel = CommandDesc.hNewKernel; + if (NewKernel && Command->Kernel != NewKernel) { + updateKernelHandle(CommandBuffer, NewKernel, Command); } - ZE2UR_CALL(Platform->ZeMutableCmdListExt - .zexCommandListUpdateMutableCommandKernelsExp, - (ZeCommandList, 1, &Command->CommandId, &KernelHandle)); - // Set current kernel to be the new kernel - Command->Kernel = NewKernel; - } - - // Check if a new global offset is provided. - if (NewGlobalWorkOffset && Dim > 0) { - auto MutableGroupOffestDesc = - std::make_unique>(); - MutableGroupOffestDesc->commandId = Command->CommandId; - DEBUG_LOG(MutableGroupOffestDesc->commandId); - MutableGroupOffestDesc->pNext = NextDesc; - DEBUG_LOG(MutableGroupOffestDesc->pNext); - MutableGroupOffestDesc->offsetX = NewGlobalWorkOffset[0]; - DEBUG_LOG(MutableGroupOffestDesc->offsetX); - MutableGroupOffestDesc->offsetY = Dim >= 2 ? NewGlobalWorkOffset[1] : 0; - DEBUG_LOG(MutableGroupOffestDesc->offsetY); - MutableGroupOffestDesc->offsetZ = Dim == 3 ? NewGlobalWorkOffset[2] : 0; - DEBUG_LOG(MutableGroupOffestDesc->offsetZ); - - NextDesc = MutableGroupOffestDesc.get(); - Descs.push_back(std::move(MutableGroupOffestDesc)); - } - - // Check if a new group size is provided. - if (NewLocalWorkSize && Dim > 0) { - auto MutableGroupSizeDesc = - std::make_unique>(); - MutableGroupSizeDesc->commandId = Command->CommandId; - DEBUG_LOG(MutableGroupSizeDesc->commandId); - MutableGroupSizeDesc->pNext = NextDesc; - DEBUG_LOG(MutableGroupSizeDesc->pNext); - MutableGroupSizeDesc->groupSizeX = NewLocalWorkSize[0]; - DEBUG_LOG(MutableGroupSizeDesc->groupSizeX); - MutableGroupSizeDesc->groupSizeY = Dim >= 2 ? NewLocalWorkSize[1] : 1; - DEBUG_LOG(MutableGroupSizeDesc->groupSizeY); - MutableGroupSizeDesc->groupSizeZ = Dim == 3 ? NewLocalWorkSize[2] : 1; - DEBUG_LOG(MutableGroupSizeDesc->groupSizeZ); - - NextDesc = MutableGroupSizeDesc.get(); - Descs.push_back(std::move(MutableGroupSizeDesc)); - } - - // Check if a new global or local size is provided and if so we need to update - // the group count. - ze_group_count_t ZeThreadGroupDimensions{1, 1, 1}; - if ((NewGlobalWorkSize || NewLocalWorkSize) && Dim > 0) { - // If a new global work size is provided update that in the command, - // otherwise the previous work group size will be used - if (NewGlobalWorkSize) { - Command->WorkDim = Dim; - Command->setGlobalWorkSize(NewGlobalWorkSize); + uint32_t Dim = CommandDesc.newWorkDim; + // Update global offset if provided. + if (size_t *NewGlobalWorkOffset = CommandDesc.pNewGlobalWorkOffset; + NewGlobalWorkOffset && Dim > 0) { + auto MutableGroupOffestDesc = + std::make_unique>(); + UR_CALL(setMutableOffsetDesc(MutableGroupOffestDesc, Dim, + NewGlobalWorkOffset, NextDesc, + Command->CommandId)); + NextDesc = MutableGroupOffestDesc.get(); + Descs.push_back(std::move(MutableGroupOffestDesc)); } - // If a new global work size is provided but a new local work size is not - // then we still need to update local work size based on the size suggested - // by the driver for the kernel. - bool UpdateWGSize = NewLocalWorkSize == nullptr; - - ze_kernel_handle_t ZeKernel{}; - UR_CALL(getZeKernel(ZeDevice, Command->Kernel, &ZeKernel)); - - uint32_t WG[3]; - UR_CALL(calculateKernelWorkDimensions( - ZeKernel, CommandBuffer->Device, ZeThreadGroupDimensions, WG, Dim, - Command->GlobalWorkSize, NewLocalWorkSize)); - - auto MutableGroupCountDesc = - std::make_unique>(); - MutableGroupCountDesc->commandId = Command->CommandId; - DEBUG_LOG(MutableGroupCountDesc->commandId); - MutableGroupCountDesc->pNext = NextDesc; - DEBUG_LOG(MutableGroupCountDesc->pNext); - MutableGroupCountDesc->pGroupCount = &ZeThreadGroupDimensions; - DEBUG_LOG(MutableGroupCountDesc->pGroupCount->groupCountX); - DEBUG_LOG(MutableGroupCountDesc->pGroupCount->groupCountY); - DEBUG_LOG(MutableGroupCountDesc->pGroupCount->groupCountZ); - - NextDesc = MutableGroupCountDesc.get(); - Descs.push_back(std::move(MutableGroupCountDesc)); - - if (UpdateWGSize) { + + // Update local-size/group-size if provided. + size_t *NewLocalWorkSize = CommandDesc.pNewLocalWorkSize; + if (NewLocalWorkSize && Dim > 0) { auto MutableGroupSizeDesc = std::make_unique>(); - MutableGroupSizeDesc->commandId = Command->CommandId; - DEBUG_LOG(MutableGroupSizeDesc->commandId); - MutableGroupSizeDesc->pNext = NextDesc; - DEBUG_LOG(MutableGroupSizeDesc->pNext); - MutableGroupSizeDesc->groupSizeX = WG[0]; - DEBUG_LOG(MutableGroupSizeDesc->groupSizeX); - MutableGroupSizeDesc->groupSizeY = WG[1]; - DEBUG_LOG(MutableGroupSizeDesc->groupSizeY); - MutableGroupSizeDesc->groupSizeZ = WG[2]; - DEBUG_LOG(MutableGroupSizeDesc->groupSizeZ); + uint32_t WG[3] = {1, 1, 1}; + for (size_t d = 0; d < Dim; d++) { + WG[d] = NewLocalWorkSize[d]; + } + + UR_CALL(setMutableGroupSizeDesc(MutableGroupSizeDesc, Dim, WG, NextDesc, + Command->CommandId)); NextDesc = MutableGroupSizeDesc.get(); Descs.push_back(std::move(MutableGroupSizeDesc)); } - } - // Check if new memory object arguments are provided. - for (uint32_t NewMemObjArgNum = CommandDesc->numNewMemObjArgs; - NewMemObjArgNum-- > 0;) { - ur_exp_command_buffer_update_memobj_arg_desc_t NewMemObjArgDesc = - CommandDesc->pNewMemObjArgList[NewMemObjArgNum]; - const ur_kernel_arg_mem_obj_properties_t *Properties = - NewMemObjArgDesc.pProperties; - ur_mem_handle_t_::access_mode_t UrAccessMode = ur_mem_handle_t_::read_write; - if (Properties) { - switch (Properties->memoryAccess) { - case UR_MEM_FLAG_READ_WRITE: - UrAccessMode = ur_mem_handle_t_::read_write; - break; - case UR_MEM_FLAG_WRITE_ONLY: - UrAccessMode = ur_mem_handle_t_::write_only; - break; - case UR_MEM_FLAG_READ_ONLY: - UrAccessMode = ur_mem_handle_t_::read_only; - break; - default: - return UR_RESULT_ERROR_INVALID_ARGUMENT; + // Update global-size/group-count if provided, and also + // local-size/group-size if required + if (size_t *NewGlobalWorkSize = CommandDesc.pNewGlobalWorkSize; + (NewGlobalWorkSize || NewLocalWorkSize) && Dim > 0) { + + // If a new global work size is provided update that in the command, + // otherwise the previous work group size will be used + if (NewGlobalWorkSize) { + Command->WorkDim = Dim; + Command->setGlobalWorkSize(NewGlobalWorkSize); + } + + // If a new global work size is provided but a new local work size is not + // then we still need to update local work size based on the size + // suggested + // by the driver for the kernel. + bool UpdateWGSize = NewLocalWorkSize == nullptr; + + ze_kernel_handle_t ZeKernel{}; + auto ZeDevice = CommandBuffer->Device->ZeDevice; + UR_CALL(getZeKernel(ZeDevice, Command->Kernel, &ZeKernel)); + + uint32_t WG[3]; + + ze_group_count_t &ZeThreadGroupDimensions = + ZeThreadGroupDimensionsList[i]; + UR_CALL(calculateKernelWorkDimensions( + ZeKernel, CommandBuffer->Device, ZeThreadGroupDimensions, WG, Dim, + Command->GlobalWorkSize, NewLocalWorkSize)); + + auto MutableGroupCountDesc = + std::make_unique>(); + UR_CALL(setMutableGroupCountDesc(MutableGroupCountDesc, + &ZeThreadGroupDimensions, NextDesc, + Command->CommandId)); + NextDesc = MutableGroupCountDesc.get(); + Descs.push_back(std::move(MutableGroupCountDesc)); + + if (UpdateWGSize) { + auto MutableGroupSizeDesc = + std::make_unique>(); + UR_CALL(setMutableGroupSizeDesc(MutableGroupSizeDesc, Dim, WG, NextDesc, + Command->CommandId)); + NextDesc = MutableGroupSizeDesc.get(); + Descs.push_back(std::move(MutableGroupSizeDesc)); } } - ur_mem_handle_t NewMemObjArg = NewMemObjArgDesc.hNewMemObjArg; - // The NewMemObjArg may be a NULL pointer in which case a NULL value is used - // for the kernel argument declared as a pointer to global or constant - // memory. - char **ZeHandlePtr = nullptr; - if (NewMemObjArg) { - UR_CALL(NewMemObjArg->getZeHandlePtr(ZeHandlePtr, UrAccessMode, - CommandBuffer->Device, nullptr, 0u)); + // Update memory object arguments if provided. + for (uint32_t NewMemObjArgNum = CommandDesc.numNewMemObjArgs; + NewMemObjArgNum-- > 0;) { + ur_exp_command_buffer_update_memobj_arg_desc_t NewMemObjArgDesc = + CommandDesc.pNewMemObjArgList[NewMemObjArgNum]; + + auto ZeMutableArgDesc = + std::make_unique>(); + + UR_CALL(setMutableMemObjArgDesc(CommandBuffer, ZeMutableArgDesc, + NewMemObjArgDesc, NextDesc, + Command->CommandId)); + + NextDesc = ZeMutableArgDesc.get(); + Descs.push_back(std::move(ZeMutableArgDesc)); } - auto ZeMutableArgDesc = - std::make_unique>(); - ZeMutableArgDesc->commandId = Command->CommandId; - DEBUG_LOG(ZeMutableArgDesc->commandId); - ZeMutableArgDesc->pNext = NextDesc; - DEBUG_LOG(ZeMutableArgDesc->pNext); - ZeMutableArgDesc->argIndex = NewMemObjArgDesc.argIndex; - DEBUG_LOG(ZeMutableArgDesc->argIndex); - ZeMutableArgDesc->argSize = sizeof(void *); - DEBUG_LOG(ZeMutableArgDesc->argSize); - ZeMutableArgDesc->pArgValue = ZeHandlePtr; - DEBUG_LOG(ZeMutableArgDesc->pArgValue); - - NextDesc = ZeMutableArgDesc.get(); - Descs.push_back(std::move(ZeMutableArgDesc)); - } - - // Check if there are new pointer arguments. - for (uint32_t NewPointerArgNum = CommandDesc->numNewPointerArgs; - NewPointerArgNum-- > 0;) { - ur_exp_command_buffer_update_pointer_arg_desc_t NewPointerArgDesc = - CommandDesc->pNewPointerArgList[NewPointerArgNum]; - - auto ZeMutableArgDesc = - std::make_unique>(); - ZeMutableArgDesc->commandId = Command->CommandId; - DEBUG_LOG(ZeMutableArgDesc->commandId); - ZeMutableArgDesc->pNext = NextDesc; - DEBUG_LOG(ZeMutableArgDesc->pNext); - ZeMutableArgDesc->argIndex = NewPointerArgDesc.argIndex; - DEBUG_LOG(ZeMutableArgDesc->argIndex); - ZeMutableArgDesc->argSize = sizeof(void *); - DEBUG_LOG(ZeMutableArgDesc->argSize); - ZeMutableArgDesc->pArgValue = NewPointerArgDesc.pNewPointerArg; - DEBUG_LOG(ZeMutableArgDesc->pArgValue); - - NextDesc = ZeMutableArgDesc.get(); - Descs.push_back(std::move(ZeMutableArgDesc)); - } - - // Check if there are new value arguments. - for (uint32_t NewValueArgNum = CommandDesc->numNewValueArgs; - NewValueArgNum-- > 0;) { - ur_exp_command_buffer_update_value_arg_desc_t NewValueArgDesc = - CommandDesc->pNewValueArgList[NewValueArgNum]; - - auto ZeMutableArgDesc = - std::make_unique>(); - ZeMutableArgDesc->commandId = Command->CommandId; - DEBUG_LOG(ZeMutableArgDesc->commandId); - ZeMutableArgDesc->pNext = NextDesc; - DEBUG_LOG(ZeMutableArgDesc->pNext); - ZeMutableArgDesc->argIndex = NewValueArgDesc.argIndex; - DEBUG_LOG(ZeMutableArgDesc->argIndex); - ZeMutableArgDesc->argSize = NewValueArgDesc.argSize; - DEBUG_LOG(ZeMutableArgDesc->argSize); - // OpenCL: "the arg_value pointer can be NULL or point to a NULL value - // in which case a NULL value will be used as the value for the argument - // declared as a pointer to global or constant memory in the kernel" - // - // We don't know the type of the argument but it seems that the only time - // SYCL RT would send a pointer to NULL in 'arg_value' is when the argument - // is a NULL pointer. Treat a pointer to NULL in 'arg_value' as a NULL. - const void *ArgValuePtr = NewValueArgDesc.pNewValueArg; - if (NewValueArgDesc.argSize == sizeof(void *) && ArgValuePtr && - *(void **)(const_cast(ArgValuePtr)) == nullptr) { - ArgValuePtr = nullptr; + // Update pointer arguments if provided. + for (uint32_t NewPointerArgNum = CommandDesc.numNewPointerArgs; + NewPointerArgNum-- > 0;) { + ur_exp_command_buffer_update_pointer_arg_desc_t NewPointerArgDesc = + CommandDesc.pNewPointerArgList[NewPointerArgNum]; + + auto ZeMutableArgDesc = + std::make_unique>(); + + UR_CALL(setMutablePointerArgDesc(ZeMutableArgDesc, NewPointerArgDesc, + NextDesc, Command->CommandId)); + + NextDesc = ZeMutableArgDesc.get(); + Descs.push_back(std::move(ZeMutableArgDesc)); } - ZeMutableArgDesc->pArgValue = ArgValuePtr; - DEBUG_LOG(ZeMutableArgDesc->pArgValue); - NextDesc = ZeMutableArgDesc.get(); - Descs.push_back(std::move(ZeMutableArgDesc)); - } + // Update value arguments if provided. + for (uint32_t NewValueArgNum = CommandDesc.numNewValueArgs; + NewValueArgNum-- > 0;) { + ur_exp_command_buffer_update_value_arg_desc_t NewValueArgDesc = + CommandDesc.pNewValueArgList[NewValueArgNum]; - ZeStruct MutableCommandDesc; - MutableCommandDesc.pNext = NextDesc; - MutableCommandDesc.flags = 0; + auto ZeMutableArgDesc = + std::make_unique>(); + + UR_CALL(setMutableValueArgDesc(ZeMutableArgDesc, NewValueArgDesc, + NextDesc, Command->CommandId)); + + NextDesc = ZeMutableArgDesc.get(); + Descs.push_back(std::move(ZeMutableArgDesc)); + } + } + auto Platform = CommandBuffer->Context->getPlatform(); ze_command_list_handle_t ZeCommandList = CommandBuffer->ZeComputeCommandListTranslated; if (Platform->ZeMutableCmdListExt.LoaderExtension) { ZeCommandList = CommandBuffer->ZeComputeCommandList; } + ZeStruct MutableCommandDesc{}; + MutableCommandDesc.pNext = NextDesc; + MutableCommandDesc.flags = 0; ZE2UR_CALL( Platform->ZeMutableCmdListExt.zexCommandListUpdateMutableCommandsExp, (ZeCommandList, &MutableCommandDesc)); + ZE2UR_CALL(zeCommandListClose, (CommandBuffer->ZeComputeCommandList)); + return UR_RESULT_SUCCESS; } @@ -2135,30 +2223,20 @@ waitForOngoingExecution(ur_exp_command_buffer_handle_t CommandBuffer) { return UR_RESULT_SUCCESS; } +} // namespace + ur_result_t urCommandBufferUpdateKernelLaunchExp( - ur_exp_command_buffer_command_handle_t Command, + ur_exp_command_buffer_handle_t CommandBuffer, uint32_t numKernelUpdates, const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) { - UR_ASSERT(Command->CommandBuffer->IsUpdatable, - UR_RESULT_ERROR_INVALID_OPERATION); - - auto KernelCommandHandle = static_cast(Command); - - UR_ASSERT(KernelCommandHandle->Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE); - - // Lock command, kernel and command-buffer for update. - std::scoped_lock Guard( - Command->Mutex, Command->CommandBuffer->Mutex, - KernelCommandHandle->Kernel->Mutex); - - UR_ASSERT(Command->CommandBuffer->IsFinalized, + UR_ASSERT(CommandBuffer->IsUpdatable && CommandBuffer->IsFinalized, UR_RESULT_ERROR_INVALID_OPERATION); + for (uint32_t i = 0; i < numKernelUpdates; i++) { + UR_CALL(validateCommandDesc(CommandBuffer, CommandDesc[i])); + } - UR_CALL(validateCommandDesc(KernelCommandHandle, CommandDesc)); - UR_CALL(waitForOngoingExecution(Command->CommandBuffer)); - UR_CALL(updateKernelCommand(KernelCommandHandle, CommandDesc)); + UR_CALL(waitForOngoingExecution(CommandBuffer)); - ZE2UR_CALL(zeCommandListClose, - (Command->CommandBuffer->ZeComputeCommandList)); + UR_CALL(updateCommandBuffer(CommandBuffer, numKernelUpdates, CommandDesc)); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/level_zero/ur_interface_loader.hpp b/unified-runtime/source/adapters/level_zero/ur_interface_loader.hpp index 8803b86b07847..518204b6525f5 100644 --- a/unified-runtime/source/adapters/level_zero/ur_interface_loader.hpp +++ b/unified-runtime/source/adapters/level_zero/ur_interface_loader.hpp @@ -664,7 +664,7 @@ ur_result_t urCommandBufferEnqueueExp( uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent); ur_result_t urCommandBufferUpdateKernelLaunchExp( - ur_exp_command_buffer_command_handle_t hCommand, + ur_exp_command_buffer_handle_t hCommandBuffer, uint32_t numKernelUpdates, const ur_exp_command_buffer_update_kernel_launch_desc_t *pUpdateKernelLaunch); ur_result_t urCommandBufferUpdateSignalEventExp( diff --git a/unified-runtime/source/adapters/level_zero/v2/api.cpp b/unified-runtime/source/adapters/level_zero/v2/api.cpp index 129db02594b5e..b46e2c0c6fa4c 100644 --- a/unified-runtime/source/adapters/level_zero/v2/api.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/api.cpp @@ -171,7 +171,7 @@ ur_result_t urBindlessImagesReleaseExternalSemaphoreExp( } ur_result_t urCommandBufferUpdateKernelLaunchExp( - ur_exp_command_buffer_command_handle_t hCommand, + ur_exp_command_buffer_handle_t hCommandBuffer, uint32_t numKernelUpdates, const ur_exp_command_buffer_update_kernel_launch_desc_t *pUpdateKernelLaunch) { logger::error("{} function not implemented!", __FUNCTION__); diff --git a/unified-runtime/source/adapters/mock/ur_mockddi.cpp b/unified-runtime/source/adapters/mock/ur_mockddi.cpp index 9a4ace593e430..9e55370df3ab1 100644 --- a/unified-runtime/source/adapters/mock/ur_mockddi.cpp +++ b/unified-runtime/source/adapters/mock/ur_mockddi.cpp @@ -9820,15 +9820,18 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urCommandBufferUpdateKernelLaunchExp __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( - /// [in] Handle of the command-buffer kernel command to update. - ur_exp_command_buffer_command_handle_t hCommand, - /// [in] Struct defining how the kernel command is to be updated. + /// [in] Handle of the command-buffer object. + ur_exp_command_buffer_handle_t hCommandBuffer, + /// [in] Length of pUpdateKernelLaunch. + uint32_t numKernelUpdates, + /// [in][range(0, numKernelUpdates)] List of structs defining how a + /// kernel commands are to be updated. const ur_exp_command_buffer_update_kernel_launch_desc_t *pUpdateKernelLaunch) try { ur_result_t result = UR_RESULT_SUCCESS; ur_command_buffer_update_kernel_launch_exp_params_t params = { - &hCommand, &pUpdateKernelLaunch}; + &hCommandBuffer, &numKernelUpdates, &pUpdateKernelLaunch}; auto beforeCallback = reinterpret_cast( mock::getCallbacks().get_before_callback( diff --git a/unified-runtime/source/adapters/native_cpu/command_buffer.cpp b/unified-runtime/source/adapters/native_cpu/command_buffer.cpp index b49974b3c06ae..135b0b8de102c 100644 --- a/unified-runtime/source/adapters/native_cpu/command_buffer.cpp +++ b/unified-runtime/source/adapters/native_cpu/command_buffer.cpp @@ -176,7 +176,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( - ur_exp_command_buffer_command_handle_t, + ur_exp_command_buffer_handle_t, uint32_t, const ur_exp_command_buffer_update_kernel_launch_desc_t *) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } diff --git a/unified-runtime/source/adapters/opencl/command_buffer.cpp b/unified-runtime/source/adapters/opencl/command_buffer.cpp index 3707dd4e5d3d0..6d0d284f5318b 100644 --- a/unified-runtime/source/adapters/opencl/command_buffer.cpp +++ b/unified-runtime/source/adapters/opencl/command_buffer.cpp @@ -470,14 +470,14 @@ namespace { void updateKernelPointerArgs( std::vector &CLUSMArgs, const ur_exp_command_buffer_update_kernel_launch_desc_t - *pUpdateKernelLaunch) { + &pUpdateKernelLaunch) { // WARNING - This relies on USM and SVM using the same implementation, // which is not guaranteed. // See https://github.com/KhronosGroup/OpenCL-Docs/issues/843 - const uint32_t NumPointerArgs = pUpdateKernelLaunch->numNewPointerArgs; + const uint32_t NumPointerArgs = pUpdateKernelLaunch.numNewPointerArgs; const ur_exp_command_buffer_update_pointer_arg_desc_t *ArgPointerList = - pUpdateKernelLaunch->pNewPointerArgList; + pUpdateKernelLaunch.pNewPointerArgList; CLUSMArgs.resize(NumPointerArgs); for (uint32_t i = 0; i < NumPointerArgs; i++) { @@ -491,13 +491,13 @@ void updateKernelPointerArgs( void updateKernelArgs(std::vector &CLArgs, const ur_exp_command_buffer_update_kernel_launch_desc_t - *pUpdateKernelLaunch) { - const uint32_t NumMemobjArgs = pUpdateKernelLaunch->numNewMemObjArgs; + &pUpdateKernelLaunch) { + const uint32_t NumMemobjArgs = pUpdateKernelLaunch.numNewMemObjArgs; const ur_exp_command_buffer_update_memobj_arg_desc_t *ArgMemobjList = - pUpdateKernelLaunch->pNewMemObjArgList; - const uint32_t NumValueArgs = pUpdateKernelLaunch->numNewValueArgs; + pUpdateKernelLaunch.pNewMemObjArgList; + const uint32_t NumValueArgs = pUpdateKernelLaunch.numNewValueArgs; const ur_exp_command_buffer_update_value_arg_desc_t *ArgValueList = - pUpdateKernelLaunch->pNewValueArgList; + pUpdateKernelLaunch.pNewValueArgList; for (uint32_t i = 0; i < NumMemobjArgs; i++) { const ur_exp_command_buffer_update_memobj_arg_desc_t &URMemObjArg = @@ -525,43 +525,52 @@ void updateKernelArgs(std::vector &CLArgs, } ur_result_t validateCommandDesc( - ur_exp_command_buffer_command_handle_t Command, - const ur_exp_command_buffer_update_kernel_launch_desc_t *UpdateDesc) { + ur_exp_command_buffer_handle_t CommandBuffer, + const ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) { + if (!CommandBuffer->IsFinalized || !CommandBuffer->IsUpdatable) { + return UR_RESULT_ERROR_INVALID_OPERATION; + } + + auto Command = UpdateDesc.hCommand; + if (CommandBuffer != Command->hCommandBuffer) { + return UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP; + } + // Kernel handle updates are not yet supported. - if (UpdateDesc->hNewKernel && UpdateDesc->hNewKernel != Command->Kernel) { + if (UpdateDesc.hNewKernel && UpdateDesc.hNewKernel != Command->Kernel) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } // Error if work-dim has changed but a new global size/offset hasn't been set - if (UpdateDesc->newWorkDim != Command->WorkDim && - (!UpdateDesc->pNewGlobalWorkOffset || !UpdateDesc->pNewGlobalWorkSize)) { - return UR_RESULT_ERROR_INVALID_OPERATION; + if (UpdateDesc.newWorkDim != Command->WorkDim && + (!UpdateDesc.pNewGlobalWorkOffset || !UpdateDesc.pNewGlobalWorkSize)) { + return UR_RESULT_ERROR_INVALID_VALUE; } // Verify that the device supports updating the aspects of the kernel that // the user is requesting. - ur_device_handle_t URDevice = Command->hCommandBuffer->hDevice; + ur_device_handle_t URDevice = CommandBuffer->hDevice; cl_device_id CLDevice = cl_adapter::cast(URDevice); ur_device_command_buffer_update_capability_flags_t UpdateCapabilities = 0; CL_RETURN_ON_FAILURE( getDeviceCommandBufferUpdateCapabilities(CLDevice, UpdateCapabilities)); - size_t *NewGlobalWorkOffset = UpdateDesc->pNewGlobalWorkOffset; + size_t *NewGlobalWorkOffset = UpdateDesc.pNewGlobalWorkOffset; UR_ASSERT( !NewGlobalWorkOffset || (UpdateCapabilities & UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_OFFSET), UR_RESULT_ERROR_UNSUPPORTED_FEATURE); - size_t *NewLocalWorkSize = UpdateDesc->pNewLocalWorkSize; + size_t *NewLocalWorkSize = UpdateDesc.pNewLocalWorkSize; UR_ASSERT( !NewLocalWorkSize || (UpdateCapabilities & UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_LOCAL_WORK_SIZE), UR_RESULT_ERROR_UNSUPPORTED_FEATURE); - size_t *NewGlobalWorkSize = UpdateDesc->pNewGlobalWorkSize; + size_t *NewGlobalWorkSize = UpdateDesc.pNewGlobalWorkSize; UR_ASSERT( !NewGlobalWorkSize || (UpdateCapabilities & @@ -574,8 +583,8 @@ ur_result_t validateCommandDesc( UR_RESULT_ERROR_UNSUPPORTED_FEATURE); UR_ASSERT( - (!UpdateDesc->numNewMemObjArgs && !UpdateDesc->numNewPointerArgs && - !UpdateDesc->numNewValueArgs) || + (!UpdateDesc.numNewMemObjArgs && !UpdateDesc.numNewPointerArgs && + !UpdateDesc.numNewValueArgs) || (UpdateCapabilities & UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_ARGUMENTS), UR_RESULT_ERROR_UNSUPPORTED_FEATURE); @@ -585,78 +594,97 @@ ur_result_t validateCommandDesc( } // end anonymous namespace UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( - ur_exp_command_buffer_command_handle_t hCommand, + ur_exp_command_buffer_handle_t hCommandBuffer, uint32_t numKernelUpdates, const ur_exp_command_buffer_update_kernel_launch_desc_t *pUpdateKernelLaunch) { + for (uint32_t i = 0; i < numKernelUpdates; i++) { + UR_RETURN_ON_FAILURE( + validateCommandDesc(hCommandBuffer, pUpdateKernelLaunch[i])); + } - UR_RETURN_ON_FAILURE(validateCommandDesc(hCommand, pUpdateKernelLaunch)); - - ur_exp_command_buffer_handle_t hCommandBuffer = hCommand->hCommandBuffer; cl_context CLContext = cl_adapter::cast(hCommandBuffer->hContext); - cl_ext::clUpdateMutableCommandsKHR_fn clUpdateMutableCommandsKHR = nullptr; UR_RETURN_ON_FAILURE( cl_ext::getExtFuncFromContext( CLContext, cl_ext::ExtFuncPtrCache->clUpdateMutableCommandsKHRCache, cl_ext::UpdateMutableCommandsName, &clUpdateMutableCommandsKHR)); - if (!hCommandBuffer->IsFinalized || !hCommandBuffer->IsUpdatable) - return UR_RESULT_ERROR_INVALID_OPERATION; - - // Find the CL USM pointer arguments to the kernel to update - std::vector CLUSMArgs; - updateKernelPointerArgs(CLUSMArgs, pUpdateKernelLaunch); - - // Find the memory object and scalar arguments to the kernel to update - std::vector CLArgs; - - updateKernelArgs(CLArgs, pUpdateKernelLaunch); + std::vector ConfigList(numKernelUpdates); + std::vector> CLUSMArgsList( + numKernelUpdates); + std::vector> CLArgsList( + numKernelUpdates); - // Find the updated ND-Range configuration of the kernel. - std::vector CLGlobalWorkOffset, CLGlobalWorkSize, CLLocalWorkSize; - cl_uint &CommandWorkDim = hCommand->WorkDim; + std::vector> CLGlobalWorkOffsetList(numKernelUpdates); + std::vector> CLGlobalWorkSizeList(numKernelUpdates); + std::vector> CLLocalWorkSizeList(numKernelUpdates); // Lambda for N-Dimensional update - auto updateNDRange = [CommandWorkDim](std::vector &NDRange, - size_t *UpdatePtr) { - NDRange.resize(CommandWorkDim, 0); - const size_t CopySize = sizeof(size_t) * CommandWorkDim; + auto updateNDRange = [](std::vector &NDRange, cl_uint WorkDim, + size_t *UpdatePtr) { + NDRange.resize(WorkDim, 0); + const size_t CopySize = sizeof(size_t) * WorkDim; std::memcpy(NDRange.data(), UpdatePtr, CopySize); }; - if (auto GlobalWorkOffsetPtr = pUpdateKernelLaunch->pNewGlobalWorkOffset) { - updateNDRange(CLGlobalWorkOffset, GlobalWorkOffsetPtr); - } + for (uint32_t i = 0; i < numKernelUpdates; i++) { + cl_mutable_dispatch_config_khr &Config = ConfigList[i]; + std::vector &CLUSMArgs = CLUSMArgsList[i]; + std::vector &CLArgs = CLArgsList[i]; + std::vector &CLGlobalWorkOffset = CLGlobalWorkOffsetList[i]; + std::vector &CLGlobalWorkSize = CLGlobalWorkSizeList[i]; + std::vector &CLLocalWorkSize = CLLocalWorkSizeList[i]; - if (auto GlobalWorkSizePtr = pUpdateKernelLaunch->pNewGlobalWorkSize) { - updateNDRange(CLGlobalWorkSize, GlobalWorkSizePtr); - } + const auto &UpdateDesc = pUpdateKernelLaunch[i]; + // Find the CL USM pointer arguments to the kernel to update + updateKernelPointerArgs(CLUSMArgs, UpdateDesc); + + // Find the memory object and scalar arguments to the kernel to update + updateKernelArgs(CLArgs, UpdateDesc); + + // Find the updated ND-Range configuration of the kernel. + auto Command = UpdateDesc.hCommand; + cl_uint &CommandWorkDim = Command->WorkDim; - if (auto LocalWorkSizePtr = pUpdateKernelLaunch->pNewLocalWorkSize) { - updateNDRange(CLLocalWorkSize, LocalWorkSizePtr); + if (auto GlobalWorkOffsetPtr = UpdateDesc.pNewGlobalWorkOffset) { + updateNDRange(CLGlobalWorkOffset, CommandWorkDim, GlobalWorkOffsetPtr); + } + + if (auto GlobalWorkSizePtr = UpdateDesc.pNewGlobalWorkSize) { + updateNDRange(CLGlobalWorkSize, CommandWorkDim, GlobalWorkSizePtr); + } + + if (auto LocalWorkSizePtr = UpdateDesc.pNewLocalWorkSize) { + updateNDRange(CLLocalWorkSize, CommandWorkDim, LocalWorkSizePtr); + } + + cl_mutable_command_khr CLCommand = + cl_adapter::cast(Command->CLMutableCommand); + Config = cl_mutable_dispatch_config_khr{ + CLCommand, + static_cast(CLArgs.size()), // num_args + static_cast(CLUSMArgs.size()), // num_svm_args + 0, // num_exec_infos + CommandWorkDim, // work_dim + CLArgs.data(), // arg_list + CLUSMArgs.data(), // arg_svm_list + nullptr, // exec_info_list + CLGlobalWorkOffset.data(), // global_work_offset + CLGlobalWorkSize.data(), // global_work_size + CLLocalWorkSize.data(), // local_work_size + }; } - cl_mutable_command_khr command = - cl_adapter::cast(hCommand->CLMutableCommand); - cl_mutable_dispatch_config_khr dispatch_config = { - command, - static_cast(CLArgs.size()), // num_args - static_cast(CLUSMArgs.size()), // num_svm_args - 0, // num_exec_infos - CommandWorkDim, // work_dim - CLArgs.data(), // arg_list - CLUSMArgs.data(), // arg_svm_list - nullptr, // exec_info_list - CLGlobalWorkOffset.data(), // global_work_offset - CLGlobalWorkSize.data(), // global_work_size - CLLocalWorkSize.data(), // local_work_size - }; - cl_uint num_configs = 1; - cl_command_buffer_update_type_khr config_types[1] = { - CL_STRUCTURE_TYPE_MUTABLE_DISPATCH_CONFIG_KHR}; - const void *configs[1] = {&dispatch_config}; + cl_uint NumConfigs = ConfigList.size(); + std::vector ConfigTypes( + NumConfigs, CL_STRUCTURE_TYPE_MUTABLE_DISPATCH_CONFIG_KHR); + std::vector ConfigPtrs(NumConfigs); + for (cl_uint i = 0; i < NumConfigs; i++) { + ConfigPtrs[i] = &ConfigList[i]; + } CL_RETURN_ON_FAILURE(clUpdateMutableCommandsKHR( - hCommandBuffer->CLCommandBuffer, num_configs, config_types, configs)); + hCommandBuffer->CLCommandBuffer, NumConfigs, ConfigTypes.data(), + (const void **)ConfigPtrs.data())); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp b/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp index e716eaaa49372..dc49864990f0e 100644 --- a/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp +++ b/unified-runtime/source/loader/layers/tracing/ur_trcddi.cpp @@ -8189,9 +8189,12 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urCommandBufferUpdateKernelLaunchExp __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( - /// [in] Handle of the command-buffer kernel command to update. - ur_exp_command_buffer_command_handle_t hCommand, - /// [in] Struct defining how the kernel command is to be updated. + /// [in] Handle of the command-buffer object. + ur_exp_command_buffer_handle_t hCommandBuffer, + /// [in] Length of pUpdateKernelLaunch. + uint32_t numKernelUpdates, + /// [in][range(0, numKernelUpdates)] List of structs defining how a + /// kernel commands are to be updated. const ur_exp_command_buffer_update_kernel_launch_desc_t *pUpdateKernelLaunch) { auto pfnUpdateKernelLaunchExp = @@ -8201,7 +8204,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; ur_command_buffer_update_kernel_launch_exp_params_t params = { - &hCommand, &pUpdateKernelLaunch}; + &hCommandBuffer, &numKernelUpdates, &pUpdateKernelLaunch}; uint64_t instance = getContext()->notify_begin( UR_FUNCTION_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_EXP, "urCommandBufferUpdateKernelLaunchExp", ¶ms); @@ -8209,7 +8212,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( auto &logger = getContext()->logger; logger.info(" ---> urCommandBufferUpdateKernelLaunchExp\n"); - ur_result_t result = pfnUpdateKernelLaunchExp(hCommand, pUpdateKernelLaunch); + ur_result_t result = pfnUpdateKernelLaunchExp( + hCommandBuffer, numKernelUpdates, pUpdateKernelLaunch); getContext()->notify_end(UR_FUNCTION_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_EXP, "urCommandBufferUpdateKernelLaunchExp", ¶ms, diff --git a/unified-runtime/source/loader/layers/validation/ur_valddi.cpp b/unified-runtime/source/loader/layers/validation/ur_valddi.cpp index acb1f65c94ce5..fa49edf0ddaf1 100644 --- a/unified-runtime/source/loader/layers/validation/ur_valddi.cpp +++ b/unified-runtime/source/loader/layers/validation/ur_valddi.cpp @@ -8844,9 +8844,12 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urCommandBufferUpdateKernelLaunchExp __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( - /// [in] Handle of the command-buffer kernel command to update. - ur_exp_command_buffer_command_handle_t hCommand, - /// [in] Struct defining how the kernel command is to be updated. + /// [in] Handle of the command-buffer object. + ur_exp_command_buffer_handle_t hCommandBuffer, + /// [in] Length of pUpdateKernelLaunch. + uint32_t numKernelUpdates, + /// [in][range(0, numKernelUpdates)] List of structs defining how a + /// kernel commands are to be updated. const ur_exp_command_buffer_update_kernel_launch_desc_t *pUpdateKernelLaunch) { auto pfnUpdateKernelLaunchExp = @@ -8857,18 +8860,21 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( } if (getContext()->enableParameterValidation) { - if (NULL == hCommand) + if (NULL == hCommandBuffer) + return UR_RESULT_ERROR_INVALID_NULL_HANDLE; + + if (NULL == pUpdateKernelLaunch->hCommand) return UR_RESULT_ERROR_INVALID_NULL_HANDLE; if (NULL == pUpdateKernelLaunch) return UR_RESULT_ERROR_INVALID_NULL_POINTER; - if (pUpdateKernelLaunch->newWorkDim < 1 || - pUpdateKernelLaunch->newWorkDim > 3) - return UR_RESULT_ERROR_INVALID_WORK_DIMENSION; + if (numKernelUpdates == 0) + return UR_RESULT_ERROR_INVALID_SIZE; } - ur_result_t result = pfnUpdateKernelLaunchExp(hCommand, pUpdateKernelLaunch); + ur_result_t result = pfnUpdateKernelLaunchExp( + hCommandBuffer, numKernelUpdates, pUpdateKernelLaunch); return result; } diff --git a/unified-runtime/source/loader/ur_ldrddi.cpp b/unified-runtime/source/loader/ur_ldrddi.cpp index c1e21fd58b7bc..cc6cfc304629e 100644 --- a/unified-runtime/source/loader/ur_ldrddi.cpp +++ b/unified-runtime/source/loader/ur_ldrddi.cpp @@ -8383,9 +8383,12 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urCommandBufferUpdateKernelLaunchExp __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( - /// [in] Handle of the command-buffer kernel command to update. - ur_exp_command_buffer_command_handle_t hCommand, - /// [in] Struct defining how the kernel command is to be updated. + /// [in] Handle of the command-buffer object. + ur_exp_command_buffer_handle_t hCommandBuffer, + /// [in] Length of pUpdateKernelLaunch. + uint32_t numKernelUpdates, + /// [in][range(0, numKernelUpdates)] List of structs defining how a + /// kernel commands are to be updated. const ur_exp_command_buffer_update_kernel_launch_desc_t *pUpdateKernelLaunch) { ur_result_t result = UR_RESULT_SUCCESS; @@ -8394,7 +8397,7 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( // extract platform's function pointer table auto dditable = - reinterpret_cast(hCommand) + reinterpret_cast(hCommandBuffer) ->dditable; auto pfnUpdateKernelLaunchExp = dditable->ur.CommandBufferExp.pfnUpdateKernelLaunchExp; @@ -8402,40 +8405,53 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( return UR_RESULT_ERROR_UNINITIALIZED; // convert loader handle to platform handle - hCommand = - reinterpret_cast(hCommand) + hCommandBuffer = + reinterpret_cast(hCommandBuffer) ->handle; // Deal with any struct parameters that have handle members we need to // convert. - auto pUpdateKernelLaunchLocal = *pUpdateKernelLaunch; - - if (pUpdateKernelLaunchLocal.hNewKernel) - pUpdateKernelLaunchLocal.hNewKernel = - reinterpret_cast( - pUpdateKernelLaunchLocal.hNewKernel) + std::vector + pUpdateKernelLaunchVector = {}; + std::vector> + ppUpdateKernelLaunchpNewMemObjArgList(numKernelUpdates); + for (size_t Offset = 0; Offset < numKernelUpdates; Offset++) { + auto pUpdateKernelLaunchLocal = *pUpdateKernelLaunch; + + pUpdateKernelLaunchLocal.hCommand = + reinterpret_cast( + pUpdateKernelLaunchLocal.hCommand) ->handle; - - std::vector - pUpdateKernelLaunchpNewMemObjArgList; - for (uint32_t i = 0; i < pUpdateKernelLaunch->numNewMemObjArgs; i++) { - ur_exp_command_buffer_update_memobj_arg_desc_t NewRangeStruct = - pUpdateKernelLaunchLocal.pNewMemObjArgList[i]; - if (NewRangeStruct.hNewMemObjArg) - NewRangeStruct.hNewMemObjArg = - reinterpret_cast(NewRangeStruct.hNewMemObjArg) + if (pUpdateKernelLaunchLocal.hNewKernel) + pUpdateKernelLaunchLocal.hNewKernel = + reinterpret_cast( + pUpdateKernelLaunchLocal.hNewKernel) ->handle; - pUpdateKernelLaunchpNewMemObjArgList.push_back(NewRangeStruct); - } - pUpdateKernelLaunchLocal.pNewMemObjArgList = - pUpdateKernelLaunchpNewMemObjArgList.data(); + std::vector + &pUpdateKernelLaunchpNewMemObjArgList = + ppUpdateKernelLaunchpNewMemObjArgList[Offset]; + for (uint32_t i = 0; i < pUpdateKernelLaunch->numNewMemObjArgs; i++) { + ur_exp_command_buffer_update_memobj_arg_desc_t NewRangeStruct = + pUpdateKernelLaunchLocal.pNewMemObjArgList[i]; + if (NewRangeStruct.hNewMemObjArg) + NewRangeStruct.hNewMemObjArg = + reinterpret_cast(NewRangeStruct.hNewMemObjArg) + ->handle; + + pUpdateKernelLaunchpNewMemObjArgList.push_back(NewRangeStruct); + } + pUpdateKernelLaunchLocal.pNewMemObjArgList = + pUpdateKernelLaunchpNewMemObjArgList.data(); - // Now that we've converted all the members update the param pointers - pUpdateKernelLaunch = &pUpdateKernelLaunchLocal; + pUpdateKernelLaunchVector.push_back(pUpdateKernelLaunchLocal); + pUpdateKernelLaunch++; + } + pUpdateKernelLaunch = pUpdateKernelLaunchVector.data(); // forward to device-platform - result = pfnUpdateKernelLaunchExp(hCommand, pUpdateKernelLaunch); + result = pfnUpdateKernelLaunchExp(hCommandBuffer, numKernelUpdates, + pUpdateKernelLaunch); return result; } diff --git a/unified-runtime/source/loader/ur_libapi.cpp b/unified-runtime/source/loader/ur_libapi.cpp index e5797537632bf..427a7c6be77c2 100644 --- a/unified-runtime/source/loader/ur_libapi.cpp +++ b/unified-runtime/source/loader/ur_libapi.cpp @@ -8755,7 +8755,8 @@ ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// /// @details /// This entry-point is synchronous and may block if the command-buffer is -/// executing when the entry-point is called. +/// executing when the entry-point is called. On error, the state of the +/// command-buffer commands being updated is undefined. /// /// @returns /// - ::UR_RESULT_SUCCESS @@ -8763,66 +8764,75 @@ ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// - ::UR_RESULT_ERROR_DEVICE_LOST /// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE -/// + `NULL == hCommand` +/// + `NULL == hCommandBuffer` +/// + `NULL == pUpdateKernelLaunch->hCommand` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == pUpdateKernelLaunch` +/// - ::UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_EXP +/// - ::UR_RESULT_ERROR_INVALID_SIZE +/// + `numKernelUpdates == 0` /// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_ARGUMENTS -/// is not supported by the device, but any of -/// `pUpdateKernelLaunch->numNewMemObjArgs`, -/// `pUpdateKernelLaunch->numNewPointerArgs`, or -/// `pUpdateKernelLaunch->numNewValueArgs` are not zero. +/// is not supported by the device, and for any of any element of +/// `pUpdateKernelLaunch` the `numNewMemObjArgs`, `numNewPointerArgs`, +/// or `numNewValueArgs` members are not zero. /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_LOCAL_WORK_SIZE is -/// not supported by the device but -/// `pUpdateKernelLaunch->pNewLocalWorkSize` is not nullptr. +/// not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `pNewLocalWorkSize` member is not nullptr. /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_LOCAL_WORK_SIZE is -/// not supported by the device but -/// `pUpdateKernelLaunch->pNewLocalWorkSize` is nullptr and -/// `pUpdateKernelLaunch->pNewGlobalWorkSize` is not nullptr. +/// not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `pNewLocalWorkSize` member is nullptr and +/// `pNewGlobalWorkSize` is not nullptr. /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_SIZE -/// is not supported by the device but -/// `pUpdateKernelLaunch->pNewGlobalWorkSize` is not nullptr +/// is not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `pNewGlobalWorkSize` member is not nullptr /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_OFFSET -/// is not supported by the device but -/// `pUpdateKernelLaunch->pNewGlobalWorkOffset` is not nullptr. +/// is not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `pNewGlobalWorkOffset` member is not +/// nullptr. /// + If ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_HANDLE -/// is not supported by the device but `pUpdateKernelLaunch->hNewKernel` -/// is not nullptr. +/// is not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `hNewKernel` member is not nullptr. /// - ::UR_RESULT_ERROR_INVALID_OPERATION /// + If ::ur_exp_command_buffer_desc_t::isUpdatable was not set to true -/// on creation of the command-buffer `hCommand` belongs to. -/// + If the command-buffer `hCommand` belongs to has not been -/// finalized. +/// on creation of the `hCommandBuffer`. +/// + If `hCommandBuffer` has not been finalized. /// - ::UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP -/// + If `hCommand` is not a kernel execution command. +/// + If for any element of `pUpdateKernelLaunch` the `hCommand` member +/// is not a kernel execution command. +/// + If for any element of `pUpdateKernelLaunch` the `hCommand` member +/// was not created from `hCommandBuffer`. /// - ::UR_RESULT_ERROR_INVALID_MEM_OBJECT /// - ::UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX /// - ::UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE /// - ::UR_RESULT_ERROR_INVALID_ENUMERATION /// - ::UR_RESULT_ERROR_INVALID_WORK_DIMENSION -/// + `pUpdateKernelLaunch->newWorkDim < 1 || -/// pUpdateKernelLaunch->newWorkDim > 3` +/// + If for any element of `pUpdateKernelLaunch` the `newWorkDim` +/// member is less than 1 or greater than 3. /// - ::UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE /// - ::UR_RESULT_ERROR_INVALID_VALUE -/// + If `pUpdateKernelLaunch->hNewKernel` was not passed to the -/// `hKernel` or `phKernelAlternatives` parameters of -/// ::urCommandBufferAppendKernelLaunchExp when this command was -/// created. -/// + If `pUpdateKernelLaunch->newWorkDim` is different from the current -/// workDim in `hCommand` and, -/// `pUpdateKernelLaunch->pNewGlobalWorkSize`, or -/// `pUpdateKernelLaunch->pNewGlobalWorkOffset` are nullptr. +/// + If for any element of `pUpdateKernelLaunch` the `hNewKernel` +/// member was not passed to the `hKernel` or `phKernelAlternatives` +/// parameters of ::urCommandBufferAppendKernelLaunchExp when the +/// command was created. +/// + If for any element of `pUpdateKernelLaunch` the `newWorkDim` +/// member is different from the current workDim in the `hCommand` +/// member, and `pNewGlobalWorkSize` or `pNewGlobalWorkOffset` are +/// nullptr. /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( - /// [in] Handle of the command-buffer kernel command to update. - ur_exp_command_buffer_command_handle_t hCommand, - /// [in] Struct defining how the kernel command is to be updated. + /// [in] Handle of the command-buffer object. + ur_exp_command_buffer_handle_t hCommandBuffer, + /// [in] Length of pUpdateKernelLaunch. + uint32_t numKernelUpdates, + /// [in][range(0, numKernelUpdates)] List of structs defining how a + /// kernel commands are to be updated. const ur_exp_command_buffer_update_kernel_launch_desc_t *pUpdateKernelLaunch) try { auto pfnUpdateKernelLaunchExp = @@ -8831,7 +8841,8 @@ ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( if (nullptr == pfnUpdateKernelLaunchExp) return UR_RESULT_ERROR_UNINITIALIZED; - return pfnUpdateKernelLaunchExp(hCommand, pUpdateKernelLaunch); + return pfnUpdateKernelLaunchExp(hCommandBuffer, numKernelUpdates, + pUpdateKernelLaunch); } catch (...) { return exceptionToResult(std::current_exception()); } diff --git a/unified-runtime/source/ur_api.cpp b/unified-runtime/source/ur_api.cpp index c5651c0fc4f83..9a0f9bf0ff272 100644 --- a/unified-runtime/source/ur_api.cpp +++ b/unified-runtime/source/ur_api.cpp @@ -7649,7 +7649,8 @@ ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// /// @details /// This entry-point is synchronous and may block if the command-buffer is -/// executing when the entry-point is called. +/// executing when the entry-point is called. On error, the state of the +/// command-buffer commands being updated is undefined. /// /// @returns /// - ::UR_RESULT_SUCCESS @@ -7657,66 +7658,75 @@ ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// - ::UR_RESULT_ERROR_DEVICE_LOST /// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE -/// + `NULL == hCommand` +/// + `NULL == hCommandBuffer` +/// + `NULL == pUpdateKernelLaunch->hCommand` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == pUpdateKernelLaunch` +/// - ::UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_EXP +/// - ::UR_RESULT_ERROR_INVALID_SIZE +/// + `numKernelUpdates == 0` /// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_ARGUMENTS -/// is not supported by the device, but any of -/// `pUpdateKernelLaunch->numNewMemObjArgs`, -/// `pUpdateKernelLaunch->numNewPointerArgs`, or -/// `pUpdateKernelLaunch->numNewValueArgs` are not zero. +/// is not supported by the device, and for any of any element of +/// `pUpdateKernelLaunch` the `numNewMemObjArgs`, `numNewPointerArgs`, +/// or `numNewValueArgs` members are not zero. /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_LOCAL_WORK_SIZE is -/// not supported by the device but -/// `pUpdateKernelLaunch->pNewLocalWorkSize` is not nullptr. +/// not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `pNewLocalWorkSize` member is not nullptr. /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_LOCAL_WORK_SIZE is -/// not supported by the device but -/// `pUpdateKernelLaunch->pNewLocalWorkSize` is nullptr and -/// `pUpdateKernelLaunch->pNewGlobalWorkSize` is not nullptr. +/// not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `pNewLocalWorkSize` member is nullptr and +/// `pNewGlobalWorkSize` is not nullptr. /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_SIZE -/// is not supported by the device but -/// `pUpdateKernelLaunch->pNewGlobalWorkSize` is not nullptr +/// is not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `pNewGlobalWorkSize` member is not nullptr /// + If /// ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_OFFSET -/// is not supported by the device but -/// `pUpdateKernelLaunch->pNewGlobalWorkOffset` is not nullptr. +/// is not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `pNewGlobalWorkOffset` member is not +/// nullptr. /// + If ::UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_HANDLE -/// is not supported by the device but `pUpdateKernelLaunch->hNewKernel` -/// is not nullptr. +/// is not supported by the device, and for any element of +/// `pUpdateKernelLaunch` the `hNewKernel` member is not nullptr. /// - ::UR_RESULT_ERROR_INVALID_OPERATION /// + If ::ur_exp_command_buffer_desc_t::isUpdatable was not set to true -/// on creation of the command-buffer `hCommand` belongs to. -/// + If the command-buffer `hCommand` belongs to has not been -/// finalized. +/// on creation of the `hCommandBuffer`. +/// + If `hCommandBuffer` has not been finalized. /// - ::UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP -/// + If `hCommand` is not a kernel execution command. +/// + If for any element of `pUpdateKernelLaunch` the `hCommand` member +/// is not a kernel execution command. +/// + If for any element of `pUpdateKernelLaunch` the `hCommand` member +/// was not created from `hCommandBuffer`. /// - ::UR_RESULT_ERROR_INVALID_MEM_OBJECT /// - ::UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX /// - ::UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE /// - ::UR_RESULT_ERROR_INVALID_ENUMERATION /// - ::UR_RESULT_ERROR_INVALID_WORK_DIMENSION -/// + `pUpdateKernelLaunch->newWorkDim < 1 || -/// pUpdateKernelLaunch->newWorkDim > 3` +/// + If for any element of `pUpdateKernelLaunch` the `newWorkDim` +/// member is less than 1 or greater than 3. /// - ::UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE /// - ::UR_RESULT_ERROR_INVALID_VALUE -/// + If `pUpdateKernelLaunch->hNewKernel` was not passed to the -/// `hKernel` or `phKernelAlternatives` parameters of -/// ::urCommandBufferAppendKernelLaunchExp when this command was -/// created. -/// + If `pUpdateKernelLaunch->newWorkDim` is different from the current -/// workDim in `hCommand` and, -/// `pUpdateKernelLaunch->pNewGlobalWorkSize`, or -/// `pUpdateKernelLaunch->pNewGlobalWorkOffset` are nullptr. +/// + If for any element of `pUpdateKernelLaunch` the `hNewKernel` +/// member was not passed to the `hKernel` or `phKernelAlternatives` +/// parameters of ::urCommandBufferAppendKernelLaunchExp when the +/// command was created. +/// + If for any element of `pUpdateKernelLaunch` the `newWorkDim` +/// member is different from the current workDim in the `hCommand` +/// member, and `pNewGlobalWorkSize` or `pNewGlobalWorkOffset` are +/// nullptr. /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( - /// [in] Handle of the command-buffer kernel command to update. - ur_exp_command_buffer_command_handle_t hCommand, - /// [in] Struct defining how the kernel command is to be updated. + /// [in] Handle of the command-buffer object. + ur_exp_command_buffer_handle_t hCommandBuffer, + /// [in] Length of pUpdateKernelLaunch. + uint32_t numKernelUpdates, + /// [in][range(0, numKernelUpdates)] List of structs defining how a + /// kernel commands are to be updated. const ur_exp_command_buffer_update_kernel_launch_desc_t *pUpdateKernelLaunch) { ur_result_t result = UR_RESULT_SUCCESS; diff --git a/unified-runtime/test/conformance/exp_command_buffer/update/buffer_fill_kernel_update.cpp b/unified-runtime/test/conformance/exp_command_buffer/update/buffer_fill_kernel_update.cpp index c76ac3e1112c2..9e9afff4770f2 100644 --- a/unified-runtime/test/conformance/exp_command_buffer/update/buffer_fill_kernel_update.cpp +++ b/unified-runtime/test/conformance/exp_command_buffer/update/buffer_fill_kernel_update.cpp @@ -119,6 +119,7 @@ TEST_P(BufferFillCommandTest, UpdateParameters) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 1, // numNewMemObjArgs 0, // numNewPointerArgs @@ -133,8 +134,8 @@ TEST_P(BufferFillCommandTest, UpdateParameters) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -172,6 +173,7 @@ TEST_P(BufferFillCommandTest, UpdateGlobalSize) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 1, // numNewMemObjArgs 0, // numNewPointerArgs @@ -185,8 +187,8 @@ TEST_P(BufferFillCommandTest, UpdateGlobalSize) { &new_local_size, // pNewLocalWorkSize }; - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -223,6 +225,7 @@ TEST_P(BufferFillCommandTest, SeparateUpdateCalls) { ur_exp_command_buffer_update_kernel_launch_desc_t output_update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 1, // numNewMemObjArgs 0, // numNewPointerArgs @@ -235,8 +238,8 @@ TEST_P(BufferFillCommandTest, SeparateUpdateCalls) { nullptr, // pNewGlobalWorkSize nullptr, // pNewLocalWorkSize }; - ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(command_handle, - &output_update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &output_update_desc)); uint32_t new_val = 33; const uint32_t arg_index = (backend == UR_PLATFORM_BACKEND_HIP) ? 4 : 2; @@ -252,6 +255,7 @@ TEST_P(BufferFillCommandTest, SeparateUpdateCalls) { ur_exp_command_buffer_update_kernel_launch_desc_t input_update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -264,13 +268,14 @@ TEST_P(BufferFillCommandTest, SeparateUpdateCalls) { nullptr, // pNewGlobalWorkSize nullptr, // pNewLocalWorkSize }; - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &input_update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &input_update_desc)); size_t new_local_size = local_size; ur_exp_command_buffer_update_kernel_launch_desc_t global_size_update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -285,7 +290,7 @@ TEST_P(BufferFillCommandTest, SeparateUpdateCalls) { }; ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( - command_handle, &global_size_update_desc)); + updatable_cmd_buf_handle, 1, &global_size_update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); @@ -316,6 +321,7 @@ TEST_P(BufferFillCommandTest, OverrideUpdate) { ur_exp_command_buffer_update_kernel_launch_desc_t first_update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -328,8 +334,8 @@ TEST_P(BufferFillCommandTest, OverrideUpdate) { nullptr, // pNewGlobalWorkSize nullptr, // pNewLocalWorkSize }; - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &first_update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &first_update_desc)); uint32_t second_val = -99; ur_exp_command_buffer_update_value_arg_desc_t second_input_desc = { @@ -344,6 +350,7 @@ TEST_P(BufferFillCommandTest, OverrideUpdate) { ur_exp_command_buffer_update_kernel_launch_desc_t second_update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -357,8 +364,8 @@ TEST_P(BufferFillCommandTest, OverrideUpdate) { nullptr, // pNewLocalWorkSize }; - ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(command_handle, - &second_update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &second_update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); @@ -401,21 +408,22 @@ TEST_P(BufferFillCommandTest, OverrideArgList) { ur_exp_command_buffer_update_kernel_launch_desc_t second_update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext - kernel, // hNewKernel - 0, // numNewMemObjArgs - 0, // numNewPointerArgs - 2, // numNewValueArgs - n_dimensions, // newWorkDim - nullptr, // pNewMemObjArgList - nullptr, // pNewPointerArgList - input_descs, // pNewValueArgList - nullptr, // pNewGlobalWorkOffset - nullptr, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize + command_handle, // hCommand + kernel, // hNewKernel + 0, // numNewMemObjArgs + 0, // numNewPointerArgs + 2, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + nullptr, // pNewPointerArgList + input_descs, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize }; - ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(command_handle, - &second_update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &second_update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); diff --git a/unified-runtime/test/conformance/exp_command_buffer/update/buffer_saxpy_kernel_update.cpp b/unified-runtime/test/conformance/exp_command_buffer/update/buffer_saxpy_kernel_update.cpp index 55b408b96bc49..01b13fc35bdd1 100644 --- a/unified-runtime/test/conformance/exp_command_buffer/update/buffer_saxpy_kernel_update.cpp +++ b/unified-runtime/test/conformance/exp_command_buffer/update/buffer_saxpy_kernel_update.cpp @@ -176,6 +176,7 @@ TEST_P(BufferSaxpyKernelTest, UpdateParameters) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 2, // numNewMemObjArgs 0, // numNewPointerArgs @@ -190,8 +191,8 @@ TEST_P(BufferSaxpyKernelTest, UpdateParameters) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); diff --git a/unified-runtime/test/conformance/exp_command_buffer/update/invalid_update.cpp b/unified-runtime/test/conformance/exp_command_buffer/update/invalid_update.cpp index b93ffc26d4d8a..409a2e0601103 100644 --- a/unified-runtime/test/conformance/exp_command_buffer/update/invalid_update.cpp +++ b/unified-runtime/test/conformance/exp_command_buffer/update/invalid_update.cpp @@ -86,6 +86,7 @@ TEST_P(InvalidUpdateTest, NotFinalizedCommandBuffer) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -100,8 +101,8 @@ TEST_P(InvalidUpdateTest, NotFinalizedCommandBuffer) { }; // Update command to command-buffer that has not been finalized - ur_result_t result = - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc); + ur_result_t result = urCommandBufferUpdateKernelLaunchExp( + updatable_cmd_buf_handle, 1, &update_desc); ASSERT_EQ(UR_RESULT_ERROR_INVALID_OPERATION, result); } @@ -145,23 +146,24 @@ TEST_P(InvalidUpdateTest, NotUpdatableCommandBuffer) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext - kernel, // hNewKernel - 0, // numNewMemObjArgs - 0, // numNewPointerArgs - 1, // numNewValueArgs - n_dimensions, // newWorkDim - nullptr, // pNewMemObjArgList - nullptr, // pNewPointerArgList - &new_input_desc, // pNewValueArgList - nullptr, // pNewGlobalWorkOffset - nullptr, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize + test_command_handle, // hCommand + kernel, // hNewKernel + 0, // numNewMemObjArgs + 0, // numNewPointerArgs + 1, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + nullptr, // pNewPointerArgList + &new_input_desc, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize }; // Since no command handle was returned Update command to command-buffer // should also be an error. - ur_result_t result = - urCommandBufferUpdateKernelLaunchExp(test_command_handle, &update_desc); + ur_result_t result = urCommandBufferUpdateKernelLaunchExp(test_cmd_buf_handle, + 1, &update_desc); EXPECT_EQ(UR_RESULT_ERROR_INVALID_NULL_HANDLE, result); if (test_cmd_buf_handle) { @@ -182,6 +184,7 @@ TEST_P(InvalidUpdateTest, InvalidDimensions) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -196,11 +199,13 @@ TEST_P(InvalidUpdateTest, InvalidDimensions) { }; ASSERT_EQ(UR_RESULT_ERROR_INVALID_VALUE, - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, 1, + &update_desc)); update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -215,7 +220,89 @@ TEST_P(InvalidUpdateTest, InvalidDimensions) { }; ASSERT_EQ(UR_RESULT_ERROR_INVALID_VALUE, - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, 1, + &update_desc)); +} + +// If the command-handle isn't valid an error should be returned +TEST_P(InvalidUpdateTest, InvalidCommandHandle) { + ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle)); + finalized = true; + + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + nullptr, // hCommand + kernel, // hNewKernel + 0, // numNewMemObjArgs + 0, // numNewPointerArgs + 0, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + nullptr, // pNewPointerArgList + nullptr, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + + ASSERT_EQ(UR_RESULT_ERROR_INVALID_NULL_HANDLE, + urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, 1, + &update_desc)); +} + +// Test error code is returned if command handle and command-buffer is +// mismatched +TEST_P(InvalidUpdateTest, CommandBufferMismatch) { + // Create a command-buffer with update enabled. + ur_exp_command_buffer_handle_t test_cmd_buf_handle = nullptr; + ur_exp_command_buffer_desc_t desc{UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC, + nullptr, true, false, false}; + ASSERT_SUCCESS( + urCommandBufferCreateExp(context, device, &desc, &test_cmd_buf_handle)); + EXPECT_NE(test_cmd_buf_handle, nullptr); + + EXPECT_SUCCESS(urCommandBufferFinalizeExp(test_cmd_buf_handle)); + EXPECT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle)); + finalized = true; + + // Set new value to use for fill at kernel index 1 + uint32_t new_val = 33; + ur_exp_command_buffer_update_value_arg_desc_t new_input_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext + 1, // argIndex + sizeof(new_val), // argSize + nullptr, // pProperties + &new_val, // hArgValue + }; + + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + command_handle, // hCommand + kernel, // hNewKernel + 0, // numNewMemObjArgs + 0, // numNewPointerArgs + 1, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + nullptr, // pNewPointerArgList + &new_input_desc, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + + // Since no command handle was returned Update command to command-buffer + // should also be an error. + ur_result_t result = urCommandBufferUpdateKernelLaunchExp(test_cmd_buf_handle, + 1, &update_desc); + EXPECT_EQ(UR_RESULT_ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP, result); + + if (test_cmd_buf_handle) { + EXPECT_SUCCESS(urCommandBufferReleaseExp(test_cmd_buf_handle)); + } } // Tests that an error is thrown when trying to update a kernel capability @@ -327,6 +414,7 @@ TEST_P(InvalidUpdateCommandBufferExpExecutionTest, KernelArg) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand nullptr, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -340,8 +428,8 @@ TEST_P(InvalidUpdateCommandBufferExpExecutionTest, KernelArg) { nullptr, // pNewLocalWorkSize }; - ur_result_t result = - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc); + ur_result_t result = urCommandBufferUpdateKernelLaunchExp( + updatable_cmd_buf_handle, 1, &update_desc); ASSERT_EQ(UR_RESULT_ERROR_UNSUPPORTED_FEATURE, result); } @@ -356,6 +444,7 @@ TEST_P(InvalidUpdateCommandBufferExpExecutionTest, GlobalSize) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand nullptr, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -369,8 +458,8 @@ TEST_P(InvalidUpdateCommandBufferExpExecutionTest, GlobalSize) { nullptr, // pNewLocalWorkSize }; - ur_result_t result = - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc); + ur_result_t result = urCommandBufferUpdateKernelLaunchExp( + updatable_cmd_buf_handle, 1, &update_desc); ASSERT_EQ(UR_RESULT_ERROR_UNSUPPORTED_FEATURE, result); } @@ -385,6 +474,7 @@ TEST_P(InvalidUpdateCommandBufferExpExecutionTest, GlobalOffset) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand nullptr, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -398,8 +488,8 @@ TEST_P(InvalidUpdateCommandBufferExpExecutionTest, GlobalOffset) { nullptr, // pNewLocalWorkSize }; - ur_result_t result = - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc); + ur_result_t result = urCommandBufferUpdateKernelLaunchExp( + updatable_cmd_buf_handle, 1, &update_desc); ASSERT_EQ(UR_RESULT_ERROR_UNSUPPORTED_FEATURE, result); } @@ -414,6 +504,7 @@ TEST_P(InvalidUpdateCommandBufferExpExecutionTest, LocalSize) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand nullptr, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -427,8 +518,8 @@ TEST_P(InvalidUpdateCommandBufferExpExecutionTest, LocalSize) { &new_local_size, // pNewLocalWorkSize }; - ur_result_t result = - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc); + ur_result_t result = urCommandBufferUpdateKernelLaunchExp( + updatable_cmd_buf_handle, 1, &update_desc); ASSERT_EQ(UR_RESULT_ERROR_UNSUPPORTED_FEATURE, result); } @@ -449,6 +540,7 @@ TEST_P(InvalidUpdateCommandBufferExpExecutionTest, ImplChosenLocalSize) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand nullptr, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -462,8 +554,8 @@ TEST_P(InvalidUpdateCommandBufferExpExecutionTest, ImplChosenLocalSize) { nullptr, // pNewLocalWorkSize }; - ur_result_t result = - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc); + ur_result_t result = urCommandBufferUpdateKernelLaunchExp( + updatable_cmd_buf_handle, 1, &update_desc); ASSERT_EQ(UR_RESULT_ERROR_UNSUPPORTED_FEATURE, result); } @@ -477,20 +569,21 @@ TEST_P(InvalidUpdateCommandBufferExpExecutionTest, Kernel) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext - kernel_2, // hNewKernel - 0, // numNewMemObjArgs - 0, // numNewPointerArgs - 0, // numNewValueArgs - n_dimensions, // newWorkDim - nullptr, // pNewMemObjArgList - nullptr, // pNewPointerArgList - nullptr, // pNewValueArgList - nullptr, // pNewGlobalWorkOffset - nullptr, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize + command_handle, // hCommand + kernel_2, // hNewKernel + 0, // numNewMemObjArgs + 0, // numNewPointerArgs + 0, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + nullptr, // pNewPointerArgList + nullptr, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize }; - ur_result_t result = - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc); + ur_result_t result = urCommandBufferUpdateKernelLaunchExp( + updatable_cmd_buf_handle, 1, &update_desc); ASSERT_EQ(UR_RESULT_ERROR_UNSUPPORTED_FEATURE, result); } diff --git a/unified-runtime/test/conformance/exp_command_buffer/update/kernel_handle_update.cpp b/unified-runtime/test/conformance/exp_command_buffer/update/kernel_handle_update.cpp index 998f664bc85e1..30b60f3ca4ad8 100644 --- a/unified-runtime/test/conformance/exp_command_buffer/update/kernel_handle_update.cpp +++ b/unified-runtime/test/conformance/exp_command_buffer/update/kernel_handle_update.cpp @@ -77,6 +77,7 @@ struct TestSaxpyKernel : public uur::command_buffer::TestKernel { UpdateDesc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + nullptr, // hCommand Kernel, // hNewKernel 0, // numNewMemObjArgs 3, // numNewPointerArgs @@ -171,6 +172,7 @@ struct TestFill2DKernel : public uur::command_buffer::TestKernel { UpdateDesc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + nullptr, // hCommand Kernel, // hNewKernel 0, // numNewMemObjArgs 1, // numNewPointerArgs @@ -275,8 +277,9 @@ TEST_P(urCommandBufferKernelHandleUpdateTest, Success) { ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); + FillUSM2DKernel->UpdateDesc.hCommand = CommandHandle; ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( - CommandHandle, &FillUSM2DKernel->UpdateDesc)); + updatable_cmd_buf_handle, 1, &FillUSM2DKernel->UpdateDesc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -304,8 +307,9 @@ TEST_P(urCommandBufferKernelHandleUpdateTest, UpdateAgain) { ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); + FillUSM2DKernel->UpdateDesc.hCommand = CommandHandle; ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( - CommandHandle, &FillUSM2DKernel->UpdateDesc)); + updatable_cmd_buf_handle, 1, &FillUSM2DKernel->UpdateDesc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -317,7 +321,7 @@ TEST_P(urCommandBufferKernelHandleUpdateTest, UpdateAgain) { // potentially fail since it would try to use the Saxpy kernel FillUSM2DKernel->Val = 78; ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( - CommandHandle, &FillUSM2DKernel->UpdateDesc)); + updatable_cmd_buf_handle, 1, &FillUSM2DKernel->UpdateDesc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -343,8 +347,9 @@ TEST_P(urCommandBufferKernelHandleUpdateTest, RestoreOriginalKernel) { ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); + FillUSM2DKernel->UpdateDesc.hCommand = CommandHandle; ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( - CommandHandle, &FillUSM2DKernel->UpdateDesc)); + updatable_cmd_buf_handle, 1, &FillUSM2DKernel->UpdateDesc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -355,8 +360,9 @@ TEST_P(urCommandBufferKernelHandleUpdateTest, RestoreOriginalKernel) { // Updating A, so that the second launch of the saxpy kernel actually has a // different output. SaxpyKernel->A = 20; + SaxpyKernel->UpdateDesc.hCommand = CommandHandle; ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( - CommandHandle, &SaxpyKernel->UpdateDesc)); + updatable_cmd_buf_handle, 1, &SaxpyKernel->UpdateDesc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -377,9 +383,11 @@ TEST_P(urCommandBufferKernelHandleUpdateTest, KernelAlternativeNotRegistered) { ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); - ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_VALUE, - urCommandBufferUpdateKernelLaunchExp( - CommandHandle, &FillUSM2DKernel->UpdateDesc)); + FillUSM2DKernel->UpdateDesc.hCommand = CommandHandle; + ASSERT_EQ_RESULT( + UR_RESULT_ERROR_INVALID_VALUE, + urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, 1, + &FillUSM2DKernel->UpdateDesc)); } TEST_P(urCommandBufferKernelHandleUpdateTest, @@ -430,8 +438,9 @@ TEST_P(urCommandBufferValidUpdateParametersTest, FillUSM2DKernel->UpdateDesc.newWorkDim = 1; FillUSM2DKernel->UpdateDesc.pNewGlobalWorkSize = &newGlobalWorkSize; FillUSM2DKernel->UpdateDesc.pNewGlobalWorkOffset = &newGlobalWorkOffset; + FillUSM2DKernel->UpdateDesc.hCommand = CommandHandle; ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( - CommandHandle, &FillUSM2DKernel->UpdateDesc)); + updatable_cmd_buf_handle, 1, &FillUSM2DKernel->UpdateDesc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -463,8 +472,9 @@ TEST_P(urCommandBufferValidUpdateParametersTest, UpdateOnlyLocalWorkSize) { SaxpyKernel->UpdateDesc.pNewGlobalWorkSize = nullptr; size_t newLocalSize = SaxpyKernel->LocalSize * 4; SaxpyKernel->UpdateDesc.pNewLocalWorkSize = &newLocalSize; + SaxpyKernel->UpdateDesc.hCommand = CommandHandle; ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( - CommandHandle, &SaxpyKernel->UpdateDesc)); + updatable_cmd_buf_handle, 1, &SaxpyKernel->UpdateDesc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -490,8 +500,9 @@ TEST_P(urCommandBufferValidUpdateParametersTest, SuccessNullptrHandle) { ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle)); SaxpyKernel->UpdateDesc.hNewKernel = nullptr; + SaxpyKernel->UpdateDesc.hCommand = CommandHandle; ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( - CommandHandle, &SaxpyKernel->UpdateDesc)); + updatable_cmd_buf_handle, 1, &SaxpyKernel->UpdateDesc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); diff --git a/unified-runtime/test/conformance/exp_command_buffer/update/local_memory_update.cpp b/unified-runtime/test/conformance/exp_command_buffer/update/local_memory_update.cpp index 1621b902a9aac..b4e442560d7a5 100644 --- a/unified-runtime/test/conformance/exp_command_buffer/update/local_memory_update.cpp +++ b/unified-runtime/test/conformance/exp_command_buffer/update/local_memory_update.cpp @@ -207,6 +207,7 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersSameLocalSize) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs new_input_descs.size(), // numNewPointerArgs @@ -221,8 +222,8 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersSameLocalSize) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -273,6 +274,7 @@ TEST_P(LocalMemoryUpdateTest, UpdateLocalOnly) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -287,8 +289,8 @@ TEST_P(LocalMemoryUpdateTest, UpdateLocalOnly) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -348,6 +350,7 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersEmptyLocalSize) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs new_input_descs.size(), // numNewPointerArgs @@ -362,8 +365,8 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersEmptyLocalSize) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -503,6 +506,7 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersSmallerLocalSize) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 2, // numNewPointerArgs @@ -517,8 +521,8 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersSmallerLocalSize) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -657,6 +661,7 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersLargerLocalSize) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 2, // numNewPointerArgs @@ -671,8 +676,8 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersLargerLocalSize) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -774,6 +779,7 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersPartialLocalSize) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 2, // numNewPointerArgs @@ -788,8 +794,8 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersPartialLocalSize) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); std::vector second_update_value_args{}; @@ -835,9 +841,10 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersPartialLocalSize) { ur_exp_command_buffer_update_kernel_launch_desc_t second_update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext - kernel, // hNewKernel - 0, // numNewMemObjArgs - 0, // numNewPointerArgs + command_handle, // hCommand + kernel, // hNewKernel + 0, // numNewMemObjArgs + 0, // numNewPointerArgs static_cast(second_update_value_args.size()), // numNewValueArgs n_dimensions, // newWorkDim nullptr, // pNewMemObjArgList @@ -847,8 +854,8 @@ TEST_P(LocalMemoryUpdateTest, UpdateParametersPartialLocalSize) { nullptr, // pNewGlobalWorkSize nullptr, // pNewLocalWorkSize }; - ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(command_handle, - &second_update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &second_update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); @@ -950,27 +957,31 @@ TEST_P(LocalMemoryMultiUpdateTest, UpdateParameters) { &shared_ptrs[4], // pArgValue }; - // Update kernel inputs - ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { - UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype - nullptr, // pNext - kernel, // hNewKernel - 0, // numNewMemObjArgs - new_input_descs.size(), // numNewPointerArgs - new_value_descs.size(), // numNewValueArgs - n_dimensions, // newWorkDim - nullptr, // pNewMemObjArgList - new_input_descs.data(), // pNewPointerArgList - new_value_descs.data(), // pNewValueArgList - nullptr, // pNewGlobalWorkOffset - nullptr, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize - }; - - // Update kernel and enqueue command-buffer again + std::vector update_descs; for (auto &handle : command_handles) { - ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(handle, &update_desc)); + // Update kernel inputs + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + handle, // hCommand + kernel, // hNewKernel + 0, // numNewMemObjArgs + new_input_descs.size(), // numNewPointerArgs + new_value_descs.size(), // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs.data(), // pNewPointerArgList + new_value_descs.data(), // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + update_descs.push_back(update_desc); } + + // Update kernel and enqueue command-buffer again + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( + updatable_cmd_buf_handle, update_descs.size(), update_descs.data())); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -1039,30 +1050,36 @@ TEST_P(LocalMemoryMultiUpdateTest, UpdateWithoutBlocking) { &shared_ptrs[4], // pArgValue }; - // Update kernel inputs - ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { - UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype - nullptr, // pNext - kernel, // hNewKernel - 0, // numNewMemObjArgs - new_input_descs.size(), // numNewPointerArgs - new_value_descs.size(), // numNewValueArgs - n_dimensions, // newWorkDim - nullptr, // pNewMemObjArgList - new_input_descs.data(), // pNewPointerArgList - new_value_descs.data(), // pNewValueArgList - nullptr, // pNewGlobalWorkOffset - nullptr, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize - }; // Enqueue without calling urQueueFinish after ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); - // Update kernel and enqueue command-buffer again + std::vector update_descs; for (auto &handle : command_handles) { - ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(handle, &update_desc)); + // Update kernel inputs + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + handle, // hCommand + kernel, // hNewKernel + 0, // numNewMemObjArgs + new_input_descs.size(), // numNewPointerArgs + new_value_descs.size(), // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs.data(), // pNewPointerArgList + new_value_descs.data(), // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + update_descs.push_back(update_desc); } + + // Update kernel and enqueue command-buffer again + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( + updatable_cmd_buf_handle, update_descs.size(), update_descs.data())); + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -1250,6 +1267,7 @@ TEST_P(LocalMemoryUpdateTestOutOfOrder, UpdateAllParameters) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs new_input_descs.size(), // numNewPointerArgs @@ -1264,8 +1282,8 @@ TEST_P(LocalMemoryUpdateTestOutOfOrder, UpdateAllParameters) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); diff --git a/unified-runtime/test/conformance/exp_command_buffer/update/ndrange_update.cpp b/unified-runtime/test/conformance/exp_command_buffer/update/ndrange_update.cpp index 47b5e48d1ae7d..b1577ed6f4735 100644 --- a/unified-runtime/test/conformance/exp_command_buffer/update/ndrange_update.cpp +++ b/unified-runtime/test/conformance/exp_command_buffer/update/ndrange_update.cpp @@ -126,6 +126,7 @@ TEST_P(NDRangeUpdateTest, Update3D) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -140,8 +141,8 @@ TEST_P(NDRangeUpdateTest, Update3D) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -179,6 +180,7 @@ TEST_P(NDRangeUpdateTest, Update2D) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -197,8 +199,8 @@ TEST_P(NDRangeUpdateTest, Update2D) { std::memset(shared_ptr, 0, allocation_size); // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -232,6 +234,7 @@ TEST_P(NDRangeUpdateTest, Update1D) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -250,8 +253,8 @@ TEST_P(NDRangeUpdateTest, Update1D) { std::memset(shared_ptr, 0, allocation_size); // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -290,6 +293,7 @@ TEST_P(NDRangeUpdateTest, ImplToUserDefinedLocalSize) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -304,8 +308,8 @@ TEST_P(NDRangeUpdateTest, ImplToUserDefinedLocalSize) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -341,6 +345,7 @@ TEST_P(NDRangeUpdateTest, UserToImplDefinedLocalSize) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -355,8 +360,8 @@ TEST_P(NDRangeUpdateTest, UserToImplDefinedLocalSize) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); diff --git a/unified-runtime/test/conformance/exp_command_buffer/update/usm_fill_kernel_update.cpp b/unified-runtime/test/conformance/exp_command_buffer/update/usm_fill_kernel_update.cpp index 6613c2d4fb15d..084580e0e62c5 100644 --- a/unified-runtime/test/conformance/exp_command_buffer/update/usm_fill_kernel_update.cpp +++ b/unified-runtime/test/conformance/exp_command_buffer/update/usm_fill_kernel_update.cpp @@ -117,6 +117,7 @@ TEST_P(USMFillCommandTest, UpdateParameters) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 1, // numNewPointerArgs @@ -131,8 +132,8 @@ TEST_P(USMFillCommandTest, UpdateParameters) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -171,6 +172,7 @@ TEST_P(USMFillCommandTest, UpdateBeforeEnqueue) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 1, // numNewPointerArgs @@ -185,8 +187,8 @@ TEST_P(USMFillCommandTest, UpdateBeforeEnqueue) { }; // Update kernel and enqueue command-buffer - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -216,6 +218,7 @@ TEST_P(USMFillCommandTest, UpdateNull) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 1, // numNewPointerArgs @@ -231,8 +234,8 @@ TEST_P(USMFillCommandTest, UpdateNull) { // Verify update kernel succeeded but don't run to avoid dereferencing // the nullptr. - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); } // Test updating a command-buffer with multiple USM fill kernel commands @@ -362,21 +365,22 @@ TEST_P(USMMultipleFillCommandTest, UpdateAllKernels) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext - kernel, // hNewKernel - 0, // numNewMemObjArgs - 1, // numNewPointerArgs - 1, // numNewValueArgs - n_dimensions, // newWorkDim - nullptr, // pNewMemObjArgList - &new_output_desc, // pNewPointerArgList - &new_input_desc, // pNewValueArgList - nullptr, // pNewGlobalWorkOffset - nullptr, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize + command_handles[k], // hCommand + kernel, // hNewKernel + 0, // numNewMemObjArgs + 1, // numNewPointerArgs + 1, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + &new_output_desc, // pNewPointerArgList + &new_input_desc, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize }; - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handles[k], &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( + updatable_cmd_buf_handle, 1, &update_desc)); } // Update kernel and enqueue command-buffer again diff --git a/unified-runtime/test/conformance/exp_command_buffer/update/usm_saxpy_kernel_update.cpp b/unified-runtime/test/conformance/exp_command_buffer/update/usm_saxpy_kernel_update.cpp index 6a645d315a255..2ad04603abd46 100644 --- a/unified-runtime/test/conformance/exp_command_buffer/update/usm_saxpy_kernel_update.cpp +++ b/unified-runtime/test/conformance/exp_command_buffer/update/usm_saxpy_kernel_update.cpp @@ -139,6 +139,7 @@ TEST_P(USMSaxpyKernelTest, UpdateParameters) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command_handle, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 2, // numNewPointerArgs @@ -153,8 +154,8 @@ TEST_P(USMSaxpyKernelTest, UpdateParameters) { }; // Update kernel and enqueue command-buffer again - ASSERT_SUCCESS( - urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(updatable_cmd_buf_handle, + 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -232,27 +233,33 @@ TEST_P(USMMultiSaxpyKernelTest, UpdateParameters) { &new_A, // hArgValue }; - // Update kernel inputs - ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { - UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype - nullptr, // pNext - kernel, // hNewKernel - 0, // numNewMemObjArgs - 2, // numNewPointerArgs - 1, // numNewValueArgs - n_dimensions, // newWorkDim - nullptr, // pNewMemObjArgList - new_input_descs, // pNewPointerArgList - &new_A_desc, // pNewValueArgList - nullptr, // pNewGlobalWorkOffset - nullptr, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize - }; - - // Update kernel and enqueue command-buffer again + std::vector update_descs; for (auto &handle : command_handles) { - ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(handle, &update_desc)); + // Update kernel inputs + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + handle, // hCommand + kernel, // hNewKernel + 0, // numNewMemObjArgs + 2, // numNewPointerArgs + 1, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs, // pNewPointerArgList + &new_A_desc, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + + update_descs.push_back(update_desc); } + + // Update kernel and enqueue command-buffer again + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( + updatable_cmd_buf_handle, update_descs.size(), update_descs.data())); + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -287,26 +294,31 @@ TEST_P(USMMultiSaxpyKernelTest, UpdateNullptrKernel) { &new_A, // hArgValue }; - // Update kernel inputs - ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { - UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype - nullptr, // pNext - nullptr, // hNewKernel - 0, // numNewMemObjArgs - 0, // numNewPointerArgs - 1, // numNewValueArgs - n_dimensions, // newWorkDim - nullptr, // pNewMemObjArgList - nullptr, // pNewPointerArgList - &new_A_desc, // pNewValueArgList - nullptr, // pNewGlobalWorkOffset - nullptr, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize - }; - + std::vector update_descs; for (auto &handle : command_handles) { - ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(handle, &update_desc)); + // Update kernel inputs + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + handle, // hCommand + nullptr, // hNewKernel + 0, // numNewMemObjArgs + 0, // numNewPointerArgs + 1, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + nullptr, // pNewPointerArgList + &new_A_desc, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + + update_descs.push_back(update_desc); } + + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( + updatable_cmd_buf_handle, update_descs.size(), update_descs.data())); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); @@ -349,31 +361,35 @@ TEST_P(USMMultiSaxpyKernelTest, UpdateWithoutBlocking) { &new_A, // hArgValue }; - // Update kernel inputs - ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { - UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype - nullptr, // pNext - kernel, // hNewKernel - 0, // numNewMemObjArgs - 2, // numNewPointerArgs - 1, // numNewValueArgs - n_dimensions, // newWorkDim - nullptr, // pNewMemObjArgList - new_input_descs, // pNewPointerArgList - &new_A_desc, // pNewValueArgList - nullptr, // pNewGlobalWorkOffset - nullptr, // pNewGlobalWorkSize - nullptr, // pNewLocalWorkSize - }; - // Run command-buffer prior to update without doing a blocking wait after ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); - // Update kernel and enqueue command-buffer again + std::vector update_descs; for (auto &handle : command_handles) { - ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(handle, &update_desc)); + // Update kernel inputs + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + handle, // hCommand + kernel, // hNewKernel + 0, // numNewMemObjArgs + 2, // numNewPointerArgs + 1, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + new_input_descs, // pNewPointerArgList + &new_A_desc, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + update_descs.push_back(update_desc); } + + // Update kernel and enqueue command-buffer again + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( + updatable_cmd_buf_handle, update_descs.size(), update_descs.data())); ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queue)); diff --git a/unified-runtime/test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp b/unified-runtime/test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp index 0094caa274128..ff84f0046b32a 100644 --- a/unified-runtime/test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp +++ b/unified-runtime/test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp @@ -366,6 +366,7 @@ TEST_P(urMultiDeviceCommandBufferExpTest, Update) { ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype nullptr, // pNext + command, // hCommand kernel, // hNewKernel 0, // numNewMemObjArgs 0, // numNewPointerArgs @@ -378,7 +379,8 @@ TEST_P(urMultiDeviceCommandBufferExpTest, Update) { nullptr, // pNewGlobalWorkSize nullptr, // pNewLocalWorkSize }; - ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(command, &update_desc)); + ASSERT_SUCCESS( + urCommandBufferUpdateKernelLaunchExp(cmd_buf_handle, 1, &update_desc)); ASSERT_SUCCESS(urCommandBufferEnqueueExp(cmd_buf_handle, queues[i], 0, nullptr, nullptr)); ASSERT_SUCCESS(urQueueFinish(queues[i])); From fb60d962501d0f0f2b4f3aabe47dea6bbb862dec Mon Sep 17 00:00:00 2001 From: Ewan Crawford Date: Fri, 21 Feb 2025 10:36:14 +0000 Subject: [PATCH 2/2] Refactor how we find updatable partitions --- sycl/source/detail/graph_impl.cpp | 122 ++++++++++------------ sycl/source/detail/graph_impl.hpp | 38 ++++--- sycl/source/detail/scheduler/commands.cpp | 11 +- 3 files changed, 84 insertions(+), 87 deletions(-) diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index 2e8b3ced44ce3..31f2e24bdf167 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -1401,6 +1401,16 @@ void exec_graph_impl::update( std::vector UpdateRequirements; bool NeedScheduledUpdate = needsScheduledUpdate(Nodes, UpdateRequirements); if (NeedScheduledUpdate) { + // Clean up any execution events which have finished so we don't pass them + // to the scheduler. + for (auto It = MExecutionEvents.begin(); It != MExecutionEvents.end();) { + if ((*It)->isCompleted()) { + It = MExecutionEvents.erase(It); + continue; + } + ++It; + } + auto AllocaQueue = std::make_shared( sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()), sycl::detail::getSyclObjImpl(MGraphImpl->getContext()), @@ -1416,11 +1426,12 @@ void exec_graph_impl::update( } else { // For each partition in the executable graph, call UR update on the // command-buffer with the nodes to update. - auto PartitionedNodes = getPartitionForNodes(Nodes); + auto PartitionedNodes = getURUpdatableNodes(Nodes); for (auto It = PartitionedNodes.begin(); It != PartitionedNodes.end(); It++) { - auto CommandBuffer = It->first->MCommandBuffers[MDevice]; - updateKernelsImpl(CommandBuffer, It->second); + auto &Partition = MPartitions[It->first]; + auto CommandBuffer = Partition->MCommandBuffers[MDevice]; + updateURImpl(CommandBuffer, It->second); } } @@ -1475,16 +1486,6 @@ bool exec_graph_impl::needsScheduledUpdate( } } - // Clean up any execution events which have finished so we don't pass them to - // the scheduler. - for (auto It = MExecutionEvents.begin(); It != MExecutionEvents.end();) { - if ((*It)->isCompleted()) { - It = MExecutionEvents.erase(It); - continue; - } - ++It; - } - // If we have previous execution events do the update through the scheduler to // ensure it is ordered correctly. NeedScheduledUpdate |= MExecutionEvents.size() > 0; @@ -1499,7 +1500,7 @@ void exec_graph_impl::populateURKernelUpdateStructs( std::vector &PtrDescs, std::vector &ValueDescs, sycl::detail::NDRDescT &NDRDesc, - ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) { + ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) const { auto ContextImpl = sycl::detail::getSyclObjImpl(MContext); const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter(); auto DeviceImpl = sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()); @@ -1656,56 +1657,52 @@ void exec_graph_impl::populateURKernelUpdateStructs( auto ExecNode = MIDCache.find(Node->MID); assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); - ur_exp_command_buffer_command_handle_t Command = - MCommandMap[ExecNode->second]; - UpdateDesc.hCommand = Command; + auto Command = MCommandMap.find(ExecNode->second); + assert(Command != MCommandMap.end()); + UpdateDesc.hCommand = Command->second; // Update ExecNode with new values from Node, in case we ever need to // rebuild the command buffers ExecNode->second->updateFromOtherNode(Node); } -std::map, std::vector>> -exec_graph_impl::getPartitionForNodes( - const std::vector> &Nodes) { - // Iterate over each partition in the executable graph, and find the nodes - // in "Nodes" that also exist in the partition. - std::map, std::vector>> - PartitionedNodes; - for (const auto &Partition : MPartitions) { - std::vector> NodesForPartition; - const auto PartitionBegin = Partition->MSchedule.begin(); - const auto PartitionEnd = Partition->MSchedule.end(); - for (auto &Node : Nodes) { - auto ExecNode = MIDCache.find(Node->MID); - assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); - - if (std::find_if(PartitionBegin, PartitionEnd, - [ExecNode](const auto &PartitionNode) { - return PartitionNode->MID == ExecNode->second->MID; - }) != PartitionEnd) { - NodesForPartition.push_back(Node); - } - } - if (!NodesForPartition.empty()) { - PartitionedNodes.insert({Partition, NodesForPartition}); +std::map>> +exec_graph_impl::getURUpdatableNodes( + const std::vector> &Nodes) const { + // Iterate over the list of nodes, and for every node that can + // be updated through UR, add it to the list of nodes for + // that can be updated for the UR command-buffer partition. + std::map>> PartitionedNodes; + + // Initialize vector for each partition + for (size_t i = 0; i < MPartitions.size(); i++) { + PartitionedNodes[i] = {}; + } + + for (auto &Node : Nodes) { + // Kernel node update is the only command type supported in UR for update. + if (Node->MCGType != sycl::detail::CGType::Kernel) { + continue; } + + auto ExecNode = MIDCache.find(Node->MID); + assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); + auto PartitionIndex = MPartitionNodes.find(ExecNode->second); + assert(PartitionIndex != MPartitionNodes.end()); + PartitionedNodes[PartitionIndex->second].push_back(Node); } return PartitionedNodes; } void exec_graph_impl::updateHostTasksImpl( - const std::vector> &Nodes) { + const std::vector> &Nodes) const { for (auto &Node : Nodes) { if (Node->MNodeType != node_type::host_task) { continue; } // Query the ID cache to find the equivalent exec node for the node passed // to this function. - // TODO: Handle subgraphs or any other cases where multiple nodes may be - // associated with a single key, once those node types are supported for - // update. auto ExecNode = MIDCache.find(Node->MID); assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); @@ -1713,40 +1710,31 @@ void exec_graph_impl::updateHostTasksImpl( } } -void exec_graph_impl::updateKernelsImpl( +void exec_graph_impl::updateURImpl( ur_exp_command_buffer_handle_t CommandBuffer, - const std::vector> &Nodes) { - // Kernel node update is the only command type supported in UR for update. - // Updating any other types of nodes, e.g. empty & barrier nodes is a no-op. - size_t NumKernelNodes = 0; - for (auto &N : Nodes) { - if (N->MCGType == sycl::detail::CGType::Kernel) { - NumKernelNodes++; - } - } - - // Don't need to call through to UR if no kernel nodes to update - if (NumKernelNodes == 0) { + const std::vector> &Nodes) const { + const size_t NumUpdatableNodes = Nodes.size(); + if (NumUpdatableNodes == 0) { return; } std::vector> - MemobjDescsList(NumKernelNodes); + MemobjDescsList(NumUpdatableNodes); std::vector> - PtrDescsList(NumKernelNodes); + PtrDescsList(NumUpdatableNodes); std::vector> - ValueDescsList(NumKernelNodes); - std::vector NDRDescList(NumKernelNodes); + ValueDescsList(NumUpdatableNodes); + std::vector NDRDescList(NumUpdatableNodes); std::vector UpdateDescList( - NumKernelNodes); + NumUpdatableNodes); std::vector> - KernelBundleObjList(NumKernelNodes); + KernelBundleObjList(NumUpdatableNodes); size_t StructListIndex = 0; for (auto &Node : Nodes) { - if (Node->MCGType != sycl::detail::CGType::Kernel) { - continue; - } + // This should be the case when getURUpdatableNodes() is used to + // create the list of nodes. + assert(Node->MCGType == sycl::detail::CGType::Kernel); auto &MemobjDescs = MemobjDescsList[StructListIndex]; auto &KernelBundleObjs = KernelBundleObjList[StructListIndex]; diff --git a/sycl/source/detail/graph_impl.hpp b/sycl/source/detail/graph_impl.hpp index b2eeac7b09cdf..303600a5dc611 100644 --- a/sycl/source/detail/graph_impl.hpp +++ b/sycl/source/detail/graph_impl.hpp @@ -1303,18 +1303,30 @@ class exec_graph_impl { void update(std::shared_ptr Node); void update(const std::vector> &Nodes); - /// Calls UR entry-point to update kernel nodes in command-buffer. + /// Calls UR entry-point to update nodes in command-buffer. /// @param CommandBuffer The UR command-buffer to update commands in. - /// @param Nodes List of nodes to update. May contain nodes of non-kernel - /// type, but only kernel nodes from the list will be used for update - void updateKernelsImpl(ur_exp_command_buffer_handle_t CommandBuffer, - const std::vector> &Nodes); + /// @param Nodes List of nodes to update. Only nodes which can be updated + /// through UR should be included in this list, currently this is only + /// nodes of kernel type. + void updateURImpl(ur_exp_command_buffer_handle_t CommandBuffer, + const std::vector> &Nodes) const; - /// Splits a list of nodes into separate lists depending on partition. + /// Update host-task nodes + /// @param Nodes List of nodes to update, any node that is not a host-task + /// will be ignored. + void updateHostTasksImpl( + const std::vector> &Nodes) const; + + /// Splits a list of nodes into separate lists of nodes for each + /// command-buffer partition. + /// + /// Only nodes that can be updated through the UR interface are included + /// in the list. Currently this is only kernel node types. + /// /// @param Nodes List of nodes to split - /// @return Map of partitions to nodes - std::map, std::vector>> - getPartitionForNodes(const std::vector> &Nodes); + /// @return Map of partition indexes to nodes + std::map>> getURUpdatableNodes( + const std::vector> &Nodes) const; unsigned long long getID() const { return MID; } @@ -1408,13 +1420,7 @@ class exec_graph_impl { std::vector &PtrDescs, std::vector &ValueDescs, sycl::detail::NDRDescT &NDRDesc, - ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc); - - /// Updates host-task nodes in the graph - /// @param Nodes List of nodes to update, any node that is not a host-task - /// will be ignored. - void - updateHostTasksImpl(const std::vector> &Nodes); + ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) const; /// Execution schedule of nodes in the graph. std::list> MSchedule; diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index f8e8a02ee73d9..6ebc6352a8f13 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -3744,12 +3744,15 @@ ur_result_t UpdateCommandBufferCommand::enqueueImp() { } // Split list of nodes into nodes per UR command-buffer partition, then - // call UR update on each command-buffer partition - auto PartitionedNodes = MGraph->getPartitionForNodes(MNodes); + // call UR update on each command-buffer partition with those updatable + // nodes. + auto PartitionedNodes = MGraph->getURUpdatableNodes(MNodes); auto Device = MQueue->get_device(); + auto &Partitions = MGraph->getPartitions(); for (auto It = PartitionedNodes.begin(); It != PartitionedNodes.end(); It++) { - auto CommandBuffer = It->first->MCommandBuffers[Device]; - MGraph->updateKernelsImpl(CommandBuffer, It->second); + const int PartitionIndex = It->first; + auto CommandBuffer = Partitions[PartitionIndex]->MCommandBuffers[Device]; + MGraph->updateURImpl(CommandBuffer, It->second); } return UR_RESULT_SUCCESS;