From b1892595135fc4e92c6b2c76e8db7268fc9d7609 Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Mon, 28 Oct 2024 14:39:40 +0800 Subject: [PATCH] Add aten::multi_margin_loss and its variants (#895) - [x] multi_margin_loss - [x] multi_margin_loss.out - [x] multi_margin_loss_backward - [x] multi_margin_loss_backward.grad_input --------- Co-authored-by: Yutao Xu --- src/ATen/native/xpu/LossMultiMargin.cpp | 64 ++ src/ATen/native/xpu/XPUFallback.template | 2 - .../xpu/sycl/MultiMarginLossKernels.cpp | 548 ++++++++++++++++++ .../native/xpu/sycl/MultiMarginLossKernels.h | 25 + test/xpu/xpu_test_utils.py | 1 + yaml/native/native_functions.yaml | 20 + yaml/xpu_functions.yaml | 4 + 7 files changed, 662 insertions(+), 2 deletions(-) create mode 100644 src/ATen/native/xpu/LossMultiMargin.cpp create mode 100644 src/ATen/native/xpu/sycl/MultiMarginLossKernels.cpp create mode 100644 src/ATen/native/xpu/sycl/MultiMarginLossKernels.h diff --git a/src/ATen/native/xpu/LossMultiMargin.cpp b/src/ATen/native/xpu/LossMultiMargin.cpp new file mode 100644 index 000000000..2db427135 --- /dev/null +++ b/src/ATen/native/xpu/LossMultiMargin.cpp @@ -0,0 +1,64 @@ +#include +#include + +#include +#include +#include + +namespace at::native { + +Tensor& multi_margin_loss_xpu_out( + const Tensor& self, + const Tensor& target, + const Scalar& p, + const Scalar& margin, + const std::optional& weight, + int64_t reduction, + Tensor& out) { + xpu::multi_margin_loss_kernel( + self, target, p, margin, weight, reduction, out); + return out; +} + +Tensor multi_margin_loss_xpu( + const Tensor& self, + const Tensor& target, + const Scalar& p, + const Scalar& margin, + const std::optional& weight, + int64_t reduction) { + auto out = at::empty({0}, self.options()); + xpu::multi_margin_loss_kernel( + self, target, p, margin, weight, reduction, out); + return out; +} + +Tensor& multi_margin_loss_xpu_backward_out( + const Tensor& grad_output, + const Tensor& self, + const Tensor& target, + const Scalar& p, + const Scalar& margin, + const std::optional& weight, + int64_t reduction, + Tensor& grad_input) { + xpu::multi_margin_loss_backward_kernel( + grad_output, self, target, p, margin, weight, reduction, grad_input); + return grad_input; +} + +Tensor multi_margin_loss_xpu_backward( + const Tensor& grad_output, + const Tensor& self, + const Tensor& target, + const Scalar& p, + const Scalar& margin, + const std::optional& weight, + int64_t reduction) { + auto grad_input = at::empty({0}, self.options()); + xpu::multi_margin_loss_backward_kernel( + grad_output, self, target, p, margin, weight, reduction, grad_input); + return grad_input; +} + +} // namespace at::native diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 51b551765..052cb23ee 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -199,8 +199,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "lu_unpack.out", "multilabel_margin_loss_backward", "multilabel_margin_loss_forward", - "multi_margin_loss", - "multi_margin_loss_backward", "ormqr", "rrelu_with_noise", "_scaled_dot_product_efficient_attention", diff --git a/src/ATen/native/xpu/sycl/MultiMarginLossKernels.cpp b/src/ATen/native/xpu/sycl/MultiMarginLossKernels.cpp new file mode 100644 index 000000000..15c9b560c --- /dev/null +++ b/src/ATen/native/xpu/sycl/MultiMarginLossKernels.cpp @@ -0,0 +1,548 @@ +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native::xpu { + +using namespace at::xpu; + +void multi_margin_loss_shape_check( + int64_t& nframe, + int64_t& dim, + const int64_t& ndims, + const Tensor& input, + const Tensor& target, + const std::optional& weight) { + TORCH_CHECK( + (ndims == 2 && input.size(1) != 0) || + (ndims == 1 && input.size(0) != 0) || ndims == 0, + "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ", + input.sizes()); + + if (ndims <= 1) { + nframe = 1; + dim = ndims == 0 ? 1 : input.size(0); + } else { + nframe = input.size(0); + dim = input.size(1); + } + + TORCH_CHECK( + target.dim() <= 1 && target.numel() == nframe, + "inconsistent target size, expected ", + nframe, + " but got ", + target.sizes()); + if (weight && weight->defined()) { + TORCH_CHECK( + weight->dim() <= 1 && weight->numel() == dim, + "inconsistent weight size, expected ", + dim, + " but got ", + weight->sizes()); + } +} + +template +struct MultiMarginLossForwardKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<1> item) const { + int k = item.get_group(0); + const scalar_t* input_k = input_ + k * dim_; + scalar_t* output_k = output_ + k; + int target_k = static_cast(target_[k]); + SYCL_KERNEL_ASSERT( + target_k >= 0 && target_k < dim_ && "target index is out of bounds"); + scalar_t input_target_k = input_k[target_k]; + int i_start = item.get_local_linear_id(); + int i_end = dim_; + int i_step = item.get_local_range(0); + + smem_[item.get_local_linear_id()] = 0; + for (int i = i_start; i < i_end; i += i_step) { + scalar_t z = margin_ - input_target_k + input_k[i]; + if (i == target_k) { + continue; + } + + if (z > 0) { + scalar_t h = (P == 1) ? z : z * z; + if (weights_) { + h *= weights_[target_k]; + } + smem_[item.get_local_linear_id()] += h; + } + } + item.barrier(sycl_local_fence); + + // reduce + if (item.get_local_linear_id() == 0) { + accscalar_t sum = 0; + for (int i = 0; i < item.get_local_range(0); i++) + sum += smem_[i]; + + const int denom = sizeAverage_ ? nframe_ * dim_ : dim_; + *output_k = static_cast(sum / denom); + } + } + void sycl_ker_config_convention(sycl::handler& cgh) { + smem_ = sycl_local_acc_t(smem_size_, cgh); + } + MultiMarginLossForwardKernelFunctor( + scalar_t* output, + const scalar_t* input, + const int64_t* target, + const scalar_t* weights, + int nframe, + int dim, + bool sizeAverage, + scalar_t margin, + int64_t smem_size) + : output_(output), + input_(input), + target_(target), + weights_(weights), + nframe_(nframe), + dim_(dim), + sizeAverage_(sizeAverage), + margin_(margin), + smem_size_(smem_size) {} + + private: + scalar_t* output_; + const scalar_t* input_; + const int64_t* target_; + const scalar_t* weights_; + int nframe_; + int dim_; + bool sizeAverage_; + scalar_t margin_; + int64_t smem_size_; + sycl_local_acc_t smem_; +}; + +template +struct MultiMarginLossBackwardKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<1> item) const { + int k = item.get_group(0); + const scalar_t* input_k = input_ + k * dim_; + scalar_t* gradInput_k = gradInput_ + k * dim_; + int target_k = static_cast(target_[k]); + scalar_t input_target_k = input_k[target_k]; + + const scalar_t* gradOutput_k = gradOutput_; + if (!reduce_) { + gradOutput_k += k; + } + const int denom = sizeAverage_ && reduce_ ? nframe_ * dim_ : dim_; + const accscalar_t g = accscalar_t(1) / static_cast(denom); + int i_start = item.get_local_linear_id(); + int i_end = dim_; + int i_step = item.get_local_range(0); + + smem_[item.get_local_linear_id()] = 0; + for (int i = i_start; i < i_end; i += i_step) { + scalar_t z = margin_ - input_target_k + input_k[i]; + if (i == target_k) { + continue; + } + + if (z > 0) { + accscalar_t h = (P == 1) ? g : 2 * g * z; + if (weights_) { + h *= weights_[target_k]; + } + + smem_[item.get_local_linear_id()] -= static_cast(h); + gradInput_k[i] = static_cast(h); + } else { + gradInput_k[i] = static_cast(0); + } + } + item.barrier(sycl_local_fence); + + // reduce + if (item.get_local_linear_id() == 0) { + accscalar_t gradInput_target_k = 0; + + for (int i = 0; i < item.get_local_range(0); i++) { + gradInput_target_k += smem_[i]; + } + + gradInput_k[target_k] = static_cast(gradInput_target_k); + } + for (int i = i_start; i < i_end; i += i_step) { + gradInput_k[i] *= *gradOutput_k; + } + } + void sycl_ker_config_convention(sycl::handler& cgh) { + smem_ = sycl_local_acc_t(smem_size_, cgh); + } + MultiMarginLossBackwardKernelFunctor( + scalar_t* gradInput, + const scalar_t* gradOutput, + const scalar_t* input, + const int64_t* target, + const scalar_t* weights, + int nframe, + int dim, + bool sizeAverage, + scalar_t margin, + bool reduce, + int64_t smem_size) + : gradInput_(gradInput), + gradOutput_(gradOutput), + input_(input), + target_(target), + weights_(weights), + nframe_(nframe), + dim_(dim), + sizeAverage_(sizeAverage), + margin_(margin), + reduce_(reduce), + smem_size_(smem_size) {} + + private: + scalar_t* gradInput_; + const scalar_t* gradOutput_; + const scalar_t* input_; + const int64_t* target_; + const scalar_t* weights_; + int nframe_; + int dim_; + bool sizeAverage_; + scalar_t margin_; + bool reduce_; + int64_t smem_size_; + sycl_local_acc_t smem_; +}; + +Tensor& multi_margin_loss_kernel( + const Tensor& input_, + const Tensor& target_, + const Scalar& p_, + const Scalar& margin_, + const std::optional& weights_, + int64_t reduction, + Tensor& out_) { + auto p = p_.toLong(); + int64_t nframe, dim; + const auto ndims = input_.dim(); + TORCH_CHECK( + p == 1 || p == 2, + "multi_margin_loss: Invalid p, expected 1 or 2 but got ", + p); + + multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_, weights_); + + // produce a scalar output for 1d input + if (reduction == Reduction::None && target_.dim() > 0) { + resize_output(out_, {nframe}); + } else { + resize_output(out_, {}); + } + if (input_.numel() == 0) { + return out_; + } + + auto input = input_.contiguous(); + auto target = target_.contiguous(); + Tensor weights; + if (weights_ && weights_->defined()) { + weights = weights_->contiguous(); + } + auto out = + (out_.is_contiguous() ? out_ : at::empty(out_.sizes(), input.options())); + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "multi_margin_loss_xpu", [&] { + const scalar_t margin = margin_.to(); + using accscalar_t = acc_type_device; + if (input.dim() <= 1) { + TORCH_CHECK( + target.dim() <= 1 && target.numel() == nframe, + "inconsistent target size"); + + if (p == 1) { + using KernelClass = + MultiMarginLossForwardKernelFunctor<1, scalar_t, accscalar_t>; + int64_t local_size = syclMaxWorkGroupSize(); + auto kfn = KernelClass( + out.mutable_data_ptr(), + input.const_data_ptr(), + target.const_data_ptr(), + weights.defined() ? weights.const_data_ptr() + : nullptr, + 1, + input.dim() < 1 ? input.numel() : input.sizes()[0], + reduction == at::Reduction::Mean, + margin, + local_size); + sycl_kernel_submit( + local_size, local_size, getCurrentSYCLQueue(), kfn); + } else if (p == 2) { + using KernelClass = + MultiMarginLossForwardKernelFunctor<2, scalar_t, accscalar_t>; + int64_t local_size = syclMaxWorkGroupSize(); + auto kfn = KernelClass( + out.mutable_data_ptr(), + input.const_data_ptr(), + target.const_data_ptr(), + weights.defined() ? weights.const_data_ptr() + : nullptr, + 1, + input.dim() < 1 ? input.numel() : input.sizes()[0], + reduction == at::Reduction::Mean, + margin, + local_size); + sycl_kernel_submit( + local_size, local_size, getCurrentSYCLQueue(), kfn); + } + } else { + auto in_sizes = input.sizes(); + TORCH_INTERNAL_ASSERT(in_sizes.size() == 2); + // allow zero-dim target for 2D input. + TORCH_CHECK( + in_sizes[1] != 0 && target.dim() <= 1 && target.numel() == nframe, + "inconsistent target size"); + + if (reduction == at::Reduction::None) { + if (p == 1) { + using KernelClass = + MultiMarginLossForwardKernelFunctor<1, scalar_t, accscalar_t>; + int64_t local_size = syclMaxWorkGroupSize(); + auto kfn = KernelClass( + out.mutable_data_ptr(), + input.const_data_ptr(), + target.const_data_ptr(), + weights.defined() ? weights.const_data_ptr() + : nullptr, + nframe, + in_sizes[1], + false, + margin, + local_size); + sycl_kernel_submit( + nframe * local_size, local_size, getCurrentSYCLQueue(), kfn); + } else if (p == 2) { + using KernelClass = + MultiMarginLossForwardKernelFunctor<2, scalar_t, accscalar_t>; + int64_t local_size = syclMaxWorkGroupSize(); + auto kfn = KernelClass( + out.mutable_data_ptr(), + input.const_data_ptr(), + target.const_data_ptr(), + weights.defined() ? weights.const_data_ptr() + : nullptr, + nframe, + in_sizes[1], + false, + margin, + local_size); + sycl_kernel_submit( + nframe * local_size, local_size, getCurrentSYCLQueue(), kfn); + } + } else { + auto tmp_output = at::empty({nframe}, input.options()); + if (p == 1) { + using KernelClass = + MultiMarginLossForwardKernelFunctor<1, scalar_t, accscalar_t>; + int64_t local_size = syclMaxWorkGroupSize(); + auto kfn = KernelClass( + tmp_output.mutable_data_ptr(), + input.const_data_ptr(), + target.const_data_ptr(), + weights.defined() ? weights.const_data_ptr() + : nullptr, + nframe, + in_sizes[1], + reduction == Reduction::Mean, + margin, + local_size); + sycl_kernel_submit( + nframe * local_size, local_size, getCurrentSYCLQueue(), kfn); + + } else if (p == 2) { + using KernelClass = + MultiMarginLossForwardKernelFunctor<2, scalar_t, accscalar_t>; + int64_t local_size = syclMaxWorkGroupSize(); + auto kfn = KernelClass( + tmp_output.mutable_data_ptr(), + input.const_data_ptr(), + target.const_data_ptr(), + weights.defined() ? weights.const_data_ptr() + : nullptr, + nframe, + in_sizes[1], + reduction == Reduction::Mean, + margin, + local_size); + sycl_kernel_submit( + nframe * local_size, local_size, getCurrentSYCLQueue(), kfn); + } + at::sum_out(out, tmp_output, IntArrayRef{}); + } + } + }); + if (!out.is_alias_of(out_)) { + out_.copy_(out); + } + return out_; +} + +Tensor& multi_margin_loss_backward_kernel( + const Tensor& grad_output_, + const Tensor& input_, + const Tensor& target_, + const Scalar& p_, + const Scalar& margin_, + const std::optional& weights_, + int64_t reduction, + Tensor& grad_input_) { + auto p = p_.toLong(); + int64_t nframe, dim; + const auto ndims = input_.dim(); + + TORCH_CHECK( + p == 1 || p == 2, + "multi_margin_loss_backward: Invalid p, expected 1 or 2 but got ", + p); + + multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_, weights_); + resize_output(grad_input_, input_.sizes()); + + if (input_.numel() == 0) { + return grad_input_; + } + + auto input = input_.contiguous(); + auto grad_input = + (grad_input_.is_contiguous() + ? grad_input_ + : at::empty(grad_input_.sizes(), input.options())); + auto grad_output = grad_output_.contiguous(); + auto target = target_.contiguous(); + Tensor weights; + if (weights_ && weights_->defined()) { + weights = weights_->contiguous(); + } + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "multi_margin_loss_backward_xpu", + [&] { + const scalar_t margin = margin_.to(); + using accscalar_t = acc_type_device; + + if (input.dim() <= 1) { + if (p == 1) { + using KernelClass = + MultiMarginLossBackwardKernelFunctor<1, scalar_t, accscalar_t>; + int64_t local_size = syclMaxWorkGroupSize(); + auto kfn = KernelClass( + grad_input.mutable_data_ptr(), + grad_output.const_data_ptr(), + input.const_data_ptr(), + target.const_data_ptr(), + weights.defined() ? weights.const_data_ptr() + : nullptr, + 1, + input.dim() == 0 ? 1 : input.sizes()[0], + reduction == at::Reduction::Mean, + margin, + reduction != at::Reduction::None, + local_size); + sycl_kernel_submit( + local_size, local_size, getCurrentSYCLQueue(), kfn); + + } else if (p == 2) { + using KernelClass = + MultiMarginLossBackwardKernelFunctor<2, scalar_t, accscalar_t>; + int64_t local_size = syclMaxWorkGroupSize(); + auto kfn = KernelClass( + grad_input.mutable_data_ptr(), + grad_output.const_data_ptr(), + input.const_data_ptr(), + target.const_data_ptr(), + weights.defined() ? weights.const_data_ptr() + : nullptr, + 1, + input.dim() == 0 ? 1 : input.sizes()[0], + reduction == at::Reduction::Mean, + margin, + reduction != at::Reduction::None, + local_size); + sycl_kernel_submit( + local_size, local_size, getCurrentSYCLQueue(), kfn); + } + } else { + auto in_sizes = input.sizes(); + TORCH_INTERNAL_ASSERT(in_sizes.size() == 2); + TORCH_CHECK( + (in_sizes[1] != 0) && (target.dim() <= 1) && + (target.numel() == nframe), + "inconsistent target size"); + + if (p == 1) { + using KernelClass = + MultiMarginLossBackwardKernelFunctor<1, scalar_t, accscalar_t>; + int64_t local_size = syclMaxWorkGroupSize(); + auto kfn = KernelClass( + grad_input.mutable_data_ptr(), + grad_output.const_data_ptr(), + input.const_data_ptr(), + target.const_data_ptr(), + weights.defined() ? weights.const_data_ptr() + : nullptr, + nframe, + in_sizes[1], + reduction == at::Reduction::Mean, + margin, + reduction != at::Reduction::None, + local_size); + sycl_kernel_submit( + in_sizes[0] * local_size, + local_size, + getCurrentSYCLQueue(), + kfn); + + } else if (p == 2) { + using KernelClass = + MultiMarginLossBackwardKernelFunctor<2, scalar_t, accscalar_t>; + int64_t local_size = syclMaxWorkGroupSize(); + auto kfn = KernelClass( + grad_input.mutable_data_ptr(), + grad_output.const_data_ptr(), + input.const_data_ptr(), + target.const_data_ptr(), + weights.defined() ? weights.const_data_ptr() + : nullptr, + nframe, + in_sizes[1], + reduction == at::Reduction::Mean, + margin, + reduction != at::Reduction::None, + local_size); + sycl_kernel_submit( + in_sizes[0] * local_size, + local_size, + getCurrentSYCLQueue(), + kfn); + } + } + }); + + if (!grad_input.is_alias_of(grad_input_)) { + grad_input_.copy_(grad_input); + } + return grad_input_; +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/MultiMarginLossKernels.h b/src/ATen/native/xpu/sycl/MultiMarginLossKernels.h new file mode 100644 index 000000000..30f7f00d6 --- /dev/null +++ b/src/ATen/native/xpu/sycl/MultiMarginLossKernels.h @@ -0,0 +1,25 @@ +#pragma once +#include + +namespace at::native::xpu { + +TORCH_XPU_API Tensor& multi_margin_loss_kernel( + const Tensor& input, + const Tensor& target, + const Scalar& p, + const Scalar& margin, + const std::optional& weight, + int64_t reduction, + Tensor& out); + +TORCH_XPU_API Tensor& multi_margin_loss_backward_kernel( + const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + const Scalar& p, + const Scalar& margin, + const std::optional& weight, + int64_t reduction, + Tensor& grad_input); + +} // namespace at::native::xpu diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 9bcce04b3..a2c7522cd 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -202,6 +202,7 @@ "nn.functional.mse_loss", "nn.functional.binary_cross_entropy", "nn.functional.huber_loss", + "nn.functional.multi_margin_loss", "nn.functional.max_unpool2d", "nn.functional.max_unpool3d", "nn.functional.ctc_loss", diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 05097180a..c79bbf899 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -5834,6 +5834,26 @@ tags: dynamic_output_shape autogen: _unique2.out +- func: multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + dispatch: + XPU: multi_margin_loss_xpu_out + +- func: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor + python_module: nn + dispatch: + XPU: multi_margin_loss_xpu + +- func: multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + dispatch: + XPU: multi_margin_loss_xpu_backward_out + +- func: multi_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean) -> Tensor + python_module: nn + dispatch: + XPU: multi_margin_loss_xpu_backward + - func: upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor python_module: nn autogen: upsample_linear1d.vec_out diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index a32a8aa75..e3e681578 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -625,6 +625,10 @@ supported: - acos_ - acos.out - acosh + - multi_margin_loss.out + - multi_margin_loss + - multi_margin_loss_backward.grad_input + - multi_margin_loss_backward - acosh_ - acosh.out - addr