Skip to content

Commit

Permalink
[SYCL][NativeCPU] Fix kernel argument passing.
Browse files Browse the repository at this point in the history
We were reading the kernel arguments at kernel execution time, but kernel
arguments are allowed to change between enqueuing and executing. Make
sure to create a copy of kernel arguments ahead of time.
  • Loading branch information
hvdijk committed Feb 18, 2025
1 parent bae7012 commit 10ccd68
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 60 deletions.
43 changes: 23 additions & 20 deletions unified-runtime/source/adapters/native_cpu/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
pLocalWorkSize);
auto &tp = hQueue->getDevice()->tp;
const size_t numParallelThreads = tp.num_threads();
hKernel->updateMemPool(numParallelThreads);
std::vector<std::future<void>> futures;
std::vector<std::function<void(size_t, ur_kernel_handle_t_)>> groups;
std::vector<std::function<void(size_t, ur_kernel_handle_t_ &)>> groups;
auto numWG0 = ndr.GlobalSize[0] / ndr.LocalSize[0];
auto numWG1 = ndr.GlobalSize[1] / ndr.LocalSize[1];
auto numWG2 = ndr.GlobalSize[2] / ndr.LocalSize[2];
Expand All @@ -119,16 +118,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
auto event = new ur_event_handle_t_(hQueue, UR_COMMAND_KERNEL_LAUNCH);
event->tick_start();

// Create a copy of the kernel and its arguments.
auto kernel = std::make_unique<ur_kernel_handle_t_>(*hKernel);
kernel->updateMemPool(numParallelThreads);

