Skip to content

Commit

Permalink
add atomicMax API
Browse files Browse the repository at this point in the history
  • Loading branch information
chunhuanMeng committed Jul 12, 2024
1 parent c698f1e commit 9945fd8
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 7 deletions.
106 changes: 106 additions & 0 deletions src/ATen/native/xpu/sycl/Atomics.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,93 @@ template <typename T>
using sycl_atomic_ref_rlx_wg_local_t =
sycl::atomic_ref<T, sycl_mem_odr_rlx, sycl_mem_scp_wg, sycl_local_space>;

template <typename T, size_t n>
struct AtomicIntegerImplLocal;

template <typename T>
struct AtomicIntegerImplLocal<T, 1> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
size_t offset = (size_t)address & 3;
uint32_t* address_as_ui = (uint32_t*)((char*)address - offset);
uint32_t assumed = *address_as_ui;
uint32_t shift = offset * 8;
uint32_t newval;
uint32_t newval_byte;
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);

do {
newval = assumed;
newval_byte = (newval >> shift) & 0xff;
// preserve size in initial cast. Casting directly to uint32_t pads
// negative signed values with 1's (e.g. signed -1 = unsigned ~0).
newval = static_cast<uint8_t>(func(val, static_cast<T>(newval_byte)));
newval = (assumed & ~(0x000000ff << shift)) | (newval << shift);
} while (!target.compare_exchange_strong(assumed, newval));
}
};

template <typename T>
struct AtomicIntegerImplLocal<T, 2> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
size_t offset = (size_t)address & 2;
uint32_t* address_as_ui = (uint32_t*)((char*)address - offset);
bool is_32_align = offset;
uint32_t assumed = *address_as_ui;
uint32_t newval;
uint32_t newval_bytes;
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);

do {
newval = assumed;
newval_bytes = is_32_align ? newval >> 16 : newval & 0xffff;
// preserve size in initial cast. Casting directly to uint32_t pads
// negative signed values with 1's (e.g. signed -1 = unsigned ~0).
newval = static_cast<uint16_t>(func(val, static_cast<T>(newval_bytes)));
newval = is_32_align ? (assumed & 0xffff) | (newval << 16)
: (assumed & 0xffff0000) | newval;
} while (!target.compare_exchange_strong(assumed, newval));
}
};

template <typename T>
struct AtomicIntegerImplLocal<T, 4> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
uint32_t* address_as_ui = (uint32_t*)(address);
uint32_t assumed = *address_as_ui;
uint32_t newval;
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);

do {
newval = static_cast<uint32_t>(func(val, static_cast<T>(assumed)));
} while (!target.compare_exchange_strong(assumed, newval));
}
};

template <typename T>
struct AtomicIntegerImplLocal<T, 8> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
unsigned long long* address_as_ull = (unsigned long long*)(address);
unsigned long long assumed = *address_as_ull;
unsigned long long newval;
sycl_atomic_ref_rlx_wg_local_t<unsigned long long> target(*address_as_ull);

do {
newval = static_cast<uint64_t>(func(val, static_cast<T>(assumed)));
} while (!target.compare_exchange_strong(assumed, newval));
}
};

#define SYCL_ATOMIC_INTEGER_LOCAL(NAME, OP, DTYPE) \
static inline void atomic##NAME( \
const sycl_local_ptr<DTYPE>& address, DTYPE val) { \
AtomicIntegerImplLocal<DTYPE, sizeof(DTYPE)>()( \
address, val, [](DTYPE a, DTYPE b) { return OP; }); \
}

template <typename T, size_t n>
struct AtomicIntegerImpl;

Expand Down Expand Up @@ -268,6 +355,25 @@ SYCL_ATOMIC_FP(Mul, std::multiplies<at::Half>()(a, b), at::Half)
SYCL_ATOMIC_FP(Mul, std::multiplies<at::BFloat16>()(a, b), at::BFloat16)

// Atomic maximum implementation.

static inline void atomicMax(
const sycl_local_ptr<int32_t>& address,
int32_t val) {
sycl_atomic_ref_rlx_wg_local_t<int32_t> target(*address);
target.fetch_add(val);
}

static inline void atomicMax(
const sycl_local_ptr<int64_t>& address,
int64_t val) {
sycl_atomic_ref_rlx_wg_local_t<int64_t> target(*address);
target.fetch_add(val);
}

SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<uint8_t>(a, b), uint8_t)
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int8_t>(a, b), int8_t)
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int16_t>(a, b), int16_t)

SYCL_ATOMIC_INTEGER(Max, safe_max<uint8_t>(a, b), uint8_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int8_t>(a, b), int8_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int16_t>(a, b), int16_t)
Expand Down
13 changes: 6 additions & 7 deletions src/ATen/native/xpu/sycl/MultinomialKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,12 @@ struct SampleMultinomialOnceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
// We're done; we have the sample
// Torch indices are 1-based

sycl::atomic_ref<
int,
sycl_mem_odr_rlx,
sycl_mem_scp_wg,
sycl_local_space>
target(foundPos_[0]);
target.fetch_max(cat, sycl_mem_odr_acq_rel, sycl_mem_scp_wg);
atomicMax(
sycl_local_ptr<int>(
foundPos_
.template get_multi_ptr<sycl::access::decorated::no>()
.get()),
cat);

found_[0] = true;
}
Expand Down

0 comments on commit 9945fd8

Please sign in to comment.