Skip to content

Commit

Permalink
[SYCL][Graph] Batch graph updates
Browse files Browse the repository at this point in the history
Uses PR oneapi-src/unified-runtime#2666 to
pass a list of updates to UR in a single host call per
UR command-buffer in the graph, rather than
making N calls to UR for N nodes to update.
  • Loading branch information
EwanC committed Feb 18, 2025
1 parent 2a9ca31 commit 81e25a9
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 56 deletions.
2 changes: 1 addition & 1 deletion sycl/cmake/modules/FetchUnifiedRuntime.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ elseif(SYCL_UR_USE_FETCH_CONTENT)
CACHE PATH "Path to external '${name}' adapter source dir" FORCE)
endfunction()

set(UNIFIED_RUNTIME_REPO "https://github.com/oneapi-src/unified-runtime.git")
set(UNIFIED_RUNTIME_REPO "https://github.com/Bensuo/unified-runtime.git")
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules/UnifiedRuntimeTag.cmake)

set(UMF_BUILD_EXAMPLES OFF CACHE INTERNAL "EXAMPLES")
Expand Down
2 changes: 1 addition & 1 deletion sycl/cmake/modules/UnifiedRuntimeTag.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# Date: Thu Feb 13 11:43:34 2025 +0000
# Merge pull request #2680 from ldorau/Set_UMF_CUDA_INCLUDE_DIR_to_not_fetch_cudart_from_gitlab
# Do not fetch cudart from gitlab for UMF
set(UNIFIED_RUNTIME_TAG d03f19a88e42cb98be9604ff24b61190d1e48727)
set(UNIFIED_RUNTIME_TAG "ewan/update_list")
Original file line number Diff line number Diff line change
Expand Up @@ -1396,6 +1396,7 @@ Exceptions:
created.
* Throws with error code `invalid` if `node` is not part of the
graph.
* If any other exception is thrown the state of the graph node is undefined.

|
[source,c++]
Expand Down Expand Up @@ -1430,6 +1431,7 @@ Exceptions:
`property::graph::updatable` was not set when the executable graph was created.
* Throws with error code `invalid` if any node in `nodes` is not part of the
graph.
* If any other exception is thrown the state of the graph nodes is undefined.

|
[source, c++]
Expand Down Expand Up @@ -1482,6 +1484,8 @@ Exceptions:
* Throws synchronously with error code `invalid` if
`property::graph::updatable` was not set when the executable graph was
created.

* If any other exception is thrown the state of the graph nodes is undefined.
|===

Table {counter: tableNumber}. Member functions of the `command_graph` class for
Expand Down
200 changes: 148 additions & 52 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1375,20 +1375,14 @@ void exec_graph_impl::update(std::shared_ptr<node_impl> Node) {
this->update(std::vector<std::shared_ptr<node_impl>>{Node});
}

void exec_graph_impl::update(
const std::vector<std::shared_ptr<node_impl>> Nodes) {

if (!MIsUpdatable) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"update() cannot be called on a executable graph "
"which was not created with property::updatable");
}
bool exec_graph_impl::needsScheduledUpdate(
const std::vector<std::shared_ptr<node_impl>> &Nodes,
std::vector<sycl::detail::AccessorImplHost *> &UpdateRequirements) {

// If there are any accessor requirements, we have to update through the
// scheduler to ensure that any allocations have taken place before trying to
// update.
bool NeedScheduledUpdate = false;
std::vector<sycl::detail::AccessorImplHost *> UpdateRequirements;
// At worst we may have as many requirements as there are for the entire graph
// for updating.
UpdateRequirements.reserve(MRequirements.size());
Expand Down Expand Up @@ -1431,37 +1425,47 @@ void exec_graph_impl::update(
// ensure it is ordered correctly.
NeedScheduledUpdate |= MExecutionEvents.size() > 0;

if (NeedScheduledUpdate) {
auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()),
sycl::detail::getSyclObjImpl(MGraphImpl->getContext()),
sycl::async_handler{}, sycl::property_list{});
// Don't need to care about the return event here because it is synchronous
sycl::detail::Scheduler::getInstance().addCommandGraphUpdate(
this, Nodes, AllocaQueue, UpdateRequirements, MExecutionEvents);
} else {
return NeedScheduledUpdate;
}

