Skip to content

Commit

Permalink
Add aten::multinomial (#520)
Browse files Browse the repository at this point in the history
  • Loading branch information
chunhuanMeng authored Jul 14, 2024
1 parent 7dc5e58 commit 822214e
Show file tree
Hide file tree
Showing 9 changed files with 809 additions and 1 deletion.
104 changes: 104 additions & 0 deletions src/ATen/native/xpu/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <ATen/xpu/XPUNativeFunctions.h>

#include <ATen/native/xpu/sycl/DistributionKernels.h>
#include <ATen/native/xpu/sycl/MultinomialKernel.h>
#include <ATen/ops/div.h>

namespace at {

Expand Down Expand Up @@ -189,4 +191,106 @@ Tensor& XPUNativeFunctions::random_(
return random_(self, 0, to, std::move(generator));
}

/* The largest consecutive integer representable in float32 (2^24) */
constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (24);

Tensor& XPUNativeFunctions::multinomial_out(
const Tensor& self,
int64_t n_sample,
bool with_replacement,
::std::optional<at::Generator> gen,
at::Tensor& result) {
TORCH_CHECK(
result.device() == self.device(),
"multinomial arguments must have the same device");
TORCH_CHECK(
self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
TORCH_CHECK(
at::isFloatingType(self.scalar_type()),
"multinomial only supports floating-point dtypes for input, got: ",
self.scalar_type());
TORCH_CHECK(
result.scalar_type() == ScalarType::Long,
"multinomial expects Long tensor out, got: ",
result.scalar_type());
TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples");
int64_t n_categories = self.size(-1);
TORCH_CHECK(
with_replacement || (n_sample <= n_categories),
"cannot sample n_sample > prob_dist.size(-1) samples without replacement");
// Since the index tensor is float, numCategories cannot exceed max
// float integer precision
TORCH_CHECK(
n_categories <= FLOAT32_MAX_CONSECUTIVE_INT,
"number of categories cannot exceed 2^24");

if (self.dim() == 1) {
result.resize_({n_sample});
} else {
const int64_t n_dist = self.size(0);
result.resize_({n_dist, n_sample});
}
if (result.numel() == 0) {
return result;
}

// Fast-path for no replacement or if only one sample is drawn.
// Reference:
// https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
if (!with_replacement || n_sample == 1) {
// Sanity checks on `self`.
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
TORCH_CHECK(
is_valid.to<bool>(),
"probability tensor contains either `inf`, `nan` or element < 0");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool zero_prob_condition;
if (self.dim() == 1) {
zero_prob_condition = (self.sum() == 0).item().to<bool>();
} else {
zero_prob_condition = (self.sum(1) == 0).sum().item().to<bool>();
}
TORCH_CHECK(
!zero_prob_condition,
"invalid multinomial distribution (sum of probabilities <= 0)");

// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
// Here we can apply exp to the formula which will not affect result of
// argmax or topk. Then we have
// s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
// We can also simplify the formula above by
// s = argmax( p / q ) where q ~ Exp(1)
Tensor q = at::empty_like(self).exponential_(1, std::move(gen));
// In theory the probability to generate 0 from exponential distribution is
// 0. However, on CUDA side there is a protection to avoid 0s, but on CPU
// side, there is a very low probability to generate 0 from
// exponential<double>. The probability is about 2^(-DBL_MANT_DIG). We just
// ignore it here, but there may be some risk to get invalid output on CPU.
at::div_out(q, self, q);
if (n_sample == 1) {
at::argmax_out(result, q, /*dim=*/-1, /*keepdim=*/true);
} else {
Tensor vals = at::empty(result.sizes(), self.options());
at::topk_out(vals, result, q, n_sample);
}
return result;
}

at::native::xpu::multinomial_kernel(result, self, n_sample, gen);
return result;
}

Tensor XPUNativeFunctions::multinomial(
const Tensor& self,
int64_t n_sample,
bool with_replacement,
::std::optional<at::Generator> gen) {
Tensor result = at::empty({0}, self.options().dtype(kLong));

XPUNativeFunctions::multinomial_out(
self, n_sample, with_replacement, std::move(gen), result);
return result;
}

} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"multilabel_margin_loss_forward",
"multi_margin_loss",
"multi_margin_loss_backward",
"multinomial",
"nanmedian",
"nanmedian.dim_values",
"nansum",
Expand Down
106 changes: 106 additions & 0 deletions src/ATen/native/xpu/sycl/Atomics.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,93 @@ template <typename T>
using sycl_atomic_ref_rlx_wg_local_t =
sycl::atomic_ref<T, sycl_mem_odr_rlx, sycl_mem_scp_wg, sycl_local_space>;

template <typename T, size_t n>
struct AtomicIntegerImplLocal;

template <typename T>
struct AtomicIntegerImplLocal<T, 1> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
size_t offset = (size_t)address & 3;
uint32_t* address_as_ui = (uint32_t*)((char*)address - offset);
uint32_t assumed = *address_as_ui;
uint32_t shift = offset * 8;
uint32_t newval;
uint32_t newval_byte;
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);

do {
newval = assumed;
newval_byte = (newval >> shift) & 0xff;
// preserve size in initial cast. Casting directly to uint32_t pads
// negative signed values with 1's (e.g. signed -1 = unsigned ~0).
newval = static_cast<uint8_t>(func(val, static_cast<T>(newval_byte)));
newval = (assumed & ~(0x000000ff << shift)) | (newval << shift);
} while (!target.compare_exchange_strong(assumed, newval));
}
};

