diff --git a/src/ATen/native/xpu/Distributions.cpp b/src/ATen/native/xpu/Distributions.cpp index ecb947a5e..a2e89743a 100644 --- a/src/ATen/native/xpu/Distributions.cpp +++ b/src/ATen/native/xpu/Distributions.cpp @@ -8,6 +8,8 @@ #include #include +#include +#include namespace at { @@ -189,4 +191,106 @@ Tensor& XPUNativeFunctions::random_( return random_(self, 0, to, std::move(generator)); } +/* The largest consecutive integer representable in float32 (2^24) */ +constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (24); + +Tensor& XPUNativeFunctions::multinomial_out( + const Tensor& self, + int64_t n_sample, + bool with_replacement, + ::std::optional gen, + at::Tensor& result) { + TORCH_CHECK( + result.device() == self.device(), + "multinomial arguments must have the same device"); + TORCH_CHECK( + self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim"); + TORCH_CHECK( + at::isFloatingType(self.scalar_type()), + "multinomial only supports floating-point dtypes for input, got: ", + self.scalar_type()); + TORCH_CHECK( + result.scalar_type() == ScalarType::Long, + "multinomial expects Long tensor out, got: ", + result.scalar_type()); + TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples"); + int64_t n_categories = self.size(-1); + TORCH_CHECK( + with_replacement || (n_sample <= n_categories), + "cannot sample n_sample > prob_dist.size(-1) samples without replacement"); + // Since the index tensor is float, numCategories cannot exceed max + // float integer precision + TORCH_CHECK( + n_categories <= FLOAT32_MAX_CONSECUTIVE_INT, + "number of categories cannot exceed 2^24"); + + if (self.dim() == 1) { + result.resize_({n_sample}); + } else { + const int64_t n_dist = self.size(0); + result.resize_({n_dist, n_sample}); + } + if (result.numel() == 0) { + return result; + } + + // Fast-path for no replacement or if only one sample is drawn. + // Reference: + // https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503 + if (!with_replacement || n_sample == 1) { + // Sanity checks on `self`. + auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item(); + TORCH_CHECK( + is_valid.to(), + "probability tensor contains either `inf`, `nan` or element < 0"); + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + bool zero_prob_condition; + if (self.dim() == 1) { + zero_prob_condition = (self.sum() == 0).item().to(); + } else { + zero_prob_condition = (self.sum(1) == 0).sum().item().to(); + } + TORCH_CHECK( + !zero_prob_condition, + "invalid multinomial distribution (sum of probabilities <= 0)"); + + // The algorithm is from gumbel softmax. + // s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1) + // Here we can apply exp to the formula which will not affect result of + // argmax or topk. Then we have + // s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1). + // We can also simplify the formula above by + // s = argmax( p / q ) where q ~ Exp(1) + Tensor q = at::empty_like(self).exponential_(1, std::move(gen)); + // In theory the probability to generate 0 from exponential distribution is + // 0. However, on CUDA side there is a protection to avoid 0s, but on CPU + // side, there is a very low probability to generate 0 from + // exponential. The probability is about 2^(-DBL_MANT_DIG). We just + // ignore it here, but there may be some risk to get invalid output on CPU. + at::div_out(q, self, q); + if (n_sample == 1) { + at::argmax_out(result, q, /*dim=*/-1, /*keepdim=*/true); + } else { + Tensor vals = at::empty(result.sizes(), self.options()); + at::topk_out(vals, result, q, n_sample); + } + return result; + } + + at::native::xpu::multinomial_kernel(result, self, n_sample, gen); + return result; +} + +Tensor XPUNativeFunctions::multinomial( + const Tensor& self, + int64_t n_sample, + bool with_replacement, + ::std::optional gen) { + Tensor result = at::empty({0}, self.options().dtype(kLong)); + + XPUNativeFunctions::multinomial_out( + self, n_sample, with_replacement, std::move(gen), result); + return result; +} + } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index e69183cce..4c2440a7f 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -270,7 +270,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "multilabel_margin_loss_forward", "multi_margin_loss", "multi_margin_loss_backward", - "multinomial", "nanmedian", "nanmedian.dim_values", "nansum", diff --git a/src/ATen/native/xpu/sycl/Atomics.h b/src/ATen/native/xpu/sycl/Atomics.h index 4b124041a..c99b02b3a 100644 --- a/src/ATen/native/xpu/sycl/Atomics.h +++ b/src/ATen/native/xpu/sycl/Atomics.h @@ -27,6 +27,93 @@ template using sycl_atomic_ref_rlx_wg_local_t = sycl::atomic_ref; +template +struct AtomicIntegerImplLocal; + +template +struct AtomicIntegerImplLocal { + template + inline void operator()(T* address, T val, const func_t& func) { + size_t offset = (size_t)address & 3; + uint32_t* address_as_ui = (uint32_t*)((char*)address - offset); + uint32_t assumed = *address_as_ui; + uint32_t shift = offset * 8; + uint32_t newval; + uint32_t newval_byte; + sycl_atomic_ref_rlx_wg_local_t target(*address_as_ui); + + do { + newval = assumed; + newval_byte = (newval >> shift) & 0xff; + // preserve size in initial cast. Casting directly to uint32_t pads + // negative signed values with 1's (e.g. signed -1 = unsigned ~0). + newval = static_cast(func(val, static_cast(newval_byte))); + newval = (assumed & ~(0x000000ff << shift)) | (newval << shift); + } while (!target.compare_exchange_strong(assumed, newval)); + } +}; + +template +struct AtomicIntegerImplLocal { + template + inline void operator()(T* address, T val, const func_t& func) { + size_t offset = (size_t)address & 2; + uint32_t* address_as_ui = (uint32_t*)((char*)address - offset); + bool is_32_align = offset; + uint32_t assumed = *address_as_ui; + uint32_t newval; + uint32_t newval_bytes; + sycl_atomic_ref_rlx_wg_local_t target(*address_as_ui); + + do { + newval = assumed; + newval_bytes = is_32_align ? newval >> 16 : newval & 0xffff; + // preserve size in initial cast. Casting directly to uint32_t pads + // negative signed values with 1's (e.g. signed -1 = unsigned ~0). + newval = static_cast(func(val, static_cast(newval_bytes))); + newval = is_32_align ? (assumed & 0xffff) | (newval << 16) + : (assumed & 0xffff0000) | newval; + } while (!target.compare_exchange_strong(assumed, newval)); + } +}; + +template +struct AtomicIntegerImplLocal { + template + inline void operator()(T* address, T val, const func_t& func) { + uint32_t* address_as_ui = (uint32_t*)(address); + uint32_t assumed = *address_as_ui; + uint32_t newval; + sycl_atomic_ref_rlx_wg_local_t target(*address_as_ui); + + do { + newval = static_cast(func(val, static_cast(assumed))); + } while (!target.compare_exchange_strong(assumed, newval)); + } +}; + +template +struct AtomicIntegerImplLocal { + template + inline void operator()(T* address, T val, const func_t& func) { + unsigned long long* address_as_ull = (unsigned long long*)(address); + unsigned long long assumed = *address_as_ull; + unsigned long long newval; + sycl_atomic_ref_rlx_wg_local_t target(*address_as_ull); + + do { + newval = static_cast(func(val, static_cast(assumed))); + } while (!target.compare_exchange_strong(assumed, newval)); + } +}; + +#define SYCL_ATOMIC_INTEGER_LOCAL(NAME, OP, DTYPE) \ + static inline void atomic##NAME( \ + const sycl_local_ptr& address, DTYPE val) { \ + AtomicIntegerImplLocal()( \ + address, val, [](DTYPE a, DTYPE b) { return OP; }); \ + } + template struct AtomicIntegerImpl; @@ -268,6 +355,25 @@ SYCL_ATOMIC_FP(Mul, std::multiplies()(a, b), at::Half) SYCL_ATOMIC_FP(Mul, std::multiplies()(a, b), at::BFloat16) // Atomic maximum implementation. + +static inline void atomicMax( + const sycl_local_ptr& address, + int32_t val) { + sycl_atomic_ref_rlx_wg_local_t target(*address); + target.fetch_add(val); +} + +static inline void atomicMax( + const sycl_local_ptr& address, + int64_t val) { + sycl_atomic_ref_rlx_wg_local_t target(*address); + target.fetch_add(val); +} + +SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max(a, b), uint8_t) +SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max(a, b), int8_t) +SYCL_ATOMIC_INTEGER_LOCAL(Max, safe_max(a, b), int16_t) + SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t) SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t) SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t) diff --git a/src/ATen/native/xpu/sycl/DistributionTemplates.h b/src/ATen/native/xpu/sycl/DistributionTemplates.h index 6b9a16703..e4345bfad 100644 --- a/src/ATen/native/xpu/sycl/DistributionTemplates.h +++ b/src/ATen/native/xpu/sycl/DistributionTemplates.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include diff --git a/src/ATen/native/xpu/sycl/MultinomialKernel.cpp b/src/ATen/native/xpu/sycl/MultinomialKernel.cpp new file mode 100644 index 000000000..9b1e612c4 --- /dev/null +++ b/src/ATen/native/xpu/sycl/MultinomialKernel.cpp @@ -0,0 +1,544 @@ +#pragma clang diagnostic push +#pragma GCC diagnostic push +// Avoid SYCL compiler return-type error +#pragma clang diagnostic ignored "-Wreturn-type" +#pragma GCC diagnostic ignored "-Wreturn-type" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace at::native::xpu { + +template +inline void renormRowsL1( + item_t& item, + scalar_t* dist, + int64_t rows, + int64_t cols, + unsigned char* my_smem) { + auto thread_idx = item.get_local_id(0); + auto thread_range = item.get_local_range(0); + auto group_idx = item.get_group(0); + auto group_range = item.get_group_range(0); + + scalar_t* smem = reinterpret_cast(my_smem); + scalar_t zero = static_cast(0); + scalar_t val; + for (int64_t row = group_idx; row < rows; row += group_range) { + scalar_t sum = static_cast(0); + for (int64_t col = thread_idx; col < cols; col += thread_range) { + val = dist[row * cols + col]; + sum = sum + val; + } + + sum = GroupReduceSumSGSizeEqualstoNumSG(item, sum, smem); + if (thread_idx == 0) { + smem[0] = sum; + } + item.barrier(sycl_local_fence); + + sum = smem[0]; + if (sum > zero) { + for (int64_t col = thread_idx; col < cols; col += thread_range) { + dist[row * cols + col] = dist[row * cols + col] / sum; + } + } + } +} + +template +struct RenormRowsKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<1> item) const { + renormRowsL1( + item, + t_ptr, + rows, + cols, + (unsigned char*)(smem.template get_multi_ptr< + sycl::access::decorated::no>() + .get())); + } + void sycl_ker_config_convention(sycl::handler& cgh) { + smem = sycl_local_acc_t(group_size_ / 8, cgh); + // We use the smallest subgroup size to ensure enough space + } + RenormRowsKernelFunctor( + int64_t rows_, + int64_t cols_, + scalar_t* t_ptr_, + int group_size) + : rows(rows_), cols(cols_), t_ptr(t_ptr_), group_size_(group_size) {} + + private: + int64_t rows; + int64_t cols; + scalar_t* t_ptr; + int group_size_; + sycl_local_acc_t smem; +}; + +inline void renormRows(Tensor& t) { + TORCH_CHECK(t.dim() == 2); + int64_t rows = t.size(0); + int64_t cols = t.size(1); + int subgroup_size = syclMaxSubGroupSize(); + int group_size = + std::min(int(syclMaxWorkItemsPerEU()), subgroup_size* subgroup_size); + int num_groups = (rows + group_size - 1) / group_size; + int hw_max_groups = syclMaxWorkItemsPerTile() / group_size; + num_groups = num_groups > hw_max_groups ? hw_max_groups : num_groups; + + auto& sycl_queue = at::xpu::getCurrentSYCLQueue(); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, + at::ScalarType::Half, + t.scalar_type(), + "renormRows_xpu", + [&] { + auto t_ptr = t.data_ptr(); + auto kfn = + RenormRowsKernelFunctor(rows, cols, t_ptr, group_size); + sycl_kernel_submit( + num_groups * group_size, group_size, sycl_queue, kfn); + }); +} + +template +inline int binarySearchForMultinomial( + scalar_t* cumdist, + scalar_t* dist, + int size, + scalar_t val) { + int start = 0; + int end = size; + // cumdist[size - 1] = 0 => all zero prob dist + + while (end - start > 0) { + int mid = start + (end - start) / 2; + + scalar_t midVal = cumdist[mid]; + if (midVal < val) { + start = mid + 1; + } else { + end = mid; + } + } + + if (start == size) { + // No probability mass or precision problems; just return the + // first non-zero element by setting start to size-1 here, + // the code below will move it to the last non-zero probability + // this actually can happen when the random number is 1 + // (github pytorch issue #4858). + start = size - 1; + } + + while (start >= 1 && dist[start] == 0) + start--; + + return start; +} + +template +inline void sampleMultinomialWithReplacement( + item_t& item, + PhiloxState philox_args, + int totalSamples, + int64_t* dest, + int64_t distributions, + int categories, + scalar_t* normDistPrefixSum, + scalar_t* normDist) { + auto thread_idx = item.get_local_id(1); + auto thread_range = item.get_local_range(1); + auto group_idx_x = item.get_group(1); + auto group_idx_y = item.get_group(0); + auto group_range_x = item.get_group_range(1); + auto group_range_y = item.get_group_range(0); + + // At the moment, each subgroup computes one sample value in the binary + // search due to divergence. It seems possible to compute multiple + // values and limit divergence though later on. + + auto seeds = philox_unpack(philox_args); + + // global index formula for 2D grid of 1D group + int idx = group_idx_y * group_range_x * thread_range + + group_idx_x * thread_range + thread_idx; + + randStatePhilox4_32_10_t state; + rand_init(std::get<0>(seeds), idx, std::get<1>(seeds), &state); + + // The block determines the distribution for which we generate a point + for (int64_t curDist = group_idx_y; curDist < distributions; + curDist += group_range_y) { + for (int sample = group_idx_x * thread_range + thread_idx; + sample < totalSamples; + sample += thread_range * group_range_x) { + // we are losing 3 out of 4 generated numbers but it's ok + // this kernel is not very efficient anyway + auto rand = rand_uniform4(&state); + scalar_t r = static_cast(rand.x); + + // Find the bucket that a uniform sample lies in + int choice = binarySearchForMultinomial( + normDistPrefixSum + curDist * categories, + normDist + curDist * categories, + categories, + r); + + dest[curDist * totalSamples + sample] = choice; + } + } +} + +template +struct MultinomialWithReplacementKernelImplFunctor { + void operator()(sycl::nd_item<2> item) const { + sampleMultinomialWithReplacement( + item, + rng_engine_inputs, + n_sample, + result_ptr, + numDist, + numCategories, + prefixSum_ptr, + normDist_ptr); + } + MultinomialWithReplacementKernelImplFunctor( + PhiloxState rng_engine_inputs_, + const int64_t n_sample_, + int64_t* result_ptr_, + int64_t numDist_, + int numCategories_, + scalar_t* prefixSum_ptr_, + scalar_t* normDist_ptr_) + : rng_engine_inputs(rng_engine_inputs_), + n_sample(n_sample_), + result_ptr(result_ptr_), + numDist(numDist_), + numCategories(numCategories_), + prefixSum_ptr(prefixSum_ptr_), + normDist_ptr(normDist_ptr_) {} + + private: + PhiloxState rng_engine_inputs; + const int64_t n_sample; + int64_t* result_ptr; + int64_t numDist; + int numCategories; + scalar_t* prefixSum_ptr; + scalar_t* normDist_ptr; +}; +template +struct SampleMultinomialOnceFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<1> item) const { + accscalar_t* smem = reinterpret_cast( + smem_.template get_multi_ptr().get()); + + accscalar_t accZero = static_cast(0); + scalar_t zero = static_cast(0); + int local_id = item.get_local_id(0); + int local_range = item.get_local_range(0); + + for (int64_t curDist = item.get_group(0); curDist < distributions_; + curDist += item.get_group_range(0)) { + // First pass, find the total sum of the distribution + accscalar_t sum = accZero; + scalar_t val; + for (int cat = item.get_local_id(0); cat < categories_; + cat += item.get_local_range(0)) { + val = dist_[curDist * stride_dist_ + cat * stride_categories_]; + SYCL_KERNEL_ASSERT(!at::_isnan(val)); + SYCL_KERNEL_ASSERT(!_isinf(val)); + SYCL_KERNEL_ASSERT(!(val < zero)); + sum = sum + static_cast(val); + } + + sum = GroupReduceSumSGSizeEqualstoNumSG(item, sum, smem); + + // Broadcast sum and sample value + if (item.get_local_id(0) == 0) { + // Make sure the sum of our distribution didn't overflow + SYCL_KERNEL_ASSERT(!_isinf(val)); + SYCL_KERNEL_ASSERT(sum > accZero); + + foundPos_[0] = 0; + smem[0] = sum; + smem[1] = sampled_[curDist]; + } + item.barrier(sycl_local_fence); + + sum = smem[0]; + scalar_t sample = static_cast(smem[1]); + item.barrier(sycl_local_fence); + + if (sum == accZero) { + // Choose the first element + if (local_id == 0) { + dest_[curDist] = 0; + } + + continue; + } + + int chunks = (categories_ + (int)local_range - 1) / local_range; + accscalar_t prevHighProb = accZero; + found_[0] = false; + + for (int chunk = 0; chunk < chunks && !found_[0]; ++chunk) { + // All threads in bounds load a value + int cat = chunk * local_range + local_id; + + accscalar_t dist_val = cat < categories_ + ? static_cast( + dist_[curDist * stride_dist_ + cat * stride_categories_]) / + sum + : accZero; + + smem[local_id] = dist_val; + item.barrier(sycl_local_fence); + + // Perform an inclusive prefix sum of the shared memory contents + for (int offset = 1; offset < local_range; offset *= 2) { + accscalar_t val = accZero; + + if (local_id >= offset) { + val = smem[local_id - offset] + smem[local_id]; + } + + item.barrier(sycl_local_fence); + if (local_id >= offset) { + smem[local_id] = val; + } + item.barrier(sycl_local_fence); + } + + // Each thread will check to see if the sample falls in its + // bucket + scalar_t curBucket = + static_cast(smem[local_id] + prevHighProb); + scalar_t prevBucket = static_cast( + local_id == 0 ? prevHighProb : smem[local_id - 1] + prevHighProb); + bool inBucket = (cat < categories_) && + (!(sample >= curBucket) && (sample >= prevBucket) && + (dist_val > zero)); + + if (inBucket) { + // We're done; we have the sample + // Torch indices are 1-based + + atomicMax( + sycl_local_ptr( + foundPos_ + .template get_multi_ptr() + .get()), + cat); + + found_[0] = true; + } + + // Store the previous scan's high value for future use + prevHighProb = prevHighProb + smem[local_range - 1]; + + item.barrier(sycl_local_fence); + } + + if (local_id == 0) { + if (found_[0]) { + dest_[curDist] = foundPos_[0]; + } else { + // This should address a rare bug where we don't select a valid index. + // This likely occurs when due to floating point arithmetic rounding + // errors, our cumulative sum does not add up to 1, but and our + // uniform sample is greater than this value. In this case we likely + // have unitialized memory in dest[curDist]. So basically we will loop + // through the distribution and pick the largest index where the + // distribution is non-zero. This is obviously terribly inefficient, + // but due to the rarity in which this occurs, this should not be an + // issue. + for (int cat = categories_ - 1; cat >= 0; --cat) { + if (dist_[curDist * stride_dist_ + cat * stride_categories_] > + zero) { + dest_[curDist] = cat; + break; + } + } + } + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + smem_ = sycl_local_acc_t(group_size_, cgh); + found_ = sycl_local_acc_t(1, cgh); + foundPos_ = sycl_local_acc_t(1, cgh); + } + + SampleMultinomialOnceFunctor( + int64_t* dest, + int64_t distributions, + int categories, + const scalar_t* sampled, + const scalar_t* dist, + int stride_dist, // dist->stride(0) + int stride_categories, // dist->stride(1) + int group_size) + : dest_(dest), + distributions_(distributions), + categories_(categories), + sampled_(sampled), + dist_(dist), + stride_dist_(stride_dist), + stride_categories_(stride_categories), + group_size_(group_size) {} + + private: + int64_t* dest_; + int64_t distributions_; + int categories_; + const scalar_t* sampled_; + const scalar_t* dist_; + int stride_dist_; + int stride_categories_; + int group_size_; + sycl_local_acc_t smem_; + sycl_local_acc_t found_; + sycl_local_acc_t foundPos_; +}; + +void multinomial_kernel( + Tensor& result, + const Tensor& self, + const int64_t n_sample, + c10::optional generator) { + auto& sycl_queue = at::xpu::getCurrentSYCLQueue(); + auto gen = get_generator_or_default( + generator, at::xpu::detail::getDefaultXPUGenerator()); + + int inputSize = self.dim(); + int64_t numDist = inputSize == 1 ? 1 : self.size(0); + int numCategories = inputSize == 1 ? self.size(0) : self.size(1); + + // Restructure data for 2d + auto self_v = inputSize == 1 ? self.view({numDist, numCategories}) : self; + + result.resize_({numDist, n_sample}); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, + at::ScalarType::Half, + self_v.scalar_type(), + "multinomial_kernel_xpu", + [&] { + using accscalar_t = acc_type; + using KernelClass = SampleMultinomialOnceFunctor; + int maxThreads = syclMaxWorkGroupSize(); + int maxShared = syclLocalMemSize(); + + int SubGroupSize = syclMinSubGroupSize(); + int requiredSubGroups = at::ceil_div(numCategories, SubGroupSize); + int requiredThreads = + std::min(maxThreads, requiredSubGroups * SubGroupSize); + int requiredShared = requiredThreads * sizeof(accscalar_t); + if (n_sample == 1 && maxShared >= requiredShared) { + Tensor sampled = + at::detail::empty_xpu({numDist, n_sample}, self_v.options()); + at::native::uniform_(sampled, 0.0, 1.0, generator); + int group_size = requiredThreads; + int group_range = numDist; + auto kfn = KernelClass( + result.mutable_data_ptr(), + numDist, + numCategories, + sampled.const_data_ptr(), + self_v.const_data_ptr(), + self_v.stride(0), + self_v.stride(1), + group_size); + + sycl_kernel_submit( + group_range * group_size, group_size, sycl_queue, kfn); + } else { + Tensor origDist = native::empty_like( + self_v, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + c10::nullopt /* device */, + c10::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + origDist.copy_(self_v); + + Tensor normDist = native::empty_like( + self_v, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + c10::nullopt /* device */, + c10::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + Tensor prefixSum = native::empty_like( + self_v, + c10::nullopt /* dtype */, + c10::nullopt /* layout */, + c10::nullopt /* device */, + c10::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + // Renorm along rows + normDist.copy_(origDist); + renormRows(normDist); + + // Prefix sum along rows + at::cumsum_out(prefixSum, normDist, 1); + int group_size = syclMaxWorkItemsPerEU(); + int group_range_y = numDist; + int group_range_x = (n_sample - 1) / group_size + 1; + + std::pair rng_engine_inputs_; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + auto offset = ((numDist - 1) / group_range_y + 1) * 4; + rng_engine_inputs_ = gen->philox_engine_inputs(offset); + } + auto rng_engine_inputs = PhiloxState( + std::get<0>(rng_engine_inputs_), std::get<1>(rng_engine_inputs_)); + // Sample with replacement + + auto result_ptr = result.data_ptr(); + auto prefixSum_ptr = prefixSum.data_ptr(); + auto normDist_ptr = normDist.data_ptr(); + auto kfn = MultinomialWithReplacementKernelImplFunctor( + rng_engine_inputs, + n_sample, + result_ptr, + numDist, + numCategories, + prefixSum_ptr, + normDist_ptr); + + sycl_kernel_submit( + sycl::range<2>(group_range_y, group_range_x * group_size), + sycl::range<2>(1, group_size), + sycl_queue, + kfn); + } + }); + + if (inputSize == 1) { + result.resize_({n_sample}); + } +} + +} // namespace at::native::xpu +#pragma GCC diagnostic pop +#pragma clang diagnostic pop diff --git a/src/ATen/native/xpu/sycl/MultinomialKernel.h b/src/ATen/native/xpu/sycl/MultinomialKernel.h new file mode 100644 index 000000000..d400e51cd --- /dev/null +++ b/src/ATen/native/xpu/sycl/MultinomialKernel.h @@ -0,0 +1,12 @@ +#pragma once +#include + +namespace at::native::xpu { + +void multinomial_kernel( + Tensor& result, + const Tensor& self, + const int64_t n_sample, + c10::optional generator); + +} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/SYCLGroupAlgorithm.h b/src/ATen/native/xpu/sycl/SYCLGroupAlgorithm.h new file mode 100644 index 000000000..35f1d54a5 --- /dev/null +++ b/src/ATen/native/xpu/sycl/SYCLGroupAlgorithm.h @@ -0,0 +1,39 @@ +#pragma once + +#include + +template +inline T GroupReduceSumSGSizeEqualstoNumSG(sg_t& sg, T val) { + auto sg_size = sg.get_local_range()[0]; + for (int offset = (sg_size >> 1); offset > 0; offset >>= 1) { + val += sg.shuffle_down(val, offset); + } + return val; +} + +// function GroupReduceSumSGSizeEqualstoNumSG will firstly reduce elements in +// each subgroups, after that it will store the results of each subgroup into a +// subgroup, and reduce this subgroup for the final result. So, pls notice, when +// using this method, the maximun work_group size should be equals to sub_group +// size * sub_group size, or some element will not be calculated into the final +// result. +template +inline T GroupReduceSumSGSizeEqualstoNumSG(item_t& item, T val, T* shared) { + auto thread_idx = item.get_local_id(0); + auto group_size = item.get_local_range(0); + auto sg = item.get_sub_group(); + auto sg_size = sg.get_local_range()[0]; + int lid = thread_idx % sg_size; + int wid = thread_idx / sg_size; + val = GroupReduceSumSGSizeEqualstoNumSG(sg, val); + item.barrier(sycl_local_fence); + if (lid == 0) { + shared[wid] = val; + } + item.barrier(sycl_local_fence); + val = (thread_idx < group_size / sg_size) ? shared[lid] : T(0); + if (wid == 0) { + val = GroupReduceSumSGSizeEqualstoNumSG(sg, val); + } + return val; +} diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 4282d4b2f..aea77c9b8 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -174,6 +174,7 @@ "bincount", "cross", "renorm", + "multinomial", "lerp", "conj_physical", "copysign", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index d0a5dcd2c..9fbcf0ba7 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -499,6 +499,8 @@ supported: - erfc - erfc_ - erfc.out + - multinomial + - multinomial.out - erf - erf_ - erf.out