std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
exec_graph_impl::getPartitionForNodes(
const std::vector<std::shared_ptr<node_impl>> &Nodes) {
// Iterate over each partition in the executable graph, and find the nodes
// in "Nodes" that also exist in the partition.
std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
PartitionedNodes;
for (const auto &Partition : MPartitions) {
std::vector<std::shared_ptr<node_impl>> NodesForPartition;
const auto PartitionBegin = Partition->MSchedule.begin();
const auto PartitionEnd = Partition->MSchedule.end();
for (auto &Node : Nodes) {
updateImpl(Node);
auto ExecNode = MIDCache.find(Node->MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");

if (std::find_if(PartitionBegin, PartitionEnd,
[ExecNode](const auto &PartitionNode) {
return PartitionNode->MID == ExecNode->second->MID;
}) != PartitionEnd) {
NodesForPartition.push_back(Node);
}
}
if (!NodesForPartition.empty()) {
PartitionedNodes.insert({Partition, NodesForPartition});
}
}

// Rebuild cached requirements for this graph with updated nodes
MRequirements.clear();
for (auto &Node : MNodeStorage) {
if (!Node->MCommandGroup)
continue;
MRequirements.insert(MRequirements.end(),
Node->MCommandGroup->getRequirements().begin(),
Node->MCommandGroup->getRequirements().end());
}
return PartitionedNodes;
}

void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
// Kernel node update is the only command type supported in UR for update.
// Updating any other types of nodes, e.g. empty & barrier nodes is a no-op.
if (Node->MCGType != sycl::detail::CGType::Kernel) {
return;
}
void exec_graph_impl::populateUpdateStruct(
const std::shared_ptr<node_impl> &Node,
std::pair<ur_program_handle_t, ur_kernel_handle_t> &BundleObjs,
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
std::vector<ur_exp_command_buffer_update_value_arg_desc_t> &ValueDescs,
sycl::detail::NDRDescT &NDRDesc,
ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) {
auto ContextImpl = sycl::detail::getSyclObjImpl(MContext);
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
auto DeviceImpl = sycl::detail::getSyclObjImpl(MGraphImpl->getDevice());
Expand All @@ -1472,9 +1476,8 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
// Copy args because we may modify them
std::vector<sycl::detail::ArgDesc> NodeArgs = ExecCG.getArguments();
// Copy NDR desc since we need to modify it
auto NDRDesc = ExecCG.MNDRDesc;
NDRDesc = ExecCG.MNDRDesc;

ur_program_handle_t UrProgram = nullptr;
ur_kernel_handle_t UrKernel = nullptr;
auto Kernel = ExecCG.MSyclKernel;
auto KernelBundleImplPtr = ExecCG.MKernelBundle;
Expand All @@ -1499,9 +1502,11 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
UrKernel = Kernel->getHandleRef();
EliminatedArgMask = Kernel->getKernelArgMask();
} else {
ur_program_handle_t UrProgram = nullptr;
std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) =
sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
ContextImpl, DeviceImpl, ExecCG.MKernelName);
BundleObjs = std::make_pair(UrProgram, UrKernel);
}

// Remove eliminated args
Expand Down Expand Up @@ -1535,17 +1540,12 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
if (EnforcedLocalSize)
LocalSize = RequiredWGSize;
}
// Create update descriptor

// Storage for individual arg descriptors
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> MemobjDescs;
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> PtrDescs;
std::vector<ur_exp_command_buffer_update_value_arg_desc_t> ValueDescs;
MemobjDescs.reserve(MaskedArgs.size());
PtrDescs.reserve(MaskedArgs.size());
ValueDescs.reserve(MaskedArgs.size());

ur_exp_command_buffer_update_kernel_launch_desc_t UpdateDesc{};
UpdateDesc.stype =
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC;
UpdateDesc.pNext = nullptr;
Expand Down Expand Up @@ -1622,24 +1622,120 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
auto ExecNode = MIDCache.find(Node->MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");

ur_exp_command_buffer_command_handle_t Command =
MCommandMap[ExecNode->second];
UpdateDesc.hCommand = Command;

// Update ExecNode with new values from Node, in case we ever need to
// rebuild the command buffers
ExecNode->second->updateFromOtherNode(Node);
}

