Skip to content

Commit

Permalink
Add reflection_pad1d/3d and backwards (#482)
Browse files Browse the repository at this point in the history
Signed-off-by: majing <[email protected]>
Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
majing921201 and fengyuan14 authored Jul 11, 2024
1 parent 0253fb9 commit 9c45ee2
Show file tree
Hide file tree
Showing 5 changed files with 865 additions and 12 deletions.
352 changes: 352 additions & 0 deletions src/ATen/native/xpu/ReflectionPad.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,322 @@
#include <ATen/Context.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/op_registration/adaption.h>
#include <ATen/native/Padding.h>
#include <ATen/native/xpu/sycl/ReflectionPadKernels.h>
#include <ATen/xpu/XPUNativeFunctions.h>
#include <comm/RegisterUtils.h>

namespace at {

void reflection_pad1d_meta(
Tensor& output,
const Tensor& input,
IntArrayRef padding) {
int64_t dim_plane = 0;
int64_t dim_w = 1;
int64_t nbatch = 1;

if (input.ndimension() == 3) {
nbatch = input.size(0);
dim_w++;
dim_plane++;
}

at::native::padding::check_valid_input<1>(input, padding);

/* sizes */
auto pad_l = padding[0];
auto pad_r = padding[1];

int64_t nplane = input.size(dim_plane);
int64_t input_w = input.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;

TORCH_CHECK(
pad_l < input_w && pad_r < input_w,
"Argument #4: Padding size "
"should be less than the corresponding input dimension, but got: padding (",
pad_l,
", ",
pad_r,
") at dimension ",
dim_w,
" of input ",
input.sizes());

TORCH_CHECK(
output_w >= 1,
"input (W: ",
input_w,
") is too small. Calculated output W: ",
output_w);

if (output.defined()) {
if (input.ndimension() == 2) {
xpu::resize_out(output, {nplane, output_w}, {}, input.options());
} else {
xpu::resize_out(output, {nbatch, nplane, output_w}, {}, input.options());
}
} else {
if (input.ndimension() == 2) {
output = xpu::create_out({nplane, output_w}, {}, input.options());
} else {
output = xpu::create_out({nbatch, nplane, output_w}, {}, input.options());
}
}
}

void reflection_pad1d_backward_meta(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding) {
int64_t dim_w = 1;
if (input.ndimension() == 3) {
dim_w++;
}

/* sizes */
auto pad_l = padding[0];
auto pad_r = padding[1];
int64_t input_w = input.size(dim_w);
int64_t output_w = input_w + pad_l + pad_r;

TORCH_CHECK(
pad_l < input_w && pad_r < input_w,
"Argument #4: Padding size "
"should be less than the corresponding input dimension, but got: padding (",
pad_l,
", ",
pad_r,
") at dimension ",
dim_w,
" of input ",
input.sizes());

TORCH_CHECK(
output_w == grad_output.size(dim_w),
"grad_output width unexpected."
" Expected: ",
output_w,
", Got: ",
grad_output.size(dim_w));

if (grad_input.defined()) {
xpu::resize_out(grad_input, input.sizes(), {}, input.options());
} else {
grad_input = xpu::create_out(input.sizes(), {}, input.options());
}
}

void reflection_pad3d_meta(
Tensor& output,
const Tensor& input,
IntArrayRef padding) {
int64_t pad_left = padding[0];
int64_t pad_right = padding[1];
int64_t pad_top = padding[2];
int64_t pad_bottom = padding[3];
int64_t pad_front = padding[4];
int64_t pad_back = padding[5];
int64_t dim_w = 3;
int64_t dim_h = 2;
int64_t dim_d = 1;
int64_t dim_plane = 0;

at::native::padding::check_valid_input<3>(input, padding);

bool batch_mode = (input.dim() == 5);
if (batch_mode) {
dim_w++;
dim_h++;
dim_d++;
dim_plane++;
}

int64_t nplane = input.size(dim_plane);
int64_t input_d = input.size(dim_d);
int64_t input_h = input.size(dim_h);
int64_t input_w = input.size(dim_w);
int64_t output_d = input_d + pad_front + pad_back;
int64_t output_h = input_h + pad_top + pad_bottom;
int64_t output_w = input_w + pad_left + pad_right;

TORCH_CHECK(
pad_left < input_w && pad_right < input_w,
"Argument #4: Padding size "
"should be less than the corresponding input dimension, but got: padding (",
pad_left,
", ",
pad_right,
") at dimension ",
dim_w,
" of input ",
input.sizes());
TORCH_CHECK(
pad_top < input_h && pad_bottom < input_h,
"Argument #6: Padding size "
"should be less than the corresponding input dimension, but got: padding (",
pad_top,
", ",
pad_bottom,
") at dimension ",
dim_h,
" of input ",
input.sizes());
TORCH_CHECK(
pad_front < input_d && pad_back < input_d,
"Argument #8: Padding size "
"should be less than the corresponding input dimension, but got: padding (",
pad_front,
", ",
pad_back,
") at dimension ",
dim_d,
" of input ",
input.sizes());

TORCH_CHECK(
output_w >= 1 || output_h >= 1 || output_d >= 1,
"input (D: ",
input_d,
" H: ",
input_h,
", W: ",
input_w,
") is too small."
" Calculated output D: ",
output_d,
" H: ",
output_h,
" W: ",
output_w);

if (output.defined()) {
if (batch_mode) {
xpu::resize_out(
output,
{input.size(0), nplane, output_d, output_h, output_w},
{},
input.options());
} else {
xpu::resize_out(
output, {nplane, output_d, output_h, output_w}, {}, input.options());
}
} else {
if (batch_mode) {
output = xpu::create_out(
{input.size(0), nplane, output_d, output_h, output_w},
{},
input.options());
} else {
output = xpu::create_out(
{nplane, output_d, output_h, output_w}, {}, input.options());
}
}
}

void reflection_pad3d_backward_meta(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding) {
TORCH_CHECK(padding.size() == 6, "padding size is expected to be 6");
TORCH_CHECK(input.dim() > 3);
TORCH_CHECK(grad_output.dim() == input.dim());

int64_t pad_left = padding[0];
int64_t pad_right = padding[1];
int64_t pad_top = padding[2];
int64_t pad_bottom = padding[3];
int64_t pad_front = padding[4];
int64_t pad_back = padding[5];
int64_t dim_w = 3;
int64_t dim_h = 2;
int64_t dim_d = 1;

if (input.dim() == 5) {
// batch mode
dim_w++;
dim_h++;
dim_d++;
}

int64_t input_d = input.size(dim_d);
int64_t input_h = input.size(dim_h);
int64_t input_w = input.size(dim_w);
int64_t output_d = input_d + pad_front + pad_back;
int64_t output_h = input_h + pad_top + pad_bottom;
int64_t output_w = input_w + pad_left + pad_right;

TORCH_CHECK(
output_w == grad_output.size(dim_w),
"grad_output width unexpected."
" Expected: ",
output_w,
", Got: ",
grad_output.size(dim_w));
TORCH_CHECK(
output_h == grad_output.size(dim_h),
"grad_output height unexpected."
" Expected: ",
output_h,
", Got: ",
grad_output.size(dim_h));
TORCH_CHECK(
output_d == grad_output.size(dim_d),
"grad_output depth unexpected."
" Expected: ",
output_d,
", Got: ",
grad_output.size(dim_d));

if (grad_input.defined()) {
xpu::resize_out(grad_input, input.sizes(), {}, input.options());
} else {
grad_input = xpu::create_out(input.sizes(), {}, input.options());
}
}

Tensor XPUNativeFunctions::reflection_pad1d(
const Tensor& input,
IntArrayRef padding) {
Tensor output;
reflection_pad1d_meta(output, input, padding);
native::xpu::reflection_pad1d_kernel(output, input, padding);
return output;
}

Tensor& XPUNativeFunctions::reflection_pad1d_out(
const Tensor& input,
IntArrayRef padding,
Tensor& output) {
reflection_pad1d_meta(output, input, padding);
native::xpu::reflection_pad1d_kernel(output, input, padding);
return output;
}

Tensor XPUNativeFunctions::reflection_pad1d_backward(
const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding) {
Tensor grad_input;
reflection_pad1d_backward_meta(grad_input, grad_output, input, padding);
native::xpu::reflection_pad1d_backward_kernel(
grad_input, grad_output, input, padding);
return grad_input;
}

Tensor& XPUNativeFunctions::reflection_pad1d_backward_out(
const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding,
Tensor& grad_input) {
native::xpu::reflection_pad1d_backward_kernel(
grad_input, grad_output, input, padding);
return grad_input;
}

Tensor& XPUNativeFunctions::reflection_pad2d_out(
const Tensor& input,
IntArrayRef padding,
Expand Down Expand Up @@ -49,4 +361,44 @@ Tensor XPUNativeFunctions::reflection_pad2d_backward(
return grad_input;
}

Tensor XPUNativeFunctions::reflection_pad3d(
const Tensor& input,
IntArrayRef padding) {
Tensor output;
reflection_pad3d_meta(output, input, padding);
native::xpu::reflection_pad3d_kernel(output, input, padding);
return output;
}

Tensor& XPUNativeFunctions::reflection_pad3d_out(
const Tensor& input,
IntArrayRef padding,
Tensor& output) {
reflection_pad3d_meta(output, input, padding);
native::xpu::reflection_pad3d_kernel(output, input, padding);
return output;
}

Tensor XPUNativeFunctions::reflection_pad3d_backward(
const Tensor& grad_output,
const Tensor& input,
at::IntArrayRef padding) {
Tensor grad_input;
reflection_pad3d_backward_meta(grad_input, grad_output, input, padding);
native::xpu::reflection_pad3d_backward_kernel(
grad_input, grad_output, input, padding);
return grad_input;
}

Tensor& XPUNativeFunctions::reflection_pad3d_backward_out(
const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding,
Tensor& grad_input) {
reflection_pad3d_backward_meta(grad_input, grad_output, input, padding);
native::xpu::reflection_pad3d_backward_kernel(
grad_input, grad_output, input, padding);
return grad_input;
}

} // namespace at
4 changes: 0 additions & 4 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"prod",
"prod.int_out",
"put_",
"reflection_pad1d_backward.grad_input",
"reflection_pad1d.out",
"renorm.out",
"repeat_interleave.Tensor",
"replication_pad1d_backward.grad_input",
Expand Down Expand Up @@ -380,8 +378,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"vdot",
"xlogy.OutTensor",
"_upsample_bicubic2d_aa.out",
"reflection_pad3d.out",
"reflection_pad3d_backward.grad_input",
"replication_pad3d.out",
"replication_pad3d_backward",
};
Expand Down
Loading

0 comments on commit 9c45ee2

Please sign in to comment.