Skip to content

Commit

Permalink
add prod
Browse files Browse the repository at this point in the history
  • Loading branch information
huaiyuzh committed Jul 18, 2024
1 parent 4372ca7 commit 2e064ef
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 12 deletions.
68 changes: 58 additions & 10 deletions src/ATen/native/xpu/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,14 @@ static void cum_ops_meta(
if (result.defined()) {
out_dtype = dtype.value_or(result.scalar_type());
at::xpu::resize_out(
result,
self.sizes(),
{},
self.options().dtype(out_dtype));
result, self.sizes(), {}, self.options().dtype(out_dtype));
} else {
auto is_integral = at::isIntegralType(self.scalar_type(), /*includeBool=*/true);
out_dtype = dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type());
result = at::xpu::create_out(
self.sizes(),
{},
self.options().dtype(out_dtype));
auto is_integral =
at::isIntegralType(self.scalar_type(), /*includeBool=*/true);
out_dtype =
dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type());
result =
at::xpu::create_out(self.sizes(), {}, self.options().dtype(out_dtype));
}

namedinference::propagate_names(result, self);
Expand Down Expand Up @@ -260,6 +257,57 @@ Tensor XPUNativeFunctions::mean(
return out;
}

Tensor& prod_meta(
const Tensor& self,
int64_t dim,
bool keepdim,
std::optional<ScalarType> dtype,
Tensor& out) {
auto out_dtype = infer_dtype_from_optional(self, dtype, out);
out = resize_reduction(out, self, dim, keepdim, out_dtype);
return out;
}

Tensor& XPUNativeFunctions::prod_out(
const Tensor& self,
int64_t dim,
bool keepdim,
std::optional<ScalarType> dtype,
Tensor& result) {
result = prod_meta(self, dim, keepdim, dtype, result);
// device is not CPU
auto iter = at::meta::make_reduction_from_out_ty(
self, result, dim, keepdim, result.scalar_type());
if (iter.numel() == 0) {
result.fill_(1);
} else {
native::xpu::prod_kernel(iter);
}
return result;
}

Tensor XPUNativeFunctions::prod(
const Tensor& self,
std::optional<ScalarType> opt_dtype) {
auto dtype = at::native::get_dtype_from_self(self, opt_dtype, true);
// auto shape = at::native::meta::get_reduction_shape(self, {}, false);
// Tensor out = at::empty(shape, self.options().dtype(dtype));
Tensor out;
out = prod_meta(self, {}, false, dtype, out);
return out;
}

Tensor XPUNativeFunctions::prod(
const Tensor& self,
int64_t dim,
bool keepdim,
std::optional<ScalarType> dtype) {
Tensor out;
out = prod_meta(self, dim, keepdim, dtype, out);
out = XPUNativeFunctions::prod_out(self, dim, keepdim, dtype, out);
return out;
}

inline TensorIterator get_allany_iter(
const Tensor& self,
const Tensor& result,
Expand Down
2 changes: 0 additions & 2 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"polygamma.out",
"_prelu_kernel",
"_prelu_kernel_backward",
"prod",
"prod.int_out",
"put_",
"renorm.out",
"repeat_interleave.Tensor",
Expand Down
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/ReduceOpsKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ void mean_kernel(TensorIterator& iter);

void sum_kernel(TensorIterator& iter);

void prod_kernel(TensorIterator& iter);

void std_var_kernel(TensorIterator& iter, double correction, bool take_sqrt);

} // namespace at::native::xpu
50 changes: 50 additions & 0 deletions src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,56 @@ void sum_kernel(TensorIterator& iter) {
});
}

template <typename acc_t>
struct ProdFunctor {
inline acc_t operator()(acc_t a, acc_t b) const {
return a * b;
}
};

template <>
struct ProdFunctor<bool> {
inline bool operator()(bool a, bool b) const {
return a && b;
}
};

template <
typename scalar_t,
typename acc_t = scalar_t,
typename out_t = scalar_t>
struct prod_functor {
void operator()(TensorIterator& iter) {
gpu_reduce_kernel<scalar_t, out_t>(
iter, func_wrapper<out_t>(ProdFunctor<acc_t>()), 1.);
}
};

template <>
struct prod_functor<bool> {
void operator()(TensorIterator& iter) {
gpu_reduce_kernel<bool, bool>(
iter, func_wrapper<bool>(ProdFunctor<bool>()), 1);
}
};

template <>
struct prod_functor<c10::complex<at::Half>> {
void operator()(TensorIterator& iter) {
using scalar_t = c10::complex<at::Half>;
using acc_t = at::opmath_type<scalar_t>;
gpu_reduce_kernel<scalar_t, scalar_t>(
iter, func_wrapper<scalar_t>(ProdFunctor<acc_t>()), acc_t{1.});
}
};

void prod_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kComplexHalf, kBool, iter.dtype(), "prod_xpu", [&]() {
prod_functor<scalar_t>{}(iter);
});
};

} // namespace xpu
} // namespace native
} // namespace at
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
"atanh",
"sqrt",
"sum",
"prod",
"amin",
"amax",
"std",
Expand Down
3 changes: 3 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ supported:
- min.dim_min
- sum.dim_IntList
- sum.IntList_out
- prod
- prod.int_out
- prod.dim_int
- mean.out
- mean.dim
- std.correction
Expand Down

0 comments on commit 2e064ef

Please sign in to comment.