Skip to content

Commit

Permalink
Add aten::multi_margin_loss and its variants (#895)
Browse files Browse the repository at this point in the history
- [x] multi_margin_loss
- [x] multi_margin_loss.out
- [x] multi_margin_loss_backward
- [x] multi_margin_loss_backward.grad_input

---------

Co-authored-by: Yutao Xu <[email protected]>
  • Loading branch information
chunhuanMeng and xytintel authored Oct 28, 2024
1 parent bdbda35 commit b189259
Show file tree
Hide file tree
Showing 7 changed files with 662 additions and 2 deletions.
64 changes: 64 additions & 0 deletions src/ATen/native/xpu/LossMultiMargin.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/xpu/sycl/MultiMarginLossKernels.h>

#include <ATen/ops/empty.h>
#include <xpu/ATen/ops/multi_margin_loss_backward_native.h>
#include <xpu/ATen/ops/multi_margin_loss_native.h>

namespace at::native {

Tensor& multi_margin_loss_xpu_out(
const Tensor& self,
const Tensor& target,
const Scalar& p,
const Scalar& margin,
const std::optional<Tensor>& weight,
int64_t reduction,
Tensor& out) {
xpu::multi_margin_loss_kernel(
self, target, p, margin, weight, reduction, out);
return out;
}

Tensor multi_margin_loss_xpu(
const Tensor& self,
const Tensor& target,
const Scalar& p,
const Scalar& margin,
const std::optional<Tensor>& weight,
int64_t reduction) {
auto out = at::empty({0}, self.options());
xpu::multi_margin_loss_kernel(
self, target, p, margin, weight, reduction, out);
return out;
}

Tensor& multi_margin_loss_xpu_backward_out(
const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
const Scalar& p,
const Scalar& margin,
const std::optional<Tensor>& weight,
int64_t reduction,
Tensor& grad_input) {
xpu::multi_margin_loss_backward_kernel(
grad_output, self, target, p, margin, weight, reduction, grad_input);
return grad_input;
}

Tensor multi_margin_loss_xpu_backward(
const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
const Scalar& p,
const Scalar& margin,
const std::optional<Tensor>& weight,
int64_t reduction) {
auto grad_input = at::empty({0}, self.options());
xpu::multi_margin_loss_backward_kernel(
grad_output, self, target, p, margin, weight, reduction, grad_input);
return grad_input;
}

} // namespace at::native
2 changes: 0 additions & 2 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"lu_unpack.out",
"multilabel_margin_loss_backward",
"multilabel_margin_loss_forward",
"multi_margin_loss",
"multi_margin_loss_backward",
"ormqr",
"rrelu_with_noise",
"_scaled_dot_product_efficient_attention",
Expand Down
Loading

0 comments on commit b189259

Please sign in to comment.