Skip to content

Commit

Permalink
Add distribution operators (#826)
Browse files Browse the repository at this point in the history
- cauchy_
- geometric
- binomial
- binomial.out
- _dirichlet_grad
- _dirichlet_grad.out
- _sample_dirichlet
- _sample_dirichlet.out

---------

Co-authored-by: Yutao Xu <[email protected]>
  • Loading branch information
hjhee and xytintel authored Oct 18, 2024
1 parent f89ff3f commit c9a94ba
Show file tree
Hide file tree
Showing 12 changed files with 1,223 additions and 4 deletions.
60 changes: 60 additions & 0 deletions src/ATen/native/xpu/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ATen/native/UnaryOps.h>

#include <ATen/native/xpu/sycl/DistributionKernels.h>
#include <ATen/native/xpu/sycl/Distributions.h>
#include <ATen/native/xpu/sycl/MultinomialKernel.h>
#include <ATen/ops/div.h>
#include <comm/xpu_aten.h>
Expand All @@ -29,5 +30,64 @@ REGISTER_XPU_DISPATCH(
multinomial_with_replacement_stub,
&xpu::multinomial_kernel);
REGISTER_XPU_DISPATCH(log_normal_stub, &xpu::log_normal_kernel);
REGISTER_XPU_DISPATCH(cauchy_stub, &xpu::cauchy_kernel);
REGISTER_XPU_DISPATCH(geometric_stub, &xpu::geometric_kernel);

Tensor _s_poisson_xpu(const Tensor& lambda, std::optional<Generator> gen_) {
auto gen = get_generator_or_default<at::XPUGeneratorImpl>(
gen_, at::xpu::detail::getDefaultXPUGenerator());
Tensor ret = at::empty(lambda.sizes(), lambda.options());
xpu::launch_poisson_kernel(ret, lambda, gen);
return ret;
}

Tensor _s_binomial_xpu(
const Tensor& count,
const Tensor& prob,
std::optional<Generator> generator) {
auto gen = get_generator_or_default<at::XPUGeneratorImpl>(
generator, at::xpu::detail::getDefaultXPUGenerator());
Tensor ret = at::empty(count.sizes(), count.options());
at::TensorIterator iter = at::TensorIteratorConfig()
.add_output(ret)
.add_input(count)
.add_input(prob)
.build();
xpu::launch_binomial_kernel(iter, gen);
return ret;
}

Tensor _sample_dirichlet_xpu(
const Tensor& alpha,
std::optional<Generator> generator) {
auto gen = get_generator_or_default<at::XPUGeneratorImpl>(
generator, at::xpu::detail::getDefaultXPUGenerator());
Tensor ret = at::empty(alpha.sizes(), alpha.options());
xpu::launch_gamma_kernel(ret, alpha, gen);
auto gamma_sum = ret.sum(/*dim=*/-1, /*keepdim=*/true);
auto iter = at::TensorIteratorConfig()
.add_output(ret)
.add_input(ret)
.add_input(gamma_sum)
.build();
xpu::launch_dirichlet_kernel(iter);
return ret;
}

Tensor _dirichlet_grad_xpu(
const Tensor& x,
const Tensor& alpha,
const Tensor& total) {
Tensor ret = at::empty(x.sizes(), x.options());
auto iter = at::TensorIteratorConfig()
.add_output(ret)
.add_input(x)
.add_input(alpha)
.add_input(total)
.build();
xpu::launch_dirichlet_grad_kernel(iter);
return ret;
}

} // namespace native
} // namespace at
2 changes: 0 additions & 2 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"adaptive_max_pool3d.out",
"avg_pool3d_backward.grad_input",
"avg_pool3d.out",
"cauchy_",
"_cdist_backward",
"cholesky",
"cholesky_inverse",
Expand All @@ -180,7 +179,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"fractional_max_pool3d.output",
"frexp.Tensor_out",
"_fused_moving_avg_obs_fq_helper",
"geometric_",
"geqrf",
"heaviside.out",
"histc",
Expand Down
17 changes: 17 additions & 0 deletions src/ATen/native/xpu/sycl/DistributionCauchyKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/DistributionTemplates.h>
#include <ATen/xpu/XPUGeneratorImpl.h>

namespace at::native::xpu {

void cauchy_kernel(
TensorIteratorBase& iter,
double median,
double sigma,
c10::optional<Generator> gen) {
auto generator = get_generator_or_default<at::XPUGeneratorImpl>(
gen, at::xpu::detail::getDefaultXPUGenerator());
at::native::templates::xpu::cauchy_kernel(iter, median, sigma, generator);
}

} // namespace at::native::xpu
16 changes: 16 additions & 0 deletions src/ATen/native/xpu/sycl/DistributionGeometricKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/DistributionTemplates.h>
#include <ATen/xpu/XPUGeneratorImpl.h>

namespace at::native::xpu {

void geometric_kernel(
TensorIteratorBase& iter,
double p_,
c10::optional<Generator> gen) {
auto generator = get_generator_or_default<at::XPUGeneratorImpl>(
gen, at::xpu::detail::getDefaultXPUGenerator());
at::native::templates::xpu::geometric_kernel(iter, p_, generator);
}

} // namespace at::native::xpu
11 changes: 11 additions & 0 deletions src/ATen/native/xpu/sycl/DistributionKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,15 @@ TORCH_XPU_API void log_normal_kernel(
double std,
std::optional<Generator> gen);

TORCH_XPU_API void cauchy_kernel(
TensorIteratorBase& iter,
double median,
double sigma,
c10::optional<Generator> gen);

TORCH_XPU_API void geometric_kernel(
TensorIteratorBase& iter,
double p_,
c10::optional<Generator> gen);

} // namespace at::native::xpu
Loading

0 comments on commit c9a94ba

Please sign in to comment.