Skip to content

Commit

Permalink
Add aten::_foreach_{max, rsqrt, sqrt_, lerp.ScalarList, lerp_.ScalarL…
Browse files Browse the repository at this point in the history
…ist} (#1114)

- _foreach_max
- _foreach_rsqrt
- \_foreach_rsqrt_
- _foreach_lerp.ScalarList
- \_foreach_lerp_.ScalarList

---------

Co-authored-by: Yutao Xu <[email protected]>
  • Loading branch information
LuFinch and xytintel authored Nov 25, 2024
1 parent 3af6f73 commit 15f6d65
Show file tree
Hide file tree
Showing 12 changed files with 562 additions and 1 deletion.
36 changes: 36 additions & 0 deletions src/ATen/native/xpu/ForeachOpScalarList.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
#include <ATen/ops/_foreach_clamp_max_native.h>
#include <ATen/ops/_foreach_clamp_min_native.h>
#include <ATen/ops/_foreach_div_native.h>
#include <ATen/ops/_foreach_lerp_native.h>
#include <ATen/ops/_foreach_mul_native.h>
#include <ATen/ops/_foreach_pow_native.h>
#include <ATen/ops/_foreach_sub_native.h>
#include <ATen/ops/empty_like.h>

#include <ATen/native/xpu/sycl/ForeachBinaryOpScalarListKernels.h>
#include <ATen/native/xpu/sycl/ForeachPointwiseOpScalarListKernels.h>
#include <ATen/native/xpu/sycl/ForeachTernaryOpScalarListKernels.h>

#include <xpu/ATen/ops/_foreach_add_native.h>
#include <xpu/ATen/ops/_foreach_mul_native.h>
Expand Down Expand Up @@ -118,5 +121,38 @@ FOREACH_BINARY_OP_SCALARLIST(pow, true);
FOREACH_POINTWISE_OP_SCALARLIST(addcmul)
FOREACH_POINTWISE_OP_SCALARLIST(addcdiv)

std::vector<at::Tensor> foreach_tensor_lerp_scalarlist_xpu(
TensorList tensors1,
TensorList tensors2,
at::ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors1, tensors2, scalars);
if (!can_use_fast_route({tensors1, tensors2}, scalars, true)) {
return foreach_tensor_lerp_scalarlist_kernel_slow(
tensors1, tensors2, scalars);
}

std::vector<at::Tensor> vec_res;
vec_res.reserve(tensors1.size());
for (const auto& t : tensors1) {
vec_res.emplace_back(at::empty_like(t));
}

xpu::foreach_lerp_scalarlist_kernel(tensors1, tensors2, scalars, vec_res);
return vec_res;
}

void foreach_tensor_lerp_scalarlist_xpu_(
TensorList tensors1,
TensorList tensors2,
at::ArrayRef<Scalar> scalars) {
check_foreach_api_restrictions(tensors1, tensors2, scalars);
if (!can_use_fast_route({tensors1, tensors2}, scalars, true)) {
return foreach_tensor_lerp_scalarlist_kernel_slow_(
tensors1, tensors2, scalars);
}

xpu::foreach_lerp_scalarlist_kernel_(tensors1, tensors2, scalars);
}

}; // namespace native
} // namespace at
19 changes: 19 additions & 0 deletions src/ATen/native/xpu/ForeachReduceOp.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <ATen/native/ForeachUtils.h>

#include <ATen/native/xpu/sycl/ForeachReduceKernels.h>
#include <xpu/ATen/ops/_foreach_max_native.h>
#include <xpu/ATen/ops/_foreach_norm_native.h>

