From c9a94baf6d82bdf8d6179fc8a1f45cb5a0f9f836 Mon Sep 17 00:00:00 2001 From: hjhee Date: Sat, 19 Oct 2024 00:43:43 +0800 Subject: [PATCH] Add distribution operators (#826) - cauchy_ - geometric - binomial - binomial.out - _dirichlet_grad - _dirichlet_grad.out - _sample_dirichlet - _sample_dirichlet.out --------- Co-authored-by: Yutao Xu --- src/ATen/native/xpu/Distributions.cpp | 60 +++ src/ATen/native/xpu/XPUFallback.template | 2 - .../xpu/sycl/DistributionCauchyKernel.cpp | 17 + .../xpu/sycl/DistributionGeometricKernel.cpp | 16 + .../native/xpu/sycl/DistributionKernels.h | 11 + .../native/xpu/sycl/DistributionTemplates.h | 303 ++++++++++++++- src/ATen/native/xpu/sycl/Distributions.cpp | 242 ++++++++++++ src/ATen/native/xpu/sycl/Distributions.h | 26 ++ src/ATen/native/xpu/sycl/Philox4x32.h | 153 ++++++++ test/xpu/test_distributions_xpu.py | 350 ++++++++++++++++++ test/xpu/xpu_test_utils.py | 3 + yaml/native/native_functions.yaml | 44 +++ 12 files changed, 1223 insertions(+), 4 deletions(-) create mode 100644 src/ATen/native/xpu/sycl/DistributionCauchyKernel.cpp create mode 100644 src/ATen/native/xpu/sycl/DistributionGeometricKernel.cpp create mode 100644 src/ATen/native/xpu/sycl/Distributions.cpp create mode 100644 src/ATen/native/xpu/sycl/Distributions.h create mode 100644 test/xpu/test_distributions_xpu.py diff --git a/src/ATen/native/xpu/Distributions.cpp b/src/ATen/native/xpu/Distributions.cpp index 264a368f8..d7f602d6e 100644 --- a/src/ATen/native/xpu/Distributions.cpp +++ b/src/ATen/native/xpu/Distributions.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -29,5 +30,64 @@ REGISTER_XPU_DISPATCH( multinomial_with_replacement_stub, &xpu::multinomial_kernel); REGISTER_XPU_DISPATCH(log_normal_stub, &xpu::log_normal_kernel); +REGISTER_XPU_DISPATCH(cauchy_stub, &xpu::cauchy_kernel); +REGISTER_XPU_DISPATCH(geometric_stub, &xpu::geometric_kernel); + +Tensor _s_poisson_xpu(const Tensor& lambda, std::optional gen_) { + auto gen = get_generator_or_default( + gen_, at::xpu::detail::getDefaultXPUGenerator()); + Tensor ret = at::empty(lambda.sizes(), lambda.options()); + xpu::launch_poisson_kernel(ret, lambda, gen); + return ret; +} + +Tensor _s_binomial_xpu( + const Tensor& count, + const Tensor& prob, + std::optional generator) { + auto gen = get_generator_or_default( + generator, at::xpu::detail::getDefaultXPUGenerator()); + Tensor ret = at::empty(count.sizes(), count.options()); + at::TensorIterator iter = at::TensorIteratorConfig() + .add_output(ret) + .add_input(count) + .add_input(prob) + .build(); + xpu::launch_binomial_kernel(iter, gen); + return ret; +} + +Tensor _sample_dirichlet_xpu( + const Tensor& alpha, + std::optional generator) { + auto gen = get_generator_or_default( + generator, at::xpu::detail::getDefaultXPUGenerator()); + Tensor ret = at::empty(alpha.sizes(), alpha.options()); + xpu::launch_gamma_kernel(ret, alpha, gen); + auto gamma_sum = ret.sum(/*dim=*/-1, /*keepdim=*/true); + auto iter = at::TensorIteratorConfig() + .add_output(ret) + .add_input(ret) + .add_input(gamma_sum) + .build(); + xpu::launch_dirichlet_kernel(iter); + return ret; +} + +Tensor _dirichlet_grad_xpu( + const Tensor& x, + const Tensor& alpha, + const Tensor& total) { + Tensor ret = at::empty(x.sizes(), x.options()); + auto iter = at::TensorIteratorConfig() + .add_output(ret) + .add_input(x) + .add_input(alpha) + .add_input(total) + .build(); + xpu::launch_dirichlet_grad_kernel(iter); + return ret; +} + } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 2195e5f08..7038b1675 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -159,7 +159,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "adaptive_max_pool3d.out", "avg_pool3d_backward.grad_input", "avg_pool3d.out", - "cauchy_", "_cdist_backward", "cholesky", "cholesky_inverse", @@ -180,7 +179,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "fractional_max_pool3d.output", "frexp.Tensor_out", "_fused_moving_avg_obs_fq_helper", - "geometric_", "geqrf", "heaviside.out", "histc", diff --git a/src/ATen/native/xpu/sycl/DistributionCauchyKernel.cpp b/src/ATen/native/xpu/sycl/DistributionCauchyKernel.cpp new file mode 100644 index 000000000..873510602 --- /dev/null +++ b/src/ATen/native/xpu/sycl/DistributionCauchyKernel.cpp @@ -0,0 +1,17 @@ +#include +#include +#include + +namespace at::native::xpu { + +void cauchy_kernel( + TensorIteratorBase& iter, + double median, + double sigma, + c10::optional gen) { + auto generator = get_generator_or_default( + gen, at::xpu::detail::getDefaultXPUGenerator()); + at::native::templates::xpu::cauchy_kernel(iter, median, sigma, generator); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/DistributionGeometricKernel.cpp b/src/ATen/native/xpu/sycl/DistributionGeometricKernel.cpp new file mode 100644 index 000000000..f42ff972d --- /dev/null +++ b/src/ATen/native/xpu/sycl/DistributionGeometricKernel.cpp @@ -0,0 +1,16 @@ +#include +#include +#include + +namespace at::native::xpu { + +void geometric_kernel( + TensorIteratorBase& iter, + double p_, + c10::optional gen) { + auto generator = get_generator_or_default( + gen, at::xpu::detail::getDefaultXPUGenerator()); + at::native::templates::xpu::geometric_kernel(iter, p_, generator); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/DistributionKernels.h b/src/ATen/native/xpu/sycl/DistributionKernels.h index 494033dbb..2f7133a7a 100644 --- a/src/ATen/native/xpu/sycl/DistributionKernels.h +++ b/src/ATen/native/xpu/sycl/DistributionKernels.h @@ -51,4 +51,15 @@ TORCH_XPU_API void log_normal_kernel( double std, std::optional gen); +TORCH_XPU_API void cauchy_kernel( + TensorIteratorBase& iter, + double median, + double sigma, + c10::optional gen); + +TORCH_XPU_API void geometric_kernel( + TensorIteratorBase& iter, + double p_, + c10::optional gen); + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/DistributionTemplates.h b/src/ATen/native/xpu/sycl/DistributionTemplates.h index f696e50d1..48b7659c2 100644 --- a/src/ATen/native/xpu/sycl/DistributionTemplates.h +++ b/src/ATen/native/xpu/sycl/DistributionTemplates.h @@ -221,6 +221,236 @@ void distribution_nullary_kernel( } } +// Unary kernel +template < + typename scalar1_t, + typename scalar2_t, + typename func_t, + typename inp_offset_calc_t, + typename out_offset_calc_t> +struct DistributionUnaryElementwiseKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + int group_size = item.get_local_range(0); + int global_size = item.get_global_range(0); + int global_idx = item.get_group(0) * group_size + item.get_local_id(0); + + auto seeds = philox_unpack(philox_args_); + randStatePhilox4_32_10_t state; + rand_init(std::get<0>(seeds), global_idx, std::get<1>(seeds), &state); + + for (int i = global_idx; i < numel_; i += global_size) { + auto in_offsets = inp_calc_.get(i); + auto out_offsets = out_calc_.get(i); + f_(state, output_data_[out_offsets[0]], input_data_[in_offsets[0]]); + } + } + DistributionUnaryElementwiseKernelFunctor( + int numel, + const func_t f, + PhiloxState philox_args, + scalar1_t* output_data, + const scalar2_t* input_data, + inp_offset_calc_t input_offset_calculator, + out_offset_calc_t output_offset_calculator) + : numel_(numel), + f_(f), + philox_args_(philox_args), + output_data_(output_data), + input_data_(input_data), + inp_calc_(input_offset_calculator), + out_calc_(output_offset_calculator) {} + + private: + int numel_; + const func_t f_; + PhiloxState philox_args_; + scalar1_t* output_data_; + const scalar2_t* input_data_; + inp_offset_calc_t inp_calc_; + out_offset_calc_t out_calc_; +}; + +template +void distribution_unary_kernel( + TensorIterator& iter, + PhiloxState philox_args, + func_t f) { + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + distribution_unary_kernel( + sub_iter, philox_args, f); + } + return; + } + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing()); + + int64_t numel = iter.numel(); + if (numel == 0) { + return; + } + + auto execution_policy = calc_execution_policy(numel); + auto num_groups = std::get<1>(execution_policy); + auto group_size = std::get<2>(execution_policy); + + scalar1_t* output_data = static_cast(iter.data_ptr(0)); + const scalar2_t* input_data = static_cast(iter.data_ptr(1)); + + if (iter.is_contiguous()) { + auto input_offset_calculator = TrivialOffsetCalculator<1>(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + auto caller = DistributionUnaryElementwiseKernelFunctor( + numel, + f, + philox_args, + output_data, + input_data, + input_offset_calculator, + output_offset_calculator); + sycl_kernel_submit( + num_groups * group_size, group_size, getCurrentSYCLQueue(), caller); + } else { + auto input_offset_calculator = make_input_offset_calculator<1>(iter); + auto output_offset_calculator = make_output_offset_calculator(iter); + auto caller = DistributionUnaryElementwiseKernelFunctor( + numel, + f, + philox_args, + output_data, + input_data, + input_offset_calculator, + output_offset_calculator); + sycl_kernel_submit( + num_groups * group_size, group_size, getCurrentSYCLQueue(), caller); + } +} + +// Binary kernel +template < + typename func_t, + typename inp_offset_calc_t, + typename out_offset_calc_t> +struct DistributionBinaryElementwiseKernelFunctor { + using input_t_1 = typename function_traits::template arg<1>::type; + using input_t_2 = typename function_traits::template arg<2>::type; + using output_t = typename function_traits::result_type; + void operator()(sycl::nd_item<1> item) const { + int group_size = item.get_local_range(0); + int global_size = item.get_global_range(0); + int global_idx = item.get_group(0) * group_size + item.get_local_id(0); + + auto seeds = philox_unpack(philox_args_); + + randStatePhilox4_32_10_t state; + rand_init(std::get<0>(seeds), global_idx, std::get<1>(seeds), &state); + + for (int i = global_idx; i < numel_; i += global_size) { + auto in_offsets = inp_calc_.get(i); + auto out_offsets = out_calc_.get(i); + out_data_[out_offsets[0]] = + f_(state, inp_data_1_[in_offsets[0]], inp_data_2_[in_offsets[1]]); + } + } + + DistributionBinaryElementwiseKernelFunctor( + int numel, + func_t f, + PhiloxState philox_args, + output_t* output_data, + const input_t_1* input_data_1, + const input_t_2* input_data_2, + inp_offset_calc_t inp_calc, + out_offset_calc_t out_calc) + : numel_(numel), + f_(f), + philox_args_(philox_args), + out_data_(output_data), + inp_data_1_(input_data_1), + inp_data_2_(input_data_2), + inp_calc_(inp_calc), + out_calc_(out_calc) {} + + private: + int64_t numel_; + func_t f_; + PhiloxState philox_args_; + output_t* out_data_; + const input_t_1* inp_data_1_; + const input_t_2* inp_data_2_; + inp_offset_calc_t inp_calc_; + out_offset_calc_t out_calc_; +}; + +template +void distribution_binary_kernel( + TensorIteratorBase& iter, + PhiloxState philox_args, + const func_t& f) { + static_assert( + std::is_same< + typename function_traits::template arg<0>::type, + randStatePhilox4_32_10_t&>::value, + "the first argument of functor must be randStatePhilox4_32_10_t"); + using input_t_1 = typename function_traits::template arg<1>::type; + using input_t_2 = typename function_traits::template arg<2>::type; + using output_t = typename function_traits::result_type; + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + distribution_binary_kernel(sub_iter, philox_args, f); + } + return; + } + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing()); + + int64_t numel = iter.numel(); + if (numel == 0) { + return; + } + + auto execution_policy = calc_execution_policy(numel); + auto num_groups = std::get<1>(execution_policy); + auto group_size = std::get<2>(execution_policy); + + output_t* output_data = static_cast(iter.data_ptr(0)); + const input_t_1* input_data_1 = + static_cast(iter.data_ptr(1)); + const input_t_2* input_data_2 = + static_cast(iter.data_ptr(2)); + + if (iter.is_contiguous()) { + auto input_offset_calculator = TrivialOffsetCalculator<2>(); + auto output_offset_calculator = TrivialOffsetCalculator<1>(); + auto caller = DistributionBinaryElementwiseKernelFunctor( + numel, + f, + philox_args, + output_data, + input_data_1, + input_data_2, + input_offset_calculator, + output_offset_calculator); + sycl_kernel_submit( + num_groups * group_size, group_size, getCurrentSYCLQueue(), caller); + } else { + auto input_offset_calculator = make_input_offset_calculator<2>(iter); + auto output_offset_calculator = make_output_offset_calculator(iter); + auto caller = DistributionBinaryElementwiseKernelFunctor( + numel, + f, + philox_args, + output_data, + input_data_1, + input_data_2, + input_offset_calculator, + output_offset_calculator); + sycl_kernel_submit( + num_groups * group_size, group_size, getCurrentSYCLQueue(), caller); + } +} + } // namespace xpu } // namespace native } // namespace at @@ -688,7 +918,7 @@ void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG gen) { at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), - "exponential__xpu_", + "exponential_xpu", [&] { using accscalar_t = at::acc_type_device; auto lambd = static_cast(lambda); @@ -724,7 +954,7 @@ void log_normal_kernel( at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), - "log_normal_xpu_", + "log_normal_xpu", [&] { using accscalar_t = at::acc_type_device; auto mean_ = static_cast(mean); @@ -736,6 +966,75 @@ void log_normal_kernel( }); } +// ====================== Cauchy ====================== + +template +struct CauchyFunctor { + scalar_t operator()(accscalar_t rand) const { + return static_cast( + transformation::cauchy(rand, median_, sigma_)); + } + + CauchyFunctor(accscalar_t median, accscalar_t sigma) + : median_(median), sigma_(sigma) {} + + private: + accscalar_t median_; + accscalar_t sigma_; +}; + +template +void cauchy_kernel( + TensorIteratorBase& iter, + double median, + double sigma, + RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "cauchy_xpu", + [&] { + using accscalar_t = at::acc_type_device; + auto median_ = static_cast(median); + auto sigma_ = static_cast(sigma); + CauchyFunctor cauchy_func(median_, sigma_); + uniform_and_transform( + iter, gen, cauchy_func); + }); +} + +// ====================== Geometric ====================== + +template +struct GeometricFunctor { + scalar_t operator()(accscalar_t rand) const { + return static_cast( + transformation::geometric(rand, p_)); + } + + GeometricFunctor(accscalar_t p) : p_(p) {} + + private: + accscalar_t p_; +}; + +template +void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "geometric_xpu", + [&] { + using accscalar_t = at::acc_type_device; + auto p_ = static_cast(p); + GeometricFunctor geometric_func(p_); + uniform_and_transform( + iter, gen, geometric_func); + }); +} + } // namespace xpu } // namespace templates } // namespace native diff --git a/src/ATen/native/xpu/sycl/Distributions.cpp b/src/ATen/native/xpu/sycl/Distributions.cpp new file mode 100644 index 000000000..ab3f10243 --- /dev/null +++ b/src/ATen/native/xpu/sycl/Distributions.cpp @@ -0,0 +1,242 @@ +#include +#include +#include +#include +#include +#include + +namespace at::native::xpu { + +template +struct PoissonTensorApplyFunctor { + void operator()( + sycl::nd_item<1> item, + scalar_t& ret_val, + const scalar_t& lambda) const { + SYCL_KERNEL_ASSERT( + lambda >= 0 && + "invalid Poisson rate, expected rate to be non-negative"); + auto seeds = philox_unpack(philox_args_); + randStatePhilox4_32_10_t state; + rand_init( + std::get<0>(seeds), + item.get_group(0) * item.get_local_range(0) + item.get_local_id(0), + std::get<1>(seeds), + &state); + ret_val = static_cast(rand_poisson(&state, lambda)); + } + PoissonTensorApplyFunctor(std::pair rng_engine_inputs) + : philox_args_( + std::get<0>(rng_engine_inputs), + std::get<1>(rng_engine_inputs)) {} + + private: + PhiloxState philox_args_; +}; + +template +void poisson_kernel( + const at::TensorBase& ret, + const at::TensorBase& lambda, + std::pair rng_engine_inputs) { + auto functor = PoissonTensorApplyFunctor(rng_engine_inputs); + at::native::xpu::tensor_apply2< + scalar_t, + scalar_t, + decltype(functor), + /*max_threads_per_block=*/512>( + const_cast(ret), + const_cast(lambda), + functor); +} + +void launch_poisson_kernel( + const TensorBase& ret, + const TensorBase& lambda, + at::XPUGeneratorImpl* gen) { + std::pair rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_engine_inputs(20); + } + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + ret.scalar_type(), + "poisson_xpu", + [&] { poisson_kernel(ret, lambda, rng_engine_inputs); }); +} + +struct rand_uniform_wrapper { + rand_uniform_wrapper(randStatePhilox4_32_10_t& state) : state_(state) {} + + float operator()() { + uint32_t val = rand(&state_); // need just bits + constexpr auto MASK = static_cast( + (static_cast(1) << std::numeric_limits::digits) - 1); + constexpr auto DIVISOR = static_cast(1) / + (static_cast(1) << std::numeric_limits::digits); + return (val & MASK) * DIVISOR; + } + + randStatePhilox4_32_10_t& state_; +}; + +template +struct BinomialFunctor { + scalar_t operator()( + randStatePhilox4_32_10_t& state, + scalar_t count, + scalar_t prob) const { + auto uniform_lambda = rand_uniform_wrapper(state); + BaseSampler standard_uniform( + uniform_lambda); + auto sample = + sample_binomial( + count, prob, standard_uniform); + return static_cast(sample); + } +}; + +template +void binomial_kernel(TensorIteratorBase& iter, PhiloxState philox_args) { + using accscalar_t = at::acc_type_device; + BinomialFunctor f; + at::native::xpu::distribution_binary_kernel(iter, philox_args, f); +} + +void launch_binomial_kernel(TensorIteratorBase& iter, XPUGeneratorImpl* gen) { + std::pair engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + engine_inputs = gen->philox_engine_inputs(42); + } + PhiloxState rng_engine_inputs( + std::get<0>(engine_inputs), std::get<1>(engine_inputs)); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.input_dtype(), + "binomial_xpu", + [&] { binomial_kernel(iter, rng_engine_inputs); }); +} + +template +struct GammaTensorApplyFunctor { + void operator()( + sycl::nd_item<1> item, + scalar_t& ret_val, + const scalar_t& alpha) const { + auto seeds = philox_unpack(philox_args_); + randStatePhilox4_32_10_t state; + rand_init( + std::get<0>(seeds), + item.get_group(0) * item.get_local_range(0) + item.get_local_id(0), + std::get<1>(seeds), + &state); + + auto uniform_lambda = [&state]() { return rand_uniform(&state); }; + BaseSampler standard_uniform( + uniform_lambda); + + auto normal_lambda = [&state]() { return rand_normal(&state); }; + BaseSampler standard_normal( + normal_lambda); + + auto sample = sample_gamma< + scalar_t, + accscalar_t, + decltype(uniform_lambda), + decltype(normal_lambda)>(alpha, standard_uniform, standard_normal); + auto min_value = std::numeric_limits::min(); + ret_val = (min_value > sample) ? min_value : sample; + } + + GammaTensorApplyFunctor(PhiloxState philox_args) + : philox_args_(philox_args) {} + + private: + PhiloxState philox_args_; +}; + +template +void gamma_kernel( + const at::TensorBase& ret, + const at::TensorBase& alpha, + PhiloxState philox_args) { + using accscalar_t = at::acc_type_device; + GammaTensorApplyFunctor functor(philox_args); + at::native::xpu::tensor_apply2< + scalar_t, + scalar_t, + decltype(functor), + /*max_threads_per_block=*/512>( + const_cast(ret), + const_cast(alpha), + functor); +} + +void launch_gamma_kernel( + Tensor& ret, + const Tensor& alpha, + XPUGeneratorImpl* gen) { + std::pair engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + engine_inputs = gen->philox_engine_inputs(42); + } + PhiloxState rng_engine_inputs( + std::get<0>(engine_inputs), std::get<1>(engine_inputs)); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + ret.scalar_type(), + "gamma_xpu", + [&] { gamma_kernel(ret, alpha, rng_engine_inputs); }); +} + +template +struct DirichletKernelFunctor { + scalar_t operator()(scalar_t gamma, scalar_t gamma_sum) const { + auto ret_val = gamma / gamma_sum; + auto min_value = std::numeric_limits::min(); + auto max_value = 1 - std::numeric_limits::epsilon(); + ret_val = (min_value > ret_val) ? min_value : ret_val; + ret_val = (max_value < ret_val) ? max_value : ret_val; + return ret_val; + } +}; + +void launch_dirichlet_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.input_dtype(), + "dirichlet_xpu", + [&] { + DirichletKernelFunctor f; + gpu_kernel(iter, f); + }); +} + +template +struct DirichletGradKernelFunctor { + scalar_t operator()(scalar_t x_val, scalar_t alpha_val, scalar_t total_val) + const { + return dirichlet_grad_one( + x_val, alpha_val, total_val); + } +}; + +void launch_dirichlet_grad_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.input_dtype(), "_dirichlet_grad_xpu", [&] { + using accscalar_t = at::acc_type_device; + DirichletGradKernelFunctor f; + gpu_kernel(iter, f); + }); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/Distributions.h b/src/ATen/native/xpu/sycl/Distributions.h new file mode 100644 index 000000000..8cb5059d2 --- /dev/null +++ b/src/ATen/native/xpu/sycl/Distributions.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include + +namespace at::native::xpu { + +TORCH_XPU_API void launch_poisson_kernel( + const TensorBase& ret, + const TensorBase& lambda, + XPUGeneratorImpl* gen); + +TORCH_XPU_API void launch_binomial_kernel( + TensorIteratorBase& iter, + XPUGeneratorImpl* gen); + +TORCH_XPU_API void launch_gamma_kernel( + Tensor& ret, + const Tensor& alpha, + XPUGeneratorImpl* gen); + +TORCH_XPU_API void launch_dirichlet_kernel(TensorIteratorBase& iter); + +TORCH_XPU_API void launch_dirichlet_grad_kernel(TensorIteratorBase& iter); + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/Philox4x32.h b/src/ATen/native/xpu/sycl/Philox4x32.h index 6ed22adbe..91278e24b 100644 --- a/src/ATen/native/xpu/sycl/Philox4x32.h +++ b/src/ATen/native/xpu/sycl/Philox4x32.h @@ -366,6 +366,159 @@ static inline double rand_normal_double(randStatePhilox4_32_10_t* state) { return state->boxmuller_extra; } +static inline double lgamma_integer(int a) { + double s; + double t; + double fa = fabs((float)a); + double sum; + + if (a > 8) { + /* Stirling approximation; coefficients from Hart et al, "Computer + * Approximations", Wiley 1968. Approximation 5404. + */ + s = 1.0 / fa; + t = s * s; + sum = -0.1633436431e-2; + sum = sum * t + 0.83645878922e-3; + sum = sum * t - 0.5951896861197e-3; + sum = sum * t + 0.793650576493454e-3; + sum = sum * t - 0.277777777735865004e-2; + sum = sum * t + 0.833333333333331018375e-1; + sum = sum * s + 0.918938533204672; + s = 0.5 * logf(fa); + t = fa - 0.5; + s = s * t; + t = s - fa; + s = s + sum; + t = t + s; + return t; + } else { + switch (a) { + case 1: + return 0.000000000000000000e-1; + case 2: + return 0.000000000000000000e-1; + case 3: + return 6.931471805599453094e-1; + case 4: + return 1.791759469228055001e0; + case 5: + return 3.178053830347945620e0; + case 6: + return 4.787491742782045994e0; + case 7: + return 6.579251212010100995e0; + case 8: + return 8.525161361065414300e0; + default: + return 1.060460290274525023e1; + } + } +} + +/* Computes regularized gamma function: gammainc(a,x)/gamma(a) */ +static inline float pgammainc(float a, float x) { + float t, alpha, beta; + + /* First level parametrization constants */ + float ma1 = 1.43248035075540910f, ma2 = 0.12400979329415655f, + ma3 = 0.00025361074907033f, mb1 = 0.21096734870196546f, + mb2 = 1.97381164089999420f, mb3 = 0.94201734077887530f; + + /* Second level parametrization constants (depends only on a) */ + + alpha = 1 / sqrtf(a - ma2); + alpha = ma1 * alpha + ma3; + beta = 1 / sqrtf(a - mb2); + beta = mb1 * beta + mb3; + + /* Final approximation (depends on a and x) */ + + t = a - x; + t = alpha * t - beta; + t = 1.0f + expf(t); + t = t * t; + t = 1 / t; + + /* Negative a,x or a,x=NAN requires special handling */ + // t = !(x > 0 && a >= 0) ? 0.0 : t; + return t; +} + +/* Computes inverse of pgammainc */ +static inline float pgammaincinv(float a, float y) { + float t, alpha, beta; + + /* First level parametrization constants */ + + float ma1 = 1.43248035075540910f, ma2 = 0.12400979329415655f, + ma3 = 0.00025361074907033f, mb1 = 0.21096734870196546f, + mb2 = 1.97381164089999420f, mb3 = 0.94201734077887530f; + + /* Second level parametrization constants (depends only on a) */ + + alpha = 1.0f / sqrtf(a - ma2); + alpha = ma1 * alpha + ma3; + beta = 1.0f / sqrtf(a - mb2); + beta = mb1 * beta + mb3; + + /* Final approximation (depends on a and y) */ + + t = 1.0f / sqrtf(y) - 1.0f; + t = logf(t); + t = beta + t; + t = -t * (1 / alpha) + a; + /* Negative a,x or a,x=NAN requires special handling */ + // t = !(y > 0 && a >= 0) ? 0.0 : t; + return t; +} + +/* Rejection Method for Poisson distribution based on gammainc approximation */ +static inline unsigned int rand_poisson_gammainc( + randStatePhilox4_32_10_t* state, + float lambda) { + float y, x, t, z, v; + float logl = logf(lambda); + while (true) { + y = rand_uniform(state); + x = pgammaincinv(lambda, y); + x = floorf(x); + z = rand_uniform(state); + v = (pgammainc(lambda, x + 1.0f) - pgammainc(lambda, x)) * 1.3f; + z = z * v; + t = (float)expf( + -lambda + x * logl - (float)lgamma_integer((int)(1.0f + x))); + if ((z < t) && (v >= 1e-20)) + break; + } + return (unsigned int)x; +} + +// Donald E. Knuth Seminumerical Algorithms. The Art of Computer Programming, +// Volume 2 +static inline unsigned int rand_poisson_knuth( + randStatePhilox4_32_10_t* state, + float lambda) { + unsigned int k = 0; + float p = expf(lambda); + do { + k++; + p *= rand_uniform(state); + } while (p > 1.0); + return k - 1; +} + +static inline unsigned int rand_poisson( + randStatePhilox4_32_10_t* state, + double lambda) { + if (lambda < 64) + return rand_poisson_knuth(state, (float)lambda); + if (lambda > 4000) + return ( + unsigned int)((std::sqrt(lambda) * rand_normal_double(state)) + lambda + 0.5); // Round to nearest + return rand_poisson_gammainc(state, (float)lambda); +} + } // namespace xpu } // namespace native } // namespace at diff --git a/test/xpu/test_distributions_xpu.py b/test/xpu/test_distributions_xpu.py new file mode 100644 index 000000000..1cbf1a20f --- /dev/null +++ b/test/xpu/test_distributions_xpu.py @@ -0,0 +1,350 @@ +# Owner(s): ["module: intel"] + + +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_utils import run_tests + +try: + from xpu_test_utils import XPUPatchForImport +except Exception as e: + from .xpu_test_utils import XPUPatchForImport +with XPUPatchForImport(False): + from test_distributions import ( + TestDistributions, + TestRsample, + TestDistributionShapes, + TestKL, + TestConstraints, + TestNumericalStability, + TestLazyLogitsInitialization, + TestAgainstScipy, + TestFunctors, + TestValidation, + TestJit, + _get_examples, + pairwise, + ) +import torch +from torch.distributions import ( + Bernoulli, + Beta, + Binomial, + Categorical, + Cauchy, + Chi2, + constraints, + ContinuousBernoulli, + Dirichlet, + Distribution, + Exponential, + ExponentialFamily, + FisherSnedecor, + Gamma, + Geometric, + Gumbel, + HalfCauchy, + HalfNormal, + Independent, + InverseGamma, + kl_divergence, + Kumaraswamy, + Laplace, + LKJCholesky, + LogisticNormal, + LogNormal, + LowRankMultivariateNormal, + MixtureSameFamily, + Multinomial, + MultivariateNormal, + NegativeBinomial, + Normal, + OneHotCategorical, + OneHotCategoricalStraightThrough, + Pareto, + Poisson, + RelaxedBernoulli, + RelaxedOneHotCategorical, + StudentT, + TransformedDistribution, + Uniform, + VonMises, + Weibull, + Wishart, +) +from torch.testing._internal.common_utils import set_rng_seed + + +def _test_beta_underflow_gpu(self): + set_rng_seed(1) + num_samples = 50000 + conc = torch.tensor(1e-2, dtype=torch.float64).xpu() + beta_samples = Beta(conc, conc).sample([num_samples]) + self.assertEqual((beta_samples == 0).sum(), 0) + self.assertEqual((beta_samples == 1).sum(), 0) + # assert support is concentrated around 0 and 1 + frac_zeros = float((beta_samples < 0.1).sum()) / num_samples + frac_ones = float((beta_samples > 0.9).sum()) / num_samples + # TODO: increase precision once imbalance on GPU is fixed. + self.assertEqual(frac_zeros, 0.5, atol=0.12, rtol=0) + self.assertEqual(frac_ones, 0.5, atol=0.12, rtol=0) + + +def _test_zero_excluded_binomial(self): + vals = Binomial( + total_count=torch.tensor(1.0).xpu(), probs=torch.tensor(0.9).xpu() + ).sample(torch.Size((100000000,))) + self.assertTrue((vals >= 0).all()) + vals = Binomial( + total_count=torch.tensor(1.0).xpu(), probs=torch.tensor(0.1).xpu() + ).sample(torch.Size((100000000,))) + self.assertTrue((vals < 2).all()) + vals = Binomial( + total_count=torch.tensor(1.0).xpu(), probs=torch.tensor(0.5).xpu() + ).sample(torch.Size((10000,))) + # vals should be roughly half zeroes, half ones + assert (vals == 0.0).sum() > 4000 + assert (vals == 1.0).sum() > 4000 + + +TestDistributions.test_beta_underflow_gpu = _test_beta_underflow_gpu +TestDistributions.test_zero_excluded_binomial = _test_zero_excluded_binomial +instantiate_device_type_tests(TestDistributions, globals(), only_for="xpu", allow_xpu=True) +instantiate_device_type_tests(TestRsample, globals(), only_for="xpu", allow_xpu=True) + + +def setup_1(self): + self.scalar_sample = 1 + self.tensor_sample_1 = torch.ones(3, 2) + self.tensor_sample_2 = torch.ones(3, 2, 3) + + +TestDistributionShapes.setUp = setup_1 +instantiate_device_type_tests(TestDistributionShapes, globals(), only_for="xpu", allow_xpu=True) + + +def setup_3(self): + class Binomial30(Binomial): + def __init__(self, probs): + super().__init__(30, probs) + + # These are pairs of distributions with 4 x 4 parameters as specified. + # The first of the pair e.g. bernoulli[0] varies column-wise and the second + # e.g. bernoulli[1] varies row-wise; that way we test all param pairs. + bernoulli = pairwise(Bernoulli, [0.1, 0.2, 0.6, 0.9]) + binomial30 = pairwise(Binomial30, [0.1, 0.2, 0.6, 0.9]) + binomial_vectorized_count = ( + Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])), + Binomial(torch.tensor([3, 4]), torch.tensor([0.5, 0.8])), + ) + beta = pairwise(Beta, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5]) + categorical = pairwise( + Categorical, + [[0.4, 0.3, 0.3], [0.2, 0.7, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.6]], + ) + cauchy = pairwise(Cauchy, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0]) + chi2 = pairwise(Chi2, [1.0, 2.0, 2.5, 5.0]) + dirichlet = pairwise( + Dirichlet, + [[0.1, 0.2, 0.7], [0.5, 0.4, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.4]], + ) + exponential = pairwise(Exponential, [1.0, 2.5, 5.0, 10.0]) + gamma = pairwise(Gamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5]) + gumbel = pairwise(Gumbel, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5]) + halfnormal = pairwise(HalfNormal, [1.0, 2.0, 1.0, 2.0]) + inversegamma = pairwise( + InverseGamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5] + ) + laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5]) + lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0]) + normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0]) + independent = (Independent(normal[0], 1), Independent(normal[1], 1)) + onehotcategorical = pairwise( + OneHotCategorical, + [[0.4, 0.3, 0.3], [0.2, 0.7, 0.1], [0.33, 0.33, 0.34], [0.2, 0.2, 0.6]], + ) + pareto = ( + Pareto( + torch.tensor([2.5, 4.0, 2.5, 4.0]).expand(4, 4), + torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4), + ), + Pareto( + torch.tensor([2.25, 3.75, 2.25, 3.8]).expand(4, 4), + torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4), + ), + ) + poisson = pairwise(Poisson, [0.3, 1.0, 5.0, 10.0]) + uniform_within_unit = pairwise( + Uniform, [0.1, 0.9, 0.2, 0.75], [0.15, 0.95, 0.25, 0.8] + ) + uniform_positive = pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7]) + uniform_real = pairwise(Uniform, [-2.0, -1, 0, 2], [-1.0, 1, 1, 4]) + uniform_pareto = pairwise(Uniform, [6.5, 7.5, 6.5, 8.5], [7.5, 8.5, 9.5, 9.5]) + continuous_bernoulli = pairwise(ContinuousBernoulli, [0.1, 0.2, 0.5, 0.9]) + + # These tests should pass with precision = 0.01, but that makes tests very expensive. + # Instead, we test with precision = 0.1 and only test with higher precision locally + # when adding a new KL implementation. + # The following pairs are not tested due to very high variance of the monte carlo + # estimator; their implementations have been reviewed with extra care: + # - (pareto, normal) + self.precision = 0.1 # Set this to 0.01 when testing a new KL implementation. + self.max_samples = int(1e07) # Increase this when testing at smaller precision. + self.samples_per_batch = int(1e04) + self.finite_examples = [ + (bernoulli, bernoulli), + (bernoulli, poisson), + (beta, beta), + (beta, chi2), + (beta, exponential), + (beta, gamma), + (beta, normal), + (binomial30, binomial30), + (binomial_vectorized_count, binomial_vectorized_count), + (categorical, categorical), + (cauchy, cauchy), + (chi2, chi2), + (chi2, exponential), + (chi2, gamma), + (chi2, normal), + (dirichlet, dirichlet), + (exponential, chi2), + (exponential, exponential), + (exponential, gamma), + (exponential, gumbel), + (exponential, normal), + (gamma, chi2), + (gamma, exponential), + (gamma, gamma), + (gamma, gumbel), + (gamma, normal), + (gumbel, gumbel), + (gumbel, normal), + (halfnormal, halfnormal), + (independent, independent), + (inversegamma, inversegamma), + (laplace, laplace), + (lognormal, lognormal), + (laplace, normal), + (normal, gumbel), + (normal, laplace), + (normal, normal), + (onehotcategorical, onehotcategorical), + (pareto, chi2), + (pareto, pareto), + (pareto, exponential), + (pareto, gamma), + (poisson, poisson), + (uniform_within_unit, beta), + (uniform_positive, chi2), + (uniform_positive, exponential), + (uniform_positive, gamma), + (uniform_real, gumbel), + (uniform_real, normal), + (uniform_pareto, pareto), + (continuous_bernoulli, continuous_bernoulli), + (continuous_bernoulli, exponential), + (continuous_bernoulli, normal), + (beta, continuous_bernoulli), + ] + + self.infinite_examples = [ + (Bernoulli(0), Bernoulli(1)), + (Bernoulli(1), Bernoulli(0)), + ( + Categorical(torch.tensor([0.9, 0.1])), + Categorical(torch.tensor([1.0, 0.0])), + ), + ( + Categorical(torch.tensor([[0.9, 0.1], [0.9, 0.1]])), + Categorical(torch.tensor([1.0, 0.0])), + ), + (Beta(1, 2), Uniform(0.25, 1)), + (Beta(1, 2), Uniform(0, 0.75)), + (Beta(1, 2), Uniform(0.25, 0.75)), + (Beta(1, 2), Pareto(1, 2)), + (Binomial(31, 0.7), Binomial(30, 0.3)), + ( + Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])), + Binomial(torch.tensor([2, 3]), torch.tensor([0.5, 0.8])), + ), + (Chi2(1), Beta(2, 3)), + (Chi2(1), Pareto(2, 3)), + (Chi2(1), Uniform(-2, 3)), + (Exponential(1), Beta(2, 3)), + (Exponential(1), Pareto(2, 3)), + (Exponential(1), Uniform(-2, 3)), + (Gamma(1, 2), Beta(3, 4)), + (Gamma(1, 2), Pareto(3, 4)), + (Gamma(1, 2), Uniform(-3, 4)), + (Gumbel(-1, 2), Beta(3, 4)), + (Gumbel(-1, 2), Chi2(3)), + (Gumbel(-1, 2), Exponential(3)), + (Gumbel(-1, 2), Gamma(3, 4)), + (Gumbel(-1, 2), Pareto(3, 4)), + (Gumbel(-1, 2), Uniform(-3, 4)), + (Laplace(-1, 2), Beta(3, 4)), + (Laplace(-1, 2), Chi2(3)), + (Laplace(-1, 2), Exponential(3)), + (Laplace(-1, 2), Gamma(3, 4)), + (Laplace(-1, 2), Pareto(3, 4)), + (Laplace(-1, 2), Uniform(-3, 4)), + (Normal(-1, 2), Beta(3, 4)), + (Normal(-1, 2), Chi2(3)), + (Normal(-1, 2), Exponential(3)), + (Normal(-1, 2), Gamma(3, 4)), + (Normal(-1, 2), Pareto(3, 4)), + (Normal(-1, 2), Uniform(-3, 4)), + (Pareto(2, 1), Chi2(3)), + (Pareto(2, 1), Exponential(3)), + (Pareto(2, 1), Gamma(3, 4)), + (Pareto(1, 2), Normal(-3, 4)), + (Pareto(1, 2), Pareto(3, 4)), + (Poisson(2), Bernoulli(0.5)), + (Poisson(2.3), Binomial(10, 0.2)), + (Uniform(-1, 1), Beta(2, 2)), + (Uniform(0, 2), Beta(3, 4)), + (Uniform(-1, 2), Beta(3, 4)), + (Uniform(-1, 2), Chi2(3)), + (Uniform(-1, 2), Exponential(3)), + (Uniform(-1, 2), Gamma(3, 4)), + (Uniform(-1, 2), Pareto(3, 4)), + (ContinuousBernoulli(0.25), Uniform(0.25, 1)), + (ContinuousBernoulli(0.25), Uniform(0, 0.75)), + (ContinuousBernoulli(0.25), Uniform(0.25, 0.75)), + (ContinuousBernoulli(0.25), Pareto(1, 2)), + (Exponential(1), ContinuousBernoulli(0.75)), + (Gamma(1, 2), ContinuousBernoulli(0.75)), + (Gumbel(-1, 2), ContinuousBernoulli(0.75)), + (Laplace(-1, 2), ContinuousBernoulli(0.75)), + (Normal(-1, 2), ContinuousBernoulli(0.75)), + (Uniform(-1, 1), ContinuousBernoulli(0.75)), + (Uniform(0, 2), ContinuousBernoulli(0.75)), + (Uniform(-1, 2), ContinuousBernoulli(0.75)), + ] + + +TestKL.setUp = setup_3 +instantiate_device_type_tests(TestKL, globals(), only_for="xpu", allow_xpu=True) +instantiate_device_type_tests(TestConstraints, globals(), only_for="xpu", allow_xpu=True) +instantiate_device_type_tests(TestNumericalStability, globals(), only_for="xpu", allow_xpu=True) + + +def setup_2(self): + self.examples = [ + e + for e in _get_examples() + if e.Dist + in (Categorical, OneHotCategorical, Bernoulli, Binomial, Multinomial) + ] + + +TestLazyLogitsInitialization.setUp = setup_2 +instantiate_device_type_tests(TestLazyLogitsInitialization, globals(), only_for="xpu", allow_xpu=True) +instantiate_device_type_tests(TestAgainstScipy, globals(), only_for="xpu", allow_xpu=True) +instantiate_device_type_tests(TestFunctors, globals(), only_for="xpu", allow_xpu=True) +instantiate_device_type_tests(TestValidation, globals(), only_for="xpu", allow_xpu=True) +instantiate_device_type_tests(TestJit, globals(), only_for="xpu", allow_xpu=True) + + +if __name__ == "__main__": + run_tests() diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 1f88dd914..1e3f52847 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -254,6 +254,8 @@ "square", "heaviside", "argsort", + "cauchy", + "geometric", "log_normal", ] @@ -707,6 +709,7 @@ def __init__(self, patch_test_case=True) -> None: self.test_package = ( os.path.dirname(os.path.abspath(__file__)) + "/../../../../test", os.path.dirname(os.path.abspath(__file__)) + "/../../../../test/nn", + os.path.dirname(os.path.abspath(__file__)) + "/../../../../test/distributions", ) self.patch_test_case = patch_test_case self.original_path = sys.path.copy() diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index feab4d218..b169c6b59 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -4721,6 +4721,43 @@ CompositeExplicitAutograd: binary_cross_entropy_with_logits autogen: binary_cross_entropy_with_logits.out +- func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: method + tags: nondeterministic_seeded + dispatch: + XPU: cauchy_ + autogen: cauchy, cauchy.out + +- func: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) + device_check: NoCheck # TensorIterator + tags: nondeterministic_seeded + variants: method + dispatch: + XPU: geometric_ + + # wrappers for TH functions + autogen: geometric, geometric.out + +- func: binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor + device_check: NoCheck # TensorIterator + dispatch: + XPU: _s_binomial_xpu + tags: nondeterministic_seeded + autogen: binomial.out + +- func: _sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor + tags: nondeterministic_seeded + variants: function + dispatch: + XPU: _sample_dirichlet_xpu + autogen: _sample_dirichlet.out + +- func: _dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor + dispatch: + XPU: _dirichlet_grad_xpu + autogen: _dirichlet_grad.out + - func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor structured_delegate: argmin.out device_check: NoCheck # TensorIterator @@ -5919,6 +5956,13 @@ XPU: angle_out tags: pointwise +- func: poisson(Tensor self, Generator? generator=None) -> Tensor + device_check: NoCheck # TensorIterator + dispatch: + XPU: _s_poisson_xpu + tags: nondeterministic_seeded + autogen: poisson.out + - func: channel_shuffle(Tensor self, SymInt groups) -> Tensor dispatch: XPU: channel_shuffle