Skip to content

Commit

Permalink
rename functor and use fetch_max
Browse files Browse the repository at this point in the history
  • Loading branch information
chunhuanMeng committed Jul 9, 2024
1 parent 549aaab commit e75fb64
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 108 deletions.
3 changes: 1 addition & 2 deletions src/ATen/native/xpu/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ Tensor& XPUNativeFunctions::multinomial_out(
return result;
}

at::native::xpu::multinomial_with_replacement_kernel(
result, self, n_sample, gen);
at::native::xpu::multinomial_kernel(result, self, n_sample, gen);
return result;
}

Expand Down
91 changes: 0 additions & 91 deletions src/ATen/native/xpu/sycl/Atomics.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,92 +115,6 @@ struct AtomicIntegerImpl<T, 8> {
address, val, [](DTYPE a, DTYPE b) { return OP; }); \
}

template <typename T, size_t n>
struct AtomicIntegerImplLocal;

template <typename T>
struct AtomicIntegerImplLocal<T, 1> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
size_t offset = (size_t)address & 3;
uint32_t* address_as_ui = (uint32_t*)((char*)address - offset);
uint32_t assumed = *address_as_ui;
uint32_t shift = offset * 8;
uint32_t newval;
uint32_t newval_byte;
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);

do {
newval = assumed;
newval_byte = (newval >> shift) & 0xff;
// preserve size in initial cast. Casting directly to uint32_t pads
// negative signed values with 1's (e.g. signed -1 = unsigned ~0).
newval = static_cast<uint8_t>(func(val, static_cast<T>(newval_byte)));
newval = (assumed & ~(0x000000ff << shift)) | (newval << shift);
} while (!target.compare_exchange_strong(assumed, newval));
}
};

template <typename T>
struct AtomicIntegerImplLocal<T, 2> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
size_t offset = (size_t)address & 2;
uint32_t* address_as_ui = (uint32_t*)((char*)address - offset);
bool is_32_align = offset;
uint32_t assumed = *address_as_ui;
uint32_t newval;
uint32_t newval_bytes;
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);

do {
newval = assumed;
newval_bytes = is_32_align ? newval >> 16 : newval & 0xffff;
// preserve size in initial cast. Casting directly to uint32_t pads
// negative signed values with 1's (e.g. signed -1 = unsigned ~0).
newval = static_cast<uint16_t>(func(val, static_cast<T>(newval_bytes)));
newval = is_32_align ? (assumed & 0xffff) | (newval << 16)
: (assumed & 0xffff0000) | newval;
} while (!target.compare_exchange_strong(assumed, newval));
}
};

template <typename T>
struct AtomicIntegerImplLocal<T, 4> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
uint32_t* address_as_ui = (uint32_t*)(address);
uint32_t assumed = *address_as_ui;
uint32_t newval;
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);

do {
newval = static_cast<uint32_t>(func(val, static_cast<T>(assumed)));
} while (!target.compare_exchange_strong(assumed, newval));
}
};

template <typename T>
struct AtomicIntegerImplLocal<T, 8> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
unsigned long long* address_as_ull = (unsigned long long*)(address);
unsigned long long assumed = *address_as_ull;
unsigned long long newval;
sycl_atomic_ref_rlx_wg_local_t<unsigned long long> target(*address_as_ull);

do {
newval = static_cast<uint64_t>(func(val, static_cast<T>(assumed)));
} while (!target.compare_exchange_strong(assumed, newval));
}
};

#define SYCL_ATOMIC_INTEGER_LOCAL(NAME, OP, DTYPE) \
static inline void atomic##NAME( \
const sycl_local_ptr<DTYPE>& address, DTYPE val) { \
AtomicIntegerImplLocal<DTYPE, sizeof(DTYPE)>()( \
address, val, [](DTYPE a, DTYPE b) { return OP; }); \
}

template <typename T>
struct AtomicFPImpl;
Expand Down Expand Up @@ -361,11 +275,6 @@ SYCL_ATOMIC_INTEGER(Max, safe_max<int16_t>(a, b), int16_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int32_t>(a, b), int32_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int64_t>(a, b), int64_t)

SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<uint8_t>(a, b), uint8_t)
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int8_t>(a, b), int8_t)
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int16_t>(a, b), int16_t)
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int32_t>(a, b), int32_t)
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int64_t>(a, b), int64_t)

