Skip to content

Commit

Permalink
[SYCL][Graph][UR] Propagate graph update list to UR
Browse files Browse the repository at this point in the history
Update the `urCommandBufferUpdateKernelLaunchExp` API for updating commands in a command-buffer to take a list of commands.

The current API operates on a single command, this means that the SYCL-Graph `update(std::vector<nodes>)` API needs to serialize the list into N calls to the UR API. Given that both OpenCL `clUpdateMutableCommandsKHR` and Level-Zero
`zeCommandListUpdateMutableCommandsExp` can operate on a list of commands, this serialization at the UR layer of the stack introduces extra host API calls.

This PR improves the `urCommandBufferUpdateKernelLaunchExp` API so that a list of commands is passed all the way from SYCL to the native backend API.

As highlighted in oneapi-src/unified-runtime#2671 this patch requires the handle translation auto generated code to be able to handle a list of structs, which is not currently the case. This is PR includes a API specific temporary workaround in the mako file which will unblock this PR until a more permanent solution is completed.

Co-authored-by: Ross Brunton <[email protected]>
  • Loading branch information
EwanC and RossBrunton committed Feb 19, 2025
1 parent 42a9485 commit 1fd571d
Show file tree
Hide file tree
Showing 32 changed files with 1,576 additions and 1,024 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,7 @@ Exceptions:
created.
* Throws with error code `invalid` if `node` is not part of the
graph.
* If any other exception is thrown the state of the graph node is undefined.

|
[source,c++]
Expand Down Expand Up @@ -1465,6 +1466,7 @@ Exceptions:
`property::graph::updatable` was not set when the executable graph was created.
* Throws with error code `invalid` if any node in `nodes` is not part of the
graph.
* If any other exception is thrown the state of the graph nodes is undefined.

|
[source, c++]
Expand Down Expand Up @@ -1517,6 +1519,8 @@ Exceptions:
* Throws synchronously with error code `invalid` if
`property::graph::updatable` was not set when the executable graph was
created.

* If any other exception is thrown the state of the graph nodes is undefined.
|===

Table {counter: tableNumber}. Member functions of the `command_graph` class for
Expand Down
296 changes: 192 additions & 104 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1381,18 +1381,74 @@ void exec_graph_impl::update(std::shared_ptr<node_impl> Node) {

void exec_graph_impl::update(
const std::vector<std::shared_ptr<node_impl>> &Nodes) {

if (!MIsUpdatable) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"update() cannot be called on a executable graph "
"which was not created with property::updatable");
}

// If the graph contains host tasks we need special handling here because
// their state lives in the graph object itself, so we must do the update
// immediately here. Whereas all other command state lives in the backend so
// it can be scheduled along with other commands.
if (MContainsHostTask) {
updateHostTasksImpl(Nodes);
}

// If there are any accessor requirements, we have to update through the
// scheduler to ensure that any allocations have taken place before trying
// to update.
std::vector<sycl::detail::AccessorImplHost *> UpdateRequirements;
bool NeedScheduledUpdate = needsScheduledUpdate(Nodes, UpdateRequirements);
if (NeedScheduledUpdate) {
auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
// Skip if device doesn't support out-of-order queues, we need
// to create one for both instantiations of the test.
sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()),
sycl::detail::getSyclObjImpl(MGraphImpl->getContext()),
sycl::async_handler{}, sycl::property_list{});

// Track the event for the update command since execution may be blocked by
// other scheduler commands
auto UpdateEvent =
sycl::detail::Scheduler::getInstance().addCommandGraphUpdate(
this, Nodes, AllocaQueue, UpdateRequirements, MExecutionEvents);

MExecutionEvents.push_back(UpdateEvent);
} else {
// For each partition in the executable graph, call UR update on the
// command-buffer with the nodes to update.
auto PartitionedNodes = getPartitionForNodes(Nodes);
for (auto It = PartitionedNodes.begin(); It != PartitionedNodes.end();
It++) {
auto CommandBuffer = It->first->MCommandBuffers[MDevice];
updateKernelsImpl(CommandBuffer, It->second);
}
}

