Skip to content

Commit

Permalink
Merge branch 'main' into xyt/topk
Browse files Browse the repository at this point in the history
  • Loading branch information
xytintel authored Jul 12, 2024
2 parents bc1a3c7 + 78575b6 commit 89c712d
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 2 deletions.
23 changes: 23 additions & 0 deletions src/ATen/native/xpu/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ATen/native/xpu/sycl/BinaryKernels.h>
#include <ATen/native/xpu/sycl/BinaryMiscBackwardOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryRemainderKernel.h>
#include <ATen/native/xpu/sycl/CopysignKernel.h>
#include <ATen/native/xpu/sycl/GcdLcmKernels.h>
#include <ATen/native/xpu/sycl/MaxMinElementwiseKernels.h>

Expand Down Expand Up @@ -502,4 +503,26 @@ Tensor& XPUNativeFunctions::atan2_out(
return out;
}

Tensor& XPUNativeFunctions::copysign_out(
const Tensor& self,
const Tensor& other,
Tensor& out) {
TensorIterator iter;
iter.build_borrowing_binary_float_op(out, self, other);
native::xpu::copysign_kernel(iter);
return out;
}

Tensor& XPUNativeFunctions::copysign_(Tensor& self, const Tensor& other) {
return XPUNativeFunctions::copysign_out(self, other, self);
}

Tensor XPUNativeFunctions::copysign(const Tensor& self, const Tensor& other) {
Tensor out;
TensorIterator iter;
iter.build_borrowing_binary_float_op(out, self, other);
native::xpu::copysign_kernel(iter);
return iter.output();
}

} // namespace at
1 change: 1 addition & 0 deletions src/ATen/native/xpu/Indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ Tensor XPUNativeFunctions::index_select(
auto out = at::empty({0}, self.options());
return index_select_out(self, dim, index, out);
}

} // namespace at
4 changes: 4 additions & 0 deletions src/ATen/native/xpu/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1396,4 +1396,8 @@ Tensor& XPUNativeFunctions::gather_out(
return out;
}

Tensor XPUNativeFunctions::count_nonzero(const Tensor& self, IntArrayRef dims) {
return (self != 0).sum(dims);
}

} // namespace at
2 changes: 0 additions & 2 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"cholesky",
"cholesky_inverse",
"_cholesky_solve_helper",
"copysign.out",
"count_nonzero.dim_IntList",
"_ctc_loss",
"_ctc_loss_backward",
"_cummax_helper",
Expand Down
24 changes: 24 additions & 0 deletions src/ATen/native/xpu/sycl/CopysignKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>

#include <ATen/native/xpu/sycl/Loops.h>

namespace at::native::xpu {

template <typename scalar_t>
struct CopysignFunctor {
scalar_t operator()(scalar_t a, scalar_t b) const {
return std::copysign(a, b);
}
};

void copysign_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.common_dtype(),
"copysign_xpu",
[&]() { gpu_kernel_with_scalars(iter, CopysignFunctor<scalar_t>()); });
}

} // namespace at::native::xpu
9 changes: 9 additions & 0 deletions src/ATen/native/xpu/sycl/CopysignKernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

#include <ATen/native/TensorIterator.h>

namespace at::native::xpu {

void copysign_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
2 changes: 2 additions & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@
"renorm",
"lerp",
"conj_physical",
"copysign",
"count_nonzero"
]


Expand Down
4 changes: 4 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,10 @@ supported:
- randperm.generator_out
- _amp_foreach_non_finite_check_and_unscale_
- _amp_update_scale_
- copysign.out
- copysign.Tensor
- copysign_.Tensor
- count_nonzero.dim_IntList
- conj_physical.out
- conj_physical_
- ceil
Expand Down

0 comments on commit 89c712d

Please sign in to comment.