Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten::multinomial #520

Merged
merged 16 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions src/ATen/native/xpu/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <ATen/xpu/XPUNativeFunctions.h>

#include <ATen/native/xpu/sycl/DistributionKernels.h>
#include <ATen/native/xpu/sycl/MultinomialKernel.h>
#include <ATen/ops/div.h>

namespace at {

Expand Down Expand Up @@ -189,4 +191,106 @@ Tensor& XPUNativeFunctions::random_(
return random_(self, 0, to, std::move(generator));
}

/* The largest consecutive integer representable in float32 (2^24) */
constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (24);

Tensor& XPUNativeFunctions::multinomial_out(
const Tensor& self,
int64_t n_sample,
bool with_replacement,
::std::optional<at::Generator> gen,
at::Tensor& result) {
TORCH_CHECK(
result.device() == self.device(),
"multinomial arguments must have the same device");
TORCH_CHECK(
self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
TORCH_CHECK(
at::isFloatingType(self.scalar_type()),
"multinomial only supports floating-point dtypes for input, got: ",
self.scalar_type());
TORCH_CHECK(
result.scalar_type() == ScalarType::Long,
"multinomial expects Long tensor out, got: ",
result.scalar_type());
TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples");
int64_t n_categories = self.size(-1);
TORCH_CHECK(
with_replacement || (n_sample <= n_categories),
"cannot sample n_sample > prob_dist.size(-1) samples without replacement");
// Since the index tensor is float, numCategories cannot exceed max
// float integer precision
TORCH_CHECK(
n_categories <= FLOAT32_MAX_CONSECUTIVE_INT,
"number of categories cannot exceed 2^24");

if (self.dim() == 1) {
result.resize_({n_sample});
} else {
const int64_t n_dist = self.size(0);
result.resize_({n_dist, n_sample});
}
if (result.numel() == 0) {
return result;
}

// Fast-path for no replacement or if only one sample is drawn.
// Reference:
// https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
if (!with_replacement || n_sample == 1) {
// Sanity checks on `self`.
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
TORCH_CHECK(
is_valid.to<bool>(),
"probability tensor contains either `inf`, `nan` or element < 0");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool zero_prob_condition;
if (self.dim() == 1) {
zero_prob_condition = (self.sum() == 0).item().to<bool>();
} else {
zero_prob_condition = (self.sum(1) == 0).sum().item().to<bool>();
}
TORCH_CHECK(
!zero_prob_condition,
"invalid multinomial distribution (sum of probabilities <= 0)");

// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
// Here we can apply exp to the formula which will not affect result of
// argmax or topk. Then we have
// s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
// We can also simplify the formula above by
// s = argmax( p / q ) where q ~ Exp(1)
Tensor q = at::empty_like(self).exponential_(1, std::move(gen));
// In theory the probability to generate 0 from exponential distribution is
// 0. However, on CUDA side there is a protection to avoid 0s, but on CPU
// side, there is a very low probability to generate 0 from
// exponential<double>. The probability is about 2^(-DBL_MANT_DIG). We just
// ignore it here, but there may be some risk to get invalid output on CPU.
at::div_out(q, self, q);
if (n_sample == 1) {
at::argmax_out(result, q, /*dim=*/-1, /*keepdim=*/true);
} else {
Tensor vals = at::empty(result.sizes(), self.options());
at::topk_out(vals, result, q, n_sample);
}
return result;
}

at::native::xpu::multinomial_kernel(result, self, n_sample, gen);
return result;
}

Tensor XPUNativeFunctions::multinomial(
const Tensor& self,
int64_t n_sample,
bool with_replacement,
::std::optional<at::Generator> gen) {
Tensor result = at::empty({0}, self.options().dtype(kLong));

XPUNativeFunctions::multinomial_out(
self, n_sample, with_replacement, std::move(gen), result);
return result;
}

} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"multilabel_margin_loss_forward",
"multi_margin_loss",
"multi_margin_loss_backward",
"multinomial",
"nanmedian",
"nanmedian.dim_values",
"nansum",
Expand Down
106 changes: 106 additions & 0 deletions src/ATen/native/xpu/sycl/Atomics.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,93 @@ template <typename T>
using sycl_atomic_ref_rlx_wg_local_t =
sycl::atomic_ref<T, sycl_mem_odr_rlx, sycl_mem_scp_wg, sycl_local_space>;

template <typename T, size_t n>
struct AtomicIntegerImplLocal;

template <typename T>
struct AtomicIntegerImplLocal<T, 1> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
size_t offset = (size_t)address & 3;
uint32_t* address_as_ui = (uint32_t*)((char*)address - offset);
uint32_t assumed = *address_as_ui;
uint32_t shift = offset * 8;
uint32_t newval;
uint32_t newval_byte;
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);

