diff --git a/src/ATen/native/xpu/sycl/Atomics.h b/src/ATen/native/xpu/sycl/Atomics.h index 4b124041a..c99b02b3a 100644 --- a/src/ATen/native/xpu/sycl/Atomics.h +++ b/src/ATen/native/xpu/sycl/Atomics.h @@ -27,6 +27,93 @@ template using sycl_atomic_ref_rlx_wg_local_t = sycl::atomic_ref; +template +struct AtomicIntegerImplLocal; + +template +struct AtomicIntegerImplLocal { + template + inline void operator()(T* address, T val, const func_t& func) { + size_t offset = (size_t)address & 3; + uint32_t* address_as_ui = (uint32_t*)((char*)address - offset); + uint32_t assumed = *address_as_ui; + uint32_t shift = offset * 8; + uint32_t newval; + uint32_t newval_byte; + sycl_atomic_ref_rlx_wg_local_t target(*address_as_ui); + + do { + newval = assumed; + newval_byte = (newval >> shift) & 0xff; + // preserve size in initial cast. Casting directly to uint32_t pads + // negative signed values with 1's (e.g. signed -1 = unsigned ~0). + newval = static_cast(func(val, static_cast(newval_byte))); + newval = (assumed & ~(0x000000ff << shift)) | (newval << shift); + } while (!target.compare_exchange_strong(assumed, newval)); + } +}; + +template +struct AtomicIntegerImplLocal { + template + inline void operator()(T* address, T val, const func_t& func) { + size_t offset = (size_t)address & 2; + uint32_t* address_as_ui = (uint32_t*)((char*)address - offset); + bool is_32_align = offset; + uint32_t assumed = *address_as_ui; + uint32_t newval; + uint32_t newval_bytes; + sycl_atomic_ref_rlx_wg_local_t target(*address_as_ui); + + do { + newval = assumed; + newval_bytes = is_32_align ? newval >> 16 : newval & 0xffff; + // preserve size in initial cast. Casting directly to uint32_t pads + // negative signed values with 1's (e.g. signed -1 = unsigned ~0). + newval = static_cast(func(val, static_cast(newval_bytes))); + newval = is_32_align ? (assumed & 0xffff) | (newval << 16) + : (assumed & 0xffff0000) | newval; + } while (!target.compare_exchange_strong(assumed, newval)); + } +}; + +template +struct AtomicIntegerImplLocal { + template + inline void operator()(T* address, T val, const func_t& func) { + uint32_t* address_as_ui = (uint32_t*)(address); + uint32_t assumed = *address_as_ui; + uint32_t newval; + sycl_atomic_ref_rlx_wg_local_t target(*address_as_ui); + + do { + newval = static_cast(func(val, static_cast(assumed))); + } while (!target.compare_exchange_strong(assumed, newval)); + } +}; + +template +struct AtomicIntegerImplLocal { + template + inline void operator()(T* address, T val, const func_t& func) { + unsigned long long* address_as_ull = (unsigned long long*)(address); + unsigned long long assumed = *address_as_ull; + unsigned long long newval; + sycl_atomic_ref_rlx_wg_local_t target(*address_as_ull); + + do { + newval = static_cast(func(val, static_cast(assumed))); + } while (!target.compare_exchange_strong(assumed, newval)); + } +}; + +#define SYCL_ATOMIC_INTEGER_LOCAL(NAME, OP, DTYPE) \ + static inline void atomic##NAME( \ + const sycl_local_ptr& address, DTYPE val) { \ + AtomicIntegerImplLocal()( \ + address, val, [](DTYPE a, DTYPE b) { return OP; }); \ + } + template struct AtomicIntegerImpl; @@ -268,6 +355,25 @@ SYCL_ATOMIC_FP(Mul, std::multiplies()(a, b), at::Half) SYCL_ATOMIC_FP(Mul, std::multiplies()(a, b), at::BFloat16) // Atomic maximum implementation. + +static inline void atomicMax( + const sycl_local_ptr& address, + int32_t val) { + sycl_atomic_ref_rlx_wg_local_t target(*address); + target.fetch_add(val); +} + +static inline void atomicMax( + const sycl_local_ptr& address, + int64_t val) { + sycl_atomic_ref_rlx_wg_local_t target(*address); + target.fetch_add(val); +} + +SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max(a, b), uint8_t) +SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max(a, b), int8_t) +SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max(a, b), int16_t) + SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t) SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t) SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t) diff --git a/src/ATen/native/xpu/sycl/MultinomialKernel.cpp b/src/ATen/native/xpu/sycl/MultinomialKernel.cpp index f332d3362..9b1e612c4 100644 --- a/src/ATen/native/xpu/sycl/MultinomialKernel.cpp +++ b/src/ATen/native/xpu/sycl/MultinomialKernel.cpp @@ -336,13 +336,12 @@ struct SampleMultinomialOnceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { // We're done; we have the sample // Torch indices are 1-based - sycl::atomic_ref< - int, - sycl_mem_odr_rlx, - sycl_mem_scp_wg, - sycl_local_space> - target(foundPos_[0]); - target.fetch_max(cat, sycl_mem_odr_acq_rel, sycl_mem_scp_wg); + atomicMax( + sycl_local_ptr( + foundPos_ + .template get_multi_ptr() + .get()), + cat); found_[0] = true; }