diff --git a/src/ATen/native/xpu/Loss.cpp b/src/ATen/native/xpu/Loss.cpp index b5a8e465c..867304201 100644 --- a/src/ATen/native/xpu/Loss.cpp +++ b/src/ATen/native/xpu/Loss.cpp @@ -1,14 +1,24 @@ #include #include #include -#include - #include #include +#include #include namespace at { +static inline at::Tensor apply_loss_reduction( + const at::Tensor& unreduced, + int64_t reduction) { + if (reduction == at::Reduction::Mean) { + return unreduced.mean(); + } else if (reduction == at::Reduction::Sum) { + return unreduced.sum(); + } + return unreduced; +} + Tensor& XPUNativeFunctions::mse_loss_out( const Tensor& input, const Tensor& target, @@ -69,4 +79,53 @@ Tensor& XPUNativeFunctions::mse_loss_backward_out( return grad_input; } +Tensor XPUNativeFunctions::huber_loss( + const Tensor& input, + const Tensor& target, + int64_t reduction, + double delta) { + TORCH_CHECK( + delta > 0, "huber_loss does not support non-positive values for delta.") + Tensor loss = at::empty_like(input); + auto iter = TensorIterator::borrowing_binary_op(loss, input, target); + native::xpu::huber_kernel(iter, delta); + return apply_loss_reduction(loss, reduction); +} + +Tensor& XPUNativeFunctions::huber_loss_out( + const Tensor& input, + const Tensor& target, + int64_t reduction, + double delta, + Tensor& result) { + TORCH_CHECK( + delta > 0, "huber_loss does not support non-positive values for delta.") + auto iter = TensorIterator::borrowing_binary_op(result, input, target); + native::xpu::huber_kernel(iter, delta); + if (reduction != Reduction::None) { + auto reduced = apply_loss_reduction(result, reduction); + result.resize_({}); + result.copy_(reduced); + } + return result; +} + +Tensor& XPUNativeFunctions::huber_loss_backward_out( + const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + int64_t reduction, + double delta, + Tensor& grad_input) { + auto norm = (reduction == Reduction::Mean) ? (1. / input.numel()) : 1.; + auto iter = at::TensorIteratorConfig() + .add_output(grad_input) + .add_const_input(input) + .add_const_input(target) + .add_const_input(grad_output) + .build(); + native::xpu::huber_backward_kernel(iter, norm, delta); + return grad_input; +} + } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index f1b861881..c3fc196e9 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -210,8 +210,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "hardshrink.out", "heaviside.out", "histc", - "huber_loss", - "huber_loss_backward.out", "i0.out", "igammac.out", "igamma.out", diff --git a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp index 5081dfc13..00c5398af 100644 --- a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp @@ -23,4 +23,25 @@ void mse_kernel(TensorIteratorBase& iter) { [&]() { gpu_kernel(iter, MSEFunctor()); }); } +template +struct HuberFunctor { + scalar_t operator()(scalar_t a, scalar_t b) const { + auto z = std::abs(a - b); + return z < delta_val_ ? scalar_t(0.5) * z * z + : delta_val_ * (z - scalar_t(0.5) * delta_val_); + } + HuberFunctor(scalar_t delta_val) : delta_val_(delta_val) {} + + private: + scalar_t delta_val_; +}; + +void huber_kernel(TensorIterator& iter, double delta) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, iter.dtype(), "huber_xpu", [&iter, delta] { + scalar_t delta_val(delta); + gpu_kernel(iter, HuberFunctor(delta_val)); + }); +} + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h index abdcc146c..94cfb7c90 100644 --- a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h +++ b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h @@ -6,4 +6,6 @@ namespace at::native::xpu { void mse_kernel(TensorIteratorBase& iter); +void huber_kernel(TensorIterator& iter, double delta); + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp index d38f511d7..9be4a3ef4 100644 --- a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp @@ -125,4 +125,41 @@ void mse_backward_kernel(TensorIterator& iter, const Scalar& value) { }); } +template +struct HuberBackwardFunctor { + scalar_t operator()(scalar_t input, scalar_t target, scalar_t grad_output) + const { + const auto x = input - target; + if (x < -delta_val_) { + return -norm_val_ * grad_output * delta_val_; + } else if (x > delta_val_) { + return norm_val_ * grad_output * delta_val_; + } else { + return norm_val_ * x * grad_output; + } + } + HuberBackwardFunctor(scalar_t norm_val, scalar_t delta_val) + : norm_val_(norm_val), delta_val_(delta_val) {} + + private: + scalar_t norm_val_; + scalar_t delta_val_; +}; + +void huber_backward_kernel( + TensorIterator& iter, + const Scalar& norm, + double delta) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, + kHalf, + iter.dtype(), + "huber_backward_xpu", + [&iter, &norm, delta] { + auto norm_val = norm.to(); + scalar_t delta_val(delta); + gpu_kernel(iter, HuberBackwardFunctor(norm_val, delta_val)); + }); +} + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h index c775b88e5..586a64f3c 100644 --- a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h +++ b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h @@ -10,4 +10,9 @@ void addcdiv_kernel(TensorIterator& iter, Scalar value); void mse_backward_kernel(TensorIterator& iter, const Scalar& value); +void huber_backward_kernel( + TensorIterator& iter, + const Scalar& norm, + double delta); + } // namespace at::native::xpu diff --git a/test/xpu/extended/run_test_with_skip.py b/test/xpu/extended/run_test_with_skip.py index 108b3073d..dada6cbfd 100644 --- a/test/xpu/extended/run_test_with_skip.py +++ b/test/xpu/extended/run_test_with_skip.py @@ -125,6 +125,7 @@ "test_compare_cpu_nn_functional_batch_norm_xpu_bfloat16", "test_compare_cpu__batch_norm_with_update_xpu_bfloat16", "test_compare_cpu__batch_norm_with_update_xpu_float16", + "test_compare_cpu_nn_functional_huber_loss_xpu_bfloat16", # Not implemented operators, aten::upsample_linear1d, aten::upsample_bilinear2d, # aten::upsample_trilinear3d diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 16031eda2..f2661bf3a 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -145,6 +145,7 @@ "nn.functional.upsample_nearest", # "nn.functional.nll_loss", # Lack of XPU implementation of aten::nll_loss2d_forward. Will retrieve the case, only if the op is implemented. "nn.functional.mse_loss", + "nn.functional.huber_loss", "sigmoid", "sgn", "nn.functional.embedding_bag", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index f17a99051..418215f2e 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -408,6 +408,9 @@ supported: - nll_loss_forward - nll_loss_backward.grad_input - nll_loss_backward + - huber_loss + - huber_loss.out + - huber_loss_backward.out - batch_norm_stats - batch_norm_elemt - batch_norm_elemt.out