ur_exp_command_buffer_command_handle_t Command =
MCommandMap[ExecNode->second];
ur_result_t Res = Adapter->call_nocheck<
sycl::detail::UrApiKind::urCommandBufferUpdateKernelLaunchExp>(
Command, &UpdateDesc);
void exec_graph_impl::update(
const std::vector<std::shared_ptr<node_impl>> Nodes) {

if (UrProgram) {
// We retained these objects by calling getOrCreateKernel()
Adapter->call<sycl::detail::UrApiKind::urKernelRelease>(UrKernel);
Adapter->call<sycl::detail::UrApiKind::urProgramRelease>(UrProgram);
if (!MIsUpdatable) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"update() cannot be called on a executable graph "
"which was not created with property::updatable");
}

if (Res != UR_RESULT_SUCCESS) {
throw sycl::exception(errc::invalid, "Error updating command_graph");
// If there are any accessor requirements, we have to update through the
// scheduler to ensure that any allocations have taken place before trying
// to update.
std::vector<sycl::detail::AccessorImplHost *> UpdateRequirements;
bool NeedScheduledUpdate = needsScheduledUpdate(Nodes, UpdateRequirements);
if (NeedScheduledUpdate) {
auto AllocaQueue = std::make_shared<sycl::detail::queue_impl>(
sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()),
sycl::detail::getSyclObjImpl(MGraphImpl->getContext()),
sycl::async_handler{}, sycl::property_list{});
// Don't need to care about the return event here because it is
// synchronous
sycl::detail::Scheduler::getInstance().addCommandGraphUpdate(
this, Nodes, AllocaQueue, UpdateRequirements, MExecutionEvents);
} else {
// For each partition in the executable graph, call UR update on the
// command-buffer with the nodes to update.
auto PartitionedNodes = getPartitionForNodes(Nodes);
for (auto It = PartitionedNodes.begin(); It != PartitionedNodes.end();
It++) {
auto CommandBuffer = It->first->MCommandBuffers[MDevice];
updateImpl(CommandBuffer, It->second);
}
}

// Rebuild cached requirements for this graph with updated nodes
MRequirements.clear();
for (auto &Node : MNodeStorage) {
if (!Node->MCommandGroup)
continue;
MRequirements.insert(MRequirements.end(),
Node->MCommandGroup->getRequirements().begin(),
Node->MCommandGroup->getRequirements().end());
}
}

void exec_graph_impl::updateImpl(
ur_exp_command_buffer_handle_t CommandBuffer,
std::vector<std::shared_ptr<node_impl>> &Nodes) {
// Kernel node update is the only command type supported in UR for update.
// Updating any other types of nodes, e.g. empty & barrier nodes is a no-op.
size_t NumKernelNodes = 0;
for (auto &N : Nodes) {
if (N->MCGType == sycl::detail::CGType::Kernel) {
NumKernelNodes++;
}
}

// Don't need to call through to UR if no kernel nodes to update
if (NumKernelNodes == 0) {
return;
}

std::vector<std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t>>
MemobjDescsList(NumKernelNodes);
std::vector<std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t>>
PtrDescsList(NumKernelNodes);
std::vector<std::vector<ur_exp_command_buffer_update_value_arg_desc_t>>
ValueDescsList(NumKernelNodes);
std::vector<sycl::detail::NDRDescT> NDRDescList(NumKernelNodes);
std::vector<ur_exp_command_buffer_update_kernel_launch_desc_t> UpdateDescList(
NumKernelNodes);
std::vector<std::pair<ur_program_handle_t, ur_kernel_handle_t>>
KernelBundleObjList(NumKernelNodes);

size_t StructListIndex = 0;
for (auto &Node : Nodes) {
if (Node->MCGType != sycl::detail::CGType::Kernel) {
continue;
}

auto &MemobjDescs = MemobjDescsList[StructListIndex];
auto &KernelBundleObjs = KernelBundleObjList[StructListIndex];
auto &PtrDescs = PtrDescsList[StructListIndex];
auto &ValueDescs = ValueDescsList[StructListIndex];
auto &NDRDesc = NDRDescList[StructListIndex];
auto &UpdateDesc = UpdateDescList[StructListIndex];
populateUpdateStruct(Node, KernelBundleObjs, MemobjDescs, PtrDescs,
ValueDescs, NDRDesc, UpdateDesc);
StructListIndex++;
}

auto ContextImpl = sycl::detail::getSyclObjImpl(MContext);
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
Adapter->call<sycl::detail::UrApiKind::urCommandBufferUpdateKernelLaunchExp>(
CommandBuffer, UpdateDescList.size(), UpdateDescList.data());

for (auto &BundleObjs : KernelBundleObjList) {
// We retained these objects by inside populateUpdateStruct() by calling
// getOrCreateKernel()
if (auto &UrKernel = BundleObjs.second; nullptr != UrKernel) {
Adapter->call<sycl::detail::UrApiKind::urKernelRelease>(UrKernel);
}
if (auto &UrProgram = BundleObjs.first; nullptr != UrProgram) {
Adapter->call<sycl::detail::UrApiKind::urProgramRelease>(UrProgram);
}
}
}