#ifndef NATIVECPU_USE_OCK
hKernel->handleLocalArgs(1, 0);
for (unsigned g2 = 0; g2 < numWG2; g2++) {
for (unsigned g1 = 0; g1 < numWG1; g1++) {
for (unsigned g0 = 0; g0 < numWG0; g0++) {
for (unsigned local2 = 0; local2 < ndr.LocalSize[2]; local2++) {
for (unsigned local1 = 0; local1 < ndr.LocalSize[1]; local1++) {
for (unsigned local0 = 0; local0 < ndr.LocalSize[0]; local0++) {
state.update(g0, g1, g2, local0, local1, local2);
hKernel->_subhandler(hKernel->getArgs().data(), &state);
kernel->_subhandler(kernel->getArgs(1, 0).data(), &state);
}
}
}
Expand All @@ -139,7 +141,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
bool isLocalSizeOne =
ndr.LocalSize[0] == 1 && ndr.LocalSize[1] == 1 && ndr.LocalSize[2] == 1;
if (isLocalSizeOne && ndr.GlobalSize[0] > numParallelThreads &&
!hKernel->hasLocalArgs()) {
!kernel->hasLocalArgs()) {
// If the local size is one, we make the assumption that we are running a
// parallel_for over a sycl::range.
// Todo: we could add more compiler checks and
Expand All @@ -160,7 +162,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
for (unsigned g1 = 0; g1 < numWG1; g1++) {
for (unsigned g0 = 0; g0 < new_num_work_groups_0; g0 += 1) {
futures.emplace_back(tp.schedule_task(
[ndr, itemsPerThread, kernel = *hKernel, g0, g1, g2](size_t) {
[ndr, itemsPerThread, &kernel = *kernel, g0, g1, g2](size_t) {
native_cpu::state resized_state =
getResizedState(ndr, itemsPerThread);
resized_state.update(g0, g1, g2);
Expand All @@ -172,7 +174,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
for (unsigned g0 = new_num_work_groups_0 * itemsPerThread; g0 < numWG0;
g0++) {
state.update(g0, g1, g2);
hKernel->_subhandler(hKernel->getArgs().data(), &state);
kernel->_subhandler(kernel->getArgs().data(), &state);
}
}
}
Expand All @@ -185,12 +187,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
for (unsigned g2 = 0; g2 < numWG2; g2++) {
for (unsigned g1 = 0; g1 < numWG1; g1++) {
futures.emplace_back(
tp.schedule_task([state, kernel = *hKernel, numWG0, g1, g2,
tp.schedule_task([state, &kernel = *kernel, numWG0, g1, g2,
numParallelThreads](size_t threadId) mutable {
for (unsigned g0 = 0; g0 < numWG0; g0++) {
kernel.handleLocalArgs(numParallelThreads, threadId);
state.update(g0, g1, g2);
kernel._subhandler(kernel.getArgs().data(), &state);
kernel._subhandler(
kernel.getArgs(numParallelThreads, threadId).data(),
&state);
}
}));
}
Expand All @@ -202,13 +205,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
for (unsigned g2 = 0; g2 < numWG2; g2++) {
for (unsigned g1 = 0; g1 < numWG1; g1++) {
for (unsigned g0 = 0; g0 < numWG0; g0++) {
groups.push_back(
[state, g0, g1, g2, numParallelThreads](
size_t threadId, ur_kernel_handle_t_ kernel) mutable {
kernel.handleLocalArgs(numParallelThreads, threadId);
state.update(g0, g1, g2);
kernel._subhandler(kernel.getArgs().data(), &state);
});
groups.push_back([state, g0, g1, g2, numParallelThreads](
size_t threadId,
ur_kernel_handle_t_ &kernel) mutable {
state.update(g0, g1, g2);
kernel._subhandler(
kernel.getArgs(numParallelThreads, threadId).data(), &state);
});
}
}
}
Expand All @@ -218,7 +221,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
for (unsigned thread = 0; thread < numParallelThreads; thread++) {
futures.emplace_back(
tp.schedule_task([groups, thread, groupsPerThread,
kernel = *hKernel](size_t threadId) {
&kernel = *kernel](size_t threadId) {
for (unsigned i = 0; i < groupsPerThread; i++) {
auto index = thread * groupsPerThread + i;
groups[index](threadId, kernel);
Expand All @@ -231,7 +234,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
futures.emplace_back(
tp.schedule_task([groups, remainder,
scheduled = numParallelThreads * groupsPerThread,
kernel = *hKernel](size_t threadId) {
&kernel = *kernel](size_t threadId) {
for (unsigned i = 0; i < remainder; i++) {
auto index = scheduled + i;
groups[index](threadId, kernel);
Expand All @@ -247,7 +250,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
if (phEvent) {
*phEvent = event;
}
event->set_callback([hKernel, event]() {
event->set_callback([kernel = std::move(kernel), hKernel, event]() {
event->tick_end();
// TODO: avoid calling clear() here.
hKernel->_localArgInfo.clear();
Expand Down
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/native_cpu/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void ur_event_handle_t_::wait() {
// The callback may need to acquire the lock, so we unlock it here
lock.unlock();

if (callback)
if (callback.valid())
callback();
}

Expand Down
6 changes: 4 additions & 2 deletions unified-runtime/source/adapters/native_cpu/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ struct ur_event_handle_t_ : RefCounted {

~ur_event_handle_t_();

void set_callback(const std::function<void()> &cb) { callback = cb; }
template <typename T> auto set_callback(T &&cb) {
callback = std::packaged_task<void()>(std::forward<T>(cb));
}

void wait();

Expand Down Expand Up @@ -60,7 +62,7 @@ struct ur_event_handle_t_ : RefCounted {
bool done;
std::mutex mutex;
std::vector<std::future<void>> futures;
std::function<void()> callback;
std::packaged_task<void()> callback;
uint64_t timestamp_start = 0;
uint64_t timestamp_end = 0;
};
96 changes: 59 additions & 37 deletions unified-runtime/source/adapters/native_cpu/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,9 @@ struct ur_kernel_handle_t_ : RefCounted {
ur_kernel_handle_t_(const ur_kernel_handle_t_ &other)
: Args(other.Args), hProgram(other.hProgram), _name(other._name),
_subhandler(other._subhandler), _localArgInfo(other._localArgInfo),
_localMemPool(other._localMemPool),
_localMemPoolSize(other._localMemPoolSize),
ReqdWGSize(other.ReqdWGSize) {
incrementReferenceCount();
}
ReqdWGSize(other.ReqdWGSize) {}

~ur_kernel_handle_t_() {
if (decrementReferenceCount() == 0) {
free(_localMemPool);
Args.deallocate();
}
}
~ur_kernel_handle_t_() { free(_localMemPool); }

ur_kernel_handle_t_(ur_program_handle_t hProgram, const char *name,
nativecpu_task_t subhandler,
Expand All @@ -64,27 +55,62 @@ struct ur_kernel_handle_t_ : RefCounted {
std::vector<bool> OwnsMem;
static constexpr size_t MaxAlign = 16 * sizeof(double);

arguments() = default;

arguments(const arguments &Other)
: Indices(Other.Indices), ParamSizes(Other.ParamSizes),
OwnsMem(Other.OwnsMem.size(), false) {
for (size_t Index = 0; Index < Indices.size(); Index++) {
if (!Other.OwnsMem[Index]) {
continue;
}
addArg(Index, ParamSizes[Index], Indices[Index]);
}
}

arguments(arguments &&Other) : arguments() {
std::swap(Indices, Other.Indices);
std::swap(ParamSizes, Other.ParamSizes);
std::swap(OwnsMem, Other.OwnsMem);
}

~arguments() {
assert(OwnsMem.size() == Indices.size() && "Size mismatch");
for (size_t Index = 0; Index < Indices.size(); Index++) {
if (!OwnsMem[Index]) {
continue;
}
native_cpu::aligned_free(Indices[Index]);
}
}

/// Add an argument to the kernel.
/// If the argument existed before, it is replaced.
/// Otherwise, it is added.
/// Gaps are filled with empty arguments.
/// Implicit offset argument is kept at the back of the indices collection.
void addArg(size_t Index, size_t Size, const void *Arg) {
bool NeedAlloc = true;
if (Index + 1 > Indices.size()) {
Indices.resize(Index + 1);
OwnsMem.resize(Index + 1);
ParamSizes.resize(Index + 1);

// Update the stored value for the argument
Indices[Index] = native_cpu::aligned_malloc(MaxAlign, Size);
OwnsMem[Index] = true;
ParamSizes[Index] = Size;
} else {
if (ParamSizes[Index] != Size) {
Indices[Index] = realloc(Indices[Index], Size);
ParamSizes[Index] = Size;
} else if (OwnsMem[Index]) {
if (ParamSizes[Index] == Size) {
NeedAlloc = false;
} else {
native_cpu::aligned_free(Indices[Index]);
}
}
if (NeedAlloc) {
size_t Align = MaxAlign;
while (Align > Size) {
Align >>= 1;
}
Indices[Index] = native_cpu::aligned_malloc(Align, Size);
ParamSizes[Index] = Size;
OwnsMem[Index] = true;
}
std::memcpy(Indices[Index], Arg, Size);
}

Expand All @@ -100,17 +126,6 @@ struct ur_kernel_handle_t_ : RefCounted {
Indices[Index] = Arg;
}

// This is called by the destructor of ur_kernel_handle_t_, since
// ur_kernel_handle_t_ implements reference counting and we want
// to deallocate only when the reference count is 0.
void deallocate() {
assert(OwnsMem.size() == Indices.size() && "Size mismatch");
for (size_t Index = 0; Index < Indices.size(); Index++) {
if (OwnsMem[Index])
native_cpu::aligned_free(Indices[Index]);
}
}

const args_index_t &getIndices() const noexcept { return Indices; }

} Args;
Expand Down Expand Up @@ -144,19 +159,26 @@ struct ur_kernel_handle_t_ : RefCounted {

bool hasLocalArgs() const { return !_localArgInfo.empty(); }

// To be called before executing a work group if local args are present
void handleLocalArgs(size_t numParallelThread, size_t threadId) {
const std::vector<void *> &getArgs() const {
assert(!hasLocalArgs() && "For kernels with local arguments, thread "
"information must be supplied.");
return Args.getIndices();
}

std::vector<void *> getArgs(size_t numThreads, size_t threadId) const {
auto Result = Args.getIndices();

// For each local argument we have size*numthreads
size_t offset = 0;
for (auto &entry : _localArgInfo) {
Args.Indices[entry.argIndex] =
Result[entry.argIndex] =
_localMemPool + offset + (entry.argSize * threadId);
// update offset in the memory pool
offset += entry.argSize * numParallelThread;
offset += entry.argSize * numThreads;
}
}

const std::vector<void *> &getArgs() const { return Args.getIndices(); }
return Result;
}

void addArg(const void *Ptr, size_t Index, size_t Size) {
Args.addArg(Index, Size, Ptr);
Expand Down

0 comments on commit 10ccd68

Please sign in to comment.