SYCL_ATOMIC_FP(Max, safe_max<float>(a, b), float)
SYCL_ATOMIC_FP(Max, safe_max<double>(a, b), double)
Expand Down
39 changes: 25 additions & 14 deletions src/ATen/native/xpu/sycl/MultinomialKernel.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
#pragma clang diagnostic push
#pragma GCC diagnostic push
// Avoid SYCL compiler return-type error
#pragma clang diagnostic ignored "-Wreturn-type"
#pragma GCC diagnostic ignored "-Wreturn-type"

#include <ATen/AccumulateType.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/xpu/sycl/Atomics.h>
Expand Down Expand Up @@ -83,8 +89,9 @@ inline void renormRows(Tensor& t) {
TORCH_CHECK(t.dim() == 2);
int64_t rows = t.size(0);
int64_t cols = t.size(1);

int group_size = syclMaxWorkItemsPerEU();
int subgroup_size = syclMaxSubGroupSize();
int group_size =
std::min(int(syclMaxWorkItemsPerEU()), subgroup_size* subgroup_size);
int num_groups = (rows + group_size - 1) / group_size;
int hw_max_groups = syclMaxWorkItemsPerTile() / group_size;
num_groups = num_groups > hw_max_groups ? hw_max_groups : num_groups;
Expand Down Expand Up @@ -232,7 +239,7 @@ struct MultinomialWithReplacementKernelImplFunctor {
scalar_t* normDist_ptr;
};
template <typename scalar_t, typename accscalar_t>
struct sampleMultinomialOnceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
struct SampleMultinomialOnceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
void operator()(sycl::nd_item<1> item) const {
accscalar_t* smem = reinterpret_cast<accscalar_t*>(
smem_.template get_multi_ptr<sycl::access::decorated::no>().get());
Expand Down Expand Up @@ -328,13 +335,15 @@ struct sampleMultinomialOnceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
if (inBucket) {
// We're done; we have the sample
// Torch indices are 1-based
atomicMax(
sycl_local_ptr<int>(
foundPos_
.template get_multi_ptr<sycl::access::decorated::no>()
.get()),
cat);
// foundPos_[0] = 1;

sycl::atomic_ref<
int,
sycl_mem_odr_rlx,
sycl_mem_scp_wg,
sycl_local_space>
target(foundPos_[0]);
target.fetch_max(cat, sycl_mem_odr_acq_rel, sycl_mem_scp_wg);

found_[0] = true;
}

Expand Down Expand Up @@ -375,7 +384,7 @@ struct sampleMultinomialOnceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
foundPos_ = sycl_local_acc_t<int>(1, cgh);
}

sampleMultinomialOnceFunctor(
SampleMultinomialOnceFunctor(
int64_t* dest,
int64_t distributions,
int categories,
Expand Down Expand Up @@ -407,7 +416,7 @@ struct sampleMultinomialOnceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
sycl_local_acc_t<int> foundPos_;
};

void multinomial_with_replacement_kernel(
void multinomial_kernel(
Tensor& result,
const Tensor& self,
const int64_t n_sample,
Expand Down Expand Up @@ -446,7 +455,7 @@ void multinomial_with_replacement_kernel(
at::native::uniform_(sampled, 0.0, 1.0, generator);
int group_size = requiredThreads;
int group_range = numDist;
auto kfn = sampleMultinomialOnceFunctor<scalar_t, accscalar_t>(
auto kfn = SampleMultinomialOnceFunctor<scalar_t, accscalar_t>(
result.mutable_data_ptr<int64_t>(),
numDist,
numCategories,
Expand Down Expand Up @@ -530,4 +539,6 @@ void multinomial_with_replacement_kernel(
}
}

} // namespace at::native::xpu
} // namespace at::native::xpu
#pragma GCC diagnostic pop
#pragma clang diagnostic pop
2 changes: 1 addition & 1 deletion src/ATen/native/xpu/sycl/MultinomialKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace at::native::xpu {

void multinomial_with_replacement_kernel(
void multinomial_kernel(
Tensor& result,
const Tensor& self,
const int64_t n_sample,
Expand Down

0 comments on commit e75fb64

Please sign in to comment.