Expand Down
38 changes: 37 additions & 1 deletion sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,17 @@ class exec_graph_impl {
void update(std::shared_ptr<node_impl> Node);
void update(const std::vector<std::shared_ptr<node_impl>> Nodes);

void updateImpl(std::shared_ptr<node_impl> NodeImpl);
/// Calls UR entry-point to update kernel nodes in command-buffer.
/// @param CommandBuffer The UR command-buffer to update commands in.
/// @param Nodes List of nodes to update.
void updateImpl(ur_exp_command_buffer_handle_t CommandBuffer,
std::vector<std::shared_ptr<node_impl>> &Nodes);

/// Splits a list of nodes into separate lists depending on partition.
/// @param Nodes List of nodes to split
/// @return Map of partitions to nodes
std::map<std::shared_ptr<partition>, std::vector<std::shared_ptr<node_impl>>>
getPartitionForNodes(const std::vector<std::shared_ptr<node_impl>> &Nodes);

unsigned long long getID() const { return MID; }

Expand Down Expand Up @@ -1371,6 +1381,32 @@ class exec_graph_impl {
Stream.close();
}

/// Determines if scheduler needs to be used for node update.
/// @param[in] Nodes List of nodes to be updated
/// @param[out] UpdateRequirements Accessor requirements found in /p Nodes.
/// return True if update should be done through the scheduler.
bool needsScheduledUpdate(
const std::vector<std::shared_ptr<node_impl>> &Nodes,
std::vector<sycl::detail::AccessorImplHost *> &UpdateRequirements);

/// Sets the UR struct values required to update a graph node.
/// @param[in] Node The node to be updated.
/// @param[out] BundleObjs UR objects created from kernel bundle.
/// Responsibility of the caller to release.
/// @param[out] MemobjDescs Memory object arguments to update.
/// @param[out] PtrDescs Pointer arguments to update.
/// @param[out] ValueDescs Value arguments to update.
/// @param[out] NDRDesc ND-Range to update.
/// @param[out] UpdateDesc Base struct in the pointer chain.
void populateUpdateStruct(
const std::shared_ptr<node_impl> &Node,
std::pair<ur_program_handle_t, ur_kernel_handle_t> &BundleObjs,
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
std::vector<ur_exp_command_buffer_update_value_arg_desc_t> &ValueDescs,
sycl::detail::NDRDescT &NDRDesc,
ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc);

/// Execution schedule of nodes in the graph.
std::list<std::shared_ptr<node_impl>> MSchedule;
/// Pointer to the modifiable graph impl associated with this executable
Expand Down
10 changes: 9 additions & 1 deletion sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3729,7 +3729,15 @@ ur_result_t UpdateCommandBufferCommand::enqueueImp() {
}
}
}
MGraph->updateImpl(Node);
}

// Split list of nodes into nodes per UR command-buffer partition, then
// call UR update on each command-buffer partition
auto PartitionedNodes = MGraph->getPartitionForNodes(MNodes);
auto Device = MQueue->get_device();
for (auto It = PartitionedNodes.begin(); It != PartitionedNodes.end(); It++) {
auto CommandBuffer = It->first->MCommandBuffers[Device];
MGraph->updateImpl(CommandBuffer, It->second);
}

return UR_RESULT_SUCCESS;
Expand Down

0 comments on commit 81e25a9

Please sign in to comment.