template <typename T>
struct AtomicIntegerImplLocal<T, 2> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
size_t offset = (size_t)address & 2;
uint32_t* address_as_ui = (uint32_t*)((char*)address - offset);
bool is_32_align = offset;
uint32_t assumed = *address_as_ui;
uint32_t newval;
uint32_t newval_bytes;
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);

do {
newval = assumed;
newval_bytes = is_32_align ? newval >> 16 : newval & 0xffff;
// preserve size in initial cast. Casting directly to uint32_t pads
// negative signed values with 1's (e.g. signed -1 = unsigned ~0).
newval = static_cast<uint16_t>(func(val, static_cast<T>(newval_bytes)));
newval = is_32_align ? (assumed & 0xffff) | (newval << 16)
: (assumed & 0xffff0000) | newval;
} while (!target.compare_exchange_strong(assumed, newval));
}
};

template <typename T>
struct AtomicIntegerImplLocal<T, 4> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
uint32_t* address_as_ui = (uint32_t*)(address);
uint32_t assumed = *address_as_ui;
uint32_t newval;
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);

do {
newval = static_cast<uint32_t>(func(val, static_cast<T>(assumed)));
} while (!target.compare_exchange_strong(assumed, newval));
}
};

template <typename T>
struct AtomicIntegerImplLocal<T, 8> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
unsigned long long* address_as_ull = (unsigned long long*)(address);
unsigned long long assumed = *address_as_ull;
unsigned long long newval;
sycl_atomic_ref_rlx_wg_local_t<unsigned long long> target(*address_as_ull);

do {
newval = static_cast<uint64_t>(func(val, static_cast<T>(assumed)));
} while (!target.compare_exchange_strong(assumed, newval));
}
};

#define SYCL_ATOMIC_INTEGER_LOCAL(NAME, OP, DTYPE) \
static inline void atomic##NAME( \
const sycl_local_ptr<DTYPE>& address, DTYPE val) { \
AtomicIntegerImplLocal<DTYPE, sizeof(DTYPE)>()( \
address, val, [](DTYPE a, DTYPE b) { return OP; }); \
}

template <typename T, size_t n>
struct AtomicIntegerImpl;

Expand Down Expand Up @@ -268,6 +355,25 @@ SYCL_ATOMIC_FP(Mul, std::multiplies<at::Half>()(a, b), at::Half)
SYCL_ATOMIC_FP(Mul, std::multiplies<at::BFloat16>()(a, b), at::BFloat16)

// Atomic maximum implementation.

static inline void atomicMax(
const sycl_local_ptr<int32_t>& address,
int32_t val) {
sycl_atomic_ref_rlx_wg_local_t<int32_t> target(*address);
target.fetch_add(val);
}

static inline void atomicMax(
const sycl_local_ptr<int64_t>& address,
int64_t val) {
sycl_atomic_ref_rlx_wg_local_t<int64_t> target(*address);
target.fetch_add(val);
}

SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<uint8_t>(a, b), uint8_t)
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int8_t>(a, b), int8_t)
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int16_t>(a, b), int16_t)

SYCL_ATOMIC_INTEGER(Max, safe_max<uint8_t>(a, b), uint8_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int8_t>(a, b), int8_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int16_t>(a, b), int16_t)
Expand Down
1 change: 1 addition & 0 deletions src/ATen/native/xpu/sycl/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ATen/native/xpu/sycl/MemoryAccess.h>
#include <ATen/native/xpu/sycl/OffsetCalculator.h>
#include <ATen/native/xpu/sycl/Philox4x32.h>
#include <ATen/ops/empty.h>
#include <comm/DeviceProperties.h>
#include <comm/Runtime.h>

Expand Down
Loading

0 comments on commit 822214e

Please sign in to comment.