Skip to content

Commit

Permalink
Add aten::binary_cross_entropy/backward and their variants (#561)
Browse files Browse the repository at this point in the history
Task list:
- [x] binary_cross_entropy
- [x] binary_cross_entropy.out
- [x] binary_cross_entropy_backward
- [x] binary_cross_entropy_backward.grad_input

---------

Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
xytintel and fengyuan14 authored Jul 22, 2024
1 parent 0f14bdf commit 5f41843
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 2 deletions.
55 changes: 55 additions & 0 deletions src/ATen/native/xpu/Loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/core/Reduction.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/xpu/sycl/BinaryMiscOpsKernels.h>
#include <ATen/native/xpu/sycl/LossKernels.h>
#include <ATen/native/xpu/sycl/PointwiseOpsKernels.h>
#include <ATen/xpu/XPUNativeFunctions.h>
#include <comm/RegisterUtils.h>
Expand Down Expand Up @@ -79,6 +80,60 @@ Tensor& XPUNativeFunctions::mse_loss_backward_out(
return grad_input;
}

Tensor XPUNativeFunctions::binary_cross_entropy(
const Tensor& self,
const Tensor& target,
const std::optional<Tensor>& weight_opt,
int64_t reduction) {
c10::MaybeOwned<Tensor> weight_maybe_owned =
at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor loss = at::empty_like(self);
return native::xpu::binary_cross_entropy_kernel(
self, target, weight, reduction, loss);
}

Tensor& XPUNativeFunctions::binary_cross_entropy_out(
const Tensor& self,
const Tensor& target,
const std::optional<Tensor>& weight_opt,
int64_t reduction,
Tensor& loss) {
c10::MaybeOwned<Tensor> weight_maybe_owned =
at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
return native::xpu::binary_cross_entropy_kernel(
self, target, weight, reduction, loss);
}

Tensor XPUNativeFunctions::binary_cross_entropy_backward(
const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
const std::optional<Tensor>& weight_opt,
int64_t reduction) {
c10::MaybeOwned<Tensor> weight_maybe_owned =
at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor grad_input = at::empty_like(self);
return native::xpu::binary_cross_entropy_backward_kernel(
grad_output, self, target, weight, reduction, grad_input);
}

Tensor& XPUNativeFunctions::binary_cross_entropy_backward_out(
const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
const std::optional<Tensor>& weight_opt,
int64_t reduction,
Tensor& grad_input) {
c10::MaybeOwned<Tensor> weight_maybe_owned =
at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
return native::xpu::binary_cross_entropy_backward_kernel(
grad_output, self, target, weight, reduction, grad_input);
}

Tensor XPUNativeFunctions::huber_loss(
const Tensor& input,
const Tensor& target,
Expand Down
2 changes: 0 additions & 2 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"angle",
"avg_pool3d_backward.grad_input",
"avg_pool3d.out",
"binary_cross_entropy",
"binary_cross_entropy_backward",
"bitwise_left_shift.Tensor_out",
"bitwise_right_shift.Tensor_out",
"cauchy_",
Expand Down
115 changes: 115 additions & 0 deletions src/ATen/native/xpu/sycl/LossKernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include <ATen/ATen.h>
#include <ATen/core/Reduction.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/Loops.h>
#include <comm/SYCLContext.h>

namespace at::native::xpu {

template <typename scalar_t>
struct BinaryCrossEntropyFunctor {
scalar_t operator()(scalar_t input_val, scalar_t target_val) const {
const scalar_t zero = 0;
const scalar_t one = 1;
const scalar_t neg_100 = -100;

SYCL_KERNEL_ASSERT(input_val >= zero && input_val <= one);
SYCL_KERNEL_ASSERT(target_val >= zero && target_val <= one);

scalar_t log_input_val = std::log(input_val);
scalar_t log_1_minus_input_val = std::log1p(-input_val);

log_input_val = std::max(log_input_val, neg_100);
log_1_minus_input_val = std::max(log_1_minus_input_val, neg_100);

return ((target_val - one) * log_1_minus_input_val) -
(target_val * log_input_val);
}
};

Tensor& binary_cross_entropy_kernel(
const Tensor& input,
const Tensor& target,
const Tensor& weight,
int64_t reduction,
Tensor& loss) {
Tensor loss_squeezed = at::squeeze(loss);

TensorIterator iter = TensorIteratorConfig()
.add_output(loss_squeezed)
.add_owned_input(at::squeeze(input))
.add_owned_input(at::squeeze(target))
.build();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.common_dtype(),
"binary_cross_entropy_xpu",
[&]() { gpu_kernel(iter, BinaryCrossEntropyFunctor<scalar_t>()); });
if (weight.defined()) {
loss.mul_(weight);
}

if (reduction != at::Reduction::None) {
Tensor loss_reduced;
if (reduction == at::Reduction::Mean) {
loss_reduced = loss.mean();
} else if (reduction == at::Reduction::Sum) {
loss_reduced = loss.sum();
}
loss.resize_as_(loss_reduced).copy_(loss_reduced);
}

return loss;
}

template <typename scalar_t>
struct BinaryCrossEntropyBackwardFunctor {
scalar_t operator()(
scalar_t grad_val,
scalar_t input_val,
scalar_t target_val) const {
constexpr float EPSILON = 1e-12;
const scalar_t one = 1;
const scalar_t epsilon = EPSILON;

scalar_t grad_input_denominator =
std::max((one - input_val) * input_val, epsilon);

return grad_val * (input_val - target_val) / grad_input_denominator;
}
};

Tensor& binary_cross_entropy_backward_kernel(
const Tensor& grad,
const Tensor& input,
const Tensor& target,
const Tensor& weight,
int64_t reduction,
Tensor& grad_input) {
Tensor grad_expand = grad.expand_as(input);
at::TensorIterator iter = TensorIteratorConfig()
.add_output(grad_input)
.add_input(grad_expand)
.add_input(input)
.add_input(target)
.build();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.common_dtype(),
"binary_cross_entropy_backward_xpu",
[&]() {
gpu_kernel(iter, BinaryCrossEntropyBackwardFunctor<scalar_t>());
});

if (weight.defined()) {
grad_input.mul_(weight);
}
if (reduction == at::Reduction::Mean) {
grad_input.div_(input.numel());
}
return grad_input;
}

} // namespace at::native::xpu
21 changes: 21 additions & 0 deletions src/ATen/native/xpu/sycl/LossKernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once
#include <ATen/native/TensorIterator.h>

namespace at::native::xpu {

Tensor& binary_cross_entropy_kernel(
const Tensor& input,
const Tensor& target,
const Tensor& weight,
int64_t reduction,
Tensor& loss);

Tensor& binary_cross_entropy_backward_kernel(
const Tensor& grad,
const Tensor& input,
const Tensor& target,
const Tensor& weight,
int64_t reduction,
Tensor& grad_input);

} // namespace at::native::xpu
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
"nn.functional.upsample_nearest",
"nn.functional.nll_loss",
"nn.functional.mse_loss",
"nn.functional.binary_cross_entropy",
"nn.functional.huber_loss",
"sigmoid",
"logsigmoid",
Expand Down
4 changes: 4 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,10 @@ supported:
- mse_loss.out
- mse_loss_backward
- mse_loss_backward.grad_input
- binary_cross_entropy
- binary_cross_entropy.out
- binary_cross_entropy_backward
- binary_cross_entropy_backward.grad_input
- nll_loss_forward.output
- nll_loss_forward
- nll_loss_backward.grad_input
Expand Down

0 comments on commit 5f41843

Please sign in to comment.