Skip to content

Commit 08b14da

Browse files
authored
[SYCL] Fix host device local accessor alignment (#5554)
Local kernel arguments must be aligned to the type size, simply using `std::vector<char>` doesn't always provide the correct alignment. So this patch adds extra padding to the vector and ensures that the pointer returned for the accessor is actually aligned to the type size. This issue was exposed by: intel/llvm-test-suite#608, which was a follow up to fixing local accessor alignment for the CUDA plugin.
1 parent 27cc930 commit 08b14da

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

sycl/include/CL/sycl/detail/accessor_impl.hpp

+16-3
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,11 @@ class AccessorBaseHost {
170170

171171
class __SYCL_EXPORT LocalAccessorImplHost {
172172
public:
173+
// Allocate ElemSize more data to have sufficient padding to enforce
174+
// alignment.
173175
LocalAccessorImplHost(sycl::range<3> Size, int Dims, int ElemSize)
174176
: MSize(Size), MDims(Dims), MElemSize(ElemSize),
175-
MMem(Size[0] * Size[1] * Size[2] * ElemSize) {}
177+
MMem(Size[0] * Size[1] * Size[2] * ElemSize + ElemSize) {}
176178

177179
sycl::range<3> MSize;
178180
int MDims;
@@ -190,9 +192,20 @@ class LocalAccessorBaseHost {
190192
}
191193
sycl::range<3> &getSize() { return impl->MSize; }
192194
const sycl::range<3> &getSize() const { return impl->MSize; }
193-
void *getPtr() { return impl->MMem.data(); }
195+
void *getPtr() {
196+
// Const cast this in order to call the const getPtr.
197+
return const_cast<const LocalAccessorBaseHost *>(this)->getPtr();
198+
}
194199
void *getPtr() const {
195-
return const_cast<void *>(reinterpret_cast<void *>(impl->MMem.data()));
200+
char *ptr = impl->MMem.data();
201+
202+
// Align the pointer to MElemSize.
203+
size_t val = reinterpret_cast<size_t>(ptr);
204+
if (val % impl->MElemSize != 0) {
205+
ptr += impl->MElemSize - val % impl->MElemSize;
206+
}
207+
208+
return ptr;
196209
}
197210

198211
int getNumOfDims() { return impl->MDims; }

0 commit comments

Comments
 (0)