do {
newval = assumed;
newval_byte = (newval >> shift) & 0xff;
// preserve size in initial cast. Casting directly to uint32_t pads
// negative signed values with 1's (e.g. signed -1 = unsigned ~0).
newval = static_cast<uint8_t>(func(val, static_cast<T>(newval_byte)));
newval = (assumed & ~(0x000000ff << shift)) | (newval << shift);
} while (!target.compare_exchange_strong(assumed, newval));
}
};

template <typename T>
struct AtomicIntegerImplLocal<T, 2> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
size_t offset = (size_t)address & 2;
uint32_t* address_as_ui = (uint32_t*)((char*)address - offset);
bool is_32_align = offset;
uint32_t assumed = *address_as_ui;
uint32_t newval;
uint32_t newval_bytes;
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);

do {
newval = assumed;
newval_bytes = is_32_align ? newval >> 16 : newval & 0xffff;
// preserve size in initial cast. Casting directly to uint32_t pads
// negative signed values with 1's (e.g. signed -1 = unsigned ~0).
newval = static_cast<uint16_t>(func(val, static_cast<T>(newval_bytes)));
newval = is_32_align ? (assumed & 0xffff) | (newval << 16)
: (assumed & 0xffff0000) | newval;
} while (!target.compare_exchange_strong(assumed, newval));
}
};

template <typename T>
struct AtomicIntegerImplLocal<T, 4> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
uint32_t* address_as_ui = (uint32_t*)(address);
uint32_t assumed = *address_as_ui;
uint32_t newval;
sycl_atomic_ref_rlx_wg_local_t<uint32_t> target(*address_as_ui);

do {
newval = static_cast<uint32_t>(func(val, static_cast<T>(assumed)));
} while (!target.compare_exchange_strong(assumed, newval));
}
};

template <typename T>
struct AtomicIntegerImplLocal<T, 8> {
template <typename func_t>
inline void operator()(T* address, T val, const func_t& func) {
unsigned long long* address_as_ull = (unsigned long long*)(address);
unsigned long long assumed = *address_as_ull;
unsigned long long newval;
sycl_atomic_ref_rlx_wg_local_t<unsigned long long> target(*address_as_ull);

do {
newval = static_cast<uint64_t>(func(val, static_cast<T>(assumed)));
} while (!target.compare_exchange_strong(assumed, newval));
}
};

#define SYCL_ATOMIC_INTEGER_LOCAL(NAME, OP, DTYPE) \
static inline void atomic##NAME( \
const sycl_local_ptr<DTYPE>& address, DTYPE val) { \
AtomicIntegerImplLocal<DTYPE, sizeof(DTYPE)>()( \
address, val, [](DTYPE a, DTYPE b) { return OP; }); \
}

template <typename T, size_t n>
struct AtomicIntegerImpl;

Expand Down Expand Up @@ -268,6 +355,25 @@ SYCL_ATOMIC_FP(Mul, std::multiplies<at::Half>()(a, b), at::Half)
SYCL_ATOMIC_FP(Mul, std::multiplies<at::BFloat16>()(a, b), at::BFloat16)

// Atomic maximum implementation.

static inline void atomicMax(
const sycl_local_ptr<int32_t>& address,
int32_t val) {
sycl_atomic_ref_rlx_wg_local_t<int32_t> target(*address);
target.fetch_add(val);
}

static inline void atomicMax(
const sycl_local_ptr<int64_t>& address,
int64_t val) {
sycl_atomic_ref_rlx_wg_local_t<int64_t> target(*address);
target.fetch_add(val);
}

SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<uint8_t>(a, b), uint8_t)
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int8_t>(a, b), int8_t)
SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max<int16_t>(a, b), int16_t)

SYCL_ATOMIC_INTEGER(Max, safe_max<uint8_t>(a, b), uint8_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int8_t>(a, b), int8_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int16_t>(a, b), int16_t)
Expand Down
1 change: 1 addition & 0 deletions src/ATen/native/xpu/sycl/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ATen/native/xpu/sycl/MemoryAccess.h>
#include <ATen/native/xpu/sycl/OffsetCalculator.h>
#include <ATen/native/xpu/sycl/Philox4x32.h>
#include <ATen/ops/empty.h>
#include <comm/DeviceProperties.h>
#include <comm/Runtime.h>

Expand Down
Loading