Skip to content

Commit

Permalink
Add aten::_foreach_copy_ (#985)
Browse files Browse the repository at this point in the history
Add the implementation opeartor _foreach_copy_
Tested on PVC.

---------

Co-authored-by: Yutao Xu <[email protected]>
  • Loading branch information
cfgfung and xytintel authored Oct 28, 2024
1 parent b189259 commit f69c52f
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/ATen/native/xpu/ForeachOpList.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
#include <ATen/ops/_foreach_div_native.h>
#include <ATen/ops/_foreach_lerp_native.h>
#include <ATen/ops/_foreach_mul_native.h>
#include <ATen/ops/_foreach_clamp_min_native.h>
#include <ATen/ops/_foreach_copy_native.h>
#include <ATen/ops/_foreach_pow_native.h>

#include <ATen/native/xpu/sycl/ForeachBinaryOpListKernels.h>
#include <ATen/native/xpu/sycl/ForeachPointwiseOpListKernels.h>
#include <ATen/native/xpu/sycl/ForeachTernaryOpListKernels.h>
#include <ATen/native/xpu/sycl/ForeachCopyKernels.h>

#include <ATen/ops/empty_like.h>

Expand Down Expand Up @@ -147,5 +150,24 @@ void foreach_tensor_lerp_ternary_xpu_(
}
}

void foreach_tensor_copy_list_kernel_xpu_(
TensorList self,
TensorList src,
bool non_blocking) {
check_foreach_api_restrictions(self, src);
if (!can_use_fast_route(
self, src, /* does_op_promote_integer_inputs_to_float */ false)) {
return foreach_tensor_copy_list_kernel_slow_(
self, src, non_blocking);
}

xpu::foreach_copy_list_kernel_(self, src);

// increment_version
for (const auto& t : self) {
t.unsafeGetTensorImpl()->bump_version();
}
}

} // namespace native
} // namespace at
42 changes: 42 additions & 0 deletions src/ATen/native/xpu/sycl/ForeachCopyKernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include <ATen/Dispatch.h>
#include <comm/xpu_aten.h>

#include <ATen/native/xpu/sycl/ForeachFunctors.h>
#include <ATen/native/xpu/sycl/MultiTensorApply.h>

#include <ATen/native/xpu/sycl/ForeachCopyKernels.h>


namespace at::native::xpu {
template <typename T>
struct Identity {
T operator()(const T& x) {
return x;
}
};

void foreach_copy_list_kernel_(
TensorList self,
TensorList src) {
std::vector<std::vector<at::Tensor>> tensor_lists{src.vec(), self.vec()};

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half,
at::ScalarType::BFloat16,
at::ScalarType::Bool,
self[0].scalar_type(),
"foreach_tensor_copy",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
multi_tensor_apply<2>(
tensor_lists,
UnaryOpFunctor<
scalar_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Identity<opmath_t>());
});
}

} // namespace at::native::xpu
10 changes: 10 additions & 0 deletions src/ATen/native/xpu/sycl/ForeachCopyKernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once
#include <comm/xpu_aten.h>

namespace at::native::xpu {

TORCH_XPU_API void foreach_copy_list_kernel_(
TensorList self,
TensorList src);

} // namespace at::native::xpu
14 changes: 14 additions & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2749,6 +2749,20 @@
XPU: foreach_tensor_zero_xpu_
autogen: _foreach_zero, _foreach_zero.out

- func: _foreach_copy_(Tensor(a!)[] self, Tensor[] src, bool non_blocking=False) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
variants: function
dispatch:
CompositeExplicitAutograd: foreach_tensor_copy_list_kernel_slow_
XPU: foreach_tensor_copy_list_kernel_xpu_
autogen: _foreach_copy.out

- func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out
device_check: NoCheck
variants: function
dispatch:
CompositeExplicitAutograd: _foreach_copy

- func: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
dispatch:
XPU: layer_norm_xpu
Expand Down

0 comments on commit f69c52f

Please sign in to comment.