// Rebuild cached requirements and accessor storage for this graph with
// updated nodes
MRequirements.clear();
MAccessors.clear();
for (auto &Node : MNodeStorage) {
if (!Node->MCommandGroup)
continue;
MRequirements.insert(MRequirements.end(),
Node->MCommandGroup->getRequirements().begin(),
Node->MCommandGroup->getRequirements().end());
MAccessors.insert(MAccessors.end(),
Node->MCommandGroup->getAccStorage().begin(),
Node->MCommandGroup->getAccStorage().end());
}
}

bool exec_graph_impl::needsScheduledUpdate(
const std::vector<std::shared_ptr<node_impl>> &Nodes,
std::vector<sycl::detail::AccessorImplHost *> &UpdateRequirements) {
// If there are any accessor requirements, we have to update through the
// scheduler to ensure that any allocations have taken place before trying to
// update.
bool NeedScheduledUpdate = false;
std::vector<sycl::detail::AccessorImplHost *> UpdateRequirements;
// At worst we may have as many requirements as there are for the entire graph
// for updating.
UpdateRequirements.reserve(MRequirements.size());
Expand Down Expand Up @@ -1435,94 +1491,17 @@ void exec_graph_impl::update(
// ensure it is ordered correctly.
NeedScheduledUpdate |= MExecutionEvents.size() > 0;

if (NeedScheduledUpdate) {
// Copy the list of nodes as we may need to modify it
auto NodesCopy = Nodes;

// If the graph contains host tasks we need special handling here because
// their state lives in the graph object itself, so we must do the update
// immediately here. Whereas all other command state lives in the backend so
// it can be scheduled along with other commands.
if (MContainsHostTask) {
std::vector<std::shared_ptr<node_impl>> HostTasks;
// Remove any nodes that are host tasks and put them in HostTasks
auto RemovedIter = std::remove_if(
NodesCopy.begin(), NodesCopy.end(),
[&HostTasks](const std::shared_ptr<node_impl> &Node) -> bool {
if (Node->MNodeType == node_type::host_task) {
HostTasks.push_back(Node);
return true;
}
return false;
});
// Clean up extra elements in NodesCopy after the remove
NodesCopy.erase(RemovedIter, NodesCopy.end());

// Update host-tasks synchronously
for (auto &HostTaskNode : HostTasks) {
updateImpl(HostTaskNode);
}
}

auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()),
sycl::detail::getSyclObjImpl(MGraphImpl->getContext()),
sycl::async_handler{}, sycl::property_list{});

// Track the event for the update command since execution may be blocked by
// other scheduler commands
auto UpdateEvent =
sycl::detail::Scheduler::getInstance().addCommandGraphUpdate(
this, std::move(NodesCopy), AllocaQueue, UpdateRequirements,
MExecutionEvents);

MExecutionEvents.push_back(UpdateEvent);
} else {
for (auto &Node : Nodes) {
updateImpl(Node);
}
}

// Rebuild cached requirements and accessor storage for this graph with
// updated nodes
MRequirements.clear();
MAccessors.clear();
for (auto &Node : MNodeStorage) {
if (!Node->MCommandGroup)
continue;
MRequirements.insert(MRequirements.end(),
Node->MCommandGroup->getRequirements().begin(),
Node->MCommandGroup->getRequirements().end());
MAccessors.insert(MAccessors.end(),
Node->MCommandGroup->getAccStorage().begin(),
Node->MCommandGroup->getAccStorage().end());
}
return NeedScheduledUpdate;
}

void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
// Updating empty or barrier nodes is a no-op
if (Node->isEmpty() || Node->MNodeType == node_type::ext_oneapi_barrier) {
return;
}