namespace at {
Expand Down Expand Up @@ -70,5 +71,23 @@ std::vector<Tensor> foreach_tensor_norm_xpu(

return native::xpu::foreach_norm_kernel(tensors, ord, p, dtype);
}

std::vector<Tensor> foreach_tensor_max_xpu(TensorList tensors) {
check_foreach_api_restrictions(tensors);
if (!can_use_fast_route(tensors)) {
return foreach_tensor_max_slow(tensors);
}

// for parity with max in ReduceAllOps.cpp, as max(empty) is ???
TORCH_CHECK(
std::all_of(
tensors.begin(),
tensors.end(),
[](const auto& t) { return t.numel() > 0; }),
"max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.");

return at::native::xpu::foreach_max_kernel(tensors);
}

} // namespace native
} // namespace at
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/ForeachUnaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <ATen/ops/_foreach_neg_native.h>
#include <ATen/ops/_foreach_reciprocal_native.h>
#include <ATen/ops/_foreach_round_native.h>
#include <ATen/ops/_foreach_rsqrt_native.h>
#include <ATen/ops/_foreach_sigmoid_native.h>
#include <ATen/ops/_foreach_sign_native.h>
#include <ATen/ops/_foreach_sin_native.h>
Expand Down Expand Up @@ -93,6 +94,7 @@ FOREACH_UNARY_OP(round);
FOREACH_UNARY_OP(frac);
FOREACH_UNARY_OP(reciprocal);
FOREACH_UNARY_OP(sign);
FOREACH_UNARY_OP(rsqrt);

std::vector<Tensor> foreach_tensor_neg_xpu(TensorList tensors) {
at::native::check_foreach_api_restrictions(tensors);
Expand Down
71 changes: 71 additions & 0 deletions src/ATen/native/xpu/sycl/ForeachFunctors.h
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,77 @@ struct TernaryOpScalarFunctor {
}
};

template <typename T, int depth, int r_args_depth, int res_arg_index>
struct TernaryOpScalarListFunctor {
using opmath_t = at::opmath_type<T>;

template <typename TLA, typename TLW, typename Op>
void operator()(
int chunk_size,
TLA tlAddress,
TLW tlWGMeta,
sycl::nd_item<1> item_id,
Op op) const {
static_assert(depth == 2 || depth == 3, "");
static_assert(depth >= r_args_depth, "");
static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
auto item_idx = item_id.get_local_id(0);
auto item_range = item_id.get_local_range(0);
auto group_idx = item_id.get_group(0);
int tensor_loc = tlWGMeta[group_idx].wg_to_tensor;
int chunk_idx = tlWGMeta[group_idx].wg_to_chunk;
int64_t n = tlAddress[tensor_loc].numel_to_tensor;

T* args[depth];
const bool all_aligned =
init_args<depth>(args, tlAddress, chunk_idx, chunk_size, tensor_loc);
n -= chunk_idx * chunk_size;
T r_args[r_args_depth][kILP];
const opmath_t scalar = tlAddress[tensor_loc].scalar_vals;

// to make things simple, we put aligned case in a different code path
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
for (int64_t i_start = item_idx;
i_start * kILP < n && i_start * kILP < chunk_size;
i_start += item_range) {
// load
load_store(r_args[0], args[0], 0, i_start);
load_store(r_args[1], args[1], 0, i_start);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[0][ii] =
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii]),
scalar);
}
// store
load_store(args[res_arg_index], r_args[0], i_start, 0);
}
} else {
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
i_start += item_range * kILP) {
load_args<r_args_depth>(
r_args, args, i_start, chunk_size, n, item_idx, item_range);
#pragma unroll
for (int ii = 0; ii < kILP; ii++) {
r_args[0][ii] =
op(static_cast<opmath_t>(r_args[0][ii]),
static_cast<opmath_t>(r_args[1][ii]),
scalar);
}
store_args(
args[res_arg_index],
r_args[0],
i_start,
chunk_size,
n,
item_idx,
item_range);
}
}
}
};

template <typename T>
struct power_functor {
T operator()(const T& a, const T& b) const {
Expand Down
Loading

0 comments on commit 15f6d65

Please sign in to comment.