Skip to content

Commit 1fe5eaa

Browse files
[SYCL][ESIMD][EMU] Handle intrinsic operations promoted to 4-byte element type (#5727)
* [SYCL][ESIMD][EMU] Handling intrinsic operations promoted to 4-byte element type - scatter_impl() and gather_impl() in memory.hpp promote argument/return vector to 4-byte elements
1 parent 86cf56a commit 1fe5eaa

File tree

3 files changed

+44
-28
lines changed

3 files changed

+44
-28
lines changed

sycl/include/sycl/ext/intel/esimd/detail/math_intrin.hpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ __esimd_pack_mask(__ESIMD_DNS::vector_type_t<uint16_t, N> src0) {
698698
// wrapper code (which does the checks already)
699699
uint32_t retv = 0;
700700
for (int i = 0; i < N; i++) {
701-
if (src0[i] & 0x1) {
701+
if (src0[i] != 0) {
702702
retv |= 0x1 << i;
703703
}
704704
}
@@ -709,12 +709,10 @@ __esimd_pack_mask(__ESIMD_DNS::vector_type_t<uint16_t, N> src0) {
709709
template <int N>
710710
__ESIMD_INTRIN __ESIMD_DNS::vector_type_t<uint16_t, N>
711711
__esimd_unpack_mask(uint32_t src0) {
712-
__ESIMD_DNS::vector_type_t<uint16_t, N> retv;
712+
__ESIMD_DNS::vector_type_t<uint16_t, N> retv = 0;
713713
for (int i = 0; i < N; i++) {
714714
if ((src0 >> i) & 0x1) {
715715
retv[i] = 1;
716-
} else {
717-
retv[i] = 0;
718716
}
719717
}
720718
return retv;

sycl/include/sycl/ext/intel/esimd/detail/memory_intrin.hpp

+40-10
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,15 @@ namespace __ESIMD_DNS {
4242
// Provides access to sycl accessor class' private members.
4343
class AccessorPrivateProxy {
4444
public:
45-
#ifdef __SYCL_DEVICE_ONLY__
4645
template <typename AccessorTy>
4746
static auto getNativeImageObj(const AccessorTy &Acc) {
47+
#ifdef __SYCL_DEVICE_ONLY__
4848
return Acc.getNativeImageObj();
49-
}
5049
#else // __SYCL_DEVICE_ONLY__
50+
return Acc;
51+
#endif // __SYCL_DEVICE_ONLY__
52+
}
53+
#ifndef __SYCL_DEVICE_ONLY__
5154
static void *getPtr(const sycl::detail::AccessorBaseHost &Acc) {
5255
return Acc.getPtr();
5356
}
@@ -421,18 +424,32 @@ __esimd_scatter_scaled(__ESIMD_DNS::simd_mask_storage_t<N> pred,
421424
static_assert(TySizeLog2 <= 2);
422425
static_assert(std::is_integral<Ty>::value || TySizeLog2 == 2);
423426

427+
// determine the original element's type size (as __esimd_scatter_scaled
428+
// requires vals to be a vector of 4-byte integers)
429+
constexpr size_t OrigSize = __ESIMD_DNS::ElemsPerAddrDecoding(TySizeLog2);
430+
using RestoredTy = __ESIMD_DNS::uint_type_t<OrigSize>;
431+
424432
sycl::detail::ESIMDDeviceInterface *I =
425433
sycl::detail::getESIMDDeviceInterface();
426434

435+
__ESIMD_DNS::vector_type_t<RestoredTy, N> TypeAdjustedVals;
436+
if constexpr (OrigSize == 4) {
437+
TypeAdjustedVals = __ESIMD_DNS::bitcast<RestoredTy, Ty, N>(vals);
438+
} else {
439+
static_assert(OrigSize == 1 || OrigSize == 2);
440+
TypeAdjustedVals = __ESIMD_DNS::convert_vector<RestoredTy, Ty, N>(vals);
441+
}
442+
427443
if (surf_ind == __ESIMD_NS::detail::SLM_BTI) {
428444
// Scattered-store for Shared Local Memory
429445
// __ESIMD_NS::detail::SLM_BTI is special binding table index for SLM
430446
assert(global_offset == 0);
431447
char *SlmBase = I->__cm_emu_get_slm_ptr();
432448
for (int i = 0; i < N; ++i) {
433449
if (pred[i]) {
434-
Ty *addr = reinterpret_cast<Ty *>(elem_offsets[i] + SlmBase);
435-
*addr = vals[i];
450+
RestoredTy *addr =
451+
reinterpret_cast<RestoredTy *>(elem_offsets[i] + SlmBase);
452+
*addr = TypeAdjustedVals[i];
436453
}
437454
}
438455
} else {
@@ -449,8 +466,9 @@ __esimd_scatter_scaled(__ESIMD_DNS::simd_mask_storage_t<N> pred,
449466

450467
for (int idx = 0; idx < N; idx++) {
451468
if (pred[idx]) {
452-
Ty *addr = reinterpret_cast<Ty *>(elem_offsets[idx] + writeBase);
453-
*addr = vals[idx];
469+
RestoredTy *addr =
470+
reinterpret_cast<RestoredTy *>(elem_offsets[idx] + writeBase);
471+
*addr = TypeAdjustedVals[idx];
454472
}
455473
}
456474

@@ -629,7 +647,12 @@ __esimd_gather_masked_scaled2(SurfIndAliasTy surf_ind, uint32_t global_offset,
629647
{
630648
static_assert(Scale == 0);
631649

632-
__ESIMD_DNS::vector_type_t<Ty, N> retv = 0;
650+
// determine the original element's type size (as __esimd_scatter_scaled
651+
// requires vals to be a vector of 4-byte integers)
652+
constexpr size_t OrigSize = __ESIMD_DNS::ElemsPerAddrDecoding(TySizeLog2);
653+
using RestoredTy = __ESIMD_DNS::uint_type_t<OrigSize>;
654+
655+
__ESIMD_DNS::vector_type_t<RestoredTy, N> retv = 0;
633656
sycl::detail::ESIMDDeviceInterface *I =
634657
sycl::detail::getESIMDDeviceInterface();
635658

@@ -639,7 +662,8 @@ __esimd_gather_masked_scaled2(SurfIndAliasTy surf_ind, uint32_t global_offset,
639662
char *SlmBase = I->__cm_emu_get_slm_ptr();
640663
for (int idx = 0; idx < N; ++idx) {
641664
if (pred[idx]) {
642-
Ty *addr = reinterpret_cast<Ty *>(offsets[idx] + SlmBase);
665+
RestoredTy *addr =
666+
reinterpret_cast<RestoredTy *>(offsets[idx] + SlmBase);
643667
retv[idx] = *addr;
644668
}
645669
}
@@ -655,15 +679,21 @@ __esimd_gather_masked_scaled2(SurfIndAliasTy surf_ind, uint32_t global_offset,
655679
std::unique_lock<std::mutex> lock(*mutexLock);
656680
for (int idx = 0; idx < N; idx++) {
657681
if (pred[idx]) {
658-
Ty *addr = reinterpret_cast<Ty *>(offsets[idx] + readBase);
682+
RestoredTy *addr =
683+
reinterpret_cast<RestoredTy *>(offsets[idx] + readBase);
659684
retv[idx] = *addr;
660685
}
661686
}
662687

663688
// TODO : Optimize
664689
I->cm_fence_ptr();
665690
}
666-
return retv;
691+
692+
if constexpr (OrigSize == 4) {
693+
return __ESIMD_DNS::bitcast<Ty, RestoredTy, N>(retv);
694+
} else {
695+
return __ESIMD_DNS::convert_vector<Ty, RestoredTy, N>(retv);
696+
}
667697
}
668698
#endif // __SYCL_DEVICE_ONLY__
669699

sycl/include/sycl/ext/intel/esimd/memory.hpp

+2-14
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,8 @@ __ESIMD_API SurfaceIndex get_surface_index(AccessorTy acc) {
6161
if constexpr (std::is_same_v<detail::LocalAccessorMarker, AccessorTy>) {
6262
return detail::SLM_BTI;
6363
} else {
64-
#ifdef __SYCL_DEVICE_ONLY__
65-
const auto mem_obj = detail::AccessorPrivateProxy::getNativeImageObj(acc);
66-
return __esimd_get_surface_index(mem_obj);
67-
#else // __SYCL_DEVICE_ONLY__
68-
return __esimd_get_surface_index(acc);
69-
#endif // __SYCL_DEVICE_ONLY__
64+
return __esimd_get_surface_index(
65+
detail::AccessorPrivateProxy::getNativeImageObj(acc));
7066
}
7167
}
7268

@@ -253,12 +249,8 @@ __ESIMD_API simd<Tx, N> block_load(AccessorTy acc, uint32_t offset,
253249
static_assert(Sz <= 8 * detail::OperandSize::OWORD,
254250
"block size must be at most 8 owords");
255251

256-
#if defined(__SYCL_DEVICE_ONLY__)
257252
auto surf_ind = __esimd_get_surface_index(
258253
detail::AccessorPrivateProxy::getNativeImageObj(acc));
259-
#else // __SYCL_DEVICE_ONLY__
260-
auto surf_ind = __esimd_get_surface_index(acc);
261-
#endif // __SYCL_DEVICE_ONLY__
262254

263255
if constexpr (Flags::template alignment<simd<T, N>> >=
264256
detail::OperandSize::OWORD) {
@@ -317,12 +309,8 @@ __ESIMD_API void block_store(AccessorTy acc, uint32_t offset,
317309
static_assert(Sz <= 8 * detail::OperandSize::OWORD,
318310
"block size must be at most 8 owords");
319311

320-
#if defined(__SYCL_DEVICE_ONLY__)
321312
auto surf_ind = __esimd_get_surface_index(
322313
detail::AccessorPrivateProxy::getNativeImageObj(acc));
323-
#else //
324-
auto surf_ind = __esimd_get_surface_index(acc);
325-
#endif
326314
__esimd_oword_st<T, N>(surf_ind, offset >> 4, vals.data());
327315
}
328316

0 commit comments

Comments
 (0)