diff --git a/src/ATen/CMakeLists.txt b/src/ATen/CMakeLists.txt index af9ef7d94..16b94e386 100644 --- a/src/ATen/CMakeLists.txt +++ b/src/ATen/CMakeLists.txt @@ -2,8 +2,8 @@ file(GLOB xpu_h "xpu/*.h") file(GLOB xpu_cpp "xpu/*.cpp") -file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp") -file(GLOB xpu_sycl "native/xpu/sycl/*.cpp" "native/sparse/xpu/sycl/*.cpp" "native/transformers/sycl/*.cpp" "native/quantized/sycl/*.cpp") +file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp") +file(GLOB xpu_sycl "native/xpu/sycl/*.cpp" "native/sparse/xpu/sycl/*.cpp" "native/nested/xpu/sycl/*.cpp" "native/transformers/sycl/*.cpp" "native/quantized/sycl/*.cpp") list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp}) list(APPEND ATen_XPU_NATIVE_CPP_SRCS ${xpu_native_cpp}) diff --git a/src/ATen/native/nested/xpu/NestedTensorTransformerFunctions.cpp b/src/ATen/native/nested/xpu/NestedTensorTransformerFunctions.cpp new file mode 100644 index 000000000..52e8aa79b --- /dev/null +++ b/src/ATen/native/nested/xpu/NestedTensorTransformerFunctions.cpp @@ -0,0 +1,208 @@ +#include +#include +#include +#include +#include + +namespace at::native { + +namespace { + +int64_t padded_tensor_numel(const Tensor& sizes) { + const auto sizes_num_rows = sizes.sizes()[0]; + const auto sizes_row_length = sizes.sizes()[1]; + const auto* sizes_data = sizes.data_ptr(); + int64_t numel = 0; + for (const auto row_num : c10::irange(sizes_num_rows)) { + const auto* row_ptr = sizes_data + row_num * sizes_row_length; + int64_t prod = 1; + for (const auto idx : c10::irange(sizes_row_length)) { + prod *= row_ptr[idx]; + } + numel += prod; + } + return numel; +} + +} // namespace + +Tensor nested_from_padded_xpu( + const Tensor& padded, + const Tensor& sizes, + bool do_transform_0213) { + if (padded.dim() > 1 && padded.dim() < 5) { + // Instead of erroring, call the generic version + if (!(padded.dim() == 4 && do_transform_0213) && + !(padded.dim() == 3 && !do_transform_0213)) { + return at::native::nested_from_padded_generic( + padded, sizes, do_transform_0213); + } + if (padded.dtype() != at::kFloat && padded.dtype() != kHalf) { + TORCH_WARN_ONCE( + "nested_from_padded XPU kernels only support fp32/fp16; falling " + "back to slower generic kernel"); + return at::native::nested_from_padded_generic( + padded, sizes, do_transform_0213); + } + Tensor target_offsets = + at::native::NestedTensor_batch_offsets_from_size_tensor(sizes, 0); + Tensor padded_sizes_tensor = at::tensor(padded.sizes()); + Tensor output = at::empty({padded_tensor_numel(sizes)}, padded.options()); + Tensor target_size_sizes = sizes.reshape(-1); + + Tensor metadata = + at::cat({target_size_sizes, padded_sizes_tensor, target_offsets}); + metadata = metadata.to(at::Device(kCUDA), kInt, true, true); + + auto output_size_ptr = metadata.data_ptr(); + auto input_size_ptr = output_size_ptr + target_size_sizes.numel(); + auto offsets_ptr = input_size_ptr + padded_sizes_tensor.numel(); + + Tensor padded_contiguous = padded.contiguous(); + if (padded.dtype() == at::kFloat) { + if (do_transform_0213) { + xpu::remove_padding_transform0213_kernel_float( + padded_contiguous.data_ptr(), + output.data_ptr(), + offsets_ptr, + input_size_ptr, + output_size_ptr, + padded_contiguous.dim() - 2, + padded_contiguous.sizes()[0]); + } else { + xpu::remove_padding_kernel_float( + padded_contiguous.data_ptr(), + output.data_ptr(), + offsets_ptr, + input_size_ptr, + output_size_ptr, + padded_contiguous.dim() - 1, + padded_contiguous.sizes()[0]); + } + } else if (padded.dtype() == at::kHalf) { + if (do_transform_0213) { + xpu::remove_padding_transform0213_kernel_half( + padded_contiguous.data_ptr(), + output.data_ptr(), + offsets_ptr, + input_size_ptr, + output_size_ptr, + padded_contiguous.dim() - 2, + padded_contiguous.sizes()[0]); + } else { + xpu::remove_padding_kernel_half( + padded_contiguous.data_ptr(), + output.data_ptr(), + offsets_ptr, + input_size_ptr, + output_size_ptr, + padded_contiguous.dim() - 1, + padded_contiguous.sizes()[0]); + } + } else { + TORCH_CHECK(false, "Only support fp32/fp16 for padded input"); + } + return at::detail::make_tensor( + std::move(output), sizes); + } else { + return at::native::nested_from_padded_generic(padded, sizes); + } +} + +static Tensor batch_offsets_from_efficient_size(const Tensor& ef_sizes) { + int64_t* nt_sizes_ptr = ef_sizes.data_ptr(); + int64_t ef_sizes_size_0 = ef_sizes.sizes()[0]; + Tensor offsets = at::empty({1 + ef_sizes_size_0}, at::kLong); + int64_t* offsets_ptr = offsets.mutable_data_ptr(); + offsets_ptr[0] = 0; + int64_t ef_sizes_size_1 = ef_sizes.sizes()[1]; + for (const auto i : c10::irange(ef_sizes_size_0)) { + int64_t prod = 1; + for (const auto j : c10::irange(ef_sizes_size_1)) { + prod = prod * nt_sizes_ptr[i * ef_sizes_size_1 + j]; + } + offsets_ptr[i + 1] = offsets_ptr[i] + prod; + } + return offsets; +} + +Tensor NestedTensor_to_padded_tensor_xpu( + const Tensor& t, + double padding, + OptionalIntArrayRef output_size) { + TORCH_CHECK( + t.numel() > 0, + "to_padded_tensor: at least one constituent tensor should have non-zero numel") + int64_t t_dim = t.dim(); + if (t_dim >= 2 && t_dim <= 4 && + (t.dtype() == at::kFloat || t.dtype() == at::kDouble || + t.dtype() == at::kHalf)) { + auto* nt_input = get_nested_tensor_impl(t); + TORCH_CHECK( + nested_tensor_impl_is_contiguous(nt_input), + "for now to_padded_tensor only supports contiguous nested tensor"); + const auto& nt_buffer = nt_input->get_buffer(); + + if (t_dim == 3 && nt_input->opt_size(2) && (*nt_input->opt_size(2) > 0) && + !(output_size.has_value())) { + Tensor nt_sizes = nt_input->get_nested_sizes(); + Tensor sizes_dim1 = at::native::narrow_symint(nt_sizes, 1, 0, 1); + Tensor sizes_dim2 = at::native::narrow_symint(nt_sizes, 1, 1, 1); + Tensor result = at::detail::make_tensor( + nt_input->get_buffer(), sizes_dim1 * sizes_dim2[0]); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.dim() == 2); + result = NestedTensor_to_padded_tensor_xpu(result, padding, output_size); + return result.reshape({result.sizes()[0], -1, *nt_input->opt_size(2)}); + } + + Tensor nt_sizes = nt_input->get_nested_sizes(); + Tensor offsets = batch_offsets_from_efficient_size(nt_sizes); + auto new_size = NestedTensor_get_max_size(*nt_input); + new_size.insert(new_size.begin(), nt_sizes.sizes()[0]); + + // Pad output tensor to output_size if provided + if (output_size.has_value()) { + auto output_size_ = output_size.value(); + TORCH_CHECK( + output_size_.size() == new_size.size(), + "Length of output_size does not match NestedTensor dims. Broadcasting is not supported."); + for (uint64_t i = 0; i < new_size.size(); i++) { + TORCH_CHECK( + output_size_[i] >= new_size[i], + "Value in output_size is less than NestedTensor padded size. Truncation is not supported."); + new_size[i] = output_size_[i]; + } + } + + Tensor output = at::empty(IntArrayRef(new_size), nt_buffer.options()); + + int64_t input_dim = nt_sizes.sizes()[1]; + int64_t batch_size = nt_sizes.sizes()[0]; + int64_t output_batch_size = new_size[0]; + // TODO: Remove need for cat here + at::Tensor metadata = at::cat({offsets, nt_sizes.reshape(-1)}); + metadata = metadata.to(at::Device(kXPU), at::kInt); + + std::vector split = + at::split_with_sizes(metadata, {offsets.numel(), nt_sizes.numel()}, 0); + + offsets = split[0]; + nt_sizes = split[1]; + + xpu::add_padding_kernel( + nt_buffer, + output, + padding, + offsets, + nt_sizes, + input_dim, + new_size, + batch_size, + output_batch_size); + + return output; + } + return NestedTensor_to_padded_tensor_generic(t, padding, output_size); +} + +} // namespace at::native diff --git a/src/ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.cpp b/src/ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.cpp new file mode 100644 index 000000000..9d1ec1ae2 --- /dev/null +++ b/src/ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.cpp @@ -0,0 +1,616 @@ +#include +#include +#include + +// keep align with cuda, global range0 is set to output_batch_size, global_range +// for dim1 is set to 16, +#define GRID_DIM_Y 16 +#define BLOCK_DIM 256 + +namespace at::native::xpu { + +template +struct RemovePaddingFunctor { + void operator()(sycl::nd_item<2> item) const { + const int batch_id = item.get_group(1); + const int grid_id = item.get_group(0); + const int tid = item.get_local_id(1) + grid_id * BLOCK_DIM; + const int grainsize = GRID_DIM_Y * BLOCK_DIM; + const int offset = offsets_[batch_id]; + const int* sizes_i = output_sizes_ + batch_id * output_dim_; + const int numel_i = sizes_i[0] * sizes_i[1] * sizes_i[2]; + int input_offset = + batch_id * input_sizes_[1] * input_sizes_[2] * input_sizes_[3]; + for (int ii = 0; ii < (numel_i / grainsize); ii++) { + const int i = ii * grainsize + tid; + const int i0 = i / (sizes_i[1] * sizes_i[2]); + const int i1 = (i % (sizes_i[1] * sizes_i[2])) / sizes_i[2]; + const int i2 = i % sizes_i[2]; + const int i0_offset = i0 * input_sizes_[2] * input_sizes_[3]; + const int i1_offset = i1 * input_sizes_[3]; + output_[offset + i] = input_[input_offset + i0_offset + i1_offset + i2]; + } + const int i = (numel_i / grainsize) * grainsize + tid; + if (i < numel_i) { + const int i0 = i / (sizes_i[1] * sizes_i[2]); + const int i1 = (i % (sizes_i[1] * sizes_i[2])) / sizes_i[2]; + const int i2 = i % sizes_i[2]; + const int i0_offset = i0 * input_sizes_[2] * input_sizes_[3]; + const int i1_offset = i1 * input_sizes_[3]; + output_[offset + i] = input_[input_offset + i0_offset + i1_offset + i2]; + } + } + + RemovePaddingFunctor( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + const int batch_size) + : input_(input), + output_(output), + offsets_(offsets), + input_sizes_(input_sizes), + output_sizes_(output_sizes), + output_dim_(output_dim), + batch_size_(batch_size) {} + + private: + const T* input_; + T* output_; + const int* offsets_; + const int* input_sizes_; + const int* output_sizes_; + int output_dim_; + const int batch_size_; +}; + +template +struct RemovePadding2Functor { + void operator()(sycl::nd_item<2> item) const { + const int batch_id = item.get_group(1); + const int grid_id = item.get_group(0); + const int tid = item.get_local_id(1) + grid_id * BLOCK_DIM; + const int grainsize = GRID_DIM_Y * BLOCK_DIM; + const int offset = offsets_[batch_id]; + const int* sizes_i = output_sizes_ + batch_id * output_dim_; + const int numel_i = sizes_i[0] * sizes_i[1]; + int input_offset = batch_id * input_sizes_[1] * input_sizes_[2]; + for (int ii = 0; ii < (numel_i / grainsize); ii++) { + const int i = ii * grainsize + tid; + const int i0 = i / sizes_i[1]; + const int i1 = i % sizes_i[1]; + const int i0_offset = i0 * input_sizes_[2]; + output_[offset + i] = input_[input_offset + i0_offset + i1]; + } + const int i = (numel_i / grainsize) * grainsize + tid; + if (i < numel_i) { + const int i0 = i / sizes_i[1]; + const int i1 = i % sizes_i[1]; + const int i0_offset = i0 * input_sizes_[2]; + output_[offset + i] = input_[input_offset + i0_offset + i1]; + } + } + + RemovePadding2Functor( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + const int batch_size) + : input_(input), + output_(output), + offsets_(offsets), + input_sizes_(input_sizes), + output_sizes_(output_sizes), + output_dim_(output_dim), + batch_size_(batch_size) {} + + const T* input_; + T* output_; + const int* offsets_; + const int* input_sizes_; + const int* output_sizes_; + int output_dim_; + const int batch_size_; +}; + +template +struct RemovePaddingTransform0213Functor { + void operator()(sycl::nd_item<2> item) const { + const int batch_id = item.get_group(1); + const int grid_id = item.get_group(0); + const int tid = item.get_local_id(1) + grid_id * BLOCK_DIM; + const int grainsize = GRID_DIM_Y * BLOCK_DIM; + const int offset = offsets_[batch_id]; + const int* sizes_i = output_sizes_ + batch_id * output_dim_; + const int numel_i = sizes_i[0] * sizes_i[1]; + int input_offset = + batch_id * input_sizes_[1] * input_sizes_[2] * input_sizes_[3]; + for (int ii = 0; ii < (numel_i / grainsize); ii++) { + const int i = ii * grainsize + tid; + const int i2 = i / sizes_i[1]; + const int i13 = i % sizes_i[1]; + const int i1 = i13 / (sizes_i[1] / input_sizes_[1]); + const int i3 = i13 % (sizes_i[1] / input_sizes_[1]); + + output_[offset + i] = input_ + [input_offset + i1 * input_sizes_[2] * input_sizes_[3] + + i2 * input_sizes_[3] + i3]; + } + const int i = (numel_i / grainsize) * grainsize + tid; + if (i < numel_i) { + const int i2 = i / sizes_i[1]; + const int i13 = i % sizes_i[1]; + const int i1 = i13 / (sizes_i[1] / input_sizes_[1]); + const int i3 = i13 % (sizes_i[1] / input_sizes_[1]); + output_[offset + i] = input_ + [input_offset + i1 * input_sizes_[2] * input_sizes_[3] + + i2 * input_sizes_[3] + i3]; + } + } + + RemovePaddingTransform0213Functor( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + const int batch_size) + : input_(input), + output_(output), + offsets_(offsets), + input_sizes_(input_sizes), + output_sizes_(output_sizes), + output_dim_(output_dim), + batch_size_(batch_size) {} + + const T* input_; + T* output_; + const int* offsets_; + const int* input_sizes_; + const int* output_sizes_; + int output_dim_; + const int batch_size_; +}; + +template +void remove_padding_kernel( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size) { + auto queue = getCurrentSYCLQueue(); + if (output_dim == 2) { + auto kfn = RemovePadding2Functor( + input, + output, + offsets, + input_sizes, + output_sizes, + output_dim, + batch_size); + int64_t max_wg_size = syclMaxWorkGroupSize(kfn); + sycl::range<2> global_range(GRID_DIM_Y, batch_size * max_wg_size); + sycl::range<2> local_range(1, max_wg_size); + sycl_kernel_submit(global_range, local_range, queue, kfn); + } else { + auto kfn = RemovePaddingFunctor( + input, + output, + offsets, + input_sizes, + output_sizes, + output_dim, + batch_size); + int64_t max_wg_size = syclMaxWorkGroupSize(kfn); + sycl::range<2> global_range(GRID_DIM_Y, batch_size * max_wg_size); + sycl::range<2> local_range(1, max_wg_size); + sycl_kernel_submit(global_range, local_range, queue, kfn); + } +} + +template +void remove_padding_transform0213_kernel( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size) { + TORCH_CHECK( + output_dim == 2, + "remove padding transform0213 only support output dim == 2"); + + auto queue = getCurrentSYCLQueue(); + auto kfn = RemovePaddingTransform0213Functor( + input, + output, + offsets, + input_sizes, + output_sizes, + output_dim, + batch_size); + + int64_t max_wg_size = syclMaxWorkGroupSize(kfn); + sycl::range<2> global_range(GRID_DIM_Y, batch_size * max_wg_size); + sycl::range<2> local_range(1, max_wg_size); + + sycl_kernel_submit(global_range, local_range, queue, kfn); +} + +void remove_padding_kernel_float( + const float* input, + float* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size) { + remove_padding_kernel( + input, + output, + offsets, + input_sizes, + output_sizes, + output_dim, + batch_size); +} + +void remove_padding_kernel_half( + const c10::Half* input, + c10::Half* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size) { + remove_padding_kernel( + input, + output, + offsets, + input_sizes, + output_sizes, + output_dim, + batch_size); +} + +void remove_padding_transform0213_kernel_float( + const float* input, + float* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size) { + remove_padding_transform0213_kernel( + input, + output, + offsets, + input_sizes, + output_sizes, + output_dim, + batch_size); +} + +void remove_padding_transform0213_kernel_half( + const c10::Half* input, + c10::Half* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size) { + remove_padding_transform0213_kernel( + input, + output, + offsets, + input_sizes, + output_sizes, + output_dim, + batch_size); +} + +template +struct AddPadding1Functor { + void operator()(sycl::nd_item<2> item) const { + const int batch_id = item.get_group(1); + const int grid_id = item.get_group(0); + const int tid = item.get_local_id(1) + grid_id * BLOCK_DIM; + const int grainsize = GRID_DIM_Y * BLOCK_DIM; + const int* sizes_i = input_sizes_ + batch_id * input_dim_; + const int batch_output_offset = batch_id * output_sizes_1_; + for (int ii = 0; ii < (output_sizes_1_ / grainsize); ii++) { + const int i = ii * grainsize + tid; + const int output_offset = batch_output_offset + i; + if (batch_id < batch_size_ && i < sizes_i[0]) { + const int batch_input_offset = offsets_[batch_id]; + output_[output_offset] = input_[batch_input_offset + i]; + } else { + output_[output_offset] = padding_value_; + } + } + const int i = (output_sizes_1_ / grainsize) * grainsize + tid; + if (i < output_sizes_1_) { + const int output_offset = batch_output_offset + i; + if (batch_id < batch_size_ && (i < sizes_i[0])) { + const int batch_input_offset = offsets_[batch_id]; + output_[output_offset] = input_[batch_input_offset + i]; + } else { + output_[output_offset] = padding_value_; + } + } + } + AddPadding1Functor( + const T* input, + T* output, + T padding_value, + const int* offsets, + const int* input_sizes, + int input_dim, + int output_sizes_1, + const int batch_size) + : input_(input), + output_(output), + padding_value_(padding_value), + offsets_(offsets), + input_sizes_(input_sizes), + input_dim_(input_dim), + output_sizes_1_(output_sizes_1), + batch_size_(batch_size) {} + + private: + const T* input_; + T* output_; + T padding_value_; + const int* offsets_; + const int* input_sizes_; + int input_dim_; + int output_sizes_1_; + const int batch_size_; +}; + +template +struct AddPadding2Functor { + void operator()(sycl::nd_item<2> item) const { + const int batch_id = item.get_group(1); + const int grid_id = item.get_group(0); + const int tid = item.get_local_id(1) + grid_id * BLOCK_DIM; + const int grainsize = GRID_DIM_Y * BLOCK_DIM; + const int* sizes_i = input_sizes_ + batch_id * input_dim_; + const int output_offset = batch_id * output_sizes_1_ * output_sizes_2_; + const int output_numel = output_sizes_1_ * output_sizes_2_; + for (int ii = 0; ii < (output_numel / grainsize); ii++) { + const int i = ii * grainsize + tid; + const int i0 = i / (output_sizes_2_); + const int i1 = i - i0 * output_sizes_2_; + if (batch_id < batch_size_ && i0 < sizes_i[0] && i1 < sizes_i[1]) { + const int offset = offsets_[batch_id]; + const int input_offset = offset + i0 * sizes_i[1] + i1; + output_[output_offset + i] = input_[input_offset]; + } else { + output_[output_offset + i] = padding_value_; + } + } + const int i = (output_numel / grainsize) * grainsize + tid; + if (i < output_numel) { + const int i0 = i / (output_sizes_2_); + const int i1 = i - i0 * output_sizes_2_; + if (batch_id < batch_size_ && i0 < sizes_i[0] && i1 < sizes_i[1]) { + const int offset = offsets_[batch_id]; + const int input_offset = offset + i0 * sizes_i[1] + i1; + output_[output_offset + i] = input_[input_offset]; + } else { + output_[output_offset + i] = padding_value_; + } + } + } + AddPadding2Functor( + const T* input, + T* output, + T padding_value, + const int* offsets, + const int* input_sizes, + int input_dim, + int output_sizes_1, + int output_sizes_2, + const int batch_size) + : input_(input), + output_(output), + padding_value_(padding_value), + offsets_(offsets), + input_sizes_(input_sizes), + input_dim_(input_dim), + output_sizes_1_(output_sizes_1), + output_sizes_2_(output_sizes_2), + batch_size_(batch_size) {} + + private: + const T* input_; + T* output_; + T padding_value_; + const int* offsets_; + const int* input_sizes_; + int input_dim_; + int output_sizes_1_; + int output_sizes_2_; + const int batch_size_; +}; + +template +struct AddPadding3Functor { + void operator()(sycl::nd_item<2> item) const { + const int batch_id = item.get_group(1); + const int grid_id = item.get_group(0); + const int tid = item.get_local_id(1) + grid_id * BLOCK_DIM; + const int grainsize = GRID_DIM_Y * BLOCK_DIM; + const int* sizes_i = input_sizes_ + batch_id * input_dim_; + const int output_offset = + batch_id * output_sizes_1_ * output_sizes_2_ * output_sizes_3_; + const int output_numel = + output_sizes_1_ * output_sizes_2_ * output_sizes_3_; + for (int ii = 0; ii < (output_numel / grainsize); ii++) { + const int i = ii * grainsize + tid; + const int i0 = i / (output_sizes_2_ * output_sizes_3_); + const int i1 = + (i % (output_sizes_2_ * output_sizes_3_)) / output_sizes_3_; + const int i2 = i % output_sizes_3_; + if (batch_id < batch_size_ && i0 < sizes_i[0] && i1 < sizes_i[1] && + i2 < sizes_i[2]) { + const int offset = offsets_[batch_id]; + const int input_offset = + offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2; + output_[output_offset + i] = input_[input_offset]; + } else { + output_[output_offset + i] = padding_value_; + } + } + const int i = (output_numel / grainsize) * grainsize + tid; + if (i < output_numel) { + const int i0 = i / (output_sizes_2_ * output_sizes_3_); + const int i1 = + (i % (output_sizes_2_ * output_sizes_3_)) / output_sizes_3_; + const int i2 = i % output_sizes_3_; + if (batch_id < batch_size_ && i0 < sizes_i[0] && i1 < sizes_i[1] && + i2 < sizes_i[2]) { + const int offset = offsets_[batch_id]; + const int input_offset = + offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2; + output_[output_offset + i] = input_[input_offset]; + } else { + output_[output_offset + i] = padding_value_; + } + } + } + AddPadding3Functor( + const T* input, + T* output, + T padding_value, + const int* offsets, + const int* input_sizes, + int input_dim, + int output_sizes_1, + int output_sizes_2, + int output_sizes_3, + const int batch_size) + : input_(input), + output_(output), + padding_value_(padding_value), + offsets_(offsets), + input_sizes_(input_sizes), + input_dim_(input_dim), + output_sizes_1_(output_sizes_1), + output_sizes_2_(output_sizes_2), + output_sizes_3_(output_sizes_3), + batch_size_(batch_size) {} + + private: + const T* input_; + T* output_; + T padding_value_; + const int* offsets_; + const int* input_sizes_; + int input_dim_; + int output_sizes_1_; + int output_sizes_2_; + int output_sizes_3_; + const int batch_size_; +}; + +template +void add_padding_kernel_impl( + T* input, // [batch_size x None] + T* output, // [batch_size x max(input.nested_size(1)) x inner_size] + T padding_value, + const int* offsets, + const int* input_sizes, + int input_dim, + const std::vector& output_sizes, + const int batch_size, + const int output_batch_size) { + auto queue = getCurrentSYCLQueue(); + if (input_dim == 1) { + auto kfn = AddPadding1Functor( + input, + output, + padding_value, + offsets, + input_sizes, + input_dim, + output_sizes[1], + batch_size); + int64_t max_wg_size = syclMaxWorkGroupSize(kfn); + sycl::range<2> global_range(GRID_DIM_Y, output_batch_size * max_wg_size); + sycl::range<2> local_range(1, max_wg_size); + sycl_kernel_submit(global_range, local_range, queue, kfn); + } + if (input_dim == 2) { + auto kfn = AddPadding2Functor( + input, + output, + padding_value, + offsets, + input_sizes, + input_dim, + output_sizes[1], + output_sizes[2], + batch_size); + int64_t max_wg_size = syclMaxWorkGroupSize(kfn); + sycl::range<2> global_range(GRID_DIM_Y, output_batch_size * max_wg_size); + sycl::range<2> local_range(1, max_wg_size); + sycl_kernel_submit(global_range, local_range, queue, kfn); + } + if (input_dim == 3) { + auto kfn = AddPadding3Functor( + input, + output, + padding_value, + offsets, + input_sizes, + input_dim, + output_sizes[1], + output_sizes[2], + output_sizes[3], + batch_size); + int64_t max_wg_size = syclMaxWorkGroupSize(kfn); + sycl::range<2> global_range(GRID_DIM_Y, output_batch_size * max_wg_size); + sycl::range<2> local_range(1, max_wg_size); + sycl_kernel_submit(global_range, local_range, queue, kfn); + } +} + +void add_padding_kernel( + at::Tensor input, + at::Tensor output, + double padding, + const at::Tensor offsets, + const at::Tensor nt_sizes, + int input_dim, + const std::vector& new_size, + const int batch_size, + const int output_batch_size) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "NestedTensor_to_padded_tensor_xpu", [&]() { + add_padding_kernel_impl( + input.data_ptr(), + output.data_ptr(), + (scalar_t)(padding), + offsets.data_ptr(), + nt_sizes.data_ptr(), + input_dim, + new_size, + batch_size, + output_batch_size); + }); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.h b/src/ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.h new file mode 100644 index 000000000..9000070b6 --- /dev/null +++ b/src/ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.h @@ -0,0 +1,54 @@ +#pragma once + +#include + +namespace at::native::xpu { + +TORCH_XPU_API void remove_padding_kernel_float( + const float* input, + float* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size); + +TORCH_XPU_API void remove_padding_kernel_half( + const c10::Half* input, + c10::Half* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size); + +TORCH_XPU_API void remove_padding_transform0213_kernel_float( + const float* input, + float* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size); + +TORCH_XPU_API void remove_padding_transform0213_kernel_half( + const c10::Half* input, + c10::Half* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int64_t output_dim, + const int64_t batch_size); + +TORCH_XPU_API void add_padding_kernel( + at::Tensor input, + at::Tensor output, + double padding, + const at::Tensor offsets, + const at::Tensor nt_sizes, + int input_dim, + const std::vector& new_size, + const int batch_size, + const int output_batch_size); + +} // namespace at::native::xpu diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 4a2fa42ae..7328d64e3 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -1818,6 +1818,14 @@ "test_scaled_mm_vs_emulated_row_wise_bfloat16_xpu", # AssertionError: Torch not compiled with CUDA enabled "test_zero_dim_tensorwise_which_dim_zero", + # New added case in 2.7 + "test_cublas_addmm_reduced_precision_fp16_accumulate_size_10000_xpu_bfloat16", + "test_cublas_addmm_reduced_precision_fp16_accumulate_size_10000_xpu_float16", + "test_cublas_addmm_reduced_precision_fp16_accumulate_size_1000_xpu_bfloat16", + "test_cublas_addmm_reduced_precision_fp16_accumulate_size_1000_xpu_float16", + "test_cublas_addmm_reduced_precision_fp16_accumulate_size_100_xpu_bfloat16", + "test_cublas_addmm_reduced_precision_fp16_accumulate_size_100_xpu_float16", + "test_cublas_and_lt_reduced_precision_fp16_accumulate_xpu", ), "test_maskedtensor_xpu.py": ( # Summary: SparseCsrXPU OPs are not supported diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index f3812e99e..a8367f060 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -2091,6 +2091,7 @@ - func: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor dispatch: XPU: _safe_softmax_xpu + NestedTensorXPU: _safe_softmax - func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor structured_delegate: _softmax_backward_data.out @@ -4430,6 +4431,12 @@ XPU: roll_xpu autogen: roll.out +- func: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor + device_check: NoCheck # cpu_nested_shape_example will always be on CPU + dispatch: + XPU: nested_from_padded_xpu + autogen: _nested_from_padded.out + - func: avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn structured: True @@ -8773,3 +8780,219 @@ dispatch: CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy autogen: _nested_view_from_buffer_copy.out + +- func: _nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor + variants: function + device_check: NoCheck + device_guard: False + dispatch: + NestedTensorXPU: _nested_select_backward_symint + +- func: chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[] + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeImplicitAutograd: chunk + NestedTensorXPU: chunk_nested_tensor + +- func: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor + dispatch: + CompositeExplicitAutograd: embedding_symint + NestedTensorXPU: NestedTensor_embedding + autogen: embedding.out + tags: core + +- func: is_same_size(Tensor self, Tensor other) -> bool + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + NestedTensorXPU: nested_is_same_size + CompositeExplicitAutograd: is_same_size + +- func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor + python_module: nn + dispatch: + CompositeImplicitAutograd: linear + NestedTensorXPU: nested_linear + +- func: linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + dispatch: + NestedTensorXPU: nested_linear_backward + autogen: linear_backward.out + +- func: matmul(Tensor self, Tensor other) -> Tensor + variants: function, method + dispatch: + CompositeImplicitAutograd: matmul + NestedTensorXPU: matmul_nested + +- func: matmul_backward(Tensor grad, Tensor self, Tensor other, bool[2] mask) -> (Tensor, Tensor) + dispatch: + NestedTensorXPU: matmul_backward_nested + autogen: matmul_backward.out + +- func: matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CompositeImplicitAutograd: matmul_out + NestedTensorXPU: matmul_out_nested + +- func: narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeImplicitAutograd: narrow_symint + NestedTensorXPU: narrow_nested_symint + +- func: ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: ones_like + NestedTensorXPU: ones_like + autogen: ones_like.out + +- func: split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: split_with_sizes + NestedTensorXPU: split_with_sizes_nested + tags: core + +- func: select_copy.int(Tensor self, int dim, SymInt index) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: select_copy_symint + tags: view_copy + autogen: select_copy.int_out + +- func: select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: select_symint + NestedTensorXPU: select_nested + tags: core + +- func: squeeze_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: squeeze_copy + tags: view_copy + autogen: squeeze_copy.out + +- func: squeeze_copy.dim(Tensor self, int dim) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: squeeze_copy_dim + tags: view_copy + autogen: squeeze_copy.dim_out + +- func: squeeze_copy.dims(Tensor self, int[] dim) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: squeeze_copy_dims + tags: view_copy + autogen: squeeze_copy.dims_out + +- func: squeeze(Tensor(a) self) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: squeeze + NestedTensorXPU: squeeze_nested + +- func: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: squeeze + NestedTensorXPU: squeeze_dim_nested + tags: core + +- func: squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: squeeze + NestedTensorXPU: squeeze_dim_nested + tags: core + +- func: detach_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: detach_copy + tags: view_copy + autogen: detach_copy.out + +- func: detach(Tensor(a) self) -> Tensor(a) + variants: function, method + dispatch: + CompositeExplicitAutograd: detach + NestedTensorXPU: detach + +- func: transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: transpose_copy_int + tags: view_copy + autogen: transpose_copy.int_out + +- func: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + variants: function, method + device_check: NoCheck + device_guard: False + dispatch: + CompositeExplicitAutograd: transpose + NestedTensorXPU: transpose_nested + +- func: alias_copy(Tensor self) -> Tensor + variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: alias_copy + tags: view_copy + autogen: alias_copy.out + +- func: alias(Tensor(a) self) -> Tensor(a) + variants: method, function + dispatch: + CompositeExplicitAutograd: alias + NestedTensorXPU: alias_nested + tags: core + +- func: _test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor + dispatch: + # the NestedTensor keys are necessary because NestedTensor has been removed + # from the CompositeExplicitAutograd keyset see Note [NestedTensor Not Included in Backend Keys] + CompositeExplicitAutograd, NestedTensorXPU: _test_autograd_multiple_dispatch_fullcoverage + autogen: _test_autograd_multiple_dispatch.fullcoverage_out + +# Note: this function is only for testing. +- func: _test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor + dispatch: + CompositeImplicitAutograd, NestedTensorXPU: _test_autograd_multiple_dispatch_ntonly + +# Apparently, putting "forward" in the name will cause Python bindings to be skipped, so "fwd" it is. +- func: _transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor + variants: function + dispatch: + XPU, NestedTensorXPU: transformer_encoder_layer_forward + autogen: _transformer_encoder_layer_fwd.out + +- func: nested_to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor + python_module: nested + variants: function + +- func: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor + variants: method + dispatch: + NestedTensorXPU: NestedTensor_to_padded_tensor_xpu + autogen: to_padded_tensor.out