diff --git a/src/ATen/native/xpu/NestedTensorTransformerFunctions.cpp b/src/ATen/native/xpu/NestedTensorTransformerFunctions.cpp new file mode 100644 index 000000000..33d222572 --- /dev/null +++ b/src/ATen/native/xpu/NestedTensorTransformerFunctions.cpp @@ -0,0 +1,108 @@ +#include +#include +#include +#include + +namespace at::native { +namespace { +int64_t padded_tensor_numel(const Tensor& sizes) { + const auto sizes_num_rows = sizes.sizes()[0]; + const auto sizes_row_length = sizes.sizes()[1]; + const auto* sizes_data = sizes.data_ptr(); + int64_t numel = 0; + for (const auto row_num : c10::irange(sizes_num_rows)) { + const auto* row_ptr = sizes_data + row_num * sizes_row_length; + int64_t prod = 1; + for (const auto idx : c10::irange(sizes_row_length)) { + prod *= row_ptr[idx]; + } + numel += prod; + } + return numel; +} +} // namespace +Tensor nested_from_padded_xpu( + const Tensor& padded, + const Tensor& sizes, + bool do_transform_0213) { + if (padded.dim() > 1 && padded.dim() < 5) { + // Instead of erroring, call the generic version + if (!(padded.dim() == 4 && do_transform_0213) && + !(padded.dim() == 3 && !do_transform_0213)) { + return at::native::nested_from_padded_generic( + padded, sizes, do_transform_0213); + } + if (padded.dtype() != at::kFloat && padded.dtype() != kHalf) { + TORCH_WARN_ONCE( + "nested_from_padded CUDA kernels only support fp32/fp16; falling " + "back to slower generic kernel"); + return at::native::nested_from_padded_generic( + padded, sizes, do_transform_0213); + } + Tensor target_offsets = + at::native::NestedTensor_batch_offsets_from_size_tensor(sizes, 0); + Tensor padded_sizes_tensor = at::tensor(padded.sizes()); + Tensor output = at::empty({padded_tensor_numel(sizes)}, padded.options()); + Tensor target_size_sizes = sizes.reshape(-1); + + target_offsets = target_offsets.to(at::Device(kXPU), at::kInt); + padded_sizes_tensor = padded_sizes_tensor.to(at::Device(kXPU), at::kInt); + target_size_sizes = target_size_sizes.to(at::Device(kXPU), at::kInt); + + auto output_size_ptr = target_size_sizes.data_ptr(); + auto input_size_ptr = padded_sizes_tensor.data_ptr(); + auto offsets_ptr = target_offsets.data_ptr(); + + Tensor padded_contiguous = padded.contiguous(); + + if (padded.dtype() == at::kFloat) { + if (do_transform_0213) { + xpu::launch_remove_padding_transform0213_kernel( + padded_contiguous.data_ptr(), + output.data_ptr(), + offsets_ptr, + input_size_ptr, + output_size_ptr, + padded_contiguous.dim() - 2, + padded_contiguous.sizes()[0]); + } else { + xpu::launch_remove_padding_kernel( + padded_contiguous.data_ptr(), + output.data_ptr(), + offsets_ptr, + input_size_ptr, + output_size_ptr, + padded_contiguous.dim() - 1, + padded_contiguous.sizes()[0]); + } + } else if (padded.dtype() == at::kHalf) { + if (do_transform_0213) { + xpu::launch_remove_padding_transform0213_kernel( + padded_contiguous.data_ptr(), + output.data_ptr(), + offsets_ptr, + input_size_ptr, + output_size_ptr, + padded_contiguous.dim() - 2, + padded_contiguous.sizes()[0]); + } else { + xpu::launch_remove_padding_kernel( + padded_contiguous.data_ptr(), + output.data_ptr(), + offsets_ptr, + input_size_ptr, + output_size_ptr, + padded_contiguous.dim() - 1, + padded_contiguous.sizes()[0]); + } + } else { + AT_ERROR("Only support fp32/fp16 for padded input"); + } + return at::detail::make_tensor( + std::move(output), sizes); + } else { + return at::native::nested_from_padded_generic(padded, sizes); + } +} + +} // namespace at::native \ No newline at end of file diff --git a/src/ATen/native/xpu/NestedTensorTransformerFunctions.h b/src/ATen/native/xpu/NestedTensorTransformerFunctions.h new file mode 100644 index 000000000..db60e7113 --- /dev/null +++ b/src/ATen/native/xpu/NestedTensorTransformerFunctions.h @@ -0,0 +1,26 @@ +#pragma once +#include + +namespace at::native::xpu { + +template +TORCH_XPU_API void launch_remove_padding_kernel( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + int batch_size); + +template +TORCH_XPU_API void launch_remove_padding_transform0213_kernel( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + const int batch_size); + +} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/NestedTensorTransformerFunctionsKernels.cpp b/src/ATen/native/xpu/sycl/NestedTensorTransformerFunctionsKernels.cpp new file mode 100644 index 000000000..41dacfa91 --- /dev/null +++ b/src/ATen/native/xpu/sycl/NestedTensorTransformerFunctionsKernels.cpp @@ -0,0 +1,290 @@ +#include +#include + +// keep align with cuda, global range0 is set to output_batch_size, global_range +// for dim1 is set to 16, +#define GRID_DIM_Y 16 +#define BLOCK_DIM 1024 + +namespace at::native::xpu { + +template +struct remove_padding_functor { + void operator()(sycl::nd_item<2> item) const { + const int batch_id = item.get_group(0); + const int grid_id = item.get_group(1); + const int tid = item.get_local_id()[0] + grid_id * BLOCK_DIM; + const int grainsize = GRID_DIM_Y * BLOCK_DIM; + const int offset = offsets[batch_id]; + + const int* sizes_i = output_sizes + batch_id * output_dim; + const int numel_i = sizes_i[0] * sizes_i[1] * sizes_i[2]; + int input_offset = + batch_id * input_sizes[1] * input_sizes[2] * input_sizes[3]; + for (int ii = 0; ii < (numel_i / grainsize); ii++) { + const int i = ii * grainsize + tid; + const int i0 = i / (sizes_i[1] * sizes_i[2]); + const int i1 = (i % (sizes_i[1] * sizes_i[2])) / sizes_i[2]; + const int i2 = i % sizes_i[2]; + const int i0_offset = i0 * input_sizes[2] * input_sizes[3]; + const int i1_offset = i1 * input_sizes[3]; + output[offset + i] = input[input_offset + i0_offset + i1_offset + i2]; + } + const int i = (numel_i / grainsize) * grainsize + tid; + if (i < numel_i) { + const int i0 = i / (sizes_i[1] * sizes_i[2]); + const int i1 = (i % (sizes_i[1] * sizes_i[2])) / sizes_i[2]; + const int i2 = i % sizes_i[2]; + const int i0_offset = i0 * input_sizes[2] * input_sizes[3]; + const int i1_offset = i1 * input_sizes[3]; + output[offset + i] = input[input_offset + i0_offset + i1_offset + i2]; + } + } + + remove_padding_functor( + const T* input_, + T* output_, + const int* offsets_, + const int* input_sizes_, + const int* output_sizes_, + int output_dim_, + const int batch_size_) + : input(input_), + output(output_), + offsets(offsets_), + input_sizes(input_sizes_), + output_sizes(output_sizes_), + output_dim(output_dim_), + batch_size(batch_size_) {} + + private: + const T* input; + T* output; + const int* offsets; + const int* input_sizes; + const int* output_sizes; + int output_dim; + const int batch_size; +}; + +template +struct remove_padding_2_functor { + void operator()(sycl::nd_item<2> item) const { + const int batch_id = item.get_group(0); + const int grid_id = item.get_group(1); + const int tid = item.get_local_id()[0] + grid_id * BLOCK_DIM; + const int grainsize = GRID_DIM_Y * BLOCK_DIM; + const int offset = offsets[batch_id]; + const int* sizes_i = output_sizes + batch_id * output_dim; + const int numel_i = sizes_i[0] * sizes_i[1]; + int input_offset = batch_id * input_sizes[1] * input_sizes[2]; + for (int ii = 0; ii < (numel_i / grainsize); ii++) { + const int i = ii * grainsize + tid; + const int i0 = i / sizes_i[1]; + const int i1 = i % sizes_i[1]; + const int i0_offset = i0 * input_sizes[2]; + output[offset + i] = input[input_offset + i0_offset + i1]; + } + const int i = (numel_i / grainsize) * grainsize + tid; + if (i < numel_i) { + const int i0 = i / sizes_i[1]; + const int i1 = i % sizes_i[1]; + const int i0_offset = i0 * input_sizes[2]; + output[offset + i] = input[input_offset + i0_offset + i1]; + } + } + + remove_padding_2_functor( + const T* input_, + T* output_, + const int* offsets_, + const int* input_sizes_, + const int* output_sizes_, + int output_dim_, + const int batch_size_) + : input(input_), + output(output_), + offsets(offsets_), + input_sizes(input_sizes_), + output_sizes(output_sizes_), + output_dim(output_dim_), + batch_size(batch_size_) {} + + const T* input; + T* output; + const int* offsets; + const int* input_sizes; + const int* output_sizes; + int output_dim; + const int batch_size; +}; + +template +struct remove_padding_transform0213_functor { + void operator()(sycl::nd_item<2> item) const { + const int batch_id = item.get_group(0); + const int grid_id = item.get_group(1); + const int tid = item.get_local_id()[0] + grid_id * BLOCK_DIM; + const int grainsize = GRID_DIM_Y * BLOCK_DIM; + const int offset = offsets[batch_id]; + const int* sizes_i = output_sizes + batch_id * output_dim; + const int numel_i = sizes_i[0] * sizes_i[1]; + int input_offset = + batch_id * input_sizes[1] * input_sizes[2] * input_sizes[3]; + for (int ii = 0; ii < (numel_i / grainsize); ii++) { + const int i = ii * grainsize + tid; + const int i2 = i / sizes_i[1]; + const int i13 = i % sizes_i[1]; + const int i1 = i13 / (sizes_i[1] / input_sizes[1]); + const int i3 = i13 % (sizes_i[1] / input_sizes[1]); + + output[offset + i] = input + [input_offset + i1 * input_sizes[2] * input_sizes[3] + + i2 * input_sizes[3] + i3]; + } + const int i = (numel_i / grainsize) * grainsize + tid; + if (i < numel_i) { + const int i2 = i / sizes_i[1]; + const int i13 = i % sizes_i[1]; + const int i1 = i13 / (sizes_i[1] / input_sizes[1]); + const int i3 = i13 % (sizes_i[1] / input_sizes[1]); + output[offset + i] = input + [input_offset + i1 * input_sizes[2] * input_sizes[3] + + i2 * input_sizes[3] + i3]; + } + } + + remove_padding_transform0213_functor( + const T* input_, + T* output_, + const int* offsets_, + const int* input_sizes_, + const int* output_sizes_, + int output_dim_, + const int batch_size_) + : input(input_), + output(output_), + offsets(offsets_), + input_sizes(input_sizes_), + output_sizes(output_sizes_), + output_dim(output_dim_), + batch_size(batch_size_) {} + + const T* input; + T* output; + const int* offsets; + const int* input_sizes; + const int* output_sizes; + int output_dim; + const int batch_size; +}; + +template +void launch_remove_padding_transform0213_kernel( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + const int batch_size) { + TORCH_CHECK( + output_dim == 2, + "remove padding transform0213 only support output dim == 2"); + + auto queue = getCurrentSYCLQueue(); + auto kfn = remove_padding_transform0213_functor( + input, + output, + offsets, + input_sizes, + output_sizes, + output_dim, + batch_size); + int64_t max_wg_size = syclMaxWorkGroupSize(kfn); + sycl::range<2> global_range{(size_t)batch_size * max_wg_size, GRID_DIM_Y}; + sycl::range<2> local_range{(size_t)max_wg_size, 1}; + + sycl_kernel_submit(global_range, local_range, queue, kfn); +} + +template +void launch_remove_padding_kernel( + const T* input, + T* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + const int batch_size) { + auto queue = getCurrentSYCLQueue(); + + if (output_dim == 2) { + auto kfn = remove_padding_2_functor( + input, + output, + offsets, + input_sizes, + output_sizes, + output_dim, + batch_size); + int64_t max_wg_size = syclMaxWorkGroupSize(kfn); + sycl::range<2> global_range{(size_t)batch_size * max_wg_size, GRID_DIM_Y}; + sycl::range<2> local_range{(size_t)max_wg_size, 1}; + + sycl_kernel_submit(global_range, local_range, queue, kfn); + } else { + auto kfn = remove_padding_functor( + input, + output, + offsets, + input_sizes, + output_sizes, + output_dim, + batch_size); + + int64_t max_wg_size = syclMaxWorkGroupSize(kfn); + sycl::range<2> global_range{(size_t)batch_size * max_wg_size, GRID_DIM_Y}; + sycl::range<2> local_range{(size_t)max_wg_size, 1}; + + sycl_kernel_submit(global_range, local_range, queue, kfn); + } +} + +template void launch_remove_padding_kernel( + const float* input, + float* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + int batch_size); + +template void launch_remove_padding_kernel( + const c10::Half* input, + c10::Half* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + int batch_size); + +template void launch_remove_padding_transform0213_kernel( + const float* input, + float* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + const int batch_size); + +template void launch_remove_padding_transform0213_kernel( + const c10::Half* input, + c10::Half* output, + const int* offsets, + const int* input_sizes, + const int* output_sizes, + int output_dim, + const int batch_size); + +} // namespace at::native::xpu \ No newline at end of file diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 40b710c12..fa0dbe8ab 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -4338,6 +4338,12 @@ XPU: roll_xpu autogen: roll.out +- func: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor + device_check: NoCheck # cpu_nested_shape_example will always be on CPU + dispatch: + XPU: nested_from_padded_xpu + autogen: _nested_from_padded.out + - func: avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn structured: True