// Query the ID cache to find the equivalent exec node for the node passed to
// this function.
// TODO: Handle subgraphs or any other cases where multiple nodes may be
// associated with a single key, once those node types are supported for
// update.
auto ExecNode = MIDCache.find(Node->MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");

// Update ExecNode with new values from Node, in case we ever need to
// rebuild the command buffers
ExecNode->second->updateFromOtherNode(Node);

// Host task update only requires updating the node itself, so can return
// early
if (Node->MNodeType == node_type::host_task) {
return;
}

void exec_graph_impl::populateURKernelUpdateStructs(
const std::shared_ptr<node_impl> &Node,
std::pair<ur_program_handle_t, ur_kernel_handle_t> &BundleObjs,
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
std::vector<ur_exp_command_buffer_update_value_arg_desc_t> &ValueDescs,
sycl::detail::NDRDescT &NDRDesc,
ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) {
auto ContextImpl = sycl::detail::getSyclObjImpl(MContext);
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
auto DeviceImpl = sycl::detail::getSyclObjImpl(MGraphImpl->getDevice());
Expand All @@ -1533,9 +1512,8 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
// Copy args because we may modify them
std::vector<sycl::detail::ArgDesc> NodeArgs = ExecCG.getArguments();
// Copy NDR desc since we need to modify it
auto NDRDesc = ExecCG.MNDRDesc;
NDRDesc = ExecCG.MNDRDesc;

ur_program_handle_t UrProgram = nullptr;
ur_kernel_handle_t UrKernel = nullptr;
auto Kernel = ExecCG.MSyclKernel;
auto KernelBundleImplPtr = ExecCG.MKernelBundle;
Expand All @@ -1560,9 +1538,11 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
UrKernel = Kernel->getHandleRef();
EliminatedArgMask = Kernel->getKernelArgMask();
} else {
ur_program_handle_t UrProgram = nullptr;
std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) =
sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
ContextImpl, DeviceImpl, ExecCG.MKernelName);
BundleObjs = std::make_pair(UrProgram, UrKernel);
}

// Remove eliminated args
Expand Down Expand Up @@ -1596,17 +1576,12 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
if (EnforcedLocalSize)
LocalSize = RequiredWGSize;
}
// Create update descriptor

// Storage for individual arg descriptors
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> MemobjDescs;
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> PtrDescs;
std::vector<ur_exp_command_buffer_update_value_arg_desc_t> ValueDescs;
MemobjDescs.reserve(MaskedArgs.size());
PtrDescs.reserve(MaskedArgs.size());
ValueDescs.reserve(MaskedArgs.size());

ur_exp_command_buffer_update_kernel_launch_desc_t UpdateDesc{};
UpdateDesc.stype =
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC;
UpdateDesc.pNext = nullptr;
Expand Down Expand Up @@ -1675,20 +1650,133 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
UpdateDesc.pNewLocalWorkSize = LocalSize;
UpdateDesc.newWorkDim = NDRDesc.Dims;

// Query the ID cache to find the equivalent exec node for the node passed to
// this function.
// TODO: Handle subgraphs or any other cases where multiple nodes may be
// associated with a single key, once those node types are supported for
// update.
auto ExecNode = MIDCache.find(Node->MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");

ur_exp_command_buffer_command_handle_t Command =
MCommandMap[ExecNode->second];
ur_result_t Res = Adapter->call_nocheck<
sycl::detail::UrApiKind::urCommandBufferUpdateKernelLaunchExp>(
Command, &UpdateDesc);
UpdateDesc.hCommand = Command;

// Update ExecNode with new values from Node, in case we ever need to
// rebuild the command buffers
ExecNode->second->updateFromOtherNode(Node);
}

