@@ -42,12 +42,15 @@ namespace __ESIMD_DNS {
42
42
// Provides access to sycl accessor class' private members.
43
43
class AccessorPrivateProxy {
44
44
public:
45
- #ifdef __SYCL_DEVICE_ONLY__
46
45
template <typename AccessorTy>
47
46
static auto getNativeImageObj (const AccessorTy &Acc) {
47
+ #ifdef __SYCL_DEVICE_ONLY__
48
48
return Acc.getNativeImageObj ();
49
- }
50
49
#else // __SYCL_DEVICE_ONLY__
50
+ return Acc;
51
+ #endif // __SYCL_DEVICE_ONLY__
52
+ }
53
+ #ifndef __SYCL_DEVICE_ONLY__
51
54
static void *getPtr (const sycl::detail::AccessorBaseHost &Acc) {
52
55
return Acc.getPtr ();
53
56
}
@@ -421,18 +424,32 @@ __esimd_scatter_scaled(__ESIMD_DNS::simd_mask_storage_t<N> pred,
421
424
static_assert (TySizeLog2 <= 2 );
422
425
static_assert (std::is_integral<Ty>::value || TySizeLog2 == 2 );
423
426
427
+ // determine the original element's type size (as __esimd_scatter_scaled
428
+ // requires vals to be a vector of 4-byte integers)
429
+ constexpr size_t OrigSize = __ESIMD_DNS::ElemsPerAddrDecoding (TySizeLog2);
430
+ using RestoredTy = __ESIMD_DNS::uint_type_t <OrigSize>;
431
+
424
432
sycl::detail::ESIMDDeviceInterface *I =
425
433
sycl::detail::getESIMDDeviceInterface ();
426
434
435
+ __ESIMD_DNS::vector_type_t <RestoredTy, N> TypeAdjustedVals;
436
+ if constexpr (OrigSize == 4 ) {
437
+ TypeAdjustedVals = __ESIMD_DNS::bitcast<RestoredTy, Ty, N>(vals);
438
+ } else {
439
+ static_assert (OrigSize == 1 || OrigSize == 2 );
440
+ TypeAdjustedVals = __ESIMD_DNS::convert_vector<RestoredTy, Ty, N>(vals);
441
+ }
442
+
427
443
if (surf_ind == __ESIMD_NS::detail::SLM_BTI) {
428
444
// Scattered-store for Shared Local Memory
429
445
// __ESIMD_NS::detail::SLM_BTI is special binding table index for SLM
430
446
assert (global_offset == 0 );
431
447
char *SlmBase = I->__cm_emu_get_slm_ptr ();
432
448
for (int i = 0 ; i < N; ++i) {
433
449
if (pred[i]) {
434
- Ty *addr = reinterpret_cast <Ty *>(elem_offsets[i] + SlmBase);
435
- *addr = vals[i];
450
+ RestoredTy *addr =
451
+ reinterpret_cast <RestoredTy *>(elem_offsets[i] + SlmBase);
452
+ *addr = TypeAdjustedVals[i];
436
453
}
437
454
}
438
455
} else {
@@ -449,8 +466,9 @@ __esimd_scatter_scaled(__ESIMD_DNS::simd_mask_storage_t<N> pred,
449
466
450
467
for (int idx = 0 ; idx < N; idx++) {
451
468
if (pred[idx]) {
452
- Ty *addr = reinterpret_cast <Ty *>(elem_offsets[idx] + writeBase);
453
- *addr = vals[idx];
469
+ RestoredTy *addr =
470
+ reinterpret_cast <RestoredTy *>(elem_offsets[idx] + writeBase);
471
+ *addr = TypeAdjustedVals[idx];
454
472
}
455
473
}
456
474
@@ -629,7 +647,12 @@ __esimd_gather_masked_scaled2(SurfIndAliasTy surf_ind, uint32_t global_offset,
629
647
{
630
648
static_assert (Scale == 0 );
631
649
632
- __ESIMD_DNS::vector_type_t <Ty, N> retv = 0 ;
650
+ // determine the original element's type size (as __esimd_scatter_scaled
651
+ // requires vals to be a vector of 4-byte integers)
652
+ constexpr size_t OrigSize = __ESIMD_DNS::ElemsPerAddrDecoding (TySizeLog2);
653
+ using RestoredTy = __ESIMD_DNS::uint_type_t <OrigSize>;
654
+
655
+ __ESIMD_DNS::vector_type_t <RestoredTy, N> retv = 0 ;
633
656
sycl::detail::ESIMDDeviceInterface *I =
634
657
sycl::detail::getESIMDDeviceInterface ();
635
658
@@ -639,7 +662,8 @@ __esimd_gather_masked_scaled2(SurfIndAliasTy surf_ind, uint32_t global_offset,
639
662
char *SlmBase = I->__cm_emu_get_slm_ptr ();
640
663
for (int idx = 0 ; idx < N; ++idx) {
641
664
if (pred[idx]) {
642
- Ty *addr = reinterpret_cast <Ty *>(offsets[idx] + SlmBase);
665
+ RestoredTy *addr =
666
+ reinterpret_cast <RestoredTy *>(offsets[idx] + SlmBase);
643
667
retv[idx] = *addr;
644
668
}
645
669
}
@@ -655,15 +679,21 @@ __esimd_gather_masked_scaled2(SurfIndAliasTy surf_ind, uint32_t global_offset,
655
679
std::unique_lock<std::mutex> lock (*mutexLock);
656
680
for (int idx = 0 ; idx < N; idx++) {
657
681
if (pred[idx]) {
658
- Ty *addr = reinterpret_cast <Ty *>(offsets[idx] + readBase);
682
+ RestoredTy *addr =
683
+ reinterpret_cast <RestoredTy *>(offsets[idx] + readBase);
659
684
retv[idx] = *addr;
660
685
}
661
686
}
662
687
663
688
// TODO : Optimize
664
689
I->cm_fence_ptr ();
665
690
}
666
- return retv;
691
+
692
+ if constexpr (OrigSize == 4 ) {
693
+ return __ESIMD_DNS::bitcast<Ty, RestoredTy, N>(retv);
694
+ } else {
695
+ return __ESIMD_DNS::convert_vector<Ty, RestoredTy, N>(retv);
696
+ }
667
697
}
668
698
#endif // __SYCL_DEVICE_ONLY__
669
699
0 commit comments