Skip to content

Commit

Permalink
Add aten::_standard_gamma (#1040)
Browse files Browse the repository at this point in the history
- `_standard_gamma`
- `_standard_gamma_grad`
  • Loading branch information
min-jean-cho authored Nov 5, 2024
1 parent 2ac2c45 commit 0cd3091
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/ATen/native/xpu/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ Tensor _s_binomial_xpu(
return ret;
}

Tensor _s_gamma_xpu(const Tensor& alpha, c10::optional<Generator> gen_) {
auto gen = get_generator_or_default<at::XPUGeneratorImpl>(
gen_, at::xpu::detail::getDefaultXPUGenerator());
Tensor ret = at::empty(alpha.sizes(), alpha.options());
xpu::launch_gamma_kernel(ret, alpha, gen);
return ret;
}

Tensor _sample_dirichlet_xpu(
const Tensor& alpha,
std::optional<Generator> generator) {
Expand All @@ -74,6 +82,17 @@ Tensor _sample_dirichlet_xpu(
return ret;
}

Tensor _standard_gamma_grad_xpu(const Tensor& self, const Tensor& output) {
Tensor ret = at::empty(self.sizes(), self.options());
TensorIterator iter = TensorIteratorConfig()
.add_output(ret)
.add_input(self)
.add_input(output)
.build();
xpu::launch_standard_gamma_grad_kernel(iter);
return ret;
}

Tensor _dirichlet_grad_xpu(
const Tensor& x,
const Tensor& alpha,
Expand Down
20 changes: 20 additions & 0 deletions src/ATen/native/xpu/sycl/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,26 @@ void launch_gamma_kernel(
[&] { gamma_kernel<scalar_t>(ret, alpha, rng_engine_inputs); });
}

template <typename scalar_t, typename accscalar_t>
struct StandardGammaGradKernelFunctor {
scalar_t operator()(scalar_t self_val, scalar_t output_val) const {
return standard_gamma_grad_one<scalar_t, accscalar_t>(self_val, output_val);
}
};

void launch_standard_gamma_grad_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.input_dtype(),
"_standard_gamma_grad_xpu",
[&] {
using accscalar_t = at::acc_type_device<scalar_t, kXPU>;
StandardGammaGradKernelFunctor<scalar_t, accscalar_t> f;
gpu_kernel(iter, f);
});
}

template <typename scalar_t>
struct DirichletKernelFunctor {
scalar_t operator()(scalar_t gamma, scalar_t gamma_sum) const {
Expand Down
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/Distributions.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ TORCH_XPU_API void launch_gamma_kernel(
const Tensor& alpha,
XPUGeneratorImpl* gen);

TORCH_XPU_API void launch_standard_gamma_grad_kernel(TensorIteratorBase& iter);

TORCH_XPU_API void launch_dirichlet_kernel(TensorIteratorBase& iter);

TORCH_XPU_API void launch_dirichlet_grad_kernel(TensorIteratorBase& iter);
Expand Down
13 changes: 13 additions & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5377,6 +5377,19 @@
tags: nondeterministic_seeded
autogen: binomial.out

- func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
variants: function
dispatch:
XPU: _standard_gamma_grad_xpu
autogen: _standard_gamma_grad.out

- func: _standard_gamma(Tensor self, Generator? generator=None) -> Tensor
variants: function
dispatch:
XPU: _s_gamma_xpu
tags: nondeterministic_seeded
autogen: _standard_gamma.out

- func: _sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor
tags: nondeterministic_seeded
variants: function
Expand Down

0 comments on commit 0cd3091

Please sign in to comment.