Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor bernoulli kernel to align with PyTorch semantics #610

Merged
merged 6 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 87 additions & 27 deletions src/ATen/native/xpu/sycl/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ATen/native/xpu/sycl/MemoryAccess.h>
#include <ATen/native/xpu/sycl/OffsetCalculator.h>
#include <ATen/native/xpu/sycl/Philox4x32.h>
#include <ATen/native/xpu/sycl/TensorApplyUtils.h>
#include <ATen/ops/empty.h>
#include <comm/DeviceProperties.h>
#include <comm/Runtime.h>
Expand Down Expand Up @@ -516,51 +517,110 @@ void uniform_kernel(

// ====================== Bernoulli ======================

template <typename scalar_t, typename accscalar_t>
struct BernoulliFunctor {
scalar_t operator()(scalar_t out, accscalar_t p) const {
return static_cast<scalar_t>((accscalar_t)out < p);
template <typename scalar_t, typename prob_t>
struct BernoulliTensorApplyFunctor {
void operator()(
sycl::nd_item<1> item,
int n,
scalar_t& v1,
scalar_t& v2,
scalar_t& v3,
scalar_t& v4,
const prob_t& p1,
const prob_t& p2,
const prob_t& p3,
const prob_t& p4) const {
auto seeds = philox_unpack(philox_args_);
randStatePhilox4_32_10_t state;
rand_init(
std::get<0>(seeds),
item.get_group(0) * item.get_local_range(0) + item.get_local_id(0),
std::get<1>(seeds),
&state);
auto rand = rand_uniform4(&state);
switch (n) {
case 4: {
SYCL_KERNEL_ASSERT(0 <= p4 && p4 <= 1);
v4 = static_cast<scalar_t>(rand.w <= p4);
[[fallthrough]];
}
case 3: {
SYCL_KERNEL_ASSERT(0 <= p3 && p3 <= 1);
v3 = static_cast<scalar_t>(rand.z <= p3);
[[fallthrough]];
}
case 2: {
SYCL_KERNEL_ASSERT(0 <= p2 && p2 <= 1);
v2 = static_cast<scalar_t>(rand.y <= p2);
[[fallthrough]];
}
case 1: {
SYCL_KERNEL_ASSERT(0 <= p1 && p1 <= 1);
v1 = static_cast<scalar_t>(rand.x <= p1);
}
}
}
BernoulliTensorApplyFunctor(std::pair<uint64_t, uint64_t> rng_engine_inputs)
: philox_args_(
std::get<0>(rng_engine_inputs),
std::get<1>(rng_engine_inputs)) {}

private:
PhiloxState philox_args_;
};

template <typename scalar_t, typename prob_t>
void bernoulli_tensor_kernel(
TensorBase& ret,
TensorBase& p,
std::pair<uint64_t, uint64_t> rng_engine_inputs) {
auto functor =
BernoulliTensorApplyFunctor<scalar_t, prob_t>(rng_engine_inputs);
// The template argument `4` below indicates that we want to operate on four
// element at each time.
at::native::xpu::tensor_apply2<
scalar_t,
prob_t,
4,
decltype(functor),
/*threads_per_group=*/512>(ret, p, functor);
}

template <typename RNG>
void bernoulli_kernel(const TensorBase& self, const TensorBase& p_, RNG gen) {
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(10);
}
TORCH_CHECK(
at::isFloatingType(p_.scalar_type()),
"expected probabilities tensor to have floating type, got ",
p_.scalar_type());
// cast probabilities tensor to double for double `self` tensor, and to
// `float` for everything else
const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat;
auto p_xpu = p_.to(TensorOptions().device(self.device()).dtype(p_type));
auto p = expand_inplace(self, p_xpu);

Tensor self_float;
auto self_type = self.scalar_type();
if (!(self_type == at::ScalarType::Float ||
self_type == at::ScalarType::Double))
self_float = at::empty(self.sizes(), self.options().dtype(at::kFloat));
else
self_float = self;

auto iter_uniform = at::TensorIterator::borrowing_nullary_op(self_float);
uniform_kernel<RNG>(iter_uniform, 0.0, 1.0, gen);

auto iter = TensorIteratorConfig()
.add_output(self)
.add_input(self_float)
.add_input(*p)
.check_all_same_dtype(false)
.build();

AT_DISPATCH_ALL_TYPES_AND3(
at::ScalarType::Half,
at::ScalarType::BFloat16,
at::ScalarType::Bool,
self.scalar_type(),
"bernoulli_xpu",
"bernoulli_tensor_xpu_self_",
[&] {
using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
auto f = BernoulliFunctor<scalar_t, accscalar_t>();
gpu_kernel(iter, f);
if (std::is_same<scalar_t, double>::value) {
return bernoulli_tensor_kernel<double, double>(
const_cast<TensorBase&>(self),
const_cast<TensorBase&>(*p),
rng_engine_inputs);
} else {
return bernoulli_tensor_kernel<scalar_t, float>(
const_cast<TensorBase&>(self),
const_cast<TensorBase&>(*p),
rng_engine_inputs);
}
});
}

Expand Down
79 changes: 79 additions & 0 deletions src/ATen/native/xpu/sycl/IndexUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#pragma once

#include <ATen/ATen.h>
#include <vector>

namespace at {
namespace xpu {
namespace detail {

struct SizeAndStride {
int64_t size;
int64_t stride;
};

/*
A comparator that will sort SizeAndStride structs by stride,
in ascending order.
*/
inline int compareSizeAndStride(const void* a, const void* b) {
const SizeAndStride* aS = (const SizeAndStride*)a;
const SizeAndStride* bS = (const SizeAndStride*)b;

if (aS->stride < bS->stride)
return -1;
if (aS->stride == bS->stride)
return 0;
return 1;
}

/*
Returns false if there is no possibility that the tensor
has "overlapping" indices and true otherwise.
"Overlapping" indices are two+ valid indices that specify
the same offset within the tensor.
The function does this by checking for a sufficient but not
necessary condition of no overlap. In particular, that
that there exists an ordering of the tensor's dimensions
that is nicely "nested," with each dimension contained
within the next one.
*/
inline bool maybeOverlappingIndices(const TensorBase& t) {
/* Extract size/stride arrays; only consider size >1 dims. */
std::vector<SizeAndStride> info(t.dim());
int dims = t.dim();
int nonSize1Dims = 0;
for (int i = 0; i < dims; ++i) {
int64_t size = t.size(i);
if (size > 1) {
info[nonSize1Dims].size = size;
info[nonSize1Dims].stride = t.stride(i);

if (info[nonSize1Dims].stride < 1) {
return true;
}

++nonSize1Dims;
}
}

// Short-circuits if tensor is a single element.
if (nonSize1Dims == 0) {
return false;
}

/* Ascending order (innermost dimension in sorted view is at [0]) */
qsort(info.data(), nonSize1Dims, sizeof(SizeAndStride), compareSizeAndStride);

for (int i = 0; i < (nonSize1Dims - 1); ++i) {
if (((info[i].size - 1) * info[i].stride) >= info[i + 1].stride) {
return true;
}
}

return false;
}

} // namespace detail
} // namespace xpu
} // namespace at
Loading