std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
exec_graph_impl::getPartitionForNodes(
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
// Iterate over each partition in the executable graph, and find the nodes
// in "Nodes" that also exist in the partition.
std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
PartitionedNodes;
for (const auto &Partition : MPartitions) {
std::vector<std::shared_ptr<node_impl>> NodesForPartition;
const auto PartitionBegin = Partition->MSchedule.begin();
const auto PartitionEnd = Partition->MSchedule.end();
for (auto &Node : Nodes) {
auto ExecNode = MIDCache.find(Node->MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");

if (std::find_if(PartitionBegin, PartitionEnd,
[ExecNode](const auto &PartitionNode) {
return PartitionNode->MID == ExecNode->second->MID;
}) != PartitionEnd) {
NodesForPartition.push_back(Node);
}
}
if (!NodesForPartition.empty()) {
PartitionedNodes.insert({Partition, NodesForPartition});
}
}

return PartitionedNodes;
}

void exec_graph_impl::updateHostTasksImpl(
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
for (auto &Node : Nodes) {
if (Node->MNodeType != node_type::host_task) {
continue;
}
// Query the ID cache to find the equivalent exec node for the node passed
// to this function.
// TODO: Handle subgraphs or any other cases where multiple nodes may be
// associated with a single key, once those node types are supported for
// update.
auto ExecNode = MIDCache.find(Node->MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");

if (UrProgram) {
// We retained these objects by calling getOrCreateKernel()
Adapter->call<sycl::detail::UrApiKind::urKernelRelease>(UrKernel);
Adapter->call<sycl::detail::UrApiKind::urProgramRelease>(UrProgram);
// Update ExecNode with new values from Node, in case we ever need to
// rebuild the command buffers
ExecNode->second->updateFromOtherNode(Node);
}
}

if (Res != UR_RESULT_SUCCESS) {
throw sycl::exception(errc::invalid, "Error updating command_graph");
void exec_graph_impl::updateKernelsImpl(
ur_exp_command_buffer_handle_t CommandBuffer,
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
// Kernel node update is the only command type supported in UR for update.
// Updating any other types of nodes, e.g. empty & barrier nodes is a no-op.
size_t NumKernelNodes = 0;
for (auto &N : Nodes) {
if (N->MCGType == sycl::detail::CGType::Kernel) {
NumKernelNodes++;
}
}

// Don't need to call through to UR if no kernel nodes to update
if (NumKernelNodes == 0) {
return;
}

std::vector<std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t>>
MemobjDescsList(NumKernelNodes);
std::vector<std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t>>
PtrDescsList(NumKernelNodes);
std::vector<std::vector<ur_exp_command_buffer_update_value_arg_desc_t>>
ValueDescsList(NumKernelNodes);
std::vector<sycl::detail::NDRDescT> NDRDescList(NumKernelNodes);
std::vector<ur_exp_command_buffer_update_kernel_launch_desc_t> UpdateDescList(
NumKernelNodes);
std::vector<std::pair<ur_program_handle_t, ur_kernel_handle_t>>
KernelBundleObjList(NumKernelNodes);

size_t StructListIndex = 0;
for (auto &Node : Nodes) {
if (Node->MCGType != sycl::detail::CGType::Kernel) {
continue;
}

auto &MemobjDescs = MemobjDescsList[StructListIndex];
auto &KernelBundleObjs = KernelBundleObjList[StructListIndex];
auto &PtrDescs = PtrDescsList[StructListIndex];
auto &ValueDescs = ValueDescsList[StructListIndex];
auto &NDRDesc = NDRDescList[StructListIndex];
auto &UpdateDesc = UpdateDescList[StructListIndex];
populateURKernelUpdateStructs(Node, KernelBundleObjs, MemobjDescs, PtrDescs,
ValueDescs, NDRDesc, UpdateDesc);
StructListIndex++;
}

auto ContextImpl = sycl::detail::getSyclObjImpl(MContext);
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
Adapter->call<sycl::detail::UrApiKind::urCommandBufferUpdateKernelLaunchExp>(
CommandBuffer, UpdateDescList.size(), UpdateDescList.data());

for (auto &BundleObjs : KernelBundleObjList) {
// We retained these objects by inside populateUpdateStruct() by calling
// getOrCreateKernel()
if (auto &UrKernel = BundleObjs.second; nullptr != UrKernel) {
Adapter->call<sycl::detail::UrApiKind::urKernelRelease>(UrKernel);
}
if (auto &UrProgram = BundleObjs.first; nullptr != UrProgram) {
Adapter->call<sycl::detail::UrApiKind::urProgramRelease>(UrProgram);
}
}
}

Expand Down
Loading

0 comments on commit 1fd571d

Please sign in to comment.