From 0cd30918614a7a66820c8491697607d17939d530 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Tue, 5 Nov 2024 00:06:11 -0800 Subject: [PATCH] Add aten::_standard_gamma (#1040) - `_standard_gamma` - `_standard_gamma_grad` --- src/ATen/native/xpu/Distributions.cpp | 19 +++++++++++++++++++ src/ATen/native/xpu/sycl/Distributions.cpp | 20 ++++++++++++++++++++ src/ATen/native/xpu/sycl/Distributions.h | 2 ++ yaml/native/native_functions.yaml | 13 +++++++++++++ 4 files changed, 54 insertions(+) diff --git a/src/ATen/native/xpu/Distributions.cpp b/src/ATen/native/xpu/Distributions.cpp index d7f602d6e..d6a683ff2 100644 --- a/src/ATen/native/xpu/Distributions.cpp +++ b/src/ATen/native/xpu/Distributions.cpp @@ -57,6 +57,14 @@ Tensor _s_binomial_xpu( return ret; } +Tensor _s_gamma_xpu(const Tensor& alpha, c10::optional gen_) { + auto gen = get_generator_or_default( + gen_, at::xpu::detail::getDefaultXPUGenerator()); + Tensor ret = at::empty(alpha.sizes(), alpha.options()); + xpu::launch_gamma_kernel(ret, alpha, gen); + return ret; +} + Tensor _sample_dirichlet_xpu( const Tensor& alpha, std::optional generator) { @@ -74,6 +82,17 @@ Tensor _sample_dirichlet_xpu( return ret; } +Tensor _standard_gamma_grad_xpu(const Tensor& self, const Tensor& output) { + Tensor ret = at::empty(self.sizes(), self.options()); + TensorIterator iter = TensorIteratorConfig() + .add_output(ret) + .add_input(self) + .add_input(output) + .build(); + xpu::launch_standard_gamma_grad_kernel(iter); + return ret; +} + Tensor _dirichlet_grad_xpu( const Tensor& x, const Tensor& alpha, diff --git a/src/ATen/native/xpu/sycl/Distributions.cpp b/src/ATen/native/xpu/sycl/Distributions.cpp index bdd18412d..b20bd10b6 100644 --- a/src/ATen/native/xpu/sycl/Distributions.cpp +++ b/src/ATen/native/xpu/sycl/Distributions.cpp @@ -199,6 +199,26 @@ void launch_gamma_kernel( [&] { gamma_kernel(ret, alpha, rng_engine_inputs); }); } +template +struct StandardGammaGradKernelFunctor { + scalar_t operator()(scalar_t self_val, scalar_t output_val) const { + return standard_gamma_grad_one(self_val, output_val); + } +}; + +void launch_standard_gamma_grad_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.input_dtype(), + "_standard_gamma_grad_xpu", + [&] { + using accscalar_t = at::acc_type_device; + StandardGammaGradKernelFunctor f; + gpu_kernel(iter, f); + }); +} + template struct DirichletKernelFunctor { scalar_t operator()(scalar_t gamma, scalar_t gamma_sum) const { diff --git a/src/ATen/native/xpu/sycl/Distributions.h b/src/ATen/native/xpu/sycl/Distributions.h index 8cb5059d2..a85632efd 100644 --- a/src/ATen/native/xpu/sycl/Distributions.h +++ b/src/ATen/native/xpu/sycl/Distributions.h @@ -19,6 +19,8 @@ TORCH_XPU_API void launch_gamma_kernel( const Tensor& alpha, XPUGeneratorImpl* gen); +TORCH_XPU_API void launch_standard_gamma_grad_kernel(TensorIteratorBase& iter); + TORCH_XPU_API void launch_dirichlet_kernel(TensorIteratorBase& iter); TORCH_XPU_API void launch_dirichlet_grad_kernel(TensorIteratorBase& iter); diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index acb719b18..5996243d8 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -5377,6 +5377,19 @@ tags: nondeterministic_seeded autogen: binomial.out +- func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor + variants: function + dispatch: + XPU: _standard_gamma_grad_xpu + autogen: _standard_gamma_grad.out + +- func: _standard_gamma(Tensor self, Generator? generator=None) -> Tensor + variants: function + dispatch: + XPU: _s_gamma_xpu + tags: nondeterministic_seeded + autogen: _standard_gamma.out + - func: _sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor tags: nondeterministic_seeded variants: function