Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Implement work group memory extension #15178

Merged
merged 97 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
652caa8
Preliminary implementation of work_group_memory extension
lbushi25 Aug 14, 2024
76daf77
Preliminary implementation of work_group_memory extension
lbushi25 Aug 14, 2024
21e082b
Implement work_group_memory extension
lbushi25 Aug 16, 2024
025cbc4
Implement work_group_memory extension
lbushi25 Aug 19, 2024
b94f7c9
Implement work group memory
lbushi25 Aug 20, 2024
0d6d694
Remove debug dumps
lbushi25 Aug 22, 2024
448071f
Merge branch 'intel:sycl' into work_group_memoy_new
lbushi25 Aug 22, 2024
9f2973a
Update work_group_memory.hpp
lbushi25 Aug 22, 2024
852315f
Remove include of deleted header file
lbushi25 Aug 22, 2024
4234022
Fix SPIRV compilation errors
lbushi25 Aug 22, 2024
ae5eb7e
Remove accidental change
lbushi25 Aug 23, 2024
8ce0280
Formatting changes
lbushi25 Aug 23, 2024
4ee31a5
Formatting changes
lbushi25 Aug 23, 2024
7b1b90b
Put the work group memory doc to supported
lbushi25 Aug 23, 2024
cf7476e
More formatting changes
lbushi25 Aug 23, 2024
50c0954
Delete sycl/include/sycl/ext/oneapi/experimental/test.cpp
lbushi25 Aug 23, 2024
44811b8
Yet more formatting changes
lbushi25 Aug 23, 2024
ad1046f
Merge branch 'work_group_memoy_new' of https://github.com/lbushi25/ll…
lbushi25 Aug 23, 2024
d343a2e
Fix warnings on Linux
lbushi25 Aug 23, 2024
3f1bc30
Remove unnecessary forward declaration from handler.hpp
lbushi25 Aug 23, 2024
103e233
Remove rvalue references in favor of const lvalue references
lbushi25 Aug 23, 2024
bfa5830
Fix syntax errors
lbushi25 Aug 23, 2024
2031478
Fix syntax errors
lbushi25 Aug 23, 2024
76f0acc
Don't explicitly make the work_group_memory class device-copyable as …
lbushi25 Aug 23, 2024
e0ad435
Don't explicitly make the work_group_memory class device-copyable as …
lbushi25 Aug 23, 2024
3513251
Remove some more unnecessary code
lbushi25 Aug 23, 2024
2cec997
Update work_group_memory.hpp
lbushi25 Aug 23, 2024
4c8b196
Update work_group_memory.hpp
lbushi25 Aug 23, 2024
a0b70e2
Formatting
lbushi25 Aug 23, 2024
ed8f125
Move doc to experimental folder
lbushi25 Aug 23, 2024
9460876
Update status section in doc
lbushi25 Aug 23, 2024
ac7130a
Merge branch 'work_group_memoy_new' of https://github.com/lbushi25/ll…
lbushi25 Aug 23, 2024
e2889b3
Final fixes
lbushi25 Aug 23, 2024
d6c78b9
Remove unnecessary include
lbushi25 Aug 23, 2024
8f7a07b
Add initial tests for work_group_memory extension
lbushi25 Aug 28, 2024
ae59899
Add E2E tests for work group memory
lbushi25 Aug 29, 2024
8cff603
Fix formatting
lbushi25 Aug 29, 2024
3ceead1
Resolve merge conflict
lbushi25 Sep 24, 2024
3228aeb
Revamp tests for work group memory extension
lbushi25 Sep 27, 2024
0e95ee5
Remove sanity test
lbushi25 Sep 27, 2024
52f13f0
Move extension doc to proposed
lbushi25 Sep 27, 2024
71d1013
Restore proposed status of work group memory doc
lbushi25 Sep 27, 2024
d48bc42
Fix unusd variable warning
lbushi25 Sep 27, 2024
f6515bc
Reduce test size to make sure UR does not run out or resources
lbushi25 Sep 27, 2024
3e4c73c
Replace sycl.hpp with core.hpp in the includes of E2E test
lbushi25 Sep 27, 2024
c84229e
Remove sycl.hpp include from tests
lbushi25 Sep 30, 2024
d2fddd8
Add support for unbounded arrays
lbushi25 Oct 2, 2024
0f677c2
Fix compilation errors
lbushi25 Oct 2, 2024
6ef823e
Improve swap test
lbushi25 Oct 2, 2024
6dc262a
Merge branch 'work_group_memoy_new' of https://github.com/lbushi25/ll…
lbushi25 Oct 2, 2024
4de6d50
Refactor CodeGenTypes.cpp changes
lbushi25 Oct 3, 2024
5653f04
Refactor CodeGenTypes.cpp changes
lbushi25 Oct 3, 2024
40eb63e
translate unbounded arrays to 1-sized arrays in LLVM IR in device com…
lbushi25 Oct 3, 2024
f6a0df7
Remove trailing spaces
lbushi25 Oct 3, 2024
026501c
Add unbounded array support by modifying LLVM IR -> SPIRV type lowering
lbushi25 Oct 3, 2024
2ce21b3
Revert CodeGenTypes.cpp changes
lbushi25 Oct 3, 2024
a9b2875
Fix merge conflicts
lbushi25 Oct 3, 2024
3821df4
Merge branch 'intel:sycl' into work_group_memoy_new
lbushi25 Oct 7, 2024
d73b0b1
Update SPIRVWriter.cpp
lbushi25 Oct 7, 2024
c1087ad
Merge branch 'work_group_memoy_new' of https://github.com/lbushi25/ll…
lbushi25 Oct 7, 2024
31481b8
Revert SPIRV translator changes
lbushi25 Oct 10, 2024
396169f
Revert SPIRV translator changes
lbushi25 Oct 10, 2024
dc37b2c
Merge branch 'intel:sycl' into work_group_memoy_new
lbushi25 Oct 10, 2024
236139f
Revert SPIRV translator changes
lbushi25 Oct 10, 2024
2beda8e
Stash changes
lbushi25 Oct 11, 2024
dbafe31
Add inital free function kernel support for work group memory
lbushi25 Oct 11, 2024
1b968df
merge latest upstream changes
lbushi25 Oct 11, 2024
e6b66c3
Apply suggestions
lbushi25 Oct 11, 2024
84ef6a8
Merge branch 'intel:sycl' into work_group_memoy_new
lbushi25 Oct 11, 2024
f24af09
Merge branch 'intel:sycl' into work_group_memoy_new
lbushi25 Oct 14, 2024
7dfa80b
Address reviews
lbushi25 Oct 14, 2024
3957cb5
Address reviews
lbushi25 Oct 14, 2024
91820d8
Address reviews
lbushi25 Oct 14, 2024
34bc23d
Update WorkGroupMemoryBackendArgument.cpp
lbushi25 Oct 14, 2024
3acf835
Update WorkGroupMemoryBackendArgument.cpp
lbushi25 Oct 14, 2024
604c640
Fix unit test implementation for backend kernel argument
lbushi25 Oct 15, 2024
5510208
Merge branch 'work_group_memoy_new' of https://github.com/lbushi25/ll…
lbushi25 Oct 15, 2024
d9418f9
Merge branch 'sycl' into work_group_memoy_new
lbushi25 Oct 15, 2024
e90a3b7
Add frontend tests
lbushi25 Oct 15, 2024
3b9a55a
Merge branch 'work_group_memoy_new' of https://github.com/lbushi25/ll…
lbushi25 Oct 15, 2024
b9ed6f4
Update work_group_memory.cpp
lbushi25 Oct 15, 2024
af08c19
Fix segmentation fault
lbushi25 Oct 15, 2024
b2a97a2
Merge branch 'work_group_memoy_new' of https://github.com/lbushi25/ll…
lbushi25 Oct 15, 2024
77a6de1
Move implementation details away from handler to handler_impl
lbushi25 Oct 16, 2024
3cb0ba4
Fix ABI breakage
lbushi25 Oct 16, 2024
1783f75
Update handler.hpp
lbushi25 Oct 16, 2024
5a6085f
Merge branch 'intel:sycl' into work_group_memoy_new
lbushi25 Oct 16, 2024
6affbc3
Add missing symbols to dumps
lbushi25 Oct 16, 2024
ed3c60f
Merge branch 'sycl' into work_group_memoy_new
lbushi25 Oct 16, 2024
fd89473
Fix compilation errors from changes in handler.hpp in another commit
lbushi25 Oct 17, 2024
24f87b0
Merge branch 'intel:sycl' into work_group_memoy_new
lbushi25 Oct 17, 2024
cba30a3
Merge branch 'intel:sycl' into work_group_memoy_new
lbushi25 Oct 17, 2024
3b10242
Refactor handler.hpp
lbushi25 Oct 18, 2024
f38f400
Revert "Refactor handler.hpp"
lbushi25 Oct 19, 2024
38a8d79
Fix typo in exception message
lbushi25 Oct 19, 2024
4df2f48
Improve frontend tests
lbushi25 Oct 22, 2024
9a7e3f1
Improve frontend tests
lbushi25 Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -1566,12 +1566,12 @@ def SYCLType: InheritableAttr {
let Subjects = SubjectList<[CXXRecord, Enum], ErrorDiag>;
let LangOpts = [SYCLIsDevice, SilentlyIgnoreSYCLIsHost];
let Args = [EnumArgument<"Type", "SYCLType", /*is_string=*/true,
["accessor", "local_accessor",
["accessor", "local_accessor", "work_group_memory",
"specialization_id", "kernel_handler", "buffer_location",
"no_alias", "accessor_property_list", "group",
"private_memory", "aspect", "annotated_ptr", "annotated_arg",
"stream", "sampler", "host_pipe", "multi_ptr"],
["accessor", "local_accessor",
["accessor", "local_accessor", "work_group_memory",
"specialization_id", "kernel_handler", "buffer_location",
"no_alias", "accessor_property_list", "group",
"private_memory", "aspect", "annotated_ptr", "annotated_arg",
Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/Sema/SemaSYCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class SYCLIntegrationHeader {
kind_pointer,
kind_specialization_constants_buffer,
kind_stream,
kind_last = kind_stream
kind_work_group_memory,
kind_last = kind_work_group_memory
};

public:
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4693,6 +4693,9 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
CurOffset + offsetOf(FD, FieldTy));
} else if (SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::stream)) {
addParam(FD, FieldTy, SYCLIntegrationHeader::kind_stream);
} else if (SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::work_group_memory)) {
addParam(FieldTy, SYCLIntegrationHeader::kind_work_group_memory,
offsetOf(FD, FieldTy));
} else if (SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::sampler) ||
SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::annotated_ptr) ||
SemaSYCL::isSyclType(FieldTy, SYCLTypeAttr::annotated_arg)) {
Expand Down Expand Up @@ -5773,6 +5776,7 @@ static const char *paramKind2Str(KernelParamKind K) {
CASE(stream);
CASE(specialization_constants_buffer);
CASE(pointer);
CASE(work_group_memory);
}
return "<ERROR>";

Expand Down
18 changes: 18 additions & 0 deletions clang/test/CodeGenSYCL/Inputs/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,24 @@ const stream& operator<<(const stream &S, T&&) {
return S;
}

// Dummy implementation of work_group_memory for use in CodeGenSYCL tests.
template <typename DataT>
class __attribute__((sycl_special_class))
__SYCL_TYPE(work_group_memory) work_group_memory {
public:
work_group_memory(handler &CGH) {}
#ifdef __SYCL_DEVICE_ONLY__
// Default constructor for objects later initialized with __init member.
work_group_memory() = default;
#endif

void __init(__attribute((opencl_local)) DataT *Ptr) { this->Ptr = Ptr; }
__attribute((opencl_local)) DataT *operator&() const { return Ptr; }

private:
__attribute((opencl_local)) DataT *Ptr;
};

template <typename T, int dimensions = 1,
typename AllocatorT = int /*fake type as AllocatorT is not used*/>
class buffer {
Expand Down
37 changes: 37 additions & 0 deletions clang/test/CodeGenSYCL/work_group_memory.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: %clang_cc1 -fsycl-is-device -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o %t.ll
// RUN: FileCheck < %t.ll %s --check-prefix CHECK-IR
// RUN: %clang_cc1 -fsycl-is-device -triple spir64-unknown-unknown -fsycl-int-header=%t.h %s
// RUN: FileCheck < %t.h %s --check-prefix CHECK-INT-HEADER
//
// Tests for work_group_memory kernel parameter using the dummy implementation in Inputs/sycl.hpp.
// The first two RUN commands verify that the init call is generated with the correct arguments in LLVM IR
// and the second two RUN commands verify the contents of the integration header produced by the frontend.
//
// CHECK-IR: define dso_local spir_kernel void @
// CHECK-IR-SAME: ptr addrspace(3) noundef align 4 [[PTR:%[a-zA-Z0-9_]+]]
//
// CHECK-IR: [[PTR]].addr = alloca ptr addrspace(3), align 8
// CHECK-IR: [[PTR]].addr.ascast = addrspacecast ptr [[PTR]].addr to ptr addrspace(4)
// CHECK-IR: store ptr addrspace(3) [[PTR]], ptr addrspace(4) [[PTR]].addr.ascast, align 8
// CHECK-IR: [[PTR_LOAD:%[a-zA-Z0-9_]+]] = load ptr addrspace(3), ptr addrspace(4) [[PTR]].addr.ascast, align 8
//
// CHECK-IR: call spir_func void @{{.*}}__init{{.*}}(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %{{[a-zA-Z0-9_]+}}, ptr addrspace(3) noundef [[PTR_LOAD]])
//
// CHECK-INT-HEADER: const kernel_param_desc_t kernel_signatures[] = {
// CHECK-INT-HEADER-NEXT: //--- _ZTSZZ4mainENKUlRN4sycl3_V17handlerEE_clES2_EUlNS0_4itemILi1EEEE_
// CHECK-INT-HEADER-NEXT: { kernel_param_kind_t::kind_work_group_memory, {{[4,8]}}, 0 },
// CHECK-INT-HEADER-EMPTY:
// CHECK-INT-HEADER-NEXT: { kernel_param_kind_t::kind_invalid, -987654321, -987654321 },
// CHECK-INT-HEADER-NEXT: };

#include "Inputs/sycl.hpp"

int main() {
sycl::queue Q;
Q.submit([&](sycl::handler &CGH) {
sycl::work_group_memory<int> mem;
sycl::range<1> ndr;
CGH.parallel_for(ndr, [=](sycl::item<1> it) { int *ptr = &mem; });
});
return 0;
}
1 change: 1 addition & 0 deletions clang/test/SemaSYCL/Inputs/sycl/detail/kernel_desc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace detail {
kind_pointer = 3,
kind_specialization_constants_buffer = 4,
kind_stream = 5,
kind_work_group_memory = 6,
kind_invalid = 0xf, // not a valid kernel kind
};

Expand Down
1 change: 1 addition & 0 deletions sycl-jit/common/include/Kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ enum class ParameterKind : uint32_t {
Pointer = 3,
SpecConstBuffer = 4,
Stream = 5,
WorkGroupMemory = 6,
Invalid = 0xF,
};

Expand Down
1 change: 1 addition & 0 deletions sycl/include/sycl/detail/kernel_desc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ enum class kernel_param_kind_t {
kind_pointer = 3,
kind_specialization_constants_buffer = 4,
kind_stream = 5,
kind_work_group_memory = 6,
kind_invalid = 0xf, // not a valid kernel kind
};

Expand Down
84 changes: 84 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/work_group_memory.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//===-------------------- work_group_memory.hpp ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once

#include <type_traits>

namespace sycl {
inline namespace _V1 {
namespace detail {
template <typename T> struct is_unbounded_array : std::false_type {};

template <typename T> struct is_unbounded_array<T[]> : std::true_type {};

template <typename T>
inline constexpr bool is_unbounded_array_v = is_unbounded_array<T>::value;

class work_group_memory_impl {
public:
work_group_memory_impl() : buffer_size{0} {}
work_group_memory_impl(const work_group_memory_impl &rhs) = default;
work_group_memory_impl &
operator=(const work_group_memory_impl &rhs) = default;
work_group_memory_impl(size_t buffer_size) : buffer_size{buffer_size} {}

private:
size_t buffer_size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked at this closely, but this looks suspicious. The work_group_memory object is required to correspond to just a single Level Zero kernel parameter, which is a pointer to work-group-local memory. This is the requirement specified here:

https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_free_function_kernels.asciidoc#dpc-guaranteed-compatibility-with-level-zero-and-opencl-backends

It looks like work_group_memory_impl has two member variables of type size_t. In addition, the work_group_memory type has a member variable of pointer type. Does this mean that the Level Zero kernel will end up with three kernel parameters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial testing shows that the L0 kernel only ends up with one parameter, but I will make this into a proper test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why there is only one Level Zero parameter. Is it because wgm_size and buffer_size are optimized away in the device code? I'd be concerned about this because the requirement of one Level Zero parameter holds even when optimization is disabled.

When I look at the code, I don't see any uses of buffer_size. Is that still needed?

If all we need is wgm_size, I wonder if it would be safer to implement work_group_memory such that its only data member was a union:

class work_group_memory {
 private:
  union {
    decoratedPtr ptr;  // Used only on device
    size_t size;       // Used only on host
  };
};

Or, you could use std::variant instead of a union in a similar way.

Copy link
Contributor Author

@lbushi25 lbushi25 Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, my understanding is that the processArg function in handler.cpp file defines a mapping between the SYCL kernel parameters and the L0 kernel arguments. In the case of work_group_memory, I've defined the mapping to be such that whenever the runtime sees a work group memory parameter passed to a SYCL kernel, instead map that to a local memory buffer on the underlying backend where the size of the buffer is given by the buffer_size member of the work group memory object. If you look at my changes in processArg function, thats exactly what I'm doing. Therefore, with or without optimization, I believe there will only be one L0 kernel argument per work group memory object.

Copy link
Contributor Author

@lbushi25 lbushi25 Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The relevant lines are 793-798 in handler.cpp file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks for the explanation!

Copy link
Contributor

@steffenlarsen steffenlarsen Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lbushi25 is right that this will only result in a single kernel argument, which is due to it being a "SYCL special class" so the arguments of its __init function is used as the arguments of the kernel and the function body is then used for constructing the object.

All that said however, I also have some concerns here:

  1. These members are public and by public inheritance they will also be public in work_group_memory. They should not be, however. @AlexeySachkov corrected me and he's right. Since the inheritor is a class the default inheritance is private.
  2. Though this does not add extra arguments, they will increase register pressure (assuming the optimizer doesn't remove them.) We could move to a PIMPL implementation like many other SYCL classes (see for example device) where we have a std::shared_ptr<work_group_memory_impl> in this case, where work_group_memory_impl would be moved to a source file. Then, we could have something like:
#ifdef __SYCL_DEVICE_ONLY__
decoratedPtr ptr;
// To ensure we have the same object size on host and device, we add padding.
[[maybe_unused]] char padding[sizeof(std::shared_ptr<work_group_memory_impl>) - sizeof(decoratedPtr)];
#else
std::shared_ptr<work_group_memory_impl> impl;
#endif

Note that if we're sure we only ever need the two size_t, I don't think there's a big problem in having them directly in the class. Reading through it though, I don't fully understand why we need two size_t. If we could reduce it to one, we could do like Greg suggests further up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These members are public and by public inheritance they will also be public in work_group_memory. They should not be, however. @AlexeySachkov corrected me and he's right. Since the inheritor is a class the default inheritance is private.

But I was also confused at first, so I think it worth explicitly declaring members of the base class as protected

friend class sycl::handler;
};

} // namespace detail
namespace ext::oneapi::experimental {

template <typename DataT, typename PropertyListT = empty_properties_t>
class __SYCL_SPECIAL_CLASS __SYCL_TYPE(work_group_memory) work_group_memory
: sycl::detail::work_group_memory_impl {
public:
using value_type = std::remove_all_extents_t<DataT>;

private:
using decoratedPtr = typename sycl::detail::DecoratedType<
value_type, access::address_space::local_space>::type *;

public:
work_group_memory() = default;
work_group_memory(const work_group_memory &rhs) = default;
work_group_memory &operator=(const work_group_memory &rhs) = default;
template <typename T = DataT,
typename = std::enable_if_t<!sycl::detail::is_unbounded_array_v<T>>>
work_group_memory(handler &)
: sycl::detail::work_group_memory_impl(sizeof(DataT)) {}
template <typename T = DataT,
typename = std::enable_if_t<sycl::detail::is_unbounded_array_v<T>>>
work_group_memory(size_t num, handler &)
: sycl::detail::work_group_memory_impl(
num * sizeof(std::remove_extent_t<DataT>)) {}
template <access::decorated IsDecorated = access::decorated::no>
multi_ptr<value_type, access::address_space::local_space, IsDecorated>
get_multi_ptr() const {
return sycl::address_space_cast<access::address_space::local_space,
IsDecorated, value_type>(ptr);
}
DataT *operator&() const { return reinterpret_cast<DataT *>(ptr); }
operator DataT &() const { return *reinterpret_cast<DataT *>(ptr); }
template <typename T = DataT,
typename = std::enable_if_t<!std::is_array_v<T>>>
const work_group_memory &operator=(const DataT &value) const {
*ptr = value;
return *this;
}
#ifdef __SYCL_DEVICE_ONLY__
void __init(decoratedPtr ptr) { this->ptr = ptr; }
#endif
private:
decoratedPtr ptr;
};
} // namespace ext::oneapi::experimental
} // namespace _V1
} // namespace sycl
33 changes: 25 additions & 8 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ class pipe;
}

namespace ext ::oneapi ::experimental {
template <typename, typename>
class work_group_memory;
struct image_descriptor;
} // namespace ext::oneapi::experimental

Expand All @@ -171,6 +173,7 @@ class graph_impl;
} // namespace ext::oneapi::experimental::detail
namespace detail {

class work_group_memory_impl;
class handler_impl;
class kernel_impl;
class queue_impl;
Expand Down Expand Up @@ -564,8 +567,8 @@ class __SYCL_EXPORT handler {
// The version for regular(standard layout) argument.
template <typename T, typename... Ts>
void setArgsHelper(int ArgIndex, T &&Arg, Ts &&...Args) {
set_arg(ArgIndex, std::move(Arg));
setArgsHelper(++ArgIndex, std::move(Args)...);
set_arg(ArgIndex, std::forward<T>(Arg));
setArgsHelper(++ArgIndex, std::forward<Ts>(Args)...);
}

void setArgsHelper(int) {}
Expand Down Expand Up @@ -603,6 +606,8 @@ class __SYCL_EXPORT handler {
#endif
}

void setArgHelper(int ArgIndex, detail::work_group_memory_impl &Arg);

// setArgHelper for non local accessor argument.
template <typename DataT, int Dims, access::mode AccessMode,
access::target AccessTarget, access::placeholder IsPlaceholder>
Expand Down Expand Up @@ -1096,7 +1101,7 @@ class __SYCL_EXPORT handler {
KernelType KernelFunc) {
#ifndef __SYCL_DEVICE_ONLY__
throwIfActionIsCreated();
throwOnLocalAccessorMisuse<KernelName, KernelType>();
throwOnKernelParameterMisuse<KernelName, KernelType>();
if (!range_size_fits_in_size_t(UserRange))
throw sycl::exception(make_error_code(errc::runtime),
"The total number of work-items in "
Expand Down Expand Up @@ -1641,7 +1646,7 @@ class __SYCL_EXPORT handler {
kernel_single_task_wrapper<NameT, KernelType, PropertiesT>(KernelFunc);
#ifndef __SYCL_DEVICE_ONLY__
throwIfActionIsCreated();
throwOnLocalAccessorMisuse<KernelName, KernelType>();
throwOnKernelParameterMisuse<KernelName, KernelType>();
verifyUsedKernelBundleInternal(
detail::string_view{detail::getKernelName<NameT>()});
// No need to check if range is out of INT_MAX limits as it's compile-time
Expand Down Expand Up @@ -1840,6 +1845,14 @@ class __SYCL_EXPORT handler {
setArgHelper(ArgIndex, std::move(Arg));
}

template <typename DataT, typename PropertyListT =
ext::oneapi::experimental::empty_properties_t>
void set_arg(
int ArgIndex,
ext::oneapi::experimental::work_group_memory<DataT, PropertyListT> &Arg) {
setArgHelper(ArgIndex, Arg);
}

// set_arg for graph dynamic_parameters
template <typename T>
void set_arg(int argIndex,
Expand All @@ -1858,9 +1871,8 @@ class __SYCL_EXPORT handler {
///
/// \param Args are argument values to be set.
template <typename... Ts> void set_args(Ts &&...Args) {
setArgsHelper(0, std::move(Args)...);
setArgsHelper(0, std::forward<Ts>(Args)...);
}

/// Defines and invokes a SYCL kernel function as a function object type.
///
/// If it is a named function object and the function object type is
Expand Down Expand Up @@ -3233,7 +3245,6 @@ class __SYCL_EXPORT handler {
private:
std::shared_ptr<detail::handler_impl> impl;
std::shared_ptr<detail::queue_impl> MQueue;

std::vector<detail::LocalAccessorImplPtr> MLocalAccStorage;
std::vector<std::shared_ptr<detail::stream_impl>> MStreamStorage;
detail::string MKernelName;
Expand Down Expand Up @@ -3554,7 +3565,7 @@ class __SYCL_EXPORT handler {
/// must not be used in a SYCL kernel function that is invoked via single_task
/// or via the simple form of parallel_for that takes a range parameter.
template <typename KernelName, typename KernelType>
void throwOnLocalAccessorMisuse() const {
void throwOnKernelParameterMisuse() const {
using NameT =
typename detail::get_kernel_name_t<KernelName, KernelType>::name;
for (unsigned I = 0; I < detail::getKernelNumParams<NameT>(); ++I) {
Expand All @@ -3570,6 +3581,12 @@ class __SYCL_EXPORT handler {
"A local accessor must not be used in a SYCL kernel function "
"that is invoked via single_task or via the simple form of "
"parallel_for that takes a range parameter.");
if (Kind == detail::kernel_param_kind_t::kind_work_group_memory)
throw sycl::exception(
make_error_code(errc::kernel_argument),
"A work group memory object must not be used in a SYCL kernel "
"function that is invoked via single_task or via the simple form "
"of parallel_for that takes a range parameter.");
}
}

Expand Down
1 change: 1 addition & 0 deletions sycl/include/sycl/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
#include <sycl/ext/oneapi/experimental/raw_kernel_arg.hpp>
#include <sycl/ext/oneapi/experimental/root_group.hpp>
#include <sycl/ext/oneapi/experimental/tangle_group.hpp>
#include <sycl/ext/oneapi/experimental/work_group_memory.hpp>
#include <sycl/ext/oneapi/filter_selector.hpp>
#include <sycl/ext/oneapi/free_function_queries.hpp>
#include <sycl/ext/oneapi/functional.hpp>
Expand Down
3 changes: 3 additions & 0 deletions sycl/source/detail/handler_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ class handler_impl {

/// True if MCodeLoc is sycl entry point code location
bool MIsTopCodeLoc = true;

/// List of work group memory objects associated with this handler
std::vector<std::shared_ptr<detail::work_group_memory_impl>> MWorkGroupMemoryObjects;
};

} // namespace detail
Expand Down
2 changes: 2 additions & 0 deletions sycl/source/detail/jit_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ translateArgType(kernel_param_kind_t Kind) {
return PK::SpecConstBuffer;
case kind::kind_stream:
return PK::Stream;
case kind::kind_work_group_memory:
return PK::WorkGroupMemory;
case kind::kind_invalid:
return PK::Invalid;
}
Expand Down
2 changes: 2 additions & 0 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2297,6 +2297,8 @@ void SetArgBasedOnType(
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
const sycl::context &Context, detail::ArgDesc &Arg, size_t NextTrueIndex) {
switch (Arg.MType) {
case kernel_param_kind_t::kind_work_group_memory:
break;
case kernel_param_kind_t::kind_stream:
break;
case kernel_param_kind_t::kind_accessor: {
Expand Down
14 changes: 14 additions & 0 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <sycl/stream.hpp>

#include <sycl/ext/oneapi/bindless_images_memory.hpp>
#include <sycl/ext/oneapi/experimental/work_group_memory.hpp>
#include <sycl/ext/oneapi/memcpy2d.hpp>

namespace sycl {
Expand Down Expand Up @@ -795,6 +796,12 @@ void handler::processArg(void *Ptr, const detail::kernel_param_kind_t &Kind,
}
break;
}
case kernel_param_kind_t::kind_work_group_memory: {
addArg(kernel_param_kind_t::kind_std_layout, nullptr,
static_cast<detail::work_group_memory_impl *>(Ptr)->buffer_size,
Index + IndexShift);
break;
}
case kernel_param_kind_t::kind_sampler: {
addArg(kernel_param_kind_t::kind_sampler, Ptr, sizeof(sampler),
Index + IndexShift);
Expand All @@ -812,6 +819,13 @@ void handler::processArg(void *Ptr, const detail::kernel_param_kind_t &Kind,
}
}

void handler::setArgHelper(int ArgIndex, detail::work_group_memory_impl &Arg) {
impl->MWorkGroupMemoryObjects.push_back(
std::make_shared<detail::work_group_memory_impl>(Arg));
addArg(detail::kernel_param_kind_t::kind_work_group_memory,
impl->MWorkGroupMemoryObjects.back().get(), 0, ArgIndex);
}

// The argument can take up more space to store additional information about
// MAccessRange, MMemoryRange, and MOffset added with addArgsForGlobalAccessor.
// We use the worst-case estimate because the lifetime of the vector is short.
Expand Down
Loading
Loading