-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add aten::binary_cross_entropy/backward and their variants (#561)
Task list: - [x] binary_cross_entropy - [x] binary_cross_entropy.out - [x] binary_cross_entropy_backward - [x] binary_cross_entropy_backward.grad_input --------- Co-authored-by: Feng Yuan <[email protected]>
- Loading branch information
1 parent
0f14bdf
commit 5f41843
Showing
6 changed files
with
196 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/core/Reduction.h> | ||
#include <ATen/native/TensorIterator.h> | ||
#include <ATen/native/xpu/sycl/Loops.h> | ||
#include <comm/SYCLContext.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
template <typename scalar_t> | ||
struct BinaryCrossEntropyFunctor { | ||
scalar_t operator()(scalar_t input_val, scalar_t target_val) const { | ||
const scalar_t zero = 0; | ||
const scalar_t one = 1; | ||
const scalar_t neg_100 = -100; | ||
|
||
SYCL_KERNEL_ASSERT(input_val >= zero && input_val <= one); | ||
SYCL_KERNEL_ASSERT(target_val >= zero && target_val <= one); | ||
|
||
scalar_t log_input_val = std::log(input_val); | ||
scalar_t log_1_minus_input_val = std::log1p(-input_val); | ||
|
||
log_input_val = std::max(log_input_val, neg_100); | ||
log_1_minus_input_val = std::max(log_1_minus_input_val, neg_100); | ||
|
||
return ((target_val - one) * log_1_minus_input_val) - | ||
(target_val * log_input_val); | ||
} | ||
}; | ||
|
||
Tensor& binary_cross_entropy_kernel( | ||
const Tensor& input, | ||
const Tensor& target, | ||
const Tensor& weight, | ||
int64_t reduction, | ||
Tensor& loss) { | ||
Tensor loss_squeezed = at::squeeze(loss); | ||
|
||
TensorIterator iter = TensorIteratorConfig() | ||
.add_output(loss_squeezed) | ||
.add_owned_input(at::squeeze(input)) | ||
.add_owned_input(at::squeeze(target)) | ||
.build(); | ||
AT_DISPATCH_FLOATING_TYPES_AND2( | ||
at::ScalarType::Half, | ||
at::ScalarType::BFloat16, | ||
iter.common_dtype(), | ||
"binary_cross_entropy_xpu", | ||
[&]() { gpu_kernel(iter, BinaryCrossEntropyFunctor<scalar_t>()); }); | ||
if (weight.defined()) { | ||
loss.mul_(weight); | ||
} | ||
|
||
if (reduction != at::Reduction::None) { | ||
Tensor loss_reduced; | ||
if (reduction == at::Reduction::Mean) { | ||
loss_reduced = loss.mean(); | ||
} else if (reduction == at::Reduction::Sum) { | ||
loss_reduced = loss.sum(); | ||
} | ||
loss.resize_as_(loss_reduced).copy_(loss_reduced); | ||
} | ||
|
||
return loss; | ||
} | ||
|
||
template <typename scalar_t> | ||
struct BinaryCrossEntropyBackwardFunctor { | ||
scalar_t operator()( | ||
scalar_t grad_val, | ||
scalar_t input_val, | ||
scalar_t target_val) const { | ||
constexpr float EPSILON = 1e-12; | ||
const scalar_t one = 1; | ||
const scalar_t epsilon = EPSILON; | ||
|
||
scalar_t grad_input_denominator = | ||
std::max((one - input_val) * input_val, epsilon); | ||
|
||
return grad_val * (input_val - target_val) / grad_input_denominator; | ||
} | ||
}; | ||
|
||
Tensor& binary_cross_entropy_backward_kernel( | ||
const Tensor& grad, | ||
const Tensor& input, | ||
const Tensor& target, | ||
const Tensor& weight, | ||
int64_t reduction, | ||
Tensor& grad_input) { | ||
Tensor grad_expand = grad.expand_as(input); | ||
at::TensorIterator iter = TensorIteratorConfig() | ||
.add_output(grad_input) | ||
.add_input(grad_expand) | ||
.add_input(input) | ||
.add_input(target) | ||
.build(); | ||
AT_DISPATCH_FLOATING_TYPES_AND2( | ||
at::ScalarType::Half, | ||
at::ScalarType::BFloat16, | ||
iter.common_dtype(), | ||
"binary_cross_entropy_backward_xpu", | ||
[&]() { | ||
gpu_kernel(iter, BinaryCrossEntropyBackwardFunctor<scalar_t>()); | ||
}); | ||
|
||
if (weight.defined()) { | ||
grad_input.mul_(weight); | ||
} | ||
if (reduction == at::Reduction::Mean) { | ||
grad_input.div_(input.numel()); | ||
} | ||
return grad_input; | ||
} | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#pragma once | ||
#include <ATen/native/TensorIterator.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
Tensor& binary_cross_entropy_kernel( | ||
const Tensor& input, | ||
const Tensor& target, | ||
const Tensor& weight, | ||
int64_t reduction, | ||
Tensor& loss); | ||
|
||
Tensor& binary_cross_entropy_backward_kernel( | ||
const Tensor& grad, | ||
const Tensor& input, | ||
const Tensor& target, | ||
const Tensor& weight, | ||
int64_t reduction, | ||
Tensor& grad_input); | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters